mirror of
git://git.yoctoproject.org/poky.git
synced 2025-07-19 21:09:03 +02:00
bitbake: hashserv: Refactor to use asyncrpc
The asyncrpc module can now be used to provide the json & asyncio based RPC system used by hashserv. (Bitbake rev: 5afb9586b0a4a23a05efb0e8ff4a97262631ae4a) Signed-off-by: Paul Barker <pbarker@konsulko.com> Signed-off-by: Richard Purdie <richard.purdie@linuxfoundation.org>
This commit is contained in:
parent
244b044fd6
commit
421e86e7ed
|
@ -8,107 +8,27 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import socket
|
import socket
|
||||||
import os
|
import os
|
||||||
from . import chunkify, DEFAULT_MAX_CHUNK, create_async_client
|
import bb.asyncrpc
|
||||||
|
from . import create_async_client
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger("hashserv.client")
|
logger = logging.getLogger("hashserv.client")
|
||||||
|
|
||||||
|
|
||||||
class AsyncClient(object):
|
class AsyncClient(bb.asyncrpc.AsyncClient):
|
||||||
MODE_NORMAL = 0
|
MODE_NORMAL = 0
|
||||||
MODE_GET_STREAM = 1
|
MODE_GET_STREAM = 1
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.reader = None
|
super().__init__('OEHASHEQUIV', '1.1', logger)
|
||||||
self.writer = None
|
|
||||||
self.mode = self.MODE_NORMAL
|
self.mode = self.MODE_NORMAL
|
||||||
self.max_chunk = DEFAULT_MAX_CHUNK
|
|
||||||
|
|
||||||
async def connect_tcp(self, address, port):
|
|
||||||
async def connect_sock():
|
|
||||||
return await asyncio.open_connection(address, port)
|
|
||||||
|
|
||||||
self._connect_sock = connect_sock
|
|
||||||
|
|
||||||
async def connect_unix(self, path):
|
|
||||||
async def connect_sock():
|
|
||||||
return await asyncio.open_unix_connection(path)
|
|
||||||
|
|
||||||
self._connect_sock = connect_sock
|
|
||||||
|
|
||||||
async def connect(self):
|
|
||||||
if self.reader is None or self.writer is None:
|
|
||||||
(self.reader, self.writer) = await self._connect_sock()
|
|
||||||
|
|
||||||
self.writer.write("OEHASHEQUIV 1.1\n\n".encode("utf-8"))
|
|
||||||
await self.writer.drain()
|
|
||||||
|
|
||||||
|
async def setup_connection(self):
|
||||||
|
await super().setup_connection()
|
||||||
cur_mode = self.mode
|
cur_mode = self.mode
|
||||||
self.mode = self.MODE_NORMAL
|
self.mode = self.MODE_NORMAL
|
||||||
await self._set_mode(cur_mode)
|
await self._set_mode(cur_mode)
|
||||||
|
|
||||||
async def close(self):
|
|
||||||
self.reader = None
|
|
||||||
|
|
||||||
if self.writer is not None:
|
|
||||||
self.writer.close()
|
|
||||||
self.writer = None
|
|
||||||
|
|
||||||
async def _send_wrapper(self, proc):
|
|
||||||
count = 0
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
await self.connect()
|
|
||||||
return await proc()
|
|
||||||
except (
|
|
||||||
OSError,
|
|
||||||
ConnectionError,
|
|
||||||
json.JSONDecodeError,
|
|
||||||
UnicodeDecodeError,
|
|
||||||
) as e:
|
|
||||||
logger.warning("Error talking to server: %s" % e)
|
|
||||||
if count >= 3:
|
|
||||||
if not isinstance(e, ConnectionError):
|
|
||||||
raise ConnectionError(str(e))
|
|
||||||
raise e
|
|
||||||
await self.close()
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
async def send_message(self, msg):
|
|
||||||
async def get_line():
|
|
||||||
line = await self.reader.readline()
|
|
||||||
if not line:
|
|
||||||
raise ConnectionError("Connection closed")
|
|
||||||
|
|
||||||
line = line.decode("utf-8")
|
|
||||||
|
|
||||||
if not line.endswith("\n"):
|
|
||||||
raise ConnectionError("Bad message %r" % message)
|
|
||||||
|
|
||||||
return line
|
|
||||||
|
|
||||||
async def proc():
|
|
||||||
for c in chunkify(json.dumps(msg), self.max_chunk):
|
|
||||||
self.writer.write(c.encode("utf-8"))
|
|
||||||
await self.writer.drain()
|
|
||||||
|
|
||||||
l = await get_line()
|
|
||||||
|
|
||||||
m = json.loads(l)
|
|
||||||
if m and "chunk-stream" in m:
|
|
||||||
lines = []
|
|
||||||
while True:
|
|
||||||
l = (await get_line()).rstrip("\n")
|
|
||||||
if not l:
|
|
||||||
break
|
|
||||||
lines.append(l)
|
|
||||||
|
|
||||||
m = json.loads("".join(lines))
|
|
||||||
|
|
||||||
return m
|
|
||||||
|
|
||||||
return await self._send_wrapper(proc)
|
|
||||||
|
|
||||||
async def send_stream(self, msg):
|
async def send_stream(self, msg):
|
||||||
async def proc():
|
async def proc():
|
||||||
self.writer.write(("%s\n" % msg).encode("utf-8"))
|
self.writer.write(("%s\n" % msg).encode("utf-8"))
|
||||||
|
@ -185,12 +105,10 @@ class AsyncClient(object):
|
||||||
return (await self.send_message({"backfill-wait": None}))["tasks"]
|
return (await self.send_message({"backfill-wait": None}))["tasks"]
|
||||||
|
|
||||||
|
|
||||||
class Client(object):
|
class Client(bb.asyncrpc.Client):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.client = AsyncClient()
|
super().__init__()
|
||||||
self.loop = asyncio.new_event_loop()
|
self._add_methods(
|
||||||
|
|
||||||
for call in (
|
|
||||||
"connect_tcp",
|
"connect_tcp",
|
||||||
"close",
|
"close",
|
||||||
"get_unihash",
|
"get_unihash",
|
||||||
|
@ -200,30 +118,7 @@ class Client(object):
|
||||||
"get_stats",
|
"get_stats",
|
||||||
"reset_stats",
|
"reset_stats",
|
||||||
"backfill_wait",
|
"backfill_wait",
|
||||||
):
|
)
|
||||||
downcall = getattr(self.client, call)
|
|
||||||
setattr(self, call, self._get_downcall_wrapper(downcall))
|
|
||||||
|
|
||||||
def _get_downcall_wrapper(self, downcall):
|
def _get_async_client(self):
|
||||||
def wrapper(*args, **kwargs):
|
return AsyncClient()
|
||||||
return self.loop.run_until_complete(downcall(*args, **kwargs))
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
def connect_unix(self, path):
|
|
||||||
# AF_UNIX has path length issues so chdir here to workaround
|
|
||||||
cwd = os.getcwd()
|
|
||||||
try:
|
|
||||||
os.chdir(os.path.dirname(path))
|
|
||||||
self.loop.run_until_complete(self.client.connect_unix(os.path.basename(path)))
|
|
||||||
self.loop.run_until_complete(self.client.connect())
|
|
||||||
finally:
|
|
||||||
os.chdir(cwd)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def max_chunk(self):
|
|
||||||
return self.client.max_chunk
|
|
||||||
|
|
||||||
@max_chunk.setter
|
|
||||||
def max_chunk(self, value):
|
|
||||||
self.client.max_chunk = value
|
|
||||||
|
|
|
@ -14,7 +14,9 @@ import signal
|
||||||
import socket
|
import socket
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from . import chunkify, DEFAULT_MAX_CHUNK, create_async_client, TABLE_COLUMNS
|
from . import create_async_client, TABLE_COLUMNS
|
||||||
|
import bb.asyncrpc
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger('hashserv.server')
|
logger = logging.getLogger('hashserv.server')
|
||||||
|
|
||||||
|
@ -109,12 +111,6 @@ class Stats(object):
|
||||||
return {k: getattr(self, k) for k in ('num', 'total_time', 'max_time', 'average', 'stdev')}
|
return {k: getattr(self, k) for k in ('num', 'total_time', 'max_time', 'average', 'stdev')}
|
||||||
|
|
||||||
|
|
||||||
class ClientError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
class ServerError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def insert_task(cursor, data, ignore=False):
|
def insert_task(cursor, data, ignore=False):
|
||||||
keys = sorted(data.keys())
|
keys = sorted(data.keys())
|
||||||
query = '''INSERT%s INTO tasks_v2 (%s) VALUES (%s)''' % (
|
query = '''INSERT%s INTO tasks_v2 (%s) VALUES (%s)''' % (
|
||||||
|
@ -149,7 +145,7 @@ async def copy_outhash_from_upstream(client, db, method, outhash, taskhash):
|
||||||
|
|
||||||
return d
|
return d
|
||||||
|
|
||||||
class ServerClient(object):
|
class ServerClient(bb.asyncrpc.AsyncServerConnection):
|
||||||
FAST_QUERY = 'SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1'
|
FAST_QUERY = 'SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1'
|
||||||
ALL_QUERY = 'SELECT * FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1'
|
ALL_QUERY = 'SELECT * FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1'
|
||||||
OUTHASH_QUERY = '''
|
OUTHASH_QUERY = '''
|
||||||
|
@ -168,21 +164,19 @@ class ServerClient(object):
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def __init__(self, reader, writer, db, request_stats, backfill_queue, upstream, read_only):
|
def __init__(self, reader, writer, db, request_stats, backfill_queue, upstream, read_only):
|
||||||
self.reader = reader
|
super().__init__(reader, writer, 'OEHASHEQUIV', logger)
|
||||||
self.writer = writer
|
|
||||||
self.db = db
|
self.db = db
|
||||||
self.request_stats = request_stats
|
self.request_stats = request_stats
|
||||||
self.max_chunk = DEFAULT_MAX_CHUNK
|
self.max_chunk = bb.asyncrpc.DEFAULT_MAX_CHUNK
|
||||||
self.backfill_queue = backfill_queue
|
self.backfill_queue = backfill_queue
|
||||||
self.upstream = upstream
|
self.upstream = upstream
|
||||||
|
|
||||||
self.handlers = {
|
self.handlers.update({
|
||||||
'get': self.handle_get,
|
'get': self.handle_get,
|
||||||
'get-outhash': self.handle_get_outhash,
|
'get-outhash': self.handle_get_outhash,
|
||||||
'get-stream': self.handle_get_stream,
|
'get-stream': self.handle_get_stream,
|
||||||
'get-stats': self.handle_get_stats,
|
'get-stats': self.handle_get_stats,
|
||||||
'chunk-stream': self.handle_chunk,
|
})
|
||||||
}
|
|
||||||
|
|
||||||
if not read_only:
|
if not read_only:
|
||||||
self.handlers.update({
|
self.handlers.update({
|
||||||
|
@ -192,57 +186,20 @@ class ServerClient(object):
|
||||||
'backfill-wait': self.handle_backfill_wait,
|
'backfill-wait': self.handle_backfill_wait,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
def validate_proto_version(self):
|
||||||
|
return (self.proto_version > (1, 0) and self.proto_version <= (1, 1))
|
||||||
|
|
||||||
async def process_requests(self):
|
async def process_requests(self):
|
||||||
if self.upstream is not None:
|
if self.upstream is not None:
|
||||||
self.upstream_client = await create_async_client(self.upstream)
|
self.upstream_client = await create_async_client(self.upstream)
|
||||||
else:
|
else:
|
||||||
self.upstream_client = None
|
self.upstream_client = None
|
||||||
|
|
||||||
try:
|
await super().process_requests()
|
||||||
|
|
||||||
|
|
||||||
self.addr = self.writer.get_extra_info('peername')
|
|
||||||
logger.debug('Client %r connected' % (self.addr,))
|
|
||||||
|
|
||||||
# Read protocol and version
|
|
||||||
protocol = await self.reader.readline()
|
|
||||||
if protocol is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
(proto_name, proto_version) = protocol.decode('utf-8').rstrip().split()
|
|
||||||
if proto_name != 'OEHASHEQUIV':
|
|
||||||
return
|
|
||||||
|
|
||||||
proto_version = tuple(int(v) for v in proto_version.split('.'))
|
|
||||||
if proto_version < (1, 0) or proto_version > (1, 1):
|
|
||||||
return
|
|
||||||
|
|
||||||
# Read headers. Currently, no headers are implemented, so look for
|
|
||||||
# an empty line to signal the end of the headers
|
|
||||||
while True:
|
|
||||||
line = await self.reader.readline()
|
|
||||||
if line is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
line = line.decode('utf-8').rstrip()
|
|
||||||
if not line:
|
|
||||||
break
|
|
||||||
|
|
||||||
# Handle messages
|
|
||||||
while True:
|
|
||||||
d = await self.read_message()
|
|
||||||
if d is None:
|
|
||||||
break
|
|
||||||
await self.dispatch_message(d)
|
|
||||||
await self.writer.drain()
|
|
||||||
except ClientError as e:
|
|
||||||
logger.error(str(e))
|
|
||||||
finally:
|
|
||||||
if self.upstream_client is not None:
|
if self.upstream_client is not None:
|
||||||
await self.upstream_client.close()
|
await self.upstream_client.close()
|
||||||
|
|
||||||
self.writer.close()
|
|
||||||
|
|
||||||
async def dispatch_message(self, msg):
|
async def dispatch_message(self, msg):
|
||||||
for k in self.handlers.keys():
|
for k in self.handlers.keys():
|
||||||
if k in msg:
|
if k in msg:
|
||||||
|
@ -255,47 +212,7 @@ class ServerClient(object):
|
||||||
await self.handlers[k](msg[k])
|
await self.handlers[k](msg[k])
|
||||||
return
|
return
|
||||||
|
|
||||||
raise ClientError("Unrecognized command %r" % msg)
|
raise bb.asyncrpc.ClientError("Unrecognized command %r" % msg)
|
||||||
|
|
||||||
def write_message(self, msg):
|
|
||||||
for c in chunkify(json.dumps(msg), self.max_chunk):
|
|
||||||
self.writer.write(c.encode('utf-8'))
|
|
||||||
|
|
||||||
async def read_message(self):
|
|
||||||
l = await self.reader.readline()
|
|
||||||
if not l:
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
message = l.decode('utf-8')
|
|
||||||
|
|
||||||
if not message.endswith('\n'):
|
|
||||||
return None
|
|
||||||
|
|
||||||
return json.loads(message)
|
|
||||||
except (json.JSONDecodeError, UnicodeDecodeError) as e:
|
|
||||||
logger.error('Bad message from client: %r' % message)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
async def handle_chunk(self, request):
|
|
||||||
lines = []
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
l = await self.reader.readline()
|
|
||||||
l = l.rstrip(b"\n").decode("utf-8")
|
|
||||||
if not l:
|
|
||||||
break
|
|
||||||
lines.append(l)
|
|
||||||
|
|
||||||
msg = json.loads(''.join(lines))
|
|
||||||
except (json.JSONDecodeError, UnicodeDecodeError) as e:
|
|
||||||
logger.error('Bad message from client: %r' % message)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
if 'chunk-stream' in msg:
|
|
||||||
raise ClientError("Nested chunks are not allowed")
|
|
||||||
|
|
||||||
await self.dispatch_message(msg)
|
|
||||||
|
|
||||||
async def handle_get(self, request):
|
async def handle_get(self, request):
|
||||||
method = request['method']
|
method = request['method']
|
||||||
|
@ -499,74 +416,20 @@ class ServerClient(object):
|
||||||
cursor.close()
|
cursor.close()
|
||||||
|
|
||||||
|
|
||||||
class Server(object):
|
class Server(bb.asyncrpc.AsyncServer):
|
||||||
def __init__(self, db, loop=None, upstream=None, read_only=False):
|
def __init__(self, db, loop=None, upstream=None, read_only=False):
|
||||||
if upstream and read_only:
|
if upstream and read_only:
|
||||||
raise ServerError("Read-only hashserv cannot pull from an upstream server")
|
raise bb.asyncrpc.ServerError("Read-only hashserv cannot pull from an upstream server")
|
||||||
|
|
||||||
|
super().__init__(logger, loop)
|
||||||
|
|
||||||
self.request_stats = Stats()
|
self.request_stats = Stats()
|
||||||
self.db = db
|
self.db = db
|
||||||
|
|
||||||
if loop is None:
|
|
||||||
self.loop = asyncio.new_event_loop()
|
|
||||||
self.close_loop = True
|
|
||||||
else:
|
|
||||||
self.loop = loop
|
|
||||||
self.close_loop = False
|
|
||||||
|
|
||||||
self.upstream = upstream
|
self.upstream = upstream
|
||||||
self.read_only = read_only
|
self.read_only = read_only
|
||||||
|
|
||||||
self._cleanup_socket = None
|
def accept_client(self, reader, writer):
|
||||||
|
return ServerClient(reader, writer, self.db, self.request_stats, self.backfill_queue, self.upstream, self.read_only)
|
||||||
def start_tcp_server(self, host, port):
|
|
||||||
self.server = self.loop.run_until_complete(
|
|
||||||
asyncio.start_server(self.handle_client, host, port, loop=self.loop)
|
|
||||||
)
|
|
||||||
|
|
||||||
for s in self.server.sockets:
|
|
||||||
logger.info('Listening on %r' % (s.getsockname(),))
|
|
||||||
# Newer python does this automatically. Do it manually here for
|
|
||||||
# maximum compatibility
|
|
||||||
s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
|
|
||||||
s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
|
|
||||||
|
|
||||||
name = self.server.sockets[0].getsockname()
|
|
||||||
if self.server.sockets[0].family == socket.AF_INET6:
|
|
||||||
self.address = "[%s]:%d" % (name[0], name[1])
|
|
||||||
else:
|
|
||||||
self.address = "%s:%d" % (name[0], name[1])
|
|
||||||
|
|
||||||
def start_unix_server(self, path):
|
|
||||||
def cleanup():
|
|
||||||
os.unlink(path)
|
|
||||||
|
|
||||||
cwd = os.getcwd()
|
|
||||||
try:
|
|
||||||
# Work around path length limits in AF_UNIX
|
|
||||||
os.chdir(os.path.dirname(path))
|
|
||||||
self.server = self.loop.run_until_complete(
|
|
||||||
asyncio.start_unix_server(self.handle_client, os.path.basename(path), loop=self.loop)
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
os.chdir(cwd)
|
|
||||||
|
|
||||||
logger.info('Listening on %r' % path)
|
|
||||||
|
|
||||||
self._cleanup_socket = cleanup
|
|
||||||
self.address = "unix://%s" % os.path.abspath(path)
|
|
||||||
|
|
||||||
async def handle_client(self, reader, writer):
|
|
||||||
# writer.transport.set_write_buffer_limits(0)
|
|
||||||
try:
|
|
||||||
client = ServerClient(reader, writer, self.db, self.request_stats, self.backfill_queue, self.upstream, self.read_only)
|
|
||||||
await client.process_requests()
|
|
||||||
except Exception as e:
|
|
||||||
import traceback
|
|
||||||
logger.error('Error from client: %s' % str(e), exc_info=True)
|
|
||||||
traceback.print_exc()
|
|
||||||
writer.close()
|
|
||||||
logger.info('Client disconnected')
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def _backfill_worker(self):
|
def _backfill_worker(self):
|
||||||
|
@ -597,31 +460,8 @@ class Server(object):
|
||||||
else:
|
else:
|
||||||
yield
|
yield
|
||||||
|
|
||||||
def serve_forever(self):
|
def run_loop_forever(self):
|
||||||
def signal_handler():
|
|
||||||
self.loop.stop()
|
|
||||||
|
|
||||||
asyncio.set_event_loop(self.loop)
|
|
||||||
try:
|
|
||||||
self.backfill_queue = asyncio.Queue()
|
self.backfill_queue = asyncio.Queue()
|
||||||
|
|
||||||
self.loop.add_signal_handler(signal.SIGTERM, signal_handler)
|
|
||||||
|
|
||||||
with self._backfill_worker():
|
with self._backfill_worker():
|
||||||
try:
|
super().run_loop_forever()
|
||||||
self.loop.run_forever()
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
pass
|
|
||||||
|
|
||||||
self.server.close()
|
|
||||||
|
|
||||||
self.loop.run_until_complete(self.server.wait_closed())
|
|
||||||
logger.info('Server shutting down')
|
|
||||||
finally:
|
|
||||||
if self.close_loop:
|
|
||||||
if sys.version_info >= (3, 6):
|
|
||||||
self.loop.run_until_complete(self.loop.shutdown_asyncgens())
|
|
||||||
self.loop.close()
|
|
||||||
|
|
||||||
if self._cleanup_socket is not None:
|
|
||||||
self._cleanup_socket()
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user