Skip to main content

SFT微调qwen2-vl-2b模型

TRL关键参数

per_device_train_batch_size 1
gradient_accumulation_steps 1

模型最终的batch_size=GPU number * per_device_train_batch_size * gradient_accumulation_steps

gradient_checkpointing True

开启梯度检查点,有助于减少大量显存占用,原文介绍如下:

info

Checkpoint a model or part of the model Checkpointing works by trading compute for memory. Rather than storing all intermediate activations of the entire computation graph for computing backward, the checkpointed part does not save intermediate activations, and instead recomputes them in backward pass. It can be applied on any part of a model. Specifically, in the forward pass, function will run in torch.no_grad() manner, i.e., not storing the intermediate activations. Instead, the forward pass saves the inputs tuple and the function parameter. In the backwards pass, the saved inputs and function is retrieved, and the forward pass is computed on function again, now tracking the intermediate activations, and then the gradients are calculated using these activation values. The output of function can contain non-Tensor values and gradient recording is only performed for the Tensor values. Note that if the output consists of nested structures (ex: custom objects, lists, dicts etc.) consisting of Tensors, these Tensors nested in custom structures will not be considered as part of autograd. 原文大意如下: 检查点技术通过以计算量为代价来节省内存,即不保存整个计算图的所有中间激活值以计算反向传播,而是在反向传播时重新计算这些激活值。它可以应用于模型的任何部分。 具体而言,在前向传递中,函数将以 torch.no_grad() 的方式运行,即不保存中间激活值。相反,在前向传递中,将保存输入元组和函数参数。在反向传递中,检索保存的输入和函数,并再次对函数进行前向传递,现在跟踪中间激活值,然后使用这些激活值计算梯度。 函数的输出可以包含非张量值,并且仅对张量值执行梯度记录。请注意,如果输出包含嵌套结构(例如:自定义对象、列表、字典等),其中包含张量的嵌套结构将不被视为自动求导的一部分。

report_to none

可用于关闭 wandb

代码

不开启4bit量化是因为与deepspeed冲突

没有使用map函数处理数据集是因为使用map函数后每一个对话的content都会存在video项,而有些video项为none,这会使得在调用apply_chat_template时存在多个video_pad导致问题。

代码如下:

import json
import os
import random
from dataclasses import dataclass, field
from typing import Any, Optional
from functools import partial

import torch
from datasets import load_dataset
from peft import LoraConfig
from qwen_vl_utils import process_vision_info
from transformers import AutoModelForImageTextToText, AutoProcessor, BitsAndBytesConfig, Qwen2VLProcessor

from datasets import Dataset, DatasetDict
from trl import ModelConfig, ScriptArguments, SFTConfig, SFTTrainer, TrlParser, get_kbit_device_map


# Enable logging in a Hugging Face Space
os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio")

@dataclass
class SFTScriptArguments(ScriptArguments):
"""
Script arguments for the SFT training script.
"""
data_path: Optional[str] = field(
default=None,
metadata={"help": "json file path"},
)
jsonl_path: Optional[str] = field(
default=None,
metadata={"help": "json file path"},
)

def prepare_dataset(example: dict[str, Any], root_dir: str) -> dict[str, list[dict[str, Any]]]:
"""Prepare dataset example for training."""
system_message = "You are acting as the grounder now."
base_prompt = """Given a video and a text query {query}, your goal is to temporally localize the video moment described by the query and
output the reasoning process and the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e.,
<think> reasoning process here </think><answer> answer here </answer>
"""
messages = [
{"role": "system", "content": [{"type": "text", "text": system_message}]},
{
"role": "user",
"content": [
{
"type": "video",
"video": os.path.join(root_dir, example['video']),
'min_pixels': 16 * 28 * 28,
'max_pixels': 32 * 28 * 28,
'fps': 1.0
},
{"type": "text", "text": base_prompt.format(query=example['problem'])},
],
},
{"role": "assistant", "content": [{"type": "text", "text": example["think"] + example["solution"]}]},
]
return {"messages": messages}


def collate_fn(examples: list[dict[str, Any]]) -> dict[str, torch.Tensor]:
"""Collate batch of examples for training."""
texts = []
video_inputs = []

for i, example in enumerate(examples):
try:
video_path = next(
content["video"]
for message in example["messages"]
for content in message["content"]
if content.get("type") == "video"
)
print(f"Processing video: {video_path}")
messages = [
{
"role": "user",
"content": [
{
"type": "video",
"video": video_path,
# "total_pixels": 3584 * 28 * 28,
'min_pixels': 16 * 28 * 28,
'max_pixels': 32 * 28 * 28,
'fps': 1.0
},
],
}
]
texts.append(processor.apply_chat_template(example["messages"], tokenize=False))
video_input = process_vision_info(messages)[1][0]
video_inputs.append(video_input)
except Exception as e:
raise ValueError(f"Failed to process example {i}: {e}") from e

inputs = processor(text=texts, videos=video_inputs, return_tensors="pt", padding=True)

labels = inputs["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = -100

# Handle visual tokens based on processor type
visual_tokens = (
[151652, 151653, 151656]
if isinstance(processor, Qwen2VLProcessor)
else [processor.tokenizer.convert_tokens_to_ids(processor.image_token)]
)

for visual_token_id in visual_tokens:
labels[labels == visual_token_id] = -100

inputs["labels"] = labels
return inputs

def create_dataset_from_jsonl_simple(jsonl_path):
base_dataset = Dataset.from_json(jsonl_path)
return DatasetDict({
"train": base_dataset
})

if __name__ == "__main__":
# Parse arguments
parser = TrlParser((SFTScriptArguments, SFTConfig, ModelConfig))
script_args, training_args, model_args = parser.parse_args_and_config()

# Configure training args
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
training_args.remove_unused_columns = False

# Load dataset
if script_args.jsonl_path:
# # load dataset from jsonl
dataset = Dataset.from_json(script_args.jsonl_path)
else:
# Load the dataset
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
# trainset, testset = dataset.train_test_split(test_size=0.1).values()

# Setup model
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)

# Quantization configuration for 4-bit training
# bnb_config = BitsAndBytesConfig(
# load_in_4bit=True,
# bnb_4bit_use_double_quant=True,
# bnb_4bit_quant_type="nf4",
# bnb_4bit_compute_dtype=torch.bfloat16,
# )

# Model initialization
model_kwargs = dict(
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
dtype=dtype,
attn_implementation=model_args.attn_implementation,
# device_map=get_kbit_device_map(),
# quantization_config=bnb_config,
)

model = AutoModelForImageTextToText.from_pretrained(model_args.model_name_or_path, **model_kwargs)

peft_config = LoraConfig(
task_type="CAUSAL_LM",
r=16,
lora_alpha=16,
lora_dropout=0.1,
bias="none",
target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj'], # 'attn.qkv', 'attn.proj'
)

# Configure model modules for gradients
if training_args.gradient_checkpointing:
model.gradient_checkpointing_enable()
model.config.use_reentrant = False
model.enable_input_require_grads()

processor = AutoProcessor.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
)


# 使用 map 函数,让 Hugging Face datasets 处理数据管理
dataset = [prepare_dataset(example, script_args.data_path) for example in dataset]
# dataset = dataset.map(
# partial(prepare_dataset, root_dir=script_args.data_path),
# remove_columns=dataset.column_names,
# num_proc=4, # 可以设置多进程加速处理
# )

# Initialize trainer
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=dataset,
data_collator=collate_fn,
peft_config=peft_config,
processing_class=processor,
)

# Train model
trainer.train()

# Save final model
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub(dataset_name=script_args.dataset_name)

# Cleanup
del model
del trainer
torch.cuda.empty_cache()