Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

multi-gpu fix #668

Merged
merged 7 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,7 @@ quant_path = 'mistral-instruct-v0.2-awq'
quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" }

# Load model
model = AutoAWQForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, use_cache=False
)
model = AutoAWQForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

# Quantize
Expand Down
1 change: 1 addition & 0 deletions awq/models/aquila.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def get_act_for_scaling(module: OldAquilaDecoderLayer):
@staticmethod
def move_embed(model: OldAquilaForCausalLM, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)
model.model.rotary_emb = model.model.rotary_emb.to(device)

@staticmethod
def get_layers_for_scaling(
Expand Down
9 changes: 8 additions & 1 deletion awq/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,21 @@ def from_pretrained(
model_path,
trust_remote_code=True,
safetensors=True,
device_map="auto",
device_map=None,
download_kwargs=None,
low_cpu_mem_usage=True,
use_cache=False,
**model_init_kwargs,
) -> BaseAWQForCausalLM:
model_type = check_and_get_model_type(
model_path, trust_remote_code, **model_init_kwargs
)

if model_init_kwargs.get("low_cpu_mem_usage") is None:
model_init_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
if model_init_kwargs.get("use_cache") is None:
model_init_kwargs["use_cache"] = use_cache

return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_pretrained(
model_path,
model_type,
Expand Down
1 change: 1 addition & 0 deletions awq/models/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def get_act_for_scaling(module):
@staticmethod
def move_embed(model, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)
model.model.rotary_emb = model.model.rotary_emb.to(device)

@staticmethod
# def get_layers_for_scaling(module: OldLlamaDecoderLayer, input_feat, module_kwargs):
Expand Down
1 change: 1 addition & 0 deletions awq/models/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def get_act_for_scaling(module: OldCohereDecoderLayer):
@staticmethod
def move_embed(model: OldCohereForCausalLM, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)
model.model.rotary_emb = model.model.rotary_emb.to(device)

@staticmethod
def get_layers_for_scaling(
Expand Down
1 change: 1 addition & 0 deletions awq/models/exaone.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def get_act_for_scaling(module: OldExaoneBlock):
@staticmethod
def move_embed(model: OldExaoneForCausalLM, device: str):
model.transformer.wte = model.transformer.wte.to(device)
model.transformer.rotary = model.transformer.rotary.to(device)

@staticmethod
def get_layers_for_scaling(module: OldExaoneBlock, input_feat, module_kwargs):
Expand Down
1 change: 1 addition & 0 deletions awq/models/falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def get_act_for_scaling(module: OldFalconDecoderLayer):
@staticmethod
def move_embed(model: FalconForCausalLM, device):
model.transformer.word_embeddings = model.transformer.word_embeddings.to(device)
model.transformer.rotary_emb = model.transformer.rotary_emb.to(device)

@staticmethod
def get_layers_for_scaling(
Expand Down
1 change: 1 addition & 0 deletions awq/models/gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def get_act_for_scaling(module: GPTNeoXLayer):
@staticmethod
def move_embed(model: GPTNeoXForCausalLM, device: str):
model.gpt_neox.embed_in = model.gpt_neox.embed_in.to(device)
model.gpt_neox.rotary_emb = model.gpt_neox.rotary_emb.to(device)

@staticmethod
def get_layers_for_scaling(module: GPTNeoXLayer, input_feat, module_kwargs):
Expand Down
1 change: 1 addition & 0 deletions awq/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def get_act_for_scaling(module: OldLlamaDecoderLayer):
@staticmethod
def move_embed(model: OldLlamaForCausalLM, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)
model.model.rotary_emb = model.model.rotary_emb.to(device)

@staticmethod
def get_layers_for_scaling(module: OldLlamaDecoderLayer, input_feat, module_kwargs):
Expand Down
1 change: 1 addition & 0 deletions awq/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def move_embed(model: OldLlavaForConditionalGeneration, device: str):
model.language_model.model.embed_tokens = model.get_input_embeddings().to(
device
)
model.language_model.model.rotary_emb = model.language_model.model.rotary_emb.to(device)

@staticmethod
def get_layers_for_scaling(module: OldLlamaDecoderLayer, input_feat, module_kwargs):
Expand Down
1 change: 1 addition & 0 deletions awq/models/llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def move_embed(model: LlavaNextForConditionalGeneration, device: str):
model.language_model.model.embed_tokens = model.get_input_embeddings().to(
device
)
model.language_model.model.rotary_emb = model.language_model.model.rotary_emb.to(device)

@staticmethod
def get_layers_for_scaling(module: OldLlamaDecoderLayer, input_feat, module_kwargs):
Expand Down
1 change: 1 addition & 0 deletions awq/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def get_act_for_scaling(module: OldQwen2DecoderLayer):
@staticmethod
def move_embed(model: OldQwen2ForCausalLM, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)
model.model.rotary_emb = model.model.rotary_emb.to(device)

@staticmethod
def get_layers_for_scaling(module: OldQwen2DecoderLayer, input_feat, module_kwargs):
Expand Down
1 change: 1 addition & 0 deletions awq/models/qwen2vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def get_act_for_scaling(module: "Qwen2VLForConditionalGeneration"):
def move_embed(model: "Qwen2VLForConditionalGeneration", device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)
model.visual = model.visual.to(device)
model.model.rotary_emb = model.model.rotary_emb.to(device)

@staticmethod
def get_layers_for_scaling(module: "Qwen2VLDecoderLayer", input_feat, module_kwargs):
Expand Down
1 change: 1 addition & 0 deletions awq/models/stablelm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def get_act_for_scaling(module: OldStableLmForCausalLM):
@staticmethod
def move_embed(model: OldStableLmForCausalLM, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)
model.model.rotary_emb = model.model.rotary_emb.to(device)

@staticmethod
def get_layers_for_scaling(
Expand Down
1 change: 1 addition & 0 deletions awq/models/starcoder2.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def get_act_for_scaling(module: OldStarcoder2DecoderLayer):
@staticmethod
def move_embed(model: OldStarcoder2ForCausalLM, device):
model.model.embed_tokens = model.model.embed_tokens.to(device)
model.model.rotary_emb = model.model.rotary_emb.to(device)

@staticmethod
def get_layers_for_scaling(module: OldStarcoder2DecoderLayer, input_feat, module_kwargs):
Expand Down
1 change: 1 addition & 0 deletions awq/models/yi.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def get_act_for_scaling(module):
@staticmethod
def move_embed(model, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)
model.model.rotary_emb = model.model.rotary_emb.to(device)

@staticmethod
def get_layers_for_scaling(module, input_feat, module_kwargs):
Expand Down
54 changes: 28 additions & 26 deletions awq/modules/triton/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,17 +280,18 @@ def awq_dequantize_triton(
triton.cdiv(X, META["BLOCK_SIZE_X"]),
triton.cdiv(Y, META["BLOCK_SIZE_Y"]),
)
awq_dequantize_kernel[grid](
qweight,
scales,
zeros,
group_size,
result,
X,
Y,
BLOCK_SIZE_X=block_size_x,
BLOCK_SIZE_Y=block_size_y,
)
with torch.cuda.device(qweight.device.index):
awq_dequantize_kernel[grid](
qweight,
scales,
zeros,
group_size,
result,
X,
Y,
BLOCK_SIZE_X=block_size_x,
BLOCK_SIZE_Y=block_size_y,
)

return result

Expand Down Expand Up @@ -332,20 +333,21 @@ def awq_gemm_triton(

# A = input, B = qweight, C = result
# A = M x K, B = K x N, C = M x N
awq_gemm_kernel[grid](
input,
qweight,
result,
qzeros,
scales,
M,
N,
K,
group_size,
BLOCK_SIZE_M=block_size_m,
BLOCK_SIZE_N=block_size_n,
BLOCK_SIZE_K=block_size_k,
SPLIT_K=split_k_iters,
)
with torch.cuda.device(qweight.device.index):
awq_gemm_kernel[grid](
input,
qweight,
result,
qzeros,
scales,
M,
N,
K,
group_size,
BLOCK_SIZE_M=block_size_m,
BLOCK_SIZE_N=block_size_n,
BLOCK_SIZE_K=block_size_k,
SPLIT_K=split_k_iters,
)

return result
13 changes: 13 additions & 0 deletions awq/quantize/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,19 @@ def quantize(self):

self.inps = self.inps.to(common_device)

# We need to move the rotary embedding every time we move to a new module.
# Transformers 4.45.0 moved rotary embedding to model definition as of this PR:
# https://github.com/huggingface/transformers/pull/32617
self.awq_model.move_embed(self.model, common_device)

for k, v in self.module_kwargs.items():
# position embeddings found in tuple
if isinstance(v, tuple):
self.module_kwargs[k] = tuple(
item.to(common_device) if isinstance(item, (torch.Tensor, nn.Module))
else item for item in v
)

# [STEP 1]: Get layer, extract linear modules, extract input features
named_linears = get_named_linears(self.modules[i])

Expand Down
29 changes: 8 additions & 21 deletions docs/examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ quant_path = 'mistral-instruct-v0.2-awq'
quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" }

# Load model
model = AutoAWQForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, use_cache=False
)
model = AutoAWQForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

# Quantize
Expand Down Expand Up @@ -50,9 +48,7 @@ quant_path = 'vicuna-7b-v1.5-awq'
quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" }

# Load model
model = AutoAWQForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, use_cache=False, device_map="cuda",
)
model = AutoAWQForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

# Define data loading methods
Expand Down Expand Up @@ -107,9 +103,7 @@ quant_path = 'qwen2-7b-awq'
quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" }

# Load model
model = AutoAWQForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, use_cache=False, device_map="cuda",
)
model = AutoAWQForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

def load_cosmopedia():
Expand Down Expand Up @@ -150,9 +144,7 @@ quant_path = 'deepseek-coder-v2-lite-instruct-awq'
quant_config = { "zero_point": True, "q_group_size": 64, "w_bit": 4, "version": "GEMM" }

# Load model
model = AutoAWQForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, use_cache=False, device_map="cuda",
)
model = AutoAWQForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

def load_openhermes_coding():
Expand Down Expand Up @@ -197,7 +189,7 @@ quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version":

# Load model
model = AutoAWQForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, device_map="cuda",
model_path, low_cpu_mem_usage=True,
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

Expand Down Expand Up @@ -236,9 +228,7 @@ llama_cpp_path = '/workspace/llama.cpp'
quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 6, "version": "GEMM" }

# Load model
model = AutoAWQForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, use_cache=False, device_map="cuda",
)
model = AutoAWQForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

# Quantize
Expand Down Expand Up @@ -293,7 +283,7 @@ quant_path = "qwen2-vl-7b-instruct"
quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM"}

model = AutoAWQForCausalLM.from_pretrained(
model_path, use_cache=False, attn_implementation="flash_attention_2"
model_path, attn_implementation="flash_attention_2"
)

# We define our own quantizer by extending the AwqQuantizer.
Expand Down Expand Up @@ -505,9 +495,7 @@ quant_path = 'minicpm3-4b-awq'
quant_config = { "zero_point": True, "q_group_size": 64, "w_bit": 4, "version": "GEMM" }

# Load model
model = AutoAWQForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, use_cache=False, safetensors=False
)
model = AutoAWQForCausalLM.from_pretrained(model_path, safetensors=False)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

# Quantize
Expand Down Expand Up @@ -591,7 +579,6 @@ model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
device_map="cuda:0"
)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

Expand Down
16 changes: 1 addition & 15 deletions examples/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,7 @@ def main():
parser.add_argument("--version", type=str, default="GEMM", help="Quantization version")

# Model config arguments
parser.add_argument("--low_cpu_mem_usage", action="store_true", help="Use low CPU memory")
parser.add_argument("--no-low_cpu_mem_usage", action="store_false", dest="low_cpu_mem_usage", help="Don't use low CPU memory")
parser.add_argument("--use_cache", action="store_true", help="Use cache")
parser.add_argument("--no-use_cache", action="store_false", dest="use_cache", help="Don't use cache")
parser.add_argument("--device_map", type=str, default="auto", help="Device map for loading the pretrained model")

parser.set_defaults(zero_point=True, low_cpu_mem_usage=True, use_cache=None)
parser.add_argument("--device_map", type=str, default=None, help="Device map for loading the pretrained model")

args = parser.parse_args()

Expand All @@ -33,18 +27,10 @@ def main():
"version": args.version
}

model_config = {
"low_cpu_mem_usage": args.low_cpu_mem_usage,
}

if args.use_cache is not None:
model_config["use_cache"] = args.use_cache

print(f"Loading model from: {args.hf_model_path}")
model = AutoAWQForCausalLM.from_pretrained(
args.hf_model_path,
device_map=args.device_map,
**model_config
)
tokenizer = AutoTokenizer.from_pretrained(args.hf_model_path, trust_remote_code=True)

Expand Down
1 change: 0 additions & 1 deletion examples/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,4 @@
do_sample=True,
max_new_tokens=256,
streamer=streamer,
eos_token_id=[tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")]
)
8 changes: 3 additions & 5 deletions examples/quantize.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer

model_path = 'mistralai/Mistral-7B-Instruct-v0.2'
quant_path = 'mistral-instruct-v0.2-awq'
model_path = 'Qwen/Qwen2.5-14B-Instruct'
quant_path = 'Qwen2.5-14B-Instruct-awq'
quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" }

# Load model
model = AutoAWQForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, use_cache=False
)
model = AutoAWQForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

# Quantize
Expand Down
Loading