跳至主要內容

SFTTrainer 源码解读: Prepare Train

Liz大约 5 分钟LLMSFTTrainerSource CodePrepare Train

SFTTrainer 源码解读: Prepare Train

  • Prepare Train 总体逻辑
  • Prepare Train 代码细节
    • _inner_training_loop
    • training_step
    • compute_loss
    • PeftModelForCausalLM.forward
    • Linear4bit.forward

1. Prepare Train 总体逻辑

总体逻辑

  • 初始化 SFTTrainer
  • 执行 trainer.train()
    • 执行 inner_training_loop()
      • get_train_dataloader
      • Setting up training control variables
        • 训练轮数 num_train_epochs
        • 每轮的总更新步数 num_update_steps_per_epoch
        • 总的更新步数 max_steps
          • 如果没有设置max_steps,那么max_steps == num_train_epochs * num_update_steps_per_epoch
        • 每个更新步 step 的全局批次大小 total_train_batch_size = self._train_batch_size * args.gradient_accumulation_steps * args.world_size
      • 开始 Running training
      • 如果 args.eval_on_start 为 True,会在训练前进行一次验证
      • Epochs 循环
        • 进行数据加载 train_dataloader
        • Update Steps 循环(global_batch_size,全局批次,也就是更新一次参数的批次大小)
          • 获取批次数据 batch_samples
            • batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches, args.device)
          • Batch 循环 (micro_batch_size,微批次)
            • 前向传播 outputs = model(**inputs) # **inputs为批次数据
              • PeftModelForCausalLM.forward
                • inputs_embeds = self.word_embeddings(input_ids)
                • return self.base_model(inputs_embeds=inputs_embeds, **kwargs)
              • Linear4bit.forward
                • result = self.base_layer(x, *args, **kwargs)
                • output = lora_B(lora_A(dropout(x))) * scaling
                • result = result + output
                • return result
            • 计算Loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
            • 反向传播 self.accelerator.backward(loss, **kwargs)
          • 梯度累积,更新参数 self.optimizer.step()

核心训练逻辑

for epoch in range(epochs_trained, num_train_epochs): # Epoch 循环
    for i, inputs in enumerate(batch_samples): # Batch 循环
        tr_loss_step = self.training_step(model, inputs, num_items_in_batch) # 执行单次训练步骤,包括前向传播、损失计算和反向传播,返回当前 batch 的损失值 tr_loss_step
        tr_loss = tr_loss + tr_loss_step # 累积损失
        if do_sync_step: # 检查是否需要执行同步更新步骤,do_sync_step 表示当前是否是梯度累积的最后一步或 epoch 的最后一步
            """
            支持梯度累积(gradient accumulation),允许多个 batch 的梯度累加后再更新参数
            只有在 do_sync_step 为 True 时,才会触发优化器更新,确保计算效率和内存使用的平衡
            """
            self.optimizer.step() # 根据累积的梯度更新模型参数
            if not self.accelerator.optimizer_step_was_skipped: # 检查优化器更新是否被跳过(例如由于梯度溢出或 NaN)
                self.lr_scheduler.step() # 更新当前学习率 (动态调整学习率,只在优化器成功更新时执行,与训练进度保持一致)
            model.zero_grad() # 清零梯度,准备下一次计算
            self.state.global_step += 1 # 全局训练步数加1
            self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, learning_rate=learning_rate) # 记录日志、保存模型检查点、执行评估

2. Prepare Train 代码细节

2.1. SFTTrainer.init

class SFTTrainer(Trainer):
    """
    Trainer for Supervised Fine-Tuning (SFT) method.
    """

    def __init__():
        super().__init__(
            model=model,
            args=args,
            data_collator=data_collator,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            processing_class=processing_class,
            compute_loss_func=compute_loss_func,
            compute_metrics=compute_metrics,
            callbacks=callbacks,
            optimizers=optimizers,
            preprocess_logits_for_metrics=preprocess_logits_for_metrics,
            **super_init_kwargs,
        )

2.2. Trainer.train()

class Trainer
    def train():
        """
        Main training entry point.
        """

        return inner_training_loop(
            args=args,
            resume_from_checkpoint=resume_from_checkpoint,
            trial=trial,
            ignore_keys_for_eval=ignore_keys_for_eval,
        )

2.3. _inner_training_loop

def _inner_training_loop():
    # Data loader and number of training steps
    train_dataloader = self.get_train_dataloader()

    # Setting up training control variables:
    # number of training epochs: num_train_epochs
    # number of training steps per epoch: num_update_steps_per_epoch
    # total number of training steps to execute: max_steps
    total_train_batch_size = self._train_batch_size * args.gradient_accumulation_steps * args.world_size
    (
        num_train_epochs,
        num_update_steps_per_epoch,
        num_examples,
        num_train_samples,
        epoch_based,
        len_dataloader,
        max_steps,
    ) = self.set_initial_training_values(args, train_dataloader, total_train_batch_size)

    # Train!
    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {num_examples:,}")
    logger.info(f"  Num Epochs = {num_train_epochs:,}")
    logger.info(f"  Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}")
    if self.args.per_device_train_batch_size != self._train_batch_size:
        logger.info(f"  Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}")
    logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}")
    logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
    logger.info(f"  Total optimization steps = {max_steps:,}")
    logger.info(f"  Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}")

    # tr_loss is a tensor to avoid synchronization of TPUs through .item()
    tr_loss = torch.tensor(0.0, device=args.device)
    # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
    self._total_loss_scalar = 0.0
    self._globalstep_last_logged = self.state.global_step
    model.zero_grad()
    grad_norm: Optional[float] = None
    learning_rate = None

    if args.eval_on_start:
        self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True)

    epochs_trained = 0

    for epoch in range(epochs_trained, num_train_epochs):
        epoch_dataloader = train_dataloader
        
        steps_in_epoch = (
            len(epoch_dataloader)
            if len_dataloader is not None
            else args.max_steps * args.gradient_accumulation_steps
        )
        
        step = -1
        epoch_iterator = iter(epoch_dataloader)
        # We chunkify the epoch iterator into gradient accumulation steps `n` batches
        remainder = num_examples % args.gradient_accumulation_steps
        if remainder == 0:
            remainder = args.gradient_accumulation_steps
        update_step = -1
        total_updates = steps_in_epoch // args.gradient_accumulation_steps + 1
        if args.gradient_accumulation_steps == 1:
            total_updates -= 1
            
        for _ in range(total_updates):
            update_step += 1
            num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder
            batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches, args.device)
            for i, inputs in enumerate(batch_samples):
                step += 1

                do_sync_step = (step + 1) % args.gradient_accumulation_steps == 0 or (step + 1) == steps_in_epoch

                tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
                
                if (
                    args.logging_nan_inf_filter
                    and not is_torch_xla_available()
                    and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
                ):
                    # if loss is nan or inf simply add the average of previous logged losses
                    tr_loss = tr_loss + tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
                else:
                    if tr_loss.device != tr_loss_step.device:
                        raise ValueError(
                            f"Calculated loss must be on the original device: {tr_loss.device} but device in use is {tr_loss_step.device}"
                        )
                    tr_loss = tr_loss + tr_loss_step

                if do_sync_step:
                    self.optimizer.step()

                    # get leaning rate before update
                    learning_rate = self._get_learning_rate()

                    model.zero_grad()
                    self.state.global_step += 1
                    self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
    
    logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
    
    if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
        self._load_best_model()
    
    # add remaining tr_loss
    self._total_loss_scalar += tr_loss.item()
    effective_global_step = max(self.state.global_step, 0.001)  # Avoid ZeroDivisionError
    train_loss = self._total_loss_scalar / effective_global_step

    return TrainOutput(self.state.global_step, train_loss, metrics)

2.4. training_step

def training_step(self, model, inputs):
    """
        Perform a training step on a batch of inputs.
    """

    model.train() # 切换到训练模式
    
    if hasattr(self.optimizer, "train") and callable(self.optimizer.train):
        self.optimizer.train()
    
    inputs = self._prepare_inputs(inputs)
    
    loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
    
    if self.args.n_gpu > 1:
        loss = loss.mean()  # mean() to average on multi-gpu parallel training

    if self.use_apex:
        with amp.scale_loss(loss, self.optimizer) as scaled_loss:
            scaled_loss.backward() # 使用 Apex 提供的自动混合精度训练,加速并节省显存
    else:
        # Finally we need to normalize the loss for reporting
        if not self.model_accepts_loss_kwargs and self.compute_loss_func is None:
            loss = loss / self.args.gradient_accumulation_steps

        self.accelerator.backward(loss, **kwargs) #  通用的反向传播调用,兼容分布式训练(Deepspeed/FSDP 等)

        return loss.detach()

2.5. compute_loss

def compute_loss(self, model, inputs):
    if (self.label_smoother is not None or self.compute_loss_func is not None) and "labels" in inputs:
            labels = inputs.pop("labels")
        else:
            labels = None
    
    outputs = model(**inputs)

    if labels is not None:
        unwrapped_model = self.accelerator.unwrap_model(model)
        if _is_peft_model(unwrapped_model):
            model_name = unwrapped_model.base_model.model._get_name()
        else:
            model_name = unwrapped_model._get_name()
        # User-defined compute_loss function
        if self.compute_loss_func is not None:
            loss = self.compute_loss_func(outputs, labels, num_items_in_batch=num_items_in_batch)
        elif model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
            loss = self.label_smoother(outputs, labels, shift_labels=True)
        else:
            loss = self.label_smoother(outputs, labels)
    else:
        if isinstance(outputs, dict) and "loss" not in outputs:
            raise ValueError(
                "The model did not return a loss from the inputs, only the following keys: "
                f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
            )
        # We don't use .loss here since the model may return tuples instead of ModelOutput.
        loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
    
    return loss

2.6. PeftModelForCausalLM.forward

class PeftModelForCausalLM(PeftModel):
    """
    Peft model for causal language modeling.
    """

    def forward(
            self,
            input_ids=None,
            attention_mask=None,
            inputs_embeds=None,
            labels=None,
            output_attentions=None,
            output_hidden_states=None,
            return_dict=None,
            task_ids=None,
            **kwargs,
    ):
        batch_size = _get_batch_size(input_ids, inputs_embeds)
        if attention_mask is not None:
            # concat prompt attention mask
            prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to(attention_mask.device)
            attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)
        # concat prompt labels
        if labels is not None:
            prefix_labels = torch.full((batch_size, peft_config.num_virtual_tokens), -100).to(labels.device)
            kwargs["labels"] = torch.cat((prefix_labels, labels), dim=1)
        prompts = self.get_prompt(batch_size=batch_size, task_ids=task_ids)
        prompts = prompts.to(inputs_embeds.dtype)
        inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1)
        return self.base_model(inputs_embeds=inputs_embeds, **kwargs)

2.7. Linear4bit.forward

class Linear4bit(torch.nn.Module, LoraLayer):
    # Lora implemented in a dense layer

    def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
        self._check_forward_args(x, *args, **kwargs)
        adapter_names = kwargs.pop("adapter_names", None)

        if self.disable_adapters:
            if self.merged:
                self.unmerge()
            result = self.base_layer(x, *args, **kwargs)
        elif adapter_names is not None:
            result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
        elif self.merged:
            result = self.base_layer(x, *args, **kwargs)
        else:
            result = self.base_layer(x, *args, **kwargs)
            # As per Tim Dettmers, for 4bit, we need to defensively clone here.
            # The reason is that in some cases, an error can occur that backprop
            # does not work on a manipulated view. This issue may be solved with
            # newer PyTorch versions but this would need extensive testing to be
            # sure.
            result = result.clone()

            for active_adapter in self.active_adapters:
                if active_adapter not in self.lora_A.keys():
                    continue
                lora_A = self.lora_A[active_adapter]
                lora_B = self.lora_B[active_adapter]
                dropout = self.lora_dropout[active_adapter]
                scaling = self.scaling[active_adapter]

                requires_conversion = not torch.is_autocast_enabled()
                if requires_conversion:
                    expected_dtype = result.dtype
                    x = self._cast_input_dtype(x, lora_A.weight.dtype)

                if not self.use_dora[active_adapter]:
                    output = lora_B(lora_A(dropout(x))) * scaling
                else:
                    if isinstance(dropout, torch.nn.Identity) or not self.training:
                        base_result = result
                    else:
                        x = dropout(x)
                        base_result = None

                    output = self.lora_magnitude_vector[active_adapter](
                        x,
                        lora_A=lora_A,
                        lora_B=lora_B,
                        scaling=scaling,
                        base_layer=self.get_base_layer(),
                        base_result=base_result,
                    )
                if requires_conversion:
                    output = output.to(expected_dtype)
                result = result + output

        return result