Skip to content

Commit

Permalink
Remove peft dependency for AMD GPUs
Browse files Browse the repository at this point in the history
  • Loading branch information
borzunov committed Aug 5, 2023
1 parent a1f7791 commit 6b38bc8
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 33 deletions.
1 change: 0 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ install_requires =
cpufeature>=0.2.0
packaging>=20.9
sentencepiece>=0.1.99
peft>=0.4.0
safetensors>=0.3.1
Dijkstar>=2.6.0

Expand Down
14 changes: 3 additions & 11 deletions src/petals/server/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,6 @@ def __init__(
max_chunk_size_bytes: int,
**kwargs,
):
import petals.utils.peft as _peft_module

self._peft_module = _peft_module

super().__init__(*args, **kwargs)
assert isinstance(self.module, TensorParallel)
self.config = config
Expand Down Expand Up @@ -98,13 +94,11 @@ def get_inference_cache_descriptors(self, batch_size: int, max_length: int) -> S

def forward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]:
*inputs, active_adapter = inputs
with self._peft_module.using_adapter(active_adapter):
return super().forward(*inputs)
return super().forward(*inputs)

def backward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]:
*inputs, active_adapter = inputs
with self._peft_module.using_adapter(active_adapter):
return super().backward(*inputs)
return super().backward(*inputs)

@torch.inference_mode()
def inference_step(
Expand All @@ -116,9 +110,7 @@ def inference_step(
assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
seq_len = hidden_states.shape[1]

with self.memory_cache.use_cache(
*inference_info.cache_handles
) as cache_tensors, self._peft_module.using_adapter(inference_info.active_adapter):
with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors:
self._reorder_cache_inplace(cache_tensors, hypo_ids)

# We chunk the inputs so that peak memory for long sequences fits into `autograd_memory`
Expand Down
12 changes: 1 addition & 11 deletions src/petals/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,17 +278,7 @@ def _choose_num_blocks(self) -> int:
block_size = get_block_size(self.block_config, "memory", dtype=self.torch_dtype, quant_type=self.quant_type)
total_memory_per_block = block_size + self._cache_bytes_per_block
if self.adapters:
# Delay import of petals.utils.peft to avoid unnecessary import of bitsandbytes
from petals.utils.peft import estimate_adapter_memory_per_block

total_memory_per_block += estimate_adapter_memory_per_block(
self.block_config,
self.torch_dtype,
self.adapters,
token=self.token,
cache_dir=self.cache_dir,
max_disk_space=self.max_disk_space,
)
raise RuntimeError("LoRA adapters are not supported on AMD GPUs")

num_blocks = math.floor((total_memory - autograd_memory) / total_memory_per_block)
assert num_blocks >= 1, "Your GPU does not have enough memory to serve at least one block"
Expand Down
11 changes: 1 addition & 10 deletions src/petals/utils/convert_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,7 @@ def convert_block(
shard.to(device)

if adapters:
from petals.utils.peft import add_adapter_to_block, create_lora_adapter, load_peft

create_lora_adapter(block, quant_type=quant_type)
for adapter_name in adapters:
adapter_config, adapter_state_dict = load_peft(
adapter_name,
block_idx=block_index,
**kwargs,
)
add_adapter_to_block(block, block_index, adapter_name, adapter_config, adapter_state_dict)
raise RuntimeError("LoRA adapters are not supported on AMD GPUs")

return block

Expand Down
2 changes: 2 additions & 0 deletions tests/test_peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import pytest
from huggingface_hub import snapshot_download

pytest.skip("LoRA adapters are not supported on AMD GPUs", allow_module_level=True)

from petals.utils.peft import check_peft_repository, load_peft

UNSAFE_PEFT_REPO = "artek0chumak/bloom-560m-unsafe-peft"
Expand Down

0 comments on commit 6b38bc8

Please sign in to comment.