diff --git a/Grammar/python.gram b/Grammar/python.gram index 1169ddbb7b5122..6ab09de4fc149a 100644 --- a/Grammar/python.gram +++ b/Grammar/python.gram @@ -634,7 +634,8 @@ keyword_pattern[KeyPatternPair*]: type_alias[stmt_ty]: | "type" n=NAME t=[type_params] '=' b=expression { - CHECK_VERSION(stmt_ty, 12, "Type statement is", _PyAST_TypeAlias(n->v.Name.id, t, b, EXTRA)) } + CHECK_VERSION(stmt_ty, 12, "Type statement is", + _PyAST_TypeAlias(CHECK(expr_ty, _PyPegen_set_expr_context(p, n, Store)), t, b, EXTRA)) } # Type parameter declaration # -------------------------- diff --git a/Include/internal/pycore_ast.h b/Include/internal/pycore_ast.h index caba84650e1eb9..9f1cef0541508c 100644 --- a/Include/internal/pycore_ast.h +++ b/Include/internal/pycore_ast.h @@ -239,7 +239,7 @@ struct _stmt { } Assign; struct { - identifier name; + expr_ty name; asdl_typeparam_seq *typeparams; expr_ty value; } TypeAlias; @@ -704,9 +704,9 @@ stmt_ty _PyAST_Delete(asdl_expr_seq * targets, int lineno, int col_offset, int stmt_ty _PyAST_Assign(asdl_expr_seq * targets, expr_ty value, string type_comment, int lineno, int col_offset, int end_lineno, int end_col_offset, PyArena *arena); -stmt_ty _PyAST_TypeAlias(identifier name, asdl_typeparam_seq * typeparams, - expr_ty value, int lineno, int col_offset, int - end_lineno, int end_col_offset, PyArena *arena); +stmt_ty _PyAST_TypeAlias(expr_ty name, asdl_typeparam_seq * typeparams, expr_ty + value, int lineno, int col_offset, int end_lineno, int + end_col_offset, PyArena *arena); stmt_ty _PyAST_AugAssign(expr_ty target, operator_ty op, expr_ty value, int lineno, int col_offset, int end_lineno, int end_col_offset, PyArena *arena); diff --git a/Lib/ast.py b/Lib/ast.py index 34122f04267a2c..93afa7d8035de8 100644 --- a/Lib/ast.py +++ b/Lib/ast.py @@ -1070,7 +1070,8 @@ def visit_ParamSpec(self, node): self.write("**" + node.name) def visit_TypeAlias(self, node): - self.fill("type " + node.name) + self.fill("type ") + self.traverse(node.name) self._typeparams_helper(node.typeparams) self.write(" = ") self.traverse(node.value) diff --git a/Parser/Python.asdl b/Parser/Python.asdl index db0f12a5f46ebe..cfc41ef45b568b 100644 --- a/Parser/Python.asdl +++ b/Parser/Python.asdl @@ -25,7 +25,7 @@ module Python | Delete(expr* targets) | Assign(expr* targets, expr value, string? type_comment) - | TypeAlias(identifier name, typeparam* typeparams, expr value) + | TypeAlias(expr name, typeparam* typeparams, expr value) | AugAssign(expr target, operator op, expr value) -- 'simple' indicates that we annotate simple name without parens | AnnAssign(expr target, expr annotation, expr? value, int simple) diff --git a/Parser/parser.c b/Parser/parser.c index 091b12c11c094e..cf0a12f3eb6fb0 100644 --- a/Parser/parser.c +++ b/Parser/parser.c @@ -10563,7 +10563,7 @@ type_alias_rule(Parser *p) UNUSED(_end_lineno); // Only used by EXTRA macro int _end_col_offset = _token->end_col_offset; UNUSED(_end_col_offset); // Only used by EXTRA macro - _res = CHECK_VERSION ( stmt_ty , 12 , "Type statement is" , _PyAST_TypeAlias ( n -> v . Name . id , t , b , EXTRA ) ); + _res = CHECK_VERSION ( stmt_ty , 12 , "Type statement is" , _PyAST_TypeAlias ( CHECK ( expr_ty , _PyPegen_set_expr_context ( p , n , Store ) ) , t , b , EXTRA ) ); if (_res == NULL && PyErr_Occurred()) { p->error_indicator = 1; p->level--; diff --git a/Python/Python-ast.c b/Python/Python-ast.c index c30f3f851e37cb..7124d6c0c25eb8 100644 --- a/Python/Python-ast.c +++ b/Python/Python-ast.c @@ -1175,7 +1175,7 @@ init_types(struct ast_state *state) " | Return(expr? value)\n" " | Delete(expr* targets)\n" " | Assign(expr* targets, expr value, string? type_comment)\n" - " | TypeAlias(identifier name, typeparam* typeparams, expr value)\n" + " | TypeAlias(expr name, typeparam* typeparams, expr value)\n" " | AugAssign(expr target, operator op, expr value)\n" " | AnnAssign(expr target, expr annotation, expr? value, int simple)\n" " | For(expr target, expr iter, stmt* body, stmt* orelse, string? type_comment)\n" @@ -1248,7 +1248,7 @@ init_types(struct ast_state *state) return 0; state->TypeAlias_type = make_type(state, "TypeAlias", state->stmt_type, TypeAlias_fields, 3, - "TypeAlias(identifier name, typeparam* typeparams, expr value)"); + "TypeAlias(expr name, typeparam* typeparams, expr value)"); if (!state->TypeAlias_type) return 0; state->AugAssign_type = make_type(state, "AugAssign", state->stmt_type, AugAssign_fields, 3, @@ -2192,8 +2192,8 @@ _PyAST_Assign(asdl_expr_seq * targets, expr_ty value, string type_comment, int } stmt_ty -_PyAST_TypeAlias(identifier name, asdl_typeparam_seq * typeparams, expr_ty - value, int lineno, int col_offset, int end_lineno, int +_PyAST_TypeAlias(expr_ty name, asdl_typeparam_seq * typeparams, expr_ty value, + int lineno, int col_offset, int end_lineno, int end_col_offset, PyArena *arena) { stmt_ty p; @@ -4047,7 +4047,7 @@ ast2obj_stmt(struct ast_state *state, void* _o) tp = (PyTypeObject *)state->TypeAlias_type; result = PyType_GenericNew(tp, NULL, NULL); if (!result) goto failed; - value = ast2obj_identifier(state, o->v.TypeAlias.name); + value = ast2obj_expr(state, o->v.TypeAlias.name); if (!value) goto failed; if (PyObject_SetAttr(result, state->name, value) == -1) goto failed; @@ -6846,7 +6846,7 @@ obj2ast_stmt(struct ast_state *state, PyObject* obj, stmt_ty* out, PyArena* return 1; } if (isinstance) { - identifier name; + expr_ty name; asdl_typeparam_seq* typeparams; expr_ty value; @@ -6862,7 +6862,7 @@ obj2ast_stmt(struct ast_state *state, PyObject* obj, stmt_ty* out, PyArena* if (_Py_EnterRecursiveCall(" while traversing 'TypeAlias' node")) { goto failed; } - res = obj2ast_identifier(state, tmp, &name, arena); + res = obj2ast_expr(state, tmp, &name, arena); _Py_LeaveRecursiveCall(); if (res != 0) goto failed; Py_CLEAR(tmp); diff --git a/Python/ast.c b/Python/ast.c index c3e202047f8705..cb694ab2e77d76 100644 --- a/Python/ast.c +++ b/Python/ast.c @@ -17,10 +17,12 @@ struct validator { static int validate_stmts(struct validator *, asdl_stmt_seq *); static int validate_exprs(struct validator *, asdl_expr_seq *, expr_context_ty, int); static int validate_patterns(struct validator *, asdl_pattern_seq *, int); +static int validate_typeparams(struct validator *, asdl_typeparam_seq *); static int _validate_nonempty_seq(asdl_seq *, const char *, const char *); static int validate_stmt(struct validator *, stmt_ty); static int validate_expr(struct validator *, expr_ty, expr_context_ty); static int validate_pattern(struct validator *, pattern_ty, int); +static int validate_typeparam(struct validator *, typeparam_ty); #define VALIDATE_POSITIONS(node) \ if (node->lineno > node->end_lineno) { \ @@ -672,6 +674,27 @@ validate_pattern(struct validator *state, pattern_ty p, int star_ok) return ret; } +static int +validate_typeparam(struct validator *state, typeparam_ty tp) +{ + VALIDATE_POSITIONS(tp); + int ret = -1; + switch (tp->kind) { + case TypeVar_kind: + ret = validate_name(tp->v.TypeVar.name) && + (!tp->v.TypeVar.bound || + validate_expr(state, tp->v.TypeVar.bound, Load)); + break; + case ParamSpec_kind: + ret = validate_name(tp->v.ParamSpec.name); + break; + case TypeVarTuple_kind: + ret = validate_name(tp->v.TypeVarTuple.name); + break; + } + return ret; +} + static int _validate_nonempty_seq(asdl_seq *seq, const char *what, const char *owner) { @@ -709,6 +732,7 @@ validate_stmt(struct validator *state, stmt_ty stmt) switch (stmt->kind) { case FunctionDef_kind: ret = validate_body(state, stmt->v.FunctionDef.body, "FunctionDef") && + validate_typeparams(state, stmt->v.FunctionDef.typeparams) && validate_arguments(state, stmt->v.FunctionDef.args) && validate_exprs(state, stmt->v.FunctionDef.decorator_list, Load, 0) && (!stmt->v.FunctionDef.returns || @@ -716,6 +740,7 @@ validate_stmt(struct validator *state, stmt_ty stmt) break; case ClassDef_kind: ret = validate_body(state, stmt->v.ClassDef.body, "ClassDef") && + validate_typeparams(state, stmt->v.ClassDef.typeparams) && validate_exprs(state, stmt->v.ClassDef.bases, Load, 0) && validate_keywords(state, stmt->v.ClassDef.keywords) && validate_exprs(state, stmt->v.ClassDef.decorator_list, Load, 0); @@ -747,7 +772,9 @@ validate_stmt(struct validator *state, stmt_ty stmt) validate_expr(state, stmt->v.AnnAssign.annotation, Load); break; case TypeAlias_kind: - ret = validate_expr(state, stmt->v.TypeAlias.value, Load); + ret = validate_expr(state, stmt->v.TypeAlias.name, Store) && + validate_typeparams(state, stmt->v.TypeAlias.typeparams) && + validate_expr(state, stmt->v.TypeAlias.value, Load); break; case For_kind: ret = validate_expr(state, stmt->v.For.target, Store) && @@ -896,6 +923,7 @@ validate_stmt(struct validator *state, stmt_ty stmt) break; case AsyncFunctionDef_kind: ret = validate_body(state, stmt->v.AsyncFunctionDef.body, "AsyncFunctionDef") && + validate_typeparams(state, stmt->v.AsyncFunctionDef.typeparams) && validate_arguments(state, stmt->v.AsyncFunctionDef.args) && validate_exprs(state, stmt->v.AsyncFunctionDef.decorator_list, Load, 0) && (!stmt->v.AsyncFunctionDef.returns || @@ -968,6 +996,20 @@ validate_patterns(struct validator *state, asdl_pattern_seq *patterns, int star_ return 1; } +static int +validate_typeparams(struct validator *state, asdl_typeparam_seq *tps) +{ + Py_ssize_t i; + for (i = 0; i < asdl_seq_LEN(tps); i++) { + typeparam_ty tp = asdl_seq_GET(tps, i); + if (tp) { + if (!validate_typeparam(state, tp)) + return 0; + } + } + return 1; +} + /* See comments in symtable.c. */ #define COMPILER_STACK_FRAME_SCALE 3 diff --git a/Python/ast_opt.c b/Python/ast_opt.c index 53984be7655c14..a296f86cf867fe 100644 --- a/Python/ast_opt.c +++ b/Python/ast_opt.c @@ -643,6 +643,7 @@ static int astfold_withitem(withitem_ty node_, PyArena *ctx_, _PyASTOptimizeStat static int astfold_excepthandler(excepthandler_ty node_, PyArena *ctx_, _PyASTOptimizeState *state); static int astfold_match_case(match_case_ty node_, PyArena *ctx_, _PyASTOptimizeState *state); static int astfold_pattern(pattern_ty node_, PyArena *ctx_, _PyASTOptimizeState *state); +static int astfold_typeparam(typeparam_ty node_, PyArena *ctx_, _PyASTOptimizeState *state); #define CALL(FUNC, TYPE, ARG) \ if (!FUNC((ARG), ctx_, state)) \ @@ -881,6 +882,7 @@ astfold_stmt(stmt_ty node_, PyArena *ctx_, _PyASTOptimizeState *state) } switch (node_->kind) { case FunctionDef_kind: + CALL_SEQ(astfold_typeparam, typeparam, node_->v.FunctionDef.typeparams); CALL(astfold_arguments, arguments_ty, node_->v.FunctionDef.args); CALL(astfold_body, asdl_seq, node_->v.FunctionDef.body); CALL_SEQ(astfold_expr, expr, node_->v.FunctionDef.decorator_list); @@ -889,6 +891,7 @@ astfold_stmt(stmt_ty node_, PyArena *ctx_, _PyASTOptimizeState *state) } break; case AsyncFunctionDef_kind: + CALL_SEQ(astfold_typeparam, typeparam, node_->v.AsyncFunctionDef.typeparams); CALL(astfold_arguments, arguments_ty, node_->v.AsyncFunctionDef.args); CALL(astfold_body, asdl_seq, node_->v.AsyncFunctionDef.body); CALL_SEQ(astfold_expr, expr, node_->v.AsyncFunctionDef.decorator_list); @@ -897,6 +900,7 @@ astfold_stmt(stmt_ty node_, PyArena *ctx_, _PyASTOptimizeState *state) } break; case ClassDef_kind: + CALL_SEQ(astfold_typeparam, typeparam, node_->v.ClassDef.typeparams); CALL_SEQ(astfold_expr, expr, node_->v.ClassDef.bases); CALL_SEQ(astfold_keyword, keyword, node_->v.ClassDef.keywords); CALL(astfold_body, asdl_seq, node_->v.ClassDef.body); @@ -924,6 +928,8 @@ astfold_stmt(stmt_ty node_, PyArena *ctx_, _PyASTOptimizeState *state) CALL_OPT(astfold_expr, expr_ty, node_->v.AnnAssign.value); break; case TypeAlias_kind: + CALL(astfold_expr, expr_ty, node_->v.TypeAlias.name); + CALL_SEQ(astfold_typeparam, typeparam, node_->v.TypeAlias.typeparams); CALL(astfold_expr, expr_ty, node_->v.TypeAlias.value); break; case For_kind: @@ -1078,6 +1084,21 @@ astfold_match_case(match_case_ty node_, PyArena *ctx_, _PyASTOptimizeState *stat return 1; } +static int +astfold_typeparam(typeparam_ty node_, PyArena *ctx_, _PyASTOptimizeState *state) +{ + switch (node_->kind) { + case TypeVar_kind: + CALL_OPT(astfold_expr, expr_ty, node_->v.TypeVar.bound); + break; + case ParamSpec_kind: + break; + case TypeVarTuple_kind: + break; + } + return 1; +} + #undef CALL #undef CALL_OPT #undef CALL_SEQ