Skip to content

Commit

Permalink
Use shapes to validate bodies (#6694)
Browse files Browse the repository at this point in the history
Co-authored-by: Matthew Hawkins <matthew.hawkins@apollographql.com>
  • Loading branch information
dylan-apollo and pubmodmatt authored Jan 30, 2025
1 parent 2b9f25f commit b3ee306
Show file tree
Hide file tree
Showing 14 changed files with 264 additions and 202 deletions.
218 changes: 153 additions & 65 deletions apollo-federation/src/sources/connect/validation/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,72 +95,96 @@ impl<'schema> Context<'schema> {
}
}

pub(crate) fn scalars() -> Shape {
Shape::one(
vec![
Shape::int([]),
Shape::float([]),
Shape::bool(None),
Shape::string(None),
Shape::null([]),
Shape::none(),
],
[],
)
}

/// Take a single expression and check that it's valid for the given context. This checks that
/// the expression can be executed given the known args and that the output shape is as expected.
///
/// TODO: this is only useful for URIs and headers right now, because it assumes objects/arrays are invalid.
pub(crate) fn validate(expression: &Expression, context: &Context) -> Result<(), Message> {
let Expression {
expression,
location,
} = expression;
let shape = expression.shape();

validate_shape(&shape, context, location.start)
pub(crate) fn validate(
expression: &Expression,
context: &Context,
expected_shape: &Shape,
) -> Result<(), Message> {
let shape = expression.expression.shape();

let actual_shape = resolve_shape(&shape, context, expression)?;
if let Some(mismatch) = expected_shape
.validate(&actual_shape)
.into_iter()
// Unknown satisfies nothing, but we have to allow it for things like `$config`
.find(|mismatch| !mismatch.received.is_unknown())
{
Err(Message {
code: context.code,
message: format!(
"{} values aren't valid here",
shape_name(&mismatch.received)
),
locations: transform_locations(&mismatch.received.locations, context, expression),
})
} else {
Ok(())
}
}

/// Validate that the shape is an acceptable output shape for an Expression.
///
/// TODO: Some day, whether objects or arrays are allowed will be dependent on &self (i.e., is the * modifier used)
fn validate_shape(
fn resolve_shape(
shape: &Shape,
context: &Context,
expression_offset: usize,
) -> Result<(), Message> {
expression: &Expression,
) -> Result<Shape, Message> {
match shape.case() {
ShapeCase::Array { .. } => Err(Message {
code: context.code,
message: "array values aren't valid here".to_string(),
locations: transform_locations(&shape.locations, context, expression_offset),
}),
ShapeCase::Object { .. } => Err(Message {
code: context.code,
message: "object values aren't valid here".to_string(),
locations: transform_locations(&shape.locations, context, expression_offset),
}),
ShapeCase::One(shapes) => {
let mut inners = Vec::new();
for inner in shapes {
validate_shape(
inners.push(resolve_shape(
&inner.with_locations(shape.locations.clone()),
context,
expression_offset,
)?;
expression,
)?);
}
Ok(())
Ok(Shape::one(inners, []))
}
ShapeCase::All(shapes) => {
let mut inners = Vec::new();
for inner in shapes {
validate_shape(
inners.push(resolve_shape(
&inner.with_locations(shape.locations.clone()),
context,
expression_offset,
)?;
expression,
)?);
}
Ok(())
Ok(Shape::all(inners, []))
}
ShapeCase::Name(name, key) => {
let mut resolved = if name.value == "$root" {
let mut key_str = key.iter().map(|key| key.to_string()).join(".");
if !key_str.is_empty() {
key_str = format!("`{key_str}` ");
}
return Err(Message {
code: context.code,
message: format!(
"`{key}` must start with one of {namespaces}",
key = key.iter().map(|key| key.to_string()).join("."),
"{key_str}must start with one of {namespaces}",
namespaces = context.var_lookup.keys().map(|ns| ns.as_str()).join(", "),
),
locations: transform_locations(
key.first().iter().flat_map(|key| &key.locations),
key.first()
.map(|key| &key.locations)
.unwrap_or(&shape.locations),
context,
expression_offset,
expression,
),
});
} else if name.value.starts_with('$') {
Expand All @@ -170,7 +194,7 @@ fn validate_shape(
"unknown variable `{name}`, must be one of {namespaces}",
namespaces = context.var_lookup.keys().map(|ns| ns.as_str()).join(", ")
),
locations: transform_locations(&shape.locations, context, expression_offset),
locations: transform_locations(&shape.locations, context, expression),
})?;
context
.var_lookup
Expand All @@ -181,11 +205,7 @@ fn validate_shape(
"{namespace} is not valid here, must be one of {namespaces}",
namespaces = context.var_lookup.keys().map(|ns| ns.as_str()).join(", "),
),
locations: transform_locations(
&shape.locations,
context,
expression_offset,
),
locations: transform_locations(&shape.locations, context, expression),
})?
.clone()
} else {
Expand All @@ -197,7 +217,7 @@ fn validate_shape(
.ok_or_else(|| Message {
code: context.code,
message: format!("unknown type `{name}`"),
locations: transform_locations(&name.locations, context, expression_offset),
locations: transform_locations(&name.locations, context, expression),
})?
};
resolved.locations.extend(shape.locations.iter().cloned());
Expand All @@ -217,35 +237,55 @@ fn validate_shape(
return Err(Message {
code: context.code,
message,
locations: transform_locations(&key.locations, context, expression_offset),
locations: transform_locations(&key.locations, context, expression),
});
}
resolved = child;
path = format!("{path}.{key}");
}
validate_shape(&resolved, context, expression_offset)
resolve_shape(&resolved, context, expression)
}
ShapeCase::Error(shape::Error { message, .. }) => Err(Message {
code: context.code,
message: message.clone(),
locations: transform_locations(&shape.locations, context, expression_offset),
locations: transform_locations(&shape.locations, context, expression),
}),
ShapeCase::Array { prefix, tail } => {
let prefix = prefix
.iter()
.map(|shape| resolve_shape(shape, context, expression))
.collect::<Result<Vec<_>, _>>()?;
let tail = resolve_shape(tail, context, expression)?;
Ok(Shape::array(prefix, tail, shape.locations.clone()))
}
ShapeCase::Object { fields, rest } => {
let mut resolved_fields = Shape::empty_map();
for (key, value) in fields {
resolved_fields.insert(key.clone(), resolve_shape(value, context, expression)?);
}
let resolved_rest = resolve_shape(rest, context, expression)?;
Ok(Shape::object(
resolved_fields,
resolved_rest,
shape.locations.clone(),
))
}
ShapeCase::None
| ShapeCase::Bool(_)
| ShapeCase::String(_)
| ShapeCase::Int(_)
| ShapeCase::Float
| ShapeCase::Null
| ShapeCase::Unknown => Ok(()),
| ShapeCase::Unknown => Ok(shape.clone()),
}
}

fn transform_locations<'a>(
locations: impl IntoIterator<Item = &'a Location>,
context: &Context,
expression_offset: usize,
expression: &Expression,
) -> Vec<Range<LineColumn>> {
locations
let mut locations: Vec<_> = locations
.into_iter()
.filter_map(|location| match &location.source_id {
SourceId::GraphQL(file_id) => context
Expand All @@ -256,12 +296,40 @@ fn transform_locations<'a>(
SourceId::Other(_) => {
// Right now, this always refers to the JSONSelection location
context.source.line_col_for_subslice(
location.span.start + expression_offset..location.span.end + expression_offset,
location.span.start + expression.location.start
..location.span.end + expression.location.start,
context.schema,
)
}
})
.collect()
.collect();
if locations.is_empty() {
// Highlight the whole expression
locations.extend(context.source.line_col_for_subslice(
expression.location.start..expression.location.end,
context.schema,
))
}
locations
}

/// A simplified shape name for error messages
fn shape_name(shape: &Shape) -> &'static str {
match shape.case() {
ShapeCase::Bool(_) => "boolean",
ShapeCase::String(_) => "string",
ShapeCase::Int(_) => "number",
ShapeCase::Float => "number",
ShapeCase::Null => "null",
ShapeCase::Array { .. } => "array",
ShapeCase::Object { .. } => "object",
ShapeCase::One(_) => "union",
ShapeCase::All(_) => "intersection",
ShapeCase::Name(_, _) => "unknown",
ShapeCase::Unknown => "unknown",
ShapeCase::None => "none",
ShapeCase::Error(_) => "error",
}
}

#[cfg(test)]
Expand Down Expand Up @@ -289,7 +357,7 @@ mod tests {
import: ["@connect", "@source"]
)
@source(name: "v2", http: { baseURL: "http://127.0.0.1" })
type Query {
aField(
int: Int
Expand All @@ -301,21 +369,21 @@ mod tests {
): AnObject @connect(source: "v2", http: {GET: """{EXPRESSION}"""})
something: String
}
scalar CustomScalar
input InputObject {
bool: Boolean
}
type AnObject {
bool: Boolean
}
input MultiLevelInput {
inner: MultiLevel
}
type MultiLevel {
nested: String
}
Expand All @@ -325,7 +393,7 @@ mod tests {
SCHEMA.replace("EXPRESSION", selection)
}

fn validate_with_context(selection: &str) -> Result<(), Message> {
fn validate_with_context(selection: &str, expected: Shape) -> Result<(), Message> {
let schema_str = schema_for(selection);
let schema = Schema::parse(&schema_str, "schema").unwrap();
let object = schema.get_object("Query").unwrap();
Expand All @@ -351,7 +419,7 @@ mod tests {
};
let context =
Context::for_connect_request(&schema_info, coordinate, &expr_string, Code::InvalidUrl);
validate(&expression(selection), &context)
validate(&expression(selection), &context, &expected)
}

/// Given a full expression replaced in `{EXPRESSION}` above, find the line/col of a substring.
Expand Down Expand Up @@ -407,7 +475,7 @@ mod tests {
#[case::last("$args.array->last.bool")]
#[case::multi_level_input("$args.multiLevel.inner.nested")]
fn valid_expressions(#[case] selection: &str) {
validate_with_context(selection).unwrap();
validate_with_context(selection, scalars()).unwrap();
}

#[rstest]
Expand All @@ -433,7 +501,7 @@ mod tests {
#[case::this_on_query("$this.something")]
#[case::bare_field_no_var("something")]
fn invalid_expressions(#[case] selection: &str) {
let err = validate_with_context(selection);
let err = validate_with_context(selection, scalars());
assert!(err.is_err());
assert!(
!err.err().unwrap().locations.is_empty(),
Expand All @@ -444,7 +512,8 @@ mod tests {
#[test]
fn bare_field_with_path() {
let selection = "something.blah";
let err = validate_with_context(selection).expect_err("missing property is unknown");
let err =
validate_with_context(selection, scalars()).expect_err("missing property is unknown");
let expected_location = location_of_expression("something", selection);
assert!(
err.message.contains("`something.blah`"),
Expand All @@ -467,7 +536,7 @@ mod tests {
#[test]
fn object_in_url() {
let selection = "$args.object";
let err = validate_with_context(selection).expect_err("objects are not allowed");
let err = validate_with_context(selection, scalars()).expect_err("objects are not allowed");
let expected_location = location_of_expression("object", selection);
assert!(
err.locations.contains(&expected_location),
Expand All @@ -480,7 +549,8 @@ mod tests {
#[test]
fn nested_unknown_property() {
let selection = "$args.multiLevel.inner.unknown";
let err = validate_with_context(selection).expect_err("missing property is unknown");
let err =
validate_with_context(selection, scalars()).expect_err("missing property is unknown");
assert!(
err.message.contains("`MultiLevel`"),
"{} didn't reference type",
Expand All @@ -498,4 +568,22 @@ mod tests {
err.locations
);
}

#[test]
fn unknown_var_in_scalar() {
let selection = r#"$({"something": $blahblahblah})"#;
let err = validate_with_context(selection, Shape::unknown([]))
.expect_err("unknown variable is unknown");
assert!(
err.message.contains("`$blahblahblah`"),
"{} didn't reference variable",
err.message
);
assert!(
err.locations
.contains(&location_of_expression("$blahblahblah", selection)),
"The relevant piece of the expression wasn't included in {:?}",
err.locations
);
}
}
Loading

0 comments on commit b3ee306

Please sign in to comment.