SFTTrainer 源码解读: Prepare Dataset
大约 4 分钟
SFTTrainer 源码解读: Prepare Dataset
- Prepare Dataset 总体逻辑
- Prepare Dataset 代码细节
- SFTTrainer.init
- DataCollatorForLanguageModeling
- _prepare_dataset
1. Prepare Dataset 总体逻辑
总体逻辑
- 1.如果 processing_class 为 None,则使用基础模型的 tokenizer
- 2.处理 Data collator,在右侧填充 pad_token,使长度一致
- 3.查看 dataset 列名中是否有 "input_ids",如果有,表示已进行过预处理,后续将跳过预处理步骤
- 4.如果列名中有 "input_ids"(表示已进行过预处理),就会忽略 formatting_func,否则进行 formatting_func 处理
- 根据 formatting_func 的返回类型自动判断是否启用批处理,然后对 dataset 进行映射操作,把每个样本格式化成 {"text": formatting_func结果} 的形式
- 5.如果 dataset 列名中包含 "prompt" 和 "completion" 字段
- 判断是否是对话格式(包含"role"和"content"的格式)还是文本格式
- 将 dataset 进行映射操作
- 如果是是对话格式,把每个样本格式化成 {"messages": example["prompt"] + example["completion"]} 的形式
- 如果是是文本格式,把每个样本格式化成 {"text": example["prompt"] + example["completion"]} 的形式
- 6.进行预处理步骤(如果列名中有 "input_ids" 则不处理)
- 将对话格式处理为统一的 ChatML 格式:{'messages': [{'role': 'user', 'content': 'What color is the sky?'},{'role': 'assistant', 'content': 'It is blue.'}]}
- 应用 tokenizer.apply_chat_template 将对话格式 "messages" 都转为文本格式:{"text": "xxx"}
- 将 "text" 字段采用 tokenizer 进行 token 化,生成 "input_ids" 和 "attention_mask" 字段
- 7.返回处理后的 dataset,其中一定包含三个字段:'text', 'input_ids', 'attention_mask'
2. Prepare Dataset 代码细节
2.1. SFTTrainer.init
class SFTTrainer(Trainer):
"""
Trainer for Supervised Fine-Tuning (SFT) method.
"""
def __init__():
# Handle the tokenizer
if processing_class is None:
processing_class = AutoTokenizer.from_pretrained(model_id)
# Data collator
if data_collator is None:
# Get the pad token: if not provided, use the one from the processing class or the eos token
# if the processing class does not have a pad token.
pad_token = args.pad_token or processing_class.pad_token or processing_class.eos_token
pad_token_id = processing_class.convert_tokens_to_ids(pad_token)
if pad_token_id is None:
raise ValueError(
f"The specified `pad_token` ('{pad_token}') is not found in the vocabulary of the given "
f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `pad_token` exists "
"in the vocabulary before using it as a padding token."
)
data_collator = DataCollatorForLanguageModeling(pad_token_id)
# Dataset
train_dataset = self._prepare_dataset(
train_dataset, processing_class, args, args.packing, formatting_func, "train"
)
if eval_dataset is not None:
packing = args.packing if args.eval_packing is None else args.eval_packing
if isinstance(eval_dataset, dict):
eval_dataset = {
key: self._prepare_dataset(dataset, processing_class, args, packing, formatting_func, key)
for key, dataset in eval_dataset.items()
}
else:
eval_dataset = self._prepare_dataset(
eval_dataset, processing_class, args, packing, formatting_func, "eval"
)
2.2. DataCollatorForLanguageModeling
@dataclass
class DataCollatorForLanguageModeling(DataCollatorMixin):
"""
Data collator used for language modeling data. Inputs are dynamically padded to the maximum length of a batch if
they are not all of the same length.
Args:
pad_token_id (`int`):
Token ID to use for padding.
return_tensors (`str`, *optional*, defaults to `"pt"`):
Type of Tensor to return. Only `"pt"` is currently supported.
Examples:
from trl import DataCollatorForLanguageModeling
collator = DataCollatorForLanguageModeling(pad_token_id=0)
examples = [
{"input_ids": [1, 2, 3]},
{"input_ids": [4, 5]}
]
collator(examples)
{'input_ids': tensor([[ 1, 2, 3],
[ 4, 5, 0]]),
'attention_mask': tensor([[ 1, 1, 1],
[ 1, 1, 0]]),
'labels': tensor([[ 1, 2, 3],
[ 4, 5, -100]])
"""
pad_token_id: int
return_tensors: str = "pt"
def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]:
# Convert to tensor
input_ids = [torch.tensor(example["input_ids"]) for example in examples]
attention_mask = [torch.ones_like(input_ids) for input_ids in input_ids]
labels = [torch.tensor(example["input_ids"]) for example in examples]
# Pad
output = {}
output["input_ids"] = pad(input_ids, padding_value=self.pad_token_id, padding_side="right")
output["attention_mask"] = pad(attention_mask, padding_value=0, padding_side="right")
output["labels"] = pad(labels, padding_value=-100, padding_side="right")
return output
2.3. _prepare_dataset
def _prepare_dataset():
# If the dataset is already preprocessed (tokenized), skip the processing steps.
column_names = list(next(iter(dataset)).keys())
is_processed = "input_ids" in column_names
# Apply the formatting function if any
if formatting_func is not None and is_processed:
warnings.warn(
"You passed a dataset that is already processed (contains an `input_ids` field) together with a "
"formatting function. Therefore `formatting_func` will be ignored. Either remove the "
"`formatting_func` or pass a dataset that is not already processed.",
UserWarning,
)
if formatting_func is not None and not is_processed:
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Applying formatting function to {dataset_name} dataset"
batched = isinstance(formatting_func(next(iter(dataset))), list)
def _func(example):
return {"text": formatting_func(example)}
dataset = dataset.map(_func, batched=batched, **map_kwargs)
# If the dataset is prompt-completion, convert it to language modeling type
first_example = next(iter(dataset))
if "prompt" in first_example.keys() and "completion" in first_example.keys():
key = "messages" if is_conversational(first_example) else "text"
def concat_prompt_completion(example):
return {key: example["prompt"] + example["completion"]}
dataset = dataset.map(concat_prompt_completion, remove_columns=["prompt", "completion"])
if not is_processed:
# Convert the dataset to ChatML if needed
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Converting {dataset_name} dataset to ChatML"
column_names = next(iter(dataset)).keys()
dataset = dataset.map(
maybe_convert_to_chatml,
remove_columns="conversations" if "conversations" in column_names else None,
**map_kwargs,
)
# Apply the chat template if needed
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset"
column_names = next(iter(dataset)).keys()
dataset = dataset.map(
maybe_apply_chat_template,
fn_kwargs={"tokenizer": processing_class},
remove_columns="messages" if "messages" in column_names else None, # renamed to "text"
**map_kwargs,
)
# Tokenize the dataset if needed
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"
def tokenize(example, processing_class, dataset_text_field):
processed = processing_class(text=example[dataset_text_field])
if (
processing_class.eos_token_id is not None
and processed["input_ids"][-1] != processing_class.eos_token_id
):
processed["input_ids"] = processed["input_ids"] + [processing_class.eos_token_id]
processed["attention_mask"] = processed["attention_mask"] + [1]
return processed
dataset = dataset.map(
tokenize,
fn_kwargs={"processing_class": processing_class, "dataset_text_field": args.dataset_text_field},
**map_kwargs,
)
return dataset