-
Notifications
You must be signed in to change notification settings - Fork 500
/
Copy pathconstants_test.cc
279 lines (219 loc) · 8.48 KB
/
constants_test.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
/* Copyright 2017 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Tests that constants in program memory round trip as expected.
#include "xla/hlo/builder/lib/constants.h"
#include <memory>
#include <vector>
#include <gtest/gtest.h>
#include "xla/array2d.h"
#include "xla/array3d.h"
#include "xla/array4d.h"
#include "xla/client/local_client.h"
#include "xla/hlo/builder/xla_builder.h"
#include "xla/literal_util.h"
#include "xla/tests/client_library_test_base.h"
#include "xla/tests/hlo_test_base.h"
#include "xla/tests/literal_test_util.h"
#include "xla/tests/test_macros.h"
#include "xla/tests/test_utils.h"
#include "xla/tsl/lib/core/status_test_util.h"
#include "xla/types.h"
#include "tsl/platform/ml_dtypes.h"
#include "tsl/platform/test.h"
namespace xla {
namespace {
class ConstantsTest : public ClientLibraryTestBase {
protected:
const ErrorSpec error_spec_{1e-3, 1e-5};
};
template <typename T>
class ConstantsFloatTest : public ConstantsTest {};
using FloatTypes =
::testing::Types<float, half, tsl::float8_e3m4, tsl::float8_e4m3,
tsl::float8_e4m3fn, tsl::float8_e4m3b11fnuz,
tsl::float8_e4m3fnuz, tsl::float8_e5m2,
tsl::float8_e5m2fnuz
#ifndef XLA_TEST_BACKEND_TPU
// TODO(b/385004399): Run tests on these types on TPU.
,
tsl::float4_e2m1fn, tsl::float8_e8m0fnu
#endif
>;
TYPED_TEST_SUITE(ConstantsFloatTest, FloatTypes);
TEST_F(ConstantsTest, ZeroCellF32) {
XlaBuilder builder(TestName());
ConstantR1<float>(&builder, {});
ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
}
TYPED_TEST(ConstantsFloatTest, OneCellFloat) {
std::vector<TypeParam> constant = {TypeParam{2.0}};
XlaBuilder builder(ClientLibraryTestBase::TestName());
ConstantR1<TypeParam>(&builder, constant);
ClientLibraryTestBase::ComputeAndCompareR1<TypeParam>(&builder, constant, {},
this->error_spec_);
}
TEST_F(ConstantsTest, OneCellS32) {
std::vector<int32_t> constant = {2};
XlaBuilder builder(TestName());
ConstantR1<int32_t>(&builder, constant);
ComputeAndCompareR1<int32_t>(&builder, constant, {});
}
TEST_F(ConstantsTest, OneCellU32) {
std::vector<uint32_t> constant = {2};
XlaBuilder builder(TestName());
ConstantR1<uint32_t>(&builder, constant);
ComputeAndCompareR1<uint32_t>(&builder, constant, {});
}
TEST_F(ConstantsTest, DISABLED_ON_CPU(DISABLED_ON_GPU(OneCellU4))) {
std::vector<u4> constant = {u4(2)};
XlaBuilder builder(TestName());
auto c = ConstantR1<u4>(&builder, constant);
// ComputeAndCompareR1 currently does not support U4, so convert to U8
ConvertElementType(c, U8);
ComputeAndCompareR1<uint8_t>(&builder, {2}, {});
}
TEST_F(ConstantsTest, DISABLED_ON_CPU(DISABLED_ON_GPU(OneCellS4))) {
std::vector<s4> constant = {s4(-2)};
XlaBuilder builder(TestName());
auto c = ConstantR1<s4>(&builder, constant);
// ComputeAndCompareR1 currently does not support S4, so convert to S8
ConvertElementType(c, S8);
ComputeAndCompareR1<int8_t>(&builder, {-2}, {});
}
TEST_F(ConstantsTest, EightCells) {
std::vector<float> constant = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0};
XlaBuilder builder(TestName());
ConstantR1<float>(&builder, constant);
ComputeAndCompareR1<float>(&builder, constant, {}, error_spec_);
}
TEST_F(ConstantsTest, SixteenCells) {
std::vector<float> constant = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0,
8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0};
XlaBuilder builder(TestName());
ConstantR1<float>(&builder, constant);
ComputeAndCompareR1<float>(&builder, constant, {}, error_spec_);
}
TEST_F(ConstantsTest, Empty_0x2) {
XlaBuilder builder(TestName());
ConstantR2FromArray2D<float>(&builder, Array2D<float>(0, 2));
ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 2), {}, error_spec_);
}
TEST_F(ConstantsTest, Small_2x2) {
std::unique_ptr<Array2D<float>> constant =
MakeLinspaceArray2D(100.0, 200.0, 2, 2);
XlaBuilder builder(TestName());
ConstantR2FromArray2D<float>(&builder, *constant);
ComputeAndCompareR2<float>(&builder, *constant, {}, error_spec_);
}
TEST_F(ConstantsTest, Empty_3x0x2) {
XlaBuilder builder(TestName());
ConstantLiteral(&builder, LiteralUtil::CreateR3FromArray3D<float>(
Array3D<float>(3, 0, 2)));
ComputeAndCompareR3<float>(&builder, Array3D<float>(3, 0, 2), {});
}
TEST_F(ConstantsTest, Small_2x2x2) {
XlaBuilder builder(TestName());
Array3D<float> array3d({
// x0 x1
{{1.f, 2.f}, // y0
{3.f, 4.f}}, // y1
{{5.f, 6.f}, // y0
{7.f, 8.f}}, // y1
});
ConstantLiteral(&builder, LiteralUtil::CreateR3FromArray3D<float>(array3d));
ComputeAndCompareR3<float>(&builder, array3d, {});
}
TEST_F(ConstantsTest, Small_3x2x1x1) {
Array4D<float> input_array(3, 2, 1, 1);
Array2D<float> pz({
// z0 z1
{-1.0f, 4.1f}, // p0
{2.0f, 4.1f}, // p1
{5.0f, 4.4f}, // p2
});
input_array.FillWithPZ(pz);
Literal input_literal = LiteralUtil::CreateR4FromArray4D(input_array);
{
XlaBuilder builder(TestName());
ConstantLiteral(&builder, input_literal);
ComputeAndCompareR4<float>(&builder, input_array, {}, error_spec_);
}
{
XlaBuilder builder(TestName());
ConstantR4FromArray4D<float>(&builder, input_array);
ComputeAndCompareR4<float>(&builder, input_array, {}, error_spec_);
}
}
// TODO(b/29263943): Support tuple constants.
TEST_F(ConstantsTest, DISABLED_TupleConstant) {
XlaBuilder builder(TestName());
ConstantLiteral(&builder, LiteralUtil::MakeTupleFromSlices(
{LiteralUtil::CreateR2<float>({{1.0}, {2.0}}),
LiteralUtil::CreateR1<float>({2.0, 42})}));
Literal result = ExecuteAndTransfer(&builder, {}).value();
LiteralTestUtil::ExpectR2Near<float>({{1.0}, {2.0}},
LiteralSlice(result, {0}), error_spec_);
LiteralTestUtil::ExpectR1Near<float>({2.0, 42.0}, LiteralSlice(result, {1}),
error_spec_);
}
TEST_F(ConstantsTest, Token) {
XlaBuilder builder(TestName());
ConstantLiteral(&builder, LiteralUtil::CreateToken());
// TODO(b/80000000): tokens cannot be returned from computations.
Tuple(&builder, {});
TF_ASSERT_OK(Execute(&builder, {}).status());
}
TEST_F(ConstantsTest, FullLike) {
XlaBuilder b(TestName());
auto val1 = Iota(&b, F32, 3);
auto val2 = FullLike(val1, 10);
val1 + val2;
ComputeAndCompareR1<float>(&b, {10, 11, 12}, {}, error_spec_);
}
TEST_F(ConstantsTest, IllegalFullLikeOnTuple) {
XlaBuilder b(TestName());
auto tuple = Tuple(&b, {Iota(&b, F32, 3), Iota(&b, F32, 1)});
FullLike(tuple, 10); // Illegal; can't do FullLike on a tuple.
EXPECT_FALSE(b.Build().ok());
}
TEST_F(ConstantsTest, FullLikeScalar) {
XlaBuilder b(TestName());
auto scalar1 = ConstantR0WithType(&b, F32, 1);
auto scalar2 = FullLike(scalar1, 2);
scalar1 - scalar2;
ComputeAndCompareR0<float>(&b, -1, {}, error_spec_);
}
class ConstantsHloTest : public HloTestBase {};
// TODO(b/121147351): Fails on GPU. Not clear if this is expected behavior.
XLA_TEST_F(ConstantsHloTest,
DISABLED_ON_TPU(DISABLED_ON_GPU(BitcastOfConstant))) {
const char* testcase = R"(
HloModule module, is_scheduled=true
func {
lhs = s32[] parameter(0)
rhs = s32[] parameter(1)
ROOT mul = s32[] add(lhs, rhs)
}
ENTRY test {
constant.0 = s32[1]{0} constant({0})
parameter.0 = s32[] parameter(0)
constant-as-scalar = s32[] bitcast(constant.0)
ROOT result = s32[] call(parameter.0, constant-as-scalar), to_apply=func
}
)";
auto module = ParseAndReturnVerifiedModule(testcase).value();
auto param = LiteralUtil::CreateR0<int32_t>(1);
auto result = ExecuteNoHloPasses(std::move(module), {¶m});
EXPECT_TRUE(LiteralTestUtil::Equal(param, result));
}
} // namespace
} // namespace xla