from __future__ import (
annotations,
) # make all type hints be strings and skip evaluating them
import asyncio
import logging
import uuid
from datetime import datetime
from types import TracebackType
from typing import TYPE_CHECKING, Optional, Type, Any
import sys
from aio_pika import IncomingMessage, Message, connect_robust, ExchangeType
from aiormq import ChannelLockedResource
from async_timeout import timeout
from handler import Registry, SystemHandler, RmqMessageTypes
from messages import (
RpcMessage,
RpcError,
TraceStoreMessage,
PingControl,
ServiceStatus,
CoreStatus,
)
from mode import Service, ServiceT
from mode.utils.logging import CompositeLogger
from mode.utils.times import want_seconds
from mode.utils.types.trees import NodeT
from settings import (
RMQ_URL,
BINDING_KEY_FANOUT,
BINDING_KEY_TOPIC,
TIMEOUT,
)
from trace import TraceStore
from utils import setup_logging, JSONType
sys.setrecursionlimit(1500) # TODO remove
# avoid circular imports
if TYPE_CHECKING:
pass
[docs]class MyService(Service):
"""Base class for agent and behaviours
Defines async service framework.
"""
def __init__(
self, identity, *, beacon: NodeT = None, loop: asyncio.AbstractEventLoop = None
) -> None:
super(MyService, self).__init__(beacon=beacon, loop=loop)
self.identity = identity
self.log = CompositeLogger(self.logger, formatter=self._format_log)
def _format_log(self, severity: int, msg: str, *args: Any, **kwargs: Any) -> str:
return f'[{self.identity}:^{"-" * (self.beacon.depth - 1)}{self.shortlabel}]: {msg}'
[docs]class Core(MyService):
"""Docstring for Core.
Every queue is automatically bound to the default exchange with a routing key which is the same as the queue name.
All async tasks must only be started in on_start method because only there the eventloop is configured.
"""
def __init__(
self,
*,
identity=None,
config=None,
clock=None,
channel_number: int = None,
beacon: NodeT = None,
loop: asyncio.AbstractEventLoop = None,
) -> None:
identity = identity or str(uuid.uuid4())
super().__init__(identity=identity, beacon=beacon, loop=loop)
self.config = {}
if config is not None:
if not isinstance(config, dict):
self.log.error(
f"Configuration must be valid dictionay, got {config}. Resetting to {{}}."
)
else:
self.log.info(f"Configuration: {config}.")
self.config = config
self.connection = None
self.channel = None
self.channel_number = channel_number
self.direct_queue = None
self.topic_exchange = None
self.fanout_exchange = None
self.behaviours = self._children
self.traces = TraceStore(size=1000)
self.peers = TraceStore(size=100)
self.futures = dict() # store for RPC futures
self.handlers: Registry = Registry()
self.clock = clock
self.web = None # set by class AsgiAgent
self.ws = None
async def __aenter__(self) -> Core:
await super(Core, self).__aenter__()
return self
async def __aexit__(
self,
exc_type: Type[BaseException] = None,
exc_val: BaseException = None,
exc_tb: TracebackType = None,
) -> Optional[bool]:
await super(Core, self).__aexit__()
return None
[docs] async def on_first_start(self):
...
[docs] async def on_start(self):
self.log.info("Starting agent.")
try:
self.connection = await connect_robust(url=RMQ_URL)
except ConnectionError as e:
self.log.exception(f"Check RabbitMQ: {e}") # TODO: implement RMW precheck
# needs to be managed when several cores running in one process (else always 1)
if self.channel_number:
self.channel = await self.connection.channel(
channel_number=self.channel_number
)
else:
self.channel = await self.connection.channel()
await self.channel.set_qos(prefetch_count=1)
await self.configure_exchanges()
try:
await self._configure_agent_queues()
except ChannelLockedResource as e:
self.log.error(f"Potential identity conflict: {self.identity}.")
raise
self.log.info(f"Start consuming: {self.direct_queue}, {self.fanout_queue}")
await self.direct_queue.consume(
consumer_tag=self.identity, callback=self.on_message
)
await self.fanout_queue.consume(callback=self.on_message)
await self._update_peers()
# TODO: refactor for better understanding and configuration
if self.config.get("UPDATE_PEER_INTERVAL") is not None:
interval = self.config.get("UPDATE_PEER_INTERVAL")
self.log.debug(f"Starting peer update with interval: {interval}")
# noinspection PyAsyncCall
self.add_future(
self.periodic_update_peers(interval)
) # service awaits future
await self.setup()
[docs] async def setup(self):
""" to be overwritten by user """
pass
async def _configure_agent_queues(self):
queue_name = self.identity
self.direct_queue = await self.channel.declare_queue(
name=queue_name, auto_delete=False, durable=False, exclusive=True
)
self.fanout_queue = await self.channel.declare_queue(
name="", auto_delete=False, durable=False, exclusive=True
)
self.log.info(f"Queues declared: {self.direct_queue}, {self.fanout_queue}")
await self.fanout_queue.bind(
self.fanout_exchange, routing_key=BINDING_KEY_FANOUT
)
self.log.info(
f"Binding: {self.fanout_queue} to {self.fanout_exchange}: BindingKey: {BINDING_KEY_FANOUT}"
)
[docs] async def on_started(self):
...
async def dummy(self):
return True
[docs] async def stop(self) -> None:
"""Stop the service."""
if not self._stopped.is_set():
# self._log_mundane('Stopping...')
self.log.info(f"Stopping agent and behaviours: {self.list_behaviour()}...")
self._stopped.set()
await self._stop_children() # tw: order reversed with regards to service.stop()
await self.on_stop()
self.log.debug("Shutting down...")
if self.wait_for_shutdown:
self.log.debug("Waiting for shutdown")
await asyncio.wait_for(self._shutdown.wait(), self.shutdown_timeout)
self.log.debug("Shutting down now")
await self._stop_futures()
await self._stop_exit_stacks()
await self.on_shutdown()
self.log.debug("-Stopped!")
[docs] async def on_stop(self):
""" Stops an agent and kills all its behaviours. """
await self.teardown()
await self.connection.close()
await self.channel.close()
self.log.info(f"Agent stopped: {self.state}")
[docs] async def teardown(self):
"""" To be overwritten by user """
pass
[docs] async def on_shutdown(self):
self.set_shutdown()
self.log.info(f"Agent shutdown: {self.state}")
# async def _async_connect(self): # pragma: no cover
# try:
# self.connection = await connect_robust(url=RMQ_URL)
# aenter = type(self.connection).__aenter__(self.connection)
# self.channel = await aenter
# self.log.info(f"Agent {self.identity} connected and authenticated.")
# except Exception:
# raise
# async def _async_disconnect(self):
# if self.is_alive:
# aexit = self.connection.__aexit__(*sys.exc_info())
# await aexit
# self.log.info("Client disconnected.")
[docs] def has_behaviour(self, behaviour):
""" Tests for behaviour """
return behaviour in self.behaviours
[docs] def list_behaviour(self):
""" Lists all behaviours """
return [str(behav) for behav in self.behaviours]
[docs] def get_behaviour(self, name: str) -> Optional[ServiceT]:
""" Returns the behaviour """
behav = [behav for behav in self.behaviours if str(behav).endswith(name)]
if len(behav) > 1:
self.log.warning(
f"{len(behav)} behaviours found for {name}. Name not unique!"
)
elif len(behav) == 0:
return None
return behav[0]
[docs] async def call(self, msg: str, target: str = None) -> str:
""" Sends PRC call """
if target is None:
target = self.identity # loopback send
result = None
correlation_id = str(uuid.uuid4())
future = self.loop.create_future()
# create awaitable future, so that in background future can be resolved while here awaiting future.result
self.futures[correlation_id] = future
await self.direct_send(msg, RmqMessageTypes.RPC.name, target, correlation_id)
try:
async with timeout(delay=TIMEOUT):
result = await future
except asyncio.TimeoutError as e:
rpc_message = RpcMessage.from_json(msg)
err_msg = f"{self}: TimeoutError after {TIMEOUT}s while waiting for RPC request: {rpc_message.c_type}: {correlation_id}"
future = self.futures.pop(correlation_id)
# future.set_exception(e)
future.cancel()
self.log.error(err_msg)
result = RpcError(error=err_msg)
return result
[docs] async def direct_send(
self,
msg: str,
msg_type: RmqMessageTypes.name,
target: str = None,
correlation_id: str = None,
headers: dict = None,
) -> None:
""" Sends message to default exchange """
if target is None:
target = self.identity # loopback send to itself
await self.channel.default_exchange.publish(
message=self._create_message(msg, msg_type, correlation_id, headers),
routing_key=target,
timeout=None,
)
self._add_trace_outgoing(correlation_id, headers, msg, msg_type, target, target)
self.log.debug(
f"Sent message: {msg}, routing_key: {self.identity}, type: {msg_type}"
)
[docs] async def fanout_send(
self,
msg: str,
msg_type: RmqMessageTypes.name,
correlation_id: str = None,
headers: dict = None,
) -> None:
""" Sends message to fanout exchange """
await self.fanout_exchange.publish(
message=self._create_message(msg, msg_type, correlation_id, headers),
routing_key=BINDING_KEY_FANOUT,
timeout=None,
)
self._add_trace_outgoing(
correlation_id, headers, msg, msg_type, "fanout", BINDING_KEY_FANOUT
)
self.log.debug(f"Sent fanout message: {msg}, routing_key: {BINDING_KEY_FANOUT}")
[docs] async def publish(self, msg: str, routing_key: str, headers: dict = None) -> None:
""" Publishes message to topic """
await self.topic_exchange.publish(
message=self._create_message(
msg,
msg_type=RmqMessageTypes.PUBSUB.name,
correlation_id=None,
headers=headers,
),
routing_key=routing_key,
timeout=None,
)
self._add_trace_outgoing(
None, headers, msg, RmqMessageTypes.PUBSUB.name, "publish", routing_key
)
self.log.debug(f"Sent: {msg}, routing_key: {routing_key}")
def _create_message(
self,
msg: str,
msg_type: RmqMessageTypes.name,
correlation_id: str = None,
headers: dict = None,
) -> Message:
return Message(
content_type="application/json",
body=msg.encode(),
timestamp=datetime.now(),
type=msg_type,
app_id=self.identity,
user_id="guest",
headers=headers,
correlation_id=correlation_id,
)
def _add_trace_outgoing(
self, correlation_id, headers, msg, msg_type, target, routing_key
):
self.traces.append(
TraceStoreMessage(
body=msg,
headers=headers,
correlation_id=correlation_id,
type=msg_type,
target=target,
routing_key=routing_key,
),
category="outgoing",
)
[docs] async def on_message(self, message: IncomingMessage):
""" Handle incoming messages
Well defined types (RmqMessageTypes) are sent to system handlers,
all others are enqueued to behaviour mailbox for user handling.
"""
# If context processor will catch an exception, the message will be returned to the queue.
async with message.process():
self.log.debug(f"Received (info/body:")
self.log.debug(f" {message.info()}")
self.log.debug(f" {message.body.decode()}")
self.traces.append(TraceStoreMessage.from_msg(message), category="incoming")
if message.type in (RmqMessageTypes.CONTROL.name, RmqMessageTypes.RPC.name):
handler = self.handlers.get(handler=message.type)
if issubclass(handler, SystemHandler):
handler_instance = handler(core=self)
return await handler_instance.handle(message)
else:
return await handler(self, message)
for behaviour in self.behaviours:
await behaviour.enqueue(message)
self.log.debug(f"Message enqueued to: {behaviour} --> {message.body}")
self.traces.append(
TraceStoreMessage.from_msg(message), category=str(behaviour)
)
async def _update_peers(self) -> None:
msg = PingControl().serialize()
correlation_id = str(uuid.uuid4())
await self.fanout_send(
msg=msg,
msg_type=RmqMessageTypes.CONTROL.name,
correlation_id=correlation_id,
)
[docs] async def periodic_update_peers(self, interval):
""" Sends periodic keepalive message to all peers (if UPDATE_PEER_INTERVAL is set)
and publishes the latest peer responses as peer list to websocket.
"""
_interval = want_seconds(interval)
async for _ in self.itertimer(_interval):
await self._update_peers()
peers = await self.list_peers()
msg = {"from": self.identity, "peers": peers}
await self._publish_ws(msg)
[docs] async def list_peers(self) -> TraceStore: # TODO: make property out of method
""" list all peers which have responded to the latest PING """
latest = self.peers.latest()
corr_id = latest[2]
peers = sorted(
[status for (ts, status, cor_id) in self.peers.filter(category=corr_id)],
key=lambda status: status.name,
)
peers = CoreStatus.schema().dump(peers, many=True)
return peers
async def _publish_ws(self, msg: JSONType):
if self.web and self.web.ws:
self.log.debug(f"Publishing ws message: {msg}")
try:
await self.web.ws.send_json(msg)
except RuntimeError as e:
self.log.exception(e)
def __repr__(self):
return "{}".format(self.__class__.__name__)
@property
def status(self):
behav_stati = list()
for behav in self.behaviours:
behav_status = ServiceStatus(name=str(behav), state=behav.state)
behav_stati.append(behav_status)
return CoreStatus(name=self.identity, state=self.state, behaviours=behav_stati)
if __name__ == "__main__":
logging.getLogger("aio_pika").setLevel(logging.INFO)
logging.getLogger("asyncio").setLevel(logging.INFO)
setup_logging(logging.DEBUG)
from mode import Worker
config = dict(UPDATE_PEER_INTERVAL=1.0)
app = Core(identity="core", config=config)
Worker(app, loglevel="info").execute_from_commandline()