Skip to content

Commit

Permalink
Fix case & lint
Browse files Browse the repository at this point in the history
  • Loading branch information
penguin-wwy committed Apr 15, 2024
1 parent 05ab35d commit 890842c
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 3 deletions.
2 changes: 1 addition & 1 deletion include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -15821,7 +15821,7 @@ def Torch_PrimsIotaOp : Torch_Op<"prims.iota", [
Torch_BoolType:$requires_grad
);
let results = (outs
AnyTorchTensorType:$result
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4722,7 +4722,7 @@ class DecomposePrimsIotaOp : public OpRewritePattern<PrimsIotaOp> {
return success();
}
};
} // namespace
} // namespace

namespace {
// Decompose constant tensor full like ops.
Expand Down
6 changes: 6 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,9 @@
"FloatImplicitModule_basic",
"IntImplicitModule_basic",

# Unsupported: missing default value for argument 0 in schema for prims.iota.default
"PrimsIotaModule_basic",

# Others
"GridSamplerBasic1_basic",
"GridSamplerBasic2_basic",
Expand Down Expand Up @@ -2438,6 +2441,9 @@
# Failure - torch.aten.squeeze lower
"BucketizeTensorOutInt32RightModule_basic", # unsupported by backend contract: tensor with unknown rank

# RuntimeError: unsupported input type: Device
"PrimsIotaModule_basic",

# Failure - unknown
"BucketizeTensorFloatModule_basic",
"BucketizeTensorModule_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ def __init__(self):
None,
])
def forward(self):
return torch.ops.prims.iota(77, start=0, step=1, dtype=torch.int64, device=torch.device('cpu'),
return torch.ops.prims.iota(77, start=0, step=1, dtype=torch.int64, device='cpu',
requires_grad=False)

@register_test_case(module_factory=lambda: PrimsIotaModule())
Expand Down

0 comments on commit 890842c

Please sign in to comment.