Skip to content

Extensions

Extensions are a mechanism to extend container and context behavior, similar to a plugin system.

Container Extensions

Lifespan

Lifespan extension could be used to execute code when container enters and exits

import asyncio
import contextlib
from collections.abc import AsyncIterator

from aioinject import Container
from aioinject.extensions import LifespanExtension


class MyLifespanExtension(LifespanExtension):
    @contextlib.asynccontextmanager
    async def lifespan(
        self,
        container: Container,  # noqa: ARG002
    ) -> AsyncIterator[None]:
        print("Enter")
        yield None
        print("Exit")


async def main() -> None:
    container = Container(extensions=[MyLifespanExtension()])
    async with container:
        # print("Enter") is executed.
        pass
        # print("Exit") is executed.


if __name__ == "__main__":
    asyncio.run(main())

OnInit

OnInit extension is executed when container's __init__ is called, this could be used for example to register dependencies in it:

from datetime import datetime
from typing import NewType

from aioinject import Container, SyncContainer, Transient
from aioinject.extensions import OnInitExtension


Now = NewType("Now", datetime)


class TimeExtension(OnInitExtension):
    def on_init(
        self,
        container: Container | SyncContainer,
    ) -> None:
        container.register(Transient(datetime.now, Now))


container = SyncContainer(extensions=[TimeExtension()])
with container.context() as ctx:
    print(ctx.resolve(Now))

OnResolve / OnResolveSync

On resolve extension is called when individual dependency is provided within a context:

import logging
from typing import TypeVar

from aioinject import Context
from aioinject.context import ProviderRecord
from aioinject.extensions import OnResolveExtension


T = TypeVar("T")


logger = logging.getLogger(__name__)


class MyExtension(OnResolveExtension):
    async def on_resolve(
        self,
        context: Context,  # noqa: ARG002
        provider: ProviderRecord[T],
        instance: T,  # noqa: ARG002
    ) -> None:
        logger.info("%s type was provided!", provider.info.type_)