240 lines
5.1 KiB
Python
240 lines
5.1 KiB
Python
import tinysql
|
|
|
|
from datetime import datetime
|
|
from urllib.parse import urlparse
|
|
|
|
from .base import DEFAULT_CONFIG
|
|
from ..misc import DotDict
|
|
|
|
|
|
class Connection(tinysql.ConnectionMixin):
|
|
## Misc methods
|
|
def accept_request(self, domain):
|
|
row = self.get_request(domain)
|
|
|
|
if not row:
|
|
raise KeyError(domain)
|
|
|
|
data = {'joined': datetime.now()}
|
|
self.update('instances', data, id=row.id)
|
|
|
|
|
|
def distill_inboxes(self, message):
|
|
src_domains = {
|
|
message.domain,
|
|
urlparse(message.objectid).netloc
|
|
}
|
|
|
|
for instance in self.get_instances():
|
|
if instance.domain not in src_domains:
|
|
yield instance.inbox
|
|
|
|
|
|
## Delete methods
|
|
def delete_ban(self, type, name):
|
|
row = self.get_ban(type, name)
|
|
|
|
if not row:
|
|
raise KeyError(name)
|
|
|
|
self.delete('bans', id=row.id)
|
|
|
|
|
|
def delete_instance(self, domain):
|
|
row = self.get_instance(domain)
|
|
|
|
if not row:
|
|
raise KeyError(domain)
|
|
|
|
self.delete('instances', id=row.id)
|
|
|
|
|
|
def delete_whitelist(self, domain):
|
|
row = self.get_whitelist_domain(domain)
|
|
|
|
if not row:
|
|
raise KeyError(domain)
|
|
|
|
self.delete('whitelist', id=row.id)
|
|
|
|
|
|
## Get methods
|
|
def get_ban(self, type, name):
|
|
if type not in {'software', 'domain'}:
|
|
raise ValueError('Ban type must be "software" or "domain"')
|
|
|
|
return self.select('bans', name=name, type=type).one()
|
|
|
|
|
|
def get_bans(self, type):
|
|
if type not in {'software', 'domain'}:
|
|
raise ValueError('Ban type must be "software" or "domain"')
|
|
|
|
return self.select('bans', type=type).all()
|
|
|
|
|
|
def get_config(self, key):
|
|
if key not in DEFAULT_CONFIG:
|
|
raise KeyError(key)
|
|
|
|
row = self.select('config', key=key).one()
|
|
|
|
if not row:
|
|
return DEFAULT_CONFIG[key][1]
|
|
|
|
return row.value
|
|
|
|
|
|
def get_config_all(self):
|
|
rows = self.select('config').all()
|
|
config = DotDict({row.key: row.value for row in rows})
|
|
|
|
for key, data in DEFAULT_CONFIG.items():
|
|
if key not in config:
|
|
config[key] = data[1]
|
|
|
|
return config
|
|
|
|
|
|
def get_hostnames(self):
|
|
return tuple(row.domain for row in self.get_instances())
|
|
|
|
|
|
def get_instance(self, data):
|
|
if data.startswith('http') and '#' in data:
|
|
data = data.split('#', 1)[0]
|
|
|
|
query = 'SELECT * FROM instances WHERE domain = :data OR actor = :data OR inbox = :data'
|
|
row = self.execute(query, dict(data=data), table='instances').one()
|
|
return row if row and row.joined else None
|
|
|
|
|
|
def get_instances(self):
|
|
query = 'SELECT * FROM instances WHERE joined IS NOT NULL'
|
|
query += ' ORDER BY domain ASC'
|
|
return self.execute(query, table='instances').all()
|
|
|
|
|
|
def get_request(self, domain):
|
|
for instance in self.get_requests():
|
|
if instance.domain == domain:
|
|
return instance
|
|
|
|
raise KeyError(domain)
|
|
|
|
|
|
def get_requests(self):
|
|
query = 'SELECT * FROM instances WHERE joined IS NULL ORDER BY domain ASC'
|
|
return self.execute(query, table='instances').all()
|
|
|
|
|
|
def get_whitelist(self):
|
|
return self.select('whitelist').all()
|
|
|
|
|
|
def get_whitelist_domain(self, domain):
|
|
return self.select('whitelist', domain=domain).one()
|
|
|
|
|
|
## Put methods
|
|
def put_ban(self, type, name, note=None):
|
|
if type not in {'software', 'domain'}:
|
|
raise ValueError('Ban type must be "software" or "domain"')
|
|
|
|
row = self.select('bans', name=name, type=type).one()
|
|
|
|
if row:
|
|
if note == None:
|
|
raise KeyError(name)
|
|
|
|
data = {'note': note}
|
|
self.update('bans', data, id=row.id)
|
|
return
|
|
|
|
self.insert('bans', {
|
|
'name': name,
|
|
'type': type,
|
|
'note': note,
|
|
'created': datetime.now()
|
|
})
|
|
|
|
|
|
def put_config(self, key, value='__DEFAULT__'):
|
|
if key not in DEFAULT_CONFIG:
|
|
raise KeyError(key)
|
|
|
|
if value == '__DEFAULT__':
|
|
value = DEFAULT_CONFIG[key][1]
|
|
|
|
elif key == 'log_level' and not getattr(logging, value.upper(), False):
|
|
raise KeyError(value)
|
|
|
|
row = self.select('config', key=key).one()
|
|
|
|
if row:
|
|
self.update('config', {'value': value}, key=key)
|
|
return
|
|
|
|
self.insert('config', {
|
|
'key': key,
|
|
'value': value
|
|
})
|
|
|
|
|
|
def put_instance(self, domain, actor=None, inbox=None, followid=None, software=None, actor_data=None, note=None, accept=True):
|
|
new_data = {
|
|
'actor': actor,
|
|
'inbox': inbox,
|
|
'followid': followid,
|
|
'software': software,
|
|
'note': note
|
|
}
|
|
|
|
if actor_data:
|
|
new_data['actor_data'] = dict(actor_data)
|
|
|
|
new_data = {key: value for key, value in new_data.items() if value != None}
|
|
instance = self.get_instance(domain)
|
|
|
|
if instance:
|
|
if not new_data:
|
|
raise KeyError(domain)
|
|
|
|
instance.update(new_data)
|
|
self.update('instances', new_data, id=instance.id)
|
|
return instance
|
|
|
|
if not inbox:
|
|
raise ValueError('Inbox must be included in instance data')
|
|
|
|
if accept:
|
|
new_data['joined'] = datetime.now()
|
|
|
|
new_data['domain'] = domain
|
|
|
|
self.insert('instances', new_data)
|
|
return self.get_instance(domain)
|
|
|
|
|
|
def put_instance_actor(self, actor, nodeinfo=None, accept=True):
|
|
data = {
|
|
'domain': actor.domain,
|
|
'actor': actor.id,
|
|
'inbox': actor.shared_inbox,
|
|
'actor_data': actor,
|
|
'accept': accept,
|
|
'software': nodeinfo.sw_name if nodeinfo else None
|
|
}
|
|
|
|
return self.put_instance(**data)
|
|
|
|
|
|
def put_whitelist(self, domain):
|
|
if self.get_whitelist_domain(domain):
|
|
raise KeyError(domain)
|
|
|
|
self.insert('whitelist', {
|
|
'domain': domain,
|
|
'created': datetime.now()
|
|
})
|