mirror of
git://git.yoctoproject.org/poky.git
synced 2025-07-19 12:59:02 +02:00

Adds support for an upstream server to be specified. The upstream server will be queried for equivalent hashes whenever a miss is found in the local server. If the server returns a match, it is merged into the local database. In order to keep the get stream queries as fast as possible since they are the critical path when bitbake is preparing the run queue, missing tasks provided by the server are not immediately pulled from the upstream server, but instead are put into a queue to be backfilled by a worker task later. (Bitbake rev: e6d6c0b39393e9bdf378c1eba141f815e26b724b) Signed-off-by: Joshua Watt <JPEWhacker@gmail.com> Signed-off-by: Richard Purdie <richard.purdie@linuxfoundation.org>
579 lines
18 KiB
Python
579 lines
18 KiB
Python
# Copyright (C) 2019 Garmin Ltd.
|
|
#
|
|
# SPDX-License-Identifier: GPL-2.0-only
|
|
#
|
|
|
|
from contextlib import closing, contextmanager
|
|
from datetime import datetime
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import math
|
|
import os
|
|
import signal
|
|
import socket
|
|
import sys
|
|
import time
|
|
from . import chunkify, DEFAULT_MAX_CHUNK, create_async_client, TABLE_COLUMNS
|
|
|
|
logger = logging.getLogger('hashserv.server')
|
|
|
|
|
|
class Measurement(object):
|
|
def __init__(self, sample):
|
|
self.sample = sample
|
|
|
|
def start(self):
|
|
self.start_time = time.perf_counter()
|
|
|
|
def end(self):
|
|
self.sample.add(time.perf_counter() - self.start_time)
|
|
|
|
def __enter__(self):
|
|
self.start()
|
|
return self
|
|
|
|
def __exit__(self, *args, **kwargs):
|
|
self.end()
|
|
|
|
|
|
class Sample(object):
|
|
def __init__(self, stats):
|
|
self.stats = stats
|
|
self.num_samples = 0
|
|
self.elapsed = 0
|
|
|
|
def measure(self):
|
|
return Measurement(self)
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, *args, **kwargs):
|
|
self.end()
|
|
|
|
def add(self, elapsed):
|
|
self.num_samples += 1
|
|
self.elapsed += elapsed
|
|
|
|
def end(self):
|
|
if self.num_samples:
|
|
self.stats.add(self.elapsed)
|
|
self.num_samples = 0
|
|
self.elapsed = 0
|
|
|
|
|
|
class Stats(object):
|
|
def __init__(self):
|
|
self.reset()
|
|
|
|
def reset(self):
|
|
self.num = 0
|
|
self.total_time = 0
|
|
self.max_time = 0
|
|
self.m = 0
|
|
self.s = 0
|
|
self.current_elapsed = None
|
|
|
|
def add(self, elapsed):
|
|
self.num += 1
|
|
if self.num == 1:
|
|
self.m = elapsed
|
|
self.s = 0
|
|
else:
|
|
last_m = self.m
|
|
self.m = last_m + (elapsed - last_m) / self.num
|
|
self.s = self.s + (elapsed - last_m) * (elapsed - self.m)
|
|
|
|
self.total_time += elapsed
|
|
|
|
if self.max_time < elapsed:
|
|
self.max_time = elapsed
|
|
|
|
def start_sample(self):
|
|
return Sample(self)
|
|
|
|
@property
|
|
def average(self):
|
|
if self.num == 0:
|
|
return 0
|
|
return self.total_time / self.num
|
|
|
|
@property
|
|
def stdev(self):
|
|
if self.num <= 1:
|
|
return 0
|
|
return math.sqrt(self.s / (self.num - 1))
|
|
|
|
def todict(self):
|
|
return {k: getattr(self, k) for k in ('num', 'total_time', 'max_time', 'average', 'stdev')}
|
|
|
|
|
|
class ClientError(Exception):
|
|
pass
|
|
|
|
def insert_task(cursor, data, ignore=False):
|
|
keys = sorted(data.keys())
|
|
query = '''INSERT%s INTO tasks_v2 (%s) VALUES (%s)''' % (
|
|
" OR IGNORE" if ignore else "",
|
|
', '.join(keys),
|
|
', '.join(':' + k for k in keys))
|
|
cursor.execute(query, data)
|
|
|
|
async def copy_from_upstream(client, db, method, taskhash):
|
|
d = await client.get_taskhash(method, taskhash, True)
|
|
if d is not None:
|
|
# Filter out unknown columns
|
|
d = {k: v for k, v in d.items() if k in TABLE_COLUMNS}
|
|
keys = sorted(d.keys())
|
|
|
|
|
|
with closing(db.cursor()) as cursor:
|
|
insert_task(cursor, d)
|
|
db.commit()
|
|
|
|
return d
|
|
|
|
class ServerClient(object):
|
|
FAST_QUERY = 'SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1'
|
|
ALL_QUERY = 'SELECT * FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1'
|
|
|
|
def __init__(self, reader, writer, db, request_stats, backfill_queue, upstream):
|
|
self.reader = reader
|
|
self.writer = writer
|
|
self.db = db
|
|
self.request_stats = request_stats
|
|
self.max_chunk = DEFAULT_MAX_CHUNK
|
|
self.backfill_queue = backfill_queue
|
|
self.upstream = upstream
|
|
|
|
self.handlers = {
|
|
'get': self.handle_get,
|
|
'report': self.handle_report,
|
|
'report-equiv': self.handle_equivreport,
|
|
'get-stream': self.handle_get_stream,
|
|
'get-stats': self.handle_get_stats,
|
|
'reset-stats': self.handle_reset_stats,
|
|
'chunk-stream': self.handle_chunk,
|
|
'backfill-wait': self.handle_backfill_wait,
|
|
}
|
|
|
|
async def process_requests(self):
|
|
if self.upstream is not None:
|
|
self.upstream_client = await create_async_client(self.upstream)
|
|
else:
|
|
self.upstream_client = None
|
|
|
|
try:
|
|
|
|
|
|
self.addr = self.writer.get_extra_info('peername')
|
|
logger.debug('Client %r connected' % (self.addr,))
|
|
|
|
# Read protocol and version
|
|
protocol = await self.reader.readline()
|
|
if protocol is None:
|
|
return
|
|
|
|
(proto_name, proto_version) = protocol.decode('utf-8').rstrip().split()
|
|
if proto_name != 'OEHASHEQUIV':
|
|
return
|
|
|
|
proto_version = tuple(int(v) for v in proto_version.split('.'))
|
|
if proto_version < (1, 0) or proto_version > (1, 1):
|
|
return
|
|
|
|
# Read headers. Currently, no headers are implemented, so look for
|
|
# an empty line to signal the end of the headers
|
|
while True:
|
|
line = await self.reader.readline()
|
|
if line is None:
|
|
return
|
|
|
|
line = line.decode('utf-8').rstrip()
|
|
if not line:
|
|
break
|
|
|
|
# Handle messages
|
|
while True:
|
|
d = await self.read_message()
|
|
if d is None:
|
|
break
|
|
await self.dispatch_message(d)
|
|
await self.writer.drain()
|
|
except ClientError as e:
|
|
logger.error(str(e))
|
|
finally:
|
|
if self.upstream_client is not None:
|
|
await self.upstream_client.close()
|
|
|
|
self.writer.close()
|
|
|
|
async def dispatch_message(self, msg):
|
|
for k in self.handlers.keys():
|
|
if k in msg:
|
|
logger.debug('Handling %s' % k)
|
|
if 'stream' in k:
|
|
await self.handlers[k](msg[k])
|
|
else:
|
|
with self.request_stats.start_sample() as self.request_sample, \
|
|
self.request_sample.measure():
|
|
await self.handlers[k](msg[k])
|
|
return
|
|
|
|
raise ClientError("Unrecognized command %r" % msg)
|
|
|
|
def write_message(self, msg):
|
|
for c in chunkify(json.dumps(msg), self.max_chunk):
|
|
self.writer.write(c.encode('utf-8'))
|
|
|
|
async def read_message(self):
|
|
l = await self.reader.readline()
|
|
if not l:
|
|
return None
|
|
|
|
try:
|
|
message = l.decode('utf-8')
|
|
|
|
if not message.endswith('\n'):
|
|
return None
|
|
|
|
return json.loads(message)
|
|
except (json.JSONDecodeError, UnicodeDecodeError) as e:
|
|
logger.error('Bad message from client: %r' % message)
|
|
raise e
|
|
|
|
async def handle_chunk(self, request):
|
|
lines = []
|
|
try:
|
|
while True:
|
|
l = await self.reader.readline()
|
|
l = l.rstrip(b"\n").decode("utf-8")
|
|
if not l:
|
|
break
|
|
lines.append(l)
|
|
|
|
msg = json.loads(''.join(lines))
|
|
except (json.JSONDecodeError, UnicodeDecodeError) as e:
|
|
logger.error('Bad message from client: %r' % message)
|
|
raise e
|
|
|
|
if 'chunk-stream' in msg:
|
|
raise ClientError("Nested chunks are not allowed")
|
|
|
|
await self.dispatch_message(msg)
|
|
|
|
async def handle_get(self, request):
|
|
method = request['method']
|
|
taskhash = request['taskhash']
|
|
|
|
if request.get('all', False):
|
|
row = self.query_equivalent(method, taskhash, self.ALL_QUERY)
|
|
else:
|
|
row = self.query_equivalent(method, taskhash, self.FAST_QUERY)
|
|
|
|
if row is not None:
|
|
logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
|
|
d = {k: row[k] for k in row.keys()}
|
|
elif self.upstream_client is not None:
|
|
d = await copy_from_upstream(self.upstream_client, self.db, method, taskhash)
|
|
else:
|
|
d = None
|
|
|
|
self.write_message(d)
|
|
|
|
async def handle_get_stream(self, request):
|
|
self.write_message('ok')
|
|
|
|
while True:
|
|
upstream = None
|
|
|
|
l = await self.reader.readline()
|
|
if not l:
|
|
return
|
|
|
|
try:
|
|
# This inner loop is very sensitive and must be as fast as
|
|
# possible (which is why the request sample is handled manually
|
|
# instead of using 'with', and also why logging statements are
|
|
# commented out.
|
|
self.request_sample = self.request_stats.start_sample()
|
|
request_measure = self.request_sample.measure()
|
|
request_measure.start()
|
|
|
|
l = l.decode('utf-8').rstrip()
|
|
if l == 'END':
|
|
self.writer.write('ok\n'.encode('utf-8'))
|
|
return
|
|
|
|
(method, taskhash) = l.split()
|
|
#logger.debug('Looking up %s %s' % (method, taskhash))
|
|
row = self.query_equivalent(method, taskhash, self.FAST_QUERY)
|
|
if row is not None:
|
|
msg = ('%s\n' % row['unihash']).encode('utf-8')
|
|
#logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
|
|
elif self.upstream_client is not None:
|
|
upstream = await self.upstream_client.get_unihash(method, taskhash)
|
|
if upstream:
|
|
msg = ("%s\n" % upstream).encode("utf-8")
|
|
else:
|
|
msg = "\n".encode("utf-8")
|
|
else:
|
|
msg = '\n'.encode('utf-8')
|
|
|
|
self.writer.write(msg)
|
|
finally:
|
|
request_measure.end()
|
|
self.request_sample.end()
|
|
|
|
await self.writer.drain()
|
|
|
|
# Post to the backfill queue after writing the result to minimize
|
|
# the turn around time on a request
|
|
if upstream is not None:
|
|
await self.backfill_queue.put((method, taskhash))
|
|
|
|
async def handle_report(self, data):
|
|
with closing(self.db.cursor()) as cursor:
|
|
cursor.execute('''
|
|
-- Find tasks with a matching outhash (that is, tasks that
|
|
-- are equivalent)
|
|
SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND outhash=:outhash
|
|
|
|
-- If there is an exact match on the taskhash, return it.
|
|
-- Otherwise return the oldest matching outhash of any
|
|
-- taskhash
|
|
ORDER BY CASE WHEN taskhash=:taskhash THEN 1 ELSE 2 END,
|
|
created ASC
|
|
|
|
-- Only return one row
|
|
LIMIT 1
|
|
''', {k: data[k] for k in ('method', 'outhash', 'taskhash')})
|
|
|
|
row = cursor.fetchone()
|
|
|
|
# If no matching outhash was found, or one *was* found but it
|
|
# wasn't an exact match on the taskhash, a new entry for this
|
|
# taskhash should be added
|
|
if row is None or row['taskhash'] != data['taskhash']:
|
|
# If a row matching the outhash was found, the unihash for
|
|
# the new taskhash should be the same as that one.
|
|
# Otherwise the caller provided unihash is used.
|
|
unihash = data['unihash']
|
|
if row is not None:
|
|
unihash = row['unihash']
|
|
|
|
insert_data = {
|
|
'method': data['method'],
|
|
'outhash': data['outhash'],
|
|
'taskhash': data['taskhash'],
|
|
'unihash': unihash,
|
|
'created': datetime.now()
|
|
}
|
|
|
|
for k in ('owner', 'PN', 'PV', 'PR', 'task', 'outhash_siginfo'):
|
|
if k in data:
|
|
insert_data[k] = data[k]
|
|
|
|
insert_task(cursor, insert_data)
|
|
self.db.commit()
|
|
|
|
logger.info('Adding taskhash %s with unihash %s',
|
|
data['taskhash'], unihash)
|
|
|
|
d = {
|
|
'taskhash': data['taskhash'],
|
|
'method': data['method'],
|
|
'unihash': unihash
|
|
}
|
|
else:
|
|
d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
|
|
|
|
self.write_message(d)
|
|
|
|
async def handle_equivreport(self, data):
|
|
with closing(self.db.cursor()) as cursor:
|
|
insert_data = {
|
|
'method': data['method'],
|
|
'outhash': "",
|
|
'taskhash': data['taskhash'],
|
|
'unihash': data['unihash'],
|
|
'created': datetime.now()
|
|
}
|
|
|
|
for k in ('owner', 'PN', 'PV', 'PR', 'task', 'outhash_siginfo'):
|
|
if k in data:
|
|
insert_data[k] = data[k]
|
|
|
|
insert_task(cursor, insert_data, ignore=True)
|
|
self.db.commit()
|
|
|
|
# Fetch the unihash that will be reported for the taskhash. If the
|
|
# unihash matches, it means this row was inserted (or the mapping
|
|
# was already valid)
|
|
row = self.query_equivalent(data['method'], data['taskhash'], self.FAST_QUERY)
|
|
|
|
if row['unihash'] == data['unihash']:
|
|
logger.info('Adding taskhash equivalence for %s with unihash %s',
|
|
data['taskhash'], row['unihash'])
|
|
|
|
d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
|
|
|
|
self.write_message(d)
|
|
|
|
|
|
async def handle_get_stats(self, request):
|
|
d = {
|
|
'requests': self.request_stats.todict(),
|
|
}
|
|
|
|
self.write_message(d)
|
|
|
|
async def handle_reset_stats(self, request):
|
|
d = {
|
|
'requests': self.request_stats.todict(),
|
|
}
|
|
|
|
self.request_stats.reset()
|
|
self.write_message(d)
|
|
|
|
async def handle_backfill_wait(self, request):
|
|
d = {
|
|
'tasks': self.backfill_queue.qsize(),
|
|
}
|
|
await self.backfill_queue.join()
|
|
self.write_message(d)
|
|
|
|
def query_equivalent(self, method, taskhash, query):
|
|
# This is part of the inner loop and must be as fast as possible
|
|
try:
|
|
cursor = self.db.cursor()
|
|
cursor.execute(query, {'method': method, 'taskhash': taskhash})
|
|
return cursor.fetchone()
|
|
except:
|
|
cursor.close()
|
|
|
|
|
|
class Server(object):
|
|
def __init__(self, db, loop=None, upstream=None):
|
|
self.request_stats = Stats()
|
|
self.db = db
|
|
|
|
if loop is None:
|
|
self.loop = asyncio.new_event_loop()
|
|
self.close_loop = True
|
|
else:
|
|
self.loop = loop
|
|
self.close_loop = False
|
|
|
|
self.upstream = upstream
|
|
|
|
self._cleanup_socket = None
|
|
|
|
def start_tcp_server(self, host, port):
|
|
self.server = self.loop.run_until_complete(
|
|
asyncio.start_server(self.handle_client, host, port, loop=self.loop)
|
|
)
|
|
|
|
for s in self.server.sockets:
|
|
logger.info('Listening on %r' % (s.getsockname(),))
|
|
# Newer python does this automatically. Do it manually here for
|
|
# maximum compatibility
|
|
s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
|
|
s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
|
|
|
|
name = self.server.sockets[0].getsockname()
|
|
if self.server.sockets[0].family == socket.AF_INET6:
|
|
self.address = "[%s]:%d" % (name[0], name[1])
|
|
else:
|
|
self.address = "%s:%d" % (name[0], name[1])
|
|
|
|
def start_unix_server(self, path):
|
|
def cleanup():
|
|
os.unlink(path)
|
|
|
|
cwd = os.getcwd()
|
|
try:
|
|
# Work around path length limits in AF_UNIX
|
|
os.chdir(os.path.dirname(path))
|
|
self.server = self.loop.run_until_complete(
|
|
asyncio.start_unix_server(self.handle_client, os.path.basename(path), loop=self.loop)
|
|
)
|
|
finally:
|
|
os.chdir(cwd)
|
|
|
|
logger.info('Listening on %r' % path)
|
|
|
|
self._cleanup_socket = cleanup
|
|
self.address = "unix://%s" % os.path.abspath(path)
|
|
|
|
async def handle_client(self, reader, writer):
|
|
# writer.transport.set_write_buffer_limits(0)
|
|
try:
|
|
client = ServerClient(reader, writer, self.db, self.request_stats, self.backfill_queue, self.upstream)
|
|
await client.process_requests()
|
|
except Exception as e:
|
|
import traceback
|
|
logger.error('Error from client: %s' % str(e), exc_info=True)
|
|
traceback.print_exc()
|
|
writer.close()
|
|
logger.info('Client disconnected')
|
|
|
|
@contextmanager
|
|
def _backfill_worker(self):
|
|
async def backfill_worker_task():
|
|
client = await create_async_client(self.upstream)
|
|
try:
|
|
while True:
|
|
item = await self.backfill_queue.get()
|
|
if item is None:
|
|
self.backfill_queue.task_done()
|
|
break
|
|
method, taskhash = item
|
|
await copy_from_upstream(client, self.db, method, taskhash)
|
|
self.backfill_queue.task_done()
|
|
finally:
|
|
await client.close()
|
|
|
|
async def join_worker(worker):
|
|
await self.backfill_queue.put(None)
|
|
await worker
|
|
|
|
if self.upstream is not None:
|
|
worker = asyncio.ensure_future(backfill_worker_task())
|
|
try:
|
|
yield
|
|
finally:
|
|
self.loop.run_until_complete(join_worker(worker))
|
|
else:
|
|
yield
|
|
|
|
def serve_forever(self):
|
|
def signal_handler():
|
|
self.loop.stop()
|
|
|
|
asyncio.set_event_loop(self.loop)
|
|
try:
|
|
self.backfill_queue = asyncio.Queue()
|
|
|
|
self.loop.add_signal_handler(signal.SIGTERM, signal_handler)
|
|
|
|
with self._backfill_worker():
|
|
try:
|
|
self.loop.run_forever()
|
|
except KeyboardInterrupt:
|
|
pass
|
|
|
|
self.server.close()
|
|
|
|
self.loop.run_until_complete(self.server.wait_closed())
|
|
logger.info('Server shutting down')
|
|
finally:
|
|
if self.close_loop:
|
|
if sys.version_info >= (3, 6):
|
|
self.loop.run_until_complete(self.loop.shutdown_asyncgens())
|
|
self.loop.close()
|
|
|
|
if self._cleanup_socket is not None:
|
|
self._cleanup_socket()
|