119 lines
4.9 KiB
Python
119 lines
4.9 KiB
Python
"""
|
|
2026.2.1
|
|
2026.2.1
|
|
4.57.6
|
|
0.24.0
|
|
__UNSLOTH_VERSIONING__
|
|
"""
|
|
|
|
# Unsloth auto generated code
|
|
# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
|
|
#
|
|
# This program is free software: you can redistribute it and/or modify
|
|
# it under the terms of the GNU Lesser General Public License as published by
|
|
# the Free Software Foundation, either version 3 of the License, or
|
|
# (at your option) any later version.
|
|
#
|
|
# This program is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU Lesser General Public License
|
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
|
|
|
|
try:
|
|
from peft.tuners.lora.layer import VARIANT_KWARG_KEYS
|
|
except ImportError:
|
|
VARIANT_KWARG_KEYS = ['alora_offsets']
|
|
torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False, 'debug': False, 'dce': True, 'memory_planning': True, 'coordinate_descent_tuning': False, 'trace.graph_diagram': False, 'compile_threads': 32, 'group_fusion': True, 'disable_progress': True, 'verbose_progress': False, 'triton.multi_kernel': 0, 'triton.use_block_ptr': False, 'triton.enable_persistent_tma_matmul': True, 'triton.autotune_at_compile_time': False, 'triton.cooperative_reductions': False, 'cuda.compile_opt_level': '-O2', 'cuda.enable_cuda_lto': True, 'combo_kernels': False, 'benchmark_combo_kernel': True, 'combo_kernel_foreach_dynamic_shapes': True}
|
|
|
|
import torch._dynamo
|
|
@torch._dynamo.disable
|
|
def _call_8bit_base_layer(base_layer, x, *args, **kwargs):
|
|
return base_layer(x, *args, **kwargs)
|
|
from torch import Tensor
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.nn import functional as F
|
|
from unsloth_zoo.temporary_patches.common import torch_compile
|
|
from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
|
|
from peft.tuners.lora.bnb import (VARIANT_KWARG_KEYS, torch)
|
|
|
|
|
|
torch_addmm = torch.addmm
|
|
torch_add = torch.add
|
|
# @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options)
|
|
def lora_forward(result, lora_A, lora_B, dropout, x, scaling):
|
|
# Use result.dtype (bfloat16 from base layer) since x may have been cast to float32
|
|
# by _cast_input_dtype when autocast is disabled
|
|
target_dtype = result.dtype
|
|
xA = dropout(x).to(target_dtype) @ lora_A.weight.to(target_dtype).t()
|
|
# output = result + scaling * xA @ lora_B.weight.t()
|
|
shape = result.shape
|
|
output = torch_addmm(
|
|
result.view(-1, shape[-1]),
|
|
xA.view(-1, xA.shape[-1]),
|
|
lora_B.weight.to(target_dtype).t(),
|
|
alpha = scaling,
|
|
beta = 1,
|
|
).view(shape)
|
|
|
|
bias = lora_B.bias
|
|
if bias is not None:
|
|
output = torch_add(
|
|
output,
|
|
bias.to(target_dtype),
|
|
alpha = scaling,
|
|
)
|
|
return output
|
|
pass
|
|
|
|
def unsloth_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
|
|
|
adapter_names = kwargs.pop("adapter_names", None)
|
|
variant_kwargs = {k: kwargs.pop(k, None) for k in VARIANT_KWARG_KEYS} # don't pass these to base_layer
|
|
|
|
if self.disable_adapters:
|
|
if self.merged:
|
|
self.unmerge()
|
|
result = _call_8bit_base_layer(self.base_layer, x, *args, **kwargs)
|
|
elif adapter_names is not None:
|
|
result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **variant_kwargs, **kwargs)
|
|
elif self.merged:
|
|
result = _call_8bit_base_layer(self.base_layer, x, *args, **kwargs)
|
|
else:
|
|
result = _call_8bit_base_layer(self.base_layer, x, *args, **kwargs)
|
|
for active_adapter in self.active_adapters:
|
|
if active_adapter not in self.lora_A.keys():
|
|
continue
|
|
lora_A = self.lora_A[active_adapter]
|
|
lora_B = self.lora_B[active_adapter]
|
|
dropout = self.lora_dropout[active_adapter]
|
|
scaling = self.scaling[active_adapter]
|
|
|
|
requires_conversion = not torch.is_autocast_enabled()
|
|
if requires_conversion:
|
|
expected_dtype = result.dtype
|
|
x = self._cast_input_dtype(x, lora_A.weight.dtype)
|
|
|
|
if active_adapter not in self.lora_variant: # vanilla LoRA
|
|
return lora_forward(result, lora_A, lora_B, dropout, x, scaling)
|
|
if requires_conversion:
|
|
output = output.to(expected_dtype)
|
|
result = result + output
|
|
else:
|
|
result = self.lora_variant[active_adapter].forward(
|
|
self,
|
|
active_adapter=active_adapter,
|
|
x=x,
|
|
result=result,
|
|
**variant_kwargs,
|
|
**kwargs,
|
|
)
|
|
if requires_conversion:
|
|
result = result.to(expected_dtype)
|
|
|
|
return result
|