Skip to content

Commit

Permalink
Async functions return typing.Coroutine
Browse files Browse the repository at this point in the history
Summary:
Came across this because we were wrapping functions returning `AsyncGenerator` inside an `Awaitable`, which was throwing false positive errors later on when we tried to call `__aiter__` on it, for instance.

Turns out functions prefixed with `async` wrap the return type in a `Coroutine`, not an `Awaitable` (see: python/mypy#3569), and functions that are actually generators (contain a yield) just take the return annotation of `AsyncGenerator` at face value - otherwise, the function signature is understood as asynchronously returning a generator object just like any other async function (see: python/mypy#5070)

Reviewed By: dkgi

Differential Revision: D13864544

fbshipit-source-id: 0d201735252b77688a5491428cfb5818d000754b
  • Loading branch information
shannonzhu authored and facebook-github-bot committed Jan 30, 2019
1 parent f89aee0 commit 8e6a03b
Show file tree
Hide file tree
Showing 10 changed files with 86 additions and 35 deletions.
25 changes: 23 additions & 2 deletions analysis/annotatedCallable.ml
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,36 @@ open Ast
open Statement


let is_generator { Define.body; _ } =
let module YieldVisit = Visit.Make(struct
type t = bool

let expression result expression =
match result, expression with
| true, _ -> true
| false, { Node.value = Expression.Yield _; _ } -> true
| false, _ -> false

let statement result statement =
match result, statement with
| true, _ -> true
| false, { Node.value = Statement.Yield _; _ } -> true
| false, { Node.value = Statement.YieldFrom _; _ } -> true
| false, _ -> false
end)
in
YieldVisit.visit false (Source.create body)


let return_annotation ~define:({ Define.return_annotation; async; _ } as define) ~resolution =
let annotation =
Option.value_map
return_annotation
~f:(Resolution.parse_annotation resolution)
~default:Type.Top
in
if async then
Type.awaitable annotation
if async && not (is_generator define) then
Type.coroutine [Type.Object; Type.Object; annotation]
else if Define.is_coroutine define then
begin
match annotation with
Expand Down
2 changes: 2 additions & 0 deletions analysis/annotatedCallable.mli
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ open Ast
open Statement


val is_generator: Define.t -> bool

val return_annotation: define: Define.t -> resolution: Resolution.t -> Type.t

val apply_decorators: define: Define.t -> resolution: Resolution.t -> Define.t
Expand Down
21 changes: 0 additions & 21 deletions analysis/annotatedDefine.ml
Original file line number Diff line number Diff line change
Expand Up @@ -25,27 +25,6 @@ let define annotated =
annotated


let is_generator { Define.body; _ } =
let module YieldVisit = Visit.Make(struct
type t = bool

let expression result expression =
match result, expression with
| true, _ -> true
| false, { Node.value = Expression.Yield _; _ } -> true
| false, _ -> false

let statement result statement =
match result, statement with
| true, _ -> true
| false, { Node.value = Statement.Yield _; _ } -> true
| false, { Node.value = Statement.YieldFrom _; _ } -> true
| false, _ -> false
end)
in
YieldVisit.visit false (Source.create body)


let parameter_annotations { Define.parameters; _ } ~resolution =
let element index { Node.value = { Parameter.annotation; _ }; _ } =
let annotation =
Expand Down
2 changes: 0 additions & 2 deletions analysis/annotatedDefine.mli
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ val create: Define.t -> t

val define: t -> Define.t

val is_generator: t -> bool

val parameter_annotations
: t
-> resolution: Resolution.t
Expand Down
9 changes: 6 additions & 3 deletions analysis/test/annotatedCallableTest.ml
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,13 @@ let test_return_annotation _ =
|> (fun define ->
Callable.return_annotation ~define ~resolution:(TypeCheck.resolution environment ()))
in
assert_equal ~cmp:Type.equal expected return_annotation
assert_equal ~printer:Type.show ~cmp:Type.equal expected return_annotation
in
assert_return_annotation (Some (Type.expression Type.integer)) false Type.integer;
assert_return_annotation (Some (Type.expression Type.integer)) true (Type.awaitable Type.integer)
assert_return_annotation
(Some (Type.expression Type.integer))
true
(Type.coroutine [Type.Object; Type.Object; Type.integer])


let test_apply_decorators _ =
Expand Down Expand Up @@ -149,7 +152,7 @@ let test_create _ =
assert_callable "def foo() -> int: ..." ~expected:"typing.Callable('foo')[[], int]";
assert_callable
"async def foo() -> int: ..."
~expected:"typing.Callable('foo')[[], typing.Awaitable[int]]";
~expected:"typing.Callable('foo')[[], typing.Coroutine[typing.Any, typing.Any, int]]";

assert_callable
"def foo(a, b) -> str: ..."
Expand Down
35 changes: 34 additions & 1 deletion analysis/test/integration/asyncTest.ml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,40 @@ let test_check_async _ =
async def read(file: typing.AsyncIterable[str]) -> typing.List[str]:
return [data async for data in file]
|}
[]
[];

assert_type_errors
{|
async def foo() -> typing.AsyncGenerator[bool, None]:
# not a generator; this gets wrapped in a coroutine
...

reveal_type(foo())
def bar() -> None:
async for x in foo():
pass
|}
[
"Revealed type [-1]: Revealed type for `foo.(...)` is " ^
"`typing.Coroutine[typing.Any, typing.Any, typing.AsyncGenerator[bool, None]]`.";
"Incompatible awaitable type [12]: Expected an awaitable but got `unknown`.";
"Undefined attribute [16]: `typing.Coroutine[typing.Any, typing.Any, typing.Any]` " ^
"has no attribute `__aiter__`.";
];

assert_type_errors
{|
async def foo() -> typing.AsyncGenerator[bool, None]:
yield

reveal_type(foo())
def bar() -> None:
async for x in foo():
pass
|}
[
"Revealed type [-1]: Revealed type for `foo.(...)` is `typing.AsyncGenerator[bool, None]`.";
]


let () =
Expand Down
14 changes: 14 additions & 0 deletions analysis/type.ml
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,13 @@ let awaitable parameter =
}


let coroutine parameters =
Parametric {
name = "typing.Coroutine";
parameters;
}


let bool =
Primitive "bool"

Expand Down Expand Up @@ -1730,6 +1737,13 @@ let awaitable_value = function
Top


let coroutine_value = function
| Parametric { name = "typing.Coroutine"; parameters = [_; _; parameter] } ->
parameter
| _ ->
Top


let parameters = function
| Parametric { parameters; _ } -> parameters
| _ -> []
Expand Down
2 changes: 2 additions & 0 deletions analysis/type.mli
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ val parametric: string -> t list -> t
val variable: ?constraints: constraints -> ?variance: variance -> string -> t

val awaitable: t -> t
val coroutine: t list -> t
val bool: t
val bytes: t
val callable
Expand Down Expand Up @@ -258,6 +259,7 @@ val mismatch_with_any: t -> t -> bool
val optional_value: t -> t
val async_generator_value: t -> t
val awaitable_value: t -> t
val coroutine_value: t -> t

val parameters: t -> t list
val single_parameter: t -> t
Expand Down
5 changes: 2 additions & 3 deletions analysis/typeCheck.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2368,7 +2368,7 @@ module State = struct
Annotated.Callable.return_annotation ~define:define_without_location ~resolution
in
if async then
Type.awaitable_value annotation
Type.coroutine_value annotation
else
annotation
in
Expand All @@ -2384,8 +2384,7 @@ module State = struct
not (Define.is_abstract_method define_without_location) &&
not (Define.is_overloaded_method define_without_location) &&
not (Type.is_none actual &&
(Annotated.Define.create define_without_location
|> Annotated.Define.is_generator)) &&
(Annotated.Callable.is_generator define_without_location)) &&
not (Type.is_none actual && Type.is_noreturn return_annotation) then
let error =
Error.create
Expand Down
6 changes: 3 additions & 3 deletions test/test.ml
Original file line number Diff line number Diff line change
Expand Up @@ -618,8 +618,7 @@ let typeshed_stubs ?(include_helper_builtins = true) () =
class AsyncIterable(Protocol[_T_co]):
def __aiter__(self) -> AsyncIterator[_T_co]: ...
class AsyncIterator(AsyncIterable[_T_co],
Protocol[_T_co]):
class AsyncIterator(AsyncIterable[_T_co], Protocol[_T_co]):
def __anext__(self) -> Awaitable[_T_co]: ...
def __aiter__(self) -> AsyncIterator[_T_co]: ...
Expand Down Expand Up @@ -671,7 +670,8 @@ let typeshed_stubs ?(include_helper_builtins = true) () =
@abstractmethod
def __delitem__(self, v: _KT) -> None: ...
class Awaitable(Protocol[_T_co]): pass
class Awaitable(Protocol[_T_co]):
def __await__(self) -> Generator[Any, None, _T_co]: ...
class Coroutine(Awaitable[_V_co], Generic[_T_co, _T_contra, _V_co]): pass
class AsyncGenerator(AsyncIterator[_T_co], Generic[_T_co, _T_contra]):
Expand Down

0 comments on commit 8e6a03b

Please sign in to comment.