diff --git a/toolchain/base/BUILD b/toolchain/base/BUILD index cca1cd8204ca5..672fb02bba075 100644 --- a/toolchain/base/BUILD +++ b/toolchain/base/BUILD @@ -56,6 +56,7 @@ cc_library( hdrs = ["value_ids.h"], deps = [ ":index_base", + "//common:check", "//common:ostream", "@llvm-project//llvm:Support", ], @@ -89,10 +90,41 @@ cc_test( ], ) +cc_library( + name = "int_store", + srcs = ["int_store.cpp"], + hdrs = ["int_store.h"], + deps = [ + ":index_base", + ":mem_usage", + ":value_store", + ":yaml", + "//common:check", + "//common:hashtable_key_context", + "//common:ostream", + "//common:set", + "@llvm-project//llvm:Support", + ], +) + +cc_test( + name = "int_store_test", + size = "small", + srcs = ["int_store_test.cpp"], + deps = [ + ":int_store", + "//testing/base:gtest_main", + "//testing/base:test_raw_ostream", + "//toolchain/testing:yaml_test_helpers", + "@googletest//:gtest", + ], +) + cc_library( name = "shared_value_stores", hdrs = ["shared_value_stores.h"], deps = [ + ":int_store", ":mem_usage", ":value_ids", ":value_store", diff --git a/toolchain/base/int_store.cpp b/toolchain/base/int_store.cpp new file mode 100644 index 0000000000000..bd502d3b8a6df --- /dev/null +++ b/toolchain/base/int_store.cpp @@ -0,0 +1,66 @@ +// Part of the Carbon Language project, under the Apache License v2.0 with LLVM +// Exceptions. See /LICENSE for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "toolchain/base/int_store.h" + +namespace Carbon { + +auto IntStore::CanonicalBitWidth(int significant_bits) -> int { + // For larger integers, we store them in as a signed APInt with a canonical + // width that is the smallest multiple of the word type's bits, but no + // smaller than a minimum of 64 bits to avoid spurious resizing of the most + // common cases (<= 64 bits). + static constexpr int WordWidth = llvm::APInt::APINT_BITS_PER_WORD; + + return std::max( + MinAPWidth, ((significant_bits + WordWidth - 1) / WordWidth) * WordWidth); +} + +auto IntStore::CanonicalizeSigned(llvm::APInt value) -> llvm::APInt { + return value.sextOrTrunc(CanonicalBitWidth(value.getSignificantBits())); +} + +auto IntStore::CanonicalizeUnsigned(llvm::APInt value) -> llvm::APInt { + // We need the width to include a zero sign bit as we canonicalize to a + // signed representation. + return value.zextOrTrunc(CanonicalBitWidth(value.getActiveBits() + 1)); +} + +auto IntStore::AddLarge(int64_t value) -> IntId { + auto ap_id = + values_.Add(llvm::APInt(CanonicalBitWidth(64), value, /*isSigned=*/true)); + return MakeIndexOrInvalid(ap_id.index); +} + +auto IntStore::AddSignedLarge(llvm::APInt value) -> IntId { + auto ap_id = values_.Add(CanonicalizeSigned(value)); + return MakeIndexOrInvalid(ap_id.index); +} + +auto IntStore::AddUnsignedLarge(llvm::APInt value) -> IntId { + auto ap_id = values_.Add(CanonicalizeUnsigned(value)); + return MakeIndexOrInvalid(ap_id.index); +} + +auto IntStore::LookupLarge(int64_t value) const -> IntId { + auto ap_id = values_.Lookup( + llvm::APInt(CanonicalBitWidth(64), value, /*isSigned=*/true)); + return MakeIndexOrInvalid(ap_id.index); +} + +auto IntStore::LookupSignedLarge(llvm::APInt value) const -> IntId { + auto ap_id = values_.Lookup(CanonicalizeSigned(value)); + return MakeIndexOrInvalid(ap_id.index); +} + +auto IntStore::OutputYaml() const -> Yaml::OutputMapping { + return values_.OutputYaml(); +} + +auto IntStore::CollectMemUsage(MemUsage& mem_usage, llvm::StringRef label) const + -> void { + mem_usage.Collect(std::string(label), values_); +} + +} // namespace Carbon diff --git a/toolchain/base/int_store.h b/toolchain/base/int_store.h new file mode 100644 index 0000000000000..28e6d5e51758a --- /dev/null +++ b/toolchain/base/int_store.h @@ -0,0 +1,428 @@ +// Part of the Carbon Language project, under the Apache License v2.0 with LLVM +// Exceptions. See /LICENSE for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef CARBON_TOOLCHAIN_BASE_INT_STORE_H_ +#define CARBON_TOOLCHAIN_BASE_INT_STORE_H_ + +#include "common/check.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/SmallVector.h" +#include "toolchain/base/index_base.h" +#include "toolchain/base/mem_usage.h" +#include "toolchain/base/value_store.h" +#include "toolchain/base/yaml.h" + +namespace Carbon { + +// Forward declare a testing peer so we can friend it. +namespace Testing { +struct IntStoreTestPeer; +} // namespace Testing + +// Corresponds to a canonicalized integer value. This is used both for integer +// literal tokens, and integer values in SemIR. These always represent the +// abstract mathematical value -- signed and regardless of the needed precision. +// +// Small values are internalized into the ID itself. Large values are +// represented as an index into an array of `APInt`s with a canonicalized bit +// width. The ID itself can be queried for whether it is a value-embedded-ID or +// an index ID. The ID also provides APIs for extracting either the value or an +// index. +// +// ## Details of the encoding scheme ## +// +// We need all the values from a maximum to minimum, as well as a healthy range +// of indices, to fit within the token ID bits. +// +// We represent this as a signed `TokenIdBits`-bit 2s compliment integer. The +// sign extension from TokenIdBits to a register size can be folded into the +// shift used to extract those bits from compressed bitfield storage. +// +// We then divide the smallest 1/4th of that bit width's space to represent +// indices, and the larger 3/4ths to embedded values. For 23-bits total this +// still gives us 2 million unique integers larger than the embedded ones, which +// would be difficult to fill without exceeding the number of tokens we can lex +// (8 million). For non-token based integers, the indices can continue downward +// to the 32-bit signed integer minimum, supporting approximately 1.998 billion +// unique larger integers. +// +// Note that the invalid ID can't be used with a token. This is OK as we +// expect invalid tokens to be *error* tokens and not need to represent an +// invalid integer. +class IntId : public Printable { + public: + using ValueType = llvm::APInt; + + // The encoding of integer IDs ensures that valid IDs associated with tokens + // during lexing can fit into a compressed storage space. We arrange for + // `TokenIdBits` to be the minimum number of bits of storage for token + // associated IDs. The constant is public so the lexer can ensure it reserves + // adequate space. + // + // Note that there may still be IDs either not associated with + // tokens or computed after lexing outside of this range. + static constexpr int TokenIdBits = 23; + + static const IntId Invalid; + + static auto MakeFromTokenPayload(uint32_t payload) -> IntId { + // Token-associated IDs are signed `TokenIdBits` integers, so force sign + // extension from that bit. + return IntId(static_cast(payload << TokenIdBitsShift) >> + TokenIdBitsShift); + } + + // Construct an ID from a raw 32-bit ID value. + static constexpr auto MakeRaw(int32_t raw_id) -> IntId { + return IntId(raw_id); + } + + // Tests whether the ID is a value ID. + // + // Only *valid* IDs can have an embedded value, so when true this also implies + // the ID is valid. + constexpr auto is_value() const -> bool { return id_ > ZeroIndexId; } + + // Tests whether the ID is an index ID. + // + // Note that an invalid ID is represented as an index ID, so this is *not* + // sufficient to test whether an ID is valid. + constexpr auto is_index() const -> bool { return id_ <= ZeroIndexId; } + + // Test whether an ID is valid. + // + // This does not distinguish between value and index IDs, only whether valid. + constexpr auto is_valid() const -> bool { return id_ != InvalidId; } + + // Converts an ID to the embedded value. Requires that `is_value()` is true. + constexpr auto AsValue() const -> int { + CARBON_DCHECK(is_value()); + return id_; + } + + // Converts an ID to an index. Requires that `is_index()` is true. + // + // Note that this does *not* require that the ID is valid. An invalid ID will + // turn into an invalid index. + constexpr auto AsIndex() const -> int { + CARBON_DCHECK(is_index()); + return ZeroIndexId - id_; + } + + // Returns the ID formatted as a lex token payload. + constexpr auto AsTokenPayload() const -> uint32_t { + uint32_t payload = id_; + // Ensure this ID round trips as the token payload. + CARBON_DCHECK(*this == MakeFromTokenPayload(payload)); + return payload; + } + + constexpr auto AsRaw() const -> int32_t { return id_; } + + auto Print(llvm::raw_ostream& out) const -> void { + out << "int ["; + if (is_value()) { + out << "value: " << AsValue() << "]"; + } else if (is_index()) { + out << "index: " << AsIndex() << "]"; + } else { + CARBON_CHECK(!is_valid()); + out << "invalid]"; + } + } + + friend constexpr auto operator==(IntId lhs, IntId rhs) -> bool { + return lhs.id_ == rhs.id_; + } + friend constexpr auto operator<=>(IntId lhs, IntId rhs) + -> std::strong_ordering { + return lhs.id_ <=> rhs.id_; + } + + private: + friend class IntStore; + friend Testing::IntStoreTestPeer; + + // The shift needed when adjusting a between a `TokenIdBits`-width integer and + // a 32-bit integer. + static constexpr int TokenIdBitsShift = 32 - TokenIdBits; + + // The maximum embedded value in an ID. + static constexpr int32_t MaxValue = + std::numeric_limits::max() >> TokenIdBitsShift; + + // The ID value that represents an index of `0`. This is the first ID value + // representing an index, and all indices are `<=` to this. + // + // `ZeroIndexId` is the first index ID, and we encode indices as successive + // negative numbers counting downwards. The setup allows us to both use a + // comparison with this ID to distinguish value and index IDs, and to compute + // the actual index from the ID. + // + // The computation of an index in fact is just a subtraction: + // `ZeroIndexId - id_`. Subtraction is *also* how most CPUs implement the + // comparison, and so all of this ends up carefully constructed to enable very + // small code size when testing for an embedded value and when that test fails + // computing and using the index. + static constexpr int32_t ZeroIndexId = std::numeric_limits::min() >> + (TokenIdBitsShift + 1); + + // The minimum embedded value in an ID. + static constexpr int32_t MinValue = ZeroIndexId + 1; + + // The invalid ID, which needs to be placed after the largest index, which + // count downwards as IDs so below the smallest index ID, in order to optimize + // the code sequence needed to distinguish between integer and value IDs and + // to convert index IDs into actual indices small. + static constexpr int32_t InvalidId = std::numeric_limits::min(); + + // The invalid index. This is the result of converting an invalid ID into an + // index. We ensure that conversion can be done so that we can simplify the + // code that first tries to use an embedded value, then converts to an index + // and checks that the index is valid. + static const int32_t InvalidIndex; + + // Document the specific values of some of these constants to help visualize + // how the bit patterns map from the above computations. + // + // clang-format off: visualizing bit positions + // + // Each bit is either `T` for part of the token or `P` as part + // of the available payload that we use for the ID: + // + // 0bTTTT'TTTT'TPPP'PPPP'PPPP'PPPP'PPPP'PPPP + static_assert(MaxValue == 0b0000'0000'0011'1111'1111'1111'1111'1111); + static_assert(ZeroIndexId == 0b1111'1111'1110'0000'0000'0000'0000'0000); + static_assert(MinValue == 0b1111'1111'1110'0000'0000'0000'0000'0001); + static_assert(InvalidId == 0b1000'0000'0000'0000'0000'0000'0000'0000); + // clang-format on + + constexpr explicit IntId(int32_t id) : id_(id) {} + + int32_t id_; +}; + +constexpr IntId IntId::Invalid(IntId::InvalidId); + +// Note that we initialize the invalid index in a constexpr context which +// ensures there is no UB in forming it. This helps ensure all the ID -> index +// conversions are correct because the invalid ID is at the limit of that range. +constexpr int32_t IntId::InvalidIndex = Invalid.AsIndex(); + +// A canonicalizing value store with deep optimizations for integers. +// +// This stores integers as abstract, signed mathematical integers. The bit width +// of specific `APInt` values, either as inputs or outputs, is disregarded for +// the purpose of canonicalization and the returned integer may use a very +// different bit width `APInt` than was used when adding. There are also +// optimized paths for adding integer values representable using native integer +// types. +// +// Because the integers in the store are canonicalized with only a minimum bit +// width, there are helper functions to coerce them to a specific desired bit +// width for use. +// +// This leverages a significant optimization for small integer values -- rather +// than canonicalizing and making them unique in a `ValueStore`, they are +// directly embedded in the `IntId` itself. Only larger integers are stored in +// an array of `APInt` values and represented as an index in the ID. +class IntStore { + public: + // Accepts a signed `int64_t` and uses the mathematical signed integer value + // of it as the added integer value. + // + // Returns the ID corresponding to this integer value, storing an `APInt` if + // necessary to represent it. + auto Add(int64_t value) -> IntId { + // First try directly making this into an ID. + if (IntId id = TryMakeValue(value); id.is_valid()) [[likely]] { + return id; + } + + // Fallback for larger values. + return AddLarge(value); + } + + // Returns the ID corresponding to this signed integer value, storing an + // `APInt` if necessary to represent it. + auto AddSigned(llvm::APInt value) -> IntId { + // First try directly making this into an ID. + if (IntId id = TryMakeSignedValue(value); id.is_valid()) [[likely]] { + return id; + } + + // Fallback for larger values. + return AddSignedLarge(std::move(value)); + } + + // Returns the ID corresponding to an equivalent signed integer value for the + // provided unsigned integer value, storing an `APInt` if necessary to + // represent it. + auto AddUnsigned(llvm::APInt value) -> IntId { + // First try directly making this into an ID. + if (IntId id = TryMakeUnsignedValue(value); id.is_valid()) [[likely]] { + return id; + } + + // Fallback for larger values. + return AddUnsignedLarge(std::move(value)); + } + + // Returns the value for an ID. + // + // This will always be a signed `APInt` with a canonical bit width for the + // specific integer value in question. + auto Get(IntId id) const -> llvm::APInt { + if (id.is_value()) [[likely]] { + return llvm::APInt(MinAPWidth, id.AsValue(), /*isSigned=*/true); + } + return values_.Get(APIntId(id.AsIndex())); + } + + // Returns the value for an ID adjusted to a specific bit width. + // + // Note that because we store canonical mathematical integers as signed + // integers, this always sign extends or truncates to the target width. The + // caller can then use that as a signed or unsigned integer as needed. + auto GetAtWidth(IntId id, int bit_width) const -> llvm::APInt { + llvm::APInt value = Get(id); + if (static_cast(value.getBitWidth()) != bit_width) { + value = value.sextOrTrunc(bit_width); + } + return value; + } + + // Returns the value for an ID adjusted to the bit width specified with + // another integer ID. + // + // This simply looks up the width integer ID, and then calls the above + // `GetAtWidth` overload using the value found for it. See that overload for + // more details. + auto GetAtWidth(IntId id, IntId bit_width_id) const -> llvm::APInt { + const llvm::APInt bit_width = Get(bit_width_id); + CARBON_CHECK( + bit_width.isStrictlyPositive() && bit_width.isSignedIntN(MinAPWidth), + "Invalid bit width value: {0}", bit_width); + return GetAtWidth(id, bit_width.getSExtValue()); + } + + // Accepts a signed `int64_t` and uses the mathematical signed integer value + // of it as the integer value to lookup. Returns the canonical ID for that + // value or returns invalid if not in the store. + auto Lookup(int64_t value) const -> IntId { + if (IntId id = TryMakeValue(value); id.is_valid()) [[likely]] { + return id; + } + + // Fallback for larger values. + return LookupLarge(value); + } + + // Looks up the canonical ID for this signed integer value, or returns invalid + // if not in the store. + auto LookupSigned(llvm::APInt value) const -> IntId { + if (IntId id = TryMakeSignedValue(value); id.is_valid()) [[likely]] { + return id; + } + + // Fallback for larger values. + return LookupSignedLarge(std::move(value)); + } + + // Output a YAML description of this data structure. Note that this will only + // include the integers that required storing, not those successfully embedded + // into the ID space. + auto OutputYaml() const -> Yaml::OutputMapping; + + auto array_ref() const -> llvm::ArrayRef { + return values_.array_ref(); + } + auto size() const -> size_t { return values_.size(); } + + // Collects the memory usage of the separately stored integers. + auto CollectMemUsage(MemUsage& mem_usage, llvm::StringRef label) const + -> void; + + private: + friend struct Testing::IntStoreTestPeer; + + // Used for `values_`; tracked using `IntId`'s index range. + struct APIntId : IdBase, Printable { + using ValueType = llvm::APInt; + static const APIntId Invalid; + using IdBase::IdBase; + auto Print(llvm::raw_ostream& out) const -> void { + out << "ap_int"; + IdBase::Print(out); + } + }; + + static constexpr int MinAPWidth = 64; + + static auto MakeIndexOrInvalid(int index) -> IntId { + CARBON_DCHECK(index >= 0 && index <= IntId::InvalidIndex); + return IntId(IntId::ZeroIndexId - index); + } + + // Tries to make a signed 64-bit integer into an embedded value in the ID, and + // if unable to do that returns the `Invalid` ID. + static auto TryMakeValue(int64_t value) -> IntId { + if (IntId::MinValue <= value && value <= IntId::MaxValue) { + return IntId(value); + } + + return IntId::Invalid; + } + + // Tries to make a signed APInt into an embedded value in the ID, and if + // unable to do that returns the `Invalid` ID. + static auto TryMakeSignedValue(llvm::APInt value) -> IntId { + if (value.sge(IntId::MinValue) && value.sle(IntId::MaxValue)) { + return IntId(value.getSExtValue()); + } + + return IntId::Invalid; + } + + // Tries to make an unsigned APInt into an embedded value in the ID, and if + // unable to do that returns the `Invalid` ID. + static auto TryMakeUnsignedValue(llvm::APInt value) -> IntId { + if (value.ule(IntId::MaxValue)) { + return IntId(value.getZExtValue()); + } + + return IntId::Invalid; + } + + // Pick a canonical bit width for the provided number of significant bits. + static auto CanonicalBitWidth(int significant_bits) -> int; + + // Canonicalize an incoming signed APInt to the correct bit width. + static auto CanonicalizeSigned(llvm::APInt value) -> llvm::APInt; + + // Canonicalize an incoming unsigned APInt to the correct bit width. + static auto CanonicalizeUnsigned(llvm::APInt value) -> llvm::APInt; + + // Helper functions for handling values that are large enough to require an + // allocated `APInt` for storage. Creating or manipulating that storage is + // only a few lines of code, but we move these out-of-line because the + // generated code is big and harms performance for the non-`Large` common + // case. + auto AddLarge(int64_t value) -> IntId; + auto AddSignedLarge(llvm::APInt value) -> IntId; + auto AddUnsignedLarge(llvm::APInt value) -> IntId; + auto LookupLarge(int64_t value) const -> IntId; + auto LookupSignedLarge(llvm::APInt value) const -> IntId; + + // Stores values which don't fit in an IntId. These are always signed. + CanonicalValueStore values_; +}; + +constexpr IntStore::APIntId IntStore::APIntId::Invalid( + IntId::Invalid.AsIndex()); + +} // namespace Carbon + +#endif // CARBON_TOOLCHAIN_BASE_INT_STORE_H_ diff --git a/toolchain/base/int_store_test.cpp b/toolchain/base/int_store_test.cpp new file mode 100644 index 0000000000000..80c90acc6f765 --- /dev/null +++ b/toolchain/base/int_store_test.cpp @@ -0,0 +1,158 @@ +// Part of the Carbon Language project, under the Apache License v2.0 with LLVM +// Exceptions. See /LICENSE for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "toolchain/base/int_store.h" + +#include +#include + +#include + +namespace Carbon::Testing { + +struct IntStoreTestPeer { + static constexpr int MinAPWidth = IntStore::MinAPWidth; + + static constexpr int32_t MaxIdEmbeddedValue = IntId::MaxValue; + static constexpr int32_t MinIdEmbeddedValue = IntId::MinValue; +}; + +namespace { + +using ::testing::Eq; + +static constexpr int MinAPWidth = IntStoreTestPeer::MinAPWidth; + +static constexpr int32_t MaxIdEmbeddedValue = + IntStoreTestPeer::MaxIdEmbeddedValue; +static constexpr int32_t MinIdEmbeddedValue = + IntStoreTestPeer::MinIdEmbeddedValue; + +TEST(IntStore, Basic) { + IntStore ints; + IntId id_0 = ints.Add(0); + IntId id_1 = ints.Add(1); + IntId id_2 = ints.Add(2); + IntId id_42 = ints.Add(42); + IntId id_n1 = ints.Add(-1); + IntId id_n42 = ints.Add(-42); + IntId id_nines = ints.Add(999'999'999'999); + IntId id_max64 = ints.Add(std::numeric_limits::max()); + IntId id_min64 = ints.Add(std::numeric_limits::min()); + + for (IntId id : + {id_0, id_1, id_2, id_42, id_n1, id_n42, id_nines, id_max64, id_min64}) { + ASSERT_TRUE(id.is_valid()); + } + + // Small values should be embedded. + EXPECT_THAT(id_0.AsValue(), Eq(0)); + EXPECT_THAT(id_1.AsValue(), Eq(1)); + EXPECT_THAT(id_2.AsValue(), Eq(2)); + EXPECT_THAT(id_42.AsValue(), Eq(42)); + EXPECT_THAT(id_n1.AsValue(), Eq(-1)); + EXPECT_THAT(id_n42.AsValue(), Eq(-42)); + + // Rest should be indices as they don't fit as embedded values. + EXPECT_TRUE(!id_nines.is_value()); + EXPECT_TRUE(id_nines.is_index()); + EXPECT_TRUE(!id_max64.is_value()); + EXPECT_TRUE(id_max64.is_index()); + EXPECT_TRUE(!id_min64.is_value()); + EXPECT_TRUE(id_min64.is_index()); + + // And round tripping all the way through the store should work. + EXPECT_THAT(ints.Get(id_0), Eq(0)); + EXPECT_THAT(ints.Get(id_1), Eq(1)); + EXPECT_THAT(ints.Get(id_2), Eq(2)); + EXPECT_THAT(ints.Get(id_42), Eq(42)); + EXPECT_THAT(ints.Get(id_n1), Eq(-1)); + EXPECT_THAT(ints.Get(id_n42), Eq(-42)); + EXPECT_THAT(ints.Get(id_nines), Eq(999'999'999'999)); + EXPECT_THAT(ints.Get(id_max64), Eq(std::numeric_limits::max())); + EXPECT_THAT(ints.Get(id_min64), Eq(std::numeric_limits::min())); +} + +// Helper struct to hold test values and the resulting IDs. +struct APAndId { + llvm::APInt ap; + IntId id = IntId::Invalid; +}; + +TEST(IntStore, APSigned) { + IntStore ints; + + llvm::APInt big_128_ap = + llvm::APInt(128, 0x1234'abcd'1234'abcd, /*isSigned=*/true) * 0xabcd'0000; + llvm::APInt max_embedded_ap(MinAPWidth, MaxIdEmbeddedValue, + /*isSigned=*/true); + llvm::APInt min_embedded_ap(MinAPWidth, MinIdEmbeddedValue, + /*isSigned=*/true); + + APAndId ap_and_ids[] = { + {.ap = llvm::APInt(MinAPWidth, 1, /*isSigned=*/true)}, + {.ap = llvm::APInt(MinAPWidth, 2, /*isSigned=*/true)}, + {.ap = llvm::APInt(MinAPWidth, 999'999'999'999, /*isSigned=*/true)}, + {.ap = big_128_ap}, + {.ap = -big_128_ap}, + {.ap = + big_128_ap.sext(512) * big_128_ap.sext(512) * big_128_ap.sext(512)}, + {.ap = + -big_128_ap.sext(512) * big_128_ap.sext(512) * big_128_ap.sext(512)}, + {.ap = max_embedded_ap}, + {.ap = max_embedded_ap + 1}, + {.ap = min_embedded_ap}, + {.ap = min_embedded_ap - 1}, + }; + for (auto& [ap, id] : ap_and_ids) { + id = ints.AddSigned(ap); + ASSERT_TRUE(id.is_valid()) << ap; + } + + for (const auto& [ap, id] : ap_and_ids) { + // The sign extend here may be a no-op, but the original bit width is a + // reliable one at which to do the comparison. + EXPECT_THAT(ints.Get(id).sext(ap.getBitWidth()), Eq(ap)); + } +} + +TEST(IntStore, APUnsigned) { + IntStore ints; + + llvm::APInt big_128_ap = + llvm::APInt(128, 0xabcd'abcd'abcd'abcd) * 0xabcd'0000'abcd'0000; + llvm::APInt max_embedded_ap(MinAPWidth, MaxIdEmbeddedValue); + + APAndId ap_and_ids[] = { + {.ap = llvm::APInt(MinAPWidth, 1)}, + {.ap = llvm::APInt(MinAPWidth, 2)}, + {.ap = llvm::APInt(MinAPWidth, 999'999'999'999)}, + {.ap = llvm::APInt(MinAPWidth, std::numeric_limits::max())}, + {.ap = llvm::APInt(MinAPWidth + 1, std::numeric_limits::max()) + + 1}, + {.ap = big_128_ap}, + {.ap = + big_128_ap.zext(512) * big_128_ap.zext(512) * big_128_ap.zext(512)}, + {.ap = max_embedded_ap}, + {.ap = max_embedded_ap + 1}, + }; + for (auto& [ap, id] : ap_and_ids) { + id = ints.AddUnsigned(ap); + ASSERT_TRUE(id.is_valid()) << ap; + } + + for (const auto& [ap, id] : ap_and_ids) { + auto stored_ap = ints.Get(id); + // Pick a bit width wide enough to represent both whatever is returned and + // the original value as a *signed* integer without any truncation. + int width = std::max(stored_ap.getBitWidth(), ap.getBitWidth() + 1); + // We sign extend the stored value and zero extend the original number. This + // ensures that anything added as unsigned ends up stored as a positive + // number even when sign extended. + EXPECT_THAT(stored_ap.sext(width), Eq(ap.zext(width))); + } +} + +} // namespace +} // namespace Carbon::Testing diff --git a/toolchain/base/shared_value_stores.h b/toolchain/base/shared_value_stores.h index 1a9d38d586e9c..2d47b7ba5ff15 100644 --- a/toolchain/base/shared_value_stores.h +++ b/toolchain/base/shared_value_stores.h @@ -5,6 +5,7 @@ #ifndef CARBON_TOOLCHAIN_BASE_SHARED_VALUE_STORES_H_ #define CARBON_TOOLCHAIN_BASE_SHARED_VALUE_STORES_H_ +#include "toolchain/base/int_store.h" #include "toolchain/base/mem_usage.h" #include "toolchain/base/value_ids.h" #include "toolchain/base/value_store.h" @@ -17,7 +18,7 @@ namespace Carbon { class SharedValueStores : public Yaml::Printable { public: // Provide types that can be used by APIs to forward access to these stores. - using IntStore = CanonicalValueStore; + using IntStore = IntStore; using RealStore = ValueStore; using FloatStore = CanonicalValueStore; using IdentifierStore = CanonicalValueStore; diff --git a/toolchain/base/shared_value_stores_test.cpp b/toolchain/base/shared_value_stores_test.cpp index 430d0931fa207..e8872a40b892e 100644 --- a/toolchain/base/shared_value_stores_test.cpp +++ b/toolchain/base/shared_value_stores_test.cpp @@ -40,7 +40,8 @@ TEST(SharedValueStores, PrintEmpty) { TEST(SharedValueStores, PrintVals) { SharedValueStores value_stores; llvm::APInt apint(64, 8, /*isSigned=*/true); - value_stores.ints().Add(apint); + value_stores.ints().AddSigned(apint); + value_stores.ints().AddSigned(llvm::APInt(64, 999'999'999'999)); value_stores.reals().Add( Real{.mantissa = apint, .exponent = apint, .is_decimal = true}); value_stores.identifiers().Add("a"); @@ -50,7 +51,7 @@ TEST(SharedValueStores, PrintVals) { EXPECT_THAT(Yaml::Value::FromText(out.TakeStr()), MatchSharedValues( - ElementsAre(Pair("int0", Yaml::Scalar("8"))), + ElementsAre(Pair("ap_int0", Yaml::Scalar("999999999999"))), ElementsAre(Pair("real0", Yaml::Scalar("8*10^8"))), ElementsAre(Pair("identifier0", Yaml::Scalar("a"))), ElementsAre(Pair("string0", Yaml::Scalar("foo'\"baz"))))); diff --git a/toolchain/base/value_ids.h b/toolchain/base/value_ids.h index e632a5589b0f2..7184a876241a2 100644 --- a/toolchain/base/value_ids.h +++ b/toolchain/base/value_ids.h @@ -41,21 +41,6 @@ class Real : public Printable { bool is_decimal; }; -// Corresponds to an integer value represented by an APInt. This is used both -// for integer literal tokens, which are unsigned and have an unspecified -// bit-width, and integer values in SemIR, which have a signedness and bit-width -// matching their type. -struct IntId : public IdBase, public Printable { - using ValueType = llvm::APInt; - static const IntId Invalid; - using IdBase::IdBase; - auto Print(llvm::raw_ostream& out) const -> void { - out << "int"; - IdBase::Print(out); - } -}; -constexpr IntId IntId::Invalid(IntId::InvalidIndex); - // Corresponds to a float value represented by an APFloat. This is used for // floating-point values in SemIR. struct FloatId : public IdBase, public Printable { diff --git a/toolchain/base/value_store_test.cpp b/toolchain/base/value_store_test.cpp index 569a9fab6cdc2..dcc58f2efda41 100644 --- a/toolchain/base/value_store_test.cpp +++ b/toolchain/base/value_store_test.cpp @@ -15,19 +15,6 @@ namespace { using ::testing::Eq; using ::testing::Not; -TEST(ValueStore, Int) { - CanonicalValueStore ints; - IntId id1 = ints.Add(llvm::APInt(64, 1)); - IntId id2 = ints.Add(llvm::APInt(64, 2)); - - ASSERT_TRUE(id1.is_valid()); - ASSERT_TRUE(id2.is_valid()); - EXPECT_THAT(id1, Not(Eq(id2))); - - EXPECT_THAT(ints.Get(id1), Eq(1)); - EXPECT_THAT(ints.Get(id2), Eq(2)); -} - TEST(ValueStore, Real) { Real real1{.mantissa = llvm::APInt(64, 1), .exponent = llvm::APInt(64, 11), diff --git a/toolchain/check/convert.cpp b/toolchain/check/convert.cpp index f5b8504bddea4..20b3ee538c14d 100644 --- a/toolchain/check/convert.cpp +++ b/toolchain/check/convert.cpp @@ -155,7 +155,7 @@ static auto MakeElementAccessInst(Context& context, SemIR::LocId loc_id, auto index_id = block.template AddInst( loc_id, {.type_id = context.GetBuiltinType(SemIR::BuiltinInstKind::IntType), - .int_id = context.ints().Add(llvm::APInt(32, i))}); + .int_id = context.ints().AddUnsigned(llvm::APInt(32, i))}); return block.template AddInst( loc_id, {elem_type_id, aggregate_id, index_id}); } else { diff --git a/toolchain/check/eval.cpp b/toolchain/check/eval.cpp index 281a2ef9ab8a6..02993ef28ee68 100644 --- a/toolchain/check/eval.cpp +++ b/toolchain/check/eval.cpp @@ -244,7 +244,7 @@ static auto MakeBoolResult(Context& context, SemIR::TypeId bool_type_id, // Converts an APInt value into a ConstantId. static auto MakeIntResult(Context& context, SemIR::TypeId type_id, llvm::APInt value) -> SemIR::ConstantId { - auto result = context.ints().Add(std::move(value)); + auto result = context.ints().AddSigned(std::move(value)); return MakeConstantResult( context, SemIR::IntValue{.type_id = type_id, .int_id = result}, Phase::Template); @@ -674,12 +674,14 @@ static auto PerformBuiltinUnaryIntOp(Context& context, SemIRLoc loc, SemIR::InstId arg_id) -> SemIR::ConstantId { auto op = context.insts().GetAs(arg_id); - auto op_val = context.ints().Get(op.int_id); + auto [is_signed, bit_width_id] = context.sem_ir().GetIntTypeInfo(op.type_id); + CARBON_CHECK(bit_width_id != IntId::Invalid, + "Cannot evaluate a generic bit width integer: {0}", op); + llvm::APInt op_val = context.ints().GetAtWidth(op.int_id, bit_width_id); switch (builtin_kind) { case SemIR::BuiltinFunctionKind::IntSNegate: - if (context.types().IsSignedInt(op.type_id) && - op_val.isMinSignedValue()) { + if (is_signed && op_val.isMinSignedValue()) { CARBON_DIAGNOSTIC(CompileTimeIntegerNegateOverflow, Error, "integer overflow in negation of {0}", TypedInt); context.emitter().Emit(loc, CompileTimeIntegerNegateOverflow, @@ -708,8 +710,6 @@ static auto PerformBuiltinBinaryIntOp(Context& context, SemIRLoc loc, -> SemIR::ConstantId { auto lhs = context.insts().GetAs(lhs_id); auto rhs = context.insts().GetAs(rhs_id); - const auto& lhs_val = context.ints().Get(lhs.int_id); - const auto& rhs_val = context.ints().Get(rhs.int_id); // Check for division by zero. switch (builtin_kind) { @@ -717,7 +717,7 @@ static auto PerformBuiltinBinaryIntOp(Context& context, SemIRLoc loc, case SemIR::BuiltinFunctionKind::IntSMod: case SemIR::BuiltinFunctionKind::IntUDiv: case SemIR::BuiltinFunctionKind::IntUMod: - if (rhs_val.isZero()) { + if (context.ints().Get(rhs.int_id).isZero()) { DiagnoseDivisionByZero(context, loc); return SemIR::ConstantId::Error; } @@ -726,9 +726,58 @@ static auto PerformBuiltinBinaryIntOp(Context& context, SemIRLoc loc, break; } - bool overflow = false; + auto [lhs_is_signed, lhs_bit_width_id] = + context.sem_ir().GetIntTypeInfo(lhs.type_id); + llvm::APInt lhs_val = context.ints().GetAtWidth(lhs.int_id, lhs_bit_width_id); + llvm::APInt result_val; + + // First handle shift, which can directly use the canonical RHS and doesn't + // overflow. + switch (builtin_kind) { + // Bit shift. + case SemIR::BuiltinFunctionKind::IntLeftShift: + case SemIR::BuiltinFunctionKind::IntRightShift: { + const auto& rhs_orig_val = context.ints().Get(rhs.int_id); + if (rhs_orig_val.uge(lhs_val.getBitWidth()) || + (rhs_orig_val.isNegative() && lhs_is_signed)) { + CARBON_DIAGNOSTIC( + CompileTimeShiftOutOfRange, Error, + "shift distance not in range [0, {0}) in {1} {2:<<|>>} {3}", + unsigned, TypedInt, BoolAsSelect, TypedInt); + context.emitter().Emit( + loc, CompileTimeShiftOutOfRange, lhs_val.getBitWidth(), + {.type = lhs.type_id, .value = lhs_val}, + builtin_kind == SemIR::BuiltinFunctionKind::IntLeftShift, + {.type = rhs.type_id, .value = rhs_orig_val}); + // TODO: Is it useful to recover by returning 0 or -1? + return SemIR::ConstantId::Error; + } + + if (builtin_kind == SemIR::BuiltinFunctionKind::IntLeftShift) { + result_val = lhs_val.shl(rhs_orig_val); + } else if (lhs_is_signed) { + result_val = lhs_val.ashr(rhs_orig_val); + } else { + result_val = lhs_val.lshr(rhs_orig_val); + } + return MakeIntResult(context, lhs.type_id, std::move(result_val)); + } + + default: + // Break to do additional setup for other builtin kinds. + break; + } + + // Other operations are already checked to be homogeneous, so we can extend + // the RHS with the LHS bit width. + CARBON_CHECK(rhs.type_id == lhs.type_id, "Heterogeneous builtin integer op!"); + llvm::APInt rhs_val = context.ints().GetAtWidth(rhs.int_id, lhs_bit_width_id); + + // We may also need to diagnose overflow for these operations. + bool overflow = false; Lex::TokenKind op_token = Lex::TokenKind::Not; + switch (builtin_kind) { // Arithmetic. case SemIR::BuiltinFunctionKind::IntSAdd: @@ -789,32 +838,9 @@ static auto PerformBuiltinBinaryIntOp(Context& context, SemIRLoc loc, op_token = Lex::TokenKind::Caret; break; - // Bit shift. case SemIR::BuiltinFunctionKind::IntLeftShift: case SemIR::BuiltinFunctionKind::IntRightShift: - if (rhs_val.uge(lhs_val.getBitWidth()) || - (rhs_val.isNegative() && context.types().IsSignedInt(rhs.type_id))) { - CARBON_DIAGNOSTIC( - CompileTimeShiftOutOfRange, Error, - "shift distance not in range [0, {0}) in {1} {2:<<|>>} {3}", - unsigned, TypedInt, BoolAsSelect, TypedInt); - context.emitter().Emit( - loc, CompileTimeShiftOutOfRange, lhs_val.getBitWidth(), - {.type = lhs.type_id, .value = lhs_val}, - builtin_kind == SemIR::BuiltinFunctionKind::IntLeftShift, - {.type = rhs.type_id, .value = rhs_val}); - // TODO: Is it useful to recover by returning 0 or -1? - return SemIR::ConstantId::Error; - } - - if (builtin_kind == SemIR::BuiltinFunctionKind::IntLeftShift) { - result_val = lhs_val.shl(rhs_val); - } else if (context.types().IsSignedInt(lhs.type_id)) { - result_val = lhs_val.ashr(rhs_val); - } else { - result_val = lhs_val.lshr(rhs_val); - } - break; + CARBON_FATAL("Handled specially above."); default: CARBON_FATAL("Unexpected operation kind."); @@ -840,10 +866,15 @@ static auto PerformBuiltinIntComparison(Context& context, SemIR::TypeId bool_type_id) -> SemIR::ConstantId { auto lhs = context.insts().GetAs(lhs_id); - const auto& lhs_val = context.ints().Get(lhs.int_id); - const auto& rhs_val = - context.ints().Get(context.insts().GetAs(rhs_id).int_id); - bool is_signed = context.types().IsSignedInt(lhs.type_id); + auto rhs = context.insts().GetAs(rhs_id); + CARBON_CHECK(lhs.type_id == rhs.type_id, + "Builtin comparison with mismatched types!"); + + auto [is_signed, bit_width_id] = context.sem_ir().GetIntTypeInfo(lhs.type_id); + CARBON_CHECK(bit_width_id != IntId::Invalid, + "Cannot evaluate a generic bit width integer: {0}", lhs); + llvm::APInt lhs_val = context.ints().GetAtWidth(lhs.int_id, bit_width_id); + llvm::APInt rhs_val = context.ints().GetAtWidth(rhs.int_id, bit_width_id); bool result; switch (builtin_kind) { diff --git a/toolchain/check/handle_literal.cpp b/toolchain/check/handle_literal.cpp index d975de044809e..09925e5d43c50 100644 --- a/toolchain/check/handle_literal.cpp +++ b/toolchain/check/handle_literal.cpp @@ -33,20 +33,23 @@ auto HandleParseNode(Context& context, Parse::BoolLiteralTrueId node_id) static auto MakeI32Literal(Context& context, Parse::NodeId node_id, IntId int_id) -> SemIR::InstId { auto val = context.ints().Get(int_id); - if (val.getActiveBits() > 31) { + CARBON_CHECK(val.isNonNegative(), + "Unexpected negative literal from the lexer: {0}", val); + + // Make sure the value fits in an `i32`. + if (val.getSignificantBits() > 32) { CARBON_DIAGNOSTIC(IntLiteralTooLargeForI32, Error, "integer literal with value {0} does not fit in i32", - llvm::APSInt); - context.emitter().Emit(node_id, IntLiteralTooLargeForI32, - llvm::APSInt(val, /*isUnsigned=*/true)); + llvm::APInt); + context.emitter().Emit(node_id, IntLiteralTooLargeForI32, val); return SemIR::InstId::BuiltinError; } - // Literals are always represented as unsigned, so zero-extend if needed. - auto i32_val = val.zextOrTrunc(32); + + // We directly reuse the integer ID as it represents the canonical value. return context.AddInst( node_id, {.type_id = context.GetBuiltinType(SemIR::BuiltinInstKind::IntType), - .int_id = context.ints().Add(i32_val)}); + .int_id = int_id}); } // Forms an IntValue instruction with type `IntLiteral` for a given literal diff --git a/toolchain/check/import_ref.cpp b/toolchain/check/import_ref.cpp index e15d0b2cbe420..b7100a4bdfb9b 100644 --- a/toolchain/check/import_ref.cpp +++ b/toolchain/check/import_ref.cpp @@ -2135,9 +2135,16 @@ class ImportRefResolver { return Retry(); } + // We can directly reuse the value IDs across file IRs. Otherwise, we need + // to add a new canonical int in this IR. + auto int_id = + inst.int_id.is_value() + ? inst.int_id + : context_.ints().AddSigned(import_ir_.ints().Get(inst.int_id)); + return ResolveAs( {.type_id = context_.GetTypeIdForTypeConstant(type_id), - .int_id = context_.ints().Add(import_ir_.ints().Get(inst.int_id))}); + .int_id = int_id}); } auto TryResolveTypedInst(SemIR::IntType inst) -> ResolveResult { diff --git a/toolchain/check/member_access.cpp b/toolchain/check/member_access.cpp index 4ea3362826c1f..4cc17c1ae292f 100644 --- a/toolchain/check/member_access.cpp +++ b/toolchain/check/member_access.cpp @@ -370,8 +370,8 @@ static auto PerformInstanceBinding(Context& context, SemIR::LocId loc_id, static auto ValidateTupleIndex(Context& context, SemIR::LocId loc_id, SemIR::InstId operand_inst_id, SemIR::IntValue index_inst, int size) - -> const llvm::APInt* { - const auto& index_val = context.ints().Get(index_inst.int_id); + -> std::optional { + llvm::APInt index_val = context.ints().Get(index_inst.int_id); if (index_val.uge(size)) { CARBON_DIAGNOSTIC(TupleIndexOutOfBounds, Error, "tuple element index `{0}` is past the end of type {1}", @@ -379,9 +379,9 @@ static auto ValidateTupleIndex(Context& context, SemIR::LocId loc_id, context.emitter().Emit(loc_id, TupleIndexOutOfBounds, {.type = index_inst.type_id, .value = index_val}, operand_inst_id); - return nullptr; + return std::nullopt; } - return &index_val; + return index_val; } auto PerformMemberAccess(Context& context, SemIR::LocId loc_id, @@ -533,8 +533,8 @@ auto PerformTupleAccess(Context& context, SemIR::LocId loc_id, auto index_literal = context.insts().GetAs( context.constant_values().GetInstId(index_const_id)); auto type_block = context.type_blocks().Get(tuple_type->elements_id); - const auto* index_val = ValidateTupleIndex(context, loc_id, tuple_inst_id, - index_literal, type_block.size()); + std::optional index_val = ValidateTupleIndex( + context, loc_id, tuple_inst_id, index_literal, type_block.size()); if (!index_val) { return SemIR::InstId::BuiltinError; } diff --git a/toolchain/driver/testdata/dump_shared_values.carbon b/toolchain/driver/testdata/dump_shared_values.carbon index b2da163d9de38..e13e938ed4b84 100644 --- a/toolchain/driver/testdata/dump_shared_values.carbon +++ b/toolchain/driver/testdata/dump_shared_values.carbon @@ -22,11 +22,7 @@ var str2: String = "ab'\"c"; // CHECK:STDOUT: --- // CHECK:STDOUT: filename: dump_shared_values.carbon // CHECK:STDOUT: shared_values: -// CHECK:STDOUT: ints: -// CHECK:STDOUT: int0: 32 -// CHECK:STDOUT: int1: 1 -// CHECK:STDOUT: int2: 8 -// CHECK:STDOUT: int3: 64 +// CHECK:STDOUT: ints: {} // CHECK:STDOUT: reals: // CHECK:STDOUT: real0: 10*10^-1 // CHECK:STDOUT: real1: 8*10^7 diff --git a/toolchain/lex/lex.cpp b/toolchain/lex/lex.cpp index fa4c942044a21..f11c2b8fce1b0 100644 --- a/toolchain/lex/lex.cpp +++ b/toolchain/lex/lex.cpp @@ -1013,10 +1013,11 @@ auto Lexer::LexNumericLiteral(llvm::StringRef source_text, ssize_t& position) return VariantMatch( literal->ComputeValue(emitter_), [&](NumericLiteral::IntValue&& value) { - return LexTokenWithPayload( - TokenKind::IntLiteral, - buffer_.value_stores_->ints().Add(std::move(value.value)).index, - byte_offset); + return LexTokenWithPayload(TokenKind::IntLiteral, + buffer_.value_stores_->ints() + .AddUnsigned(std::move(value.value)) + .AsTokenPayload(), + byte_offset); }, [&](NumericLiteral::RealValue&& value) { auto real_id = buffer_.value_stores_->reals().Add(Real{ @@ -1222,10 +1223,13 @@ auto Lexer::LexWordAsTypeLiteralToken(llvm::StringRef word, int32_t byte_offset) suffix_value = suffix_value * 10 + (c - '0'); } - return LexTokenWithPayload( - kind, - buffer_.value_stores_->ints().Add(llvm::APInt(64, suffix_value)).index, - byte_offset); + // Add the bit width to our integer store and get its index. We treat it as + // unsigned as that's less expensive and it can't be negative. + CARBON_CHECK(suffix_value >= 0); + auto bit_width_payload = + buffer_.value_stores_->ints().Add(suffix_value).AsTokenPayload(); + + return LexTokenWithPayload(kind, bit_width_payload, byte_offset); } auto Lexer::LexKeywordOrIdentifier(llvm::StringRef source_text, diff --git a/toolchain/lex/tokenized_buffer.h b/toolchain/lex/tokenized_buffer.h index b958eef20ef59..f3d0f6756fbea 100644 --- a/toolchain/lex/tokenized_buffer.h +++ b/toolchain/lex/tokenized_buffer.h @@ -312,7 +312,7 @@ class TokenizedBuffer : public Printable { kind() == TokenKind::IntTypeLiteral || kind() == TokenKind::UnsignedIntTypeLiteral || kind() == TokenKind::FloatTypeLiteral); - return IntId(token_payload_); + return IntId::MakeFromTokenPayload(token_payload_); } auto real_id() const -> RealId { @@ -363,6 +363,9 @@ class TokenizedBuffer : public Printable { static constexpr int PayloadBits = 23; + // Make sure we have enough payload bits to represent token-associated IDs. + static_assert(PayloadBits >= IntId::TokenIdBits); + // Constructor for a TokenKind that carries no payload, or where the payload // will be set later. // diff --git a/toolchain/lower/constant.cpp b/toolchain/lower/constant.cpp index 6fdaa614ba41f..f7dc290c4d00c 100644 --- a/toolchain/lower/constant.cpp +++ b/toolchain/lower/constant.cpp @@ -210,13 +210,18 @@ static auto EmitAsConstant(ConstantContext& context, SemIR::IntValue inst) // IntLiteral is represented as an empty struct. All other integer types are // represented as an LLVM integer type. - if (!llvm::isa(type)) { + auto* int_type = llvm::dyn_cast(type); + if (!int_type) { auto* struct_type = llvm::dyn_cast(type); CARBON_CHECK(struct_type && struct_type->getNumElements() == 0); return llvm::ConstantStruct::get(struct_type); } - return llvm::ConstantInt::get(type, context.sem_ir().ints().Get(inst.int_id)); + auto val = context.sem_ir().ints().Get(inst.int_id); + int bit_width = int_type->getBitWidth(); + bool is_signed = context.sem_ir().GetIntTypeInfo(inst.type_id).is_signed; + return llvm::ConstantInt::get(type, is_signed ? val.sextOrTrunc(bit_width) + : val.zextOrTrunc(bit_width)); } static auto EmitAsConstant(ConstantContext& context, SemIR::Namespace inst) diff --git a/toolchain/sem_ir/BUILD b/toolchain/sem_ir/BUILD index 7996b9e1cb1d3..8d1ab6e342d6f 100644 --- a/toolchain/sem_ir/BUILD +++ b/toolchain/sem_ir/BUILD @@ -36,6 +36,7 @@ cc_library( "//common:check", "//common:ostream", "//toolchain/base:index_base", + "//toolchain/base:int_store", "//toolchain/base:shared_value_stores", "//toolchain/base:value_ids", "//toolchain/diagnostics:diagnostic_emitter", @@ -55,6 +56,7 @@ cc_library( deps = [ "//common:check", "//common:enum_base", + "//toolchain/base:int_store", "//toolchain/parse:node_kind", "//toolchain/sem_ir:builtin_inst_kind", "//toolchain/sem_ir:ids", @@ -76,6 +78,7 @@ cc_library( "//common:ostream", "//common:struct_reflection", "//toolchain/base:index_base", + "//toolchain/base:int_store", "//toolchain/base:value_store", "@llvm-project//llvm:Support", ], @@ -130,6 +133,7 @@ cc_library( "//common:map", "//common:ostream", "//common:set", + "//toolchain/base:int_store", "//toolchain/base:kind_switch", "//toolchain/base:shared_value_stores", "//toolchain/base:value_ids", diff --git a/toolchain/sem_ir/file.h b/toolchain/sem_ir/file.h index 355c2f856ad44..f5b655d9caefa 100644 --- a/toolchain/sem_ir/file.h +++ b/toolchain/sem_ir/file.h @@ -10,6 +10,7 @@ #include "llvm/ADT/iterator_range.h" #include "llvm/Support/Allocator.h" #include "llvm/Support/FormatVariadic.h" +#include "toolchain/base/int_store.h" #include "toolchain/base/shared_value_stores.h" #include "toolchain/base/value_store.h" #include "toolchain/base/yaml.h" @@ -35,6 +36,12 @@ namespace Carbon::SemIR { // Provides semantic analysis on a Parse::Tree. class File : public Printable { public: + // Used to return information about an integer type in `GetIntTypeInfo`. + struct IntTypeInfo { + bool is_signed; + IntId bit_width; + }; + // Starts a new file for Check::CheckParseTree. explicit File(CheckIRId check_ir_id, IdentifierId package_id, LibraryNameId library_id, SharedValueStores& value_stores, @@ -71,6 +78,26 @@ class File : public Printable { return types().GetAs(pointer_id).pointee_id; } + // Returns integer type information from a type ID. Abstracts away the + // difference between an `IntType` instruction defined type and a builtin + // instruction defined type. Uses IntId::Invalid for types that have an + // invalid width. + // + // TODO: When we don't have a builtin int type mixed with actual `IntType` + // instructions, clients should directly query the `IntType` instruction to + // compute this information. + auto GetIntTypeInfo(TypeId int_type_id) const -> IntTypeInfo { + auto inst_id = types().GetInstId(int_type_id); + if (inst_id == InstId::BuiltinIntType) { + return {.is_signed = true, .bit_width = ints().Lookup(32)}; + } + auto int_type = insts().GetAs(inst_id); + auto bit_width_inst = insts().TryGetAs(int_type.bit_width_id); + return { + .is_signed = int_type.int_kind.is_signed(), + .bit_width = bit_width_inst ? bit_width_inst->int_id : IntId::Invalid}; + } + auto check_ir_id() const -> CheckIRId { return check_ir_id_; } auto package_id() const -> IdentifierId { return package_id_; } auto library_id() const -> SemIR::LibraryNameId { return library_id_; } diff --git a/toolchain/sem_ir/id_kind.h b/toolchain/sem_ir/id_kind.h index 84e4866ff2b88..9c1a78324d381 100644 --- a/toolchain/sem_ir/id_kind.h +++ b/toolchain/sem_ir/id_kind.h @@ -7,6 +7,7 @@ #include +#include "toolchain/base/int_store.h" #include "toolchain/sem_ir/ids.h" namespace Carbon::SemIR { diff --git a/toolchain/sem_ir/inst.h b/toolchain/sem_ir/inst.h index 5e2c11f3e6a12..a79d59d18cf0e 100644 --- a/toolchain/sem_ir/inst.h +++ b/toolchain/sem_ir/inst.h @@ -13,6 +13,7 @@ #include "common/ostream.h" #include "common/struct_reflection.h" #include "toolchain/base/index_base.h" +#include "toolchain/base/int_store.h" #include "toolchain/base/value_store.h" #include "toolchain/sem_ir/block_value_store.h" #include "toolchain/sem_ir/builtin_inst_kind.h" @@ -265,6 +266,7 @@ class Inst : public Printable { // Convert a field to its raw representation, used as `arg0_` / `arg1_`. static constexpr auto ToRaw(IdBase base) -> int32_t { return base.index; } + static constexpr auto ToRaw(IntId id) -> int32_t { return id.AsRaw(); } static constexpr auto ToRaw(BuiltinInstKind kind) -> int32_t { return kind.AsInt(); } @@ -275,6 +277,10 @@ class Inst : public Printable { return T(raw); } template <> + constexpr auto FromRaw(int32_t raw) -> IntId { + return IntId::MakeRaw(raw); + } + template <> constexpr auto FromRaw(int32_t raw) -> BuiltinInstKind { return BuiltinInstKind::FromInt(raw); } diff --git a/toolchain/sem_ir/type.h b/toolchain/sem_ir/type.h index 5cf5f435a263a..f197e40fecc25 100644 --- a/toolchain/sem_ir/type.h +++ b/toolchain/sem_ir/type.h @@ -99,6 +99,10 @@ class TypeStore : public Yaml::Printable { } // Determines whether the given type is a signed integer type. + // + // TODO: When we don't have a builtin int type mixed with actual `IntType` + // instructions, clients should directly query the `IntType` instruction to + // compute this information. auto IsSignedInt(TypeId int_type_id) const -> bool { auto inst_id = GetInstId(int_type_id); if (inst_id == InstId::BuiltinIntType) { diff --git a/toolchain/sem_ir/typed_insts.h b/toolchain/sem_ir/typed_insts.h index 691df160803db..875f53f0f2048 100644 --- a/toolchain/sem_ir/typed_insts.h +++ b/toolchain/sem_ir/typed_insts.h @@ -5,6 +5,7 @@ #ifndef CARBON_TOOLCHAIN_SEM_IR_TYPED_INSTS_H_ #define CARBON_TOOLCHAIN_SEM_IR_TYPED_INSTS_H_ +#include "toolchain/base/int_store.h" #include "toolchain/parse/node_ids.h" #include "toolchain/sem_ir/builtin_inst_kind.h" #include "toolchain/sem_ir/ids.h"