Skip to content

Commit

Permalink
fix(deps): Copy KHD imports into scattermoe_utils (#127)
Browse files Browse the repository at this point in the history
* fix: move necessary khd functions into scattermoe_utils

Signed-off-by: Will Johnson <mwjohnson728@gmail.com>

* remove requirements-khd, update README

Signed-off-by: Will Johnson <mwjohnson728@gmail.com>

* fix: remove ext dep from benchmarks

Signed-off-by: Will Johnson <mwjohnson728@gmail.com>

* fix:  move files to khd folder

Signed-off-by: Will Johnson <mwjohnson728@gmail.com>

* fmt

Signed-off-by: Will Johnson <mwjohnson728@gmail.com>

* fix lint

Signed-off-by: Will Johnson <mwjohnson728@gmail.com>

* rm copyright

Signed-off-by: Will Johnson <mwjohnson728@gmail.com>

* reference cute kernels

Signed-off-by: Will Johnson <mwjohnson728@gmail.com>

* cute kernels reference

Signed-off-by: Will Johnson <mwjohnson728@gmail.com>

---------

Signed-off-by: Will Johnson <mwjohnson728@gmail.com>
  • Loading branch information
willmj authored Mar 5, 2025
1 parent 791bdd9 commit de9a4f1
Show file tree
Hide file tree
Showing 11 changed files with 1,207 additions and 30 deletions.
2 changes: 1 addition & 1 deletion plugins/accelerated-moe/.pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ ignore=CVS,protobufs
# ignore-list. The regex matches against paths and can be in Posix or Windows
# format. Because '\\' represents the directory delimiter on Windows systems,
# it can't be used as an escape character.
ignore-paths=.*megablocks
ignore-paths=.*megablocks,.*khd

# Files or directories matching the regular expression patterns are skipped.
# The regex matches against base names, not paths. The default value ignores
Expand Down
8 changes: 1 addition & 7 deletions plugins/accelerated-moe/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ Run the below in the top-level directory of this repo:

```
tox -e run-benches \
-x testenv:run-benches.deps+="-r plugins/accelerated-moe/requirements-khd.txt" \
-x testenv:run-benches.setenv+="MEMORY_LOGGING=nvidia" \
-- \
"1 2 4" 128 benchmark_outputs scenarios-moe.yaml accelerated-moe-full
Expand Down Expand Up @@ -77,12 +76,7 @@ bash scripts/run_benchmarks.sh \

### Triton Kernel Dependencies

Currently we do not copy the `scattermoe` kernels into this respository, to this is an additional manual install:

```
# this will install the kernel-hyperdrive fork with the scattermoe triton kernels
pip install -r requirements-khd.txt
```
Triton Kernels are copied into [scattermoe_utils](./src/fms_acceleration_moe/utils/scattermoe_utils/megablocks/kernels) and were copied from [kernel hyperdrive](https://github.com/fabianlim/kernel-hyperdrive) which is a fork of [cute kernels](https://github.com/mayank31398/cute-kernels)

### Known Issues

Expand Down
2 changes: 0 additions & 2 deletions plugins/accelerated-moe/requirements-khd.txt

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,6 @@
# pylint: disable=too-many-instance-attributes
class ScatterMoEAccelerationPlugin(AccelerationPlugin):

# NOTE: we cannot do
# - require_packages = {"khd"}
# this is because the khd fork is not properly packaged as a PyPI project, and so
# - "importlib.util.find_spec('khd')" returns, but
# - "importlib.metadata.version('kernel-hyperdrive')" does not return
# if we decide to extract the kernels, then we do not need to anymore,
# https://github.com/foundation-model-stack/fms-acceleration/issues/105

restricted_model_archs = [
"GraniteMoeForCausalLM",
"MixtralForCausalLM",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,13 @@
import torch
import torch.nn.functional as F

try:
# Third Party
from khd.kernels.scattermoe.triton_implementation.ops import (
padded_block_indices,
scattered_experts,
)
except ImportError as e:
raise ImportError(
"kernel-hyperdrive PyPI package not found. Install it: "
"pip install -r plugins/accelerated-moe/requirements-khd.txt"
) from e

# Local
from .scattermoe_constants import SCATTERMOE_SPEC_HAS_GATE
from .scattermoe_utils import all_to_all_gather_inputs, scatter_with_routing_weights
from .scattermoe_utils.khd.kernels.ops import (
padded_block_indices,
scattered_experts,
)


# helper function to fetch the local tensor if its a dtensor
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright The FMS HF Tuning Authors
# Copyright 2024 Databricks
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Local
from .custom_op import torch_custom_op
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Standard
from typing import Any, Callable, Iterable

# Third Party
import torch

try:
# Third Party
from torch.library import custom_op

_IS_CUSTOM_OP_IN_PYTORCH = True
except:
_IS_CUSTOM_OP_IN_PYTORCH = False


class _IdentityOp:
def __init__(self, fn: Callable) -> None:
self.fn = fn

def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self.fn(*args, **kwargs)

def register_fake(self, fn: Callable) -> Callable:
return fn


def torch_custom_op(
name: str,
fn: Callable | None = None,
/,
*,
mutates_args: str | Iterable[str],
device_types: torch.device = None,
schema: str | None = None,
) -> Callable | _IdentityOp:
if _IS_CUSTOM_OP_IN_PYTORCH:
op = custom_op(
name,
fn,
mutates_args=mutates_args,
device_types=device_types,
schema=schema,
)
else:
op = _IdentityOp if fn is None else _IdentityOp(fn)

return op
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright The FMS HF Tuning Authors
# Copyright 2024 Databricks
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Local
from .kernels import (
group_triton_kernel,
groupXtY_triton_kernel,
scatter2scatter_lora_triton_kernel,
scatter2scatter_triton_kernel,
)
Loading

0 comments on commit de9a4f1

Please sign in to comment.