Skip to content

Commit

Permalink
Add Gemma pipeline to benchmark_large
Browse files Browse the repository at this point in the history
Signed-off-by: mariecwhite <mariewhite@google.com>
  • Loading branch information
mariecwhite committed Jun 27, 2024
1 parent 794a3ca commit 3a56262
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 0 deletions.
1 change: 1 addition & 0 deletions .github/workflows/benchmark_large.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ on:
# Please keep this default value in sync with the jobs.build_e2e_test_artifacts.with.shard-count field below
default: a2-highgpu-1g=1,c2-standard-60=2,default=1
type: string
pull_request:

concurrency:
# A PR number if a pull request and otherwise the commit hash. This cancels
Expand Down
13 changes: 13 additions & 0 deletions build_tools/python/e2e_test_framework/models/jax_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,16 @@
],
batch_sizes=[1, 16, 24, 32, 48, 64, 512],
)

# Derived from https://huggingface.co/docs/transformers/model_doc/gemma#transformers.FlaxGemmaForCausalLM
GEMMA_TAGS = ["gemma", "decoder-only", "pipeline", "generative"]

GEMMA_2B_IT_GREEDY_FP32_JAX_1024XI32_256I32 = common_definitions.Model(
name="Gemma2bit_fp32",
id=unique_ids.MODEL_GEMMA_2B_IT_GREEDY_FP32_JAX_1024XI32_256I32,
tags=GEMMA_TAGS + ["fp32"],
source_type=common_definitions.ModelSourceType.EXPORTED_STABLEHLO_MLIR,
source_url="https://storage.googleapis.com/iree-model-artifacts/jax/jax_models_0.4.25_1709787220/stablehlo.mlirbc",
entry_function="main",
input_types=["1x1024xi32"],
)
3 changes: 3 additions & 0 deletions build_tools/python/e2e_test_framework/models/model_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@
common_definitions.CpuBenchmarkConfig(
model=jax_models.T5_LARGE_FP32_JAX_512XI32_BATCHES[32], threads=[30]
),
common_definitions.CpuBenchmarkConfig(
model=jax_models.GEMMA_2B_IT_GREEDY_FP32_JAX_1024XI32_256I32, threads=[30]
),
]

# Microkernels.
Expand Down
9 changes: 9 additions & 0 deletions build_tools/python/e2e_test_framework/unique_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,15 @@ def hash_composite_id(keys: Sequence[str]) -> str:
MODEL_T5_LARGE_FP32_JAX = f"{MODEL_T5_LARGE_FP32}-JAX"
MODEL_T5_LARGE_FP32_JAX_512XI32 = f"{MODEL_T5_LARGE_FP32_JAX}-512xi32"

MODEL_GEMMA_2B_IT_GREEDY = (
"362aa7fe-f04e-477a-b364-c98c0eaca861-MODEL_GEMMA_2B_IT_GREEDY"
)
MODEL_GEMMA_2B_IT_GREEDY_FP32 = f"{MODEL_GEMMA_2B_IT_GREEDY}-fp32"
MODEL_GEMMA_2B_IT_GREEDY_FP32_JAX = f"{MODEL_GEMMA_2B_IT_GREEDY_FP32}-JAX"
MODEL_GEMMA_2B_IT_GREEDY_FP32_JAX_1024XI32_256I32 = (
f"{MODEL_GEMMA_2B_IT_GREEDY_FP32_JAX}-1024XI32_256I32"
)

# Microbenchmarks. UB is shorthand for microbenchmark.
MICRO_MATMUL_3456X1024X2048_FP16_MLIR = "50a7aece-73f9-47f4-a93a-4a1178f45407"
MICRO_MATMUL_3456X1024X2048_FP32_MLIR = "a55afe1c-9410-47a6-b417-04b0d75ee5f4"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,13 @@ iree_fetch_artifact(
UNPACK
)

iree_fetch_artifact(
NAME "model-Gemma2bit_fp32"
SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/jax/jax_models_0.4.25_1709787220/stablehlo.mlirbc"
OUTPUT "${ROOT_ARTIFACTS_DIR}/model_Gemma2bit_fp32.mlirbc"
UNPACK
)

iree_fetch_artifact(
NAME "model-matmul_3456x1024x2048_f16t_tile_config_default"
SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/microbenchmarks/matmul/20230410_1681181224/matmul_3456x1024x2048_f16t_f16t_f16t_tile_config_default.mlirbc"
Expand Down
38 changes: 38 additions & 0 deletions tests/e2e/test_artifacts/generated_e2e_test_iree_artifacts.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1179,6 +1179,21 @@ iree_bytecode_module(
PUBLIC
)

iree_bytecode_module(
NAME "iree-module-Gemma2bit_fp32_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_"
SRC "${ROOT_ARTIFACTS_DIR}/model_Gemma2bit_fp32.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_Gemma2bit_fp32_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
"--iree-input-type=stablehlo"
"--iree-llvmcpu-target-triple=x86_64-unknown-linux-gnu"
"--iree-llvmcpu-target-cpu=cascadelake"
"--iree-opt-data-tiling=true"
"--iree-llvmcpu-enable-ukernels=all"
FRIENDLY_NAME "Gemma2bit_fp32(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu][default-flags,dt-uk]"
PUBLIC
)

iree_bytecode_module(
NAME "iree-module-EfficientNetV2STF_stablehlo___cuda-sm_80-linux_gnu-cuda__default-flags_"
SRC "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.mlirbc"
Expand Down Expand Up @@ -3569,6 +3584,25 @@ iree_bytecode_module(
PUBLIC
)

iree_bytecode_module(
NAME "iree-module-Gemma2bit_fp32_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_compile-stats_"
SRC "${ROOT_ARTIFACTS_DIR}/model_Gemma2bit_fp32.mlirbc"
MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_Gemma2bit_fp32_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_compile-stats_/module.vmfb"
FLAGS
"--iree-hal-target-backends=llvm-cpu"
"--iree-input-type=stablehlo"
"--iree-llvmcpu-target-triple=x86_64-unknown-linux-gnu"
"--iree-llvmcpu-target-cpu=cascadelake"
"--iree-opt-data-tiling=true"
"--iree-llvmcpu-enable-ukernels=all"
"--iree-vm-emit-polyglot-zip=true"
"--iree-llvmcpu-debug-symbols=false"
"--iree-scheduling-dump-statistics-format=json"
"--iree-scheduling-dump-statistics-file=${ROOT_ARTIFACTS_DIR}/iree_module_Gemma2bit_fp32_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_compile-stats_/scheduling_stats.json"
FRIENDLY_NAME "Gemma2bit_fp32(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu][default-flags,dt-uk,compile-stats]"
PUBLIC
)

iree_bytecode_module(
NAME "iree-module-EfficientNetV2STF_stablehlo___cuda-sm_80-linux_gnu-cuda__default-flags_compile-stats_"
SRC "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.mlirbc"
Expand Down Expand Up @@ -4872,6 +4906,7 @@ add_dependencies(iree-benchmark-import-models-large
${PACKAGE_NAME}_model-BERT_LARGE_JAX_384XI32_BATCH1
${PACKAGE_NAME}_model-BERT_LARGE_JAX_384XI32_BATCH32
${PACKAGE_NAME}_model-BERT_LARGE_JAX_384XI32_BATCH64
${PACKAGE_NAME}_model-Gemma2bit_fp32
${PACKAGE_NAME}_model-RESNET50_FP32_JAX_3X224X224XF32_BATCH1
${PACKAGE_NAME}_model-RESNET50_FP32_JAX_3X224X224XF32_BATCH128
${PACKAGE_NAME}_model-RESNET50_FP32_JAX_3X224X224XF32_BATCH64
Expand Down Expand Up @@ -5072,6 +5107,7 @@ add_dependencies(iree-benchmark-suites-comp-stats-large
${PACKAGE_NAME}_iree-module-BERT_LARGE_JAX_384XI32_BATCH1_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_compile-stats_
${PACKAGE_NAME}_iree-module-BERT_LARGE_JAX_384XI32_BATCH32_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_compile-stats_
${PACKAGE_NAME}_iree-module-BERT_LARGE_JAX_384XI32_BATCH64_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_compile-stats_
${PACKAGE_NAME}_iree-module-Gemma2bit_fp32_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_compile-stats_
${PACKAGE_NAME}_iree-module-RESNET50_FP32_JAX_3X224X224XF32_BATCH128_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_compile-stats_
${PACKAGE_NAME}_iree-module-RESNET50_FP32_JAX_3X224X224XF32_BATCH1_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_compile-stats_
${PACKAGE_NAME}_iree-module-RESNET50_FP32_JAX_3X224X224XF32_BATCH64_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_compile-stats_
Expand Down Expand Up @@ -5239,6 +5275,7 @@ add_dependencies(iree-benchmark-suites-large
${PACKAGE_NAME}_iree-module-BERT_LARGE_JAX_384XI32_BATCH1_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_
${PACKAGE_NAME}_iree-module-BERT_LARGE_JAX_384XI32_BATCH32_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_
${PACKAGE_NAME}_iree-module-BERT_LARGE_JAX_384XI32_BATCH64_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_
${PACKAGE_NAME}_iree-module-Gemma2bit_fp32_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_
${PACKAGE_NAME}_iree-module-RESNET50_FP32_JAX_3X224X224XF32_BATCH128_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_
${PACKAGE_NAME}_iree-module-RESNET50_FP32_JAX_3X224X224XF32_BATCH1_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_
${PACKAGE_NAME}_iree-module-RESNET50_FP32_JAX_3X224X224XF32_BATCH64_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_
Expand Down Expand Up @@ -5340,6 +5377,7 @@ add_dependencies(iree-benchmark-suites-x86_64-large
${PACKAGE_NAME}_iree-module-BERT_LARGE_JAX_384XI32_BATCH1_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_
${PACKAGE_NAME}_iree-module-BERT_LARGE_JAX_384XI32_BATCH32_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_
${PACKAGE_NAME}_iree-module-BERT_LARGE_JAX_384XI32_BATCH64_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_
${PACKAGE_NAME}_iree-module-Gemma2bit_fp32_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_
${PACKAGE_NAME}_iree-module-RESNET50_FP32_JAX_3X224X224XF32_BATCH128_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_
${PACKAGE_NAME}_iree-module-RESNET50_FP32_JAX_3X224X224XF32_BATCH1_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_
${PACKAGE_NAME}_iree-module-RESNET50_FP32_JAX_3X224X224XF32_BATCH64_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_
Expand Down

0 comments on commit 3a56262

Please sign in to comment.