bitbake: asyncrpc: Abstract sockets

Rewrites the asyncrpc client and server code to make it possible to have
other transport backends that are not stream based (e.g. websockets
which are message based). The connection handling classes are now shared
between both the client and server to make it easier to implement new
transport mechanisms

(Bitbake rev: 2aaeae53696e4c2f13a169830c3b7089cbad6eca)

Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
Signed-off-by: Richard Purdie <richard.purdie@linuxfoundation.org>
This commit is contained in:
Joshua Watt 2023-11-03 08:26:19 -06:00 committed by Richard Purdie
parent f97b686884
commit 8f8501ed40
10 changed files with 396 additions and 362 deletions

View File

@ -4,30 +4,12 @@
# SPDX-License-Identifier: GPL-2.0-only
#
import itertools
import json
# The Python async server defaults to a 64K receive buffer, so we hardcode our
# maximum chunk size. It would be better if the client and server reported to
# each other what the maximum chunk sizes were, but that will slow down the
# connection setup with a round trip delay so I'd rather not do that unless it
# is necessary
DEFAULT_MAX_CHUNK = 32 * 1024
def chunkify(msg, max_chunk):
if len(msg) < max_chunk - 1:
yield ''.join((msg, "\n"))
else:
yield ''.join((json.dumps({
'chunk-stream': None
}), "\n"))
args = [iter(msg)] * (max_chunk - 1)
for m in map(''.join, itertools.zip_longest(*args, fillvalue='')):
yield ''.join(itertools.chain(m, "\n"))
yield "\n"
from .client import AsyncClient, Client
from .serv import AsyncServer, AsyncServerConnection, ClientError, ServerError
from .serv import AsyncServer, AsyncServerConnection
from .connection import DEFAULT_MAX_CHUNK
from .exceptions import (
ClientError,
ServerError,
ConnectionClosedError,
)

View File

@ -10,13 +10,13 @@ import json
import os
import socket
import sys
from . import chunkify, DEFAULT_MAX_CHUNK
from .connection import StreamConnection, DEFAULT_MAX_CHUNK
from .exceptions import ConnectionClosedError
class AsyncClient(object):
def __init__(self, proto_name, proto_version, logger, timeout=30):
self.reader = None
self.writer = None
self.socket = None
self.max_chunk = DEFAULT_MAX_CHUNK
self.proto_name = proto_name
self.proto_version = proto_version
@ -25,7 +25,8 @@ class AsyncClient(object):
async def connect_tcp(self, address, port):
async def connect_sock():
return await asyncio.open_connection(address, port)
reader, writer = await asyncio.open_connection(address, port)
return StreamConnection(reader, writer, self.timeout, self.max_chunk)
self._connect_sock = connect_sock
@ -40,27 +41,27 @@ class AsyncClient(object):
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0)
sock.connect(os.path.basename(path))
finally:
os.chdir(cwd)
return await asyncio.open_unix_connection(sock=sock)
os.chdir(cwd)
reader, writer = await asyncio.open_unix_connection(sock=sock)
return StreamConnection(reader, writer, self.timeout, self.max_chunk)
self._connect_sock = connect_sock
async def setup_connection(self):
s = '%s %s\n\n' % (self.proto_name, self.proto_version)
self.writer.write(s.encode("utf-8"))
await self.writer.drain()
# Send headers
await self.socket.send("%s %s" % (self.proto_name, self.proto_version))
# End of headers
await self.socket.send("")
async def connect(self):
if self.reader is None or self.writer is None:
(self.reader, self.writer) = await self._connect_sock()
if self.socket is None:
self.socket = await self._connect_sock()
await self.setup_connection()
async def close(self):
self.reader = None
if self.writer is not None:
self.writer.close()
self.writer = None
if self.socket is not None:
await self.socket.close()
self.socket = None
async def _send_wrapper(self, proc):
count = 0
@ -71,6 +72,7 @@ class AsyncClient(object):
except (
OSError,
ConnectionError,
ConnectionClosedError,
json.JSONDecodeError,
UnicodeDecodeError,
) as e:
@ -82,49 +84,15 @@ class AsyncClient(object):
await self.close()
count += 1
async def send_message(self, msg):
async def get_line():
try:
line = await asyncio.wait_for(self.reader.readline(), self.timeout)
except asyncio.TimeoutError:
raise ConnectionError("Timed out waiting for server")
if not line:
raise ConnectionError("Connection closed")
line = line.decode("utf-8")
if not line.endswith("\n"):
raise ConnectionError("Bad message %r" % (line))
return line
async def invoke(self, msg):
async def proc():
for c in chunkify(json.dumps(msg), self.max_chunk):
self.writer.write(c.encode("utf-8"))
await self.writer.drain()
l = await get_line()
m = json.loads(l)
if m and "chunk-stream" in m:
lines = []
while True:
l = (await get_line()).rstrip("\n")
if not l:
break
lines.append(l)
m = json.loads("".join(lines))
return m
await self.socket.send_message(msg)
return await self.socket.recv_message()
return await self._send_wrapper(proc)
async def ping(self):
return await self.send_message(
{'ping': {}}
)
return await self.invoke({"ping": {}})
class Client(object):
@ -142,7 +110,7 @@ class Client(object):
# required (but harmless) with it.
asyncio.set_event_loop(self.loop)
self._add_methods('connect_tcp', 'ping')
self._add_methods("connect_tcp", "ping")
@abc.abstractmethod
def _get_async_client(self):

View File

@ -0,0 +1,95 @@
#
# Copyright BitBake Contributors
#
# SPDX-License-Identifier: GPL-2.0-only
#
import asyncio
import itertools
import json
from .exceptions import ClientError, ConnectionClosedError
# The Python async server defaults to a 64K receive buffer, so we hardcode our
# maximum chunk size. It would be better if the client and server reported to
# each other what the maximum chunk sizes were, but that will slow down the
# connection setup with a round trip delay so I'd rather not do that unless it
# is necessary
DEFAULT_MAX_CHUNK = 32 * 1024
def chunkify(msg, max_chunk):
if len(msg) < max_chunk - 1:
yield "".join((msg, "\n"))
else:
yield "".join((json.dumps({"chunk-stream": None}), "\n"))
args = [iter(msg)] * (max_chunk - 1)
for m in map("".join, itertools.zip_longest(*args, fillvalue="")):
yield "".join(itertools.chain(m, "\n"))
yield "\n"
class StreamConnection(object):
def __init__(self, reader, writer, timeout, max_chunk=DEFAULT_MAX_CHUNK):
self.reader = reader
self.writer = writer
self.timeout = timeout
self.max_chunk = max_chunk
@property
def address(self):
return self.writer.get_extra_info("peername")
async def send_message(self, msg):
for c in chunkify(json.dumps(msg), self.max_chunk):
self.writer.write(c.encode("utf-8"))
await self.writer.drain()
async def recv_message(self):
l = await self.recv()
m = json.loads(l)
if not m:
return m
if "chunk-stream" in m:
lines = []
while True:
l = await self.recv()
if not l:
break
lines.append(l)
m = json.loads("".join(lines))
return m
async def send(self, msg):
self.writer.write(("%s\n" % msg).encode("utf-8"))
await self.writer.drain()
async def recv(self):
if self.timeout < 0:
line = await self.reader.readline()
else:
try:
line = await asyncio.wait_for(self.reader.readline(), self.timeout)
except asyncio.TimeoutError:
raise ConnectionError("Timed out waiting for data")
if not line:
raise ConnectionClosedError("Connection closed")
line = line.decode("utf-8")
if not line.endswith("\n"):
raise ConnectionError("Bad message %r" % (line))
return line.rstrip()
async def close(self):
self.reader = None
if self.writer is not None:
self.writer.close()
self.writer = None

View File

@ -0,0 +1,17 @@
#
# Copyright BitBake Contributors
#
# SPDX-License-Identifier: GPL-2.0-only
#
class ClientError(Exception):
pass
class ServerError(Exception):
pass
class ConnectionClosedError(Exception):
pass

View File

@ -12,241 +12,248 @@ import signal
import socket
import sys
import multiprocessing
from . import chunkify, DEFAULT_MAX_CHUNK
class ClientError(Exception):
pass
class ServerError(Exception):
pass
from .connection import StreamConnection
from .exceptions import ClientError, ServerError, ConnectionClosedError
class AsyncServerConnection(object):
def __init__(self, reader, writer, proto_name, logger):
self.reader = reader
self.writer = writer
# If a handler returns this object (e.g. `return self.NO_RESPONSE`), no
# return message will be automatically be sent back to the client
NO_RESPONSE = object()
def __init__(self, socket, proto_name, logger):
self.socket = socket
self.proto_name = proto_name
self.max_chunk = DEFAULT_MAX_CHUNK
self.handlers = {
'chunk-stream': self.handle_chunk,
'ping': self.handle_ping,
"ping": self.handle_ping,
}
self.logger = logger
async def close(self):
await self.socket.close()
async def process_requests(self):
try:
self.addr = self.writer.get_extra_info('peername')
self.logger.debug('Client %r connected' % (self.addr,))
self.logger.info("Client %r connected" % (self.socket.address,))
# Read protocol and version
client_protocol = await self.reader.readline()
client_protocol = await self.socket.recv()
if not client_protocol:
return
(client_proto_name, client_proto_version) = client_protocol.decode('utf-8').rstrip().split()
(client_proto_name, client_proto_version) = client_protocol.split()
if client_proto_name != self.proto_name:
self.logger.debug('Rejecting invalid protocol %s' % (self.proto_name))
self.logger.debug("Rejecting invalid protocol %s" % (self.proto_name))
return
self.proto_version = tuple(int(v) for v in client_proto_version.split('.'))
self.proto_version = tuple(int(v) for v in client_proto_version.split("."))
if not self.validate_proto_version():
self.logger.debug('Rejecting invalid protocol version %s' % (client_proto_version))
self.logger.debug(
"Rejecting invalid protocol version %s" % (client_proto_version)
)
return
# Read headers. Currently, no headers are implemented, so look for
# an empty line to signal the end of the headers
while True:
line = await self.reader.readline()
if not line:
return
line = line.decode('utf-8').rstrip()
if not line:
header = await self.socket.recv()
if not header:
break
# Handle messages
while True:
d = await self.read_message()
d = await self.socket.recv_message()
if d is None:
break
await self.dispatch_message(d)
await self.writer.drain()
except ClientError as e:
response = await self.dispatch_message(d)
if response is not self.NO_RESPONSE:
await self.socket.send_message(response)
except ConnectionClosedError as e:
self.logger.info(str(e))
except (ClientError, ConnectionError) as e:
self.logger.error(str(e))
finally:
self.writer.close()
await self.close()
async def dispatch_message(self, msg):
for k in self.handlers.keys():
if k in msg:
self.logger.debug('Handling %s' % k)
await self.handlers[k](msg[k])
return
self.logger.debug("Handling %s" % k)
return await self.handlers[k](msg[k])
raise ClientError("Unrecognized command %r" % msg)
def write_message(self, msg):
for c in chunkify(json.dumps(msg), self.max_chunk):
self.writer.write(c.encode('utf-8'))
async def read_message(self):
l = await self.reader.readline()
if not l:
return None
try:
message = l.decode('utf-8')
if not message.endswith('\n'):
return None
return json.loads(message)
except (json.JSONDecodeError, UnicodeDecodeError) as e:
self.logger.error('Bad message from client: %r' % message)
raise e
async def handle_chunk(self, request):
lines = []
try:
while True:
l = await self.reader.readline()
l = l.rstrip(b"\n").decode("utf-8")
if not l:
break
lines.append(l)
msg = json.loads(''.join(lines))
except (json.JSONDecodeError, UnicodeDecodeError) as e:
self.logger.error('Bad message from client: %r' % lines)
raise e
if 'chunk-stream' in msg:
raise ClientError("Nested chunks are not allowed")
await self.dispatch_message(msg)
async def handle_ping(self, request):
response = {'alive': True}
self.write_message(response)
return {"alive": True}
class StreamServer(object):
def __init__(self, handler, logger):
self.handler = handler
self.logger = logger
self.closed = False
async def handle_stream_client(self, reader, writer):
# writer.transport.set_write_buffer_limits(0)
socket = StreamConnection(reader, writer, -1)
if self.closed:
await socket.close()
return
await self.handler(socket)
async def stop(self):
self.closed = True
class TCPStreamServer(StreamServer):
def __init__(self, host, port, handler, logger):
super().__init__(handler, logger)
self.host = host
self.port = port
def start(self, loop):
self.server = loop.run_until_complete(
asyncio.start_server(self.handle_stream_client, self.host, self.port)
)
for s in self.server.sockets:
self.logger.debug("Listening on %r" % (s.getsockname(),))
# Newer python does this automatically. Do it manually here for
# maximum compatibility
s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
# Enable keep alives. This prevents broken client connections
# from persisting on the server for long periods of time.
s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30)
s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15)
s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
name = self.server.sockets[0].getsockname()
if self.server.sockets[0].family == socket.AF_INET6:
self.address = "[%s]:%d" % (name[0], name[1])
else:
self.address = "%s:%d" % (name[0], name[1])
return [self.server.wait_closed()]
async def stop(self):
await super().stop()
self.server.close()
def cleanup(self):
pass
class UnixStreamServer(StreamServer):
def __init__(self, path, handler, logger):
super().__init__(handler, logger)
self.path = path
def start(self, loop):
cwd = os.getcwd()
try:
# Work around path length limits in AF_UNIX
os.chdir(os.path.dirname(self.path))
self.server = loop.run_until_complete(
asyncio.start_unix_server(
self.handle_stream_client, os.path.basename(self.path)
)
)
finally:
os.chdir(cwd)
self.logger.debug("Listening on %r" % self.path)
self.address = "unix://%s" % os.path.abspath(self.path)
return [self.server.wait_closed()]
async def stop(self):
await super().stop()
self.server.close()
def cleanup(self):
os.unlink(self.path)
class AsyncServer(object):
def __init__(self, logger):
self._cleanup_socket = None
self.logger = logger
self.start = None
self.address = None
self.loop = None
self.run_tasks = []
def start_tcp_server(self, host, port):
def start_tcp():
self.server = self.loop.run_until_complete(
asyncio.start_server(self.handle_client, host, port)
)
for s in self.server.sockets:
self.logger.debug('Listening on %r' % (s.getsockname(),))
# Newer python does this automatically. Do it manually here for
# maximum compatibility
s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
# Enable keep alives. This prevents broken client connections
# from persisting on the server for long periods of time.
s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30)
s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15)
s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
name = self.server.sockets[0].getsockname()
if self.server.sockets[0].family == socket.AF_INET6:
self.address = "[%s]:%d" % (name[0], name[1])
else:
self.address = "%s:%d" % (name[0], name[1])
self.start = start_tcp
self.server = TCPStreamServer(host, port, self._client_handler, self.logger)
def start_unix_server(self, path):
def cleanup():
os.unlink(path)
self.server = UnixStreamServer(path, self._client_handler, self.logger)
def start_unix():
cwd = os.getcwd()
try:
# Work around path length limits in AF_UNIX
os.chdir(os.path.dirname(path))
self.server = self.loop.run_until_complete(
asyncio.start_unix_server(self.handle_client, os.path.basename(path))
)
finally:
os.chdir(cwd)
self.logger.debug('Listening on %r' % path)
self._cleanup_socket = cleanup
self.address = "unix://%s" % os.path.abspath(path)
self.start = start_unix
@abc.abstractmethod
def accept_client(self, reader, writer):
pass
async def handle_client(self, reader, writer):
# writer.transport.set_write_buffer_limits(0)
async def _client_handler(self, socket):
try:
client = self.accept_client(reader, writer)
client = self.accept_client(socket)
await client.process_requests()
except Exception as e:
import traceback
self.logger.error('Error from client: %s' % str(e), exc_info=True)
traceback.print_exc()
writer.close()
self.logger.debug('Client disconnected')
def run_loop_forever(self):
try:
self.loop.run_forever()
except KeyboardInterrupt:
pass
self.logger.error("Error from client: %s" % str(e), exc_info=True)
traceback.print_exc()
await socket.close()
self.logger.debug("Client disconnected")
@abc.abstractmethod
def accept_client(self, socket):
pass
async def stop(self):
self.logger.debug("Stopping server")
await self.server.stop()
def start(self):
tasks = self.server.start(self.loop)
self.address = self.server.address
return tasks
def signal_handler(self):
self.logger.debug("Got exit signal")
self.loop.stop()
self.loop.create_task(self.stop())
def _serve_forever(self):
def _serve_forever(self, tasks):
try:
self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler)
self.loop.add_signal_handler(signal.SIGINT, self.signal_handler)
self.loop.add_signal_handler(signal.SIGQUIT, self.signal_handler)
signal.pthread_sigmask(signal.SIG_UNBLOCK, [signal.SIGTERM])
self.run_loop_forever()
self.server.close()
self.loop.run_until_complete(asyncio.gather(*tasks))
self.loop.run_until_complete(self.server.wait_closed())
self.logger.debug('Server shutting down')
self.logger.debug("Server shutting down")
finally:
if self._cleanup_socket is not None:
self._cleanup_socket()
self.server.cleanup()
def serve_forever(self):
"""
Serve requests in the current process
"""
self._create_loop()
tasks = self.start()
self._serve_forever(tasks)
self.loop.close()
def _create_loop(self):
# Create loop and override any loop that may have existed in
# a parent process. It is possible that the usecases of
# serve_forever might be constrained enough to allow using
# get_event_loop here, but better safe than sorry for now.
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self.start()
self._serve_forever()
def serve_as_process(self, *, prefunc=None, args=()):
"""
Serve requests in a child process
"""
def run(queue):
# Create loop and override any loop that may have existed
# in a parent process. Without doing this and instead
@ -259,18 +266,19 @@ class AsyncServer(object):
# more general, though, as any potential use of asyncio in
# Cooker could create a loop that needs to replaced in this
# new process.
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self._create_loop()
try:
self.start()
self.address = None
tasks = self.start()
finally:
# Always put the server address to wake up the parent task
queue.put(self.address)
queue.close()
if prefunc is not None:
prefunc(self, *args)
self._serve_forever()
self._serve_forever(tasks)
if sys.version_info >= (3, 6):
self.loop.run_until_complete(self.loop.shutdown_asyncgens())

View File

@ -15,13 +15,6 @@ UNIX_PREFIX = "unix://"
ADDR_TYPE_UNIX = 0
ADDR_TYPE_TCP = 1
# The Python async server defaults to a 64K receive buffer, so we hardcode our
# maximum chunk size. It would be better if the client and server reported to
# each other what the maximum chunk sizes were, but that will slow down the
# connection setup with a round trip delay so I'd rather not do that unless it
# is necessary
DEFAULT_MAX_CHUNK = 32 * 1024
UNIHASH_TABLE_DEFINITION = (
("method", "TEXT NOT NULL", "UNIQUE"),
("taskhash", "TEXT NOT NULL", "UNIQUE"),
@ -102,20 +95,6 @@ def parse_address(addr):
return (ADDR_TYPE_TCP, (host, int(port)))
def chunkify(msg, max_chunk):
if len(msg) < max_chunk - 1:
yield ''.join((msg, "\n"))
else:
yield ''.join((json.dumps({
'chunk-stream': None
}), "\n"))
args = [iter(msg)] * (max_chunk - 1)
for m in map(''.join, itertools.zip_longest(*args, fillvalue='')):
yield ''.join(itertools.chain(m, "\n"))
yield "\n"
def create_server(addr, dbname, *, sync=True, upstream=None, read_only=False):
from . import server
db = setup_database(dbname, sync=sync)

View File

@ -28,24 +28,24 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
async def send_stream(self, msg):
async def proc():
self.writer.write(("%s\n" % msg).encode("utf-8"))
await self.writer.drain()
l = await self.reader.readline()
if not l:
raise ConnectionError("Connection closed")
return l.decode("utf-8").rstrip()
await self.socket.send(msg)
return await self.socket.recv()
return await self._send_wrapper(proc)
async def _set_mode(self, new_mode):
async def stream_to_normal():
await self.socket.send("END")
return await self.socket.recv()
if new_mode == self.MODE_NORMAL and self.mode == self.MODE_GET_STREAM:
r = await self.send_stream("END")
r = await self._send_wrapper(stream_to_normal)
if r != "ok":
raise ConnectionError("Bad response from server %r" % r)
raise ConnectionError("Unable to transition to normal mode: Bad response from server %r" % r)
elif new_mode == self.MODE_GET_STREAM and self.mode == self.MODE_NORMAL:
r = await self.send_message({"get-stream": None})
r = await self.invoke({"get-stream": None})
if r != "ok":
raise ConnectionError("Bad response from server %r" % r)
raise ConnectionError("Unable to transition to stream mode: Bad response from server %r" % r)
elif new_mode != self.mode:
raise Exception(
"Undefined mode transition %r -> %r" % (self.mode, new_mode)
@ -67,7 +67,7 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
m["method"] = method
m["outhash"] = outhash
m["unihash"] = unihash
return await self.send_message({"report": m})
return await self.invoke({"report": m})
async def report_unihash_equiv(self, taskhash, method, unihash, extra={}):
await self._set_mode(self.MODE_NORMAL)
@ -75,39 +75,39 @@ class AsyncClient(bb.asyncrpc.AsyncClient):
m["taskhash"] = taskhash
m["method"] = method
m["unihash"] = unihash
return await self.send_message({"report-equiv": m})
return await self.invoke({"report-equiv": m})
async def get_taskhash(self, method, taskhash, all_properties=False):
await self._set_mode(self.MODE_NORMAL)
return await self.send_message(
return await self.invoke(
{"get": {"taskhash": taskhash, "method": method, "all": all_properties}}
)
async def get_outhash(self, method, outhash, taskhash, with_unihash=True):
await self._set_mode(self.MODE_NORMAL)
return await self.send_message(
return await self.invoke(
{"get-outhash": {"outhash": outhash, "taskhash": taskhash, "method": method, "with_unihash": with_unihash}}
)
async def get_stats(self):
await self._set_mode(self.MODE_NORMAL)
return await self.send_message({"get-stats": None})
return await self.invoke({"get-stats": None})
async def reset_stats(self):
await self._set_mode(self.MODE_NORMAL)
return await self.send_message({"reset-stats": None})
return await self.invoke({"reset-stats": None})
async def backfill_wait(self):
await self._set_mode(self.MODE_NORMAL)
return (await self.send_message({"backfill-wait": None}))["tasks"]
return (await self.invoke({"backfill-wait": None}))["tasks"]
async def remove(self, where):
await self._set_mode(self.MODE_NORMAL)
return await self.send_message({"remove": {"where": where}})
return await self.invoke({"remove": {"where": where}})
async def clean_unused(self, max_age):
await self._set_mode(self.MODE_NORMAL)
return await self.send_message({"clean-unused": {"max_age_seconds": max_age}})
return await self.invoke({"clean-unused": {"max_age_seconds": max_age}})
class Client(bb.asyncrpc.Client):

View File

@ -165,8 +165,8 @@ class ServerCursor(object):
class ServerClient(bb.asyncrpc.AsyncServerConnection):
def __init__(self, reader, writer, db, request_stats, backfill_queue, upstream, read_only):
super().__init__(reader, writer, 'OEHASHEQUIV', logger)
def __init__(self, socket, db, request_stats, backfill_queue, upstream, read_only):
super().__init__(socket, 'OEHASHEQUIV', logger)
self.db = db
self.request_stats = request_stats
self.max_chunk = bb.asyncrpc.DEFAULT_MAX_CHUNK
@ -209,12 +209,11 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
if k in msg:
logger.debug('Handling %s' % k)
if 'stream' in k:
await self.handlers[k](msg[k])
return await self.handlers[k](msg[k])
else:
with self.request_stats.start_sample() as self.request_sample, \
self.request_sample.measure():
await self.handlers[k](msg[k])
return
return await self.handlers[k](msg[k])
raise bb.asyncrpc.ClientError("Unrecognized command %r" % msg)
@ -224,9 +223,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
fetch_all = request.get('all', False)
with closing(self.db.cursor()) as cursor:
d = await self.get_unihash(cursor, method, taskhash, fetch_all)
self.write_message(d)
return await self.get_unihash(cursor, method, taskhash, fetch_all)
async def get_unihash(self, cursor, method, taskhash, fetch_all=False):
d = None
@ -274,9 +271,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
with_unihash = request.get("with_unihash", True)
with closing(self.db.cursor()) as cursor:
d = await self.get_outhash(cursor, method, outhash, taskhash, with_unihash)
self.write_message(d)
return await self.get_outhash(cursor, method, outhash, taskhash, with_unihash)
async def get_outhash(self, cursor, method, outhash, taskhash, with_unihash=True):
d = None
@ -334,14 +329,14 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
)
async def handle_get_stream(self, request):
self.write_message('ok')
await self.socket.send_message("ok")
while True:
upstream = None
l = await self.reader.readline()
l = await self.socket.recv()
if not l:
return
break
try:
# This inner loop is very sensitive and must be as fast as
@ -352,10 +347,8 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
request_measure = self.request_sample.measure()
request_measure.start()
l = l.decode('utf-8').rstrip()
if l == 'END':
self.writer.write('ok\n'.encode('utf-8'))
return
break
(method, taskhash) = l.split()
#logger.debug('Looking up %s %s' % (method, taskhash))
@ -366,29 +359,30 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
cursor.close()
if row is not None:
msg = ('%s\n' % row['unihash']).encode('utf-8')
msg = row['unihash']
#logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
elif self.upstream_client is not None:
upstream = await self.upstream_client.get_unihash(method, taskhash)
if upstream:
msg = ("%s\n" % upstream).encode("utf-8")
msg = upstream
else:
msg = "\n".encode("utf-8")
msg = ""
else:
msg = '\n'.encode('utf-8')
msg = ""
self.writer.write(msg)
await self.socket.send(msg)
finally:
request_measure.end()
self.request_sample.end()
await self.writer.drain()
# Post to the backfill queue after writing the result to minimize
# the turn around time on a request
if upstream is not None:
await self.backfill_queue.put((method, taskhash))
await self.socket.send("ok")
return self.NO_RESPONSE
async def handle_report(self, data):
with closing(self.db.cursor()) as cursor:
outhash_data = {
@ -468,7 +462,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
'unihash': unihash,
}
self.write_message(d)
return d
async def handle_equivreport(self, data):
with closing(self.db.cursor()) as cursor:
@ -491,30 +485,28 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
self.write_message(d)
return d
async def handle_get_stats(self, request):
d = {
return {
'requests': self.request_stats.todict(),
}
self.write_message(d)
async def handle_reset_stats(self, request):
d = {
'requests': self.request_stats.todict(),
}
self.request_stats.reset()
self.write_message(d)
return d
async def handle_backfill_wait(self, request):
d = {
'tasks': self.backfill_queue.qsize(),
}
await self.backfill_queue.join()
self.write_message(d)
return d
async def handle_remove(self, request):
condition = request["where"]
@ -541,7 +533,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
count += do_remove(UNIHASH_TABLE_COLUMNS, "unihashes_v2", cursor)
self.db.commit()
self.write_message({"count": count})
return {"count": count}
async def handle_clean_unused(self, request):
max_age = request["max_age_seconds"]
@ -558,7 +550,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
)
count = cursor.rowcount
self.write_message({"count": count})
return {"count": count}
def query_equivalent(self, cursor, method, taskhash):
# This is part of the inner loop and must be as fast as possible
@ -583,41 +575,33 @@ class Server(bb.asyncrpc.AsyncServer):
self.db = db
self.upstream = upstream
self.read_only = read_only
self.backfill_queue = None
def accept_client(self, reader, writer):
return ServerClient(reader, writer, self.db, self.request_stats, self.backfill_queue, self.upstream, self.read_only)
def accept_client(self, socket):
return ServerClient(socket, self.db, self.request_stats, self.backfill_queue, self.upstream, self.read_only)
@contextmanager
def _backfill_worker(self):
async def backfill_worker_task():
client = await create_async_client(self.upstream)
try:
while True:
item = await self.backfill_queue.get()
if item is None:
self.backfill_queue.task_done()
break
method, taskhash = item
await copy_unihash_from_upstream(client, self.db, method, taskhash)
async def backfill_worker_task(self):
client = await create_async_client(self.upstream)
try:
while True:
item = await self.backfill_queue.get()
if item is None:
self.backfill_queue.task_done()
finally:
await client.close()
break
method, taskhash = item
await copy_unihash_from_upstream(client, self.db, method, taskhash)
self.backfill_queue.task_done()
finally:
await client.close()
async def join_worker(worker):
def start(self):
tasks = super().start()
if self.upstream:
self.backfill_queue = asyncio.Queue()
tasks += [self.backfill_worker_task()]
return tasks
async def stop(self):
if self.backfill_queue is not None:
await self.backfill_queue.put(None)
await worker
if self.upstream is not None:
worker = asyncio.ensure_future(backfill_worker_task())
try:
yield
finally:
self.loop.run_until_complete(join_worker(worker))
else:
yield
def run_loop_forever(self):
self.backfill_queue = asyncio.Queue()
with self._backfill_worker():
super().run_loop_forever()
await super().stop()

View File

@ -14,28 +14,28 @@ class PRAsyncClient(bb.asyncrpc.AsyncClient):
super().__init__('PRSERVICE', '1.0', logger)
async def getPR(self, version, pkgarch, checksum):
response = await self.send_message(
response = await self.invoke(
{'get-pr': {'version': version, 'pkgarch': pkgarch, 'checksum': checksum}}
)
if response:
return response['value']
async def importone(self, version, pkgarch, checksum, value):
response = await self.send_message(
response = await self.invoke(
{'import-one': {'version': version, 'pkgarch': pkgarch, 'checksum': checksum, 'value': value}}
)
if response:
return response['value']
async def export(self, version, pkgarch, checksum, colinfo):
response = await self.send_message(
response = await self.invoke(
{'export': {'version': version, 'pkgarch': pkgarch, 'checksum': checksum, 'colinfo': colinfo}}
)
if response:
return (response['metainfo'], response['datainfo'])
async def is_readonly(self):
response = await self.send_message(
response = await self.invoke(
{'is-readonly': {}}
)
if response:

View File

@ -20,8 +20,8 @@ PIDPREFIX = "/tmp/PRServer_%s_%s.pid"
singleton = None
class PRServerClient(bb.asyncrpc.AsyncServerConnection):
def __init__(self, reader, writer, table, read_only):
super().__init__(reader, writer, 'PRSERVICE', logger)
def __init__(self, socket, table, read_only):
super().__init__(socket, 'PRSERVICE', logger)
self.handlers.update({
'get-pr': self.handle_get_pr,
'import-one': self.handle_import_one,
@ -36,12 +36,12 @@ class PRServerClient(bb.asyncrpc.AsyncServerConnection):
async def dispatch_message(self, msg):
try:
await super().dispatch_message(msg)
return await super().dispatch_message(msg)
except:
self.table.sync()
raise
self.table.sync_if_dirty()
else:
self.table.sync_if_dirty()
async def handle_get_pr(self, request):
version = request['version']
@ -57,7 +57,7 @@ class PRServerClient(bb.asyncrpc.AsyncServerConnection):
except sqlite3.Error as exc:
logger.error(str(exc))
self.write_message(response)
return response
async def handle_import_one(self, request):
response = None
@ -71,7 +71,7 @@ class PRServerClient(bb.asyncrpc.AsyncServerConnection):
if value is not None:
response = {'value': value}
self.write_message(response)
return response
async def handle_export(self, request):
version = request['version']
@ -85,12 +85,10 @@ class PRServerClient(bb.asyncrpc.AsyncServerConnection):
logger.error(str(exc))
metainfo = datainfo = None
response = {'metainfo': metainfo, 'datainfo': datainfo}
self.write_message(response)
return {'metainfo': metainfo, 'datainfo': datainfo}
async def handle_is_readonly(self, request):
response = {'readonly': self.read_only}
self.write_message(response)
return {'readonly': self.read_only}
class PRServer(bb.asyncrpc.AsyncServer):
def __init__(self, dbfile, read_only=False):
@ -99,20 +97,23 @@ class PRServer(bb.asyncrpc.AsyncServer):
self.table = None
self.read_only = read_only
def accept_client(self, reader, writer):
return PRServerClient(reader, writer, self.table, self.read_only)
def accept_client(self, socket):
return PRServerClient(socket, self.table, self.read_only)
def _serve_forever(self):
def start(self):
tasks = super().start()
self.db = prserv.db.PRData(self.dbfile, read_only=self.read_only)
self.table = self.db["PRMAIN"]
logger.info("Started PRServer with DBfile: %s, Address: %s, PID: %s" %
(self.dbfile, self.address, str(os.getpid())))
super()._serve_forever()
return tasks
async def stop(self):
self.table.sync_if_dirty()
self.db.disconnect()
await super().stop()
def signal_handler(self):
super().signal_handler()