sedi-relay/relay/database.py

341 lines
7.5 KiB
Python

import tinysql
from datetime import datetime
from tinysql import Column, Table
from urllib.parse import urlparse
from .logger import set_level
from .misc import AppBase, DotDict, boolean
TABLES = [
Table('config',
Column('key', 'text', unique=True, nullable=False, primary_key=True),
Column('value', 'text')
),
Table('instances',
Column('id', 'serial'),
Column('domain', 'text', unique=True, nullable=False),
Column('actor', 'text'),
Column('inbox', 'text', nullable=False),
Column('followid', 'text'),
Column('software', 'text'),
Column('actor_data', 'json'),
Column('note', 'text'),
Column('joined', 'datetime', nullable=False),
Column('updated', 'datetime')
),
Table('whitelist',
Column('id', 'serial'),
Column('domain', 'text', unique=True),
Column('created', 'datetime', nullable=False)
),
Table('bans',
Column('id', 'serial'),
Column('name', 'text', unique=True),
Column('note', 'text'),
Column('type', 'text', nullable=False),
Column('created', 'datetime', nullable=False)
),
Table('users',
Column('id', 'serial'),
Column('handle', 'text', unique=True, nullable=False),
Column('domain', 'text', nullable=False),
Column('api_token', 'text'),
Column('created', 'datetime', nullable=False),
Column('updated', 'datetime')
),
Table('tokens',
Column('id', 'text', unique=True, nullable=False, primary_key=True),
Column('userid', 'integer', nullable=False),
Column('created', 'datetime', nullable=False),
Column('updated', 'datetime')
)
]
DEFAULT_CONFIG = {
'description': ('str', 'Make a note about your relay here'),
'http_timeout': ('int', 10),
'json_cache': ('int', 1024),
'log_level': ('str', 'INFO'),
'name': ('str', 'ActivityRelay'),
'privkey': ('str', ''),
'push_limit': ('int', 512),
'require_approval': ('bool', False),
'version': ('int', 20221211),
'whitelist': ('bool', False),
'workers': ('int', 8)
}
RELAY_SOFTWARE = [
'activityrelay', # https://git.pleroma.social/pleroma/relay
'aoderelay', # https://git.asonix.dog/asonix/relay
'feditools-relay' # https://git.ptzo.gdn/feditools/relay
]
class Database(AppBase, tinysql.Database):
def __init__(self, **config):
tinysql.Database.__init__(self, **config,
connection_class = Connection,
row_classes = [
ConfigRow
]
)
def create(self):
self.create_database(TABLES)
class Connection(tinysql.ConnectionMixin):
## Misc methods
def accept_request(self, domain):
row = self.get_request(domain)
if not row:
raise KeyError(domain)
data = {'joined': datetime.now()}
self.update('instances', data, id=row.id)
def distill_inboxes(self, message):
src_domains = {
message.domain,
urlparse(message.objectid).netloc
}
for instance in self.get_instances():
if instance.domain not in src_domains:
yield instance.inbox
## Delete methods
def delete_ban(self, type, name):
row = self.get_ban(type, name)
if not row:
raise KeyError(name)
self.delete('bans', id=row.id)
def delete_instance(self, domain):
row = self.get_instance(domain)
if not row:
raise KeyError(name)
self.delete('instances', id=row.id)
def delete_whitelist(self, domain):
row = self.get_whitelist_domain(domain)
if not row:
raise KeyError(domain)
self.delete('whitelist', id=row.id)
## Get methods
def get_ban(self, type, name):
if type not in {'software', 'domain'}:
raise ValueError('Ban type must be "software" or "domain"')
return self.select('bans', name=name, type=type).one()
def get_bans(self, type):
if type not in {'software', 'domain'}:
raise ValueError('Ban type must be "software" or "domain"')
return self.select('bans', type=type).all()
def get_config(self, key):
if key not in DEFAULT_CONFIG:
raise KeyError(key)
row = self.select('config', key=key).one()
if not row:
return DEFAULT_CONFIG[key][1]
return row.get_value()
def get_config_all(self):
rows = self.select('config').all()
config = DotDict({row.key: row.get_value() for row in rows})
for key, data in DEFAULT_CONFIG.items():
if key not in config:
config[key] = data[1]
return config
def get_hostnames(self):
return tuple(row.domain for row in self.get_instances())
def get_instance(self, data):
if data.startswith('http') and '#' in data:
data = data.split('#', 1)[0]
query = 'SELECT * FROM instances WHERE domain = :data OR actor = :data OR inbox = :data'
row = self.execute(query, dict(data=data), table='instances').one()
return row if row and row.joined else None
def get_instances(self):
query = 'SELECT * FROM instances WHERE joined IS NOT NULL'
return self.execute(query, table='instances').all()
def get_request(self, domain):
for instance in self.get_requests():
if instance.domain == domain:
return instance
raise KeyError(domain)
def get_requests(self):
query = 'SELECT * FROM instances WHERE joined IS NULL'
return self.execute(query, table='instances').all()
def get_whitelist(self):
return self.select('whitelist').all()
def get_whitelist_domain(self, domain):
return self.select('whitelist', domain=domain).one()
## Put methods
def put_ban(self, type, name, note=None):
if type not in {'software', 'domain'}:
raise ValueError('Ban type must be "software" or "domain"')
row = self.select('bans', name=name, type=type).one()
if row:
if note == None:
raise KeyError(name)
data = {'note': note}
self.update('bans', data, id=row.id)
return
self.insert('bans', {
'name': name,
'type': type,
'note': note,
'created': datetime.now()
})
def put_config(self, key, value='__DEFAULT__'):
if key not in DEFAULT_CONFIG:
raise KeyError(key)
if value == '__DEFAULT__':
value = DEFAULT_CONFIG[key][1]
if key == 'log_level':
set_level(value)
row = self.select('config', key=key).one()
if row:
self.update('config', {'value': value}, key=key)
return
self.insert('config', {
'key': key,
'value': value
})
def put_instance(self, domain, actor=None, inbox=None, followid=None, software=None, actor_data=None, note=None, accept=True):
new_data = {
'actor': actor,
'inbox': inbox,
'followid': followid,
'software': software,
'note': note
}
if actor_data:
new_data['actor_data'] = dict(actor_data)
new_data = {key: value for key, value in new_data.items() if value != None}
instance = self.get_instance(domain)
if instance:
if not new_data:
raise KeyError(domain)
instance.update(new_data)
self.update('instances', new_data, id=instance.id)
return instance
if not inbox:
raise ValueError('Inbox must be included in instance data')
if accept:
new_data['joined'] = datetime.now()
new_data['domain'] = domain
self.insert('instances', new_data)
return self.get_instance(domain)
def put_instance_actor(self, actor, nodeinfo=None, accept=True):
data = {
'domain': actor.domain,
'actor': actor.id,
'inbox': actor.shared_inbox,
'actor_data': actor,
'accept': accept,
'software': nodeinfo.sw_name if nodeinfo else None
}
return self.put_instance(**data)
def put_whitelist(self, domain):
if self.get_whitelist_domain(domain):
raise KeyError(domain)
self.insert('whitelist', {
'domain': domain,
'created': datetime.now()
})
class ConfigRow(tinysql.Row):
__table__ = 'config'
def get_value(self):
type = DEFAULT_CONFIG[self.key][0]
if type == 'int':
return int(self.value)
elif type == 'bool':
return boolean(self.value.encode('utf-8'))
elif type == 'list':
return json.loads(value)
elif type == 'json':
return DotDict.parse(value)
return self.value