From 7e7317b3f6120f79a1d3cf2f3957658cdf94b676 Mon Sep 17 00:00:00 2001 From: Jeff Bezanson Date: Fri, 12 Feb 2021 14:51:39 -0500 Subject: [PATCH] expand use of egal for testing type equality (#39604) fixes #39565 --- src/builtins.c | 56 ++++++++++++++++++++++++-------------------- src/julia_internal.h | 1 + src/subtype.c | 16 ++++--------- test/subtype.jl | 7 ++++++ 4 files changed, 42 insertions(+), 38 deletions(-) diff --git a/src/builtins.c b/src/builtins.c index 9844545c8a24db..43c9d02dd1e6c8 100644 --- a/src/builtins.c +++ b/src/builtins.c @@ -126,13 +126,27 @@ static int NOINLINE compare_fields(jl_value_t *a, jl_value_t *b, jl_datatype_t * return 1; } -static int egal_types(jl_value_t *a, jl_value_t *b, jl_typeenv_t *env) JL_NOTSAFEPOINT +static int egal_types(jl_value_t *a, jl_value_t *b, jl_typeenv_t *env, int tvar_names) JL_NOTSAFEPOINT { if (a == b) return 1; jl_datatype_t *dt = (jl_datatype_t*)jl_typeof(a); if (dt != (jl_datatype_t*)jl_typeof(b)) return 0; + if (dt == jl_datatype_type) { + jl_datatype_t *dta = (jl_datatype_t*)a; + jl_datatype_t *dtb = (jl_datatype_t*)b; + if (dta->name != dtb->name) + return 0; + size_t i, l = jl_nparams(dta); + if (jl_nparams(dtb) != l) + return 0; + for (i = 0; i < l; i++) { + if (!egal_types(jl_tparam(dta, i), jl_tparam(dtb, i), env, tvar_names)) + return 0; + } + return 1; + } if (dt == jl_tvar_type) { jl_typeenv_t *pe = env; while (pe != NULL) { @@ -142,49 +156,39 @@ static int egal_types(jl_value_t *a, jl_value_t *b, jl_typeenv_t *env) JL_NOTSAF } return 0; } - if (dt == jl_uniontype_type) { - return egal_types(((jl_uniontype_t*)a)->a, ((jl_uniontype_t*)b)->a, env) && - egal_types(((jl_uniontype_t*)a)->b, ((jl_uniontype_t*)b)->b, env); - } if (dt == jl_unionall_type) { jl_unionall_t *ua = (jl_unionall_t*)a; jl_unionall_t *ub = (jl_unionall_t*)b; - if (ua->var->name != ub->var->name) + if (tvar_names && ua->var->name != ub->var->name) return 0; - if (!(egal_types(ua->var->lb, ub->var->lb, env) && egal_types(ua->var->ub, ub->var->ub, env))) + if (!(egal_types(ua->var->lb, ub->var->lb, env, tvar_names) && egal_types(ua->var->ub, ub->var->ub, env, tvar_names))) return 0; jl_typeenv_t e = { ua->var, (jl_value_t*)ub->var, env }; - return egal_types(ua->body, ub->body, &e); + return egal_types(ua->body, ub->body, &e, tvar_names); } - if (dt == jl_datatype_type) { - jl_datatype_t *dta = (jl_datatype_t*)a; - jl_datatype_t *dtb = (jl_datatype_t*)b; - if (dta->name != dtb->name) - return 0; - size_t i, l = jl_nparams(dta); - if (jl_nparams(dtb) != l) - return 0; - for (i = 0; i < l; i++) { - if (!egal_types(jl_tparam(dta, i), jl_tparam(dtb, i), env)) - return 0; - } - return 1; + if (dt == jl_uniontype_type) { + return egal_types(((jl_uniontype_t*)a)->a, ((jl_uniontype_t*)b)->a, env, tvar_names) && + egal_types(((jl_uniontype_t*)a)->b, ((jl_uniontype_t*)b)->b, env, tvar_names); } - if (dt == jl_vararg_type) - { + if (dt == jl_vararg_type) { jl_vararg_t *vma = (jl_vararg_t*)a; jl_vararg_t *vmb = (jl_vararg_t*)b; jl_value_t *vmaT = vma->T ? vma->T : (jl_value_t*)jl_any_type; jl_value_t *vmbT = vmb->T ? vmb->T : (jl_value_t*)jl_any_type; - if (!egal_types(vmaT, vmbT, env)) + if (!egal_types(vmaT, vmbT, env, tvar_names)) return 0; if (vma->N && vmb->N) - return egal_types(vma->N, vmb->N, env); + return egal_types(vma->N, vmb->N, env, tvar_names); return !vma->N && !vmb->N; } return jl_egal(a, b); } +JL_DLLEXPORT int jl_types_egal(jl_value_t *a, jl_value_t *b) +{ + return egal_types(a, b, NULL, 0); +} + JL_DLLEXPORT int jl_egal(jl_value_t *a JL_MAYBE_UNROOTED, jl_value_t *b JL_MAYBE_UNROOTED) JL_NOTSAFEPOINT { // warning: a,b may NOT have been gc-rooted by the caller @@ -219,7 +223,7 @@ JL_DLLEXPORT int jl_egal(jl_value_t *a JL_MAYBE_UNROOTED, jl_value_t *b JL_MAYBE if (nf == 0 || !dt->layout->haspadding) return bits_equal(a, b, sz); if (dt == jl_unionall_type) - return egal_types(a, b, NULL); + return egal_types(a, b, NULL, 1); return compare_fields(a, b, dt); } diff --git a/src/julia_internal.h b/src/julia_internal.h index 8c3106557fb996..487cdbfbceb74d 100644 --- a/src/julia_internal.h +++ b/src/julia_internal.h @@ -473,6 +473,7 @@ jl_svec_t *jl_outer_unionall_vars(jl_value_t *u); jl_value_t *jl_type_intersection_env_s(jl_value_t *a, jl_value_t *b, jl_svec_t **penv, int *issubty); jl_value_t *jl_type_intersection_env(jl_value_t *a, jl_value_t *b, jl_svec_t **penv); int jl_subtype_matching(jl_value_t *a, jl_value_t *b, jl_svec_t **penv); +JL_DLLEXPORT int jl_types_egal(jl_value_t *a, jl_value_t *b); // specificity comparison assuming !(a <: b) and !(b <: a) JL_DLLEXPORT int jl_type_morespecific_no_subtype(jl_value_t *a, jl_value_t *b); jl_value_t *jl_instantiate_type_with(jl_value_t *t, jl_value_t **env, size_t n); diff --git a/src/subtype.c b/src/subtype.c index f98ecd3737458f..928727a5cd59c5 100644 --- a/src/subtype.c +++ b/src/subtype.c @@ -1811,7 +1811,7 @@ JL_DLLEXPORT int jl_subtype_env(jl_value_t *x, jl_value_t *y, jl_value_t **env, if (x == y || (jl_typeof(x) == jl_typeof(y) && (jl_is_unionall(y) || jl_is_uniontype(y)) && - jl_egal(x, y))) { + jl_types_egal(x, y))) { if (envsz != 0) { // quickly copy env from x jl_unionall_t *ua = (jl_unionall_t*)x; int i; @@ -1877,7 +1877,9 @@ JL_DLLEXPORT int jl_subtype(jl_value_t *x, jl_value_t *y) JL_DLLEXPORT int jl_types_equal(jl_value_t *a, jl_value_t *b) { - if (obviously_egal(a, b)) + if (a == b) + return 1; + if (jl_typeof(a) == jl_typeof(b) && jl_types_egal(a, b)) return 1; if (obviously_unequal(a, b)) return 0; @@ -1896,11 +1898,6 @@ JL_DLLEXPORT int jl_types_equal(jl_value_t *a, jl_value_t *b) if (b == (jl_value_t*)jl_any_type || a == jl_bottom_type) { subtype_ab = 1; } - else if (jl_typeof(a) == jl_typeof(b) && - (jl_is_unionall(b) || jl_is_uniontype(b)) && - jl_egal(a, b)) { - subtype_ab = 1; - } else if (jl_obvious_subtype(a, b, &subtype_ab)) { #ifdef NDEBUG if (subtype_ab == 0) @@ -1915,11 +1912,6 @@ JL_DLLEXPORT int jl_types_equal(jl_value_t *a, jl_value_t *b) if (a == (jl_value_t*)jl_any_type || b == jl_bottom_type) { subtype_ba = 1; } - else if (jl_typeof(b) == jl_typeof(a) && - (jl_is_unionall(a) || jl_is_uniontype(a)) && - jl_egal(b, a)) { - subtype_ba = 1; - } else if (jl_obvious_subtype(b, a, &subtype_ba)) { #ifdef NDEBUG if (subtype_ba == 0) diff --git a/test/subtype.jl b/test/subtype.jl index efd2f79082f7b3..530145da393875 100644 --- a/test/subtype.jl +++ b/test/subtype.jl @@ -1872,3 +1872,10 @@ g39218(a, b) = (@nospecialize; if a isa AB39218 && b isa AB39218; f39218(a, b); # issue #39521 @test Tuple{Type{Tuple{A}} where A, DataType, DataType} <: Tuple{Vararg{B}} where B @test Tuple{DataType, Type{Tuple{A}} where A, DataType} <: Tuple{Vararg{B}} where B + +let A = Tuple{Type{<:Union{Number, T}}, Ref{T}} where T, + B = Tuple{Type{<:Union{Number, T}}, Ref{T}} where T + # TODO: these are caught by the egal check, but the core algorithm gets them wrong + @test A == B + @test A <: B +end