bitbake: hashserv: Add SQLalchemy backend

Adds an SQLAlchemy backend to the server. While this database backend is
slower than the more direct sqlite backend, it easily supports just
about any SQL server, which is useful for large scale deployments.

(Bitbake rev: e0b73466dd7478c77c82f46879246c1b68b228c0)

Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
Signed-off-by: Richard Purdie <richard.purdie@linuxfoundation.org>
This commit is contained in:
Joshua Watt 2023-11-03 08:26:26 -06:00 committed by Richard Purdie
parent baa3e5391d
commit cfbb1d2cc0
5 changed files with 362 additions and 5 deletions

View File

@ -69,6 +69,16 @@ To bind to all addresses, leave the ADDRESS empty, e.g. "--bind :8686" or
action="store_true",
help="Disallow write operations from clients ($HASHSERVER_READ_ONLY)",
)
parser.add_argument(
"--db-username",
default=os.environ.get("HASHSERVER_DB_USERNAME", None),
help="Database username ($HASHSERVER_DB_USERNAME)",
)
parser.add_argument(
"--db-password",
default=os.environ.get("HASHSERVER_DB_PASSWORD", None),
help="Database password ($HASHSERVER_DB_PASSWORD)",
)
args = parser.parse_args()
@ -90,6 +100,8 @@ To bind to all addresses, leave the ADDRESS empty, e.g. "--bind :8686" or
args.database,
upstream=args.upstream,
read_only=read_only,
db_username=args.db_username,
db_password=args.db_password,
)
server.serve_forever()
return 0

View File

@ -7,6 +7,7 @@
import asyncio
import itertools
import json
from datetime import datetime
from .exceptions import ClientError, ConnectionClosedError
@ -30,6 +31,12 @@ def chunkify(msg, max_chunk):
yield "\n"
def json_serialize(obj):
if isinstance(obj, datetime):
return obj.isoformat()
raise TypeError("Type %s not serializeable" % type(obj))
class StreamConnection(object):
def __init__(self, reader, writer, timeout, max_chunk=DEFAULT_MAX_CHUNK):
self.reader = reader
@ -42,7 +49,7 @@ class StreamConnection(object):
return self.writer.get_extra_info("peername")
async def send_message(self, msg):
for c in chunkify(json.dumps(msg), self.max_chunk):
for c in chunkify(json.dumps(msg, default=json_serialize), self.max_chunk):
self.writer.write(c.encode("utf-8"))
await self.writer.drain()
@ -105,7 +112,7 @@ class WebsocketConnection(object):
return ":".join(str(s) for s in self.socket.remote_address)
async def send_message(self, msg):
await self.send(json.dumps(msg))
await self.send(json.dumps(msg, default=json_serialize))
async def recv_message(self):
m = await self.recv()

View File

@ -35,15 +35,32 @@ def parse_address(addr):
return (ADDR_TYPE_TCP, (host, int(port)))
def create_server(addr, dbname, *, sync=True, upstream=None, read_only=False):
def create_server(
addr,
dbname,
*,
sync=True,
upstream=None,
read_only=False,
db_username=None,
db_password=None
):
def sqlite_engine():
from .sqlite import DatabaseEngine
return DatabaseEngine(dbname, sync)
def sqlalchemy_engine():
from .sqlalchemy import DatabaseEngine
return DatabaseEngine(dbname, db_username, db_password)
from . import server
db_engine = sqlite_engine()
if "://" in dbname:
db_engine = sqlalchemy_engine()
else:
db_engine = sqlite_engine()
s = server.Server(db_engine, upstream=upstream, read_only=read_only)

View File

@ -0,0 +1,304 @@
#! /usr/bin/env python3
#
# Copyright (C) 2023 Garmin Ltd.
#
# SPDX-License-Identifier: GPL-2.0-only
#
import logging
from datetime import datetime
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.pool import NullPool
from sqlalchemy import (
MetaData,
Column,
Table,
Text,
Integer,
UniqueConstraint,
DateTime,
Index,
select,
insert,
exists,
literal,
and_,
delete,
)
import sqlalchemy.engine
from sqlalchemy.orm import declarative_base
from sqlalchemy.exc import IntegrityError
logger = logging.getLogger("hashserv.sqlalchemy")
Base = declarative_base()
class UnihashesV2(Base):
__tablename__ = "unihashes_v2"
id = Column(Integer, primary_key=True, autoincrement=True)
method = Column(Text, nullable=False)
taskhash = Column(Text, nullable=False)
unihash = Column(Text, nullable=False)
__table_args__ = (
UniqueConstraint("method", "taskhash"),
Index("taskhash_lookup_v3", "method", "taskhash"),
)
class OuthashesV2(Base):
__tablename__ = "outhashes_v2"
id = Column(Integer, primary_key=True, autoincrement=True)
method = Column(Text, nullable=False)
taskhash = Column(Text, nullable=False)
outhash = Column(Text, nullable=False)
created = Column(DateTime)
owner = Column(Text)
PN = Column(Text)
PV = Column(Text)
PR = Column(Text)
task = Column(Text)
outhash_siginfo = Column(Text)
__table_args__ = (
UniqueConstraint("method", "taskhash", "outhash"),
Index("outhash_lookup_v3", "method", "outhash"),
)
class DatabaseEngine(object):
def __init__(self, url, username=None, password=None):
self.logger = logger
self.url = sqlalchemy.engine.make_url(url)
if username is not None:
self.url = self.url.set(username=username)
if password is not None:
self.url = self.url.set(password=password)
async def create(self):
self.logger.info("Using database %s", self.url)
self.engine = create_async_engine(self.url, poolclass=NullPool)
async with self.engine.begin() as conn:
# Create tables
logger.info("Creating tables...")
await conn.run_sync(Base.metadata.create_all)
def connect(self, logger):
return Database(self.engine, logger)
def map_row(row):
if row is None:
return None
return dict(**row._mapping)
class Database(object):
def __init__(self, engine, logger):
self.engine = engine
self.db = None
self.logger = logger
async def __aenter__(self):
self.db = await self.engine.connect()
return self
async def __aexit__(self, exc_type, exc_value, traceback):
await self.close()
async def close(self):
await self.db.close()
self.db = None
async def get_unihash_by_taskhash_full(self, method, taskhash):
statement = (
select(
OuthashesV2,
UnihashesV2.unihash.label("unihash"),
)
.join(
UnihashesV2,
and_(
UnihashesV2.method == OuthashesV2.method,
UnihashesV2.taskhash == OuthashesV2.taskhash,
),
)
.where(
OuthashesV2.method == method,
OuthashesV2.taskhash == taskhash,
)
.order_by(
OuthashesV2.created.asc(),
)
.limit(1)
)
self.logger.debug("%s", statement)
async with self.db.begin():
result = await self.db.execute(statement)
return map_row(result.first())
async def get_unihash_by_outhash(self, method, outhash):
statement = (
select(OuthashesV2, UnihashesV2.unihash.label("unihash"))
.join(
UnihashesV2,
and_(
UnihashesV2.method == OuthashesV2.method,
UnihashesV2.taskhash == OuthashesV2.taskhash,
),
)
.where(
OuthashesV2.method == method,
OuthashesV2.outhash == outhash,
)
.order_by(
OuthashesV2.created.asc(),
)
.limit(1)
)
self.logger.debug("%s", statement)
async with self.db.begin():
result = await self.db.execute(statement)
return map_row(result.first())
async def get_outhash(self, method, outhash):
statement = (
select(OuthashesV2)
.where(
OuthashesV2.method == method,
OuthashesV2.outhash == outhash,
)
.order_by(
OuthashesV2.created.asc(),
)
.limit(1)
)
self.logger.debug("%s", statement)
async with self.db.begin():
result = await self.db.execute(statement)
return map_row(result.first())
async def get_equivalent_for_outhash(self, method, outhash, taskhash):
statement = (
select(
OuthashesV2.taskhash.label("taskhash"),
UnihashesV2.unihash.label("unihash"),
)
.join(
UnihashesV2,
and_(
UnihashesV2.method == OuthashesV2.method,
UnihashesV2.taskhash == OuthashesV2.taskhash,
),
)
.where(
OuthashesV2.method == method,
OuthashesV2.outhash == outhash,
OuthashesV2.taskhash != taskhash,
)
.order_by(
OuthashesV2.created.asc(),
)
.limit(1)
)
self.logger.debug("%s", statement)
async with self.db.begin():
result = await self.db.execute(statement)
return map_row(result.first())
async def get_equivalent(self, method, taskhash):
statement = select(
UnihashesV2.unihash,
UnihashesV2.method,
UnihashesV2.taskhash,
).where(
UnihashesV2.method == method,
UnihashesV2.taskhash == taskhash,
)
self.logger.debug("%s", statement)
async with self.db.begin():
result = await self.db.execute(statement)
return map_row(result.first())
async def remove(self, condition):
async def do_remove(table):
where = {}
for c in table.__table__.columns:
if c.key in condition and condition[c.key] is not None:
where[c] = condition[c.key]
if where:
statement = delete(table).where(*[(k == v) for k, v in where.items()])
self.logger.debug("%s", statement)
async with self.db.begin():
result = await self.db.execute(statement)
return result.rowcount
return 0
count = 0
count += await do_remove(UnihashesV2)
count += await do_remove(OuthashesV2)
return count
async def clean_unused(self, oldest):
statement = delete(OuthashesV2).where(
OuthashesV2.created < oldest,
~(
select(UnihashesV2.id)
.where(
UnihashesV2.method == OuthashesV2.method,
UnihashesV2.taskhash == OuthashesV2.taskhash,
)
.limit(1)
.exists()
),
)
self.logger.debug("%s", statement)
async with self.db.begin():
result = await self.db.execute(statement)
return result.rowcount
async def insert_unihash(self, method, taskhash, unihash):
statement = insert(UnihashesV2).values(
method=method,
taskhash=taskhash,
unihash=unihash,
)
self.logger.debug("%s", statement)
try:
async with self.db.begin():
await self.db.execute(statement)
return True
except IntegrityError:
logger.debug(
"%s, %s, %s already in unihash database", method, taskhash, unihash
)
return False
async def insert_outhash(self, data):
outhash_columns = set(c.key for c in OuthashesV2.__table__.columns)
data = {k: v for k, v in data.items() if k in outhash_columns}
if "created" in data and not isinstance(data["created"], datetime):
data["created"] = datetime.fromisoformat(data["created"])
statement = insert(OuthashesV2).values(**data)
self.logger.debug("%s", statement)
try:
async with self.db.begin():
await self.db.execute(statement)
return True
except IntegrityError:
logger.debug(
"%s, %s already in outhash database", data["method"], data["outhash"]
)
return False

View File

@ -33,7 +33,7 @@ class HashEquivalenceTestSetup(object):
def start_server(self, dbpath=None, upstream=None, read_only=False, prefunc=server_prefunc):
self.server_index += 1
if dbpath is None:
dbpath = os.path.join(self.temp_dir.name, "db%d.sqlite" % self.server_index)
dbpath = self.make_dbpath()
def cleanup_server(server):
if server.process.exitcode is not None:
@ -53,6 +53,9 @@ class HashEquivalenceTestSetup(object):
return server
def make_dbpath(self):
return os.path.join(self.temp_dir.name, "db%d.sqlite" % self.server_index)
def start_client(self, server_address):
def cleanup_client(client):
client.close()
@ -517,6 +520,20 @@ class TestHashEquivalenceWebsocketServer(HashEquivalenceTestSetup, HashEquivalen
return "ws://%s:0" % host
class TestHashEquivalenceWebsocketsSQLAlchemyServer(TestHashEquivalenceWebsocketServer):
def setUp(self):
try:
import sqlalchemy
import aiosqlite
except ImportError as e:
self.skipTest(str(e))
super().setUp()
def make_dbpath(self):
return "sqlite+aiosqlite:///%s" % os.path.join(self.temp_dir.name, "db%d.sqlite" % self.server_index)
class TestHashEquivalenceExternalServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase):
def start_test_server(self):
if 'BB_TEST_HASHSERV' not in os.environ: