Skip to content

Commit

Permalink
resolve redundancies
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt committed May 29, 2024
1 parent f0293b7 commit 077db2d
Show file tree
Hide file tree
Showing 13 changed files with 12 additions and 33 deletions.
10 changes: 0 additions & 10 deletions litgpt/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,6 @@
from jsonargparse import ArgumentParser


def _new_parser(**kwargs: Any) -> "ArgumentParser":
from jsonargparse import ActionConfigFile, ArgumentParser

parser = ArgumentParser(**kwargs)
parser.add_argument(
"-c", "--config", action=ActionConfigFile, help="Path to a configuration file in json or yaml format."
)
return parser


def main() -> None:
parser_data = {
"download": {"fn": download_fn, "_help": "Download weights or tokenizer data from the Hugging Face Hub."},
Expand Down
3 changes: 1 addition & 2 deletions litgpt/finetune/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@


def setup(
checkpoint_dir: str,
checkpoint_dir: Path,
out_dir: Path = Path("out/finetune/adapter"),
precision: Optional[str] = None,
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8-training"]] = None,
Expand Down Expand Up @@ -80,7 +80,6 @@ def setup(
devices = parse_devices(devices)
out_dir = init_out_dir(out_dir)

checkpoint_dir = Path(checkpoint_dir)
check_valid_checkpoint_dir(checkpoint_dir)
config = Config.from_file(checkpoint_dir / "model_config.yaml")

Expand Down
3 changes: 1 addition & 2 deletions litgpt/finetune/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@


def setup(
checkpoint_dir: str,
checkpoint_dir: Path,
out_dir: Path = Path("out/finetune/adapter-v2"),
precision: Optional[str] = None,
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8-training"]] = None,
Expand Down Expand Up @@ -81,7 +81,6 @@ def setup(
devices = parse_devices(devices)
out_dir = init_out_dir(out_dir)

checkpoint_dir = Path(checkpoint_dir)
check_valid_checkpoint_dir(checkpoint_dir)
config = Config.from_file(checkpoint_dir / "model_config.yaml")

Expand Down
3 changes: 1 addition & 2 deletions litgpt/finetune/full.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@


def setup(
checkpoint_dir: str,
checkpoint_dir: Path,
out_dir: Path = Path("out/finetune/full"),
precision: Optional[str] = None,
devices: Union[int, str] = 1,
Expand Down Expand Up @@ -78,7 +78,6 @@ def setup(
devices = parse_devices(devices)
out_dir = init_out_dir(out_dir)

checkpoint_dir = Path(checkpoint_dir)
check_valid_checkpoint_dir(checkpoint_dir)
config = Config.from_file(checkpoint_dir / "model_config.yaml")

Expand Down
3 changes: 1 addition & 2 deletions litgpt/finetune/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@


def setup(
checkpoint_dir: str,
checkpoint_dir: Path,
out_dir: Path = Path("out/finetune/lora"),
precision: Optional[str] = None,
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8-training"]] = None,
Expand Down Expand Up @@ -99,7 +99,6 @@ def setup(
devices = parse_devices(devices)
out_dir = init_out_dir(out_dir)

checkpoint_dir = Path(checkpoint_dir)
check_valid_checkpoint_dir(checkpoint_dir)
config = Config.from_file(
checkpoint_dir / "model_config.yaml",
Expand Down
3 changes: 1 addition & 2 deletions litgpt/generate/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


def main(
checkpoint_dir: str,
checkpoint_dir: Path,
prompt: str = "What food do llamas eat?",
input: str = "",
adapter_path: Path = Path("out/finetune/adapter/final/lit_model.pth.adapter"),
Expand Down Expand Up @@ -63,7 +63,6 @@ def main(
samples.
precision: Indicates the Fabric precision setting to use.
"""
checkpoint_dir = Path(checkpoint_dir)
precision = precision or get_default_supported_precision(training=False)

plugins = None
Expand Down
3 changes: 1 addition & 2 deletions litgpt/generate/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


def main(
checkpoint_dir: str,
checkpoint_dir: Path,
prompt: str = "What food do llamas eat?",
input: str = "",
adapter_path: Path = Path("out/finetune/adapter-v2/final/lit_model.pth.adapter_v2"),
Expand Down Expand Up @@ -63,7 +63,6 @@ def main(
samples.
precision: Indicates the Fabric precision setting to use.
"""
checkpoint_dir = Path(checkpoint_dir)
precision = precision or get_default_supported_precision(training=False)

plugins = None
Expand Down
3 changes: 1 addition & 2 deletions litgpt/generate/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def generate(

@torch.inference_mode()
def main(
checkpoint_dir: str,
checkpoint_dir: Path,
prompt: str = "What food do llamas eat?",
*,
num_samples: int = 1,
Expand Down Expand Up @@ -178,7 +178,6 @@ def main(
precision: Indicates the Fabric precision setting to use.
compile: Whether to compile the model.
"""
checkpoint_dir = Path(checkpoint_dir)
precision = precision or get_default_supported_precision(training=False)

plugins = None
Expand Down
3 changes: 1 addition & 2 deletions litgpt/generate/full.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


def main(
checkpoint_dir: str,
checkpoint_dir: Path,
prompt: str = "What food do llamas eat?",
input: str = "",
finetuned_path: Path = Path("out/full/alpaca/lit_model_finetuned.pth"),
Expand Down Expand Up @@ -62,7 +62,6 @@ def main(
samples.
precision: Indicates the Fabric precision setting to use.
"""
checkpoint_dir = Path(checkpoint_dir)
precision = precision or get_default_supported_precision(training=False)

plugins = None
Expand Down
3 changes: 1 addition & 2 deletions litgpt/generate/sequentially.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def replace_device(module: torch.nn.Module, replace: torch.device, by: torch.dev

@torch.inference_mode()
def main(
checkpoint_dir: str,
checkpoint_dir: Path,
prompt: str = "What food do llamas eat?",
*,
num_samples: int = 1,
Expand Down Expand Up @@ -156,7 +156,6 @@ def main(
precision: Indicates the Fabric precision setting to use.
compile: Whether to compile the model.
"""
checkpoint_dir = Path(checkpoint_dir)
precision = precision or get_default_supported_precision(training=False)

plugins = None
Expand Down
2 changes: 1 addition & 1 deletion litgpt/generate/tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def tensor_parallel(fabric: L.Fabric, model: GPT) -> GPT:

@torch.inference_mode()
def main(
checkpoint_dir: str,
checkpoint_dir: Path,
prompt: str = "What food do llamas eat?",
*,
num_samples: int = 1,
Expand Down
3 changes: 1 addition & 2 deletions litgpt/scripts/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def load_param(param: Union[torch.Tensor, NotYetLoadedTensor], name: str, dtype:

@torch.inference_mode()
def convert_hf_checkpoint(
checkpoint_dir: str,
checkpoint_dir: Path,
*,
model_name: Optional[str] = None,
dtype: Optional[str] = None,
Expand All @@ -302,7 +302,6 @@ def convert_hf_checkpoint(
dtype: The data type to convert the checkpoint files to. If not specified, the weights will remain in the
dtype they are downloaded in.
"""
checkpoint_dir = Path(checkpoint_dir)
if model_name is None:
model_name = checkpoint_dir.name
if dtype is not None:
Expand Down
3 changes: 1 addition & 2 deletions litgpt/scripts/merge_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


def merge_lora(
checkpoint_dir: str,
checkpoint_dir: Path,
pretrained_checkpoint_dir: Optional[Path] = None,
precision: Optional[str] = None
) -> None:
Expand All @@ -34,7 +34,6 @@ def merge_lora(
precision: Optional precision setting to instantiate the model weights in. By default, this will
automatically be inferred from the metadata in the given ``checkpoint_dir`` directory.
"""
checkpoint_dir = Path(checkpoint_dir)
check_valid_checkpoint_dir(checkpoint_dir, model_filename="lit_model.pth.lora")
if pretrained_checkpoint_dir is not None:
check_valid_checkpoint_dir(pretrained_checkpoint_dir)
Expand Down

0 comments on commit 077db2d

Please sign in to comment.