在 Amazon SageMaker 上玩转 Stable Diffusion: 基于 Dreambooth 的模型微调
本文将以 Stable Diffusion Quick Kit 为例,详细讲解如何利用 Dreambooth 对 Stable Diffusion 模型进行微调,包括基础的 Stable Diffusion 模型微调知识,Dreambooth 微调介绍,并且使用 Quick Kit 通过一个 demo 演示微调效果。
model_dir='/opt/ml/input/fineturned_model/' model = StableDiffusionPipeline.from_pretrained( model_dir, scheduler = DPMSolverMultistepScheduler.from_pretrained(model_dir, subfolder="scheduler"), torch_dtype=torch.float16, )
images_s3uri = 's3://{0}/dreambooth/images/'.format(bucket) inputs = { 'images': images_s3uri } estimator = Estimator( role = role, instance_count=1, instance_type = instance_type, image_uri = image_uri, hyperparameters = hyperparameters, environment = environment ) estimator.fit(inputs)
FROM pytorch/pytorch:1.13.0-cuda11.6-cudnn8-runtime ENV PATH="/opt/ml/code:${PATH}" ENV DEBIAN_FRONTEND noninteractive RUN apt-get update RUN apt-get install --assume-yes apt-utils -y RUN apt update RUN echo "Y"|apt install vim RUN apt install wget git -y RUN apt install libgl1-mesa-glx -y RUN pip install opencv-python-headless RUN mkdir -p /opt/ml/code RUN pip3 install sagemaker-training COPY train.py /opt/ml/code/ COPY ./sd_code/ /opt/ml/code/ RUN pip install -r /opt/ml/code/extensions/sd_dreambooth_extension/requirements.txt ENV SAGEMAKER_PROGRAM train.py RUN export TORCH_CUDA_ARCH_LIST="7.5 8.0 8.6" && export FORCE_CUDA="1" && pip install ninja triton==2.0.0.dev20221120 && git clone https://github.com/xieyongliang/xformers.git /opt/ml/code/repositories/xformers && cd /opt/ml/code/repositories/xformers && git submodule update --init --recursive && pip install -r requirements.txt && pip install -e . ENTRYPOINT []
algorithm_name=dreambooth-finetuning-v3 account=$(aws sts get-caller-identity --query Account --output text) # Get the region defined in the current configuration (default to us-west-2 if none defined) region=$(aws configure get region) fullname="${account}.dkr.ecr.${region}.amazonaws.com/${algorithm_name}:latest" # If the repository doesn't exist in ECR, create it. aws ecr describe-repositories --repository-names "${algorithm_name}" > /dev/null 2>&1 if [ $? -ne 0 ] then aws ecr create-repository --repository-name "${algorithm_name}" > /dev/null fi # Log into Docker pwd=$(aws ecr get-login-password --region ${region}) docker login --username AWS -p ${pwd} ${account}.dkr.ecr.${region}.amazonaws.com # Build the docker image locally with the image name and then push it to ECR # with the full name. mkdir -p ./sd_code/extensions cd ./sd_code/extensions/ && git clone https://github.com/qingyuan18/sd_dreambooth_extension.git cd ../../ docker build -t ${algorithm_name} ./ -f ./dockerfile_v3 > ./docker_build.log docker tag ${algorithm_name} ${fullname} docker push ${fullname} rm -rf ./sd_code
if shared.force_cpu: import modules.shared no_safe = modules.shared.cmd_opts.disable_safe_unpickle modules.shared.cmd_opts.disable_safe_unpickle = True
from helpers.mytqdm import mytqdm
hyperparameters = { 'model_name':'aws-trained-dreambooth-model', 'mixed_precision':'fp16', 'pretrained_model_name_or_path': model_name, 'instance_data_dir':instance_dir, 'class_data_dir':class_dir, 'with_prior_preservation':True, 'models_path': '/opt/ml/model/', 'manul_upload_model_path':s3_model_output_location, 'instance_prompt': instance_prompt, ……} estimator = Estimator( role = role, instance_count=1, instance_type = instance_type, image_uri = image_uri, hyperparameters = hyperparameters )
python convert_original_stable_diffusion_to_diffusers.py —checkpoint_path ./models_ckpt/768-v-ema.ckpt —dump_path ./models_diffuser
def upload_directory_to_s3(local_directory, dest_s3_path): bucket,s3_prefix=get_bucket_and_key(dest_s3_path) for root, dirs, files in os.walk(local_directory): for filename in files: local_path = os.path.join(root, filename) relative_path = os.path.relpath(local_path, local_directory) s3_path = os.path.join(s3_prefix, relative_path).replace("\\", "/") s3_client.upload_file(local_path, bucket, s3_path) print(f'File {local_path} uploaded to s3://{bucket}/{s3_path}') for subdir in dirs: upload_directory_to_s3(local_directory+"/"+subdir,dest_s3_path+"/"+subdir) s_pipeline.save_pretrained(args.models_path) ### manually upload trained db model dirs to s3 path##### #### to eliminate sagemaker tar process##### print(f"manul_upload_model_path is {args.manul_upload_model_path}") upload_directory_to_s3(args.models_path,args.manul_upload_model_path)
print(f"Total VRAM: {gb}") if 24 > gb >= 16: attention = "xformers" not_cache_latents = False train_text_encoder = True use_ema = True if 16 > gb >= 10: train_text_encoder = False use_ema = False if gb < 10: use_cpu = True use_8bit_adam = False mixed_precision = 'no'
***** Running training ***** Instantaneous batch size per device = 1 Total train batch size (w. parallel, distributed & accumulation) = 1 Gradient Accumulation steps = 1 Total optimization steps = 1000 Training settings: CPU: False Adam: True, Prec: fp16, Grad: True, TextTr: False EM: True, LR: 2e-06 LORA:False Allocated: 10.5GB Reserved: 11.7GB
***** Running training ***** Instantaneous batch size per device = 1 Total train batch size (w. parallel, distributed & accumulation) = 1 Gradient Accumulation steps = 1 Total optimization steps = 1000 Training settings: CPU: False Adam: True, Prec: fp16, Grad: True, TextTr: False EM: True, LR: 2e-06 LORA:False Allocated: 5.5GB Reserved: 5.6GB
- ‘PYTORCH_CUDA_ALLOC_CONF’:‘max_split_size_mb:32′对于显存碎片化引起的 CUDA OOM,可以将 PYTORCH_CUDA_ALLOC_CONF 的 max_split_size_mb 设为较小值。
- train_batch_size’:1每次处理的图片数量,如果 instance images 或者 class image 不多的情况下(小于10张),可以把该值设置为1,减少一个批次处理的图片数量,一定程度降低显存使用。
- ‘sample_batch_size’: 1和 train_batch_size 对应,一次进行采样加噪和降噪的批次吞吐量,调低该值也对应降低显存使用。
- not_cache_latents 另外,Stable Diffusion 的训练,是基于 Latent Diffusion Models,原始模型会缓存 latent,而我们主要是训练 instance prompt, class prompt 下的正则化,因此在 GPU 显存紧张情况下,我们可以选择不缓存 latent,最大限度降低显存占用。
- ‘gradient_accumulation_steps’ 梯度更新的批次,如果训练 steps 较大,比如1000,可以增大梯度更新的步数,累计到一定批次再一次性更新,该值越大,显存占用越高,如果希望降低显存,可以在牺牲一部分训练时长的前提下减少该值。注意如果选择了重新训练文本编码器 text_encode,不支持梯度累积,且多 GPU 的机器上开启了 accelerate 的多卡分布式训练,则批量梯度更新 gradient_accumulation_steps 只能设置为1,否则文本编码器的重训练将被禁用。
#使用了zwx作为触发词, 模型训练好之后我们使用这个词来生成图 instance_prompt="photo\ of\ zwx\ toy" class_prompt="photo\ of\ a\ cat toy" #notebook训练代码说明 #设置超参 environment = { 'PYTORCH_CUDA_ALLOC_CONF':'max_split_size_mb:32', 'LD_LIBRARY_PATH':"${LD_LIBRARY_PATH}:/opt/conda/lib/" } hyperparameters = { 'model_name':'aws-trained-dreambooth-model', 'mixed_precision':'fp16', 'pretrained_model_name_or_path': model_name, 'instance_data_dir':instance_dir, 'class_data_dir':class_dir, 'with_prior_preservation':True, 'models_path': '/opt/ml/model/', 'instance_prompt': instance_prompt, 'class_prompt':class_prompt, 'resolution':512, 'train_batch_size':1, 'sample_batch_size': 1, 'gradient_accumulation_steps':1, 'learning_rate':2e-06, 'lr_scheduler':'constant', 'lr_warmup_steps':0, 'num_class_images':50, 'max_train_steps':300, 'save_steps':100, 'attention':'xformers', 'prior_loss_weight': 0.5, 'use_ema':True, 'train_text_encoder':False, 'not_cache_latents':True, 'gradient_checkpointing':True, 'save_use_epochs': False, 'use_8bit_adam': False } hyperparameters = json_encode_hyperparameters(hyperparameters) #启动sagemaker training job from sagemaker.estimator import Estimator inputs = { 'images': f"s3://{bucket}/dreambooth/images/" } estimator = Estimator( role = role, instance_count=1, instance_type = instance_type, image_uri = image_uri, hyperparameters = hyperparameters, environment = environment ) estimator.fit(inputs)
Stable Diffusion Quick Kit Dreambooth 微调文档:
https://catalog.us-east-1.prod.workshops.aws/workshops/1ac668b1-dbd3-4b45-bf0a-5bc36138fcf1/zh-CN/4-configuration-stablediffusion/4-4-find-tuning-notebook
Dreambooth 论文:
https://dreambooth.github.io/
Dreambooth 原始开源 github: https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_dreambooth_training.ipynb#scrollTo=rscg285SBh4M
Huggingface diffuser 格式转换工具:
https://github.com/huggingface/diffusers/tree/main/scripts
Stable diffusion webui dreambooth extendtion 插件:
https://github.com/d8ahazard/sd_dreambooth_extension.git
Facebook xformers 开源:
https://github.com/facebookresearch/xformers
留言 | Comments