mirror of
https://git.pleroma.social/pleroma/relay.git
synced 2024-11-08 17:48:00 +00:00
update barkshark-sql to 0.2.0-rc1 and create row classes
This commit is contained in:
parent
45b0de26c7
commit
bdc7d41d7a
|
@ -20,7 +20,7 @@ dependencies = [
|
|||
"aiohttp >= 3.9.5",
|
||||
"aiohttp-swagger[performance] == 1.0.16",
|
||||
"argon2-cffi == 23.1.0",
|
||||
"barkshark-lib >= 0.1.3-1",
|
||||
"barkshark-lib >= 0.2.0-rc1",
|
||||
"barkshark-sql == 0.1.4-1",
|
||||
"click >= 8.1.2",
|
||||
"hiredis == 2.3.2",
|
||||
|
@ -104,7 +104,3 @@ implicit_reexport = true
|
|||
[[tool.mypy.overrides]]
|
||||
module = "blib"
|
||||
implicit_reexport = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "bsql"
|
||||
implicit_reexport = true
|
||||
|
|
|
@ -1 +1 @@
|
|||
__version__ = '0.3.2'
|
||||
__version__ = '0.3.3'
|
||||
|
|
|
@ -4,30 +4,26 @@ import asyncio
|
|||
import multiprocessing
|
||||
import signal
|
||||
import time
|
||||
import traceback
|
||||
|
||||
from aiohttp import web
|
||||
from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError
|
||||
from aiohttp.web import StaticResource
|
||||
from aiohttp_swagger import setup_swagger
|
||||
from aputils.signer import Signer
|
||||
from asyncio.exceptions import TimeoutError as AsyncTimeoutError
|
||||
from bsql import Database, Row
|
||||
from bsql import Database
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import datetime, timedelta
|
||||
from mimetypes import guess_type
|
||||
from pathlib import Path
|
||||
from queue import Empty
|
||||
from threading import Event, Thread
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from . import logger as logging, workers
|
||||
from .cache import Cache, get_cache
|
||||
from .config import Config
|
||||
from .database import Connection, get_database
|
||||
from .database.schema import Instance
|
||||
from .http_client import HttpClient
|
||||
from .misc import IS_WINDOWS, Message, Response, check_open_port, get_resource
|
||||
from .misc import Message, Response, check_open_port, get_resource
|
||||
from .template import Template
|
||||
from .views import VIEWS
|
||||
from .views.api import handle_api_path
|
||||
|
@ -142,7 +138,7 @@ class Application(web.Application):
|
|||
return timedelta(seconds=uptime.seconds)
|
||||
|
||||
|
||||
def push_message(self, inbox: str, message: Message, instance: Row) -> None:
|
||||
def push_message(self, inbox: str, message: Message, instance: Instance) -> None:
|
||||
self['workers'].push_message(inbox, message, instance)
|
||||
|
||||
|
||||
|
@ -286,67 +282,6 @@ class CacheCleanupThread(Thread):
|
|||
self.running.clear()
|
||||
|
||||
|
||||
class PushWorker(multiprocessing.Process):
|
||||
def __init__(self, queue: multiprocessing.Queue[tuple[str, Message, Row]]) -> None:
|
||||
if Application.DEFAULT is None:
|
||||
raise RuntimeError('Application not setup yet')
|
||||
|
||||
multiprocessing.Process.__init__(self)
|
||||
|
||||
self.queue = queue
|
||||
self.shutdown = multiprocessing.Event()
|
||||
self.path = Application.DEFAULT.config.path
|
||||
|
||||
|
||||
def stop(self) -> None:
|
||||
self.shutdown.set()
|
||||
|
||||
|
||||
def run(self) -> None:
|
||||
asyncio.run(self.handle_queue())
|
||||
|
||||
|
||||
async def handle_queue(self) -> None:
|
||||
if IS_WINDOWS:
|
||||
app = Application(self.path)
|
||||
client = app.client
|
||||
|
||||
client.open()
|
||||
app.database.connect()
|
||||
app.cache.setup()
|
||||
|
||||
else:
|
||||
client = HttpClient()
|
||||
client.open()
|
||||
|
||||
while not self.shutdown.is_set():
|
||||
try:
|
||||
inbox, message, instance = self.queue.get(block=True, timeout=0.1)
|
||||
asyncio.create_task(client.post(inbox, message, instance))
|
||||
|
||||
except Empty:
|
||||
await asyncio.sleep(0)
|
||||
|
||||
except ClientSSLError as e:
|
||||
logging.error('SSL error when pushing to %s: %s', urlparse(inbox).netloc, str(e))
|
||||
|
||||
except (AsyncTimeoutError, ClientConnectionError) as e:
|
||||
logging.error(
|
||||
'Failed to connect to %s for message push: %s',
|
||||
urlparse(inbox).netloc, str(e)
|
||||
)
|
||||
|
||||
# make sure an exception doesn't bring down the worker
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
if IS_WINDOWS:
|
||||
app.database.disconnect()
|
||||
app.cache.close()
|
||||
|
||||
await client.close()
|
||||
|
||||
|
||||
@web.middleware
|
||||
async def handle_response_headers(
|
||||
request: web.Request,
|
||||
|
|
|
@ -4,7 +4,7 @@ import json
|
|||
import os
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from bsql import Database
|
||||
from bsql import Database, Row
|
||||
from collections.abc import Callable, Iterator
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
@ -172,7 +172,7 @@ class SqlCache(Cache):
|
|||
|
||||
with self._db.session(False) as conn:
|
||||
with conn.run('get-cache-item', params) as cur:
|
||||
if not (row := cur.one()):
|
||||
if not (row := cur.one(Row)):
|
||||
raise KeyError(f'{namespace}:{key}')
|
||||
|
||||
row.pop('id', None)
|
||||
|
@ -211,9 +211,11 @@ class SqlCache(Cache):
|
|||
|
||||
with self._db.session(True) as conn:
|
||||
with conn.run('set-cache-item', params) as cur:
|
||||
row = cur.one()
|
||||
row.pop('id', None) # type: ignore[union-attr]
|
||||
return Item.from_data(*tuple(row.values())) # type: ignore[union-attr]
|
||||
if (row := cur.one(Row)) is None:
|
||||
raise RuntimeError("Cache item not set")
|
||||
|
||||
row.pop('id', None)
|
||||
return Item.from_data(*tuple(row.values()))
|
||||
|
||||
|
||||
def delete(self, namespace: str, key: str) -> None:
|
||||
|
|
|
@ -16,6 +16,8 @@ def get_database(config: Config, migrate: bool = True) -> Database[Connection]:
|
|||
'tables': TABLES
|
||||
}
|
||||
|
||||
db: Database[Connection]
|
||||
|
||||
if config.db_type == 'sqlite':
|
||||
db = Database.sqlite(config.sqlite_path, **options)
|
||||
|
||||
|
|
|
@ -2,12 +2,13 @@ from __future__ import annotations
|
|||
|
||||
from argon2 import PasswordHasher
|
||||
from bsql import Connection as SqlConnection, Row, Update
|
||||
from collections.abc import Iterator, Sequence
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from urllib.parse import urlparse
|
||||
from uuid import uuid4
|
||||
|
||||
from . import schema
|
||||
from .config import (
|
||||
THEMES,
|
||||
ConfigData
|
||||
|
@ -37,14 +38,14 @@ class Connection(SqlConnection):
|
|||
return get_app()
|
||||
|
||||
|
||||
def distill_inboxes(self, message: Message) -> Iterator[Row]:
|
||||
def distill_inboxes(self, message: Message) -> Iterator[schema.Instance]:
|
||||
src_domains = {
|
||||
message.domain,
|
||||
urlparse(message.object_id).netloc
|
||||
}
|
||||
|
||||
for instance in self.get_inboxes():
|
||||
if instance['domain'] not in src_domains:
|
||||
if instance.domain not in src_domains:
|
||||
yield instance
|
||||
|
||||
|
||||
|
@ -52,7 +53,7 @@ class Connection(SqlConnection):
|
|||
key = key.replace('_', '-')
|
||||
|
||||
with self.run('get-config', {'key': key}) as cur:
|
||||
if not (row := cur.one()):
|
||||
if (row := cur.one(Row)) is None:
|
||||
return ConfigData.DEFAULT(key)
|
||||
|
||||
data = ConfigData()
|
||||
|
@ -61,8 +62,8 @@ class Connection(SqlConnection):
|
|||
|
||||
|
||||
def get_config_all(self) -> ConfigData:
|
||||
with self.run('get-config-all', None) as cur:
|
||||
return ConfigData.from_rows(tuple(cur.all()))
|
||||
rows = tuple(self.run('get-config-all', None).all(schema.Row))
|
||||
return ConfigData.from_rows(rows)
|
||||
|
||||
|
||||
def put_config(self, key: str, value: Any) -> Any:
|
||||
|
@ -99,14 +100,13 @@ class Connection(SqlConnection):
|
|||
return data.get(key)
|
||||
|
||||
|
||||
def get_inbox(self, value: str) -> Row:
|
||||
def get_inbox(self, value: str) -> schema.Instance | None:
|
||||
with self.run('get-inbox', {'value': value}) as cur:
|
||||
return cur.one() # type: ignore
|
||||
return cur.one(schema.Instance)
|
||||
|
||||
|
||||
def get_inboxes(self) -> Sequence[Row]:
|
||||
with self.execute("SELECT * FROM inboxes WHERE accepted = 1") as cur:
|
||||
return tuple(cur.all())
|
||||
def get_inboxes(self) -> Iterator[schema.Instance]:
|
||||
return self.execute("SELECT * FROM inboxes WHERE accepted = 1").all(schema.Instance)
|
||||
|
||||
|
||||
def put_inbox(self,
|
||||
|
@ -115,7 +115,7 @@ class Connection(SqlConnection):
|
|||
actor: str | None = None,
|
||||
followid: str | None = None,
|
||||
software: str | None = None,
|
||||
accepted: bool = True) -> Row:
|
||||
accepted: bool = True) -> schema.Instance:
|
||||
|
||||
params: dict[str, Any] = {
|
||||
'inbox': inbox,
|
||||
|
@ -125,7 +125,7 @@ class Connection(SqlConnection):
|
|||
'accepted': accepted
|
||||
}
|
||||
|
||||
if not self.get_inbox(domain):
|
||||
if self.get_inbox(domain) is None:
|
||||
if not inbox:
|
||||
raise ValueError("Missing inbox")
|
||||
|
||||
|
@ -133,14 +133,20 @@ class Connection(SqlConnection):
|
|||
params['created'] = datetime.now(tz = timezone.utc)
|
||||
|
||||
with self.run('put-inbox', params) as cur:
|
||||
return cur.one() # type: ignore
|
||||
if (row := cur.one(schema.Instance)) is None:
|
||||
raise RuntimeError(f"Failed to insert instance: {domain}")
|
||||
|
||||
return row
|
||||
|
||||
for key, value in tuple(params.items()):
|
||||
if value is None:
|
||||
del params[key]
|
||||
|
||||
with self.update('inboxes', params, domain = domain) as cur:
|
||||
return cur.one() # type: ignore
|
||||
if (row := cur.one(schema.Instance)) is None:
|
||||
raise RuntimeError(f"Failed to update instance: {domain}")
|
||||
|
||||
return row
|
||||
|
||||
|
||||
def del_inbox(self, value: str) -> bool:
|
||||
|
@ -151,24 +157,23 @@ class Connection(SqlConnection):
|
|||
return cur.row_count == 1
|
||||
|
||||
|
||||
def get_request(self, domain: str) -> Row:
|
||||
def get_request(self, domain: str) -> schema.Instance | None:
|
||||
with self.run('get-request', {'domain': domain}) as cur:
|
||||
if not (row := cur.one()):
|
||||
raise KeyError(domain)
|
||||
|
||||
return row
|
||||
return cur.one(schema.Instance)
|
||||
|
||||
|
||||
def get_requests(self) -> Sequence[Row]:
|
||||
with self.execute('SELECT * FROM inboxes WHERE accepted = 0') as cur:
|
||||
return tuple(cur.all())
|
||||
def get_requests(self) -> Iterator[schema.Instance]:
|
||||
return self.execute('SELECT * FROM inboxes WHERE accepted = 0').all(schema.Instance)
|
||||
|
||||
|
||||
def put_request_response(self, domain: str, accepted: bool) -> Row:
|
||||
instance = self.get_request(domain)
|
||||
def put_request_response(self, domain: str, accepted: bool) -> schema.Instance:
|
||||
if (instance := self.get_request(domain)) is None:
|
||||
raise KeyError(domain)
|
||||
|
||||
if not accepted:
|
||||
self.del_inbox(domain)
|
||||
if not self.del_inbox(domain):
|
||||
raise RuntimeError(f'Failed to delete request: {domain}')
|
||||
|
||||
return instance
|
||||
|
||||
params = {
|
||||
|
@ -177,21 +182,28 @@ class Connection(SqlConnection):
|
|||
}
|
||||
|
||||
with self.run('put-inbox-accept', params) as cur:
|
||||
return cur.one() # type: ignore
|
||||
if (row := cur.one(schema.Instance)) is None:
|
||||
raise RuntimeError(f"Failed to insert response for domain: {domain}")
|
||||
|
||||
return row
|
||||
|
||||
|
||||
def get_user(self, value: str) -> Row:
|
||||
def get_user(self, value: str) -> schema.User | None:
|
||||
with self.run('get-user', {'value': value}) as cur:
|
||||
return cur.one() # type: ignore
|
||||
return cur.one(schema.User)
|
||||
|
||||
|
||||
def get_user_by_token(self, code: str) -> Row:
|
||||
def get_user_by_token(self, code: str) -> schema.User | None:
|
||||
with self.run('get-user-by-token', {'code': code}) as cur:
|
||||
return cur.one() # type: ignore
|
||||
return cur.one(schema.User)
|
||||
|
||||
|
||||
def put_user(self, username: str, password: str | None, handle: str | None = None) -> Row:
|
||||
if self.get_user(username):
|
||||
def get_users(self) -> Iterator[schema.User]:
|
||||
return self.execute("SELECT * FROM users").all(schema.User)
|
||||
|
||||
|
||||
def put_user(self, username: str, password: str | None, handle: str | None = None) -> schema.User:
|
||||
if self.get_user(username) is not None:
|
||||
data: dict[str, str | datetime | None] = {}
|
||||
|
||||
if password:
|
||||
|
@ -204,7 +216,10 @@ class Connection(SqlConnection):
|
|||
stmt.set_where("username", username)
|
||||
|
||||
with self.query(stmt) as cur:
|
||||
return cur.one() # type: ignore
|
||||
if (row := cur.one(schema.User)) is None:
|
||||
raise RuntimeError(f"Failed to update user: {username}")
|
||||
|
||||
return row
|
||||
|
||||
if password is None:
|
||||
raise ValueError('Password cannot be empty')
|
||||
|
@ -217,25 +232,36 @@ class Connection(SqlConnection):
|
|||
}
|
||||
|
||||
with self.run('put-user', data) as cur:
|
||||
return cur.one() # type: ignore
|
||||
if (row := cur.one(schema.User)) is None:
|
||||
raise RuntimeError(f"Failed to insert user: {username}")
|
||||
|
||||
return row
|
||||
|
||||
|
||||
def del_user(self, username: str) -> None:
|
||||
user = self.get_user(username)
|
||||
if (user := self.get_user(username)) is None:
|
||||
raise KeyError(username)
|
||||
|
||||
with self.run('del-user', {'value': user['username']}):
|
||||
with self.run('del-user', {'value': user.username}):
|
||||
pass
|
||||
|
||||
with self.run('del-token-user', {'username': user['username']}):
|
||||
with self.run('del-token-user', {'username': user.username}):
|
||||
pass
|
||||
|
||||
|
||||
def get_token(self, code: str) -> Row:
|
||||
def get_token(self, code: str) -> schema.Token | None:
|
||||
with self.run('get-token', {'code': code}) as cur:
|
||||
return cur.one() # type: ignore
|
||||
return cur.one(schema.Token)
|
||||
|
||||
|
||||
def put_token(self, username: str) -> Row:
|
||||
def get_tokens(self, username: str | None = None) -> Iterator[schema.Token]:
|
||||
if username is not None:
|
||||
return self.select('tokens').all(schema.Token)
|
||||
|
||||
return self.select('tokens', username = username).all(schema.Token)
|
||||
|
||||
|
||||
def put_token(self, username: str) -> schema.Token:
|
||||
data = {
|
||||
'code': uuid4().hex,
|
||||
'user': username,
|
||||
|
@ -243,7 +269,10 @@ class Connection(SqlConnection):
|
|||
}
|
||||
|
||||
with self.run('put-token', data) as cur:
|
||||
return cur.one() # type: ignore
|
||||
if (row := cur.one(schema.Token)) is None:
|
||||
raise RuntimeError(f"Failed to insert token for user: {username}")
|
||||
|
||||
return row
|
||||
|
||||
|
||||
def del_token(self, code: str) -> None:
|
||||
|
@ -251,18 +280,22 @@ class Connection(SqlConnection):
|
|||
pass
|
||||
|
||||
|
||||
def get_domain_ban(self, domain: str) -> Row:
|
||||
def get_domain_ban(self, domain: str) -> schema.DomainBan | None:
|
||||
if domain.startswith('http'):
|
||||
domain = urlparse(domain).netloc
|
||||
|
||||
with self.run('get-domain-ban', {'domain': domain}) as cur:
|
||||
return cur.one() # type: ignore
|
||||
return cur.one(schema.DomainBan)
|
||||
|
||||
|
||||
def get_domain_bans(self) -> Iterator[schema.DomainBan]:
|
||||
return self.execute("SELECT * FROM domain_bans").all(schema.DomainBan)
|
||||
|
||||
|
||||
def put_domain_ban(self,
|
||||
domain: str,
|
||||
reason: str | None = None,
|
||||
note: str | None = None) -> Row:
|
||||
note: str | None = None) -> schema.DomainBan:
|
||||
|
||||
params = {
|
||||
'domain': domain,
|
||||
|
@ -272,13 +305,16 @@ class Connection(SqlConnection):
|
|||
}
|
||||
|
||||
with self.run('put-domain-ban', params) as cur:
|
||||
return cur.one() # type: ignore
|
||||
if (row := cur.one(schema.DomainBan)) is None:
|
||||
raise RuntimeError(f"Failed to insert domain ban: {domain}")
|
||||
|
||||
return row
|
||||
|
||||
|
||||
def update_domain_ban(self,
|
||||
domain: str,
|
||||
reason: str | None = None,
|
||||
note: str | None = None) -> Row:
|
||||
note: str | None = None) -> schema.DomainBan:
|
||||
|
||||
if not (reason or note):
|
||||
raise ValueError('"reason" and/or "note" must be specified')
|
||||
|
@ -298,7 +334,10 @@ class Connection(SqlConnection):
|
|||
if cur.row_count > 1:
|
||||
raise ValueError('More than one row was modified')
|
||||
|
||||
return self.get_domain_ban(domain)
|
||||
if (row := cur.one(schema.DomainBan)) is None:
|
||||
raise RuntimeError(f"Failed to update domain ban: {domain}")
|
||||
|
||||
return row
|
||||
|
||||
|
||||
def del_domain_ban(self, domain: str) -> bool:
|
||||
|
@ -309,15 +348,19 @@ class Connection(SqlConnection):
|
|||
return cur.row_count == 1
|
||||
|
||||
|
||||
def get_software_ban(self, name: str) -> Row:
|
||||
def get_software_ban(self, name: str) -> schema.SoftwareBan | None:
|
||||
with self.run('get-software-ban', {'name': name}) as cur:
|
||||
return cur.one() # type: ignore
|
||||
return cur.one(schema.SoftwareBan)
|
||||
|
||||
|
||||
def get_software_bans(self) -> Iterator[schema.SoftwareBan,]:
|
||||
return self.execute('SELECT * FROM software_bans').all(schema.SoftwareBan)
|
||||
|
||||
|
||||
def put_software_ban(self,
|
||||
name: str,
|
||||
reason: str | None = None,
|
||||
note: str | None = None) -> Row:
|
||||
note: str | None = None) -> schema.SoftwareBan:
|
||||
|
||||
params = {
|
||||
'name': name,
|
||||
|
@ -327,13 +370,16 @@ class Connection(SqlConnection):
|
|||
}
|
||||
|
||||
with self.run('put-software-ban', params) as cur:
|
||||
return cur.one() # type: ignore
|
||||
if (row := cur.one(schema.SoftwareBan)) is None:
|
||||
raise RuntimeError(f'Failed to insert software ban: {name}')
|
||||
|
||||
return row
|
||||
|
||||
|
||||
def update_software_ban(self,
|
||||
name: str,
|
||||
reason: str | None = None,
|
||||
note: str | None = None) -> Row:
|
||||
note: str | None = None) -> schema.SoftwareBan:
|
||||
|
||||
if not (reason or note):
|
||||
raise ValueError('"reason" and/or "note" must be specified')
|
||||
|
@ -353,7 +399,10 @@ class Connection(SqlConnection):
|
|||
if cur.row_count > 1:
|
||||
raise ValueError('More than one row was modified')
|
||||
|
||||
return self.get_software_ban(name)
|
||||
if (row := cur.one(schema.SoftwareBan)) is None:
|
||||
raise RuntimeError(f'Failed to update software ban: {name}')
|
||||
|
||||
return row
|
||||
|
||||
|
||||
def del_software_ban(self, name: str) -> bool:
|
||||
|
@ -364,19 +413,26 @@ class Connection(SqlConnection):
|
|||
return cur.row_count == 1
|
||||
|
||||
|
||||
def get_domain_whitelist(self, domain: str) -> Row:
|
||||
def get_domain_whitelist(self, domain: str) -> schema.Whitelist | None:
|
||||
with self.run('get-domain-whitelist', {'domain': domain}) as cur:
|
||||
return cur.one() # type: ignore
|
||||
return cur.one()
|
||||
|
||||
|
||||
def put_domain_whitelist(self, domain: str) -> Row:
|
||||
def get_domains_whitelist(self) -> Iterator[schema.Whitelist,]:
|
||||
return self.execute("SELECT * FROM whitelist").all(schema.Whitelist)
|
||||
|
||||
|
||||
def put_domain_whitelist(self, domain: str) -> schema.Whitelist:
|
||||
params = {
|
||||
'domain': domain,
|
||||
'created': datetime.now(tz = timezone.utc)
|
||||
}
|
||||
|
||||
with self.run('put-domain-whitelist', params) as cur:
|
||||
return cur.one() # type: ignore
|
||||
if (row := cur.one(schema.Whitelist)) is None:
|
||||
raise RuntimeError(f'Failed to insert whitelisted domain: {domain}')
|
||||
|
||||
return row
|
||||
|
||||
|
||||
def del_domain_whitelist(self, domain: str) -> bool:
|
||||
|
|
|
@ -1,61 +1,88 @@
|
|||
from bsql import Column, Table, Tables
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from bsql import Column, Row, Tables
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
|
||||
from .config import ConfigData
|
||||
from .connection import Connection
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from .connection import Connection
|
||||
|
||||
|
||||
VERSIONS: dict[int, Callable[[Connection], None]] = {}
|
||||
TABLES: Tables = Tables(
|
||||
Table(
|
||||
'config',
|
||||
Column('key', 'text', primary_key = True, unique = True, nullable = False),
|
||||
Column('value', 'text'),
|
||||
Column('type', 'text', default = 'str')
|
||||
),
|
||||
Table(
|
||||
'inboxes',
|
||||
Column('domain', 'text', primary_key = True, unique = True, nullable = False),
|
||||
Column('actor', 'text', unique = True),
|
||||
Column('inbox', 'text', unique = True, nullable = False),
|
||||
Column('followid', 'text'),
|
||||
Column('software', 'text'),
|
||||
Column('accepted', 'boolean'),
|
||||
Column('created', 'timestamp', nullable = False)
|
||||
),
|
||||
Table(
|
||||
'whitelist',
|
||||
Column('domain', 'text', primary_key = True, unique = True, nullable = True),
|
||||
Column('created', 'timestamp')
|
||||
),
|
||||
Table(
|
||||
'domain_bans',
|
||||
Column('domain', 'text', primary_key = True, unique = True, nullable = True),
|
||||
Column('reason', 'text'),
|
||||
Column('note', 'text'),
|
||||
Column('created', 'timestamp', nullable = False)
|
||||
),
|
||||
Table(
|
||||
'software_bans',
|
||||
Column('name', 'text', primary_key = True, unique = True, nullable = True),
|
||||
Column('reason', 'text'),
|
||||
Column('note', 'text'),
|
||||
Column('created', 'timestamp', nullable = False)
|
||||
),
|
||||
Table(
|
||||
'users',
|
||||
Column('username', 'text', primary_key = True, unique = True, nullable = False),
|
||||
Column('hash', 'text', nullable = False),
|
||||
Column('handle', 'text'),
|
||||
Column('created', 'timestamp', nullable = False)
|
||||
),
|
||||
Table(
|
||||
'tokens',
|
||||
Column('code', 'text', primary_key = True, unique = True, nullable = False),
|
||||
Column('user', 'text', nullable = False),
|
||||
Column('created', 'timestmap', nullable = False)
|
||||
)
|
||||
)
|
||||
TABLES = Tables()
|
||||
|
||||
|
||||
@TABLES.add_row
|
||||
class Config(Row):
|
||||
key: Column[str] = Column('key', 'text', primary_key = True, unique = True, nullable = False)
|
||||
value: Column[str] = Column('value', 'text')
|
||||
type: Column[str] = Column('type', 'text', default = 'str')
|
||||
|
||||
|
||||
@TABLES.add_row
|
||||
class Instance(Row):
|
||||
table_name: str = 'inboxes'
|
||||
|
||||
domain: Column[str] = Column(
|
||||
'domain', 'text', primary_key = True, unique = True, nullable = False)
|
||||
actor: Column[str] = Column('actor', 'text', unique = True)
|
||||
inbox: Column[str] = Column('inbox', 'text', unique = True, nullable = False)
|
||||
followid: Column[str] = Column('followid', 'text')
|
||||
software: Column[str] = Column('software', 'text')
|
||||
accepted: Column[datetime] = Column('accepted', 'boolean')
|
||||
created: Column[datetime] = Column('created', 'timestamp', nullable = False)
|
||||
|
||||
|
||||
@TABLES.add_row
|
||||
class Whitelist(Row):
|
||||
domain: Column[str] = Column(
|
||||
'domain', 'text', primary_key = True, unique = True, nullable = True)
|
||||
created: Column[datetime] = Column('created', 'timestamp')
|
||||
|
||||
|
||||
@TABLES.add_row
|
||||
class DomainBan(Row):
|
||||
table_name: str = 'domain_bans'
|
||||
|
||||
domain: Column[str] = Column(
|
||||
'domain', 'text', primary_key = True, unique = True, nullable = True)
|
||||
reason: Column[str] = Column('reason', 'text')
|
||||
note: Column[str] = Column('note', 'text')
|
||||
created: Column[datetime] = Column('created', 'timestamp')
|
||||
|
||||
|
||||
@TABLES.add_row
|
||||
class SoftwareBan(Row):
|
||||
table_name: str = 'software_bans'
|
||||
|
||||
name: Column[str] = Column('name', 'text', primary_key = True, unique = True, nullable = True)
|
||||
reason: Column[str] = Column('reason', 'text')
|
||||
note: Column[str] = Column('note', 'text')
|
||||
created: Column[datetime] = Column('created', 'timestamp')
|
||||
|
||||
|
||||
@TABLES.add_row
|
||||
class User(Row):
|
||||
table_name: str = 'users'
|
||||
|
||||
username: Column[str] = Column(
|
||||
'username', 'text', primary_key = True, unique = True, nullable = False)
|
||||
hash: Column[str] = Column('hash', 'text', nullable = False)
|
||||
handle: Column[str] = Column('handle', 'text')
|
||||
created: Column[datetime] = Column('created', 'timestamp')
|
||||
|
||||
|
||||
@TABLES.add_row
|
||||
class Token(Row):
|
||||
table_name: str = 'tokens'
|
||||
|
||||
code: Column[str] = Column('code', 'text', primary_key = True, unique = True, nullable = False)
|
||||
user: Column[str] = Column('user', 'text', nullable = False)
|
||||
created: Column[datetime] = Column('created', 'timestamp')
|
||||
|
||||
|
||||
def migration(func: Callable[[Connection], None]) -> Callable[[Connection], None]:
|
||||
|
|
|
@ -5,11 +5,11 @@ import json
|
|||
from aiohttp import ClientSession, ClientTimeout, TCPConnector
|
||||
from aputils import AlgorithmType, Nodeinfo, ObjectType, Signer, WellKnownNodeinfo
|
||||
from blib import JsonBase
|
||||
from bsql import Row
|
||||
from typing import TYPE_CHECKING, Any, TypeVar, overload
|
||||
|
||||
from . import __version__, logger as logging
|
||||
from .cache import Cache
|
||||
from .database.schema import Instance
|
||||
from .misc import MIMETYPES, Message, get_app
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -184,12 +184,12 @@ class HttpClient:
|
|||
return None
|
||||
|
||||
|
||||
async def post(self, url: str, data: Message | bytes, instance: Row | None = None) -> None:
|
||||
async def post(self, url: str, data: Message | bytes, instance: Instance | None = None) -> None:
|
||||
if not self._session:
|
||||
raise RuntimeError('Client not open')
|
||||
|
||||
# akkoma and pleroma do not support HS2019 and other software still needs to be tested
|
||||
if instance and instance['software'] in SUPPORTS_HS2019:
|
||||
if instance is not None and instance.software in SUPPORTS_HS2019:
|
||||
algorithm = AlgorithmType.HS2019
|
||||
|
||||
else:
|
||||
|
|
131
relay/manage.py
131
relay/manage.py
|
@ -6,7 +6,6 @@ import click
|
|||
import json
|
||||
import os
|
||||
|
||||
from bsql import Row
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any
|
||||
|
@ -17,7 +16,7 @@ from . import http_client as http
|
|||
from . import logger as logging
|
||||
from .application import Application
|
||||
from .compat import RelayConfig, RelayDatabase
|
||||
from .database import RELAY_SOFTWARE, get_database
|
||||
from .database import RELAY_SOFTWARE, get_database, schema
|
||||
from .misc import ACTOR_FORMATS, SOFTWARE, IS_DOCKER, Message
|
||||
|
||||
|
||||
|
@ -367,8 +366,8 @@ def cli_user_list(ctx: click.Context) -> None:
|
|||
click.echo('Users:')
|
||||
|
||||
with ctx.obj.database.session() as conn:
|
||||
for user in conn.execute('SELECT * FROM users'):
|
||||
click.echo(f'- {user["username"]}')
|
||||
for row in conn.get_users():
|
||||
click.echo(f'- {row.username}')
|
||||
|
||||
|
||||
@cli_user.command('create')
|
||||
|
@ -379,7 +378,7 @@ def cli_user_create(ctx: click.Context, username: str, handle: str) -> None:
|
|||
'Create a new local user'
|
||||
|
||||
with ctx.obj.database.session() as conn:
|
||||
if conn.get_user(username):
|
||||
if conn.get_user(username) is not None:
|
||||
click.echo(f'User already exists: {username}')
|
||||
return
|
||||
|
||||
|
@ -406,7 +405,7 @@ def cli_user_delete(ctx: click.Context, username: str) -> None:
|
|||
'Delete a local user'
|
||||
|
||||
with ctx.obj.database.session() as conn:
|
||||
if not conn.get_user(username):
|
||||
if conn.get_user(username) is None:
|
||||
click.echo(f'User does not exist: {username}')
|
||||
return
|
||||
|
||||
|
@ -424,8 +423,8 @@ def cli_user_list_tokens(ctx: click.Context, username: str) -> None:
|
|||
click.echo(f'Tokens for "{username}":')
|
||||
|
||||
with ctx.obj.database.session() as conn:
|
||||
for token in conn.execute('SELECT * FROM tokens WHERE user = :user', {'user': username}):
|
||||
click.echo(f'- {token["code"]}')
|
||||
for row in conn.get_tokens(username):
|
||||
click.echo(f'- {row.code}')
|
||||
|
||||
|
||||
@cli_user.command('create-token')
|
||||
|
@ -435,13 +434,13 @@ def cli_user_create_token(ctx: click.Context, username: str) -> None:
|
|||
'Create a new API token for a user'
|
||||
|
||||
with ctx.obj.database.session() as conn:
|
||||
if not (user := conn.get_user(username)):
|
||||
if (user := conn.get_user(username)) is None:
|
||||
click.echo(f'User does not exist: {username}')
|
||||
return
|
||||
|
||||
token = conn.put_token(user['username'])
|
||||
token = conn.put_token(user.username)
|
||||
|
||||
click.echo(f'New token for "{username}": {token["code"]}')
|
||||
click.echo(f'New token for "{username}": {token.code}')
|
||||
|
||||
|
||||
@cli_user.command('delete-token')
|
||||
|
@ -451,7 +450,7 @@ def cli_user_delete_token(ctx: click.Context, code: str) -> None:
|
|||
'Delete an API token'
|
||||
|
||||
with ctx.obj.database.session() as conn:
|
||||
if not conn.get_token(code):
|
||||
if conn.get_token(code) is None:
|
||||
click.echo('Token does not exist')
|
||||
return
|
||||
|
||||
|
@ -473,8 +472,8 @@ def cli_inbox_list(ctx: click.Context) -> None:
|
|||
click.echo('Connected to the following instances or relays:')
|
||||
|
||||
with ctx.obj.database.session() as conn:
|
||||
for inbox in conn.get_inboxes():
|
||||
click.echo(f'- {inbox["inbox"]}')
|
||||
for row in conn.get_inboxes():
|
||||
click.echo(f'- {row.inbox}')
|
||||
|
||||
|
||||
@cli_inbox.command('follow')
|
||||
|
@ -483,19 +482,21 @@ def cli_inbox_list(ctx: click.Context) -> None:
|
|||
def cli_inbox_follow(ctx: click.Context, actor: str) -> None:
|
||||
'Follow an actor (Relay must be running)'
|
||||
|
||||
instance: schema.Instance | None = None
|
||||
|
||||
with ctx.obj.database.session() as conn:
|
||||
if conn.get_domain_ban(actor):
|
||||
click.echo(f'Error: Refusing to follow banned actor: {actor}')
|
||||
return
|
||||
|
||||
if (inbox_data := conn.get_inbox(actor)):
|
||||
inbox = inbox_data['inbox']
|
||||
if (instance := conn.get_inbox(actor)) is not None:
|
||||
inbox = instance.inbox
|
||||
|
||||
else:
|
||||
if not actor.startswith('http'):
|
||||
actor = f'https://{actor}/actor'
|
||||
|
||||
if not (actor_data := asyncio.run(http.get(actor, sign_headers = True))):
|
||||
if (actor_data := asyncio.run(http.get(actor, sign_headers = True))) is None:
|
||||
click.echo(f'Failed to fetch actor: {actor}')
|
||||
return
|
||||
|
||||
|
@ -506,7 +507,7 @@ def cli_inbox_follow(ctx: click.Context, actor: str) -> None:
|
|||
actor = actor
|
||||
)
|
||||
|
||||
asyncio.run(http.post(inbox, message, inbox_data))
|
||||
asyncio.run(http.post(inbox, message, instance))
|
||||
click.echo(f'Sent follow message to actor: {actor}')
|
||||
|
||||
|
||||
|
@ -516,19 +517,19 @@ def cli_inbox_follow(ctx: click.Context, actor: str) -> None:
|
|||
def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None:
|
||||
'Unfollow an actor (Relay must be running)'
|
||||
|
||||
inbox_data: Row | None = None
|
||||
instance: schema.Instance | None = None
|
||||
|
||||
with ctx.obj.database.session() as conn:
|
||||
if conn.get_domain_ban(actor):
|
||||
click.echo(f'Error: Refusing to follow banned actor: {actor}')
|
||||
return
|
||||
|
||||
if (inbox_data := conn.get_inbox(actor)):
|
||||
inbox = inbox_data['inbox']
|
||||
if (instance := conn.get_inbox(actor)):
|
||||
inbox = instance.inbox
|
||||
message = Message.new_unfollow(
|
||||
host = ctx.obj.config.domain,
|
||||
actor = actor,
|
||||
follow = inbox_data['followid']
|
||||
follow = instance.followid
|
||||
)
|
||||
|
||||
else:
|
||||
|
@ -552,7 +553,7 @@ def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None:
|
|||
}
|
||||
)
|
||||
|
||||
asyncio.run(http.post(inbox, message, inbox_data))
|
||||
asyncio.run(http.post(inbox, message, instance))
|
||||
click.echo(f'Sent unfollow message to: {actor}')
|
||||
|
||||
|
||||
|
@ -632,9 +633,9 @@ def cli_request_list(ctx: click.Context) -> None:
|
|||
click.echo('Follow requests:')
|
||||
|
||||
with ctx.obj.database.session() as conn:
|
||||
for instance in conn.get_requests():
|
||||
date = instance['created'].strftime('%Y-%m-%d')
|
||||
click.echo(f'- [{date}] {instance["domain"]}')
|
||||
for row in conn.get_requests():
|
||||
date = row.created.strftime('%Y-%m-%d')
|
||||
click.echo(f'- [{date}] {row.domain}')
|
||||
|
||||
|
||||
@cli_request.command('accept')
|
||||
|
@ -653,20 +654,20 @@ def cli_request_accept(ctx: click.Context, domain: str) -> None:
|
|||
|
||||
message = Message.new_response(
|
||||
host = ctx.obj.config.domain,
|
||||
actor = instance['actor'],
|
||||
followid = instance['followid'],
|
||||
actor = instance.actor,
|
||||
followid = instance.followid,
|
||||
accept = True
|
||||
)
|
||||
|
||||
asyncio.run(http.post(instance['inbox'], message, instance))
|
||||
asyncio.run(http.post(instance.inbox, message, instance))
|
||||
|
||||
if instance['software'] != 'mastodon':
|
||||
if instance.software != 'mastodon':
|
||||
message = Message.new_follow(
|
||||
host = ctx.obj.config.domain,
|
||||
actor = instance['actor']
|
||||
actor = instance.actor
|
||||
)
|
||||
|
||||
asyncio.run(http.post(instance['inbox'], message, instance))
|
||||
asyncio.run(http.post(instance.inbox, message, instance))
|
||||
|
||||
|
||||
@cli_request.command('deny')
|
||||
|
@ -685,12 +686,12 @@ def cli_request_deny(ctx: click.Context, domain: str) -> None:
|
|||
|
||||
response = Message.new_response(
|
||||
host = ctx.obj.config.domain,
|
||||
actor = instance['actor'],
|
||||
followid = instance['followid'],
|
||||
actor = instance.actor,
|
||||
followid = instance.followid,
|
||||
accept = False
|
||||
)
|
||||
|
||||
asyncio.run(http.post(instance['inbox'], response, instance))
|
||||
asyncio.run(http.post(instance.inbox, response, instance))
|
||||
|
||||
|
||||
@cli.group('instance')
|
||||
|
@ -706,12 +707,12 @@ def cli_instance_list(ctx: click.Context) -> None:
|
|||
click.echo('Banned domains:')
|
||||
|
||||
with ctx.obj.database.session() as conn:
|
||||
for instance in conn.execute('SELECT * FROM domain_bans'):
|
||||
if instance['reason']:
|
||||
click.echo(f'- {instance["domain"]} ({instance["reason"]})')
|
||||
for row in conn.get_domain_bans():
|
||||
if row.reason is not None:
|
||||
click.echo(f'- {row.domain} ({row.reason})')
|
||||
|
||||
else:
|
||||
click.echo(f'- {instance["domain"]}')
|
||||
click.echo(f'- {row.domain}')
|
||||
|
||||
|
||||
@cli_instance.command('ban')
|
||||
|
@ -723,7 +724,7 @@ def cli_instance_ban(ctx: click.Context, domain: str, reason: str, note: str) ->
|
|||
'Ban an instance and remove the associated inbox if it exists'
|
||||
|
||||
with ctx.obj.database.session() as conn:
|
||||
if conn.get_domain_ban(domain):
|
||||
if conn.get_domain_ban(domain) is not None:
|
||||
click.echo(f'Domain already banned: {domain}')
|
||||
return
|
||||
|
||||
|
@ -739,7 +740,7 @@ def cli_instance_unban(ctx: click.Context, domain: str) -> None:
|
|||
'Unban an instance'
|
||||
|
||||
with ctx.obj.database.session() as conn:
|
||||
if not conn.del_domain_ban(domain):
|
||||
if conn.del_domain_ban(domain) is None:
|
||||
click.echo(f'Instance wasn\'t banned: {domain}')
|
||||
return
|
||||
|
||||
|
@ -764,11 +765,11 @@ def cli_instance_update(ctx: click.Context, domain: str, reason: str, note: str)
|
|||
|
||||
click.echo(f'Updated domain ban: {domain}')
|
||||
|
||||
if row['reason']:
|
||||
click.echo(f'- {row["domain"]} ({row["reason"]})')
|
||||
if row.reason:
|
||||
click.echo(f'- {row.domain} ({row.reason})')
|
||||
|
||||
else:
|
||||
click.echo(f'- {row["domain"]}')
|
||||
click.echo(f'- {row.domain}')
|
||||
|
||||
|
||||
@cli.group('software')
|
||||
|
@ -784,12 +785,12 @@ def cli_software_list(ctx: click.Context) -> None:
|
|||
click.echo('Banned software:')
|
||||
|
||||
with ctx.obj.database.session() as conn:
|
||||
for software in conn.execute('SELECT * FROM software_bans'):
|
||||
if software['reason']:
|
||||
click.echo(f'- {software["name"]} ({software["reason"]})')
|
||||
for row in conn.get_software_bans():
|
||||
if row.reason:
|
||||
click.echo(f'- {row.name} ({row.reason})')
|
||||
|
||||
else:
|
||||
click.echo(f'- {software["name"]}')
|
||||
click.echo(f'- {row.name}')
|
||||
|
||||
|
||||
@cli_software.command('ban')
|
||||
|
@ -811,12 +812,12 @@ def cli_software_ban(ctx: click.Context,
|
|||
|
||||
with ctx.obj.database.session() as conn:
|
||||
if name == 'RELAYS':
|
||||
for software in RELAY_SOFTWARE:
|
||||
if conn.get_software_ban(software):
|
||||
click.echo(f'Relay already banned: {software}')
|
||||
for item in RELAY_SOFTWARE:
|
||||
if conn.get_software_ban(item):
|
||||
click.echo(f'Relay already banned: {item}')
|
||||
continue
|
||||
|
||||
conn.put_software_ban(software, reason or 'relay', note)
|
||||
conn.put_software_ban(item, reason or 'relay', note)
|
||||
|
||||
click.echo('Banned all relay software')
|
||||
return
|
||||
|
@ -893,11 +894,11 @@ def cli_software_update(ctx: click.Context, name: str, reason: str, note: str) -
|
|||
|
||||
click.echo(f'Updated software ban: {name}')
|
||||
|
||||
if row['reason']:
|
||||
click.echo(f'- {row["name"]} ({row["reason"]})')
|
||||
if row.reason:
|
||||
click.echo(f'- {row.name} ({row.reason})')
|
||||
|
||||
else:
|
||||
click.echo(f'- {row["name"]}')
|
||||
click.echo(f'- {row.name}')
|
||||
|
||||
|
||||
@cli.group('whitelist')
|
||||
|
@ -913,8 +914,8 @@ def cli_whitelist_list(ctx: click.Context) -> None:
|
|||
click.echo('Current whitelisted domains:')
|
||||
|
||||
with ctx.obj.database.session() as conn:
|
||||
for domain in conn.execute('SELECT * FROM whitelist'):
|
||||
click.echo(f'- {domain["domain"]}')
|
||||
for row in conn.get_domain_whitelist():
|
||||
click.echo(f'- {row.domain}')
|
||||
|
||||
|
||||
@cli_whitelist.command('add')
|
||||
|
@ -953,23 +954,19 @@ def cli_whitelist_remove(ctx: click.Context, domain: str) -> None:
|
|||
@cli_whitelist.command('import')
|
||||
@click.pass_context
|
||||
def cli_whitelist_import(ctx: click.Context) -> None:
|
||||
'Add all current inboxes to the whitelist'
|
||||
'Add all current instances to the whitelist'
|
||||
|
||||
with ctx.obj.database.session() as conn:
|
||||
for inbox in conn.execute('SELECT * FROM inboxes').all():
|
||||
if conn.get_domain_whitelist(inbox['domain']):
|
||||
click.echo(f'Domain already in whitelist: {inbox["domain"]}')
|
||||
for row in conn.get_inboxes():
|
||||
if conn.get_domain_whitelist(row.domain) is not None:
|
||||
click.echo(f'Domain already in whitelist: {row.domain}')
|
||||
continue
|
||||
|
||||
conn.put_domain_whitelist(inbox['domain'])
|
||||
conn.put_domain_whitelist(row.domain)
|
||||
|
||||
click.echo('Imported whitelist from inboxes')
|
||||
|
||||
|
||||
|
||||
def main() -> None:
|
||||
cli(prog_name='relay')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
click.echo('Running relay.manage is depreciated. Run `activityrelay [command]` instead.')
|
||||
cli(prog_name='activityrelay')
|
||||
|
|
|
@ -34,7 +34,7 @@ async def handle_relay(view: ActorView, conn: Connection) -> None:
|
|||
logging.debug('>> relay: %s', message)
|
||||
|
||||
for instance in conn.distill_inboxes(view.message):
|
||||
view.app.push_message(instance["inbox"], message, instance)
|
||||
view.app.push_message(instance.inbox, message, instance)
|
||||
|
||||
view.cache.set('handle-relay', view.message.object_id, message.id, 'str')
|
||||
|
||||
|
@ -52,7 +52,7 @@ async def handle_forward(view: ActorView, conn: Connection) -> None:
|
|||
logging.debug('>> forward: %s', message)
|
||||
|
||||
for instance in conn.distill_inboxes(view.message):
|
||||
view.app.push_message(instance["inbox"], view.message, instance)
|
||||
view.app.push_message(instance.inbox, view.message, instance)
|
||||
|
||||
view.cache.set('handle-relay', view.message.id, message.id, 'str')
|
||||
|
||||
|
@ -177,7 +177,7 @@ async def handle_undo(view: ActorView, conn: Connection) -> None:
|
|||
return
|
||||
|
||||
# prevent past unfollows from removing an instance
|
||||
if view.instance['followid'] and view.instance['followid'] != view.message.object_id:
|
||||
if view.instance.followid and view.instance.followid != view.message.object_id:
|
||||
return
|
||||
|
||||
with conn.transaction():
|
||||
|
@ -221,18 +221,18 @@ async def run_processor(view: ActorView) -> None:
|
|||
|
||||
with view.database.session() as conn:
|
||||
if view.instance:
|
||||
if not view.instance['software']:
|
||||
if (nodeinfo := await view.client.fetch_nodeinfo(view.instance['domain'])):
|
||||
if not view.instance.software:
|
||||
if (nodeinfo := await view.client.fetch_nodeinfo(view.instance.domain)):
|
||||
with conn.transaction():
|
||||
view.instance = conn.put_inbox(
|
||||
domain = view.instance['domain'],
|
||||
domain = view.instance.domain,
|
||||
software = nodeinfo.sw_name
|
||||
)
|
||||
|
||||
if not view.instance['actor']:
|
||||
if not view.instance.actor:
|
||||
with conn.transaction():
|
||||
view.instance = conn.put_inbox(
|
||||
domain = view.instance['domain'],
|
||||
domain = view.instance.domain,
|
||||
actor = view.actor.id
|
||||
)
|
||||
|
||||
|
|
|
@ -1,26 +1,22 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import aputils
|
||||
import traceback
|
||||
import typing
|
||||
|
||||
from aiohttp.web import Request
|
||||
|
||||
from .base import View, register_route
|
||||
|
||||
from .. import logger as logging
|
||||
from ..database import schema
|
||||
from ..misc import Message, Response
|
||||
from ..processors import run_processor
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from aiohttp.web import Request
|
||||
from bsql import Row
|
||||
|
||||
|
||||
@register_route('/actor', '/inbox')
|
||||
class ActorView(View):
|
||||
signature: aputils.Signature
|
||||
message: Message
|
||||
actor: Message
|
||||
instancce: Row
|
||||
instance: schema.Instance
|
||||
signer: aputils.Signer
|
||||
|
||||
|
||||
|
@ -47,7 +43,7 @@ class ActorView(View):
|
|||
return response
|
||||
|
||||
with self.database.session() as conn:
|
||||
self.instance = conn.get_inbox(self.actor.shared_inbox)
|
||||
self.instance = conn.get_inbox(self.actor.shared_inbox) # type: ignore[assignment]
|
||||
|
||||
# reject if actor is banned
|
||||
if conn.get_domain_ban(self.actor.domain):
|
||||
|
|
|
@ -90,10 +90,10 @@ class Login(View):
|
|||
|
||||
token = conn.put_token(data['username'])
|
||||
|
||||
resp = Response.new({'token': token['code']}, ctype = 'json')
|
||||
resp = Response.new({'token': token.code}, ctype = 'json')
|
||||
resp.set_cookie(
|
||||
'user-token',
|
||||
token['code'],
|
||||
token.code,
|
||||
max_age = 60 * 60 * 24 * 365,
|
||||
domain = self.config.domain,
|
||||
path = '/',
|
||||
|
@ -117,7 +117,7 @@ class RelayInfo(View):
|
|||
async def get(self, request: Request) -> Response:
|
||||
with self.database.session() as conn:
|
||||
config = conn.get_config_all()
|
||||
inboxes = [row['domain'] for row in conn.get_inboxes()]
|
||||
inboxes = [row.domain for row in conn.get_inboxes()]
|
||||
|
||||
data = {
|
||||
'domain': self.config.domain,
|
||||
|
@ -188,7 +188,7 @@ class Config(View):
|
|||
class Inbox(View):
|
||||
async def get(self, request: Request) -> Response:
|
||||
with self.database.session() as conn:
|
||||
data = conn.get_inboxes()
|
||||
data = tuple(conn.get_inboxes())
|
||||
|
||||
return Response.new(data, ctype = 'json')
|
||||
|
||||
|
@ -202,7 +202,7 @@ class Inbox(View):
|
|||
data['domain'] = urlparse(data["actor"]).netloc
|
||||
|
||||
with self.database.session() as conn:
|
||||
if conn.get_inbox(data['domain']):
|
||||
if conn.get_inbox(data['domain']) is not None:
|
||||
return Response.new_error(404, 'Instance already in database', 'json')
|
||||
|
||||
data['domain'] = data['domain'].encode('idna').decode()
|
||||
|
@ -225,7 +225,12 @@ class Inbox(View):
|
|||
except Exception:
|
||||
pass
|
||||
|
||||
row = conn.put_inbox(**data) # type: ignore[arg-type]
|
||||
row = conn.put_inbox(
|
||||
data['domain'],
|
||||
actor = data.get('actor'),
|
||||
software = data.get('software'),
|
||||
followid = data.get('followid')
|
||||
)
|
||||
|
||||
return Response.new(row, ctype = 'json')
|
||||
|
||||
|
@ -239,10 +244,15 @@ class Inbox(View):
|
|||
|
||||
data['domain'] = data['domain'].encode('idna').decode()
|
||||
|
||||
if not (instance := conn.get_inbox(data['domain'])):
|
||||
if (instance := conn.get_inbox(data['domain'])) is None:
|
||||
return Response.new_error(404, 'Instance with domain not found', 'json')
|
||||
|
||||
instance = conn.put_inbox(instance['domain'], **data) # type: ignore[arg-type]
|
||||
instance = conn.put_inbox(
|
||||
instance.domain,
|
||||
actor = data.get('actor'),
|
||||
software = data.get('software'),
|
||||
followid = data.get('followid')
|
||||
)
|
||||
|
||||
return Response.new(instance, ctype = 'json')
|
||||
|
||||
|
@ -268,7 +278,7 @@ class Inbox(View):
|
|||
class RequestView(View):
|
||||
async def get(self, request: Request) -> Response:
|
||||
with self.database.session() as conn:
|
||||
instances = conn.get_requests()
|
||||
instances = tuple(conn.get_requests())
|
||||
|
||||
return Response.new(instances, ctype = 'json')
|
||||
|
||||
|
@ -291,20 +301,20 @@ class RequestView(View):
|
|||
|
||||
message = Message.new_response(
|
||||
host = self.config.domain,
|
||||
actor = instance['actor'],
|
||||
followid = instance['followid'],
|
||||
actor = instance.actor,
|
||||
followid = instance.followid,
|
||||
accept = data['accept']
|
||||
)
|
||||
|
||||
self.app.push_message(instance['inbox'], message, instance)
|
||||
self.app.push_message(instance.inbox, message, instance)
|
||||
|
||||
if data['accept'] and instance['software'] != 'mastodon':
|
||||
if data['accept'] and instance.software != 'mastodon':
|
||||
message = Message.new_follow(
|
||||
host = self.config.domain,
|
||||
actor = instance['actor']
|
||||
actor = instance.actor
|
||||
)
|
||||
|
||||
self.app.push_message(instance['inbox'], message, instance)
|
||||
self.app.push_message(instance.inbox, message, instance)
|
||||
|
||||
resp_message = {'message': 'Request accepted' if data['accept'] else 'Request denied'}
|
||||
return Response.new(resp_message, ctype = 'json')
|
||||
|
@ -314,7 +324,7 @@ class RequestView(View):
|
|||
class DomainBan(View):
|
||||
async def get(self, request: Request) -> Response:
|
||||
with self.database.session() as conn:
|
||||
bans = tuple(conn.execute('SELECT * FROM domain_bans').all())
|
||||
bans = tuple(conn.get_domain_bans())
|
||||
|
||||
return Response.new(bans, ctype = 'json')
|
||||
|
||||
|
@ -328,10 +338,14 @@ class DomainBan(View):
|
|||
data['domain'] = data['domain'].encode('idna').decode()
|
||||
|
||||
with self.database.session() as conn:
|
||||
if conn.get_domain_ban(data['domain']):
|
||||
if conn.get_domain_ban(data['domain']) is not None:
|
||||
return Response.new_error(400, 'Domain already banned', 'json')
|
||||
|
||||
ban = conn.put_domain_ban(**data)
|
||||
ban = conn.put_domain_ban(
|
||||
domain = data['domain'],
|
||||
reason = data.get('reason'),
|
||||
note = data.get('note')
|
||||
)
|
||||
|
||||
return Response.new(ban, ctype = 'json')
|
||||
|
||||
|
@ -343,15 +357,19 @@ class DomainBan(View):
|
|||
if isinstance(data, Response):
|
||||
return data
|
||||
|
||||
data['domain'] = data['domain'].encode('idna').decode()
|
||||
|
||||
if not conn.get_domain_ban(data['domain']):
|
||||
return Response.new_error(404, 'Domain not banned', 'json')
|
||||
|
||||
if not any([data.get('note'), data.get('reason')]):
|
||||
return Response.new_error(400, 'Must include note and/or reason parameters', 'json')
|
||||
|
||||
ban = conn.update_domain_ban(**data)
|
||||
data['domain'] = data['domain'].encode('idna').decode()
|
||||
|
||||
if conn.get_domain_ban(data['domain']) is None:
|
||||
return Response.new_error(404, 'Domain not banned', 'json')
|
||||
|
||||
ban = conn.update_domain_ban(
|
||||
domain = data['domain'],
|
||||
reason = data.get('reason'),
|
||||
note = data.get('note')
|
||||
)
|
||||
|
||||
return Response.new(ban, ctype = 'json')
|
||||
|
||||
|
@ -365,7 +383,7 @@ class DomainBan(View):
|
|||
|
||||
data['domain'] = data['domain'].encode('idna').decode()
|
||||
|
||||
if not conn.get_domain_ban(data['domain']):
|
||||
if conn.get_domain_ban(data['domain']) is None:
|
||||
return Response.new_error(404, 'Domain not banned', 'json')
|
||||
|
||||
conn.del_domain_ban(data['domain'])
|
||||
|
@ -377,7 +395,7 @@ class DomainBan(View):
|
|||
class SoftwareBan(View):
|
||||
async def get(self, request: Request) -> Response:
|
||||
with self.database.session() as conn:
|
||||
bans = tuple(conn.execute('SELECT * FROM software_bans').all())
|
||||
bans = tuple(conn.get_software_bans())
|
||||
|
||||
return Response.new(bans, ctype = 'json')
|
||||
|
||||
|
@ -389,10 +407,14 @@ class SoftwareBan(View):
|
|||
return data
|
||||
|
||||
with self.database.session() as conn:
|
||||
if conn.get_software_ban(data['name']):
|
||||
if conn.get_software_ban(data['name']) is not None:
|
||||
return Response.new_error(400, 'Domain already banned', 'json')
|
||||
|
||||
ban = conn.put_software_ban(**data)
|
||||
ban = conn.put_software_ban(
|
||||
name = data['name'],
|
||||
reason = data.get('reason'),
|
||||
note = data.get('note')
|
||||
)
|
||||
|
||||
return Response.new(ban, ctype = 'json')
|
||||
|
||||
|
@ -403,14 +425,18 @@ class SoftwareBan(View):
|
|||
if isinstance(data, Response):
|
||||
return data
|
||||
|
||||
if not any([data.get('note'), data.get('reason')]):
|
||||
return Response.new_error(400, 'Must include note and/or reason parameters', 'json')
|
||||
|
||||
with self.database.session() as conn:
|
||||
if not conn.get_software_ban(data['name']):
|
||||
if conn.get_software_ban(data['name']) is None:
|
||||
return Response.new_error(404, 'Software not banned', 'json')
|
||||
|
||||
if not any([data.get('note'), data.get('reason')]):
|
||||
return Response.new_error(400, 'Must include note and/or reason parameters', 'json')
|
||||
|
||||
ban = conn.update_software_ban(**data)
|
||||
ban = conn.update_software_ban(
|
||||
name = data['name'],
|
||||
reason = data.get('reason'),
|
||||
note = data.get('note')
|
||||
)
|
||||
|
||||
return Response.new(ban, ctype = 'json')
|
||||
|
||||
|
@ -422,7 +448,7 @@ class SoftwareBan(View):
|
|||
return data
|
||||
|
||||
with self.database.session() as conn:
|
||||
if not conn.get_software_ban(data['name']):
|
||||
if conn.get_software_ban(data['name']) is None:
|
||||
return Response.new_error(404, 'Software not banned', 'json')
|
||||
|
||||
conn.del_software_ban(data['name'])
|
||||
|
@ -436,7 +462,7 @@ class User(View):
|
|||
with self.database.session() as conn:
|
||||
items = []
|
||||
|
||||
for row in conn.execute('SELECT * FROM users'):
|
||||
for row in conn.get_users():
|
||||
del row['hash']
|
||||
items.append(row)
|
||||
|
||||
|
@ -450,12 +476,16 @@ class User(View):
|
|||
return data
|
||||
|
||||
with self.database.session() as conn:
|
||||
if conn.get_user(data['username']):
|
||||
if conn.get_user(data['username']) is not None:
|
||||
return Response.new_error(404, 'User already exists', 'json')
|
||||
|
||||
user = conn.put_user(**data)
|
||||
del user['hash']
|
||||
user = conn.put_user(
|
||||
username = data['username'],
|
||||
password = data['password'],
|
||||
handle = data.get('handle')
|
||||
)
|
||||
|
||||
del user['hash']
|
||||
return Response.new(user, ctype = 'json')
|
||||
|
||||
|
||||
|
@ -466,9 +496,13 @@ class User(View):
|
|||
return data
|
||||
|
||||
with self.database.session(True) as conn:
|
||||
user = conn.put_user(**data)
|
||||
del user['hash']
|
||||
user = conn.put_user(
|
||||
username = data['username'],
|
||||
password = data['password'],
|
||||
handle = data.get('handle')
|
||||
)
|
||||
|
||||
del user['hash']
|
||||
return Response.new(user, ctype = 'json')
|
||||
|
||||
|
||||
|
@ -479,7 +513,7 @@ class User(View):
|
|||
return data
|
||||
|
||||
with self.database.session(True) as conn:
|
||||
if not conn.get_user(data['username']):
|
||||
if conn.get_user(data['username']) is None:
|
||||
return Response.new_error(404, 'User does not exist', 'json')
|
||||
|
||||
conn.del_user(data['username'])
|
||||
|
@ -491,7 +525,7 @@ class User(View):
|
|||
class Whitelist(View):
|
||||
async def get(self, request: Request) -> Response:
|
||||
with self.database.session() as conn:
|
||||
items = tuple(conn.execute('SELECT * FROM whitelist').all())
|
||||
items = tuple(conn.get_domains_whitelist())
|
||||
|
||||
return Response.new(items, ctype = 'json')
|
||||
|
||||
|
@ -502,13 +536,13 @@ class Whitelist(View):
|
|||
if isinstance(data, Response):
|
||||
return data
|
||||
|
||||
data['domain'] = data['domain'].encode('idna').decode()
|
||||
domain = data['domain'].encode('idna').decode()
|
||||
|
||||
with self.database.session() as conn:
|
||||
if conn.get_domain_whitelist(data['domain']):
|
||||
if conn.get_domain_whitelist(domain) is not None:
|
||||
return Response.new_error(400, 'Domain already added to whitelist', 'json')
|
||||
|
||||
item = conn.put_domain_whitelist(**data)
|
||||
item = conn.put_domain_whitelist(domain)
|
||||
|
||||
return Response.new(item, ctype = 'json')
|
||||
|
||||
|
@ -519,12 +553,12 @@ class Whitelist(View):
|
|||
if isinstance(data, Response):
|
||||
return data
|
||||
|
||||
data['domain'] = data['domain'].encode('idna').decode()
|
||||
domain = data['domain'].encode('idna').decode()
|
||||
|
||||
with self.database.session() as conn:
|
||||
if not conn.get_domain_whitelist(data['domain']):
|
||||
if conn.get_domain_whitelist(domain) is None:
|
||||
return Response.new_error(404, 'Domain not in whitelist', 'json')
|
||||
|
||||
conn.del_domain_whitelist(data['domain'])
|
||||
conn.del_domain_whitelist(domain)
|
||||
|
||||
return Response.new({'message': 'Removed domain from whitelist'}, ctype = 'json')
|
||||
|
|
|
@ -202,7 +202,7 @@ class AdminConfig(View):
|
|||
'message': message,
|
||||
'desc': {
|
||||
"name": "Name of the relay to be displayed in the header of the pages and in " +
|
||||
"the actor endpoint.",
|
||||
"the actor endpoint.", # noqa: E131
|
||||
"note": "Description of the relay to be displayed on the front page and as the " +
|
||||
"bio in the actor endpoint.",
|
||||
"theme": "Color theme to use on the web pages.",
|
||||
|
|
|
@ -4,7 +4,6 @@ import typing
|
|||
|
||||
from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError
|
||||
from asyncio.exceptions import TimeoutError as AsyncTimeoutError
|
||||
from bsql import Row
|
||||
from dataclasses import dataclass
|
||||
from multiprocessing import Event, Process, Queue, Value
|
||||
from multiprocessing.synchronize import Event as EventType
|
||||
|
@ -13,6 +12,7 @@ from queue import Empty, Queue as QueueType
|
|||
from urllib.parse import urlparse
|
||||
|
||||
from . import application, logger as logging
|
||||
from .database.schema import Instance
|
||||
from .http_client import HttpClient
|
||||
from .misc import IS_WINDOWS, Message, get_app
|
||||
|
||||
|
@ -29,7 +29,7 @@ class QueueItem:
|
|||
class PostItem(QueueItem):
|
||||
inbox: str
|
||||
message: Message
|
||||
instance: Row | None
|
||||
instance: Instance | None
|
||||
|
||||
@property
|
||||
def domain(self) -> str:
|
||||
|
@ -122,7 +122,7 @@ class PushWorkers(list[PushWorker]):
|
|||
self.queue.put(item)
|
||||
|
||||
|
||||
def push_message(self, inbox: str, message: Message, instance: Row) -> None:
|
||||
def push_message(self, inbox: str, message: Message, instance: Instance) -> None:
|
||||
self.queue.put(PostItem(inbox, message, instance))
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue