diff --git a/python/setup.py b/python/setup.py index 19b29b3aa..b91b498c0 100644 --- a/python/setup.py +++ b/python/setup.py @@ -64,10 +64,10 @@ def get_instantiation_cu() -> List[str]: (root / prefix).mkdir(parents=True, exist_ok=True) group_sizes = os.environ.get("FLASHINFER_GROUP_SIZES", "1,4,6,8").split(",") - page_sizes = os.environ.get("FLASHINFER_PAGE_SIZES", "1").split(",") - head_dims = os.environ.get("FLASHINFER_HEAD_DIMS", "128").split(",") - kv_layouts = os.environ.get("FLASHINFER_KV_LAYOUTS", "0").split(",") - pos_encoding_modes = os.environ.get("FLASHINFER_POS_ENCODING_MODES", "0").split( + page_sizes = os.environ.get("FLASHINFER_PAGE_SIZES", "1,16,32").split(",") + head_dims = os.environ.get("FLASHINFER_HEAD_DIMS", "64,128,256").split(",") + kv_layouts = os.environ.get("FLASHINFER_KV_LAYOUTS", "0,1").split(",") + pos_encoding_modes = os.environ.get("FLASHINFER_POS_ENCODING_MODES", "0,1,2").split( "," ) allow_fp16_qk_reduction_options = os.environ.get(