Chat-LLaMA-8bit-LoRA/merge_adapter_weights.py
2023-03-22 04:10:30 -06:00

54 lines
2.1 KiB
Python

# Taken from https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neox-20b_peft/merge_peft_adapter.py for ease of use (copy and pasted)
from dataclasses import dataclass, field
from typing import Optional
import peft
import torch
from peft import PeftConfig, PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, LlamaForCausalLM, LlamaTokenizer
@dataclass
class ScriptArguments:
"""
The name of the Casual LM model we wish to fine with PPO
"""
# NOTE: gpt2 models use Conv1D instead of Linear layers which are not yet supported in 8 bit mode
# models like gpt-neo* models are more suitable
model_name: Optional[str] = field(default="serpdotai/llama-hh-lora-7B", metadata={"help": "the lora model name"})
output_dir: Optional[str] = field(default="llama-7B-hh-adapter-merged", metadata={"help": "the output directory"})
push_to_hub: Optional[bool] = field(default=False, metadata={"help": "push the model to the huggingface hub"})
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
peft_model_id = script_args.model_name
peft_config = PeftConfig.from_pretrained(peft_model_id)
model = LlamaForCausalLM.from_pretrained(
peft_config.base_model_name_or_path,
return_dict=True,
torch_dtype=torch.float16,
)
tokenizer =LlamaTokenizer.from_pretrained(peft_config.base_model_name_or_path)
# Load the Lora model
model = PeftModel.from_pretrained(model, peft_model_id)
model.eval()
key_list = [key for key, _ in model.base_model.model.named_modules() if "lora" not in key]
for key in key_list:
parent, target, target_name = model.base_model._get_submodules(key)
if isinstance(target, peft.tuners.lora.Linear):
bias = target.bias is not None
new_module = torch.nn.Linear(target.in_features, target.out_features, bias=bias)
model.base_model._replace_module(parent, target_name, new_module, target)
model = model.base_model.model
if script_args.push_to_hub:
model.push_to_hub(f"{script_args.model_name}-adapter-merged", use_temp_dir=False)
model.save_pretrained(script_args.output_dir)