create HttpClient class to avoid creating a new session every request
This commit is contained in:
parent
32764a1f93
commit
b85b4ab80b
|
@ -6,12 +6,12 @@ import signal
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from cachetools import LRUCache
|
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
from .config import RelayConfig
|
from .config import RelayConfig
|
||||||
from .database import RelayDatabase
|
from .database import RelayDatabase
|
||||||
from .misc import DotDict, check_open_port, request, set_app
|
from .http_client import HttpClient
|
||||||
|
from .misc import DotDict, check_open_port, set_app
|
||||||
from .views import routes
|
from .views import routes
|
||||||
|
|
||||||
|
|
||||||
|
@ -27,8 +27,6 @@ class Application(web.Application):
|
||||||
if not self['config'].load():
|
if not self['config'].load():
|
||||||
self['config'].save()
|
self['config'].save()
|
||||||
|
|
||||||
self['cache'] = DotDict({key: Cache(maxsize=self['config'][key]) for key in self['config'].cachekeys})
|
|
||||||
self['semaphore'] = asyncio.Semaphore(self['config'].push_limit)
|
|
||||||
self['workers'] = []
|
self['workers'] = []
|
||||||
self['last_worker'] = 0
|
self['last_worker'] = 0
|
||||||
|
|
||||||
|
@ -37,12 +35,18 @@ class Application(web.Application):
|
||||||
self['database'] = RelayDatabase(self['config'])
|
self['database'] = RelayDatabase(self['config'])
|
||||||
self['database'].load()
|
self['database'].load()
|
||||||
|
|
||||||
|
self['client'] = HttpClient(
|
||||||
|
limit = self.config.push_limit,
|
||||||
|
timeout = self.config.timeout,
|
||||||
|
cache_size = self.config.json_cache
|
||||||
|
)
|
||||||
|
|
||||||
self.set_signal_handler()
|
self.set_signal_handler()
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cache(self):
|
def client(self):
|
||||||
return self['cache']
|
return self['client']
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -76,6 +80,9 @@ class Application(web.Application):
|
||||||
|
|
||||||
|
|
||||||
def push_message(self, inbox, message):
|
def push_message(self, inbox, message):
|
||||||
|
if self.config.workers <= 0:
|
||||||
|
return asyncio.ensure_future(self.client.post(inbox, message))
|
||||||
|
|
||||||
worker = self['workers'][self['last_worker']]
|
worker = self['workers'][self['last_worker']]
|
||||||
worker.queue.put((inbox, message))
|
worker.queue.put((inbox, message))
|
||||||
|
|
||||||
|
@ -145,11 +152,6 @@ class Application(web.Application):
|
||||||
self['workers'].clear()
|
self['workers'].clear()
|
||||||
|
|
||||||
|
|
||||||
class Cache(LRUCache):
|
|
||||||
def set_maxsize(self, value):
|
|
||||||
self.__maxsize = int(value)
|
|
||||||
|
|
||||||
|
|
||||||
class PushWorker(threading.Thread):
|
class PushWorker(threading.Thread):
|
||||||
def __init__(self, app):
|
def __init__(self, app):
|
||||||
threading.Thread.__init__(self)
|
threading.Thread.__init__(self)
|
||||||
|
@ -158,6 +160,12 @@ class PushWorker(threading.Thread):
|
||||||
|
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
|
self.client = HttpClient(
|
||||||
|
limit = self.app.config.push_limit,
|
||||||
|
timeout = self.app.config.timeout,
|
||||||
|
cache_size = self.app.config.json_cache
|
||||||
|
)
|
||||||
|
|
||||||
asyncio.run(self.handle_queue())
|
asyncio.run(self.handle_queue())
|
||||||
|
|
||||||
|
|
||||||
|
@ -166,13 +174,14 @@ class PushWorker(threading.Thread):
|
||||||
try:
|
try:
|
||||||
inbox, message = self.queue.get(block=True, timeout=0.25)
|
inbox, message = self.queue.get(block=True, timeout=0.25)
|
||||||
self.queue.task_done()
|
self.queue.task_done()
|
||||||
await request(inbox, message)
|
|
||||||
|
|
||||||
logging.verbose(f'New push from Thread-{threading.get_ident()}')
|
logging.verbose(f'New push from Thread-{threading.get_ident()}')
|
||||||
|
await self.client.post(inbox, message)
|
||||||
|
|
||||||
except queue.Empty:
|
except queue.Empty:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
await self.client.close()
|
||||||
|
|
||||||
|
|
||||||
## Can't sub-class web.Request, so let's just add some properties
|
## Can't sub-class web.Request, so let's just add some properties
|
||||||
def request_actor(self):
|
def request_actor(self):
|
||||||
|
@ -203,7 +212,6 @@ setattr(web.Request, 'instance', property(request_instance))
|
||||||
setattr(web.Request, 'message', property(request_message))
|
setattr(web.Request, 'message', property(request_message))
|
||||||
setattr(web.Request, 'signature', property(request_signature))
|
setattr(web.Request, 'signature', property(request_signature))
|
||||||
|
|
||||||
setattr(web.Request, 'cache', property(lambda self: self.app.cache))
|
|
||||||
setattr(web.Request, 'config', property(lambda self: self.app.config))
|
setattr(web.Request, 'config', property(lambda self: self.app.config))
|
||||||
setattr(web.Request, 'database', property(lambda self: self.app.database))
|
setattr(web.Request, 'database', property(lambda self: self.app.database))
|
||||||
setattr(web.Request, 'semaphore', property(lambda self: self.app.semaphore))
|
setattr(web.Request, 'semaphore', property(lambda self: self.app.semaphore))
|
||||||
|
|
|
@ -24,12 +24,6 @@ class RelayConfig(DotDict):
|
||||||
'whitelist'
|
'whitelist'
|
||||||
}
|
}
|
||||||
|
|
||||||
cachekeys = {
|
|
||||||
'json',
|
|
||||||
'objects',
|
|
||||||
'digests'
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def __init__(self, path, is_docker):
|
def __init__(self, path, is_docker):
|
||||||
DotDict.__init__(self, {})
|
DotDict.__init__(self, {})
|
||||||
|
@ -50,7 +44,7 @@ class RelayConfig(DotDict):
|
||||||
if key in ['blocked_instances', 'blocked_software', 'whitelist']:
|
if key in ['blocked_instances', 'blocked_software', 'whitelist']:
|
||||||
assert isinstance(value, (list, set, tuple))
|
assert isinstance(value, (list, set, tuple))
|
||||||
|
|
||||||
elif key in ['port', 'workers', 'json', 'objects', 'digests']:
|
elif key in ['port', 'workers', 'json_cache', 'timeout']:
|
||||||
if not isinstance(value, int):
|
if not isinstance(value, int):
|
||||||
value = int(value)
|
value = int(value)
|
||||||
|
|
||||||
|
@ -94,15 +88,14 @@ class RelayConfig(DotDict):
|
||||||
'port': 8080,
|
'port': 8080,
|
||||||
'note': 'Make a note about your instance here.',
|
'note': 'Make a note about your instance here.',
|
||||||
'push_limit': 512,
|
'push_limit': 512,
|
||||||
|
'json_cache': 1024,
|
||||||
|
'timeout': 10,
|
||||||
'workers': 0,
|
'workers': 0,
|
||||||
'host': 'relay.example.com',
|
'host': 'relay.example.com',
|
||||||
|
'whitelist_enabled': False,
|
||||||
'blocked_software': [],
|
'blocked_software': [],
|
||||||
'blocked_instances': [],
|
'blocked_instances': [],
|
||||||
'whitelist': [],
|
'whitelist': []
|
||||||
'whitelist_enabled': False,
|
|
||||||
'json': 1024,
|
|
||||||
'objects': 1024,
|
|
||||||
'digests': 1024
|
|
||||||
})
|
})
|
||||||
|
|
||||||
def ban_instance(self, instance):
|
def ban_instance(self, instance):
|
||||||
|
@ -211,7 +204,7 @@ class RelayConfig(DotDict):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
for key, value in config.items():
|
for key, value in config.items():
|
||||||
if key in ['ap', 'cache']:
|
if key in ['ap']:
|
||||||
for k, v in value.items():
|
for k, v in value.items():
|
||||||
if k not in self:
|
if k not in self:
|
||||||
continue
|
continue
|
||||||
|
@ -239,8 +232,9 @@ class RelayConfig(DotDict):
|
||||||
'note': self.note,
|
'note': self.note,
|
||||||
'push_limit': self.push_limit,
|
'push_limit': self.push_limit,
|
||||||
'workers': self.workers,
|
'workers': self.workers,
|
||||||
'ap': {key: self[key] for key in self.apkeys},
|
'json_cache': self.json_cache,
|
||||||
'cache': {key: self[key] for key in self.cachekeys}
|
'timeout': self.timeout,
|
||||||
|
'ap': {key: self[key] for key in self.apkeys}
|
||||||
}
|
}
|
||||||
|
|
||||||
with open(self._path, 'w') as fd:
|
with open(self._path, 'w') as fd:
|
||||||
|
|
|
@ -6,8 +6,6 @@ import traceback
|
||||||
from Crypto.PublicKey import RSA
|
from Crypto.PublicKey import RSA
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from .misc import fetch_nodeinfo
|
|
||||||
|
|
||||||
|
|
||||||
class RelayDatabase(dict):
|
class RelayDatabase(dict):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
|
|
203
relay/http_client.py
Normal file
203
relay/http_client.py
Normal file
|
@ -0,0 +1,203 @@
|
||||||
|
import logging
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
from aiohttp import ClientSession, ClientTimeout, TCPConnector
|
||||||
|
from aiohttp.client_exceptions import ClientConnectorError, ServerTimeoutError
|
||||||
|
from datetime import datetime
|
||||||
|
from cachetools import LRUCache
|
||||||
|
from json.decoder import JSONDecodeError
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from . import __version__
|
||||||
|
from .misc import (
|
||||||
|
MIMETYPES,
|
||||||
|
DotDict,
|
||||||
|
Message,
|
||||||
|
create_signature_header,
|
||||||
|
generate_body_digest
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
HEADERS = {
|
||||||
|
'Accept': f'{MIMETYPES["activity"]}, {MIMETYPES["json"]};q=0.9',
|
||||||
|
'User-Agent': f'ActivityRelay/{__version__}'
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Cache(LRUCache):
|
||||||
|
def set_maxsize(self, value):
|
||||||
|
self.__maxsize = int(value)
|
||||||
|
|
||||||
|
|
||||||
|
class HttpClient:
|
||||||
|
def __init__(self, limit=100, timeout=10, cache_size=1024):
|
||||||
|
self.cache = Cache(cache_size)
|
||||||
|
self.cfg = {'limit': limit, 'timeout': timeout}
|
||||||
|
self._conn = None
|
||||||
|
self._session = None
|
||||||
|
|
||||||
|
|
||||||
|
@property
|
||||||
|
def limit(self):
|
||||||
|
return self.cfg['limit']
|
||||||
|
|
||||||
|
|
||||||
|
@property
|
||||||
|
def timeout(self):
|
||||||
|
return self.cfg['timeout']
|
||||||
|
|
||||||
|
|
||||||
|
def sign_headers(self, method, url, message=None):
|
||||||
|
parsed = urlparse(url)
|
||||||
|
headers = {
|
||||||
|
'(request-target)': f'{method.lower()} {parsed.path}',
|
||||||
|
'Date': datetime.utcnow().strftime('%a, %d %b %Y %H:%M:%S GMT'),
|
||||||
|
'Host': parsed.netloc
|
||||||
|
}
|
||||||
|
|
||||||
|
if message:
|
||||||
|
data = message.to_json()
|
||||||
|
headers.update({
|
||||||
|
'Digest': f'SHA-256={generate_body_digest(data)}',
|
||||||
|
'Content-Length': str(len(data.encode('utf-8')))
|
||||||
|
})
|
||||||
|
|
||||||
|
headers['Signature'] = create_signature_header(headers)
|
||||||
|
|
||||||
|
del headers['(request-target)']
|
||||||
|
del headers['Host']
|
||||||
|
|
||||||
|
return headers
|
||||||
|
|
||||||
|
|
||||||
|
async def open(self):
|
||||||
|
if self._session:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._conn = TCPConnector(
|
||||||
|
limit = self.limit,
|
||||||
|
ttl_dns_cache = 300,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._session = ClientSession(
|
||||||
|
connector = self._conn,
|
||||||
|
headers = HEADERS,
|
||||||
|
connector_owner = True,
|
||||||
|
timeout = ClientTimeout(total=self.timeout)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
if not self._session:
|
||||||
|
return
|
||||||
|
|
||||||
|
await self._session.close()
|
||||||
|
await self._conn.close()
|
||||||
|
|
||||||
|
self._conn = None
|
||||||
|
self._session = None
|
||||||
|
|
||||||
|
|
||||||
|
async def get(self, url, sign_headers=False, loads=None, force=False):
|
||||||
|
await self.open()
|
||||||
|
|
||||||
|
try: url, _ = url.split('#', 1)
|
||||||
|
except: pass
|
||||||
|
|
||||||
|
if not force and url in self.cache:
|
||||||
|
return self.cache[url]
|
||||||
|
|
||||||
|
headers = {}
|
||||||
|
|
||||||
|
if sign_headers:
|
||||||
|
headers.update(self.sign_headers('GET', url))
|
||||||
|
|
||||||
|
try:
|
||||||
|
logging.verbose(f'Fetching resource: {url}')
|
||||||
|
|
||||||
|
async with self._session.get(url, headers=headers) as resp:
|
||||||
|
## Not expecting a response with 202s, so just return
|
||||||
|
if resp.status == 202:
|
||||||
|
return
|
||||||
|
|
||||||
|
elif resp.status != 200:
|
||||||
|
logging.verbose(f'Received error when requesting {url}: {resp.status}')
|
||||||
|
logging.verbose(await resp.read()) # change this to debug
|
||||||
|
return
|
||||||
|
|
||||||
|
if loads:
|
||||||
|
if issubclass(loads, DotDict):
|
||||||
|
message = await resp.json(loads=loads.new_from_json)
|
||||||
|
|
||||||
|
else:
|
||||||
|
message = await resp.json(loads=loads)
|
||||||
|
|
||||||
|
elif resp.content_type == MIMETYPES['activity']:
|
||||||
|
message = await resp.json(loads=Message.new_from_json)
|
||||||
|
|
||||||
|
elif resp.content_type == MIMETYPES['json']:
|
||||||
|
message = await resp.json(loads=DotDict.new_from_json)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# todo: raise TypeError or something
|
||||||
|
logging.verbose(f'Invalid Content-Type for "{url}": {resp.content_type}')
|
||||||
|
return logging.debug(f'Response: {resp.read()}')
|
||||||
|
|
||||||
|
logging.debug(f'{url} >> resp {message.to_json(4)}')
|
||||||
|
|
||||||
|
self.cache[url] = message
|
||||||
|
return message
|
||||||
|
|
||||||
|
except JSONDecodeError:
|
||||||
|
logging.verbose(f'Failed to parse JSON')
|
||||||
|
|
||||||
|
except (ClientConnectorError, ServerTimeoutError):
|
||||||
|
logging.verbose(f'Failed to connect to {urlparse(url).netloc}')
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
async def post(self, url, message):
|
||||||
|
await self.open()
|
||||||
|
|
||||||
|
headers = {'Content-Type': 'application/activity+json'}
|
||||||
|
headers.update(self.sign_headers('POST', url, message))
|
||||||
|
|
||||||
|
try:
|
||||||
|
logging.verbose(f'Sending "{message.type}" to {url}')
|
||||||
|
|
||||||
|
async with self._session.post(url, headers=headers, data=message.to_json()) as resp:
|
||||||
|
## Not expecting a response, so just return
|
||||||
|
if resp.status in {200, 202}:
|
||||||
|
return logging.verbose(f'Successfully sent "{message.type}" to {url}')
|
||||||
|
|
||||||
|
logging.verbose(f'Received error when pushing to {url}: {resp.status}')
|
||||||
|
return logging.verbose(await resp.read()) # change this to debug
|
||||||
|
|
||||||
|
except (ClientConnectorError, ServerTimeoutError):
|
||||||
|
logging.verbose(f'Failed to connect to {url.netloc}')
|
||||||
|
|
||||||
|
## prevent workers from being brought down
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
|
## Additional methods ##
|
||||||
|
async def fetch_nodeinfo(domain):
|
||||||
|
nodeinfo_url = None
|
||||||
|
wk_nodeinfo = await self.get(f'https://{domain}/.well-known/nodeinfo', loads=WKNodeinfo)
|
||||||
|
|
||||||
|
for version in ['20', '21']:
|
||||||
|
try:
|
||||||
|
nodeinfo_url = wk_nodeinfo.get_url(version)
|
||||||
|
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if not nodeinfo_url:
|
||||||
|
logging.verbose(f'Failed to fetch nodeinfo url for domain: {domain}')
|
||||||
|
return False
|
||||||
|
|
||||||
|
return await request(nodeinfo_url, loads=Nodeinfo) or False
|
|
@ -145,7 +145,7 @@ def cli_inbox_follow(actor):
|
||||||
inbox = inbox_data['inbox']
|
inbox = inbox_data['inbox']
|
||||||
|
|
||||||
except KeyError:
|
except KeyError:
|
||||||
actor_data = asyncio.run(misc.request(actor))
|
actor_data = asyncio.run(app.client.get(actor, sign_headers=True))
|
||||||
|
|
||||||
if not actor_data:
|
if not actor_data:
|
||||||
return click.echo(f'Failed to fetch actor: {actor}')
|
return click.echo(f'Failed to fetch actor: {actor}')
|
||||||
|
@ -157,7 +157,7 @@ def cli_inbox_follow(actor):
|
||||||
actor = actor
|
actor = actor
|
||||||
)
|
)
|
||||||
|
|
||||||
asyncio.run(misc.request(inbox, message))
|
asyncio.run(app.client.post(inbox, message))
|
||||||
click.echo(f'Sent follow message to actor: {actor}')
|
click.echo(f'Sent follow message to actor: {actor}')
|
||||||
|
|
||||||
|
|
||||||
|
@ -183,7 +183,7 @@ def cli_inbox_unfollow(actor):
|
||||||
)
|
)
|
||||||
|
|
||||||
except KeyError:
|
except KeyError:
|
||||||
actor_data = asyncio.run(misc.request(actor))
|
actor_data = asyncio.run(app.client.get(actor, sign_headers=True))
|
||||||
inbox = actor_data.shared_inbox
|
inbox = actor_data.shared_inbox
|
||||||
message = misc.Message.new_unfollow(
|
message = misc.Message.new_unfollow(
|
||||||
host = app.config.host,
|
host = app.config.host,
|
||||||
|
@ -195,7 +195,7 @@ def cli_inbox_unfollow(actor):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
asyncio.run(misc.request(inbox, message))
|
asyncio.run(app.client.post(inbox, message))
|
||||||
click.echo(f'Sent unfollow message to: {actor}')
|
click.echo(f'Sent unfollow message to: {actor}')
|
||||||
|
|
||||||
|
|
||||||
|
@ -319,7 +319,7 @@ def cli_software_ban(name, fetch_nodeinfo):
|
||||||
return click.echo('Banned all relay software')
|
return click.echo('Banned all relay software')
|
||||||
|
|
||||||
if fetch_nodeinfo:
|
if fetch_nodeinfo:
|
||||||
nodeinfo = asyncio.run(misc.fetch_nodeinfo(name))
|
nodeinfo = asyncio.run(app.client.fetch_nodeinfo(name))
|
||||||
|
|
||||||
if not software:
|
if not software:
|
||||||
click.echo(f'Failed to fetch software name from domain: {name}')
|
click.echo(f'Failed to fetch software name from domain: {name}')
|
||||||
|
@ -347,7 +347,7 @@ def cli_software_unban(name, fetch_nodeinfo):
|
||||||
return click.echo('Unbanned all relay software')
|
return click.echo('Unbanned all relay software')
|
||||||
|
|
||||||
if fetch_nodeinfo:
|
if fetch_nodeinfo:
|
||||||
nodeinfo = asyncio.run(misc.fetch_nodeinfo(name))
|
nodeinfo = asyncio.run(app.client.fetch_nodeinfo(name))
|
||||||
|
|
||||||
if not nodeinfo:
|
if not nodeinfo:
|
||||||
click.echo(f'Failed to fetch software name from domain: {name}')
|
click.echo(f'Failed to fetch software name from domain: {name}')
|
||||||
|
|
148
relay/misc.py
148
relay/misc.py
|
@ -9,8 +9,6 @@ import uuid
|
||||||
from Crypto.Hash import SHA, SHA256, SHA512
|
from Crypto.Hash import SHA, SHA256, SHA512
|
||||||
from Crypto.PublicKey import RSA
|
from Crypto.PublicKey import RSA
|
||||||
from Crypto.Signature import PKCS1_v1_5
|
from Crypto.Signature import PKCS1_v1_5
|
||||||
from aiohttp import ClientSession, ClientTimeout
|
|
||||||
from aiohttp.client_exceptions import ClientConnectorError, ServerTimeoutError
|
|
||||||
from aiohttp.hdrs import METH_ALL as METHODS
|
from aiohttp.hdrs import METH_ALL as METHODS
|
||||||
from aiohttp.web import Response as AiohttpResponse, View as AiohttpView
|
from aiohttp.web import Response as AiohttpResponse, View as AiohttpView
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
@ -117,14 +115,8 @@ def distill_inboxes(actor, object_id):
|
||||||
|
|
||||||
|
|
||||||
def generate_body_digest(body):
|
def generate_body_digest(body):
|
||||||
bodyhash = app.cache.digests.get(body)
|
|
||||||
|
|
||||||
if bodyhash:
|
|
||||||
return bodyhash
|
|
||||||
|
|
||||||
h = SHA256.new(body.encode('utf-8'))
|
h = SHA256.new(body.encode('utf-8'))
|
||||||
bodyhash = base64.b64encode(h.digest()).decode('utf-8')
|
bodyhash = base64.b64encode(h.digest()).decode('utf-8')
|
||||||
app.cache.digests[body] = bodyhash
|
|
||||||
|
|
||||||
return bodyhash
|
return bodyhash
|
||||||
|
|
||||||
|
@ -138,141 +130,6 @@ def sign_signing_string(sigstring, key):
|
||||||
return base64.b64encode(sigdata).decode('utf-8')
|
return base64.b64encode(sigdata).decode('utf-8')
|
||||||
|
|
||||||
|
|
||||||
async def fetch_actor_key(actor):
|
|
||||||
actor_data = await request(actor)
|
|
||||||
|
|
||||||
if not actor_data:
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
return RSA.importKey(actor_data['publicKey']['publicKeyPem'])
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logging.debug(f'Exception occured while fetching actor key: {e}')
|
|
||||||
|
|
||||||
|
|
||||||
async def fetch_nodeinfo(domain):
|
|
||||||
nodeinfo_url = None
|
|
||||||
wk_nodeinfo = await request(f'https://{domain}/.well-known/nodeinfo', sign_headers=False, activity=False)
|
|
||||||
|
|
||||||
if not wk_nodeinfo:
|
|
||||||
return
|
|
||||||
|
|
||||||
wk_nodeinfo = WKNodeinfo(wk_nodeinfo)
|
|
||||||
|
|
||||||
for version in ['20', '21']:
|
|
||||||
try:
|
|
||||||
nodeinfo_url = wk_nodeinfo.get_url(version)
|
|
||||||
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if not nodeinfo_url:
|
|
||||||
logging.verbose(f'Failed to fetch nodeinfo url for domain: {domain}')
|
|
||||||
return False
|
|
||||||
|
|
||||||
nodeinfo = await request(nodeinfo_url, sign_headers=False, activity=False)
|
|
||||||
|
|
||||||
if not nodeinfo:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return Nodeinfo(nodeinfo)
|
|
||||||
|
|
||||||
|
|
||||||
async def request(uri, data=None, force=False, sign_headers=True, activity=True, timeout=10):
|
|
||||||
## If a get request and not force, try to use the cache first
|
|
||||||
if not data and not force:
|
|
||||||
try:
|
|
||||||
return app.cache.json[uri]
|
|
||||||
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
url = urlparse(uri)
|
|
||||||
method = 'POST' if data else 'GET'
|
|
||||||
action = data.get('type') if data else None
|
|
||||||
headers = {
|
|
||||||
'Accept': f'{MIMETYPES["activity"]}, {MIMETYPES["json"]};q=0.9',
|
|
||||||
'User-Agent': 'ActivityRelay',
|
|
||||||
}
|
|
||||||
|
|
||||||
if data:
|
|
||||||
headers['Content-Type'] = MIMETYPES['activity' if activity else 'json']
|
|
||||||
|
|
||||||
if sign_headers:
|
|
||||||
signing_headers = {
|
|
||||||
'(request-target)': f'{method.lower()} {url.path}',
|
|
||||||
'Date': datetime.utcnow().strftime('%a, %d %b %Y %H:%M:%S GMT'),
|
|
||||||
'Host': url.netloc
|
|
||||||
}
|
|
||||||
|
|
||||||
if data:
|
|
||||||
assert isinstance(data, dict)
|
|
||||||
|
|
||||||
data = json.dumps(data)
|
|
||||||
signing_headers.update({
|
|
||||||
'Digest': f'SHA-256={generate_body_digest(data)}',
|
|
||||||
'Content-Length': str(len(data.encode('utf-8')))
|
|
||||||
})
|
|
||||||
|
|
||||||
signing_headers['Signature'] = create_signature_header(signing_headers)
|
|
||||||
|
|
||||||
del signing_headers['(request-target)']
|
|
||||||
del signing_headers['Host']
|
|
||||||
|
|
||||||
headers.update(signing_headers)
|
|
||||||
|
|
||||||
try:
|
|
||||||
if data:
|
|
||||||
logging.verbose(f'Sending "{action}" to inbox: {uri}')
|
|
||||||
|
|
||||||
else:
|
|
||||||
logging.verbose(f'Sending GET request to url: {uri}')
|
|
||||||
|
|
||||||
timeout_cfg = ClientTimeout(connect=timeout)
|
|
||||||
async with ClientSession(trace_configs=http_debug(), timeout=timeout_cfg) as session, app.semaphore:
|
|
||||||
async with session.request(method, uri, headers=headers, data=data) as resp:
|
|
||||||
## aiohttp has been known to leak if the response hasn't been read,
|
|
||||||
## so we're just gonna read the request no matter what
|
|
||||||
resp_data = await resp.read()
|
|
||||||
|
|
||||||
## Not expecting a response, so just return
|
|
||||||
if resp.status == 202:
|
|
||||||
return
|
|
||||||
|
|
||||||
elif resp.status != 200:
|
|
||||||
if not resp_data:
|
|
||||||
return logging.verbose(f'Received error when requesting {uri}: {resp.status} {resp_data}')
|
|
||||||
|
|
||||||
return logging.verbose(f'Received error when sending {action} to {uri}: {resp.status} {resp_data}')
|
|
||||||
|
|
||||||
if resp.content_type == MIMETYPES['activity']:
|
|
||||||
resp_data = await resp.json(loads=Message.new_from_json)
|
|
||||||
|
|
||||||
elif resp.content_type == MIMETYPES['json']:
|
|
||||||
resp_data = await resp.json(loads=DotDict.new_from_json)
|
|
||||||
|
|
||||||
else:
|
|
||||||
logging.verbose(f'Invalid Content-Type for "{url}": {resp.content_type}')
|
|
||||||
return logging.debug(f'Response: {resp_data}')
|
|
||||||
|
|
||||||
logging.debug(f'{uri} >> resp {resp_data}')
|
|
||||||
|
|
||||||
app.cache.json[uri] = resp_data
|
|
||||||
return resp_data
|
|
||||||
|
|
||||||
except JSONDecodeError:
|
|
||||||
logging.verbose(f'Failed to parse JSON')
|
|
||||||
return
|
|
||||||
|
|
||||||
except (ClientConnectorError, ServerTimeoutError):
|
|
||||||
logging.verbose(f'Failed to connect to {url.netloc}')
|
|
||||||
return
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
|
|
||||||
async def validate_signature(actor, signature, http_request):
|
async def validate_signature(actor, signature, http_request):
|
||||||
headers = {key.lower(): value for key, value in http_request.headers.items()}
|
headers = {key.lower(): value for key, value in http_request.headers.items()}
|
||||||
headers['(request-target)'] = ' '.join([http_request.method.lower(), http_request.path])
|
headers['(request-target)'] = ' '.join([http_request.method.lower(), http_request.path])
|
||||||
|
@ -559,11 +416,6 @@ class View(AiohttpView):
|
||||||
return self._request.app
|
return self._request.app
|
||||||
|
|
||||||
|
|
||||||
@property
|
|
||||||
def cache(self):
|
|
||||||
return self.app.cache
|
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def config(self):
|
def config(self):
|
||||||
return self.app.config
|
return self.app.config
|
||||||
|
|
|
@ -1,63 +1,55 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from cachetools import LRUCache
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from . import misc
|
from .misc import Message, distill_inboxes
|
||||||
|
|
||||||
|
|
||||||
|
cache = LRUCache(1024)
|
||||||
|
|
||||||
|
|
||||||
async def handle_relay(request):
|
async def handle_relay(request):
|
||||||
if request.message.objectid in request.cache.objects:
|
if request.message.objectid in cache:
|
||||||
logging.verbose(f'already relayed {request.message.objectid}')
|
logging.verbose(f'already relayed {request.message.objectid}')
|
||||||
return
|
return
|
||||||
|
|
||||||
message = misc.Message.new_announce(
|
message = Message.new_announce(
|
||||||
host = request.config.host,
|
host = request.config.host,
|
||||||
object = request.message.objectid
|
object = request.message.objectid
|
||||||
)
|
)
|
||||||
|
|
||||||
request.cache.objects[request.message.objectid] = message.id
|
cache[request.message.objectid] = message.id
|
||||||
logging.verbose(f'Relaying post from {request.message.actorid}')
|
|
||||||
logging.debug(f'>> relay: {message}')
|
logging.debug(f'>> relay: {message}')
|
||||||
|
|
||||||
inboxes = misc.distill_inboxes(request.actor, request.message.objectid)
|
inboxes = distill_inboxes(request.actor, request.message.objectid)
|
||||||
|
|
||||||
if request.config.workers > 0:
|
for inbox in inboxes:
|
||||||
for inbox in inboxes:
|
request.app.push_message(inbox, message)
|
||||||
request.app.push_message(inbox, message)
|
|
||||||
|
|
||||||
else:
|
|
||||||
futures = [misc.request(inbox, data=message) for inbox in inboxes]
|
|
||||||
asyncio.ensure_future(asyncio.gather(*futures))
|
|
||||||
|
|
||||||
|
|
||||||
async def handle_forward(request):
|
async def handle_forward(request):
|
||||||
if request.message.id in request.cache.objects:
|
if request.message.id in cache:
|
||||||
logging.verbose(f'already forwarded {request.message.id}')
|
logging.verbose(f'already forwarded {request.message.id}')
|
||||||
return
|
return
|
||||||
|
|
||||||
message = misc.Message.new_announce(
|
message = Message.new_announce(
|
||||||
host = request.config.host,
|
host = request.config.host,
|
||||||
object = request.message
|
object = request.message
|
||||||
)
|
)
|
||||||
|
|
||||||
request.cache.objects[request.message.id] = message.id
|
cache[request.message.id] = message.id
|
||||||
logging.verbose(f'Forwarding post from {request.actor.id}')
|
logging.debug(f'>> forward: {message}')
|
||||||
logging.debug(f'>> Relay {request.message}')
|
|
||||||
|
|
||||||
inboxes = misc.distill_inboxes(request.actor, request.message.objectid)
|
inboxes = distill_inboxes(request.actor, request.message.objectid)
|
||||||
|
|
||||||
if request.config.workers > 0:
|
for inbox in inboxes:
|
||||||
for inbox in inboxes:
|
request.app.push_message(inbox, message)
|
||||||
request.app.push_message(inbox, message)
|
|
||||||
|
|
||||||
else:
|
|
||||||
futures = [misc.request(inbox, data=message) for inbox in inboxes]
|
|
||||||
asyncio.ensure_future(asyncio.gather(*futures))
|
|
||||||
|
|
||||||
|
|
||||||
async def handle_follow(request):
|
async def handle_follow(request):
|
||||||
nodeinfo = await misc.fetch_nodeinfo(request.actor.domain)
|
nodeinfo = await request.app.client.fetch_nodeinfo(request.actor.domain)
|
||||||
software = nodeinfo.swname if nodeinfo else None
|
software = nodeinfo.swname if nodeinfo else None
|
||||||
|
|
||||||
## reject if software used by actor is banned
|
## reject if software used by actor is banned
|
||||||
|
@ -67,9 +59,9 @@ async def handle_follow(request):
|
||||||
request.database.add_inbox(request.actor.shared_inbox, request.message.id, software)
|
request.database.add_inbox(request.actor.shared_inbox, request.message.id, software)
|
||||||
request.database.save()
|
request.database.save()
|
||||||
|
|
||||||
await misc.request(
|
await request.app.push_message(
|
||||||
request.actor.shared_inbox,
|
request.actor.shared_inbox,
|
||||||
misc.Message.new_response(
|
Message.new_response(
|
||||||
host = request.config.host,
|
host = request.config.host,
|
||||||
actor = request.actor.id,
|
actor = request.actor.id,
|
||||||
followid = request.message.id,
|
followid = request.message.id,
|
||||||
|
@ -80,9 +72,9 @@ async def handle_follow(request):
|
||||||
# Are Akkoma and Pleroma the only two that expect a follow back?
|
# Are Akkoma and Pleroma the only two that expect a follow back?
|
||||||
# Ignoring only Mastodon for now
|
# Ignoring only Mastodon for now
|
||||||
if software != 'mastodon':
|
if software != 'mastodon':
|
||||||
await misc.request(
|
await request.app.push_message(
|
||||||
request.actor.shared_inbox,
|
request.actor.shared_inbox,
|
||||||
misc.Message.new_follow(
|
Message.new_follow(
|
||||||
host = request.config.host,
|
host = request.config.host,
|
||||||
actor = request.actor.id
|
actor = request.actor.id
|
||||||
)
|
)
|
||||||
|
@ -99,14 +91,15 @@ async def handle_undo(request):
|
||||||
|
|
||||||
request.database.save()
|
request.database.save()
|
||||||
|
|
||||||
message = misc.Message.new_unfollow(
|
await request.app.push_message(
|
||||||
host = request.config.host,
|
request.actor.shared_inbox,
|
||||||
actor = request.actor.id,
|
Message.new_unfollow(
|
||||||
follow = request.message
|
host = request.config.host,
|
||||||
|
actor = request.actor.id,
|
||||||
|
follow = request.message
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
await misc.request(request.actor.shared_inbox, message)
|
|
||||||
|
|
||||||
|
|
||||||
processors = {
|
processors = {
|
||||||
'Announce': handle_relay,
|
'Announce': handle_relay,
|
||||||
|
@ -123,7 +116,7 @@ async def run_processor(request):
|
||||||
return
|
return
|
||||||
|
|
||||||
if request.instance and not request.instance.get('software'):
|
if request.instance and not request.instance.get('software'):
|
||||||
nodeinfo = await misc.fetch_nodeinfo(request.instance['domain'])
|
nodeinfo = await request.app.client.fetch_nodeinfo(request.instance['domain'])
|
||||||
|
|
||||||
if nodeinfo:
|
if nodeinfo:
|
||||||
request.instance['software'] = nodeinfo.swname
|
request.instance['software'] = nodeinfo.swname
|
||||||
|
|
|
@ -102,7 +102,7 @@ async def inbox(request):
|
||||||
logging.verbose('Failed to parse inbox message')
|
logging.verbose('Failed to parse inbox message')
|
||||||
return Response.new_error(400, 'failed to parse message', 'json')
|
return Response.new_error(400, 'failed to parse message', 'json')
|
||||||
|
|
||||||
request['actor'] = await misc.request(request.signature.keyid)
|
request['actor'] = await request.app.client.get(request.signature.keyid, sign_headers=True)
|
||||||
|
|
||||||
## reject if actor is empty
|
## reject if actor is empty
|
||||||
if not request.actor:
|
if not request.actor:
|
||||||
|
|
Loading…
Reference in a new issue