SFTTrainer 源码解读: Prepare Train
大约 5 分钟
SFTTrainer 源码解读: Prepare Train
- Prepare Train 总体逻辑
- Prepare Train 代码细节
- _inner_training_loop
- training_step
- compute_loss
- PeftModelForCausalLM.forward
- Linear4bit.forward
1. Prepare Train 总体逻辑
总体逻辑
- 初始化 SFTTrainer
- 执行 trainer.train()
- 执行 inner_training_loop()
- get_train_dataloader
- Setting up training control variables
- 训练轮数 num_train_epochs
- 每轮的总更新步数 num_update_steps_per_epoch
- 总的更新步数 max_steps
- 如果没有设置max_steps,那么max_steps == num_train_epochs * num_update_steps_per_epoch
- 每个更新步 step 的全局批次大小 total_train_batch_size = self._train_batch_size * args.gradient_accumulation_steps * args.world_size
- 开始 Running training
- 如果 args.eval_on_start 为 True,会在训练前进行一次验证
- Epochs 循环
- 进行数据加载 train_dataloader
- Update Steps 循环(global_batch_size,全局批次,也就是更新一次参数的批次大小)
- 获取批次数据 batch_samples
- batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches, args.device)
- Batch 循环 (micro_batch_size,微批次)
- 前向传播 outputs = model(**inputs) # **inputs为批次数据
- PeftModelForCausalLM.forward
- inputs_embeds = self.word_embeddings(input_ids)
- return self.base_model(inputs_embeds=inputs_embeds, **kwargs)
- Linear4bit.forward
- result = self.base_layer(x, *args, **kwargs)
- output = lora_B(lora_A(dropout(x))) * scaling
- result = result + output
- return result
- PeftModelForCausalLM.forward
- 计算Loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
- 反向传播 self.accelerator.backward(loss, **kwargs)
- 前向传播 outputs = model(**inputs) # **inputs为批次数据
- 梯度累积,更新参数 self.optimizer.step()
- 获取批次数据 batch_samples
- 执行 inner_training_loop()
核心训练逻辑
for epoch in range(epochs_trained, num_train_epochs): # Epoch 循环
for i, inputs in enumerate(batch_samples): # Batch 循环
tr_loss_step = self.training_step(model, inputs, num_items_in_batch) # 执行单次训练步骤,包括前向传播、损失计算和反向传播,返回当前 batch 的损失值 tr_loss_step
tr_loss = tr_loss + tr_loss_step # 累积损失
if do_sync_step: # 检查是否需要执行同步更新步骤,do_sync_step 表示当前是否是梯度累积的最后一步或 epoch 的最后一步
"""
支持梯度累积(gradient accumulation),允许多个 batch 的梯度累加后再更新参数
只有在 do_sync_step 为 True 时,才会触发优化器更新,确保计算效率和内存使用的平衡
"""
self.optimizer.step() # 根据累积的梯度更新模型参数
if not self.accelerator.optimizer_step_was_skipped: # 检查优化器更新是否被跳过(例如由于梯度溢出或 NaN)
self.lr_scheduler.step() # 更新当前学习率 (动态调整学习率,只在优化器成功更新时执行,与训练进度保持一致)
model.zero_grad() # 清零梯度,准备下一次计算
self.state.global_step += 1 # 全局训练步数加1
self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, learning_rate=learning_rate) # 记录日志、保存模型检查点、执行评估
2. Prepare Train 代码细节
2.1. SFTTrainer.init
class SFTTrainer(Trainer):
"""
Trainer for Supervised Fine-Tuning (SFT) method.
"""
def __init__():
super().__init__(
model=model,
args=args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
processing_class=processing_class,
compute_loss_func=compute_loss_func,
compute_metrics=compute_metrics,
callbacks=callbacks,
optimizers=optimizers,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
**super_init_kwargs,
)
2.2. Trainer.train()
class Trainer
def train():
"""
Main training entry point.
"""
return inner_training_loop(
args=args,
resume_from_checkpoint=resume_from_checkpoint,
trial=trial,
ignore_keys_for_eval=ignore_keys_for_eval,
)
2.3. _inner_training_loop
def _inner_training_loop():
# Data loader and number of training steps
train_dataloader = self.get_train_dataloader()
# Setting up training control variables:
# number of training epochs: num_train_epochs
# number of training steps per epoch: num_update_steps_per_epoch
# total number of training steps to execute: max_steps
total_train_batch_size = self._train_batch_size * args.gradient_accumulation_steps * args.world_size
(
num_train_epochs,
num_update_steps_per_epoch,
num_examples,
num_train_samples,
epoch_based,
len_dataloader,
max_steps,
) = self.set_initial_training_values(args, train_dataloader, total_train_batch_size)
# Train!
logger.info("***** Running training *****")
logger.info(f" Num examples = {num_examples:,}")
logger.info(f" Num Epochs = {num_train_epochs:,}")
logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}")
if self.args.per_device_train_batch_size != self._train_batch_size:
logger.info(f" Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {max_steps:,}")
logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}")
# tr_loss is a tensor to avoid synchronization of TPUs through .item()
tr_loss = torch.tensor(0.0, device=args.device)
# _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
self._total_loss_scalar = 0.0
self._globalstep_last_logged = self.state.global_step
model.zero_grad()
grad_norm: Optional[float] = None
learning_rate = None
if args.eval_on_start:
self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True)
epochs_trained = 0
for epoch in range(epochs_trained, num_train_epochs):
epoch_dataloader = train_dataloader
steps_in_epoch = (
len(epoch_dataloader)
if len_dataloader is not None
else args.max_steps * args.gradient_accumulation_steps
)
step = -1
epoch_iterator = iter(epoch_dataloader)
# We chunkify the epoch iterator into gradient accumulation steps `n` batches
remainder = num_examples % args.gradient_accumulation_steps
if remainder == 0:
remainder = args.gradient_accumulation_steps
update_step = -1
total_updates = steps_in_epoch // args.gradient_accumulation_steps + 1
if args.gradient_accumulation_steps == 1:
total_updates -= 1
for _ in range(total_updates):
update_step += 1
num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder
batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches, args.device)
for i, inputs in enumerate(batch_samples):
step += 1
do_sync_step = (step + 1) % args.gradient_accumulation_steps == 0 or (step + 1) == steps_in_epoch
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
if (
args.logging_nan_inf_filter
and not is_torch_xla_available()
and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
):
# if loss is nan or inf simply add the average of previous logged losses
tr_loss = tr_loss + tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
else:
if tr_loss.device != tr_loss_step.device:
raise ValueError(
f"Calculated loss must be on the original device: {tr_loss.device} but device in use is {tr_loss_step.device}"
)
tr_loss = tr_loss + tr_loss_step
if do_sync_step:
self.optimizer.step()
# get leaning rate before update
learning_rate = self._get_learning_rate()
model.zero_grad()
self.state.global_step += 1
self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
self._load_best_model()
# add remaining tr_loss
self._total_loss_scalar += tr_loss.item()
effective_global_step = max(self.state.global_step, 0.001) # Avoid ZeroDivisionError
train_loss = self._total_loss_scalar / effective_global_step
return TrainOutput(self.state.global_step, train_loss, metrics)
2.4. training_step
def training_step(self, model, inputs):
"""
Perform a training step on a batch of inputs.
"""
model.train() # 切换到训练模式
if hasattr(self.optimizer, "train") and callable(self.optimizer.train):
self.optimizer.train()
inputs = self._prepare_inputs(inputs)
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
if self.args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training
if self.use_apex:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward() # 使用 Apex 提供的自动混合精度训练,加速并节省显存
else:
# Finally we need to normalize the loss for reporting
if not self.model_accepts_loss_kwargs and self.compute_loss_func is None:
loss = loss / self.args.gradient_accumulation_steps
self.accelerator.backward(loss, **kwargs) # 通用的反向传播调用,兼容分布式训练(Deepspeed/FSDP 等)
return loss.detach()
2.5. compute_loss
def compute_loss(self, model, inputs):
if (self.label_smoother is not None or self.compute_loss_func is not None) and "labels" in inputs:
labels = inputs.pop("labels")
else:
labels = None
outputs = model(**inputs)
if labels is not None:
unwrapped_model = self.accelerator.unwrap_model(model)
if _is_peft_model(unwrapped_model):
model_name = unwrapped_model.base_model.model._get_name()
else:
model_name = unwrapped_model._get_name()
# User-defined compute_loss function
if self.compute_loss_func is not None:
loss = self.compute_loss_func(outputs, labels, num_items_in_batch=num_items_in_batch)
elif model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
loss = self.label_smoother(outputs, labels, shift_labels=True)
else:
loss = self.label_smoother(outputs, labels)
else:
if isinstance(outputs, dict) and "loss" not in outputs:
raise ValueError(
"The model did not return a loss from the inputs, only the following keys: "
f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
)
# We don't use .loss here since the model may return tuples instead of ModelOutput.
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
return loss
2.6. PeftModelForCausalLM.forward
class PeftModelForCausalLM(PeftModel):
"""
Peft model for causal language modeling.
"""
def forward(
self,
input_ids=None,
attention_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
task_ids=None,
**kwargs,
):
batch_size = _get_batch_size(input_ids, inputs_embeds)
if attention_mask is not None:
# concat prompt attention mask
prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to(attention_mask.device)
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
# concat prompt labels
if labels is not None:
prefix_labels = torch.full((batch_size, peft_config.num_virtual_tokens), -100).to(labels.device)
kwargs["labels"] = torch.cat((prefix_labels, labels), dim=1)
prompts = self.get_prompt(batch_size=batch_size, task_ids=task_ids)
prompts = prompts.to(inputs_embeds.dtype)
inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1)
return self.base_model(inputs_embeds=inputs_embeds, **kwargs)
2.7. Linear4bit.forward
class Linear4bit(torch.nn.Module, LoraLayer):
# Lora implemented in a dense layer
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
self._check_forward_args(x, *args, **kwargs)
adapter_names = kwargs.pop("adapter_names", None)
if self.disable_adapters:
if self.merged:
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif adapter_names is not None:
result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
result = self.base_layer(x, *args, **kwargs)
# As per Tim Dettmers, for 4bit, we need to defensively clone here.
# The reason is that in some cases, an error can occur that backprop
# does not work on a manipulated view. This issue may be solved with
# newer PyTorch versions but this would need extensive testing to be
# sure.
result = result.clone()
for active_adapter in self.active_adapters:
if active_adapter not in self.lora_A.keys():
continue
lora_A = self.lora_A[active_adapter]
lora_B = self.lora_B[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
x = self._cast_input_dtype(x, lora_A.weight.dtype)
if not self.use_dora[active_adapter]:
output = lora_B(lora_A(dropout(x))) * scaling
else:
if isinstance(dropout, torch.nn.Identity) or not self.training:
base_result = result
else:
x = dropout(x)
base_result = None
output = self.lora_magnitude_vector[active_adapter](
x,
lora_A=lora_A,
lora_B=lora_B,
scaling=scaling,
base_layer=self.get_base_layer(),
base_result=base_result,
)
if requires_conversion:
output = output.to(expected_dtype)
result = result + output
return result