Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix recursive constraint violations with paths over list and map shapes #2371

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,14 @@ class ClientCodegenVisitor(
// Add errors attached at the service level to the models
.let { ModelTransformer.create().copyServiceErrorsToOperations(it, settings.getService(it)) }
// Add `Box<T>` to recursive shapes as necessary
.let(RecursiveShapeBoxer::transform)
.let(RecursiveShapeBoxer()::transform)
// Normalize the `message` field on errors when enabled in settings (default: true)
.letIf(settings.codegenConfig.addMessageToErrors, AddErrorMessage::transform)
// NormalizeOperations by ensuring every operation has an input & output shape
.let(OperationNormalizer::transform)
// Drop unsupported event stream operations from the model
.let { RemoveEventStreamOperations.transform(it, settings) }
// - Normalize event stream operations
// Normalize event stream operations
.let(EventStreamNormalizer::transform)

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ internal class ResiliencyConfigCustomizationTest {

@Test
fun `generates a valid config`() {
val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel))
val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel))
val project = TestWorkspace.testProject()
val codegenContext = testCodegenContext(model, settings = project.rustSettings())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ import software.amazon.smithy.model.traits.Trait
/**
* Trait indicating that this shape should be represented with `Box<T>` when converted into Rust
*
* This is used to handle recursive shapes. See RecursiveShapeBoxer.
* This is used to handle recursive shapes.
* See [software.amazon.smithy.rust.codegen.core.smithy.transformers.RecursiveShapeBoxer].
*
* This trait is synthetic, applied during code generation, and never used in actual models.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,50 @@ package software.amazon.smithy.rust.codegen.core.smithy.transformers

import software.amazon.smithy.codegen.core.TopologicalIndex
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.ListShape
import software.amazon.smithy.model.shapes.CollectionShape
import software.amazon.smithy.model.shapes.MapShape
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.SetShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.transform.ModelTransformer
import software.amazon.smithy.rust.codegen.core.smithy.traits.RustBoxTrait
import software.amazon.smithy.rust.codegen.core.util.hasTrait

object RecursiveShapeBoxer {
class RecursiveShapeBoxer(
/**
* Transform a model which may contain recursive shapes into a model annotated with [RustBoxTrait]
* A predicate that determines when a cycle in the shape graph contains "indirection". If a cycle contains
* indirection, no shape needs to be tagged. What constitutes indirection is up to the caller to decide.
*/
private val containsIndirectionPredicate: (Collection<Shape>) -> Boolean = ::containsIndirection,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to see some docs for these params. I was unfamiliar with the :: syntax before I saw this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forgot to add docs on the params. I've now added them.

/**
* A closure that gets called on one member shape of a cycle that does not contain indirection for "fixing". For
* example, the [RustBoxTrait] trait can be used to tag the member shape.
*/
private val boxShapeFn: (MemberShape) -> MemberShape = ::addRustBoxTrait,
) {
/**
* Transform a model which may contain recursive shapes.
*
* When recursive shapes do NOT go through a List, Map, or Set, they must be boxed in Rust. This function will
* iteratively find loops & add the `RustBox` trait in a deterministic way until it reaches a fixed point.
* For example, when recursive shapes do NOT go through a `CollectionShape` or a `MapShape` shape, they must be
* boxed in Rust. This function will iteratively find cycles and call [boxShapeFn] on a member shape in the
* cycle to act on it. This is done in a deterministic way until it reaches a fixed point.
*
* This function MUST be deterministic (always choose the same shapes to `Box`). If it is not, that is a bug. Even so
* This function MUST be deterministic (always choose the same shapes to fix). If it is not, that is a bug. Even so
* this function may cause backward compatibility issues in certain pathological cases where a changes to recursive
* structures cause different members to be boxed. We may need to address these via customizations.
*
* For example, given the following model,
*
* ```smithy
* namespace com.example
*
* structure Recursive {
* recursiveStruct: Recursive
* anotherField: Boolean
* }
* ```
*
* The `com.example#Recursive$recursiveStruct` member shape is part of a cycle, but the
* `com.example#Recursive$anotherField` member shape is not.
*/
fun transform(model: Model): Model {
val next = transformInner(model)
Expand All @@ -37,16 +62,17 @@ object RecursiveShapeBoxer {
}

/**
* If [model] contains a recursive loop that must be boxed, apply one instance of [RustBoxTrait] return the new model.
* If [model] contains no loops, return null.
* If [model] contains a recursive loop that must be boxed, return the transformed model resulting form a call to
* [boxShapeFn].
* If [model] contains no loops, return `null`.
*/
private fun transformInner(model: Model): Model? {
// Execute 1-step of the boxing algorithm in the path to reaching a fixed point
// 1. Find all the shapes that are part of a cycle
// 2. Find all the loops that those shapes are part of
// 3. Filter out the loops that go through a layer of indirection
// 3. Pick _just one_ of the remaining loops to fix
// 4. Select the member shape in that loop with the earliest shape id
// Execute 1 step of the boxing algorithm in the path to reaching a fixed point:
// 1. Find all the shapes that are part of a cycle.
// 2. Find all the loops that those shapes are part of.
// 3. Filter out the loops that go through a layer of indirection.
// 3. Pick _just one_ of the remaining loops to fix.
// 4. Select the member shape in that loop with the earliest shape id.
// 5. Box it.
// (External to this function) Go back to 1.
val index = TopologicalIndex.of(model)
Expand All @@ -58,34 +84,38 @@ object RecursiveShapeBoxer {
// Flatten the connections into shapes.
loops.map { it.shapes }
}
val loopToFix = loops.firstOrNull { !containsIndirection(it) }
val loopToFix = loops.firstOrNull { !containsIndirectionPredicate(it) }

return loopToFix?.let { loop: List<Shape> ->
check(loop.isNotEmpty())
// pick the shape to box in a deterministic way
// Pick the shape to box in a deterministic way.
val shapeToBox = loop.filterIsInstance<MemberShape>().minByOrNull { it.id }!!
ModelTransformer.create().mapShapes(model) { shape ->
if (shape == shapeToBox) {
shape.asMemberShape().get().toBuilder().addTrait(RustBoxTrait()).build()
boxShapeFn(shape.asMemberShape().get())
} else {
shape
}
}
}
}
}

/**
* Check if a List<Shape> contains a shape which will use a pointer when represented in Rust, avoiding the
* need to add more Boxes
*/
private fun containsIndirection(loop: List<Shape>): Boolean {
return loop.find {
when (it) {
is ListShape,
is MapShape,
is SetShape, -> true
else -> it.hasTrait<RustBoxTrait>()
}
} != null
/**
* Check if a `List<Shape>` contains a shape which will use a pointer when represented in Rust, avoiding the
* need to add more `Box`es.
*
* Why `CollectionShape`s and `MapShape`s? Note that `CollectionShape`s get rendered in Rust as `Vec<T>`, and
* `MapShape`s as `HashMap<String, T>`; they're the only Smithy shapes that "organically" introduce indirection
* (via a pointer to the heap) in the recursive path. For other recursive paths, we thus have to introduce the
* indirection artificially ourselves using `Box`.
*
*/
private fun containsIndirection(loop: Collection<Shape>): Boolean = loop.find {
when (it) {
is CollectionShape, is MapShape -> true
else -> it.hasTrait<RustBoxTrait>()
}
}
} != null

private fun addRustBoxTrait(shape: MemberShape): MemberShape = shape.toBuilder().addTrait(RustBoxTrait()).build()
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class InstantiatorTest {
@required
num: Integer
}
""".asSmithyModel().let { RecursiveShapeBoxer.transform(it) }
""".asSmithyModel().let { RecursiveShapeBoxer().transform(it) }

private val codegenContext = testCodegenContext(model)
private val symbolProvider = codegenContext.symbolProvider
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ class StructureGeneratorTest {
@Test
fun `it generates accessor methods`() {
val testModel =
RecursiveShapeBoxer.transform(
RecursiveShapeBoxer().transform(
"""
namespace test

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class AwsQueryParserGeneratorTest {

@Test
fun `it modifies operation parsing to include Response and Result tags`() {
val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel))
val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel))
val codegenContext = testCodegenContext(model)
val symbolProvider = codegenContext.symbolProvider
val parserGenerator = AwsQueryParserGenerator(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class Ec2QueryParserGeneratorTest {

@Test
fun `it modifies operation parsing to include Response and Result tags`() {
val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel))
val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel))
val codegenContext = testCodegenContext(model)
val symbolProvider = codegenContext.symbolProvider
val parserGenerator = Ec2QueryParserGenerator(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class JsonParserGeneratorTest {

@Test
fun `generates valid deserializers`() {
val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel))
val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel))
val codegenContext = testCodegenContext(model)
val symbolProvider = codegenContext.symbolProvider
fun builderSymbol(shape: StructureShape): Symbol =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ internal class XmlBindingTraitParserGeneratorTest {

@Test
fun `generates valid parsers`() {
val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel))
val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel))
val codegenContext = testCodegenContext(model)
val symbolProvider = codegenContext.symbolProvider
val parserGenerator = XmlBindingTraitParserGenerator(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class AwsQuerySerializerGeneratorTest {
true -> CodegenTarget.CLIENT
false -> CodegenTarget.SERVER
}
val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel))
val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel))
val codegenContext = testCodegenContext(model, codegenTarget = codegenTarget)
val symbolProvider = codegenContext.symbolProvider
val parserGenerator = AwsQuerySerializerGenerator(testCodegenContext(model, codegenTarget = codegenTarget))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class Ec2QuerySerializerGeneratorTest {

@Test
fun `generates valid serializers`() {
val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel))
val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel))
val codegenContext = testCodegenContext(model)
val symbolProvider = codegenContext.symbolProvider
val parserGenerator = Ec2QuerySerializerGenerator(codegenContext)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class JsonSerializerGeneratorTest {

@Test
fun `generates valid serializers`() {
val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel))
val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel))
val codegenContext = testCodegenContext(model)
val symbolProvider = codegenContext.symbolProvider
val parserSerializer = JsonSerializerGenerator(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ internal class XmlBindingTraitSerializerGeneratorTest {

@Test
fun `generates valid serializers`() {
val model = RecursiveShapeBoxer.transform(OperationNormalizer.transform(baseModel))
val model = RecursiveShapeBoxer().transform(OperationNormalizer.transform(baseModel))
val codegenContext = testCodegenContext(model)
val symbolProvider = codegenContext.symbolProvider
val parserGenerator = XmlBindingTraitSerializerGenerator(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ internal class RecursiveShapeBoxerTest {
hello: Hello
}
""".asSmithyModel()
RecursiveShapeBoxer.transform(model) shouldBe model
RecursiveShapeBoxer().transform(model) shouldBe model
}

@Test
Expand All @@ -43,7 +43,7 @@ internal class RecursiveShapeBoxerTest {
anotherField: Boolean
}
""".asSmithyModel()
val transformed = RecursiveShapeBoxer.transform(model)
val transformed = RecursiveShapeBoxer().transform(model)
val member: MemberShape = transformed.lookup("com.example#Recursive\$RecursiveStruct")
member.expectTrait<RustBoxTrait>()
}
Expand All @@ -70,7 +70,7 @@ internal class RecursiveShapeBoxerTest {
third: SecondTree
}
""".asSmithyModel()
val transformed = RecursiveShapeBoxer.transform(model)
val transformed = RecursiveShapeBoxer().transform(model)
val boxed = transformed.shapes().filter { it.hasTrait<RustBoxTrait>() }.toList()
boxed.map { it.id.toString().removePrefix("com.example#") }.toSet() shouldBe setOf(
"Atom\$add",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class RecursiveShapesIntegrationTest {
}
output.message shouldContain "has infinite size"

val fixedProject = check(RecursiveShapeBoxer.transform(model))
val fixedProject = check(RecursiveShapeBoxer().transform(model))
fixedProject.compileAndTest()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.Ser
import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerProtocolLoader
import software.amazon.smithy.rust.codegen.server.smithy.traits.isReachableFromOperationInput
import software.amazon.smithy.rust.codegen.server.smithy.transformers.AttachValidationExceptionToConstrainedOperationInputsInAllowList
import software.amazon.smithy.rust.codegen.server.smithy.transformers.RecursiveConstraintViolationBoxer
import software.amazon.smithy.rust.codegen.server.smithy.transformers.RemoveEbsModelValidationException
import software.amazon.smithy.rust.codegen.server.smithy.transformers.ShapesReachableFromOperationInputTagger
import java.util.logging.Logger
Expand Down Expand Up @@ -162,7 +163,9 @@ open class ServerCodegenVisitor(
// Add errors attached at the service level to the models
.let { ModelTransformer.create().copyServiceErrorsToOperations(it, settings.getService(it)) }
// Add `Box<T>` to recursive shapes as necessary
.let(RecursiveShapeBoxer::transform)
.let(RecursiveShapeBoxer()::transform)
// Add `Box<T>` to recursive constraint violations as necessary
.let(RecursiveConstraintViolationBoxer::transform)
// Normalize operations by adding synthetic input and output shapes to every operation
.let(OperationNormalizer::transform)
// Remove the EBS model's own `ValidationException`, which collides with `smithy.framework#ValidationException`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,15 @@ import software.amazon.smithy.model.shapes.CollectionShape
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.Visibility
import software.amazon.smithy.rust.codegen.core.rustlang.join
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.smithy.makeRustBoxed
import software.amazon.smithy.rust.codegen.core.smithy.module
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.letIf
import software.amazon.smithy.rust.codegen.server.smithy.PubCrateConstraintViolationSymbolProvider
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape
import software.amazon.smithy.rust.codegen.server.smithy.traits.ConstraintViolationRustBoxTrait
import software.amazon.smithy.rust.codegen.server.smithy.traits.isReachableFromOperationInput

class CollectionConstraintViolationGenerator(
Expand All @@ -38,16 +41,22 @@ class CollectionConstraintViolationGenerator(
private val constraintsInfo: List<TraitInfo> = collectionConstraintsInfo.map { it.toTraitInfo() }

fun render() {
val memberShape = model.expectShape(shape.member.target)
val targetShape = model.expectShape(shape.member.target)
val constraintViolationSymbol = constraintViolationSymbolProvider.toSymbol(shape)
val constraintViolationName = constraintViolationSymbol.name
val isMemberConstrained = memberShape.canReachConstrainedShape(model, symbolProvider)
val isMemberConstrained = targetShape.canReachConstrainedShape(model, symbolProvider)
val constraintViolationVisibility = Visibility.publicIf(publicConstrainedTypes, Visibility.PUBCRATE)

modelsModuleWriter.withInlineModule(constraintViolationSymbol.module()) {
val constraintViolationVariants = constraintsInfo.map { it.constraintViolationVariant }.toMutableList()
if (isMemberConstrained) {
constraintViolationVariants += {
val memberConstraintViolationSymbol =
constraintViolationSymbolProvider.toSymbol(targetShape).letIf(
shape.member.hasTrait<ConstraintViolationRustBoxTrait>(),
) {
it.makeRustBoxed()
}
rustTemplate(
"""
/// Constraint violation error when an element doesn't satisfy its own constraints.
Expand All @@ -56,7 +65,7 @@ class CollectionConstraintViolationGenerator(
##[doc(hidden)]
Member(usize, #{MemberConstraintViolationSymbol})
""",
"MemberConstraintViolationSymbol" to constraintViolationSymbolProvider.toSymbol(memberShape),
"MemberConstraintViolationSymbol" to memberConstraintViolationSymbol,
)
}
}
Expand Down
Loading