Skip to content

Commit

Permalink
Make a typing dependency that is not in older PyTorch backwards compa…
Browse files Browse the repository at this point in the history
…tible. (llvm#2948)

This was found in a downstream that is pegged to an older PyTorch
version.
  • Loading branch information
stellaraccident authored Feb 23, 2024
1 parent ec2b80b commit 89e02c1
Showing 1 changed file with 23 additions and 2 deletions.
25 changes: 23 additions & 2 deletions python/torch_mlir/extras/fx_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,18 @@
import re
from dataclasses import dataclass
from types import BuiltinMethodType, BuiltinFunctionType
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Set,
Tuple,
TYPE_CHECKING,
Union,
)
import weakref

import numpy as np
Expand Down Expand Up @@ -45,6 +56,16 @@
Node,
)

try:
from torch.export.graph_signature import InputSpec as TypingInputSpec
except ModuleNotFoundError:
# PyTorch prior to 2.3 is missing certain things we use in typing
# signatures. Just make them be Any.
if not TYPE_CHECKING:
TypingInputSpec = Any
else:
raise

try:
import ml_dtypes
except ModuleNotFoundError:
Expand Down Expand Up @@ -299,7 +320,7 @@ class InputInfo:
"""Provides additional metadata when resolving inputs."""

program: torch.export.ExportedProgram
input_spec: torch.export.graph_signature.InputSpec
input_spec: TypingInputSpec
node: Node
ir_type: IrType
mutable_producer_node_name: Optional[str] = None
Expand Down

0 comments on commit 89e02c1

Please sign in to comment.