11import asyncio
22import uuid
3- from typing import Any , AsyncGenerator , Generator
3+ from typing import Any , AsyncGenerator , Generator , Generic , TypeVar
44
55import pytest
66
@@ -313,7 +313,6 @@ def target(class_val: str = Depends(TeClass("tval"))) -> None:
313313
314314
315315def test_exception_generators () -> None :
316-
317316 errors_found = 0
318317
319318 def my_generator () -> Generator [int , None , None ]:
@@ -335,7 +334,6 @@ def target(_: int = Depends(my_generator)) -> None:
335334
336335@pytest .mark .anyio
337336async def test_async_exception_generators () -> None :
338-
339337 errors_found = 0
340338
341339 async def my_generator () -> AsyncGenerator [int , None ]:
@@ -357,7 +355,6 @@ def target(_: int = Depends(my_generator)) -> None:
357355
358356@pytest .mark .anyio
359357async def test_async_exception_generators_multiple () -> None :
360-
361358 errors_found = 0
362359
363360 async def my_generator () -> AsyncGenerator [int , None ]:
@@ -383,7 +380,6 @@ def target(
383380
384381@pytest .mark .anyio
385382async def test_async_exception_in_teardown () -> None :
386-
387383 errors_found = 0
388384
389385 async def my_generator () -> AsyncGenerator [int , None ]:
@@ -404,7 +400,6 @@ def target(_: int = Depends(my_generator)) -> None:
404400
405401@pytest .mark .anyio
406402async def test_async_propagation_disabled () -> None :
407-
408403 errors_found = 0
409404
410405 async def my_generator () -> AsyncGenerator [int , None ]:
@@ -428,7 +423,6 @@ def target(_: int = Depends(my_generator)) -> None:
428423
429424
430425def test_sync_propagation_disabled () -> None :
431-
432426 errors_found = 0
433427
434428 def my_generator () -> Generator [int , None , None ]:
@@ -447,3 +441,133 @@ def target(_: int = Depends(my_generator)) -> None:
447441 target (** (g .resolve_kwargs ()))
448442
449443 assert errors_found == 0
444+
445+
446+ def test_generic_classes () -> None :
447+ errors_found = 0
448+
449+ _T = TypeVar ("_T" )
450+
451+ class MyClass :
452+ pass
453+
454+ class MainClass (Generic [_T ]):
455+ def __init__ (self , val : _T = Depends ()) -> None :
456+ self .val = val
457+
458+ def test_func (a : MainClass [MyClass ] = Depends ()) -> MyClass :
459+ return a .val
460+
461+ with DependencyGraph (target = test_func ).sync_ctx (exception_propagation = False ) as g :
462+ value = test_func (** (g .resolve_kwargs ()))
463+
464+ assert errors_found == 0
465+ assert isinstance (value , MyClass )
466+
467+
468+ def test_generic_multiple () -> None :
469+ errors_found = 0
470+
471+ _T = TypeVar ("_T" )
472+ _V = TypeVar ("_V" )
473+
474+ class MyClass1 :
475+ pass
476+
477+ class MyClass2 :
478+ pass
479+
480+ class MainClass (Generic [_T , _V ]):
481+ def __init__ (self , t_val : _T = Depends (), v_val : _V = Depends ()) -> None :
482+ self .t_val = t_val
483+ self .v_val = v_val
484+
485+ def test_func (
486+ a : MainClass [MyClass1 , MyClass2 ] = Depends (),
487+ ) -> MainClass [MyClass1 , MyClass2 ]:
488+ return a
489+
490+ with DependencyGraph (target = test_func ).sync_ctx (exception_propagation = False ) as g :
491+ result = test_func (** (g .resolve_kwargs ()))
492+
493+ assert errors_found == 0
494+ assert isinstance (result .t_val , MyClass1 )
495+ assert isinstance (result .v_val , MyClass2 )
496+
497+
498+ def test_generic_unordered () -> None :
499+ errors_found = 0
500+
501+ _T = TypeVar ("_T" )
502+ _V = TypeVar ("_V" )
503+
504+ class MyClass1 :
505+ pass
506+
507+ class MyClass2 :
508+ pass
509+
510+ class MainClass (Generic [_T , _V ]):
511+ def __init__ (self , v_val : _V = Depends (), t_val : _T = Depends ()) -> None :
512+ self .t_val = t_val
513+ self .v_val = v_val
514+
515+ def test_func (
516+ a : MainClass [MyClass1 , MyClass2 ] = Depends (),
517+ ) -> MainClass [MyClass1 , MyClass2 ]:
518+ return a
519+
520+ with DependencyGraph (target = test_func ).sync_ctx (exception_propagation = False ) as g :
521+ result = test_func (** (g .resolve_kwargs ()))
522+
523+ assert errors_found == 0
524+ assert isinstance (result .t_val , MyClass1 )
525+ assert isinstance (result .v_val , MyClass2 )
526+
527+
528+ def test_generic_classes_nesting () -> None :
529+ errors_found = 0
530+
531+ _T = TypeVar ("_T" )
532+ _V = TypeVar ("_V" )
533+
534+ class DummyClass :
535+ pass
536+
537+ class DependantClass (Generic [_V ]):
538+ def __init__ (self , var : _V = Depends ()) -> None :
539+ self .var = var
540+
541+ class MainClass (Generic [_T ]):
542+ def __init__ (self , var : _T = Depends ()) -> None :
543+ self .var = var
544+
545+ def test_func (a : MainClass [DependantClass [DummyClass ]] = Depends ()) -> DummyClass :
546+ return a .var .var
547+
548+ with DependencyGraph (target = test_func ).sync_ctx (exception_propagation = False ) as g :
549+ value = test_func (** (g .resolve_kwargs ()))
550+
551+ assert errors_found == 0
552+ assert isinstance (value , DummyClass )
553+
554+
555+ def test_generic_class_based_dependencies () -> None :
556+ """Tests that if ParamInfo is used on the target, no error is raised."""
557+
558+ _T = TypeVar ("_T" )
559+
560+ class GenericClass (Generic [_T ]):
561+ def __init__ (self , class_val : _T = Depends ()):
562+ self .return_val = class_val
563+
564+ def func_dep () -> GenericClass [int ]:
565+ return GenericClass (123 )
566+
567+ def target (my_dep : GenericClass [int ] = Depends (func_dep )) -> int :
568+ return my_dep .return_val
569+
570+ with DependencyGraph (target = target ).sync_ctx () as g :
571+ result = target (** g .resolve_kwargs ())
572+
573+ assert result == 123
0 commit comments