536 lines
21 KiB
Python
536 lines
21 KiB
Python
"""
|
|
2026.2.1
|
|
2026.2.1
|
|
4.57.6
|
|
0.24.0
|
|
__UNSLOTH_VERSIONING__
|
|
"""
|
|
|
|
# Unsloth auto generated code
|
|
# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
|
|
#
|
|
# This program is free software: you can redistribute it and/or modify
|
|
# it under the terms of the GNU Lesser General Public License as published by
|
|
# the Free Software Foundation, either version 3 of the License, or
|
|
# (at your option) any later version.
|
|
#
|
|
# This program is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU Lesser General Public License
|
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
|
|
|
|
import os
|
|
import torch
|
|
import importlib.util
|
|
import math
|
|
if importlib.util.find_spec("unsloth_studio") is None:
|
|
UNSLOTH_STUDIO_ENABLED = False
|
|
else:
|
|
UNSLOTH_STUDIO_ENABLED = os.environ.get("UNSLOTH_STUDIO_DISABLED", "0") == "0"
|
|
pass
|
|
from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
|
|
import math
|
|
|
|
UNSLOTH_ENABLE_LOGGING = os.environ.get("UNSLOTH_ENABLE_LOGGING", "0") == "1"
|
|
UNSLOTH_ENABLE_CCE = os.environ.get("UNSLOTH_ENABLE_CCE", "1") == "1"
|
|
UNSLOTH_COMPILE_DISABLE = os.environ.get("UNSLOTH_COMPILE_DISABLE", "0") in ("1", "partial",)
|
|
|
|
import logging
|
|
logger_compiler = logging.getLogger(__name__)
|
|
if UNSLOTH_ENABLE_LOGGING:
|
|
logger_compiler.setLevel(logging.DEBUG)
|
|
|
|
global INFERENCE_RUNS
|
|
INFERENCE_RUNS = 0
|
|
|
|
try:
|
|
import torch._dynamo.eval_frame as torch_dynamo_eval_frame
|
|
torch_dynamo_eval_frame._stance.stance
|
|
torch_compiler_set_stance = torch.compiler.set_stance
|
|
except:
|
|
torch_dynamo_eval_frame = None
|
|
torch_compiler_set_stance = None
|
|
pass
|
|
|
|
from unsloth_zoo import DEVICE_TYPE_TORCH, DEVICE_COUNT
|
|
|
|
|
|
from unsloth_zoo.loss_utils import (
|
|
fused_linear_cross_entropy,
|
|
unsloth_fused_ce_loss,
|
|
)
|
|
|
|
if UNSLOTH_STUDIO_ENABLED:
|
|
from unsloth_zoo.loss_utils import fast_linear_cross_entropy
|
|
|
|
scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
|
@torch.compiler.disable(recursive = False)
|
|
def disable_compile_scaled_dot_product_attention(*args, **kwargs):
|
|
return scaled_dot_product_attention(*args, **kwargs)
|
|
pass
|
|
|
|
|
|
from transformers.modeling_flash_attention_utils import is_flash_attn_available
|
|
|
|
if is_flash_attn_available():
|
|
try:
|
|
from transformers.modeling_flash_attention_utils import flash_attn_supports_top_left_mask
|
|
except:
|
|
flash_attn_supports_top_left_mask = None
|
|
try:
|
|
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
|
except:
|
|
_flash_attention_forward = None
|
|
try:
|
|
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
|
except:
|
|
FlashAttentionKwargs = None
|
|
try:
|
|
from transformers.modeling_flash_attention_utils import flash_attn_varlen_func
|
|
except:
|
|
flash_attn_varlen_func = None
|
|
else:
|
|
flash_attn_supports_top_left_mask = None
|
|
_flash_attention_forward = None
|
|
FlashAttentionKwargs = None
|
|
flash_attn_varlen_func = None
|
|
pass
|
|
|
|
|
|
torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False, 'debug': False, 'dce': True, 'memory_planning': True, 'coordinate_descent_tuning': False, 'trace.graph_diagram': False, 'compile_threads': 32, 'group_fusion': True, 'disable_progress': True, 'verbose_progress': False, 'triton.multi_kernel': 0, 'triton.use_block_ptr': False, 'triton.enable_persistent_tma_matmul': True, 'triton.autotune_at_compile_time': False, 'triton.cooperative_reductions': False, 'cuda.compile_opt_level': '-O2', 'cuda.enable_cuda_lto': True, 'combo_kernels': False, 'benchmark_combo_kernel': True, 'combo_kernel_foreach_dynamic_shapes': True}
|
|
|
|
from torch.nn import CrossEntropyLoss
|
|
|
|
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
|
|
def normal_cross_entropy_loss(self, hidden_states, labels):
|
|
logits = self.lm_head(hidden_states)
|
|
logits = logits.float()
|
|
# Shift so that tokens < n predict n
|
|
shift_logits = logits[..., :-1, :].contiguous()
|
|
shift_labels = labels[..., 1:].contiguous()
|
|
# Flatten the tokens
|
|
loss_fct = CrossEntropyLoss()
|
|
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
|
shift_labels = shift_labels.view(-1)
|
|
# Enable model parallelism
|
|
shift_labels = shift_labels.to(shift_logits.device)
|
|
loss = loss_fct(shift_logits, shift_labels)
|
|
return loss, logits
|
|
pass
|
|
|
|
# We need an empty logits flag to warn people logits will not be returned anymore unless asked ie
|
|
# os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
|
LOGITS_ERROR_STRING = \
|
|
"Unsloth: Logits are empty from 2024.11 onwards. To get raw logits again, please "\
|
|
'set the environment variable `UNSLOTH_RETURN_LOGITS` to `"1" BEFORE starting to train ie before `trainer.train()`. For example:\n'\
|
|
"```\nimport os\n"\
|
|
"os.environ['UNSLOTH_RETURN_LOGITS'] = '1'\n"\
|
|
"trainer.train()\n```\n"\
|
|
"No need to restart your console - just add `os.environ['UNSLOTH_RETURN_LOGITS'] = '1'` before trainer.train() and re-run the cell!"
|
|
|
|
def raise_logits_error(*args, **kwargs): raise NotImplementedError(LOGITS_ERROR_STRING)
|
|
def return_none(*args, **kwargs): return None
|
|
class EmptyLogits:
|
|
def __init__(self): return
|
|
def raise_getattr_error(self, attr): return return_none if attr == "to" else raise_logits_error
|
|
__getitem__ = raise_logits_error
|
|
__getattr__ = raise_getattr_error
|
|
def __repr__(self): return LOGITS_ERROR_STRING
|
|
def __str__ (self): return LOGITS_ERROR_STRING
|
|
pass
|
|
EMPTY_LOGITS = EmptyLogits()
|
|
functions = dir(torch.Tensor)
|
|
for j, function in enumerate(functions):
|
|
if function.startswith("__") and function.endswith("__"):
|
|
exec(f"def raise_{j}(*args, **kwargs): print('{function}')", globals(), locals())
|
|
try: exec(f"EMPTY_LOGITS.{function} = raise_{j}", globals(), locals())
|
|
except: continue
|
|
pass
|
|
|
|
|
|
def mask_attention_mask_out(labels = None, attention_mask = None):
|
|
if labels is not None and attention_mask is not None:
|
|
attention_mask = attention_mask.to(device = labels.device)
|
|
labels[attention_mask == 0] = -100
|
|
return labels
|
|
pass
|
|
|
|
|
|
from torch import Tensor
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.nn import functional as F
|
|
from unsloth_zoo.temporary_patches.common import torch_compile
|
|
from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
|
|
from transformers.models.siglip.modeling_siglip import (math, warnings, Callable, Optional, np, torch, nn, _calculate_fan_in_and_fan_out, ACT2FN, ALL_ATTENTION_FUNCTIONS, torch_int, SiglipTextConfig, SiglipVisionConfig)
|
|
|
|
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
|
|
def _trunc_normal_(tensor, mean, std, a, b):
|
|
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
|
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
|
def norm_cdf(x):
|
|
# Computes standard normal cumulative distribution function
|
|
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
|
|
|
|
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
|
warnings.warn(
|
|
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
|
"The distribution of values may be incorrect.",
|
|
stacklevel=2,
|
|
)
|
|
|
|
# Values are generated by using a truncated uniform distribution and
|
|
# then using the inverse CDF for the normal distribution.
|
|
# Get upper and lower cdf values
|
|
l = norm_cdf((a - mean) / std)
|
|
u = norm_cdf((b - mean) / std)
|
|
|
|
# Uniformly fill tensor with values from [l, u], then translate to
|
|
# [2l-1, 2u-1].
|
|
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
|
|
|
# Use inverse cdf transform for normal distribution to get truncated
|
|
# standard normal
|
|
tensor.erfinv_()
|
|
|
|
# Transform to proper mean, std
|
|
tensor.mul_(std * math.sqrt(2.0))
|
|
tensor.add_(mean)
|
|
|
|
# Clamp to ensure it's in the proper range
|
|
tensor.clamp_(min=a, max=b)
|
|
|
|
|
|
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
|
|
def trunc_normal_tf_(
|
|
tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0
|
|
) -> torch.Tensor:
|
|
"""Fills the input Tensor with values drawn from a truncated
|
|
normal distribution. The values are effectively drawn from the
|
|
normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
|
|
with values outside :math:`[a, b]` redrawn until they are within
|
|
the bounds. The method used for generating the random values works
|
|
best when :math:`a \\leq \text{mean} \\leq b`.
|
|
|
|
NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
|
|
bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
|
|
and the result is subsequently scaled and shifted by the mean and std args.
|
|
|
|
Args:
|
|
tensor: an n-dimensional `torch.Tensor`
|
|
mean: the mean of the normal distribution
|
|
std: the standard deviation of the normal distribution
|
|
a: the minimum cutoff value
|
|
b: the maximum cutoff value
|
|
"""
|
|
with torch.no_grad():
|
|
_trunc_normal_(tensor, 0, 1.0, a, b)
|
|
tensor.mul_(std).add_(mean)
|
|
|
|
|
|
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
|
|
def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
|
|
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
|
|
if mode == "fan_in":
|
|
denom = fan_in
|
|
elif mode == "fan_out":
|
|
denom = fan_out
|
|
elif mode == "fan_avg":
|
|
denom = (fan_in + fan_out) / 2
|
|
|
|
variance = scale / denom
|
|
|
|
if distribution == "truncated_normal":
|
|
# constant is stddev of standard normal truncated to (-2, 2)
|
|
trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
|
|
elif distribution == "normal":
|
|
with torch.no_grad():
|
|
tensor.normal_(std=math.sqrt(variance))
|
|
elif distribution == "uniform":
|
|
bound = math.sqrt(3 * variance)
|
|
with torch.no_grad():
|
|
tensor.uniform_(-bound, bound)
|
|
else:
|
|
raise ValueError(f"invalid distribution {distribution}")
|
|
|
|
|
|
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
|
|
def lecun_normal_(tensor):
|
|
variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
|
|
|
|
|
|
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
|
|
def default_flax_embed_init(tensor):
|
|
variance_scaling_(tensor, mode="fan_in", distribution="normal")
|
|
|
|
|
|
@torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options)
|
|
def SiglipVisionEmbeddings_forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
|
|
_, _, height, width = pixel_values.shape
|
|
target_dtype = self.patch_embedding.weight.dtype
|
|
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
|
|
embeddings = patch_embeds.flatten(2).transpose(1, 2)
|
|
|
|
if interpolate_pos_encoding:
|
|
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
|
|
else:
|
|
embeddings = embeddings + self.position_embedding(self.position_ids)
|
|
return embeddings
|
|
|
|
class SiglipVisionEmbeddings(nn.Module):
|
|
def __init__(self, config: SiglipVisionConfig):
|
|
super().__init__()
|
|
self.config = config
|
|
self.embed_dim = config.hidden_size
|
|
self.image_size = config.image_size
|
|
self.patch_size = config.patch_size
|
|
|
|
self.patch_embedding = nn.Conv2d(
|
|
in_channels=config.num_channels,
|
|
out_channels=self.embed_dim,
|
|
kernel_size=self.patch_size,
|
|
stride=self.patch_size,
|
|
padding="valid",
|
|
)
|
|
|
|
self.num_patches = (self.image_size // self.patch_size) ** 2
|
|
self.num_positions = self.num_patches
|
|
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
|
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
|
|
|
|
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
|
"""
|
|
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
|
|
images. This method is also adapted to support torch.jit tracing and no class embeddings.
|
|
|
|
Adapted from:
|
|
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
|
|
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
|
|
"""
|
|
|
|
num_patches = embeddings.shape[1]
|
|
num_positions = self.position_embedding.weight.shape[0]
|
|
|
|
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
|
|
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
|
|
return self.position_embedding(self.position_ids)
|
|
|
|
patch_pos_embed = self.position_embedding.weight.unsqueeze(0)
|
|
|
|
dim = embeddings.shape[-1]
|
|
|
|
new_height = height // self.patch_size
|
|
new_width = width // self.patch_size
|
|
|
|
sqrt_num_positions = torch_int(num_positions**0.5)
|
|
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
|
|
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
|
|
|
patch_pos_embed = nn.functional.interpolate(
|
|
patch_pos_embed,
|
|
size=(new_height, new_width),
|
|
mode="bicubic",
|
|
align_corners=False,
|
|
)
|
|
|
|
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
|
return patch_pos_embed
|
|
|
|
def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
|
|
return SiglipVisionEmbeddings_forward(self, pixel_values, interpolate_pos_encoding)
|
|
|
|
|
|
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
|
|
def SiglipTextEmbeddings_forward(
|
|
self,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
) -> torch.Tensor:
|
|
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
|
|
max_position_embedding = self.position_embedding.weight.shape[0]
|
|
|
|
if seq_length > max_position_embedding:
|
|
raise ValueError(
|
|
f"Sequence length must be less than max_position_embeddings (got `sequence length`: "
|
|
f"{seq_length} and max_position_embeddings: {max_position_embedding}"
|
|
)
|
|
|
|
if position_ids is None:
|
|
position_ids = self.position_ids[:, :seq_length]
|
|
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.token_embedding(input_ids)
|
|
|
|
position_embeddings = self.position_embedding(position_ids)
|
|
embeddings = inputs_embeds + position_embeddings
|
|
|
|
return embeddings
|
|
|
|
class SiglipTextEmbeddings(nn.Module):
|
|
def __init__(self, config: SiglipTextConfig):
|
|
super().__init__()
|
|
embed_dim = config.hidden_size
|
|
|
|
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
|
|
self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
|
|
|
|
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
|
self.register_buffer(
|
|
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
) -> torch.Tensor:
|
|
return SiglipTextEmbeddings_forward(self, input_ids, position_ids, inputs_embeds)
|
|
|
|
|
|
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
|
|
def eager_attention_forward(
|
|
module: nn.Module,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor],
|
|
scaling: float,
|
|
dropout: float = 0.0,
|
|
**kwargs,
|
|
):
|
|
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
|
if attention_mask is not None:
|
|
attn_weights = attn_weights + attention_mask
|
|
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype = torch.float32).to(attn_weights.dtype).to(query.dtype)
|
|
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
|
|
|
attn_output = torch.matmul(attn_weights, value)
|
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
|
|
return attn_output, attn_weights
|
|
|
|
|
|
@torch.compiler.disable(recursive = False)
|
|
def SiglipAttention_forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
**kwargs,
|
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
"""Input shape: Batch x Time x Channel"""
|
|
|
|
batch_size, seq_length, embed_dim = hidden_states.shape
|
|
|
|
queries = self.q_proj(hidden_states)
|
|
keys = self.k_proj(hidden_states)
|
|
values = self.v_proj(hidden_states)
|
|
|
|
queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
|
keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
|
values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
|
|
|
attention_interface: Callable = eager_attention_forward
|
|
if self.config._attn_implementation != "eager":
|
|
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
|
|
|
attn_output, attn_weights = attention_interface(
|
|
self,
|
|
queries,
|
|
keys,
|
|
values,
|
|
attention_mask,
|
|
is_causal=self.is_causal,
|
|
scaling=self.scale,
|
|
dropout=0.0 if not self.training else self.dropout,
|
|
)
|
|
|
|
attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
|
|
attn_output = self.out_proj(attn_output)
|
|
|
|
return attn_output, attn_weights
|
|
|
|
class SiglipAttention(nn.Module):
|
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
self.embed_dim = config.hidden_size
|
|
self.num_heads = config.num_attention_heads
|
|
self.head_dim = self.embed_dim // self.num_heads
|
|
if self.head_dim * self.num_heads != self.embed_dim:
|
|
raise ValueError(
|
|
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
|
f" {self.num_heads})."
|
|
)
|
|
self.scale = self.head_dim**-0.5
|
|
self.dropout = config.attention_dropout
|
|
self.is_causal = False
|
|
|
|
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
|
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
|
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
|
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
**kwargs,
|
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
return SiglipAttention_forward(self, hidden_states, attention_mask, **kwargs)
|
|
|
|
|
|
@torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options)
|
|
def SiglipMLP_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
hidden_states = self.fc1(hidden_states)
|
|
hidden_states = self.activation_fn(hidden_states)
|
|
hidden_states = self.fc2(hidden_states)
|
|
return hidden_states
|
|
|
|
class SiglipMLP(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
self.activation_fn = ACT2FN[config.hidden_act]
|
|
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
|
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
return SiglipMLP_forward(self, hidden_states)
|
|
|
|
|
|
@torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options)
|
|
def SiglipMultiheadAttentionPoolingHead_forward(self, hidden_state):
|
|
batch_size = hidden_state.shape[0]
|
|
probe = self.probe.repeat(batch_size, 1, 1)
|
|
|
|
hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
|
|
|
|
residual = hidden_state
|
|
hidden_state = self.layernorm(hidden_state)
|
|
hidden_state = residual + self.mlp(hidden_state)
|
|
|
|
return hidden_state[:, 0]
|
|
|
|
class SiglipMultiheadAttentionPoolingHead(nn.Module):
|
|
"""Multihead Attention Pooling."""
|
|
|
|
def __init__(self, config: SiglipVisionConfig):
|
|
super().__init__()
|
|
|
|
self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
|
|
self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
|
|
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
self.mlp = SiglipMLP(config)
|
|
|
|
def forward(self, hidden_state):
|
|
return SiglipMultiheadAttentionPoolingHead_forward(self, hidden_state)
|