diff --git a/.github/workflows/benchmark_large.yml b/.github/workflows/benchmark_large.yml index f512c9f6d4ea..e5627e330d7d 100644 --- a/.github/workflows/benchmark_large.yml +++ b/.github/workflows/benchmark_large.yml @@ -23,6 +23,7 @@ on: # Please keep this default value in sync with the jobs.build_e2e_test_artifacts.with.shard-count field below default: a2-highgpu-1g=1,c2-standard-60=2,default=1 type: string + pull_request: concurrency: # A PR number if a pull request and otherwise the commit hash. This cancels diff --git a/build_tools/python/e2e_test_framework/models/jax_models.py b/build_tools/python/e2e_test_framework/models/jax_models.py index 1a8ce4bc8d5d..3faf44e306b8 100644 --- a/build_tools/python/e2e_test_framework/models/jax_models.py +++ b/build_tools/python/e2e_test_framework/models/jax_models.py @@ -86,3 +86,16 @@ ], batch_sizes=[1, 16, 24, 32, 48, 64, 512], ) + +# Derived from https://huggingface.co/docs/transformers/model_doc/gemma#transformers.FlaxGemmaForCausalLM +GEMMA_TAGS = ["gemma", "decoder-only", "pipeline", "generative"] + +GEMMA_2B_IT_GREEDY_FP32_JAX_1024XI32_256I32 = common_definitions.Model( + name="Gemma2bit_fp32", + id=unique_ids.MODEL_GEMMA_2B_IT_GREEDY_FP32_JAX_1024XI32_256I32, + tags=GEMMA_TAGS + ["fp32"], + source_type=common_definitions.ModelSourceType.EXPORTED_STABLEHLO_MLIR, + source_url="https://storage.googleapis.com/iree-model-artifacts/jax/jax_models_0.4.25_1709787220/GEMMA2BIT_GREEDY_FP32_JAX_1X1024XI32_256XI32/stablehlo.mlirbc", + entry_function="main", + input_types=["1x1024xi32"], +) diff --git a/build_tools/python/e2e_test_framework/models/model_groups.py b/build_tools/python/e2e_test_framework/models/model_groups.py index 151f21121bfd..3b977f28cb21 100644 --- a/build_tools/python/e2e_test_framework/models/model_groups.py +++ b/build_tools/python/e2e_test_framework/models/model_groups.py @@ -137,6 +137,9 @@ common_definitions.CpuBenchmarkConfig( model=jax_models.T5_LARGE_FP32_JAX_512XI32_BATCHES[32], threads=[30] ), + common_definitions.CpuBenchmarkConfig( + model=jax_models.GEMMA_2B_IT_GREEDY_FP32_JAX_1024XI32_256I32, threads=[30] + ), ] # Microkernels. diff --git a/build_tools/python/e2e_test_framework/unique_ids.py b/build_tools/python/e2e_test_framework/unique_ids.py index 0bc49e62e2b9..b77a35dc4d8d 100644 --- a/build_tools/python/e2e_test_framework/unique_ids.py +++ b/build_tools/python/e2e_test_framework/unique_ids.py @@ -120,6 +120,15 @@ def hash_composite_id(keys: Sequence[str]) -> str: MODEL_T5_LARGE_FP32_JAX = f"{MODEL_T5_LARGE_FP32}-JAX" MODEL_T5_LARGE_FP32_JAX_512XI32 = f"{MODEL_T5_LARGE_FP32_JAX}-512xi32" +MODEL_GEMMA_2B_IT_GREEDY = ( + "362aa7fe-f04e-477a-b364-c98c0eaca861-MODEL_GEMMA_2B_IT_GREEDY" +) +MODEL_GEMMA_2B_IT_GREEDY_FP32 = f"{MODEL_GEMMA_2B_IT_GREEDY}-fp32" +MODEL_GEMMA_2B_IT_GREEDY_FP32_JAX = f"{MODEL_GEMMA_2B_IT_GREEDY_FP32}-JAX" +MODEL_GEMMA_2B_IT_GREEDY_FP32_JAX_1024XI32_256I32 = ( + f"{MODEL_GEMMA_2B_IT_GREEDY_FP32_JAX}-1024XI32_256I32" +) + # Microbenchmarks. UB is shorthand for microbenchmark. MICRO_MATMUL_3456X1024X2048_FP16_MLIR = "50a7aece-73f9-47f4-a93a-4a1178f45407" MICRO_MATMUL_3456X1024X2048_FP32_MLIR = "a55afe1c-9410-47a6-b417-04b0d75ee5f4" diff --git a/tests/e2e/test_artifacts/generated_e2e_test_fetch_models.cmake b/tests/e2e/test_artifacts/generated_e2e_test_fetch_models.cmake index d9d36517e224..655c6edf7bd8 100644 --- a/tests/e2e/test_artifacts/generated_e2e_test_fetch_models.cmake +++ b/tests/e2e/test_artifacts/generated_e2e_test_fetch_models.cmake @@ -215,6 +215,13 @@ iree_fetch_artifact( UNPACK ) +iree_fetch_artifact( + NAME "model-Gemma2bit_fp32" + SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/jax/jax_models_0.4.25_1709787220/GEMMA2BIT_GREEDY_FP32_JAX_1X1024XI32_256XI32/stablehlo.mlirbc" + OUTPUT "${ROOT_ARTIFACTS_DIR}/model_Gemma2bit_fp32.mlirbc" + UNPACK +) + iree_fetch_artifact( NAME "model-matmul_3456x1024x2048_f16t_tile_config_default" SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/microbenchmarks/matmul/20230410_1681181224/matmul_3456x1024x2048_f16t_f16t_f16t_tile_config_default.mlirbc" diff --git a/tests/e2e/test_artifacts/generated_e2e_test_iree_artifacts.cmake b/tests/e2e/test_artifacts/generated_e2e_test_iree_artifacts.cmake index df6dc8367f61..84881b964d3d 100644 --- a/tests/e2e/test_artifacts/generated_e2e_test_iree_artifacts.cmake +++ b/tests/e2e/test_artifacts/generated_e2e_test_iree_artifacts.cmake @@ -1179,6 +1179,21 @@ iree_bytecode_module( PUBLIC ) +iree_bytecode_module( + NAME "iree-module-Gemma2bit_fp32_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_" + SRC "${ROOT_ARTIFACTS_DIR}/model_Gemma2bit_fp32.mlirbc" + MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_Gemma2bit_fp32_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_/module.vmfb" + FLAGS + "--iree-hal-target-backends=llvm-cpu" + "--iree-input-type=stablehlo" + "--iree-llvmcpu-target-triple=x86_64-unknown-linux-gnu" + "--iree-llvmcpu-target-cpu=cascadelake" + "--iree-opt-data-tiling=true" + "--iree-llvmcpu-enable-ukernels=all" + FRIENDLY_NAME "Gemma2bit_fp32(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu][default-flags,dt-uk]" + PUBLIC +) + iree_bytecode_module( NAME "iree-module-EfficientNetV2STF_stablehlo___cuda-sm_80-linux_gnu-cuda__default-flags_" SRC "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.mlirbc" @@ -3569,6 +3584,25 @@ iree_bytecode_module( PUBLIC ) +iree_bytecode_module( + NAME "iree-module-Gemma2bit_fp32_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_compile-stats_" + SRC "${ROOT_ARTIFACTS_DIR}/model_Gemma2bit_fp32.mlirbc" + MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_Gemma2bit_fp32_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_compile-stats_/module.vmfb" + FLAGS + "--iree-hal-target-backends=llvm-cpu" + "--iree-input-type=stablehlo" + "--iree-llvmcpu-target-triple=x86_64-unknown-linux-gnu" + "--iree-llvmcpu-target-cpu=cascadelake" + "--iree-opt-data-tiling=true" + "--iree-llvmcpu-enable-ukernels=all" + "--iree-vm-emit-polyglot-zip=true" + "--iree-llvmcpu-debug-symbols=false" + "--iree-scheduling-dump-statistics-format=json" + "--iree-scheduling-dump-statistics-file=${ROOT_ARTIFACTS_DIR}/iree_module_Gemma2bit_fp32_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_compile-stats_/scheduling_stats.json" + FRIENDLY_NAME "Gemma2bit_fp32(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu][default-flags,dt-uk,compile-stats]" + PUBLIC +) + iree_bytecode_module( NAME "iree-module-EfficientNetV2STF_stablehlo___cuda-sm_80-linux_gnu-cuda__default-flags_compile-stats_" SRC "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.mlirbc" @@ -4872,6 +4906,7 @@ add_dependencies(iree-benchmark-import-models-large ${PACKAGE_NAME}_model-BERT_LARGE_JAX_384XI32_BATCH1 ${PACKAGE_NAME}_model-BERT_LARGE_JAX_384XI32_BATCH32 ${PACKAGE_NAME}_model-BERT_LARGE_JAX_384XI32_BATCH64 + ${PACKAGE_NAME}_model-Gemma2bit_fp32 ${PACKAGE_NAME}_model-RESNET50_FP32_JAX_3X224X224XF32_BATCH1 ${PACKAGE_NAME}_model-RESNET50_FP32_JAX_3X224X224XF32_BATCH128 ${PACKAGE_NAME}_model-RESNET50_FP32_JAX_3X224X224XF32_BATCH64 @@ -5072,6 +5107,7 @@ add_dependencies(iree-benchmark-suites-comp-stats-large ${PACKAGE_NAME}_iree-module-BERT_LARGE_JAX_384XI32_BATCH1_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_compile-stats_ ${PACKAGE_NAME}_iree-module-BERT_LARGE_JAX_384XI32_BATCH32_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_compile-stats_ ${PACKAGE_NAME}_iree-module-BERT_LARGE_JAX_384XI32_BATCH64_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_compile-stats_ + ${PACKAGE_NAME}_iree-module-Gemma2bit_fp32_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_compile-stats_ ${PACKAGE_NAME}_iree-module-RESNET50_FP32_JAX_3X224X224XF32_BATCH128_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_compile-stats_ ${PACKAGE_NAME}_iree-module-RESNET50_FP32_JAX_3X224X224XF32_BATCH1_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_compile-stats_ ${PACKAGE_NAME}_iree-module-RESNET50_FP32_JAX_3X224X224XF32_BATCH64_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_compile-stats_ @@ -5239,6 +5275,7 @@ add_dependencies(iree-benchmark-suites-large ${PACKAGE_NAME}_iree-module-BERT_LARGE_JAX_384XI32_BATCH1_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_ ${PACKAGE_NAME}_iree-module-BERT_LARGE_JAX_384XI32_BATCH32_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_ ${PACKAGE_NAME}_iree-module-BERT_LARGE_JAX_384XI32_BATCH64_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_ + ${PACKAGE_NAME}_iree-module-Gemma2bit_fp32_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_ ${PACKAGE_NAME}_iree-module-RESNET50_FP32_JAX_3X224X224XF32_BATCH128_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_ ${PACKAGE_NAME}_iree-module-RESNET50_FP32_JAX_3X224X224XF32_BATCH1_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_ ${PACKAGE_NAME}_iree-module-RESNET50_FP32_JAX_3X224X224XF32_BATCH64_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_ @@ -5340,6 +5377,7 @@ add_dependencies(iree-benchmark-suites-x86_64-large ${PACKAGE_NAME}_iree-module-BERT_LARGE_JAX_384XI32_BATCH1_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_ ${PACKAGE_NAME}_iree-module-BERT_LARGE_JAX_384XI32_BATCH32_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_ ${PACKAGE_NAME}_iree-module-BERT_LARGE_JAX_384XI32_BATCH64_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_ + ${PACKAGE_NAME}_iree-module-Gemma2bit_fp32_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_ ${PACKAGE_NAME}_iree-module-RESNET50_FP32_JAX_3X224X224XF32_BATCH128_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_ ${PACKAGE_NAME}_iree-module-RESNET50_FP32_JAX_3X224X224XF32_BATCH1_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_ ${PACKAGE_NAME}_iree-module-RESNET50_FP32_JAX_3X224X224XF32_BATCH64_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_