Skip to main content

Qwen-VL架构

Qwen2.5-VL

组成

self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config)
self.language_model = Qwen2_5_VLTextModel._from_config(config.text_config)

数据流:

文本根据token_id编码:

inputs_embeds = self.get_input_embeddings()(input_ids)

图像/视频像素值分块后传递到self.visual中

pixel_values_videos = pixel_values_videos.typ(self.visual.dtype)
video_embeds = self.visual(pixel_values_videos,grid_thw=video_grid_thw)
split_sizes = (video_grid_thw.prod(-1) // selfvisual.spatial_merge_size**2).tolist()
video_embeds = torch.split(video_embeds,split_sizes)

拼接文本与视觉embbeddings

_, video_mask = self.get_placeholder_mask(
input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds
)
# 关键拼接操作:使用masked_scatter将视频特征插入到文embedding中
inputs_embeds = inputs_embeds.masked_scatte(video_mask, video_embeds)

添加位置编码

if position_ids is None:
# Calculate RoPE index once per generation in the pre-fill stage only.
# When compiling, we can't check tensor values thus we check only input length
# It is safe to assume that `length!=1` means we're in pre-fill because compiled
# models currently cannot do asssisted decoding
prefill_compiled_stage = is_torchdynamo_compiling() and (
(input_ids is not None and input_ids.shape[1] != 1)
or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
)
prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
(cache_position is not None and cache_position[0] == 0)
or (past_key_values is None or past_key_values.get_seq_length() == 0)
)
if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:
position_ids, rope_deltas = self.get_rope_index(
input_ids,
image_grid_thw,
video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
attention_mask=attention_mask,
)
self.rope_deltas = rope_deltas
else:
batch_size, seq_length, _ = inputs_embeds.shape
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1)
if cache_position is not None:
delta = (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
else:
delta = torch.zeros((batch_size, seq_length), device=inputs_embeds.device)
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=1)
position_ids = position_ids + delta.to(position_ids.device)

送入到语言模型中预测next token:

        outputs = self.language_model(
input_ids=None,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**kwargs,
)