poky/bitbake/lib/hashserv/server.py
Joshua Watt 8f8501ed40 bitbake: asyncrpc: Abstract sockets
Rewrites the asyncrpc client and server code to make it possible to have
other transport backends that are not stream based (e.g. websockets
which are message based). The connection handling classes are now shared
between both the client and server to make it easier to implement new
transport mechanisms

(Bitbake rev: 2aaeae53696e4c2f13a169830c3b7089cbad6eca)

Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
Signed-off-by: Richard Purdie <richard.purdie@linuxfoundation.org>
2023-11-09 17:33:02 +00:00

608 lines
20 KiB
Python

# Copyright (C) 2019 Garmin Ltd.
#
# SPDX-License-Identifier: GPL-2.0-only
#
from contextlib import closing, contextmanager
from datetime import datetime, timedelta
import enum
import asyncio
import logging
import math
import time
from . import create_async_client, UNIHASH_TABLE_COLUMNS, OUTHASH_TABLE_COLUMNS
import bb.asyncrpc
logger = logging.getLogger('hashserv.server')
class Measurement(object):
def __init__(self, sample):
self.sample = sample
def start(self):
self.start_time = time.perf_counter()
def end(self):
self.sample.add(time.perf_counter() - self.start_time)
def __enter__(self):
self.start()
return self
def __exit__(self, *args, **kwargs):
self.end()
class Sample(object):
def __init__(self, stats):
self.stats = stats
self.num_samples = 0
self.elapsed = 0
def measure(self):
return Measurement(self)
def __enter__(self):
return self
def __exit__(self, *args, **kwargs):
self.end()
def add(self, elapsed):
self.num_samples += 1
self.elapsed += elapsed
def end(self):
if self.num_samples:
self.stats.add(self.elapsed)
self.num_samples = 0
self.elapsed = 0
class Stats(object):
def __init__(self):
self.reset()
def reset(self):
self.num = 0
self.total_time = 0
self.max_time = 0
self.m = 0
self.s = 0
self.current_elapsed = None
def add(self, elapsed):
self.num += 1
if self.num == 1:
self.m = elapsed
self.s = 0
else:
last_m = self.m
self.m = last_m + (elapsed - last_m) / self.num
self.s = self.s + (elapsed - last_m) * (elapsed - self.m)
self.total_time += elapsed
if self.max_time < elapsed:
self.max_time = elapsed
def start_sample(self):
return Sample(self)
@property
def average(self):
if self.num == 0:
return 0
return self.total_time / self.num
@property
def stdev(self):
if self.num <= 1:
return 0
return math.sqrt(self.s / (self.num - 1))
def todict(self):
return {k: getattr(self, k) for k in ('num', 'total_time', 'max_time', 'average', 'stdev')}
@enum.unique
class Resolve(enum.Enum):
FAIL = enum.auto()
IGNORE = enum.auto()
REPLACE = enum.auto()
def insert_table(cursor, table, data, on_conflict):
resolve = {
Resolve.FAIL: "",
Resolve.IGNORE: " OR IGNORE",
Resolve.REPLACE: " OR REPLACE",
}[on_conflict]
keys = sorted(data.keys())
query = 'INSERT{resolve} INTO {table} ({fields}) VALUES({values})'.format(
resolve=resolve,
table=table,
fields=", ".join(keys),
values=", ".join(":" + k for k in keys),
)
prevrowid = cursor.lastrowid
cursor.execute(query, data)
logging.debug(
"Inserting %r into %s, %s",
data,
table,
on_conflict
)
return (cursor.lastrowid, cursor.lastrowid != prevrowid)
def insert_unihash(cursor, data, on_conflict):
return insert_table(cursor, "unihashes_v2", data, on_conflict)
def insert_outhash(cursor, data, on_conflict):
return insert_table(cursor, "outhashes_v2", data, on_conflict)
async def copy_unihash_from_upstream(client, db, method, taskhash):
d = await client.get_taskhash(method, taskhash)
if d is not None:
with closing(db.cursor()) as cursor:
insert_unihash(
cursor,
{k: v for k, v in d.items() if k in UNIHASH_TABLE_COLUMNS},
Resolve.IGNORE,
)
db.commit()
return d
class ServerCursor(object):
def __init__(self, db, cursor, upstream):
self.db = db
self.cursor = cursor
self.upstream = upstream
class ServerClient(bb.asyncrpc.AsyncServerConnection):
def __init__(self, socket, db, request_stats, backfill_queue, upstream, read_only):
super().__init__(socket, 'OEHASHEQUIV', logger)
self.db = db
self.request_stats = request_stats
self.max_chunk = bb.asyncrpc.DEFAULT_MAX_CHUNK
self.backfill_queue = backfill_queue
self.upstream = upstream
self.handlers.update({
'get': self.handle_get,
'get-outhash': self.handle_get_outhash,
'get-stream': self.handle_get_stream,
'get-stats': self.handle_get_stats,
})
if not read_only:
self.handlers.update({
'report': self.handle_report,
'report-equiv': self.handle_equivreport,
'reset-stats': self.handle_reset_stats,
'backfill-wait': self.handle_backfill_wait,
'remove': self.handle_remove,
'clean-unused': self.handle_clean_unused,
})
def validate_proto_version(self):
return (self.proto_version > (1, 0) and self.proto_version <= (1, 1))
async def process_requests(self):
if self.upstream is not None:
self.upstream_client = await create_async_client(self.upstream)
else:
self.upstream_client = None
await super().process_requests()
if self.upstream_client is not None:
await self.upstream_client.close()
async def dispatch_message(self, msg):
for k in self.handlers.keys():
if k in msg:
logger.debug('Handling %s' % k)
if 'stream' in k:
return await self.handlers[k](msg[k])
else:
with self.request_stats.start_sample() as self.request_sample, \
self.request_sample.measure():
return await self.handlers[k](msg[k])
raise bb.asyncrpc.ClientError("Unrecognized command %r" % msg)
async def handle_get(self, request):
method = request['method']
taskhash = request['taskhash']
fetch_all = request.get('all', False)
with closing(self.db.cursor()) as cursor:
return await self.get_unihash(cursor, method, taskhash, fetch_all)
async def get_unihash(self, cursor, method, taskhash, fetch_all=False):
d = None
if fetch_all:
cursor.execute(
'''
SELECT *, unihashes_v2.unihash AS unihash FROM outhashes_v2
INNER JOIN unihashes_v2 ON unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash
WHERE outhashes_v2.method=:method AND outhashes_v2.taskhash=:taskhash
ORDER BY outhashes_v2.created ASC
LIMIT 1
''',
{
'method': method,
'taskhash': taskhash,
}
)
row = cursor.fetchone()
if row is not None:
d = {k: row[k] for k in row.keys()}
elif self.upstream_client is not None:
d = await self.upstream_client.get_taskhash(method, taskhash, True)
self.update_unified(cursor, d)
self.db.commit()
else:
row = self.query_equivalent(cursor, method, taskhash)
if row is not None:
d = {k: row[k] for k in row.keys()}
elif self.upstream_client is not None:
d = await self.upstream_client.get_taskhash(method, taskhash)
d = {k: v for k, v in d.items() if k in UNIHASH_TABLE_COLUMNS}
insert_unihash(cursor, d, Resolve.IGNORE)
self.db.commit()
return d
async def handle_get_outhash(self, request):
method = request['method']
outhash = request['outhash']
taskhash = request['taskhash']
with_unihash = request.get("with_unihash", True)
with closing(self.db.cursor()) as cursor:
return await self.get_outhash(cursor, method, outhash, taskhash, with_unihash)
async def get_outhash(self, cursor, method, outhash, taskhash, with_unihash=True):
d = None
if with_unihash:
cursor.execute(
'''
SELECT *, unihashes_v2.unihash AS unihash FROM outhashes_v2
INNER JOIN unihashes_v2 ON unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash
WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash
ORDER BY outhashes_v2.created ASC
LIMIT 1
''',
{
'method': method,
'outhash': outhash,
}
)
else:
cursor.execute(
"""
SELECT * FROM outhashes_v2
WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash
ORDER BY outhashes_v2.created ASC
LIMIT 1
""",
{
'method': method,
'outhash': outhash,
}
)
row = cursor.fetchone()
if row is not None:
d = {k: row[k] for k in row.keys()}
elif self.upstream_client is not None:
d = await self.upstream_client.get_outhash(method, outhash, taskhash)
self.update_unified(cursor, d)
self.db.commit()
return d
def update_unified(self, cursor, data):
if data is None:
return
insert_unihash(
cursor,
{k: v for k, v in data.items() if k in UNIHASH_TABLE_COLUMNS},
Resolve.IGNORE
)
insert_outhash(
cursor,
{k: v for k, v in data.items() if k in OUTHASH_TABLE_COLUMNS},
Resolve.IGNORE
)
async def handle_get_stream(self, request):
await self.socket.send_message("ok")
while True:
upstream = None
l = await self.socket.recv()
if not l:
break
try:
# This inner loop is very sensitive and must be as fast as
# possible (which is why the request sample is handled manually
# instead of using 'with', and also why logging statements are
# commented out.
self.request_sample = self.request_stats.start_sample()
request_measure = self.request_sample.measure()
request_measure.start()
if l == 'END':
break
(method, taskhash) = l.split()
#logger.debug('Looking up %s %s' % (method, taskhash))
cursor = self.db.cursor()
try:
row = self.query_equivalent(cursor, method, taskhash)
finally:
cursor.close()
if row is not None:
msg = row['unihash']
#logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
elif self.upstream_client is not None:
upstream = await self.upstream_client.get_unihash(method, taskhash)
if upstream:
msg = upstream
else:
msg = ""
else:
msg = ""
await self.socket.send(msg)
finally:
request_measure.end()
self.request_sample.end()
# Post to the backfill queue after writing the result to minimize
# the turn around time on a request
if upstream is not None:
await self.backfill_queue.put((method, taskhash))
await self.socket.send("ok")
return self.NO_RESPONSE
async def handle_report(self, data):
with closing(self.db.cursor()) as cursor:
outhash_data = {
'method': data['method'],
'outhash': data['outhash'],
'taskhash': data['taskhash'],
'created': datetime.now()
}
for k in ('owner', 'PN', 'PV', 'PR', 'task', 'outhash_siginfo'):
if k in data:
outhash_data[k] = data[k]
# Insert the new entry, unless it already exists
(rowid, inserted) = insert_outhash(cursor, outhash_data, Resolve.IGNORE)
if inserted:
# If this row is new, check if it is equivalent to another
# output hash
cursor.execute(
'''
SELECT outhashes_v2.taskhash AS taskhash, unihashes_v2.unihash AS unihash FROM outhashes_v2
INNER JOIN unihashes_v2 ON unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash
-- Select any matching output hash except the one we just inserted
WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash AND outhashes_v2.taskhash!=:taskhash
-- Pick the oldest hash
ORDER BY outhashes_v2.created ASC
LIMIT 1
''',
{
'method': data['method'],
'outhash': data['outhash'],
'taskhash': data['taskhash'],
}
)
row = cursor.fetchone()
if row is not None:
# A matching output hash was found. Set our taskhash to the
# same unihash since they are equivalent
unihash = row['unihash']
resolve = Resolve.IGNORE
else:
# No matching output hash was found. This is probably the
# first outhash to be added.
unihash = data['unihash']
resolve = Resolve.IGNORE
# Query upstream to see if it has a unihash we can use
if self.upstream_client is not None:
upstream_data = await self.upstream_client.get_outhash(data['method'], data['outhash'], data['taskhash'])
if upstream_data is not None:
unihash = upstream_data['unihash']
insert_unihash(
cursor,
{
'method': data['method'],
'taskhash': data['taskhash'],
'unihash': unihash,
},
resolve
)
unihash_data = await self.get_unihash(cursor, data['method'], data['taskhash'])
if unihash_data is not None:
unihash = unihash_data['unihash']
else:
unihash = data['unihash']
self.db.commit()
d = {
'taskhash': data['taskhash'],
'method': data['method'],
'unihash': unihash,
}
return d
async def handle_equivreport(self, data):
with closing(self.db.cursor()) as cursor:
insert_data = {
'method': data['method'],
'taskhash': data['taskhash'],
'unihash': data['unihash'],
}
insert_unihash(cursor, insert_data, Resolve.IGNORE)
self.db.commit()
# Fetch the unihash that will be reported for the taskhash. If the
# unihash matches, it means this row was inserted (or the mapping
# was already valid)
row = self.query_equivalent(cursor, data['method'], data['taskhash'])
if row['unihash'] == data['unihash']:
logger.info('Adding taskhash equivalence for %s with unihash %s',
data['taskhash'], row['unihash'])
d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
return d
async def handle_get_stats(self, request):
return {
'requests': self.request_stats.todict(),
}
async def handle_reset_stats(self, request):
d = {
'requests': self.request_stats.todict(),
}
self.request_stats.reset()
return d
async def handle_backfill_wait(self, request):
d = {
'tasks': self.backfill_queue.qsize(),
}
await self.backfill_queue.join()
return d
async def handle_remove(self, request):
condition = request["where"]
if not isinstance(condition, dict):
raise TypeError("Bad condition type %s" % type(condition))
def do_remove(columns, table_name, cursor):
nonlocal condition
where = {}
for c in columns:
if c in condition and condition[c] is not None:
where[c] = condition[c]
if where:
query = ('DELETE FROM %s WHERE ' % table_name) + ' AND '.join("%s=:%s" % (k, k) for k in where.keys())
cursor.execute(query, where)
return cursor.rowcount
return 0
count = 0
with closing(self.db.cursor()) as cursor:
count += do_remove(OUTHASH_TABLE_COLUMNS, "outhashes_v2", cursor)
count += do_remove(UNIHASH_TABLE_COLUMNS, "unihashes_v2", cursor)
self.db.commit()
return {"count": count}
async def handle_clean_unused(self, request):
max_age = request["max_age_seconds"]
with closing(self.db.cursor()) as cursor:
cursor.execute(
"""
DELETE FROM outhashes_v2 WHERE created<:oldest AND NOT EXISTS (
SELECT unihashes_v2.id FROM unihashes_v2 WHERE unihashes_v2.method=outhashes_v2.method AND unihashes_v2.taskhash=outhashes_v2.taskhash LIMIT 1
)
""",
{
"oldest": datetime.now() - timedelta(seconds=-max_age)
}
)
count = cursor.rowcount
return {"count": count}
def query_equivalent(self, cursor, method, taskhash):
# This is part of the inner loop and must be as fast as possible
cursor.execute(
'SELECT taskhash, method, unihash FROM unihashes_v2 WHERE method=:method AND taskhash=:taskhash',
{
'method': method,
'taskhash': taskhash,
}
)
return cursor.fetchone()
class Server(bb.asyncrpc.AsyncServer):
def __init__(self, db, upstream=None, read_only=False):
if upstream and read_only:
raise bb.asyncrpc.ServerError("Read-only hashserv cannot pull from an upstream server")
super().__init__(logger)
self.request_stats = Stats()
self.db = db
self.upstream = upstream
self.read_only = read_only
self.backfill_queue = None
def accept_client(self, socket):
return ServerClient(socket, self.db, self.request_stats, self.backfill_queue, self.upstream, self.read_only)
async def backfill_worker_task(self):
client = await create_async_client(self.upstream)
try:
while True:
item = await self.backfill_queue.get()
if item is None:
self.backfill_queue.task_done()
break
method, taskhash = item
await copy_unihash_from_upstream(client, self.db, method, taskhash)
self.backfill_queue.task_done()
finally:
await client.close()
def start(self):
tasks = super().start()
if self.upstream:
self.backfill_queue = asyncio.Queue()
tasks += [self.backfill_worker_task()]
return tasks
async def stop(self):
if self.backfill_queue is not None:
await self.backfill_queue.put(None)
await super().stop()