Skip to content

Commit

Permalink
Fix return span
Browse files Browse the repository at this point in the history
  • Loading branch information
mustafaquraish committed Oct 30, 2024
1 parent b05066a commit 58d8bd7
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 33 deletions.
55 changes: 34 additions & 21 deletions bootstrap/stage0.c
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,7 @@ typedef struct compiler_ast_nodes_MatchCase compiler_ast_nodes_MatchCase;
typedef struct compiler_ast_nodes_Match compiler_ast_nodes_Match;
typedef struct compiler_ast_nodes_Specialization compiler_ast_nodes_Specialization;
typedef struct compiler_ast_nodes_ArrayLiteral compiler_ast_nodes_ArrayLiteral;
typedef struct compiler_ast_nodes_Return compiler_ast_nodes_Return;
typedef union compiler_ast_nodes_ASTUnion compiler_ast_nodes_ASTUnion;
typedef struct compiler_ast_nodes_AST compiler_ast_nodes_AST;
typedef struct compiler_attributes_Attribute compiler_attributes_Attribute;
Expand Down Expand Up @@ -1263,6 +1264,11 @@ struct compiler_ast_nodes_ArrayLiteral {
std_vector_Vector__13 *elements;
};

struct compiler_ast_nodes_Return {
compiler_ast_nodes_AST *expr;
std_span_Span return_span;
};

union compiler_ast_nodes_ASTUnion {
compiler_ast_nodes_Assertion assertion;
compiler_ast_nodes_Binary binary;
Expand All @@ -1288,6 +1294,7 @@ union compiler_ast_nodes_ASTUnion {
compiler_ast_nodes_Specialization spec;
compiler_ast_nodes_ArrayLiteral array_literal;
compiler_ast_nodes_AST *child;
compiler_ast_nodes_Return ret;
};

struct compiler_ast_nodes_AST {
Expand Down Expand Up @@ -3730,11 +3737,13 @@ void compiler_passes_mark_dead_code_MarkDeadCode_mark(compiler_passes_mark_dead_
case compiler_ast_nodes_ASTType_CharLiteral:
case compiler_ast_nodes_ASTType_Null: {
} break;
case compiler_ast_nodes_ASTType_Return:
case compiler_ast_nodes_ASTType_Yield:
case compiler_ast_nodes_ASTType_Defer: {
compiler_passes_mark_dead_code_MarkDeadCode_mark(this, node->u.child);
} break;
case compiler_ast_nodes_ASTType_Return: {
compiler_passes_mark_dead_code_MarkDeadCode_mark(this, node->u.ret.expr);
} break;
case compiler_ast_nodes_ASTType_UnaryOp: {
compiler_passes_mark_dead_code_MarkDeadCode_mark(this, node->u.unary.expr);
} break;
Expand Down Expand Up @@ -5710,7 +5719,7 @@ void compiler_passes_typechecker_TypeChecker_check_match(compiler_passes_typeche
compiler_ast_nodes_AST *expr = match_stmt->expr;
compiler_types_Type *expr_type = compiler_passes_typechecker_TypeChecker_check_expression(this, expr, NULL);
if (!((bool)expr_type)) {
compiler_passes_typechecker_TypeChecker_error(this, compiler_errors_Error_new(node->span, "Match statement must have a valid expression"));
compiler_passes_typechecker_TypeChecker_error(this, compiler_errors_Error_new(match_stmt->match_span, "Match statement must have a valid expression"));
return;
}
switch (expr_type->base) {
Expand Down Expand Up @@ -5870,21 +5879,22 @@ void compiler_passes_typechecker_TypeChecker_check_statement(compiler_passes_typ
}
compiler_types_Type *expected = cur_func->return_type;
compiler_types_Type *res = NULL;
compiler_ast_nodes_AST *child = node->u.child;
compiler_ast_nodes_AST *child = node->u.ret.expr;
std_span_Span ret_span = node->u.ret.return_span;
if (((bool)child)) {
res=compiler_passes_typechecker_TypeChecker_check_expression(this, child, expected);
}
if (((bool)child) && child->returns) {
} else if (expected->base==compiler_types_BaseType_Void) {
if (((bool)node->u.child)) {
compiler_passes_typechecker_TypeChecker_error(this, compiler_errors_Error_new(node->span, "Cannot return a value from a void function"));
compiler_passes_typechecker_TypeChecker_error(this, compiler_errors_Error_new(ret_span, "Cannot return a value from a void function"));
}
} else if (((bool)child)) {
if (((bool)res) && !compiler_types_Type_eq(res, expected, false)) {
compiler_passes_typechecker_TypeChecker_error(this, compiler_errors_Error_new(node->span, std_format("Return type %s does not match function return type %s", compiler_types_Type_str(res), compiler_types_Type_str(expected))));
compiler_passes_typechecker_TypeChecker_error(this, compiler_errors_Error_new(ret_span, std_format("Return type %s does not match function return type %s", compiler_types_Type_str(res), compiler_types_Type_str(expected))));
}
} else {
compiler_passes_typechecker_TypeChecker_error(this, compiler_errors_Error_new(node->span, "Expected a return value for non-void function"));
compiler_passes_typechecker_TypeChecker_error(this, compiler_errors_Error_new(ret_span, "Expected a return value for non-void function"));
}
node->returns=true;
} break;
Expand Down Expand Up @@ -6561,8 +6571,8 @@ void compiler_passes_typechecker_TypeChecker_try_resolve_typedefs_in_namespace(c
continue;
}
compiler_ast_scopes_Symbol *sym = compiler_ast_scopes_Scope_lookup_recursive(compiler_passes_generic_pass_GenericPass_scope(this->o), it->key);
if(!(((bool)sym))) { ae_assert_fail("/Users/mustafa/ocen-lang/ocen/compiler/passes/typechecker.oc:2419:16: Assertion failed: `sym?`", "Should have added the symbol into scope already"); }
if(!(sym->type==compiler_ast_scopes_SymbolType_TypeDef)) { ae_assert_fail("/Users/mustafa/ocen-lang/ocen/compiler/passes/typechecker.oc:2423:16: Assertion failed: `sym.type == TypeDef`", NULL); }
if(!(((bool)sym))) { ae_assert_fail("/Users/mustafa/ocen-lang/ocen/compiler/passes/typechecker.oc:2420:16: Assertion failed: `sym?`", "Should have added the symbol into scope already"); }
if(!(sym->type==compiler_ast_scopes_SymbolType_TypeDef)) { ae_assert_fail("/Users/mustafa/ocen-lang/ocen/compiler/passes/typechecker.oc:2424:16: Assertion failed: `sym.type == TypeDef`", NULL); }
compiler_types_Type *res = compiler_passes_typechecker_TypeChecker_resolve_type(this, it->value, false, !pre_import, true);
if (!((bool)res)) {
continue;
Expand Down Expand Up @@ -7575,11 +7585,12 @@ void compiler_passes_code_generator_CodeGenerator_gen_statement(compiler_passes_
}
compiler_passes_code_generator_CodeGenerator_gen_defers_upto(this, upto);
compiler_passes_code_generator_CodeGenerator_gen_indent(this);
if (((bool)node->u.child)) {
if (!node->u.child->returns) {
compiler_ast_nodes_AST *expr = node->u.ret.expr;
if (((bool)expr)) {
if (!expr->returns) {
std_buffer_Buffer_write_str(&this->out, "return ");
}
compiler_passes_code_generator_CodeGenerator_gen_expression(this, node->u.child, true);
compiler_passes_code_generator_CodeGenerator_gen_expression(this, expr, true);
std_buffer_Buffer_write_str(&this->out, ";\n");
} else {
std_buffer_Buffer_write_str(&this->out, "return;\n");
Expand Down Expand Up @@ -7833,7 +7844,7 @@ char *compiler_passes_code_generator_CodeGenerator_helper_gen_type(compiler_pass
}

char *compiler_passes_code_generator_CodeGenerator_get_type_name_string(compiler_passes_code_generator_CodeGenerator *this, compiler_types_Type *type, char *name, bool is_func_def) {
if(!(type != NULL)) { ae_assert_fail("/Users/mustafa/ocen-lang/ocen/compiler/passes/code_generator.oc:1081:12: Assertion failed: `type != null`", NULL); }
if(!(type != NULL)) { ae_assert_fail("/Users/mustafa/ocen-lang/ocen/compiler/passes/code_generator.oc:1082:12: Assertion failed: `type != null`", NULL); }
char *final = compiler_passes_code_generator_CodeGenerator_helper_gen_type(this, type, type, strdup(name), is_func_def);
str_strip_trailing_whitespace(final);
return final;
Expand Down Expand Up @@ -7887,7 +7898,7 @@ void compiler_passes_code_generator_CodeGenerator_gen_functions(compiler_passes_
compiler_ast_scopes_TemplateInstance *instance = std_vector_Iterator__3_cur(&__iter);
{
compiler_ast_scopes_Symbol *sym = instance->resolved;
if(!(sym->type==compiler_ast_scopes_SymbolType_Function)) { ae_assert_fail("/Users/mustafa/ocen-lang/ocen/compiler/passes/code_generator.oc:1127:24: Assertion failed: `sym.type == Function`", NULL); }
if(!(sym->type==compiler_ast_scopes_SymbolType_Function)) { ae_assert_fail("/Users/mustafa/ocen-lang/ocen/compiler/passes/code_generator.oc:1128:24: Assertion failed: `sym.type == Function`", NULL); }
compiler_ast_nodes_Function *func = sym->u.func;
compiler_passes_code_generator_CodeGenerator_gen_function(this, func);
}
Expand Down Expand Up @@ -7924,7 +7935,7 @@ void compiler_passes_code_generator_CodeGenerator_gen_function_decls(compiler_pa
compiler_ast_scopes_TemplateInstance *instance = std_vector_Iterator__3_cur(&__iter);
{
compiler_ast_scopes_Symbol *sym = instance->resolved;
if(!(sym->type==compiler_ast_scopes_SymbolType_Function)) { ae_assert_fail("/Users/mustafa/ocen-lang/ocen/compiler/passes/code_generator.oc:1154:24: Assertion failed: `sym.type == Function`", NULL); }
if(!(sym->type==compiler_ast_scopes_SymbolType_Function)) { ae_assert_fail("/Users/mustafa/ocen-lang/ocen/compiler/passes/code_generator.oc:1155:24: Assertion failed: `sym.type == Function`", NULL); }
compiler_ast_nodes_Function *func = sym->u.func;
if (func->is_dead) {
continue;
Expand Down Expand Up @@ -9429,13 +9440,14 @@ compiler_ast_nodes_AST *compiler_parser_Parser_parse_statement(compiler_parser_P
node=compiler_parser_Parser_parse_block(this);
} break;
case compiler_tokens_TokenType_Return: {
compiler_parser_Parser_consume(this, compiler_tokens_TokenType_Return);
compiler_tokens_Token *tok = compiler_parser_Parser_consume(this, compiler_tokens_TokenType_Return);
compiler_ast_nodes_AST *expr = ((compiler_ast_nodes_AST *)NULL);
if (!compiler_parser_Parser_is_end_of_statement(this)) {
expr=compiler_parser_Parser_parse_expression(this, compiler_tokens_TokenType_Newline);
}
node=compiler_ast_nodes_AST_new(compiler_ast_nodes_ASTType_Return, std_span_Span_join(start_span, compiler_parser_Parser_token(this)->span));
node->u.child=expr;
node->u.ret.expr=expr;
node->u.ret.return_span=tok->span;
compiler_parser_Parser_consume_end_of_statement(this);
} break;
case compiler_tokens_TokenType_Yield: {
Expand Down Expand Up @@ -9721,7 +9733,8 @@ compiler_ast_nodes_Function *compiler_parser_Parser_parse_function(compiler_pars
compiler_ast_nodes_AST *stmt = compiler_parser_Parser_parse_expression(this, compiler_tokens_TokenType_Newline);
if (returns) {
compiler_ast_nodes_AST *ret = compiler_ast_nodes_AST_new(compiler_ast_nodes_ASTType_Return, stmt->span);
ret->u.child=stmt;
ret->u.ret.expr=stmt;
ret->u.ret.return_span=arrow->span;
stmt=ret;
}
compiler_ast_nodes_AST *body = compiler_ast_nodes_AST_new(compiler_ast_nodes_ASTType_Block, stmt->span);
Expand Down Expand Up @@ -9753,7 +9766,7 @@ void compiler_parser_Parser_parse_extern_into_symbol(compiler_parser_Parser *thi
}

void compiler_parser_Parser_get_extern_from_attr(compiler_parser_Parser *this, compiler_ast_scopes_Symbol *sym, compiler_attributes_Attribute *attr) {
if(!(attr->type==compiler_attributes_AttributeType_Extern)) { ae_assert_fail("/Users/mustafa/ocen-lang/ocen/compiler/parser.oc:1685:12: Assertion failed: `attr.type == Extern`", NULL); }
if(!(attr->type==compiler_attributes_AttributeType_Extern)) { ae_assert_fail("/Users/mustafa/ocen-lang/ocen/compiler/parser.oc:1687:12: Assertion failed: `attr.type == Extern`", NULL); }
sym->is_extern=true;
if (attr->args->size > 0) {
sym->extern_name=std_vector_Vector__7_at(attr->args, 0);
Expand Down Expand Up @@ -10361,8 +10374,8 @@ bool compiler_parser_Parser_load_import_path(compiler_parser_Parser *this, compi
switch (path->type) {
case compiler_ast_nodes_ImportType_GlobalNamespace: {
std_vector_Vector__5 *parts = path->parts;
if(!(parts->size > 0)) { ae_assert_fail("/Users/mustafa/ocen-lang/ocen/compiler/parser.oc:2328:20: Assertion failed: `parts.size > 0`", "Expected at least one part in import path"); }
if(!(std_vector_Vector__5_at(parts, 0)->type==compiler_ast_nodes_ImportPartType_Single)) { ae_assert_fail("/Users/mustafa/ocen-lang/ocen/compiler/parser.oc:2329:20: Assertion failed: `parts.at(0).type == Single`", "Expected first part to be a single import"); }
if(!(parts->size > 0)) { ae_assert_fail("/Users/mustafa/ocen-lang/ocen/compiler/parser.oc:2330:20: Assertion failed: `parts.size > 0`", "Expected at least one part in import path"); }
if(!(std_vector_Vector__5_at(parts, 0)->type==compiler_ast_nodes_ImportPartType_Single)) { ae_assert_fail("/Users/mustafa/ocen-lang/ocen/compiler/parser.oc:2331:20: Assertion failed: `parts.at(0).type == Single`", "Expected first part to be a single import"); }
compiler_ast_nodes_ImportPartSingle first_part = std_vector_Vector__5_at(parts, 0)->u.single;
char *lib_name = first_part.name;
if (!std_map_Map__3_contains(this->program->global->namespaces, lib_name)) {
Expand Down Expand Up @@ -13361,7 +13374,7 @@ bool compiler_lsp_finder_Finder_find_in_statement(compiler_lsp_finder_Finder *th
return compiler_lsp_finder_Finder_find_in_block(this, node);
} break;
case compiler_ast_nodes_ASTType_Return: {
return ((bool)node->u.child) && compiler_lsp_finder_Finder_find_in_expression(this, node->u.child);
return ((bool)node->u.ret.expr) && compiler_lsp_finder_Finder_find_in_expression(this, node->u.ret.expr);
} break;
case compiler_ast_nodes_ASTType_Import: {
compiler_ast_nodes_Import path = node->u.import_path;
Expand Down
13 changes: 13 additions & 0 deletions compiler/ast/nodes.oc
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,11 @@ struct ArrayLiteral {
elements: &Vector<&AST>
}

struct Return {
expr: &AST
return_span: Span
}

union ASTUnion {
assertion: Assertion
binary: Binary
Expand All @@ -402,6 +407,7 @@ union ASTUnion {
spec: Specialization
array_literal: ArrayLiteral
child: &AST
ret: Return
}

struct AST {
Expand Down Expand Up @@ -440,6 +446,13 @@ def AST::new_binop(op: Operator, lhs: &AST, rhs: &AST, op_span: Span): &AST {
return ast
}

//! Returns the span we want to display in the LSP for this node
def AST::display_span(&this): Span => match .type {
Match => .u.match_stmt.match_span
If => .u.if_stmt.if_span
else => .span
}

def AST::is_identifier(&this): bool => match .type {
Identifier => true
NSLookup => true
Expand Down
2 changes: 1 addition & 1 deletion compiler/lsp/finder.oc
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def Finder::find_in_statement(&this, node: &AST): bool {
if decl.init? and .find_in_expression(decl.init) return true
}
Block => return .find_in_block(node)
Return => return node.u.child? and .find_in_expression(node.u.child)
Return => return node.u.ret.expr? and .find_in_expression(node.u.ret.expr)
Import => {
let path = node.u.import_path
let prev = path.root_sym
Expand Down
8 changes: 5 additions & 3 deletions compiler/parser.oc
Original file line number Diff line number Diff line change
Expand Up @@ -1332,13 +1332,14 @@ def Parser::parse_statement(&this): &AST {
match .token().type {
TokenType::OpenCurly => node = .parse_block()
TokenType::Return => {
.consume(TokenType::Return)
let tok = .consume(TokenType::Return)
let expr = null as &AST
if not .is_end_of_statement() {
expr = .parse_expression(end_type: TokenType::Newline)
}
node = AST::new(Return, start_span.join(.token().span))
node.u.child = expr
node.u.ret.expr = expr
node.u.ret.return_span = tok.span
.consume_end_of_statement()
}
TokenType::Yield => {
Expand Down Expand Up @@ -1647,7 +1648,8 @@ def Parser::parse_function(&this): &Function {
let stmt = .parse_expression(end_type: TokenType::Newline)
if returns {
let ret = AST::new(Return, stmt.span)
ret.u.child = stmt
ret.u.ret.expr = stmt
ret.u.ret.return_span = arrow.span
stmt = ret
}

Expand Down
7 changes: 4 additions & 3 deletions compiler/passes/code_generator.oc
Original file line number Diff line number Diff line change
Expand Up @@ -836,11 +836,12 @@ def CodeGenerator::gen_statement(&this, node: &AST) {
.gen_defers_upto(upto)
.gen_indent()

if node.u.child? {
if not node.u.child.returns {
let expr = node.u.ret.expr
if expr? {
if not expr.returns {
.out += "return "
}
.gen_expression(node.u.child, is_top_level: true)
.gen_expression(expr, is_top_level: true)
.out += ";\n"
} else {
.out += "return;\n"
Expand Down
3 changes: 2 additions & 1 deletion compiler/passes/mark_dead_code.oc
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ def MarkDeadCode::mark(&this, node: &AST) {
Import | Break | Continue | IntLiteral | FloatLiteral |
BoolLiteral | StringLiteral | CharLiteral | Null => {}

Return | Yield | Defer => .mark(node.u.child)
Yield | Defer => .mark(node.u.child)
Return => .mark(node.u.ret.expr)
UnaryOp => .mark(node.u.unary.expr)

SizeOf => .mark_type(node.u.size_of_type)
Expand Down
9 changes: 5 additions & 4 deletions compiler/passes/typechecker.oc
Original file line number Diff line number Diff line change
Expand Up @@ -1733,7 +1733,8 @@ def TypeChecker::check_statement(&this, node: &AST) {
let expected = cur_func.return_type

let res: &Type = null
let child = node.u.child
let child = node.u.ret.expr
let ret_span = node.u.ret.return_span
if child? then res = .check_expression(child, hint: expected)

if child? and child.returns {
Expand All @@ -1742,14 +1743,14 @@ def TypeChecker::check_statement(&this, node: &AST) {
} else if expected.base == BaseType::Void {
// We allow using arrow returns in void functions, they just don't return anything.
if node.u.child? {
.error(Error::new(node.span, "Cannot return a value from a void function"))
.error(Error::new(ret_span, "Cannot return a value from a void function"))
}
} else if child? {
if res? and not res.eq(expected) {
.error(Error::new(node.span, `Return type {res.str()} does not match function return type {expected.str()}`))
.error(Error::new(ret_span, `Return type {res.str()} does not match function return type {expected.str()}`))
}
} else {
.error(Error::new(node.span, "Expected a return value for non-void function"))
.error(Error::new(ret_span, "Expected a return value for non-void function"))
}
node.returns = true
}
Expand Down

0 comments on commit 58d8bd7

Please sign in to comment.