poky/bitbake/lib/hashserv/server.py
Joshua Watt 07a02b31fd bitbake: hashserv: Chunkify large messages
The hash equivalence client and server can occasionally send messages
that are too large for the server to fit in the receive buffer (64 KB).
To prevent this, support is added to the protocol to "chunkify" the
stream and break it up into manageable pieces that the server can each
side can back together.

Ideally, this would be negotiated by the client and server, but it's
currently hard coded to 32 KB to prevent the round-trip delay.

(Bitbake rev: e27a28c1e40e886ee68ba4b99b537ffc9c3577d4)

Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
Signed-off-by: Richard Purdie <richard.purdie@linuxfoundation.org>
2020-06-28 08:36:56 +01:00

490 lines
15 KiB
Python

# Copyright (C) 2019 Garmin Ltd.
#
# SPDX-License-Identifier: GPL-2.0-only
#
from contextlib import closing
from datetime import datetime
import asyncio
import json
import logging
import math
import os
import signal
import socket
import time
from . import chunkify, DEFAULT_MAX_CHUNK
logger = logging.getLogger('hashserv.server')
class Measurement(object):
def __init__(self, sample):
self.sample = sample
def start(self):
self.start_time = time.perf_counter()
def end(self):
self.sample.add(time.perf_counter() - self.start_time)
def __enter__(self):
self.start()
return self
def __exit__(self, *args, **kwargs):
self.end()
class Sample(object):
def __init__(self, stats):
self.stats = stats
self.num_samples = 0
self.elapsed = 0
def measure(self):
return Measurement(self)
def __enter__(self):
return self
def __exit__(self, *args, **kwargs):
self.end()
def add(self, elapsed):
self.num_samples += 1
self.elapsed += elapsed
def end(self):
if self.num_samples:
self.stats.add(self.elapsed)
self.num_samples = 0
self.elapsed = 0
class Stats(object):
def __init__(self):
self.reset()
def reset(self):
self.num = 0
self.total_time = 0
self.max_time = 0
self.m = 0
self.s = 0
self.current_elapsed = None
def add(self, elapsed):
self.num += 1
if self.num == 1:
self.m = elapsed
self.s = 0
else:
last_m = self.m
self.m = last_m + (elapsed - last_m) / self.num
self.s = self.s + (elapsed - last_m) * (elapsed - self.m)
self.total_time += elapsed
if self.max_time < elapsed:
self.max_time = elapsed
def start_sample(self):
return Sample(self)
@property
def average(self):
if self.num == 0:
return 0
return self.total_time / self.num
@property
def stdev(self):
if self.num <= 1:
return 0
return math.sqrt(self.s / (self.num - 1))
def todict(self):
return {k: getattr(self, k) for k in ('num', 'total_time', 'max_time', 'average', 'stdev')}
class ClientError(Exception):
pass
class ServerClient(object):
FAST_QUERY = 'SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1'
ALL_QUERY = 'SELECT * FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1'
def __init__(self, reader, writer, db, request_stats):
self.reader = reader
self.writer = writer
self.db = db
self.request_stats = request_stats
self.max_chunk = DEFAULT_MAX_CHUNK
self.handlers = {
'get': self.handle_get,
'report': self.handle_report,
'report-equiv': self.handle_equivreport,
'get-stream': self.handle_get_stream,
'get-stats': self.handle_get_stats,
'reset-stats': self.handle_reset_stats,
'chunk-stream': self.handle_chunk,
}
async def process_requests(self):
try:
self.addr = self.writer.get_extra_info('peername')
logger.debug('Client %r connected' % (self.addr,))
# Read protocol and version
protocol = await self.reader.readline()
if protocol is None:
return
(proto_name, proto_version) = protocol.decode('utf-8').rstrip().split()
if proto_name != 'OEHASHEQUIV':
return
proto_version = tuple(int(v) for v in proto_version.split('.'))
if proto_version < (1, 0) or proto_version > (1, 1):
return
# Read headers. Currently, no headers are implemented, so look for
# an empty line to signal the end of the headers
while True:
line = await self.reader.readline()
if line is None:
return
line = line.decode('utf-8').rstrip()
if not line:
break
# Handle messages
while True:
d = await self.read_message()
if d is None:
break
await self.dispatch_message(d)
await self.writer.drain()
except ClientError as e:
logger.error(str(e))
finally:
self.writer.close()
async def dispatch_message(self, msg):
for k in self.handlers.keys():
if k in msg:
logger.debug('Handling %s' % k)
if 'stream' in k:
await self.handlers[k](msg[k])
else:
with self.request_stats.start_sample() as self.request_sample, \
self.request_sample.measure():
await self.handlers[k](msg[k])
return
raise ClientError("Unrecognized command %r" % msg)
def write_message(self, msg):
for c in chunkify(json.dumps(msg), self.max_chunk):
self.writer.write(c.encode('utf-8'))
async def read_message(self):
l = await self.reader.readline()
if not l:
return None
try:
message = l.decode('utf-8')
if not message.endswith('\n'):
return None
return json.loads(message)
except (json.JSONDecodeError, UnicodeDecodeError) as e:
logger.error('Bad message from client: %r' % message)
raise e
async def handle_chunk(self, request):
lines = []
try:
while True:
l = await self.reader.readline()
l = l.rstrip(b"\n").decode("utf-8")
if not l:
break
lines.append(l)
msg = json.loads(''.join(lines))
except (json.JSONDecodeError, UnicodeDecodeError) as e:
logger.error('Bad message from client: %r' % message)
raise e
if 'chunk-stream' in msg:
raise ClientError("Nested chunks are not allowed")
await self.dispatch_message(msg)
async def handle_get(self, request):
method = request['method']
taskhash = request['taskhash']
if request.get('all', False):
row = self.query_equivalent(method, taskhash, self.ALL_QUERY)
else:
row = self.query_equivalent(method, taskhash, self.FAST_QUERY)
if row is not None:
logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
d = {k: row[k] for k in row.keys()}
self.write_message(d)
else:
self.write_message(None)
async def handle_get_stream(self, request):
self.write_message('ok')
while True:
l = await self.reader.readline()
if not l:
return
try:
# This inner loop is very sensitive and must be as fast as
# possible (which is why the request sample is handled manually
# instead of using 'with', and also why logging statements are
# commented out.
self.request_sample = self.request_stats.start_sample()
request_measure = self.request_sample.measure()
request_measure.start()
l = l.decode('utf-8').rstrip()
if l == 'END':
self.writer.write('ok\n'.encode('utf-8'))
return
(method, taskhash) = l.split()
#logger.debug('Looking up %s %s' % (method, taskhash))
row = self.query_equivalent(method, taskhash, self.FAST_QUERY)
if row is not None:
msg = ('%s\n' % row['unihash']).encode('utf-8')
#logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
else:
msg = '\n'.encode('utf-8')
self.writer.write(msg)
finally:
request_measure.end()
self.request_sample.end()
await self.writer.drain()
async def handle_report(self, data):
with closing(self.db.cursor()) as cursor:
cursor.execute('''
-- Find tasks with a matching outhash (that is, tasks that
-- are equivalent)
SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND outhash=:outhash
-- If there is an exact match on the taskhash, return it.
-- Otherwise return the oldest matching outhash of any
-- taskhash
ORDER BY CASE WHEN taskhash=:taskhash THEN 1 ELSE 2 END,
created ASC
-- Only return one row
LIMIT 1
''', {k: data[k] for k in ('method', 'outhash', 'taskhash')})
row = cursor.fetchone()
# If no matching outhash was found, or one *was* found but it
# wasn't an exact match on the taskhash, a new entry for this
# taskhash should be added
if row is None or row['taskhash'] != data['taskhash']:
# If a row matching the outhash was found, the unihash for
# the new taskhash should be the same as that one.
# Otherwise the caller provided unihash is used.
unihash = data['unihash']
if row is not None:
unihash = row['unihash']
insert_data = {
'method': data['method'],
'outhash': data['outhash'],
'taskhash': data['taskhash'],
'unihash': unihash,
'created': datetime.now()
}
for k in ('owner', 'PN', 'PV', 'PR', 'task', 'outhash_siginfo'):
if k in data:
insert_data[k] = data[k]
cursor.execute('''INSERT INTO tasks_v2 (%s) VALUES (%s)''' % (
', '.join(sorted(insert_data.keys())),
', '.join(':' + k for k in sorted(insert_data.keys()))),
insert_data)
self.db.commit()
logger.info('Adding taskhash %s with unihash %s',
data['taskhash'], unihash)
d = {
'taskhash': data['taskhash'],
'method': data['method'],
'unihash': unihash
}
else:
d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
self.write_message(d)
async def handle_equivreport(self, data):
with closing(self.db.cursor()) as cursor:
insert_data = {
'method': data['method'],
'outhash': "",
'taskhash': data['taskhash'],
'unihash': data['unihash'],
'created': datetime.now()
}
for k in ('owner', 'PN', 'PV', 'PR', 'task', 'outhash_siginfo'):
if k in data:
insert_data[k] = data[k]
cursor.execute('''INSERT OR IGNORE INTO tasks_v2 (%s) VALUES (%s)''' % (
', '.join(sorted(insert_data.keys())),
', '.join(':' + k for k in sorted(insert_data.keys()))),
insert_data)
self.db.commit()
# Fetch the unihash that will be reported for the taskhash. If the
# unihash matches, it means this row was inserted (or the mapping
# was already valid)
row = self.query_equivalent(data['method'], data['taskhash'], self.FAST_QUERY)
if row['unihash'] == data['unihash']:
logger.info('Adding taskhash equivalence for %s with unihash %s',
data['taskhash'], row['unihash'])
d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
self.write_message(d)
async def handle_get_stats(self, request):
d = {
'requests': self.request_stats.todict(),
}
self.write_message(d)
async def handle_reset_stats(self, request):
d = {
'requests': self.request_stats.todict(),
}
self.request_stats.reset()
self.write_message(d)
def query_equivalent(self, method, taskhash, query):
# This is part of the inner loop and must be as fast as possible
try:
cursor = self.db.cursor()
cursor.execute(query, {'method': method, 'taskhash': taskhash})
return cursor.fetchone()
except:
cursor.close()
class Server(object):
def __init__(self, db, loop=None):
self.request_stats = Stats()
self.db = db
if loop is None:
self.loop = asyncio.new_event_loop()
self.close_loop = True
else:
self.loop = loop
self.close_loop = False
self._cleanup_socket = None
def start_tcp_server(self, host, port):
self.server = self.loop.run_until_complete(
asyncio.start_server(self.handle_client, host, port, loop=self.loop)
)
for s in self.server.sockets:
logger.info('Listening on %r' % (s.getsockname(),))
# Newer python does this automatically. Do it manually here for
# maximum compatibility
s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
name = self.server.sockets[0].getsockname()
if self.server.sockets[0].family == socket.AF_INET6:
self.address = "[%s]:%d" % (name[0], name[1])
else:
self.address = "%s:%d" % (name[0], name[1])
def start_unix_server(self, path):
def cleanup():
os.unlink(path)
cwd = os.getcwd()
try:
# Work around path length limits in AF_UNIX
os.chdir(os.path.dirname(path))
self.server = self.loop.run_until_complete(
asyncio.start_unix_server(self.handle_client, os.path.basename(path), loop=self.loop)
)
finally:
os.chdir(cwd)
logger.info('Listening on %r' % path)
self._cleanup_socket = cleanup
self.address = "unix://%s" % os.path.abspath(path)
async def handle_client(self, reader, writer):
# writer.transport.set_write_buffer_limits(0)
try:
client = ServerClient(reader, writer, self.db, self.request_stats)
await client.process_requests()
except Exception as e:
import traceback
logger.error('Error from client: %s' % str(e), exc_info=True)
traceback.print_exc()
writer.close()
logger.info('Client disconnected')
def serve_forever(self):
def signal_handler():
self.loop.stop()
self.loop.add_signal_handler(signal.SIGTERM, signal_handler)
try:
self.loop.run_forever()
except KeyboardInterrupt:
pass
self.server.close()
self.loop.run_until_complete(self.server.wait_closed())
logger.info('Server shutting down')
if self.close_loop:
self.loop.close()
if self._cleanup_socket is not None:
self._cleanup_socket()