Source code for aiocontext.task_factory
"""Task factory."""
import asyncio
from functools import wraps
from .__about__ import __title__
from .errors import TaskFactoryError
_TASK_FACTORY_ATTR = '_{}_contexts'.format(__title__)
def get_task_factory_attr(loop):
"""Return the dict of contexts registered in *loop*.
Raises :exc:`TaskFactoryError` if the loop is not context-aware.
"""
task_factory = loop.get_task_factory()
if not hasattr(task_factory, _TASK_FACTORY_ATTR):
raise TaskFactoryError("Task factory is not context-aware")
return getattr(task_factory, _TASK_FACTORY_ATTR)
def _default_task_factory(loop, coro):
return asyncio.Task(coro, loop=loop)
[docs]def wrap_task_factory(loop):
"""Wrap the *loop* task factory to make it context-aware.
Internally, this replaces the loop task factory by a wrapper function that
manages context sharing between tasks. When a new task is spawned, the
original task factory is called, then for each attached context, data is
copied from the parent task to the child one. How copy is performed is
specified in :meth:`Context.copy_func`.
If *loop* uses a custom task factory, this function must be called after
setting it::
class CustomTask(asyncio.Task):
pass
def custom_task_factory(loop, coro):
return CustomTask(coro, loop=loop)
loop.set_task_factory(custom_task_factory)
wrap_task_factory(loop)
This function has no effect if the task factory is already context-aware.
"""
task_factory = loop.get_task_factory()
if hasattr(task_factory, _TASK_FACTORY_ATTR):
return
if task_factory is None:
task_factory = _default_task_factory
@wraps(task_factory)
def wrapper(loop, coro):
parent_task = asyncio.Task.current_task(loop=loop)
child_task = task_factory(loop, coro)
if child_task._source_traceback:
del child_task._source_traceback[-1]
for context in getattr(wrapper, _TASK_FACTORY_ATTR).values():
parent_data = getattr(parent_task, context._data_attr, None)
if parent_data is None:
child_data = {}
else:
child_data = context.copy_func(parent_data)
setattr(child_task, context._data_attr, child_data)
return child_task
setattr(wrapper, _TASK_FACTORY_ATTR, {})
loop.set_task_factory(wrapper)
[docs]def unwrap_task_factory(loop):
"""Restore the original task factory of *loop*.
This function cancels the effect of :func:`wrap_task_factory`. After
calling it, the loop task factory is no longer context-aware. Context
registration is lost.
This function has no effect if the task factory is not context-aware.
"""
task_factory = loop.get_task_factory()
if not hasattr(task_factory, _TASK_FACTORY_ATTR):
return
loop.set_task_factory(task_factory.__wrapped__)