Skip to content

Commit

Permalink
[FAST002] fix - do not generate invalid args
Browse files Browse the repository at this point in the history
  - Context: this rule swaps out default function argument values
for type annotations
  - Problem: When this swap occurs after a default argument value is
    already present, the function signature is no longer valid python
code (#12982)
  - Fix: Track default arguments and bail if we hit this condition,
    making this fix only sometimes possible.
  • Loading branch information
arkuhn committed Aug 28, 2024
1 parent 81cd438 commit dc94dd4
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 57 deletions.
30 changes: 29 additions & 1 deletion crates/ruff_linter/resources/test/fixtures/fastapi/FAST002.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
router = APIRouter()


# Errors
# Fixable errors

@app.get("/items/")
def get_items(
Expand All @@ -40,6 +40,34 @@ def do_stuff(
# do stuff
pass

@app.get("/users/")
def get_users(
skip: int,
limit: int,
current_user: User = Depends(get_current_user),
):
pass

@app.get("/users/")
def get_users(
current_user: User = Depends(get_current_user),
skip: int = 0,
limit: int = 10,
):
pass



# Non fixable errors

@app.get("/users/")
def get_users(
skip: int = 0,
limit: int = 10,
current_user: User = Depends(get_current_user),
):
pass


# Unchanged

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use ruff_diagnostics::{AlwaysFixableViolation, Diagnostic, Edit, Fix};
use ruff_diagnostics::{Diagnostic, Edit, Fix, FixAvailability, Violation};
use ruff_macros::{derive_message_formats, violation};
use ruff_python_ast as ast;
use ruff_python_ast::helpers::map_callable;
Expand Down Expand Up @@ -59,14 +59,16 @@ use crate::settings::types::PythonVersion;
#[violation]
pub struct FastApiNonAnnotatedDependency;

impl AlwaysFixableViolation for FastApiNonAnnotatedDependency {
impl Violation for FastApiNonAnnotatedDependency {
const FIX_AVAILABILITY: FixAvailability = FixAvailability::Sometimes;

#[derive_message_formats]
fn message(&self) -> String {
format!("FastAPI dependency without `Annotated`")
}

fn fix_title(&self) -> String {
"Replace with `Annotated`".to_string()
fn fix_title(&self) -> Option<String> {
Some("Replace with `Annotated`".to_string())
}
}

Expand All @@ -75,64 +77,95 @@ pub(crate) fn fastapi_non_annotated_dependency(
checker: &mut Checker,
function_def: &ast::StmtFunctionDef,
) {
if !checker.semantic().seen_module(Modules::FASTAPI) {
return;
}
if !is_fastapi_route(function_def, checker.semantic()) {
if !checker.semantic().seen_module(Modules::FASTAPI)
|| !is_fastapi_route(function_def, checker.semantic())
{
return;
}

let mut updatable_count = 0;
let mut has_non_updatable_default = false;
let total_params = function_def.parameters.args.len();

for parameter in &function_def.parameters.args {
let needs_update = matches!(
(&parameter.parameter.annotation, &parameter.default),
(Some(_annotation), Some(default)) if is_fastapi_dependency(checker, default)
);

if needs_update {
updatable_count += 1;
// Determine if it's safe to update this parameter:
// - if all parameters are updatable its safe.
// - if we've encountered a non-updatable parameter with a default value, its no longer
// safe. (https://github.com/astral-sh/ruff/issues/12982)
let safe_to_update = updatable_count == total_params || !has_non_updatable_default;
create_diagnostic(checker, parameter, safe_to_update);
} else if parameter.default.is_some() {
has_non_updatable_default = true;
}
}
}

fn is_fastapi_dependency(checker: &Checker, expr: &ast::Expr) -> bool {
checker
.semantic()
.resolve_qualified_name(map_callable(expr))
.is_some_and(|qualified_name| {
matches!(
qualified_name.segments(),
[
"fastapi",
"Query"
| "Path"
| "Body"
| "Cookie"
| "Header"
| "File"
| "Form"
| "Depends"
| "Security"
]
)
})
}

fn create_diagnostic(
checker: &mut Checker,
parameter: &ast::ParameterWithDefault,
safe_to_update: bool,
) {
let mut diagnostic = Diagnostic::new(FastApiNonAnnotatedDependency, parameter.range);

if safe_to_update {
if let (Some(annotation), Some(default)) =
(&parameter.parameter.annotation, &parameter.default)
{
if checker
.semantic()
.resolve_qualified_name(map_callable(default))
.is_some_and(|qualified_name| {
matches!(
qualified_name.segments(),
[
"fastapi",
"Query"
| "Path"
| "Body"
| "Cookie"
| "Header"
| "File"
| "Form"
| "Depends"
| "Security"
]
)
})
{
let mut diagnostic =
Diagnostic::new(FastApiNonAnnotatedDependency, parameter.range);

diagnostic.try_set_fix(|| {
let module = if checker.settings.target_version >= PythonVersion::Py39 {
"typing"
} else {
"typing_extensions"
};
let (import_edit, binding) = checker.importer().get_or_import_symbol(
&ImportRequest::import_from(module, "Annotated"),
function_def.start(),
checker.semantic(),
)?;
let content = format!(
"{}: {}[{}, {}]",
parameter.parameter.name.id,
binding,
checker.locator().slice(annotation.range()),
checker.locator().slice(default.range())
);
let parameter_edit = Edit::range_replacement(content, parameter.range());
Ok(Fix::unsafe_edits(import_edit, [parameter_edit]))
});

checker.diagnostics.push(diagnostic);
}
diagnostic.try_set_fix(|| {
let module = if checker.settings.target_version >= PythonVersion::Py39 {
"typing"
} else {
"typing_extensions"
};
let (import_edit, binding) = checker.importer().get_or_import_symbol(
&ImportRequest::import_from(module, "Annotated"),
parameter.range.start(),
checker.semantic(),
)?;
let content = format!(
"{}: {}[{}, {}]",
parameter.parameter.name.id,
binding,
checker.locator().slice(annotation.range()),
checker.locator().slice(default.range())
);
let parameter_edit = Edit::range_replacement(content, parameter.range);
Ok(Fix::unsafe_edits(import_edit, [parameter_edit]))
});
}
} else {
diagnostic.fix = None;
}

checker.diagnostics.push(diagnostic);
}
Original file line number Diff line number Diff line change
Expand Up @@ -261,3 +261,72 @@ FAST002.py:38:5: FAST002 [*] FastAPI dependency without `Annotated`
39 40 | ):
40 41 | # do stuff
41 42 | pass

FAST002.py:47:5: FAST002 [*] FastAPI dependency without `Annotated`
|
45 | skip: int,
46 | limit: int,
47 | current_user: User = Depends(get_current_user),
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ FAST002
48 | ):
49 | pass
|
= help: Replace with `Annotated`

Unsafe fix
12 12 | Security,
13 13 | )
14 14 | from pydantic import BaseModel
15 |+from typing import Annotated
15 16 |
16 17 | app = FastAPI()
17 18 | router = APIRouter()
--------------------------------------------------------------------------------
44 45 | def get_users(
45 46 | skip: int,
46 47 | limit: int,
47 |- current_user: User = Depends(get_current_user),
48 |+ current_user: Annotated[User, Depends(get_current_user)],
48 49 | ):
49 50 | pass
50 51 |

FAST002.py:53:5: FAST002 [*] FastAPI dependency without `Annotated`
|
51 | @app.get("/users/")
52 | def get_users(
53 | current_user: User = Depends(get_current_user),
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ FAST002
54 | skip: int = 0,
55 | limit: int = 10,
|
= help: Replace with `Annotated`

Unsafe fix
12 12 | Security,
13 13 | )
14 14 | from pydantic import BaseModel
15 |+from typing import Annotated
15 16 |
16 17 | app = FastAPI()
17 18 | router = APIRouter()
--------------------------------------------------------------------------------
50 51 |
51 52 | @app.get("/users/")
52 53 | def get_users(
53 |- current_user: User = Depends(get_current_user),
54 |+ current_user: Annotated[User, Depends(get_current_user)],
54 55 | skip: int = 0,
55 56 | limit: int = 10,
56 57 | ):

FAST002.py:67:5: FAST002 FastAPI dependency without `Annotated`
|
65 | skip: int = 0,
66 | limit: int = 10,
67 | current_user: User = Depends(get_current_user),
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ FAST002
68 | ):
69 | pass
|
= help: Replace with `Annotated`

0 comments on commit dc94dd4

Please sign in to comment.