mirror of
https://git.pleroma.social/pleroma/relay.git
synced 2024-11-22 14:38:00 +00:00
update barkshark-sql to 0.2.0-rc1 and create row classes
This commit is contained in:
parent
45b0de26c7
commit
bdc7d41d7a
|
@ -20,7 +20,7 @@ dependencies = [
|
||||||
"aiohttp >= 3.9.5",
|
"aiohttp >= 3.9.5",
|
||||||
"aiohttp-swagger[performance] == 1.0.16",
|
"aiohttp-swagger[performance] == 1.0.16",
|
||||||
"argon2-cffi == 23.1.0",
|
"argon2-cffi == 23.1.0",
|
||||||
"barkshark-lib >= 0.1.3-1",
|
"barkshark-lib >= 0.2.0-rc1",
|
||||||
"barkshark-sql == 0.1.4-1",
|
"barkshark-sql == 0.1.4-1",
|
||||||
"click >= 8.1.2",
|
"click >= 8.1.2",
|
||||||
"hiredis == 2.3.2",
|
"hiredis == 2.3.2",
|
||||||
|
@ -104,7 +104,3 @@ implicit_reexport = true
|
||||||
[[tool.mypy.overrides]]
|
[[tool.mypy.overrides]]
|
||||||
module = "blib"
|
module = "blib"
|
||||||
implicit_reexport = true
|
implicit_reexport = true
|
||||||
|
|
||||||
[[tool.mypy.overrides]]
|
|
||||||
module = "bsql"
|
|
||||||
implicit_reexport = true
|
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
__version__ = '0.3.2'
|
__version__ = '0.3.3'
|
||||||
|
|
|
@ -4,30 +4,26 @@ import asyncio
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import signal
|
import signal
|
||||||
import time
|
import time
|
||||||
import traceback
|
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError
|
|
||||||
from aiohttp.web import StaticResource
|
from aiohttp.web import StaticResource
|
||||||
from aiohttp_swagger import setup_swagger
|
from aiohttp_swagger import setup_swagger
|
||||||
from aputils.signer import Signer
|
from aputils.signer import Signer
|
||||||
from asyncio.exceptions import TimeoutError as AsyncTimeoutError
|
from bsql import Database
|
||||||
from bsql import Database, Row
|
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from mimetypes import guess_type
|
from mimetypes import guess_type
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from queue import Empty
|
|
||||||
from threading import Event, Thread
|
from threading import Event, Thread
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
from . import logger as logging, workers
|
from . import logger as logging, workers
|
||||||
from .cache import Cache, get_cache
|
from .cache import Cache, get_cache
|
||||||
from .config import Config
|
from .config import Config
|
||||||
from .database import Connection, get_database
|
from .database import Connection, get_database
|
||||||
|
from .database.schema import Instance
|
||||||
from .http_client import HttpClient
|
from .http_client import HttpClient
|
||||||
from .misc import IS_WINDOWS, Message, Response, check_open_port, get_resource
|
from .misc import Message, Response, check_open_port, get_resource
|
||||||
from .template import Template
|
from .template import Template
|
||||||
from .views import VIEWS
|
from .views import VIEWS
|
||||||
from .views.api import handle_api_path
|
from .views.api import handle_api_path
|
||||||
|
@ -142,7 +138,7 @@ class Application(web.Application):
|
||||||
return timedelta(seconds=uptime.seconds)
|
return timedelta(seconds=uptime.seconds)
|
||||||
|
|
||||||
|
|
||||||
def push_message(self, inbox: str, message: Message, instance: Row) -> None:
|
def push_message(self, inbox: str, message: Message, instance: Instance) -> None:
|
||||||
self['workers'].push_message(inbox, message, instance)
|
self['workers'].push_message(inbox, message, instance)
|
||||||
|
|
||||||
|
|
||||||
|
@ -286,67 +282,6 @@ class CacheCleanupThread(Thread):
|
||||||
self.running.clear()
|
self.running.clear()
|
||||||
|
|
||||||
|
|
||||||
class PushWorker(multiprocessing.Process):
|
|
||||||
def __init__(self, queue: multiprocessing.Queue[tuple[str, Message, Row]]) -> None:
|
|
||||||
if Application.DEFAULT is None:
|
|
||||||
raise RuntimeError('Application not setup yet')
|
|
||||||
|
|
||||||
multiprocessing.Process.__init__(self)
|
|
||||||
|
|
||||||
self.queue = queue
|
|
||||||
self.shutdown = multiprocessing.Event()
|
|
||||||
self.path = Application.DEFAULT.config.path
|
|
||||||
|
|
||||||
|
|
||||||
def stop(self) -> None:
|
|
||||||
self.shutdown.set()
|
|
||||||
|
|
||||||
|
|
||||||
def run(self) -> None:
|
|
||||||
asyncio.run(self.handle_queue())
|
|
||||||
|
|
||||||
|
|
||||||
async def handle_queue(self) -> None:
|
|
||||||
if IS_WINDOWS:
|
|
||||||
app = Application(self.path)
|
|
||||||
client = app.client
|
|
||||||
|
|
||||||
client.open()
|
|
||||||
app.database.connect()
|
|
||||||
app.cache.setup()
|
|
||||||
|
|
||||||
else:
|
|
||||||
client = HttpClient()
|
|
||||||
client.open()
|
|
||||||
|
|
||||||
while not self.shutdown.is_set():
|
|
||||||
try:
|
|
||||||
inbox, message, instance = self.queue.get(block=True, timeout=0.1)
|
|
||||||
asyncio.create_task(client.post(inbox, message, instance))
|
|
||||||
|
|
||||||
except Empty:
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
|
|
||||||
except ClientSSLError as e:
|
|
||||||
logging.error('SSL error when pushing to %s: %s', urlparse(inbox).netloc, str(e))
|
|
||||||
|
|
||||||
except (AsyncTimeoutError, ClientConnectionError) as e:
|
|
||||||
logging.error(
|
|
||||||
'Failed to connect to %s for message push: %s',
|
|
||||||
urlparse(inbox).netloc, str(e)
|
|
||||||
)
|
|
||||||
|
|
||||||
# make sure an exception doesn't bring down the worker
|
|
||||||
except Exception:
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
if IS_WINDOWS:
|
|
||||||
app.database.disconnect()
|
|
||||||
app.cache.close()
|
|
||||||
|
|
||||||
await client.close()
|
|
||||||
|
|
||||||
|
|
||||||
@web.middleware
|
@web.middleware
|
||||||
async def handle_response_headers(
|
async def handle_response_headers(
|
||||||
request: web.Request,
|
request: web.Request,
|
||||||
|
|
|
@ -4,7 +4,7 @@ import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from bsql import Database
|
from bsql import Database, Row
|
||||||
from collections.abc import Callable, Iterator
|
from collections.abc import Callable, Iterator
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
|
@ -172,7 +172,7 @@ class SqlCache(Cache):
|
||||||
|
|
||||||
with self._db.session(False) as conn:
|
with self._db.session(False) as conn:
|
||||||
with conn.run('get-cache-item', params) as cur:
|
with conn.run('get-cache-item', params) as cur:
|
||||||
if not (row := cur.one()):
|
if not (row := cur.one(Row)):
|
||||||
raise KeyError(f'{namespace}:{key}')
|
raise KeyError(f'{namespace}:{key}')
|
||||||
|
|
||||||
row.pop('id', None)
|
row.pop('id', None)
|
||||||
|
@ -211,9 +211,11 @@ class SqlCache(Cache):
|
||||||
|
|
||||||
with self._db.session(True) as conn:
|
with self._db.session(True) as conn:
|
||||||
with conn.run('set-cache-item', params) as cur:
|
with conn.run('set-cache-item', params) as cur:
|
||||||
row = cur.one()
|
if (row := cur.one(Row)) is None:
|
||||||
row.pop('id', None) # type: ignore[union-attr]
|
raise RuntimeError("Cache item not set")
|
||||||
return Item.from_data(*tuple(row.values())) # type: ignore[union-attr]
|
|
||||||
|
row.pop('id', None)
|
||||||
|
return Item.from_data(*tuple(row.values()))
|
||||||
|
|
||||||
|
|
||||||
def delete(self, namespace: str, key: str) -> None:
|
def delete(self, namespace: str, key: str) -> None:
|
||||||
|
|
|
@ -16,6 +16,8 @@ def get_database(config: Config, migrate: bool = True) -> Database[Connection]:
|
||||||
'tables': TABLES
|
'tables': TABLES
|
||||||
}
|
}
|
||||||
|
|
||||||
|
db: Database[Connection]
|
||||||
|
|
||||||
if config.db_type == 'sqlite':
|
if config.db_type == 'sqlite':
|
||||||
db = Database.sqlite(config.sqlite_path, **options)
|
db = Database.sqlite(config.sqlite_path, **options)
|
||||||
|
|
||||||
|
|
|
@ -2,12 +2,13 @@ from __future__ import annotations
|
||||||
|
|
||||||
from argon2 import PasswordHasher
|
from argon2 import PasswordHasher
|
||||||
from bsql import Connection as SqlConnection, Row, Update
|
from bsql import Connection as SqlConnection, Row, Update
|
||||||
from collections.abc import Iterator, Sequence
|
from collections.abc import Iterator
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from . import schema
|
||||||
from .config import (
|
from .config import (
|
||||||
THEMES,
|
THEMES,
|
||||||
ConfigData
|
ConfigData
|
||||||
|
@ -37,14 +38,14 @@ class Connection(SqlConnection):
|
||||||
return get_app()
|
return get_app()
|
||||||
|
|
||||||
|
|
||||||
def distill_inboxes(self, message: Message) -> Iterator[Row]:
|
def distill_inboxes(self, message: Message) -> Iterator[schema.Instance]:
|
||||||
src_domains = {
|
src_domains = {
|
||||||
message.domain,
|
message.domain,
|
||||||
urlparse(message.object_id).netloc
|
urlparse(message.object_id).netloc
|
||||||
}
|
}
|
||||||
|
|
||||||
for instance in self.get_inboxes():
|
for instance in self.get_inboxes():
|
||||||
if instance['domain'] not in src_domains:
|
if instance.domain not in src_domains:
|
||||||
yield instance
|
yield instance
|
||||||
|
|
||||||
|
|
||||||
|
@ -52,7 +53,7 @@ class Connection(SqlConnection):
|
||||||
key = key.replace('_', '-')
|
key = key.replace('_', '-')
|
||||||
|
|
||||||
with self.run('get-config', {'key': key}) as cur:
|
with self.run('get-config', {'key': key}) as cur:
|
||||||
if not (row := cur.one()):
|
if (row := cur.one(Row)) is None:
|
||||||
return ConfigData.DEFAULT(key)
|
return ConfigData.DEFAULT(key)
|
||||||
|
|
||||||
data = ConfigData()
|
data = ConfigData()
|
||||||
|
@ -61,8 +62,8 @@ class Connection(SqlConnection):
|
||||||
|
|
||||||
|
|
||||||
def get_config_all(self) -> ConfigData:
|
def get_config_all(self) -> ConfigData:
|
||||||
with self.run('get-config-all', None) as cur:
|
rows = tuple(self.run('get-config-all', None).all(schema.Row))
|
||||||
return ConfigData.from_rows(tuple(cur.all()))
|
return ConfigData.from_rows(rows)
|
||||||
|
|
||||||
|
|
||||||
def put_config(self, key: str, value: Any) -> Any:
|
def put_config(self, key: str, value: Any) -> Any:
|
||||||
|
@ -99,14 +100,13 @@ class Connection(SqlConnection):
|
||||||
return data.get(key)
|
return data.get(key)
|
||||||
|
|
||||||
|
|
||||||
def get_inbox(self, value: str) -> Row:
|
def get_inbox(self, value: str) -> schema.Instance | None:
|
||||||
with self.run('get-inbox', {'value': value}) as cur:
|
with self.run('get-inbox', {'value': value}) as cur:
|
||||||
return cur.one() # type: ignore
|
return cur.one(schema.Instance)
|
||||||
|
|
||||||
|
|
||||||
def get_inboxes(self) -> Sequence[Row]:
|
def get_inboxes(self) -> Iterator[schema.Instance]:
|
||||||
with self.execute("SELECT * FROM inboxes WHERE accepted = 1") as cur:
|
return self.execute("SELECT * FROM inboxes WHERE accepted = 1").all(schema.Instance)
|
||||||
return tuple(cur.all())
|
|
||||||
|
|
||||||
|
|
||||||
def put_inbox(self,
|
def put_inbox(self,
|
||||||
|
@ -115,7 +115,7 @@ class Connection(SqlConnection):
|
||||||
actor: str | None = None,
|
actor: str | None = None,
|
||||||
followid: str | None = None,
|
followid: str | None = None,
|
||||||
software: str | None = None,
|
software: str | None = None,
|
||||||
accepted: bool = True) -> Row:
|
accepted: bool = True) -> schema.Instance:
|
||||||
|
|
||||||
params: dict[str, Any] = {
|
params: dict[str, Any] = {
|
||||||
'inbox': inbox,
|
'inbox': inbox,
|
||||||
|
@ -125,7 +125,7 @@ class Connection(SqlConnection):
|
||||||
'accepted': accepted
|
'accepted': accepted
|
||||||
}
|
}
|
||||||
|
|
||||||
if not self.get_inbox(domain):
|
if self.get_inbox(domain) is None:
|
||||||
if not inbox:
|
if not inbox:
|
||||||
raise ValueError("Missing inbox")
|
raise ValueError("Missing inbox")
|
||||||
|
|
||||||
|
@ -133,14 +133,20 @@ class Connection(SqlConnection):
|
||||||
params['created'] = datetime.now(tz = timezone.utc)
|
params['created'] = datetime.now(tz = timezone.utc)
|
||||||
|
|
||||||
with self.run('put-inbox', params) as cur:
|
with self.run('put-inbox', params) as cur:
|
||||||
return cur.one() # type: ignore
|
if (row := cur.one(schema.Instance)) is None:
|
||||||
|
raise RuntimeError(f"Failed to insert instance: {domain}")
|
||||||
|
|
||||||
|
return row
|
||||||
|
|
||||||
for key, value in tuple(params.items()):
|
for key, value in tuple(params.items()):
|
||||||
if value is None:
|
if value is None:
|
||||||
del params[key]
|
del params[key]
|
||||||
|
|
||||||
with self.update('inboxes', params, domain = domain) as cur:
|
with self.update('inboxes', params, domain = domain) as cur:
|
||||||
return cur.one() # type: ignore
|
if (row := cur.one(schema.Instance)) is None:
|
||||||
|
raise RuntimeError(f"Failed to update instance: {domain}")
|
||||||
|
|
||||||
|
return row
|
||||||
|
|
||||||
|
|
||||||
def del_inbox(self, value: str) -> bool:
|
def del_inbox(self, value: str) -> bool:
|
||||||
|
@ -151,24 +157,23 @@ class Connection(SqlConnection):
|
||||||
return cur.row_count == 1
|
return cur.row_count == 1
|
||||||
|
|
||||||
|
|
||||||
def get_request(self, domain: str) -> Row:
|
def get_request(self, domain: str) -> schema.Instance | None:
|
||||||
with self.run('get-request', {'domain': domain}) as cur:
|
with self.run('get-request', {'domain': domain}) as cur:
|
||||||
if not (row := cur.one()):
|
return cur.one(schema.Instance)
|
||||||
raise KeyError(domain)
|
|
||||||
|
|
||||||
return row
|
|
||||||
|
|
||||||
|
|
||||||
def get_requests(self) -> Sequence[Row]:
|
def get_requests(self) -> Iterator[schema.Instance]:
|
||||||
with self.execute('SELECT * FROM inboxes WHERE accepted = 0') as cur:
|
return self.execute('SELECT * FROM inboxes WHERE accepted = 0').all(schema.Instance)
|
||||||
return tuple(cur.all())
|
|
||||||
|
|
||||||
|
|
||||||
def put_request_response(self, domain: str, accepted: bool) -> Row:
|
def put_request_response(self, domain: str, accepted: bool) -> schema.Instance:
|
||||||
instance = self.get_request(domain)
|
if (instance := self.get_request(domain)) is None:
|
||||||
|
raise KeyError(domain)
|
||||||
|
|
||||||
if not accepted:
|
if not accepted:
|
||||||
self.del_inbox(domain)
|
if not self.del_inbox(domain):
|
||||||
|
raise RuntimeError(f'Failed to delete request: {domain}')
|
||||||
|
|
||||||
return instance
|
return instance
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
|
@ -177,21 +182,28 @@ class Connection(SqlConnection):
|
||||||
}
|
}
|
||||||
|
|
||||||
with self.run('put-inbox-accept', params) as cur:
|
with self.run('put-inbox-accept', params) as cur:
|
||||||
return cur.one() # type: ignore
|
if (row := cur.one(schema.Instance)) is None:
|
||||||
|
raise RuntimeError(f"Failed to insert response for domain: {domain}")
|
||||||
|
|
||||||
|
return row
|
||||||
|
|
||||||
|
|
||||||
def get_user(self, value: str) -> Row:
|
def get_user(self, value: str) -> schema.User | None:
|
||||||
with self.run('get-user', {'value': value}) as cur:
|
with self.run('get-user', {'value': value}) as cur:
|
||||||
return cur.one() # type: ignore
|
return cur.one(schema.User)
|
||||||
|
|
||||||
|
|
||||||
def get_user_by_token(self, code: str) -> Row:
|
def get_user_by_token(self, code: str) -> schema.User | None:
|
||||||
with self.run('get-user-by-token', {'code': code}) as cur:
|
with self.run('get-user-by-token', {'code': code}) as cur:
|
||||||
return cur.one() # type: ignore
|
return cur.one(schema.User)
|
||||||
|
|
||||||
|
|
||||||
def put_user(self, username: str, password: str | None, handle: str | None = None) -> Row:
|
def get_users(self) -> Iterator[schema.User]:
|
||||||
if self.get_user(username):
|
return self.execute("SELECT * FROM users").all(schema.User)
|
||||||
|
|
||||||
|
|
||||||
|
def put_user(self, username: str, password: str | None, handle: str | None = None) -> schema.User:
|
||||||
|
if self.get_user(username) is not None:
|
||||||
data: dict[str, str | datetime | None] = {}
|
data: dict[str, str | datetime | None] = {}
|
||||||
|
|
||||||
if password:
|
if password:
|
||||||
|
@ -204,7 +216,10 @@ class Connection(SqlConnection):
|
||||||
stmt.set_where("username", username)
|
stmt.set_where("username", username)
|
||||||
|
|
||||||
with self.query(stmt) as cur:
|
with self.query(stmt) as cur:
|
||||||
return cur.one() # type: ignore
|
if (row := cur.one(schema.User)) is None:
|
||||||
|
raise RuntimeError(f"Failed to update user: {username}")
|
||||||
|
|
||||||
|
return row
|
||||||
|
|
||||||
if password is None:
|
if password is None:
|
||||||
raise ValueError('Password cannot be empty')
|
raise ValueError('Password cannot be empty')
|
||||||
|
@ -217,25 +232,36 @@ class Connection(SqlConnection):
|
||||||
}
|
}
|
||||||
|
|
||||||
with self.run('put-user', data) as cur:
|
with self.run('put-user', data) as cur:
|
||||||
return cur.one() # type: ignore
|
if (row := cur.one(schema.User)) is None:
|
||||||
|
raise RuntimeError(f"Failed to insert user: {username}")
|
||||||
|
|
||||||
|
return row
|
||||||
|
|
||||||
|
|
||||||
def del_user(self, username: str) -> None:
|
def del_user(self, username: str) -> None:
|
||||||
user = self.get_user(username)
|
if (user := self.get_user(username)) is None:
|
||||||
|
raise KeyError(username)
|
||||||
|
|
||||||
with self.run('del-user', {'value': user['username']}):
|
with self.run('del-user', {'value': user.username}):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
with self.run('del-token-user', {'username': user['username']}):
|
with self.run('del-token-user', {'username': user.username}):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def get_token(self, code: str) -> Row:
|
def get_token(self, code: str) -> schema.Token | None:
|
||||||
with self.run('get-token', {'code': code}) as cur:
|
with self.run('get-token', {'code': code}) as cur:
|
||||||
return cur.one() # type: ignore
|
return cur.one(schema.Token)
|
||||||
|
|
||||||
|
|
||||||
def put_token(self, username: str) -> Row:
|
def get_tokens(self, username: str | None = None) -> Iterator[schema.Token]:
|
||||||
|
if username is not None:
|
||||||
|
return self.select('tokens').all(schema.Token)
|
||||||
|
|
||||||
|
return self.select('tokens', username = username).all(schema.Token)
|
||||||
|
|
||||||
|
|
||||||
|
def put_token(self, username: str) -> schema.Token:
|
||||||
data = {
|
data = {
|
||||||
'code': uuid4().hex,
|
'code': uuid4().hex,
|
||||||
'user': username,
|
'user': username,
|
||||||
|
@ -243,7 +269,10 @@ class Connection(SqlConnection):
|
||||||
}
|
}
|
||||||
|
|
||||||
with self.run('put-token', data) as cur:
|
with self.run('put-token', data) as cur:
|
||||||
return cur.one() # type: ignore
|
if (row := cur.one(schema.Token)) is None:
|
||||||
|
raise RuntimeError(f"Failed to insert token for user: {username}")
|
||||||
|
|
||||||
|
return row
|
||||||
|
|
||||||
|
|
||||||
def del_token(self, code: str) -> None:
|
def del_token(self, code: str) -> None:
|
||||||
|
@ -251,18 +280,22 @@ class Connection(SqlConnection):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def get_domain_ban(self, domain: str) -> Row:
|
def get_domain_ban(self, domain: str) -> schema.DomainBan | None:
|
||||||
if domain.startswith('http'):
|
if domain.startswith('http'):
|
||||||
domain = urlparse(domain).netloc
|
domain = urlparse(domain).netloc
|
||||||
|
|
||||||
with self.run('get-domain-ban', {'domain': domain}) as cur:
|
with self.run('get-domain-ban', {'domain': domain}) as cur:
|
||||||
return cur.one() # type: ignore
|
return cur.one(schema.DomainBan)
|
||||||
|
|
||||||
|
|
||||||
|
def get_domain_bans(self) -> Iterator[schema.DomainBan]:
|
||||||
|
return self.execute("SELECT * FROM domain_bans").all(schema.DomainBan)
|
||||||
|
|
||||||
|
|
||||||
def put_domain_ban(self,
|
def put_domain_ban(self,
|
||||||
domain: str,
|
domain: str,
|
||||||
reason: str | None = None,
|
reason: str | None = None,
|
||||||
note: str | None = None) -> Row:
|
note: str | None = None) -> schema.DomainBan:
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
'domain': domain,
|
'domain': domain,
|
||||||
|
@ -272,13 +305,16 @@ class Connection(SqlConnection):
|
||||||
}
|
}
|
||||||
|
|
||||||
with self.run('put-domain-ban', params) as cur:
|
with self.run('put-domain-ban', params) as cur:
|
||||||
return cur.one() # type: ignore
|
if (row := cur.one(schema.DomainBan)) is None:
|
||||||
|
raise RuntimeError(f"Failed to insert domain ban: {domain}")
|
||||||
|
|
||||||
|
return row
|
||||||
|
|
||||||
|
|
||||||
def update_domain_ban(self,
|
def update_domain_ban(self,
|
||||||
domain: str,
|
domain: str,
|
||||||
reason: str | None = None,
|
reason: str | None = None,
|
||||||
note: str | None = None) -> Row:
|
note: str | None = None) -> schema.DomainBan:
|
||||||
|
|
||||||
if not (reason or note):
|
if not (reason or note):
|
||||||
raise ValueError('"reason" and/or "note" must be specified')
|
raise ValueError('"reason" and/or "note" must be specified')
|
||||||
|
@ -298,7 +334,10 @@ class Connection(SqlConnection):
|
||||||
if cur.row_count > 1:
|
if cur.row_count > 1:
|
||||||
raise ValueError('More than one row was modified')
|
raise ValueError('More than one row was modified')
|
||||||
|
|
||||||
return self.get_domain_ban(domain)
|
if (row := cur.one(schema.DomainBan)) is None:
|
||||||
|
raise RuntimeError(f"Failed to update domain ban: {domain}")
|
||||||
|
|
||||||
|
return row
|
||||||
|
|
||||||
|
|
||||||
def del_domain_ban(self, domain: str) -> bool:
|
def del_domain_ban(self, domain: str) -> bool:
|
||||||
|
@ -309,15 +348,19 @@ class Connection(SqlConnection):
|
||||||
return cur.row_count == 1
|
return cur.row_count == 1
|
||||||
|
|
||||||
|
|
||||||
def get_software_ban(self, name: str) -> Row:
|
def get_software_ban(self, name: str) -> schema.SoftwareBan | None:
|
||||||
with self.run('get-software-ban', {'name': name}) as cur:
|
with self.run('get-software-ban', {'name': name}) as cur:
|
||||||
return cur.one() # type: ignore
|
return cur.one(schema.SoftwareBan)
|
||||||
|
|
||||||
|
|
||||||
|
def get_software_bans(self) -> Iterator[schema.SoftwareBan,]:
|
||||||
|
return self.execute('SELECT * FROM software_bans').all(schema.SoftwareBan)
|
||||||
|
|
||||||
|
|
||||||
def put_software_ban(self,
|
def put_software_ban(self,
|
||||||
name: str,
|
name: str,
|
||||||
reason: str | None = None,
|
reason: str | None = None,
|
||||||
note: str | None = None) -> Row:
|
note: str | None = None) -> schema.SoftwareBan:
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
'name': name,
|
'name': name,
|
||||||
|
@ -327,13 +370,16 @@ class Connection(SqlConnection):
|
||||||
}
|
}
|
||||||
|
|
||||||
with self.run('put-software-ban', params) as cur:
|
with self.run('put-software-ban', params) as cur:
|
||||||
return cur.one() # type: ignore
|
if (row := cur.one(schema.SoftwareBan)) is None:
|
||||||
|
raise RuntimeError(f'Failed to insert software ban: {name}')
|
||||||
|
|
||||||
|
return row
|
||||||
|
|
||||||
|
|
||||||
def update_software_ban(self,
|
def update_software_ban(self,
|
||||||
name: str,
|
name: str,
|
||||||
reason: str | None = None,
|
reason: str | None = None,
|
||||||
note: str | None = None) -> Row:
|
note: str | None = None) -> schema.SoftwareBan:
|
||||||
|
|
||||||
if not (reason or note):
|
if not (reason or note):
|
||||||
raise ValueError('"reason" and/or "note" must be specified')
|
raise ValueError('"reason" and/or "note" must be specified')
|
||||||
|
@ -353,7 +399,10 @@ class Connection(SqlConnection):
|
||||||
if cur.row_count > 1:
|
if cur.row_count > 1:
|
||||||
raise ValueError('More than one row was modified')
|
raise ValueError('More than one row was modified')
|
||||||
|
|
||||||
return self.get_software_ban(name)
|
if (row := cur.one(schema.SoftwareBan)) is None:
|
||||||
|
raise RuntimeError(f'Failed to update software ban: {name}')
|
||||||
|
|
||||||
|
return row
|
||||||
|
|
||||||
|
|
||||||
def del_software_ban(self, name: str) -> bool:
|
def del_software_ban(self, name: str) -> bool:
|
||||||
|
@ -364,19 +413,26 @@ class Connection(SqlConnection):
|
||||||
return cur.row_count == 1
|
return cur.row_count == 1
|
||||||
|
|
||||||
|
|
||||||
def get_domain_whitelist(self, domain: str) -> Row:
|
def get_domain_whitelist(self, domain: str) -> schema.Whitelist | None:
|
||||||
with self.run('get-domain-whitelist', {'domain': domain}) as cur:
|
with self.run('get-domain-whitelist', {'domain': domain}) as cur:
|
||||||
return cur.one() # type: ignore
|
return cur.one()
|
||||||
|
|
||||||
|
|
||||||
def put_domain_whitelist(self, domain: str) -> Row:
|
def get_domains_whitelist(self) -> Iterator[schema.Whitelist,]:
|
||||||
|
return self.execute("SELECT * FROM whitelist").all(schema.Whitelist)
|
||||||
|
|
||||||
|
|
||||||
|
def put_domain_whitelist(self, domain: str) -> schema.Whitelist:
|
||||||
params = {
|
params = {
|
||||||
'domain': domain,
|
'domain': domain,
|
||||||
'created': datetime.now(tz = timezone.utc)
|
'created': datetime.now(tz = timezone.utc)
|
||||||
}
|
}
|
||||||
|
|
||||||
with self.run('put-domain-whitelist', params) as cur:
|
with self.run('put-domain-whitelist', params) as cur:
|
||||||
return cur.one() # type: ignore
|
if (row := cur.one(schema.Whitelist)) is None:
|
||||||
|
raise RuntimeError(f'Failed to insert whitelisted domain: {domain}')
|
||||||
|
|
||||||
|
return row
|
||||||
|
|
||||||
|
|
||||||
def del_domain_whitelist(self, domain: str) -> bool:
|
def del_domain_whitelist(self, domain: str) -> bool:
|
||||||
|
|
|
@ -1,61 +1,88 @@
|
||||||
from bsql import Column, Table, Tables
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import typing
|
||||||
|
|
||||||
|
from bsql import Column, Row, Tables
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
from .config import ConfigData
|
from .config import ConfigData
|
||||||
from .connection import Connection
|
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
from .connection import Connection
|
||||||
|
|
||||||
|
|
||||||
VERSIONS: dict[int, Callable[[Connection], None]] = {}
|
VERSIONS: dict[int, Callable[[Connection], None]] = {}
|
||||||
TABLES: Tables = Tables(
|
TABLES = Tables()
|
||||||
Table(
|
|
||||||
'config',
|
|
||||||
Column('key', 'text', primary_key = True, unique = True, nullable = False),
|
@TABLES.add_row
|
||||||
Column('value', 'text'),
|
class Config(Row):
|
||||||
Column('type', 'text', default = 'str')
|
key: Column[str] = Column('key', 'text', primary_key = True, unique = True, nullable = False)
|
||||||
),
|
value: Column[str] = Column('value', 'text')
|
||||||
Table(
|
type: Column[str] = Column('type', 'text', default = 'str')
|
||||||
'inboxes',
|
|
||||||
Column('domain', 'text', primary_key = True, unique = True, nullable = False),
|
|
||||||
Column('actor', 'text', unique = True),
|
@TABLES.add_row
|
||||||
Column('inbox', 'text', unique = True, nullable = False),
|
class Instance(Row):
|
||||||
Column('followid', 'text'),
|
table_name: str = 'inboxes'
|
||||||
Column('software', 'text'),
|
|
||||||
Column('accepted', 'boolean'),
|
domain: Column[str] = Column(
|
||||||
Column('created', 'timestamp', nullable = False)
|
'domain', 'text', primary_key = True, unique = True, nullable = False)
|
||||||
),
|
actor: Column[str] = Column('actor', 'text', unique = True)
|
||||||
Table(
|
inbox: Column[str] = Column('inbox', 'text', unique = True, nullable = False)
|
||||||
'whitelist',
|
followid: Column[str] = Column('followid', 'text')
|
||||||
Column('domain', 'text', primary_key = True, unique = True, nullable = True),
|
software: Column[str] = Column('software', 'text')
|
||||||
Column('created', 'timestamp')
|
accepted: Column[datetime] = Column('accepted', 'boolean')
|
||||||
),
|
created: Column[datetime] = Column('created', 'timestamp', nullable = False)
|
||||||
Table(
|
|
||||||
'domain_bans',
|
|
||||||
Column('domain', 'text', primary_key = True, unique = True, nullable = True),
|
@TABLES.add_row
|
||||||
Column('reason', 'text'),
|
class Whitelist(Row):
|
||||||
Column('note', 'text'),
|
domain: Column[str] = Column(
|
||||||
Column('created', 'timestamp', nullable = False)
|
'domain', 'text', primary_key = True, unique = True, nullable = True)
|
||||||
),
|
created: Column[datetime] = Column('created', 'timestamp')
|
||||||
Table(
|
|
||||||
'software_bans',
|
|
||||||
Column('name', 'text', primary_key = True, unique = True, nullable = True),
|
@TABLES.add_row
|
||||||
Column('reason', 'text'),
|
class DomainBan(Row):
|
||||||
Column('note', 'text'),
|
table_name: str = 'domain_bans'
|
||||||
Column('created', 'timestamp', nullable = False)
|
|
||||||
),
|
domain: Column[str] = Column(
|
||||||
Table(
|
'domain', 'text', primary_key = True, unique = True, nullable = True)
|
||||||
'users',
|
reason: Column[str] = Column('reason', 'text')
|
||||||
Column('username', 'text', primary_key = True, unique = True, nullable = False),
|
note: Column[str] = Column('note', 'text')
|
||||||
Column('hash', 'text', nullable = False),
|
created: Column[datetime] = Column('created', 'timestamp')
|
||||||
Column('handle', 'text'),
|
|
||||||
Column('created', 'timestamp', nullable = False)
|
|
||||||
),
|
@TABLES.add_row
|
||||||
Table(
|
class SoftwareBan(Row):
|
||||||
'tokens',
|
table_name: str = 'software_bans'
|
||||||
Column('code', 'text', primary_key = True, unique = True, nullable = False),
|
|
||||||
Column('user', 'text', nullable = False),
|
name: Column[str] = Column('name', 'text', primary_key = True, unique = True, nullable = True)
|
||||||
Column('created', 'timestmap', nullable = False)
|
reason: Column[str] = Column('reason', 'text')
|
||||||
)
|
note: Column[str] = Column('note', 'text')
|
||||||
)
|
created: Column[datetime] = Column('created', 'timestamp')
|
||||||
|
|
||||||
|
|
||||||
|
@TABLES.add_row
|
||||||
|
class User(Row):
|
||||||
|
table_name: str = 'users'
|
||||||
|
|
||||||
|
username: Column[str] = Column(
|
||||||
|
'username', 'text', primary_key = True, unique = True, nullable = False)
|
||||||
|
hash: Column[str] = Column('hash', 'text', nullable = False)
|
||||||
|
handle: Column[str] = Column('handle', 'text')
|
||||||
|
created: Column[datetime] = Column('created', 'timestamp')
|
||||||
|
|
||||||
|
|
||||||
|
@TABLES.add_row
|
||||||
|
class Token(Row):
|
||||||
|
table_name: str = 'tokens'
|
||||||
|
|
||||||
|
code: Column[str] = Column('code', 'text', primary_key = True, unique = True, nullable = False)
|
||||||
|
user: Column[str] = Column('user', 'text', nullable = False)
|
||||||
|
created: Column[datetime] = Column('created', 'timestamp')
|
||||||
|
|
||||||
|
|
||||||
def migration(func: Callable[[Connection], None]) -> Callable[[Connection], None]:
|
def migration(func: Callable[[Connection], None]) -> Callable[[Connection], None]:
|
||||||
|
|
|
@ -5,11 +5,11 @@ import json
|
||||||
from aiohttp import ClientSession, ClientTimeout, TCPConnector
|
from aiohttp import ClientSession, ClientTimeout, TCPConnector
|
||||||
from aputils import AlgorithmType, Nodeinfo, ObjectType, Signer, WellKnownNodeinfo
|
from aputils import AlgorithmType, Nodeinfo, ObjectType, Signer, WellKnownNodeinfo
|
||||||
from blib import JsonBase
|
from blib import JsonBase
|
||||||
from bsql import Row
|
|
||||||
from typing import TYPE_CHECKING, Any, TypeVar, overload
|
from typing import TYPE_CHECKING, Any, TypeVar, overload
|
||||||
|
|
||||||
from . import __version__, logger as logging
|
from . import __version__, logger as logging
|
||||||
from .cache import Cache
|
from .cache import Cache
|
||||||
|
from .database.schema import Instance
|
||||||
from .misc import MIMETYPES, Message, get_app
|
from .misc import MIMETYPES, Message, get_app
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -184,12 +184,12 @@ class HttpClient:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def post(self, url: str, data: Message | bytes, instance: Row | None = None) -> None:
|
async def post(self, url: str, data: Message | bytes, instance: Instance | None = None) -> None:
|
||||||
if not self._session:
|
if not self._session:
|
||||||
raise RuntimeError('Client not open')
|
raise RuntimeError('Client not open')
|
||||||
|
|
||||||
# akkoma and pleroma do not support HS2019 and other software still needs to be tested
|
# akkoma and pleroma do not support HS2019 and other software still needs to be tested
|
||||||
if instance and instance['software'] in SUPPORTS_HS2019:
|
if instance is not None and instance.software in SUPPORTS_HS2019:
|
||||||
algorithm = AlgorithmType.HS2019
|
algorithm = AlgorithmType.HS2019
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
131
relay/manage.py
131
relay/manage.py
|
@ -6,7 +6,6 @@ import click
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from bsql import Row
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
@ -17,7 +16,7 @@ from . import http_client as http
|
||||||
from . import logger as logging
|
from . import logger as logging
|
||||||
from .application import Application
|
from .application import Application
|
||||||
from .compat import RelayConfig, RelayDatabase
|
from .compat import RelayConfig, RelayDatabase
|
||||||
from .database import RELAY_SOFTWARE, get_database
|
from .database import RELAY_SOFTWARE, get_database, schema
|
||||||
from .misc import ACTOR_FORMATS, SOFTWARE, IS_DOCKER, Message
|
from .misc import ACTOR_FORMATS, SOFTWARE, IS_DOCKER, Message
|
||||||
|
|
||||||
|
|
||||||
|
@ -367,8 +366,8 @@ def cli_user_list(ctx: click.Context) -> None:
|
||||||
click.echo('Users:')
|
click.echo('Users:')
|
||||||
|
|
||||||
with ctx.obj.database.session() as conn:
|
with ctx.obj.database.session() as conn:
|
||||||
for user in conn.execute('SELECT * FROM users'):
|
for row in conn.get_users():
|
||||||
click.echo(f'- {user["username"]}')
|
click.echo(f'- {row.username}')
|
||||||
|
|
||||||
|
|
||||||
@cli_user.command('create')
|
@cli_user.command('create')
|
||||||
|
@ -379,7 +378,7 @@ def cli_user_create(ctx: click.Context, username: str, handle: str) -> None:
|
||||||
'Create a new local user'
|
'Create a new local user'
|
||||||
|
|
||||||
with ctx.obj.database.session() as conn:
|
with ctx.obj.database.session() as conn:
|
||||||
if conn.get_user(username):
|
if conn.get_user(username) is not None:
|
||||||
click.echo(f'User already exists: {username}')
|
click.echo(f'User already exists: {username}')
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -406,7 +405,7 @@ def cli_user_delete(ctx: click.Context, username: str) -> None:
|
||||||
'Delete a local user'
|
'Delete a local user'
|
||||||
|
|
||||||
with ctx.obj.database.session() as conn:
|
with ctx.obj.database.session() as conn:
|
||||||
if not conn.get_user(username):
|
if conn.get_user(username) is None:
|
||||||
click.echo(f'User does not exist: {username}')
|
click.echo(f'User does not exist: {username}')
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -424,8 +423,8 @@ def cli_user_list_tokens(ctx: click.Context, username: str) -> None:
|
||||||
click.echo(f'Tokens for "{username}":')
|
click.echo(f'Tokens for "{username}":')
|
||||||
|
|
||||||
with ctx.obj.database.session() as conn:
|
with ctx.obj.database.session() as conn:
|
||||||
for token in conn.execute('SELECT * FROM tokens WHERE user = :user', {'user': username}):
|
for row in conn.get_tokens(username):
|
||||||
click.echo(f'- {token["code"]}')
|
click.echo(f'- {row.code}')
|
||||||
|
|
||||||
|
|
||||||
@cli_user.command('create-token')
|
@cli_user.command('create-token')
|
||||||
|
@ -435,13 +434,13 @@ def cli_user_create_token(ctx: click.Context, username: str) -> None:
|
||||||
'Create a new API token for a user'
|
'Create a new API token for a user'
|
||||||
|
|
||||||
with ctx.obj.database.session() as conn:
|
with ctx.obj.database.session() as conn:
|
||||||
if not (user := conn.get_user(username)):
|
if (user := conn.get_user(username)) is None:
|
||||||
click.echo(f'User does not exist: {username}')
|
click.echo(f'User does not exist: {username}')
|
||||||
return
|
return
|
||||||
|
|
||||||
token = conn.put_token(user['username'])
|
token = conn.put_token(user.username)
|
||||||
|
|
||||||
click.echo(f'New token for "{username}": {token["code"]}')
|
click.echo(f'New token for "{username}": {token.code}')
|
||||||
|
|
||||||
|
|
||||||
@cli_user.command('delete-token')
|
@cli_user.command('delete-token')
|
||||||
|
@ -451,7 +450,7 @@ def cli_user_delete_token(ctx: click.Context, code: str) -> None:
|
||||||
'Delete an API token'
|
'Delete an API token'
|
||||||
|
|
||||||
with ctx.obj.database.session() as conn:
|
with ctx.obj.database.session() as conn:
|
||||||
if not conn.get_token(code):
|
if conn.get_token(code) is None:
|
||||||
click.echo('Token does not exist')
|
click.echo('Token does not exist')
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -473,8 +472,8 @@ def cli_inbox_list(ctx: click.Context) -> None:
|
||||||
click.echo('Connected to the following instances or relays:')
|
click.echo('Connected to the following instances or relays:')
|
||||||
|
|
||||||
with ctx.obj.database.session() as conn:
|
with ctx.obj.database.session() as conn:
|
||||||
for inbox in conn.get_inboxes():
|
for row in conn.get_inboxes():
|
||||||
click.echo(f'- {inbox["inbox"]}')
|
click.echo(f'- {row.inbox}')
|
||||||
|
|
||||||
|
|
||||||
@cli_inbox.command('follow')
|
@cli_inbox.command('follow')
|
||||||
|
@ -483,19 +482,21 @@ def cli_inbox_list(ctx: click.Context) -> None:
|
||||||
def cli_inbox_follow(ctx: click.Context, actor: str) -> None:
|
def cli_inbox_follow(ctx: click.Context, actor: str) -> None:
|
||||||
'Follow an actor (Relay must be running)'
|
'Follow an actor (Relay must be running)'
|
||||||
|
|
||||||
|
instance: schema.Instance | None = None
|
||||||
|
|
||||||
with ctx.obj.database.session() as conn:
|
with ctx.obj.database.session() as conn:
|
||||||
if conn.get_domain_ban(actor):
|
if conn.get_domain_ban(actor):
|
||||||
click.echo(f'Error: Refusing to follow banned actor: {actor}')
|
click.echo(f'Error: Refusing to follow banned actor: {actor}')
|
||||||
return
|
return
|
||||||
|
|
||||||
if (inbox_data := conn.get_inbox(actor)):
|
if (instance := conn.get_inbox(actor)) is not None:
|
||||||
inbox = inbox_data['inbox']
|
inbox = instance.inbox
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if not actor.startswith('http'):
|
if not actor.startswith('http'):
|
||||||
actor = f'https://{actor}/actor'
|
actor = f'https://{actor}/actor'
|
||||||
|
|
||||||
if not (actor_data := asyncio.run(http.get(actor, sign_headers = True))):
|
if (actor_data := asyncio.run(http.get(actor, sign_headers = True))) is None:
|
||||||
click.echo(f'Failed to fetch actor: {actor}')
|
click.echo(f'Failed to fetch actor: {actor}')
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -506,7 +507,7 @@ def cli_inbox_follow(ctx: click.Context, actor: str) -> None:
|
||||||
actor = actor
|
actor = actor
|
||||||
)
|
)
|
||||||
|
|
||||||
asyncio.run(http.post(inbox, message, inbox_data))
|
asyncio.run(http.post(inbox, message, instance))
|
||||||
click.echo(f'Sent follow message to actor: {actor}')
|
click.echo(f'Sent follow message to actor: {actor}')
|
||||||
|
|
||||||
|
|
||||||
|
@ -516,19 +517,19 @@ def cli_inbox_follow(ctx: click.Context, actor: str) -> None:
|
||||||
def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None:
|
def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None:
|
||||||
'Unfollow an actor (Relay must be running)'
|
'Unfollow an actor (Relay must be running)'
|
||||||
|
|
||||||
inbox_data: Row | None = None
|
instance: schema.Instance | None = None
|
||||||
|
|
||||||
with ctx.obj.database.session() as conn:
|
with ctx.obj.database.session() as conn:
|
||||||
if conn.get_domain_ban(actor):
|
if conn.get_domain_ban(actor):
|
||||||
click.echo(f'Error: Refusing to follow banned actor: {actor}')
|
click.echo(f'Error: Refusing to follow banned actor: {actor}')
|
||||||
return
|
return
|
||||||
|
|
||||||
if (inbox_data := conn.get_inbox(actor)):
|
if (instance := conn.get_inbox(actor)):
|
||||||
inbox = inbox_data['inbox']
|
inbox = instance.inbox
|
||||||
message = Message.new_unfollow(
|
message = Message.new_unfollow(
|
||||||
host = ctx.obj.config.domain,
|
host = ctx.obj.config.domain,
|
||||||
actor = actor,
|
actor = actor,
|
||||||
follow = inbox_data['followid']
|
follow = instance.followid
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
@ -552,7 +553,7 @@ def cli_inbox_unfollow(ctx: click.Context, actor: str) -> None:
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
asyncio.run(http.post(inbox, message, inbox_data))
|
asyncio.run(http.post(inbox, message, instance))
|
||||||
click.echo(f'Sent unfollow message to: {actor}')
|
click.echo(f'Sent unfollow message to: {actor}')
|
||||||
|
|
||||||
|
|
||||||
|
@ -632,9 +633,9 @@ def cli_request_list(ctx: click.Context) -> None:
|
||||||
click.echo('Follow requests:')
|
click.echo('Follow requests:')
|
||||||
|
|
||||||
with ctx.obj.database.session() as conn:
|
with ctx.obj.database.session() as conn:
|
||||||
for instance in conn.get_requests():
|
for row in conn.get_requests():
|
||||||
date = instance['created'].strftime('%Y-%m-%d')
|
date = row.created.strftime('%Y-%m-%d')
|
||||||
click.echo(f'- [{date}] {instance["domain"]}')
|
click.echo(f'- [{date}] {row.domain}')
|
||||||
|
|
||||||
|
|
||||||
@cli_request.command('accept')
|
@cli_request.command('accept')
|
||||||
|
@ -653,20 +654,20 @@ def cli_request_accept(ctx: click.Context, domain: str) -> None:
|
||||||
|
|
||||||
message = Message.new_response(
|
message = Message.new_response(
|
||||||
host = ctx.obj.config.domain,
|
host = ctx.obj.config.domain,
|
||||||
actor = instance['actor'],
|
actor = instance.actor,
|
||||||
followid = instance['followid'],
|
followid = instance.followid,
|
||||||
accept = True
|
accept = True
|
||||||
)
|
)
|
||||||
|
|
||||||
asyncio.run(http.post(instance['inbox'], message, instance))
|
asyncio.run(http.post(instance.inbox, message, instance))
|
||||||
|
|
||||||
if instance['software'] != 'mastodon':
|
if instance.software != 'mastodon':
|
||||||
message = Message.new_follow(
|
message = Message.new_follow(
|
||||||
host = ctx.obj.config.domain,
|
host = ctx.obj.config.domain,
|
||||||
actor = instance['actor']
|
actor = instance.actor
|
||||||
)
|
)
|
||||||
|
|
||||||
asyncio.run(http.post(instance['inbox'], message, instance))
|
asyncio.run(http.post(instance.inbox, message, instance))
|
||||||
|
|
||||||
|
|
||||||
@cli_request.command('deny')
|
@cli_request.command('deny')
|
||||||
|
@ -685,12 +686,12 @@ def cli_request_deny(ctx: click.Context, domain: str) -> None:
|
||||||
|
|
||||||
response = Message.new_response(
|
response = Message.new_response(
|
||||||
host = ctx.obj.config.domain,
|
host = ctx.obj.config.domain,
|
||||||
actor = instance['actor'],
|
actor = instance.actor,
|
||||||
followid = instance['followid'],
|
followid = instance.followid,
|
||||||
accept = False
|
accept = False
|
||||||
)
|
)
|
||||||
|
|
||||||
asyncio.run(http.post(instance['inbox'], response, instance))
|
asyncio.run(http.post(instance.inbox, response, instance))
|
||||||
|
|
||||||
|
|
||||||
@cli.group('instance')
|
@cli.group('instance')
|
||||||
|
@ -706,12 +707,12 @@ def cli_instance_list(ctx: click.Context) -> None:
|
||||||
click.echo('Banned domains:')
|
click.echo('Banned domains:')
|
||||||
|
|
||||||
with ctx.obj.database.session() as conn:
|
with ctx.obj.database.session() as conn:
|
||||||
for instance in conn.execute('SELECT * FROM domain_bans'):
|
for row in conn.get_domain_bans():
|
||||||
if instance['reason']:
|
if row.reason is not None:
|
||||||
click.echo(f'- {instance["domain"]} ({instance["reason"]})')
|
click.echo(f'- {row.domain} ({row.reason})')
|
||||||
|
|
||||||
else:
|
else:
|
||||||
click.echo(f'- {instance["domain"]}')
|
click.echo(f'- {row.domain}')
|
||||||
|
|
||||||
|
|
||||||
@cli_instance.command('ban')
|
@cli_instance.command('ban')
|
||||||
|
@ -723,7 +724,7 @@ def cli_instance_ban(ctx: click.Context, domain: str, reason: str, note: str) ->
|
||||||
'Ban an instance and remove the associated inbox if it exists'
|
'Ban an instance and remove the associated inbox if it exists'
|
||||||
|
|
||||||
with ctx.obj.database.session() as conn:
|
with ctx.obj.database.session() as conn:
|
||||||
if conn.get_domain_ban(domain):
|
if conn.get_domain_ban(domain) is not None:
|
||||||
click.echo(f'Domain already banned: {domain}')
|
click.echo(f'Domain already banned: {domain}')
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -739,7 +740,7 @@ def cli_instance_unban(ctx: click.Context, domain: str) -> None:
|
||||||
'Unban an instance'
|
'Unban an instance'
|
||||||
|
|
||||||
with ctx.obj.database.session() as conn:
|
with ctx.obj.database.session() as conn:
|
||||||
if not conn.del_domain_ban(domain):
|
if conn.del_domain_ban(domain) is None:
|
||||||
click.echo(f'Instance wasn\'t banned: {domain}')
|
click.echo(f'Instance wasn\'t banned: {domain}')
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -764,11 +765,11 @@ def cli_instance_update(ctx: click.Context, domain: str, reason: str, note: str)
|
||||||
|
|
||||||
click.echo(f'Updated domain ban: {domain}')
|
click.echo(f'Updated domain ban: {domain}')
|
||||||
|
|
||||||
if row['reason']:
|
if row.reason:
|
||||||
click.echo(f'- {row["domain"]} ({row["reason"]})')
|
click.echo(f'- {row.domain} ({row.reason})')
|
||||||
|
|
||||||
else:
|
else:
|
||||||
click.echo(f'- {row["domain"]}')
|
click.echo(f'- {row.domain}')
|
||||||
|
|
||||||
|
|
||||||
@cli.group('software')
|
@cli.group('software')
|
||||||
|
@ -784,12 +785,12 @@ def cli_software_list(ctx: click.Context) -> None:
|
||||||
click.echo('Banned software:')
|
click.echo('Banned software:')
|
||||||
|
|
||||||
with ctx.obj.database.session() as conn:
|
with ctx.obj.database.session() as conn:
|
||||||
for software in conn.execute('SELECT * FROM software_bans'):
|
for row in conn.get_software_bans():
|
||||||
if software['reason']:
|
if row.reason:
|
||||||
click.echo(f'- {software["name"]} ({software["reason"]})')
|
click.echo(f'- {row.name} ({row.reason})')
|
||||||
|
|
||||||
else:
|
else:
|
||||||
click.echo(f'- {software["name"]}')
|
click.echo(f'- {row.name}')
|
||||||
|
|
||||||
|
|
||||||
@cli_software.command('ban')
|
@cli_software.command('ban')
|
||||||
|
@ -811,12 +812,12 @@ def cli_software_ban(ctx: click.Context,
|
||||||
|
|
||||||
with ctx.obj.database.session() as conn:
|
with ctx.obj.database.session() as conn:
|
||||||
if name == 'RELAYS':
|
if name == 'RELAYS':
|
||||||
for software in RELAY_SOFTWARE:
|
for item in RELAY_SOFTWARE:
|
||||||
if conn.get_software_ban(software):
|
if conn.get_software_ban(item):
|
||||||
click.echo(f'Relay already banned: {software}')
|
click.echo(f'Relay already banned: {item}')
|
||||||
continue
|
continue
|
||||||
|
|
||||||
conn.put_software_ban(software, reason or 'relay', note)
|
conn.put_software_ban(item, reason or 'relay', note)
|
||||||
|
|
||||||
click.echo('Banned all relay software')
|
click.echo('Banned all relay software')
|
||||||
return
|
return
|
||||||
|
@ -893,11 +894,11 @@ def cli_software_update(ctx: click.Context, name: str, reason: str, note: str) -
|
||||||
|
|
||||||
click.echo(f'Updated software ban: {name}')
|
click.echo(f'Updated software ban: {name}')
|
||||||
|
|
||||||
if row['reason']:
|
if row.reason:
|
||||||
click.echo(f'- {row["name"]} ({row["reason"]})')
|
click.echo(f'- {row.name} ({row.reason})')
|
||||||
|
|
||||||
else:
|
else:
|
||||||
click.echo(f'- {row["name"]}')
|
click.echo(f'- {row.name}')
|
||||||
|
|
||||||
|
|
||||||
@cli.group('whitelist')
|
@cli.group('whitelist')
|
||||||
|
@ -913,8 +914,8 @@ def cli_whitelist_list(ctx: click.Context) -> None:
|
||||||
click.echo('Current whitelisted domains:')
|
click.echo('Current whitelisted domains:')
|
||||||
|
|
||||||
with ctx.obj.database.session() as conn:
|
with ctx.obj.database.session() as conn:
|
||||||
for domain in conn.execute('SELECT * FROM whitelist'):
|
for row in conn.get_domain_whitelist():
|
||||||
click.echo(f'- {domain["domain"]}')
|
click.echo(f'- {row.domain}')
|
||||||
|
|
||||||
|
|
||||||
@cli_whitelist.command('add')
|
@cli_whitelist.command('add')
|
||||||
|
@ -953,23 +954,19 @@ def cli_whitelist_remove(ctx: click.Context, domain: str) -> None:
|
||||||
@cli_whitelist.command('import')
|
@cli_whitelist.command('import')
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
def cli_whitelist_import(ctx: click.Context) -> None:
|
def cli_whitelist_import(ctx: click.Context) -> None:
|
||||||
'Add all current inboxes to the whitelist'
|
'Add all current instances to the whitelist'
|
||||||
|
|
||||||
with ctx.obj.database.session() as conn:
|
with ctx.obj.database.session() as conn:
|
||||||
for inbox in conn.execute('SELECT * FROM inboxes').all():
|
for row in conn.get_inboxes():
|
||||||
if conn.get_domain_whitelist(inbox['domain']):
|
if conn.get_domain_whitelist(row.domain) is not None:
|
||||||
click.echo(f'Domain already in whitelist: {inbox["domain"]}')
|
click.echo(f'Domain already in whitelist: {row.domain}')
|
||||||
continue
|
continue
|
||||||
|
|
||||||
conn.put_domain_whitelist(inbox['domain'])
|
conn.put_domain_whitelist(row.domain)
|
||||||
|
|
||||||
click.echo('Imported whitelist from inboxes')
|
click.echo('Imported whitelist from inboxes')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
cli(prog_name='relay')
|
cli(prog_name='activityrelay')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
click.echo('Running relay.manage is depreciated. Run `activityrelay [command]` instead.')
|
|
||||||
|
|
|
@ -34,7 +34,7 @@ async def handle_relay(view: ActorView, conn: Connection) -> None:
|
||||||
logging.debug('>> relay: %s', message)
|
logging.debug('>> relay: %s', message)
|
||||||
|
|
||||||
for instance in conn.distill_inboxes(view.message):
|
for instance in conn.distill_inboxes(view.message):
|
||||||
view.app.push_message(instance["inbox"], message, instance)
|
view.app.push_message(instance.inbox, message, instance)
|
||||||
|
|
||||||
view.cache.set('handle-relay', view.message.object_id, message.id, 'str')
|
view.cache.set('handle-relay', view.message.object_id, message.id, 'str')
|
||||||
|
|
||||||
|
@ -52,7 +52,7 @@ async def handle_forward(view: ActorView, conn: Connection) -> None:
|
||||||
logging.debug('>> forward: %s', message)
|
logging.debug('>> forward: %s', message)
|
||||||
|
|
||||||
for instance in conn.distill_inboxes(view.message):
|
for instance in conn.distill_inboxes(view.message):
|
||||||
view.app.push_message(instance["inbox"], view.message, instance)
|
view.app.push_message(instance.inbox, view.message, instance)
|
||||||
|
|
||||||
view.cache.set('handle-relay', view.message.id, message.id, 'str')
|
view.cache.set('handle-relay', view.message.id, message.id, 'str')
|
||||||
|
|
||||||
|
@ -177,7 +177,7 @@ async def handle_undo(view: ActorView, conn: Connection) -> None:
|
||||||
return
|
return
|
||||||
|
|
||||||
# prevent past unfollows from removing an instance
|
# prevent past unfollows from removing an instance
|
||||||
if view.instance['followid'] and view.instance['followid'] != view.message.object_id:
|
if view.instance.followid and view.instance.followid != view.message.object_id:
|
||||||
return
|
return
|
||||||
|
|
||||||
with conn.transaction():
|
with conn.transaction():
|
||||||
|
@ -221,18 +221,18 @@ async def run_processor(view: ActorView) -> None:
|
||||||
|
|
||||||
with view.database.session() as conn:
|
with view.database.session() as conn:
|
||||||
if view.instance:
|
if view.instance:
|
||||||
if not view.instance['software']:
|
if not view.instance.software:
|
||||||
if (nodeinfo := await view.client.fetch_nodeinfo(view.instance['domain'])):
|
if (nodeinfo := await view.client.fetch_nodeinfo(view.instance.domain)):
|
||||||
with conn.transaction():
|
with conn.transaction():
|
||||||
view.instance = conn.put_inbox(
|
view.instance = conn.put_inbox(
|
||||||
domain = view.instance['domain'],
|
domain = view.instance.domain,
|
||||||
software = nodeinfo.sw_name
|
software = nodeinfo.sw_name
|
||||||
)
|
)
|
||||||
|
|
||||||
if not view.instance['actor']:
|
if not view.instance.actor:
|
||||||
with conn.transaction():
|
with conn.transaction():
|
||||||
view.instance = conn.put_inbox(
|
view.instance = conn.put_inbox(
|
||||||
domain = view.instance['domain'],
|
domain = view.instance.domain,
|
||||||
actor = view.actor.id
|
actor = view.actor.id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1,26 +1,22 @@
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import aputils
|
import aputils
|
||||||
import traceback
|
import traceback
|
||||||
import typing
|
|
||||||
|
from aiohttp.web import Request
|
||||||
|
|
||||||
from .base import View, register_route
|
from .base import View, register_route
|
||||||
|
|
||||||
from .. import logger as logging
|
from .. import logger as logging
|
||||||
|
from ..database import schema
|
||||||
from ..misc import Message, Response
|
from ..misc import Message, Response
|
||||||
from ..processors import run_processor
|
from ..processors import run_processor
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
|
||||||
from aiohttp.web import Request
|
|
||||||
from bsql import Row
|
|
||||||
|
|
||||||
|
|
||||||
@register_route('/actor', '/inbox')
|
@register_route('/actor', '/inbox')
|
||||||
class ActorView(View):
|
class ActorView(View):
|
||||||
signature: aputils.Signature
|
signature: aputils.Signature
|
||||||
message: Message
|
message: Message
|
||||||
actor: Message
|
actor: Message
|
||||||
instancce: Row
|
instance: schema.Instance
|
||||||
signer: aputils.Signer
|
signer: aputils.Signer
|
||||||
|
|
||||||
|
|
||||||
|
@ -47,7 +43,7 @@ class ActorView(View):
|
||||||
return response
|
return response
|
||||||
|
|
||||||
with self.database.session() as conn:
|
with self.database.session() as conn:
|
||||||
self.instance = conn.get_inbox(self.actor.shared_inbox)
|
self.instance = conn.get_inbox(self.actor.shared_inbox) # type: ignore[assignment]
|
||||||
|
|
||||||
# reject if actor is banned
|
# reject if actor is banned
|
||||||
if conn.get_domain_ban(self.actor.domain):
|
if conn.get_domain_ban(self.actor.domain):
|
||||||
|
|
|
@ -90,10 +90,10 @@ class Login(View):
|
||||||
|
|
||||||
token = conn.put_token(data['username'])
|
token = conn.put_token(data['username'])
|
||||||
|
|
||||||
resp = Response.new({'token': token['code']}, ctype = 'json')
|
resp = Response.new({'token': token.code}, ctype = 'json')
|
||||||
resp.set_cookie(
|
resp.set_cookie(
|
||||||
'user-token',
|
'user-token',
|
||||||
token['code'],
|
token.code,
|
||||||
max_age = 60 * 60 * 24 * 365,
|
max_age = 60 * 60 * 24 * 365,
|
||||||
domain = self.config.domain,
|
domain = self.config.domain,
|
||||||
path = '/',
|
path = '/',
|
||||||
|
@ -117,7 +117,7 @@ class RelayInfo(View):
|
||||||
async def get(self, request: Request) -> Response:
|
async def get(self, request: Request) -> Response:
|
||||||
with self.database.session() as conn:
|
with self.database.session() as conn:
|
||||||
config = conn.get_config_all()
|
config = conn.get_config_all()
|
||||||
inboxes = [row['domain'] for row in conn.get_inboxes()]
|
inboxes = [row.domain for row in conn.get_inboxes()]
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
'domain': self.config.domain,
|
'domain': self.config.domain,
|
||||||
|
@ -188,7 +188,7 @@ class Config(View):
|
||||||
class Inbox(View):
|
class Inbox(View):
|
||||||
async def get(self, request: Request) -> Response:
|
async def get(self, request: Request) -> Response:
|
||||||
with self.database.session() as conn:
|
with self.database.session() as conn:
|
||||||
data = conn.get_inboxes()
|
data = tuple(conn.get_inboxes())
|
||||||
|
|
||||||
return Response.new(data, ctype = 'json')
|
return Response.new(data, ctype = 'json')
|
||||||
|
|
||||||
|
@ -202,7 +202,7 @@ class Inbox(View):
|
||||||
data['domain'] = urlparse(data["actor"]).netloc
|
data['domain'] = urlparse(data["actor"]).netloc
|
||||||
|
|
||||||
with self.database.session() as conn:
|
with self.database.session() as conn:
|
||||||
if conn.get_inbox(data['domain']):
|
if conn.get_inbox(data['domain']) is not None:
|
||||||
return Response.new_error(404, 'Instance already in database', 'json')
|
return Response.new_error(404, 'Instance already in database', 'json')
|
||||||
|
|
||||||
data['domain'] = data['domain'].encode('idna').decode()
|
data['domain'] = data['domain'].encode('idna').decode()
|
||||||
|
@ -225,7 +225,12 @@ class Inbox(View):
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
row = conn.put_inbox(**data) # type: ignore[arg-type]
|
row = conn.put_inbox(
|
||||||
|
data['domain'],
|
||||||
|
actor = data.get('actor'),
|
||||||
|
software = data.get('software'),
|
||||||
|
followid = data.get('followid')
|
||||||
|
)
|
||||||
|
|
||||||
return Response.new(row, ctype = 'json')
|
return Response.new(row, ctype = 'json')
|
||||||
|
|
||||||
|
@ -239,10 +244,15 @@ class Inbox(View):
|
||||||
|
|
||||||
data['domain'] = data['domain'].encode('idna').decode()
|
data['domain'] = data['domain'].encode('idna').decode()
|
||||||
|
|
||||||
if not (instance := conn.get_inbox(data['domain'])):
|
if (instance := conn.get_inbox(data['domain'])) is None:
|
||||||
return Response.new_error(404, 'Instance with domain not found', 'json')
|
return Response.new_error(404, 'Instance with domain not found', 'json')
|
||||||
|
|
||||||
instance = conn.put_inbox(instance['domain'], **data) # type: ignore[arg-type]
|
instance = conn.put_inbox(
|
||||||
|
instance.domain,
|
||||||
|
actor = data.get('actor'),
|
||||||
|
software = data.get('software'),
|
||||||
|
followid = data.get('followid')
|
||||||
|
)
|
||||||
|
|
||||||
return Response.new(instance, ctype = 'json')
|
return Response.new(instance, ctype = 'json')
|
||||||
|
|
||||||
|
@ -268,7 +278,7 @@ class Inbox(View):
|
||||||
class RequestView(View):
|
class RequestView(View):
|
||||||
async def get(self, request: Request) -> Response:
|
async def get(self, request: Request) -> Response:
|
||||||
with self.database.session() as conn:
|
with self.database.session() as conn:
|
||||||
instances = conn.get_requests()
|
instances = tuple(conn.get_requests())
|
||||||
|
|
||||||
return Response.new(instances, ctype = 'json')
|
return Response.new(instances, ctype = 'json')
|
||||||
|
|
||||||
|
@ -291,20 +301,20 @@ class RequestView(View):
|
||||||
|
|
||||||
message = Message.new_response(
|
message = Message.new_response(
|
||||||
host = self.config.domain,
|
host = self.config.domain,
|
||||||
actor = instance['actor'],
|
actor = instance.actor,
|
||||||
followid = instance['followid'],
|
followid = instance.followid,
|
||||||
accept = data['accept']
|
accept = data['accept']
|
||||||
)
|
)
|
||||||
|
|
||||||
self.app.push_message(instance['inbox'], message, instance)
|
self.app.push_message(instance.inbox, message, instance)
|
||||||
|
|
||||||
if data['accept'] and instance['software'] != 'mastodon':
|
if data['accept'] and instance.software != 'mastodon':
|
||||||
message = Message.new_follow(
|
message = Message.new_follow(
|
||||||
host = self.config.domain,
|
host = self.config.domain,
|
||||||
actor = instance['actor']
|
actor = instance.actor
|
||||||
)
|
)
|
||||||
|
|
||||||
self.app.push_message(instance['inbox'], message, instance)
|
self.app.push_message(instance.inbox, message, instance)
|
||||||
|
|
||||||
resp_message = {'message': 'Request accepted' if data['accept'] else 'Request denied'}
|
resp_message = {'message': 'Request accepted' if data['accept'] else 'Request denied'}
|
||||||
return Response.new(resp_message, ctype = 'json')
|
return Response.new(resp_message, ctype = 'json')
|
||||||
|
@ -314,7 +324,7 @@ class RequestView(View):
|
||||||
class DomainBan(View):
|
class DomainBan(View):
|
||||||
async def get(self, request: Request) -> Response:
|
async def get(self, request: Request) -> Response:
|
||||||
with self.database.session() as conn:
|
with self.database.session() as conn:
|
||||||
bans = tuple(conn.execute('SELECT * FROM domain_bans').all())
|
bans = tuple(conn.get_domain_bans())
|
||||||
|
|
||||||
return Response.new(bans, ctype = 'json')
|
return Response.new(bans, ctype = 'json')
|
||||||
|
|
||||||
|
@ -328,10 +338,14 @@ class DomainBan(View):
|
||||||
data['domain'] = data['domain'].encode('idna').decode()
|
data['domain'] = data['domain'].encode('idna').decode()
|
||||||
|
|
||||||
with self.database.session() as conn:
|
with self.database.session() as conn:
|
||||||
if conn.get_domain_ban(data['domain']):
|
if conn.get_domain_ban(data['domain']) is not None:
|
||||||
return Response.new_error(400, 'Domain already banned', 'json')
|
return Response.new_error(400, 'Domain already banned', 'json')
|
||||||
|
|
||||||
ban = conn.put_domain_ban(**data)
|
ban = conn.put_domain_ban(
|
||||||
|
domain = data['domain'],
|
||||||
|
reason = data.get('reason'),
|
||||||
|
note = data.get('note')
|
||||||
|
)
|
||||||
|
|
||||||
return Response.new(ban, ctype = 'json')
|
return Response.new(ban, ctype = 'json')
|
||||||
|
|
||||||
|
@ -343,15 +357,19 @@ class DomainBan(View):
|
||||||
if isinstance(data, Response):
|
if isinstance(data, Response):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
data['domain'] = data['domain'].encode('idna').decode()
|
|
||||||
|
|
||||||
if not conn.get_domain_ban(data['domain']):
|
|
||||||
return Response.new_error(404, 'Domain not banned', 'json')
|
|
||||||
|
|
||||||
if not any([data.get('note'), data.get('reason')]):
|
if not any([data.get('note'), data.get('reason')]):
|
||||||
return Response.new_error(400, 'Must include note and/or reason parameters', 'json')
|
return Response.new_error(400, 'Must include note and/or reason parameters', 'json')
|
||||||
|
|
||||||
ban = conn.update_domain_ban(**data)
|
data['domain'] = data['domain'].encode('idna').decode()
|
||||||
|
|
||||||
|
if conn.get_domain_ban(data['domain']) is None:
|
||||||
|
return Response.new_error(404, 'Domain not banned', 'json')
|
||||||
|
|
||||||
|
ban = conn.update_domain_ban(
|
||||||
|
domain = data['domain'],
|
||||||
|
reason = data.get('reason'),
|
||||||
|
note = data.get('note')
|
||||||
|
)
|
||||||
|
|
||||||
return Response.new(ban, ctype = 'json')
|
return Response.new(ban, ctype = 'json')
|
||||||
|
|
||||||
|
@ -365,7 +383,7 @@ class DomainBan(View):
|
||||||
|
|
||||||
data['domain'] = data['domain'].encode('idna').decode()
|
data['domain'] = data['domain'].encode('idna').decode()
|
||||||
|
|
||||||
if not conn.get_domain_ban(data['domain']):
|
if conn.get_domain_ban(data['domain']) is None:
|
||||||
return Response.new_error(404, 'Domain not banned', 'json')
|
return Response.new_error(404, 'Domain not banned', 'json')
|
||||||
|
|
||||||
conn.del_domain_ban(data['domain'])
|
conn.del_domain_ban(data['domain'])
|
||||||
|
@ -377,7 +395,7 @@ class DomainBan(View):
|
||||||
class SoftwareBan(View):
|
class SoftwareBan(View):
|
||||||
async def get(self, request: Request) -> Response:
|
async def get(self, request: Request) -> Response:
|
||||||
with self.database.session() as conn:
|
with self.database.session() as conn:
|
||||||
bans = tuple(conn.execute('SELECT * FROM software_bans').all())
|
bans = tuple(conn.get_software_bans())
|
||||||
|
|
||||||
return Response.new(bans, ctype = 'json')
|
return Response.new(bans, ctype = 'json')
|
||||||
|
|
||||||
|
@ -389,10 +407,14 @@ class SoftwareBan(View):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
with self.database.session() as conn:
|
with self.database.session() as conn:
|
||||||
if conn.get_software_ban(data['name']):
|
if conn.get_software_ban(data['name']) is not None:
|
||||||
return Response.new_error(400, 'Domain already banned', 'json')
|
return Response.new_error(400, 'Domain already banned', 'json')
|
||||||
|
|
||||||
ban = conn.put_software_ban(**data)
|
ban = conn.put_software_ban(
|
||||||
|
name = data['name'],
|
||||||
|
reason = data.get('reason'),
|
||||||
|
note = data.get('note')
|
||||||
|
)
|
||||||
|
|
||||||
return Response.new(ban, ctype = 'json')
|
return Response.new(ban, ctype = 'json')
|
||||||
|
|
||||||
|
@ -403,14 +425,18 @@ class SoftwareBan(View):
|
||||||
if isinstance(data, Response):
|
if isinstance(data, Response):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
if not any([data.get('note'), data.get('reason')]):
|
||||||
|
return Response.new_error(400, 'Must include note and/or reason parameters', 'json')
|
||||||
|
|
||||||
with self.database.session() as conn:
|
with self.database.session() as conn:
|
||||||
if not conn.get_software_ban(data['name']):
|
if conn.get_software_ban(data['name']) is None:
|
||||||
return Response.new_error(404, 'Software not banned', 'json')
|
return Response.new_error(404, 'Software not banned', 'json')
|
||||||
|
|
||||||
if not any([data.get('note'), data.get('reason')]):
|
ban = conn.update_software_ban(
|
||||||
return Response.new_error(400, 'Must include note and/or reason parameters', 'json')
|
name = data['name'],
|
||||||
|
reason = data.get('reason'),
|
||||||
ban = conn.update_software_ban(**data)
|
note = data.get('note')
|
||||||
|
)
|
||||||
|
|
||||||
return Response.new(ban, ctype = 'json')
|
return Response.new(ban, ctype = 'json')
|
||||||
|
|
||||||
|
@ -422,7 +448,7 @@ class SoftwareBan(View):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
with self.database.session() as conn:
|
with self.database.session() as conn:
|
||||||
if not conn.get_software_ban(data['name']):
|
if conn.get_software_ban(data['name']) is None:
|
||||||
return Response.new_error(404, 'Software not banned', 'json')
|
return Response.new_error(404, 'Software not banned', 'json')
|
||||||
|
|
||||||
conn.del_software_ban(data['name'])
|
conn.del_software_ban(data['name'])
|
||||||
|
@ -436,7 +462,7 @@ class User(View):
|
||||||
with self.database.session() as conn:
|
with self.database.session() as conn:
|
||||||
items = []
|
items = []
|
||||||
|
|
||||||
for row in conn.execute('SELECT * FROM users'):
|
for row in conn.get_users():
|
||||||
del row['hash']
|
del row['hash']
|
||||||
items.append(row)
|
items.append(row)
|
||||||
|
|
||||||
|
@ -450,12 +476,16 @@ class User(View):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
with self.database.session() as conn:
|
with self.database.session() as conn:
|
||||||
if conn.get_user(data['username']):
|
if conn.get_user(data['username']) is not None:
|
||||||
return Response.new_error(404, 'User already exists', 'json')
|
return Response.new_error(404, 'User already exists', 'json')
|
||||||
|
|
||||||
user = conn.put_user(**data)
|
user = conn.put_user(
|
||||||
del user['hash']
|
username = data['username'],
|
||||||
|
password = data['password'],
|
||||||
|
handle = data.get('handle')
|
||||||
|
)
|
||||||
|
|
||||||
|
del user['hash']
|
||||||
return Response.new(user, ctype = 'json')
|
return Response.new(user, ctype = 'json')
|
||||||
|
|
||||||
|
|
||||||
|
@ -466,9 +496,13 @@ class User(View):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
with self.database.session(True) as conn:
|
with self.database.session(True) as conn:
|
||||||
user = conn.put_user(**data)
|
user = conn.put_user(
|
||||||
del user['hash']
|
username = data['username'],
|
||||||
|
password = data['password'],
|
||||||
|
handle = data.get('handle')
|
||||||
|
)
|
||||||
|
|
||||||
|
del user['hash']
|
||||||
return Response.new(user, ctype = 'json')
|
return Response.new(user, ctype = 'json')
|
||||||
|
|
||||||
|
|
||||||
|
@ -479,7 +513,7 @@ class User(View):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
with self.database.session(True) as conn:
|
with self.database.session(True) as conn:
|
||||||
if not conn.get_user(data['username']):
|
if conn.get_user(data['username']) is None:
|
||||||
return Response.new_error(404, 'User does not exist', 'json')
|
return Response.new_error(404, 'User does not exist', 'json')
|
||||||
|
|
||||||
conn.del_user(data['username'])
|
conn.del_user(data['username'])
|
||||||
|
@ -491,7 +525,7 @@ class User(View):
|
||||||
class Whitelist(View):
|
class Whitelist(View):
|
||||||
async def get(self, request: Request) -> Response:
|
async def get(self, request: Request) -> Response:
|
||||||
with self.database.session() as conn:
|
with self.database.session() as conn:
|
||||||
items = tuple(conn.execute('SELECT * FROM whitelist').all())
|
items = tuple(conn.get_domains_whitelist())
|
||||||
|
|
||||||
return Response.new(items, ctype = 'json')
|
return Response.new(items, ctype = 'json')
|
||||||
|
|
||||||
|
@ -502,13 +536,13 @@ class Whitelist(View):
|
||||||
if isinstance(data, Response):
|
if isinstance(data, Response):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
data['domain'] = data['domain'].encode('idna').decode()
|
domain = data['domain'].encode('idna').decode()
|
||||||
|
|
||||||
with self.database.session() as conn:
|
with self.database.session() as conn:
|
||||||
if conn.get_domain_whitelist(data['domain']):
|
if conn.get_domain_whitelist(domain) is not None:
|
||||||
return Response.new_error(400, 'Domain already added to whitelist', 'json')
|
return Response.new_error(400, 'Domain already added to whitelist', 'json')
|
||||||
|
|
||||||
item = conn.put_domain_whitelist(**data)
|
item = conn.put_domain_whitelist(domain)
|
||||||
|
|
||||||
return Response.new(item, ctype = 'json')
|
return Response.new(item, ctype = 'json')
|
||||||
|
|
||||||
|
@ -519,12 +553,12 @@ class Whitelist(View):
|
||||||
if isinstance(data, Response):
|
if isinstance(data, Response):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
data['domain'] = data['domain'].encode('idna').decode()
|
domain = data['domain'].encode('idna').decode()
|
||||||
|
|
||||||
with self.database.session() as conn:
|
with self.database.session() as conn:
|
||||||
if not conn.get_domain_whitelist(data['domain']):
|
if conn.get_domain_whitelist(domain) is None:
|
||||||
return Response.new_error(404, 'Domain not in whitelist', 'json')
|
return Response.new_error(404, 'Domain not in whitelist', 'json')
|
||||||
|
|
||||||
conn.del_domain_whitelist(data['domain'])
|
conn.del_domain_whitelist(domain)
|
||||||
|
|
||||||
return Response.new({'message': 'Removed domain from whitelist'}, ctype = 'json')
|
return Response.new({'message': 'Removed domain from whitelist'}, ctype = 'json')
|
||||||
|
|
|
@ -202,7 +202,7 @@ class AdminConfig(View):
|
||||||
'message': message,
|
'message': message,
|
||||||
'desc': {
|
'desc': {
|
||||||
"name": "Name of the relay to be displayed in the header of the pages and in " +
|
"name": "Name of the relay to be displayed in the header of the pages and in " +
|
||||||
"the actor endpoint.",
|
"the actor endpoint.", # noqa: E131
|
||||||
"note": "Description of the relay to be displayed on the front page and as the " +
|
"note": "Description of the relay to be displayed on the front page and as the " +
|
||||||
"bio in the actor endpoint.",
|
"bio in the actor endpoint.",
|
||||||
"theme": "Color theme to use on the web pages.",
|
"theme": "Color theme to use on the web pages.",
|
||||||
|
|
|
@ -4,7 +4,6 @@ import typing
|
||||||
|
|
||||||
from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError
|
from aiohttp.client_exceptions import ClientConnectionError, ClientSSLError
|
||||||
from asyncio.exceptions import TimeoutError as AsyncTimeoutError
|
from asyncio.exceptions import TimeoutError as AsyncTimeoutError
|
||||||
from bsql import Row
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from multiprocessing import Event, Process, Queue, Value
|
from multiprocessing import Event, Process, Queue, Value
|
||||||
from multiprocessing.synchronize import Event as EventType
|
from multiprocessing.synchronize import Event as EventType
|
||||||
|
@ -13,6 +12,7 @@ from queue import Empty, Queue as QueueType
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from . import application, logger as logging
|
from . import application, logger as logging
|
||||||
|
from .database.schema import Instance
|
||||||
from .http_client import HttpClient
|
from .http_client import HttpClient
|
||||||
from .misc import IS_WINDOWS, Message, get_app
|
from .misc import IS_WINDOWS, Message, get_app
|
||||||
|
|
||||||
|
@ -29,7 +29,7 @@ class QueueItem:
|
||||||
class PostItem(QueueItem):
|
class PostItem(QueueItem):
|
||||||
inbox: str
|
inbox: str
|
||||||
message: Message
|
message: Message
|
||||||
instance: Row | None
|
instance: Instance | None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def domain(self) -> str:
|
def domain(self) -> str:
|
||||||
|
@ -122,7 +122,7 @@ class PushWorkers(list[PushWorker]):
|
||||||
self.queue.put(item)
|
self.queue.put(item)
|
||||||
|
|
||||||
|
|
||||||
def push_message(self, inbox: str, message: Message, instance: Row) -> None:
|
def push_message(self, inbox: str, message: Message, instance: Instance) -> None:
|
||||||
self.queue.put(PostItem(inbox, message, instance))
|
self.queue.put(PostItem(inbox, message, instance))
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue