diff --git a/sans titre_diarization.rttm b/sans titre_diarization.rttm
deleted file mode 100644
index e69de29..0000000
diff --git a/sans titre_diarized.txt b/sans titre_diarized.txt
deleted file mode 100644
index 7ccb9ae..0000000
--- a/sans titre_diarized.txt
+++ /dev/null
@@ -1,7386 +0,0 @@
-[00:00:00] VOIX: Tu sais, moi, je suis toujours là, toi, tu es toujours là.
-[00:00:03] VOIX: Il faut que petit à petit, on connecte les différents fils d'une manière ou d'une autre.
-[00:00:09] VOIX: Et j'ai envie, même si je n'ai pas fait tout le boulot que je voulais faire, je l'ai commencé, même après 14 heures.
-[00:00:20] VOIX: Et je voulais juste, par rapport à T2A, par rapport à ce que tu demandes, parce que ce que tu demandes, je vais te montrer les référentiels.
-[00:00:28] VOIX: Que nous, on utilise.
-[00:00:30] VOIX: Je démarre, je partage.
-[00:00:32] VOIX: Vas-y, vas-y.
-[00:00:33] VOIX: Et je t'enregistre à l'insu de ton plein gré.
-[00:00:37] VOIX: Oui, ça marche.
-[00:00:39] VOIX: En fait, tu vois, ça, c'est ce que j'ai commencé tout à l'heure.
-[00:00:43] VOIX: Et l'idée, c'était de dire, raisonner bien sûr sur le contrôle T2A, parce que, petit à petit,
-[00:00:51] VOIX: les différentes fonctionnalités qu'on va utiliser, elles vont être les mêmes pour la suite.
-[00:01:00] VOIX: Donc, c'est bien de réfléchir sur tout ça.
-[00:01:02] VOIX: Et en fait, j'ai découpé.
-[00:01:05] VOIX: En fait, normalement, c'est avant le contrôle, pendant le contrôle et il y a après le contrôle.
-[00:01:11] VOIX: Mais tu vois, on n'en est pas là l'objectif.
-[00:01:13] VOIX: Ce n'est pas d'être exhaustif.
-[00:01:14] VOIX: C'est juste comprendre le principe.
-[00:01:17] VOIX: Et avant le contrôle, tu vois, on commence par l'établissement qui reçoit la notification ARS.
-[00:01:23] VOIX: Ça, c'est le document, en fait, que j'ai, que UCR, c'est ça ? Ou URC ?
-[00:01:29] VOIX: Non, pas du tout.
-[00:01:31] VOIX: L'UCR, il est bien après.
-[00:01:32] VOIX: C'est pour ça que c'est important que je t'explique cette logique-là.
-[00:01:36] VOIX: Et oui, c'est pour ça que je, en fait, je sais que tu, même si tu me dis, je connais la mécanique, il y a...
-[00:01:43] VOIX: Non, mais c'est les noms que je ne connais pas, donc...
-[00:01:46] VOIX: C'est une étape.
-[00:01:47] VOIX: Et l'UCR, il est...
-[00:01:48] VOIX: Il est bien après.
-[00:01:50] VOIX: Oui, d'accord, OK.
-[00:01:51] VOIX: Il est pendant, il est bien après, en fait.
-[00:01:54] VOIX: C'est ça qu'il faut...
-[00:01:56] VOIX: Oui, mais c'est...
-[00:01:58] VOIX: Là, j'ai très bien compris que c'était après, en fait, mais je ne sais pas comment les appeler.
-[00:02:02] VOIX: En fait, moi, j'ai un problème, en fait, avec la sémantique, en fait, on va dire.
-[00:02:07] VOIX: C'est...
-[00:02:07] VOIX: Mais ce n'est pas grave.
-[00:02:08] VOIX: C'est pour ça que je suis en train d'écrire un document, de manière...
-[00:02:13] VOIX: Je vais le partager avec Guy aussi, de manière à ce que, petit à petit, tout le monde...
-[00:02:18] VOIX: Parce que même Khalid, il ne connaît pas ce parcours, au fait.
-[00:02:21] VOIX: C'est un parcours du contrôle.
-[00:02:24] VOIX: La notification...
-[00:02:24] VOIX: Bon, je vais te montrer l'exemple de l'ARS, mais l'objectif, ce n'est pas de rentrer dans le détail.
-[00:02:28] VOIX: C'est juste que l'ARS, elle va écrire à l'établissement pour leur dire, bonjour, vous êtes santé, sa séance du novembre, machin, ce programme, votre établissement est inclus dans ce programme.
-[00:02:44] VOIX: Ça veut dire, vous êtes contrôlé.
-[00:02:46] VOIX: Et à partir de là, elle va leur dire le champ 1, le champ 2, de la période, etc.
-[00:02:54] VOIX: D'accord.
-[00:02:55] VOIX: Le champ 1...
-[00:02:56] VOIX: D'accord.
-[00:02:58] VOIX: Oui.
-[00:02:58] VOIX: Tu vois, ce que tu retrouves par la suite, quand on dit champ 1, champ 2, etc., c'est à partir de ce courrier.
-[00:03:05] VOIX: Et donc, à partir de là, nous, dans EVA, on a un requêteur dans la phase initiée, d'accord, qui est dans EVA.
-[00:03:13] VOIX: J'ai mis 80% parce qu'il n'est pas ici à 100%.
-[00:03:18] VOIX: On est à 70-80%.
-[00:03:19] VOIX: Il y a des cas qu'on n'arrive pas encore à faire.
-[00:03:22] VOIX: On a immédiatement l'automatisation d'un tableau de synthèse, des dossiers potentiels.
-[00:03:26] VOIX: C'est ce que j'appelle les dossiers potentiels.
-[00:03:28] VOIX: Le nombre, les euros, les types.
-[00:03:30] VOIX: Bon, ça, je te montrerai tout ça après.
-[00:03:32] VOIX: Ce n'est pas l'objet.
-[00:03:33] VOIX: La liste des dossiers.
-[00:03:35] VOIX: Et l'avenir, c'est ce qu'il a travaillé Khalid.
-[00:03:38] VOIX: Lecture, je ne sais pas, est-ce qu'il a fait l'IA ou d'autres méthodes.
-[00:03:42] VOIX: Tu vois, moi, j'ai dit lecture par l'IA du document.
-[00:03:44] VOIX: Et requêtage, l'IA, synthèse des documents.
-[00:03:47] VOIX: Et ça, c'est testé, approuvé par Khalid et par l'équipe.
-[00:03:51] VOIX: C'est-à-dire, à partir de ce document, il n'y a même plus besoin de passer par l'étape initiée.
-[00:03:57] VOIX: Il sort la liste des dossiers potentiels qu'on peut exploiter sous un format Excel, etc.
-[00:04:04] VOIX: Pour l'instant, tout ça, c'est bien sûr en rappelie.
-[00:04:07] VOIX: Tout ce que je raconte ici.
-[00:04:09] VOIX: Mais en tout cas, Khalid a travaillé sur ça.
-[00:04:11] VOIX: Ça, entre cette étape et cette étape, il peut se passer un mois, comme une semaine, comme...
-[00:04:17] VOIX: Bon, ça, c'est juste, il n'y a rien ici à faire.
-[00:04:21] VOIX: C'est l'assurance maladie.
-[00:04:22] VOIX: Le MRC, c'est médecin responsable du contrôle.
-[00:04:25] VOIX: Le MDIM, c'est médecin DIM, qui crée des comptes.
-[00:04:28] VOIX: Bon, ça, c'est le document de création, etc.
-[00:04:31] VOIX: Après, il y a un fichier de dossier candidat.
-[00:04:35] VOIX: On ne reploite pas dans EVA.
-[00:04:39] VOIX: Point d'interrogation.
-[00:04:41] VOIX: On n'a pas intégré, nous, dans le process habituel.
-[00:04:44] VOIX: Mais c'est vrai que l'année dernière, certains établissements nous ont dit qu'ils commençaient à préparer dès les dossiers candidats,
-[00:04:50] VOIX: voire même potentiels, voire même là.
-[00:04:52] VOIX: Ils commencent à regarder les dossiers, à préparer leurs argumentaires, etc.
-[00:04:58] VOIX: Préparer, c'est ça.
-[00:04:59] VOIX: C'est regarder le codage, préparer les argumentaires, etc.
-[00:05:04] VOIX: Mais à ce moment-là, on n'a pas les fiches OGC, parce que, de toute façon, on n'a pas de numéro OGC.
-[00:05:13] VOIX: Tout ça, c'est le RSSDEF qui a les NUM OGC, le téléchargement de RSSDEF de Druid,
-[00:05:19] VOIX: déclenchement de la phase préparée et de la fiche OGC.
-[00:05:25] VOIX: Et ça, franchement, pour l'instant, je ne vois pas en quoi il y a.
-[00:05:32] VOIX: J'ai une question à te poser.
-[00:05:33] VOIX: Pourquoi tu mets une étape dossier candidat non exploité aujourd'hui ?
-[00:05:37] VOIX: Et après, comment les établissements commencent à préparer des dossiers alors qu'ils n'ont pas les fiches ?
-[00:05:43] VOIX: Parce qu'à partir de la deuxième étape-là, il y a des échanges de fichiers entre l'assurance maladie et l'établissement.
-[00:05:51] VOIX: Et l'établissement va envoyer à l'assurance maladie le fichier PMSI.
-[00:05:57] VOIX: L'assurance maladie va dire, vous allez avoir, pour le champ 1, par exemple, vous avez 200 dossiers candidats, d'accord ?
-[00:06:08] VOIX: Et en fait, après, si le nombre de dossiers sur un champ pour un établissement est très élevé, ça dépend cas par cas,
-[00:06:16] VOIX: ils vont faire un tirage au sort.
-[00:06:18] VOIX: C'est ça qu'ils appellent les dossiers qui vont être réellement contrôlés, avec un fichier qui s'appelle RSSCOND.
-[00:06:25] VOIX: Nous, pour l'instant, on ne l'utilise pas, parce qu'on attend que le contrôleur décide des dossiers qui seront réellement contrôlés.
-[00:06:35] VOIX: Mais certains établissements qui disent, oui, mais moi, j'ai que 200-300 dossiers pour l'établissement,
-[00:06:40] VOIX: je sais que le nombre de dossiers candidats va être le même que le dossier définitif du contrôle,
-[00:06:48] VOIX: ils commencent déjà à préparer, mais ils n'ont pas de numéro OGC, parce que le numéro OGC, il est dans ce fichier RSSDEF.
-[00:06:56] VOIX: Et nous, l'appli, pour la déclencher vraiment avec les fiches OGC, tout le bazar, on part du RSSDEF.
-[00:07:05] VOIX: Avant le RSSDEF, on peut faire des simulations avec un requêteur, mais ça ne déclenche pas la fiche OGC.
-[00:07:12] VOIX: Quand on dit fiche OGC, est-ce que ça te parle ?
-[00:07:15] VOIX: Oui, oui, oui, oui, ça je l'ai compris.
-[00:07:18] VOIX: Voilà.
-[00:07:19] VOIX: Ok.
-[00:07:20] VOIX: Ça, Jordan, il pourra te montrer ce que c'est la fiche OGC.
-[00:07:24] VOIX: C'est la fiche, en fait, dans laquelle on va modifier un codage et qui intègre dans le groupeur de la TIH
-[00:07:31] VOIX: et qui va dire, vous allez perdre, gagner tel ou tel euro.
-[00:07:36] VOIX: Tu vois, c'est elle qui va historiser les simulations.
-[00:07:40] VOIX: Comme l'objet, ce n'est pas de te montrer ça dans le détail, je ne vais pas me connecter.
-[00:07:44] VOIX: Sinon, je peux très bien me connecter et te montrer un petit peu tout ça.
-[00:07:49] VOIX: Non, on le verra après.
-[00:07:51] VOIX: Je veux bien, en fait, que tu continues.
-[00:07:54] VOIX: Là, après, en fait, tu passes sur...
-[00:07:55] VOIX: Oui, vas-y, pardon.
-[00:07:58] VOIX: Non, non, vas-y, pose ta question.
-[00:07:59] VOIX: Non, c'est simplement, en fait, dans le dossier de contrôle, tu télécharges.
-[00:08:04] VOIX: Ok, et de druide, c'est...
-[00:08:07] VOIX: Et de druide, qu'est-ce que c'est, druide ?
-[00:08:09] VOIX: Druide, c'est fichier pémicyde.
-[00:08:12] VOIX: C'est fichier pémicyde de l'établissement.
-[00:08:14] VOIX: Mais c'est déjà pas assez compliqué, vous rajoutez, en fait, des mots.
-[00:08:19] VOIX: Oui, mais ça, c'est le...
-[00:08:20] VOIX: Malheureusement, c'est...
-[00:08:22] VOIX: C'est pour ça que c'est important que je t'explique, que je crasse,
-[00:08:26] VOIX: que tu puisses un peu, quand on dit UCR,
-[00:08:29] VOIX: tu sais dans quelle étape...
-[00:08:31] VOIX: Oui, ça, c'est l'affaire.
-[00:08:31] VOIX: Qu'est-ce qu'il reste et quelle étape, etc.
-[00:08:34] VOIX: Tu vois, RSS-DEF, nous, on en parle beaucoup de druide.
-[00:08:39] VOIX: Dans les démos, on leur dit, vous téléchargez votre RSS-DEF, votre druide,
-[00:08:43] VOIX: et vous allez avoir immédiatement tout ce qui s'affiche dans l'étape préparée.
-[00:08:49] VOIX: Et à partir de cette étape,
-[00:08:52] VOIX: l'établissement peut commencer à regarder sans codage.
-[00:08:56] VOIX: Mais pour regarder sans codage,
-[00:08:58] VOIX: soit il se connecte à son DPI,
-[00:09:00] VOIX: soit en parallèle.
-[00:09:02] VOIX: Ça, c'est pas une étape après,
-[00:09:03] VOIX: parce qu'elle peut se faire en parallèle de tout ça.
-[00:09:06] VOIX: Elle peut se faire après ça,
-[00:09:08] VOIX: c'est-à-dire extraction et exploitation des documents justificatifs
-[00:09:12] VOIX: depuis le DPI,
-[00:09:13] VOIX: c'est depuis le dossier patient.
-[00:09:14] VOIX: D'accord ?
-[00:09:15] VOIX: Aujourd'hui, nada, on n'en fait rien.
-[00:09:20] VOIX: Extraction, hier, dans une démo,
-[00:09:22] VOIX: il y a une nana du CHU de...
-[00:09:26] VOIX: C'est pas le CHU de Nice,
-[00:09:27] VOIX: c'était quel CHU ?
-[00:09:28] VOIX: Je ne me rappelle plus.
-[00:09:29] VOIX: Qui nous a dit qu'ils ont mis à 5,
-[00:09:32] VOIX: un mois
-[00:09:35] VOIX: pour extraire
-[00:09:37] VOIX: et classer,
-[00:09:38] VOIX: renommer,
-[00:09:39] VOIX: organiser les fichiers
-[00:09:41] VOIX: pour les transmettre à l'équipe de contrôle.
-[00:09:43] VOIX: Je ne parle pas d'exploiter
-[00:09:45] VOIX: qui est lire.
-[00:09:46] VOIX: Je parle juste de sortir les documents,
-[00:09:49] VOIX: les organiser,
-[00:09:51] VOIX: les renommer, etc.
-[00:09:52] VOIX: Parce que soit le service informatique,
-[00:09:54] VOIX: il est futé,
-[00:09:55] VOIX: et le DPI permet d'extraire
-[00:09:57] VOIX: de manière automatique,
-[00:09:58] VOIX: comme nous,
-[00:09:59] VOIX: on a fait à Bayonne,
-[00:10:01] VOIX: mais il fallait que je tape du poing
-[00:10:03] VOIX: pour qu'il le fasse.
-[00:10:04] VOIX: Soit l'établissement,
-[00:10:05] VOIX: l'informatique,
-[00:10:06] VOIX: ils disent à nous,
-[00:10:07] VOIX: nous, on n'en sait rien.
-[00:10:08] VOIX: Et donc, typiquement,
-[00:10:09] VOIX: cet établissement,
-[00:10:10] VOIX: ils ont mis 5 personnes
-[00:10:11] VOIX: pendant un mois,
-[00:10:12] VOIX: ce qui je trouve extrêmement aberrant.
-[00:10:15] VOIX: Et là,
-[00:10:16] VOIX: c'est pour ça que l'avenir,
-[00:10:18] VOIX: tu vois,
-[00:10:18] VOIX: en fait,
-[00:10:19] VOIX: ici, c'est Eva,
-[00:10:21] VOIX: et demain,
-[00:10:22] VOIX: c'est avenir.
-[00:10:24] VOIX: C'est pour ça que j'ai séparé.
-[00:10:25] VOIX: Et tu vas voir
-[00:10:26] VOIX: que toutes ces fonctionnalités,
-[00:10:27] VOIX: je te dirai après
-[00:10:28] VOIX: comment on peut les travailler après.
-[00:10:32] VOIX: Parce que si tu as un agent
-[00:10:34] VOIX: d'IAVision
-[00:10:35] VOIX: qui connaît les DPI
-[00:10:36] VOIX: et que tu lui dis
-[00:10:38] VOIX: OK, je branche,
-[00:10:40] VOIX: tu sais,
-[00:10:40] VOIX: on est toujours dans la logique
-[00:10:41] VOIX: est-ce qu'on arrivera un jour
-[00:10:42] VOIX: à brancher
-[00:10:45] VOIX: l'IAVision
-[00:10:45] VOIX: pour extraire
-[00:10:47] VOIX: des documents,
-[00:10:48] VOIX: des informations,
-[00:10:49] VOIX: etc.
-[00:10:50] VOIX: Ben ça,
-[00:10:51] VOIX: tu vois,
-[00:10:51] VOIX: j'ai mis un point
-[00:10:52] VOIX: d'interrogation
-[00:10:52] VOIX: parce que c'est pas
-[00:10:54] VOIX: d'actualité,
-[00:10:55] VOIX: mais en tout cas,
-[00:10:56] VOIX: il faut savoir
-[00:10:57] VOIX: que ça peut être un point
-[00:11:00] VOIX: pour les établissements
-[00:11:03] VOIX: parce que exploiter les documents
-[00:11:06] VOIX: c'est une chose
-[00:11:07] VOIX: mais les extraire
-[00:11:08] VOIX: c'est important
-[00:11:09] VOIX: dans un contrôle
-[00:11:10] VOIX: parce qu'il faut
-[00:11:12] VOIX: les transmettre
-[00:11:13] VOIX: via Blue File
-[00:11:14] VOIX: à l'équipe du contrôle.
-[00:11:15] VOIX: Tu sais pourquoi
-[00:11:18] VOIX: Christophe,
-[00:11:19] VOIX: ah non,
-[00:11:20] VOIX: t'étais pas là,
-[00:11:21] VOIX: quand Jordan lui a créé
-[00:11:23] VOIX: des répertoires vides
-[00:11:24] VOIX: au GC1,
-[00:11:25] VOIX: numéro de dossier,
-[00:11:26] VOIX: machin,
-[00:11:27] VOIX: ben justement,
-[00:11:28] VOIX: il aide
-[00:11:28] VOIX: à classer,
-[00:11:30] VOIX: à renommer
-[00:11:31] VOIX: ces fichiers
-[00:11:32] VOIX: pour qu'il les envoie
-[00:11:33] VOIX: au médecin contrôleur.
-[00:11:35] VOIX: Et l'exploitation,
-[00:11:36] VOIX: c'est ce que tu vas me montrer,
-[00:11:37] VOIX: c'est l'objet
-[00:11:38] VOIX: de ce qu'on va voir.
-[00:11:40] VOIX: Projet IA Dominique
-[00:11:41] VOIX: en cours,
-[00:11:42] VOIX: la synthèse
-[00:11:43] VOIX: du dossier patient,
-[00:11:44] VOIX: le codage
-[00:11:45] VOIX: PMSI groupage,
-[00:11:47] VOIX: argumentaire
-[00:11:47] VOIX: de chaque codage,
-[00:11:48] VOIX: etc.
-[00:11:49] VOIX: Et donc,
-[00:11:49] VOIX: la synthèse,
-[00:11:50] VOIX: c'est ta question.
-[00:11:51] VOIX: Comment je différencie
-[00:11:53] VOIX: l'histoire ancienne
-[00:11:54] VOIX: de la réalité
-[00:11:55] VOIX: de ce qui est pris en charge
-[00:11:56] VOIX: pendant le séjour
-[00:11:57] VOIX: qui m'intéresse ?
-[00:11:59] VOIX: Quel est le motif
-[00:12:00] VOIX: de l'hospitalisation ?
-[00:12:02] VOIX: Quelles sont
-[00:12:03] VOIX: les pathologies
-[00:12:04] VOIX: prises en charge ?
-[00:12:05] VOIX: Et donc,
-[00:12:06] VOIX: c'est hyper important
-[00:12:08] VOIX: de dire que,
-[00:12:09] VOIX: avant de coder
-[00:12:10] VOIX: le PMSI,
-[00:12:11] VOIX: c'est ça le cerveau
-[00:12:13] VOIX: dont tu dis,
-[00:12:14] VOIX: ça manque.
-[00:12:15] VOIX: C'est-à-dire,
-[00:12:15] VOIX: avant de dire,
-[00:12:16] VOIX: je code,
-[00:12:17] VOIX: tous ceux qui commencent
-[00:12:18] VOIX: par dire,
-[00:12:18] VOIX: ah, c'est quoi
-[00:12:19] VOIX: comme code,
-[00:12:19] VOIX: c'est quoi comme code,
-[00:12:20] VOIX: c'est des gens
-[00:12:21] VOIX: qui ne vont pas forcément
-[00:12:22] VOIX: bien raisonner
-[00:12:23] VOIX: sur le dossier.
-[00:12:24] VOIX: Ceux qui raisonnent,
-[00:12:26] VOIX: c'est ceux qui vont
-[00:12:27] VOIX: comprendre le dossier
-[00:12:28] VOIX: et l'organiser,
-[00:12:31] VOIX: de dire,
-[00:12:31] VOIX: ok,
-[00:12:31] VOIX: ça c'est l'histoire ancienne,
-[00:12:33] VOIX: ça a commencé
-[00:12:34] VOIX: telle année,
-[00:12:35] VOIX: ça arrive aujourd'hui,
-[00:12:37] VOIX: voici le contexte,
-[00:12:38] VOIX: l'histoire du patient,
-[00:12:39] VOIX: mais aujourd'hui,
-[00:12:40] VOIX: il vient en hospitalisation,
-[00:12:42] VOIX: pourquoi ?
-[00:12:43] VOIX: Et qu'est-ce qui a été
-[00:12:44] VOIX: pris en charge ?
-[00:12:46] VOIX: Ça,
-[00:12:47] VOIX: mine de rien,
-[00:12:47] VOIX: c'est la clé du codage.
-[00:12:49] VOIX: Si tu n'as pas
-[00:12:50] VOIX: cette logique de synthèse,
-[00:12:52] VOIX: de compréhension,
-[00:12:53] VOIX: en fait,
-[00:12:54] VOIX: c'est compréhension,
-[00:12:55] VOIX: je vais l'appeler comme ça,
-[00:12:57] VOIX: compréhension,
-[00:13:00] VOIX: compréhension,
-[00:13:03] VOIX: organisation,
-[00:13:05] VOIX: aïe,
-[00:13:06] VOIX: aïe,
-[00:13:06] VOIX: aïe,
-[00:13:06] VOIX: aïe,
-[00:13:07] VOIX: aïe,
-[00:13:09] VOIX: organization,
-[00:13:14] VOIX: et synthèse du DPI,
-[00:13:16] VOIX: pour moi,
-[00:13:17] VOIX: franchement,
-[00:13:18] VOIX: on va le voir,
-[00:13:19] VOIX: parce que c'est l'exemple
-[00:13:21] VOIX: que tu demandes,
-[00:13:22] VOIX: mais on va travailler
-[00:13:23] VOIX: sur ça.
-[00:13:24] VOIX: Pour que le codage
-[00:13:25] VOIX: PMC,
-[00:13:26] VOIX: après,
-[00:13:26] VOIX: c'est facile,
-[00:13:27] VOIX: le groupage,
-[00:13:28] VOIX: c'est facile,
-[00:13:29] VOIX: et si on comprend bien,
-[00:13:32] VOIX: on organise bien,
-[00:13:32] VOIX: on synthétise bien,
-[00:13:33] VOIX: l'argumentaire devient logique.
-[00:13:36] VOIX: D'accord ?
-[00:13:38] VOIX: Donc,
-[00:13:39] VOIX: on va dire,
-[00:13:40] VOIX: c'est le sujet du moment.
-[00:13:44] VOIX: Et une fois que j'ai fait ça,
-[00:13:47] VOIX: l'idéal,
-[00:13:48] VOIX: évidemment,
-[00:13:49] VOIX: c'est qu'on crée
-[00:13:50] VOIX: une préparation IA,
-[00:13:52] VOIX: c'est-à-dire,
-[00:13:53] VOIX: on dit préparation du codage
-[00:13:54] VOIX: des dossiers par IA,
-[00:13:55] VOIX: aujourd'hui,
-[00:13:55] VOIX: on n'a pas la possibilité
-[00:13:56] VOIX: de faire ça,
-[00:13:58] VOIX: mais on peut ajouter une phase
-[00:14:00] VOIX: qui s'appelle
-[00:14:00] VOIX: créer une phase préparée IA,
-[00:14:02] VOIX: c'est un agent IA
-[00:14:03] VOIX: vision ou autre,
-[00:14:04] VOIX: je n'en sais rien,
-[00:14:05] VOIX: pour saisir directement
-[00:14:06] VOIX: les codes IA
-[00:14:07] VOIX: dans la phase préparée IA.
-[00:14:09] VOIX: Ce codage IA
-[00:14:10] VOIX: peut même concerner
-[00:14:11] VOIX: les dossiers potentiels,
-[00:14:12] VOIX: les dossiers candidats,
-[00:14:13] VOIX: puis les dossiers du RSSDF,
-[00:14:14] VOIX: mais bon,
-[00:14:14] VOIX: ça,
-[00:14:14] VOIX: c'est du détail.
-[00:14:17] VOIX: Et ensuite,
-[00:14:18] VOIX: la phase qui existe aujourd'hui,
-[00:14:20] VOIX: c'est préparer des dossiers
-[00:14:21] VOIX: par l'établissement.
-[00:14:23] VOIX: C'est ça,
-[00:14:24] VOIX: comme j'ai dit ici,
-[00:14:25] VOIX: le RSSDF,
-[00:14:26] VOIX: il déclenche la phase préparée.
-[00:14:28] VOIX: Nous,
-[00:14:28] VOIX: ces étapes-là,
-[00:14:29] VOIX: aujourd'hui,
-[00:14:29] VOIX: elles n'existent pas.
-[00:14:30] VOIX: On bascule vers ça.
-[00:14:32] VOIX: Et la phase préparée,
-[00:14:33] VOIX: aujourd'hui,
-[00:14:34] VOIX: c'est l'humain.
-[00:14:34] VOIX: C'est l'humain
-[00:14:35] VOIX: qui va ouvrir
-[00:14:36] VOIX: d'une manière ou d'une autre,
-[00:14:37] VOIX: soit directement
-[00:14:38] VOIX: son logiciel des pays,
-[00:14:40] VOIX: soit ses fiches PDF,
-[00:14:43] VOIX: voilà,
-[00:14:43] VOIX: et qui va dire,
-[00:14:45] VOIX: je décide de tel codage
-[00:14:46] VOIX: et qui va le saisir
-[00:14:48] VOIX: dans la phase préparée
-[00:14:49] VOIX: de Eva.
-[00:14:49] VOIX: mais on peut imaginer
-[00:14:51] VOIX: une étape POC,
-[00:14:53] VOIX: parce que là aussi,
-[00:14:54] VOIX: c'est un POC,
-[00:14:55] VOIX: exploitation POC,
-[00:14:57] VOIX: ça aussi,
-[00:14:58] VOIX: il y aura qui vont le lire,
-[00:15:00] VOIX: mais si on a l'étape POC
-[00:15:01] VOIX: pour valider la phase préparée,
-[00:15:03] VOIX: il y a,
-[00:15:03] VOIX: ça va rejoindre ça.
-[00:15:09] VOIX: puisque l'humain,
-[00:15:10] VOIX: et d'ailleurs,
-[00:15:11] VOIX: il y a,
-[00:15:12] VOIX: tu vois,
-[00:15:12] VOIX: Frédéric,
-[00:15:13] VOIX: etc.,
-[00:15:14] VOIX: c'est des gens,
-[00:15:15] VOIX: soit on l'intègre
-[00:15:16] VOIX: d'emblée
-[00:15:17] VOIX: dans Eva,
-[00:15:19] VOIX: soit on le fait
-[00:15:20] VOIX: à part,
-[00:15:21] VOIX: et on fait valider
-[00:15:23] VOIX: par les médecins,
-[00:15:24] VOIX: par les établissements,
-[00:15:25] VOIX: ce que l'IA propose ici,
-[00:15:27] VOIX: réfléchir,
-[00:15:28] VOIX: ça c'est le premier truc,
-[00:15:30] VOIX: valider déjà
-[00:15:30] VOIX: ou dévalider
-[00:15:31] VOIX: les propositions,
-[00:15:33] VOIX: les synthèses,
-[00:15:33] VOIX: etc.,
-[00:15:34] VOIX: et puis réfléchir
-[00:15:35] VOIX: à comment articuler
-[00:15:36] VOIX: toutes ces nouvelles fonctionnalités
-[00:15:38] VOIX: et les intégrer
-[00:15:39] VOIX: dans le logiciel Eva,
-[00:15:41] VOIX: et les exploiter
-[00:15:43] VOIX: pour EvaNov
-[00:15:44] VOIX: et pour tout ce qui va suivre,
-[00:15:46] VOIX: parce que c'est,
-[00:15:47] VOIX: il faut les raisonner
-[00:15:50] VOIX: comme des petites fonctionnalités
-[00:15:51] VOIX: qui puissent être intégrées
-[00:15:56] VOIX: à d'autres logiciels
-[00:15:58] VOIX: ou à d'autres projets.
-[00:16:04] VOIX: Oui,
-[00:16:04] VOIX: je vois l'idée en fait,
-[00:16:06] VOIX: Amina,
-[00:16:07] VOIX: de ce côté-là,
-[00:16:07] VOIX: j'ai bien vu l'idée.
-[00:16:11] VOIX: OK.
-[00:16:11] VOIX: Est-ce que jusque-là,
-[00:16:12] VOIX: déjà,
-[00:16:13] VOIX: sur tes étapes,
-[00:16:14] VOIX: c'est bon.
-[00:16:15] VOIX: C'est bon.
-[00:16:16] VOIX: Pour moi,
-[00:16:17] VOIX: j'ai commencé.
-[00:16:17] VOIX: Oui,
-[00:16:18] VOIX: vas-y,
-[00:16:18] VOIX: vas-y,
-[00:16:18] VOIX: continue.
-[00:16:19] VOIX: Non,
-[00:16:19] VOIX: non,
-[00:16:19] VOIX: vas-y.
-[00:16:20] VOIX: L'étape pendant,
-[00:16:21] VOIX: je l'ai commencé,
-[00:16:24] VOIX: en fait,
-[00:16:24] VOIX: c'est les documents
-[00:16:26] VOIX: justificatifs du DPI,
-[00:16:28] VOIX: avant même de dire,
-[00:16:29] VOIX: bonjour Guy,
-[00:16:31] VOIX: il vient de rentrer,
-[00:16:32] VOIX: avant même de dire,
-[00:16:34] VOIX: je les lis,
-[00:16:35] VOIX: je les comprends,
-[00:16:36] VOIX: je les organise,
-[00:16:37] VOIX: etc.,
-[00:16:38] VOIX: il y a quand même
-[00:16:38] VOIX: une petite étape
-[00:16:40] VOIX: de dire,
-[00:16:41] VOIX: tu vois,
-[00:16:42] VOIX: directement dans ton tracker
-[00:16:43] VOIX: que tu as dit,
-[00:16:44] VOIX: tu vas avoir un truc
-[00:16:45] VOIX: qui s'appelle observation,
-[00:16:47] VOIX: un autre qui s'appelle
-[00:16:48] VOIX: transmission,
-[00:16:49] VOIX: un autre qui s'appelle
-[00:16:50] VOIX: croc en train d'opératoire,
-[00:16:52] VOIX: un autre qui s'appelle
-[00:16:52] VOIX: CRH ou pas.
-[00:16:54] VOIX: Et donc,
-[00:16:55] VOIX: en fait,
-[00:16:55] VOIX: il faut les classer,
-[00:16:57] VOIX: peut-être les renommer,
-[00:16:59] VOIX: avoir,
-[00:16:59] VOIX: ça c'est un point d'interrogation,
-[00:17:01] VOIX: les organiser,
-[00:17:02] VOIX: etc.,
-[00:17:02] VOIX: pour les transmettre
-[00:17:04] VOIX: via Blufa
-[00:17:05] VOIX: et la liquide du contrôle,
-[00:17:07] VOIX: plus,
-[00:17:11] VOIX: identifier,
-[00:17:11] VOIX: peut-être,
-[00:17:12] VOIX: les documents manquants,
-[00:17:14] VOIX: parce que souvent,
-[00:17:16] VOIX: ça,
-[00:17:17] VOIX: on le fait à la main.
-[00:17:18] VOIX: aujourd'hui,
-[00:17:19] VOIX: Eva Nada,
-[00:17:21] VOIX: à venir,
-[00:17:22] VOIX: fonctionnalité,
-[00:17:23] VOIX: à réfléchir.
-[00:17:24] VOIX: Question,
-[00:17:24] VOIX: en fait,
-[00:17:25] VOIX: parce que ça fait plusieurs fois
-[00:17:26] VOIX: que je le vois,
-[00:17:27] VOIX: Blufaile,
-[00:17:28] VOIX: c'est quoi,
-[00:17:29] VOIX: en fait,
-[00:17:29] VOIX: exactement ?
-[00:17:30] VOIX: Blufaile,
-[00:17:31] VOIX: c'est juste un...
-[00:17:32] VOIX: C'est un système
-[00:17:34] VOIX: de transmission.
-[00:17:35] VOIX: C'est le dépôt,
-[00:17:35] VOIX: c'est le...
-[00:17:36] VOIX: c'est pas...
-[00:17:37] VOIX: c'est comme un S
-[00:17:37] VOIX: ou comme Dropbox.
-[00:17:39] VOIX: D'accord.
-[00:17:40] VOIX: Sécurisé de l'assurance maladie.
-[00:17:42] VOIX: C'est-à-dire,
-[00:17:43] VOIX: c'est eux qui gèrent ça,
-[00:17:44] VOIX: ils ouvrent,
-[00:17:45] VOIX: ils ouvrent à l'établissement,
-[00:17:46] VOIX: ou médecins d'île
-[00:17:47] VOIX: de l'établissement,
-[00:17:47] VOIX: un compte.
-[00:17:48] VOIX: C'est ce truc-là,
-[00:17:50] VOIX: je vais l'ouvrir,
-[00:17:50] VOIX: ça va mieux te parler.
-[00:17:53] VOIX: Reception,
-[00:17:55] VOIX: tu vois,
-[00:17:55] VOIX: c'est l'assurance maladie,
-[00:17:56] VOIX: d'un mail,
-[00:17:57] VOIX: Blufaile,
-[00:17:57] VOIX: création du compte,
-[00:17:58] VOIX: Blufaile,
-[00:17:58] VOIX: etc.
-[00:17:59] VOIX: Et donc,
-[00:18:00] VOIX: c'est eux qui t'ouvrent,
-[00:18:01] VOIX: tu valides par ton mail,
-[00:18:03] VOIX: etc.
-[00:18:03] VOIX: Et derrière,
-[00:18:04] VOIX: c'est l'outil de communication
-[00:18:06] VOIX: sécurisé,
-[00:18:07] VOIX: parce que tu vas déposer
-[00:18:08] VOIX: des dossiers patients,
-[00:18:10] VOIX: nominatifs,
-[00:18:11] VOIX: tu vas échanger
-[00:18:12] VOIX: les fichiers du contrôle,
-[00:18:13] VOIX: etc.
-[00:18:14] VOIX: C'est comme nous,
-[00:18:15] VOIX: on fait avec S,
-[00:18:16] VOIX: avec notre répertoire,
-[00:18:20] VOIX: tu vois,
-[00:18:20] VOIX: ShareFile,
-[00:18:21] VOIX: c'est un peu le ShareFile
-[00:18:23] VOIX: de l'assurance maladie.
-[00:18:26] VOIX: Ok.
-[00:18:29] VOIX: Après,
-[00:18:30] VOIX: bien sûr,
-[00:18:30] VOIX: il y a d'autres étapes
-[00:18:31] VOIX: que je travaillerai,
-[00:18:32] VOIX: mais tu vois,
-[00:18:33] VOIX: en fait,
-[00:18:37] VOIX: toi,
-[00:18:38] VOIX: ce que tu m'as demandé,
-[00:18:39] VOIX: c'est ça,
-[00:18:42] VOIX: on va se concentrer sur ça,
-[00:18:44] VOIX: tu vois.
-[00:18:44] VOIX: parce qu'il y a plusieurs choses
-[00:18:46] VOIX: et je vais essayer
-[00:18:47] VOIX: de les lister petit à petit,
-[00:18:49] VOIX: les revoir ensemble,
-[00:18:51] VOIX: parce que c'est important
-[00:18:52] VOIX: qu'on soit tous
-[00:18:52] VOIX: sur la même longueur d'onde,
-[00:18:54] VOIX: parce que peut-être
-[00:18:55] VOIX: que moi aussi,
-[00:18:56] VOIX: j'ai oublié des choses,
-[00:18:57] VOIX: etc.
-[00:18:57] VOIX: pour que l'on voit
-[00:19:00] VOIX: de quoi on parle
-[00:19:01] VOIX: et quand je t'ai dit
-[00:19:02] VOIX: la lecture
-[00:19:02] VOIX: et les argumentaires,
-[00:19:04] VOIX: les argumentaires
-[00:19:05] VOIX: de chaque codage,
-[00:19:06] VOIX: tu vas voir
-[00:19:07] VOIX: que ça va dépendre
-[00:19:08] VOIX: des étapes
-[00:19:09] VOIX: parce que ton argumentaire
-[00:19:10] VOIX: ici
-[00:19:12] VOIX: de la phase avant,
-[00:19:14] VOIX: il va être différent
-[00:19:16] VOIX: de ton argumentaire
-[00:19:17] VOIX: de la phase après.
-[00:19:19] VOIX: Ben oui,
-[00:19:20] VOIX: ça,
-[00:19:21] VOIX: là,
-[00:19:21] VOIX: il y a une logique.
-[00:19:23] VOIX: Là,
-[00:19:24] VOIX: pour moi,
-[00:19:24] VOIX: en fait,
-[00:19:24] VOIX: il y a une logique.
-[00:19:26] VOIX: OK.
-[00:19:27] VOIX: Bien sûr,
-[00:19:28] VOIX: parce que ici,
-[00:19:29] VOIX: l'établissement,
-[00:19:29] VOIX: il est avec lui-même,
-[00:19:31] VOIX: si tu veux.
-[00:19:32] VOIX: Il prépare,
-[00:19:33] VOIX: mais il n'a aucune idée
-[00:19:34] VOIX: de comment
-[00:19:35] VOIX: le médecin contrôleur
-[00:19:36] VOIX: va réagir,
-[00:19:37] VOIX: en fait.
-[00:19:46] VOIX: quelqu'un qui va lui valider
-[00:19:47] VOIX: ou pas son codage.
-[00:19:50] VOIX: Mais après,
-[00:19:51] VOIX: le retour du médecin contrôleur
-[00:19:53] VOIX: ou après le retour
-[00:19:54] VOIX: de l'UCR,
-[00:19:55] VOIX: etc.,
-[00:19:55] VOIX: tu argumentes
-[00:19:57] VOIX: comme un avocat.
-[00:19:58] VOIX: Tu contre-argumentes,
-[00:19:59] VOIX: en fait.
-[00:20:00] VOIX: C'est-à-dire,
-[00:20:01] VOIX: tu utilises
-[00:20:02] VOIX: les refus
-[00:20:03] VOIX: ou les arguments
-[00:20:04] VOIX: du médecin contrôleur
-[00:20:05] VOIX: pour les contre-balancer.
-[00:20:08] VOIX: Et à l'étape avant,
-[00:20:11] VOIX: tu ne les as pas,
-[00:20:12] VOIX: les arguments
-[00:20:12] VOIX: du médecin contrôleur.
-[00:20:13] VOIX: Tu peux les imaginer.
-[00:20:15] VOIX: Et c'est pour ça
-[00:20:15] VOIX: que le travail
-[00:20:17] VOIX: sur l'après,
-[00:20:18] VOIX: ça permet aussi
-[00:20:18] VOIX: d'imaginer
-[00:20:19] VOIX: les contre-argumentaires
-[00:20:21] VOIX: que va utiliser
-[00:20:22] VOIX: le médecin contrôleur
-[00:20:24] VOIX: et l'UCR.
-[00:20:30] VOIX: Ça te va un peu
-[00:20:32] VOIX: la manière
-[00:20:32] VOIX: pour commencer à...
-[00:20:34] VOIX: Mais le document
-[00:20:35] VOIX: que tu as fait,
-[00:20:36] VOIX: en fait,
-[00:20:37] VOIX: quand tu penses
-[00:20:38] VOIX: l'avoir fini,
-[00:20:38] VOIX: c'est bien
-[00:20:39] VOIX: si tu me le transmets.
-[00:20:41] VOIX: Bien sûr,
-[00:20:42] VOIX: c'est fait pour ça.
-[00:20:42] VOIX: Voilà,
-[00:20:43] VOIX: parce qu'en fait,
-[00:20:43] VOIX: moi,
-[00:20:43] VOIX: ça me guide
-[00:20:45] VOIX: et puis avec Ralid,
-[00:20:46] VOIX: comme ça,
-[00:20:46] VOIX: on pourra travailler dessus
-[00:20:47] VOIX: parce que si
-[00:20:47] VOIX: ce n'est pas lui non plus,
-[00:20:49] VOIX: voilà,
-[00:20:50] VOIX: ça me guide.
-[00:20:50] VOIX: Bon,
-[00:20:51] VOIX: ce qu'il y a de bien,
-[00:20:52] VOIX: en fait,
-[00:20:52] VOIX: c'est que la mécanique
-[00:20:53] VOIX: en fait que tu m'as décrite,
-[00:20:55] VOIX: alors en dehors,
-[00:20:55] VOIX: en fait,
-[00:20:56] VOIX: de tout acronyme,
-[00:20:57] VOIX: de choses comme ça,
-[00:20:57] VOIX: en fait,
-[00:20:58] VOIX: je l'avais...
-[00:20:58] VOIX: Voilà,
-[00:21:00] VOIX: je l'avais bien saisi.
-[00:21:02] VOIX: Je l'ai bien saisi
-[00:21:04] VOIX: et là,
-[00:21:04] VOIX: maintenant,
-[00:21:05] VOIX: en fait,
-[00:21:05] VOIX: est-ce que j'ai le droit
-[00:21:05] VOIX: de te poser des questions ?
-[00:21:07] VOIX: Bien sûr,
-[00:21:08] VOIX: c'est là,
-[00:21:09] VOIX: parce que,
-[00:21:09] VOIX: évidemment,
-[00:21:10] VOIX: mais c'était aussi
-[00:21:11] VOIX: pour te montrer
-[00:21:12] VOIX: que moi,
-[00:21:12] VOIX: j'ai compris aussi
-[00:21:13] VOIX: sur quoi tu travailles,
-[00:21:15] VOIX: comment tu raisons,
-[00:21:16] VOIX: etc.
-[00:21:17] VOIX: C'est pour ça
-[00:21:18] VOIX: que l'autre jour,
-[00:21:18] VOIX: j'ai voulu te montrer
-[00:21:19] VOIX: un peu la mécanique.
-[00:21:22] VOIX: Je ne cherchais pas,
-[00:21:23] VOIX: en fait,
-[00:21:23] VOIX: que le code,
-[00:21:24] VOIX: en fait,
-[00:21:24] VOIX: parce que je n'étais pas là,
-[00:21:26] VOIX: c'était la mécanique.
-[00:21:27] VOIX: Voilà.
-[00:21:28] VOIX: Une fois,
-[00:21:29] VOIX: en fait,
-[00:21:29] VOIX: que la mécanique,
-[00:21:30] VOIX: en fait,
-[00:21:30] VOIX: elle est bonne.
-[00:21:30] VOIX: Après,
-[00:21:31] VOIX: effectivement,
-[00:21:31] VOIX: en fait,
-[00:21:31] VOIX: il y a tout le travail
-[00:21:32] VOIX: que je suis en train de faire.
-[00:21:34] VOIX: Je ne l'ai pas encore fini,
-[00:21:35] VOIX: donc on verra ça par la suite,
-[00:21:37] VOIX: mais j'ai des questions
-[00:21:38] VOIX: justement pour le finir
-[00:21:39] VOIX: parce qu'en fait,
-[00:21:40] VOIX: j'ai besoin,
-[00:21:41] VOIX: en fait,
-[00:21:41] VOIX: d'un parti métier.
-[00:21:42] VOIX: C'est ce que je disais
-[00:21:43] VOIX: tout à l'heure.
-[00:21:43] VOIX: Alors,
-[00:21:44] VOIX: j'ai un raisonnement
-[00:21:45] VOIX: qui est basé uniquement,
-[00:21:47] VOIX: en fait,
-[00:21:47] VOIX: sur la documentation,
-[00:21:48] VOIX: en fait,
-[00:21:48] VOIX: que j'ai.
-[00:22:00] VOIX: Oui.
-[00:22:01] VOIX: Et moi,
-[00:22:01] VOIX: par exemple,
-[00:22:02] VOIX: si on prend un exemple
-[00:22:03] VOIX: de dossier,
-[00:22:06] VOIX: on pourra le lire ensemble.
-[00:22:07] VOIX: Je ne sais pas comment
-[00:22:08] VOIX: tu veux procéder
-[00:22:09] VOIX: et tu montrerai les règles.
-[00:22:11] VOIX: Alors,
-[00:22:12] VOIX: voyons,
-[00:22:12] VOIX: comment on peut faire ?
-[00:22:14] VOIX: Toi,
-[00:22:14] VOIX: tu en as un sous la main
-[00:22:16] VOIX: ou alors
-[00:22:16] VOIX: je me fournis un,
-[00:22:17] VOIX: moi ?
-[00:22:19] VOIX: Fournis un.
-[00:22:19] VOIX: Fournis un
-[00:22:21] VOIX: que tu trouves
-[00:22:22] VOIX: au hasard,
-[00:22:23] VOIX: en fait,
-[00:22:24] VOIX: parce qu'on n'est pas obligé
-[00:22:24] VOIX: de travailler
-[00:22:25] VOIX: sur des dossiers complexes.
-[00:22:27] VOIX: Ce qui soit complexe
-[00:22:28] VOIX: ou simple,
-[00:22:28] VOIX: l'essentiel,
-[00:22:29] VOIX: c'est que
-[00:22:31] VOIX: je puisse te montrer
-[00:22:32] VOIX: avec le guide méthodologique
-[00:22:33] VOIX: parce que je ne sais pas
-[00:22:34] VOIX: si tu l'as exploité.
-[00:22:35] VOIX: Oui,
-[00:22:36] VOIX: je l'ai exploité.
-[00:22:37] VOIX: En fait,
-[00:22:37] VOIX: je l'ai intégré
-[00:22:38] VOIX: dans le système
-[00:22:39] VOIX: mais ce qui me manque,
-[00:22:40] VOIX: en fait,
-[00:22:40] VOIX: c'est le guide méthodologique,
-[00:22:43] VOIX: il est très bien,
-[00:22:44] VOIX: certainement.
-[00:22:44] VOIX: Sauf que le problème,
-[00:22:45] VOIX: c'est que si je n'ai pas
-[00:22:46] VOIX: en fait ta lecture à toi,
-[00:22:48] VOIX: la lecture d'un professionnel,
-[00:22:50] VOIX: d'un expert,
-[00:22:51] VOIX: ça ne marche pas,
-[00:22:52] VOIX: en fait,
-[00:22:52] VOIX: en réalité.
-[00:22:53] VOIX: Je me suis rendu compte
-[00:22:53] VOIX: d'une chose,
-[00:22:54] VOIX: c'est que si je suis
-[00:22:55] VOIX: pile-poil l'air,
-[00:22:56] VOIX: ça ne marche pas.
-[00:22:57] VOIX: Alors,
-[00:22:57] VOIX: attends,
-[00:22:57] VOIX: j'essaye de te partager
-[00:22:58] VOIX: mon écran.
-[00:22:58] VOIX: Attends,
-[00:22:59] VOIX: cherche un dossier.
-[00:23:01] VOIX: Attends,
-[00:23:01] VOIX: attends,
-[00:23:02] VOIX: est-ce que je ne peux pas
-[00:23:03] VOIX: en fait partager mon écran ?
-[00:23:05] VOIX: C'est en haut à droite.
-[00:23:07] VOIX: Non,
-[00:23:07] VOIX: mais je l'ai trouvé
-[00:23:08] VOIX: mais quand je clique dessus,
-[00:23:10] VOIX: en fait,
-[00:23:10] VOIX: je ne peux pas.
-[00:23:11] VOIX: Est-ce que toi,
-[00:23:12] VOIX: tu n'as pas un truc
-[00:23:14] VOIX: en fait qui se...
-[00:23:16] VOIX: Ou alors,
-[00:23:16] VOIX: c'est mon ordinateur
-[00:23:17] VOIX: en fait qui a un problème.
-[00:23:18] VOIX: Ah,
-[00:23:19] VOIX: putain,
-[00:23:19] VOIX: ce n'est pas vrai.
-[00:23:21] VOIX: Et puis,
-[00:23:22] VOIX: en plus,
-[00:23:22] VOIX: je veux le faire
-[00:23:22] VOIX: à partir de mon truc.
-[00:23:24] VOIX: Je vais fermer
-[00:23:25] VOIX: en fait la fenêtre,
-[00:23:26] VOIX: enfin,
-[00:23:27] VOIX: pas celle d'en bas,
-[00:23:28] VOIX: et je reviens de suite.
-[00:23:30] VOIX: Ok,
-[00:23:31] VOIX: il n'y a pas de sens.
-[00:23:31] VOIX: Je reviens de suite.
-[00:23:33] VOIX: Allez,
-[00:23:34] VOIX: de suite.
-[00:23:35] VOIX: Ah oui,
-[00:23:35] VOIX: mais il a planté mon truc.
-[00:23:38] VOIX: Ah ben,
-[00:23:39] VOIX: c'est ballot
-[00:23:39] VOIX: pour un informaticien.
-[00:23:45] VOIX: Ah.
-[00:24:22] VOIX: me revoilà.
-[00:24:23] VOIX: Alors,
-[00:24:23] VOIX: alors voyons si ça marche.
-[00:24:24] VOIX: Oui,
-[00:24:24] VOIX: ça marche.
-[00:24:25] VOIX: Allez,
-[00:24:26] VOIX: hop,
-[00:24:26] VOIX: je te partage tout mon écran.
-[00:24:30] VOIX: Voilà.
-[00:24:31] VOIX: Voilà,
-[00:24:31] VOIX: normalement.
-[00:24:32] VOIX: Nickel.
-[00:24:33] VOIX: Voilà,
-[00:24:33] VOIX: là,
-[00:24:33] VOIX: maintenant,
-[00:24:33] VOIX: tu me vois.
-[00:24:34] VOIX: Et là,
-[00:24:34] VOIX: on va prendre un dossier.
-[00:24:35] VOIX: Alors,
-[00:24:36] VOIX: je prends un dossier au pif.
-[00:24:37] VOIX: D'accord ?
-[00:24:38] VOIX: D'ailleurs,
-[00:24:39] VOIX: en passant,
-[00:24:40] VOIX: en fait,
-[00:24:40] VOIX: je te montre.
-[00:24:41] VOIX: Là,
-[00:24:41] VOIX: moi,
-[00:24:41] VOIX: ça,
-[00:24:41] VOIX: c'est le dossier complet.
-[00:24:42] VOIX: Donc,
-[00:24:43] VOIX: en fait,
-[00:24:43] VOIX: j'ai le CRH
-[00:24:45] VOIX: et le tracker.
-[00:24:46] VOIX: D'accord ?
-[00:24:47] VOIX: Et moi,
-[00:24:47] VOIX: ce que je fais...
-[00:24:48] VOIX: et parfois,
-[00:24:49] VOIX: pardon,
-[00:24:50] VOIX: parfois,
-[00:24:50] VOIX: dans le tracker,
-[00:24:52] VOIX: tu peux avoir le CRH intégré
-[00:24:54] VOIX: dans le truc complet.
-[00:24:56] VOIX: J'ai remarqué...
-[00:24:57] VOIX: Là,
-[00:24:58] VOIX: ça s'appelle CRH à part
-[00:24:59] VOIX: parce qu'on l'a ajouté
-[00:25:00] VOIX: a posteriori.
-[00:25:01] VOIX: Il n'a pas été extrait
-[00:25:04] VOIX: initialement.
-[00:25:05] VOIX: Mais ça,
-[00:25:05] VOIX: à la rigueur,
-[00:25:06] VOIX: c'est du détail
-[00:25:06] VOIX: parce qu'on n'aura pas
-[00:25:07] VOIX: cette méthode partout
-[00:25:08] VOIX: dans les autres établissements.
-[00:25:10] VOIX: Ce n'est pas du tout ça
-[00:25:11] VOIX: l'essentiel.
-[00:25:12] VOIX: D'accord.
-[00:25:13] VOIX: Bon,
-[00:25:13] VOIX: je fais,
-[00:25:13] VOIX: en fait,
-[00:25:14] VOIX: une petite passe
-[00:25:14] VOIX: pour que tu visualises.
-[00:25:16] VOIX: Ici,
-[00:25:17] VOIX: en fait,
-[00:25:17] VOIX: tu vois,
-[00:25:18] VOIX: le dossier,
-[00:25:19] VOIX: c'est ça,
-[00:25:20] VOIX: c'est d'autres documents
-[00:25:21] VOIX: et l'IA,
-[00:25:22] VOIX: en fait,
-[00:25:23] VOIX: ne fonctionne que sur ça.
-[00:25:25] VOIX: Donc,
-[00:25:25] VOIX: qui sont des documents
-[00:25:26] VOIX: qui sont anonymisés.
-[00:25:28] VOIX: Et je te montre,
-[00:25:28] VOIX: tu vois,
-[00:25:29] VOIX: le PDF.
-[00:25:30] VOIX: Donc,
-[00:25:30] VOIX: en fait,
-[00:25:31] VOIX: c'est ça.
-[00:25:32] VOIX: Tu vois,
-[00:25:33] VOIX: en fait,
-[00:25:33] VOIX: tous les noms,
-[00:25:34] VOIX: les prénoms,
-[00:25:35] VOIX: les dates,
-[00:25:35] VOIX: certaines dates,
-[00:25:36] VOIX: certaines choses,
-[00:25:37] VOIX: en fait,
-[00:25:37] VOIX: sont anonymisés.
-[00:25:38] VOIX: Ça,
-[00:25:38] VOIX: il faut que j'aie un retour
-[00:25:39] VOIX: de la part de Jordan
-[00:25:40] VOIX: parce qu'en fait,
-[00:25:40] VOIX: il avait des règles
-[00:25:41] VOIX: à me passer.
-[00:25:42] VOIX: Et donc,
-[00:25:43] VOIX: si tu veux,
-[00:25:43] VOIX: tout le système...
-[00:25:45] VOIX: Oui.
-[00:25:46] VOIX: Ce qu'il faudra faire
-[00:25:47] VOIX: si tu fais ça,
-[00:25:50] VOIX: c'est insérer
-[00:25:53] VOIX: le numéro OGC.
-[00:25:56] VOIX: D'accord.
-[00:25:59] VOIX: D'accord.
-[00:26:00] VOIX: Parce que si jamais,
-[00:26:02] VOIX: typiquement,
-[00:26:04] VOIX: tu vois,
-[00:26:05] VOIX: moi,
-[00:26:05] VOIX: si je veux envoyer
-[00:26:06] VOIX: les documents
-[00:26:07] VOIX: à l'ARS
-[00:26:08] VOIX: ou à la saisie
-[00:26:09] VOIX: d'un atelier,
-[00:26:10] VOIX: etc.,
-[00:26:11] VOIX: c'est vrai que
-[00:26:12] VOIX: stabiloter,
-[00:26:13] VOIX: on le faisait manuellement
-[00:26:14] VOIX: il y a un moment.
-[00:26:15] VOIX: après,
-[00:26:16] VOIX: on ajoutait
-[00:26:17] VOIX: OGC temps
-[00:26:18] VOIX: pour que ça soit
-[00:26:19] VOIX: identifié...
-[00:26:20] VOIX: Pour moi,
-[00:26:20] VOIX: en fait,
-[00:26:20] VOIX: ce n'est pas très compliqué
-[00:26:21] VOIX: parce que j'ai le numéro,
-[00:26:23] VOIX: en fait,
-[00:26:23] VOIX: les documents sont numérotés
-[00:26:24] VOIX: puisque ça,
-[00:26:25] VOIX: c'est le dossier,
-[00:26:26] VOIX: en fait,
-[00:26:26] VOIX: OGC,
-[00:26:27] VOIX: en fait...
-[00:26:27] VOIX: Non,
-[00:26:27] VOIX: c'est l'OGC 1.
-[00:26:28] VOIX: Tu sais,
-[00:26:29] VOIX: dans 1,
-[00:26:30] VOIX: tu as deux infos.
-[00:26:31] VOIX: Tu as l'OGC,
-[00:26:33] VOIX: tu as le numéro dossier.
-[00:26:35] VOIX: D'accord.
-[00:26:36] VOIX: L'OGC,
-[00:26:36] VOIX: c'est le chiffre
-[00:26:38] VOIX: au départ.
-[00:26:39] VOIX: 1,
-[00:26:39] VOIX: 2,
-[00:26:39] VOIX: 3,
-[00:26:39] VOIX: 4,
-[00:26:40] VOIX: 5,
-[00:26:40] VOIX: 6,
-[00:26:40] VOIX: 7,
-[00:26:41] VOIX: 7,
-[00:26:41] VOIX: 7,
-[00:26:42] VOIX: 7,
-[00:26:42] VOIX: 7,
-[00:27:04] VOIX: D'accord.
-[00:27:05] VOIX: Ça,
-[00:27:05] VOIX: c'est fait,
-[00:27:06] VOIX: en fait.
-[00:27:06] VOIX: C'est...
-[00:27:09] VOIX: Le classement,
-[00:27:10] VOIX: numérotation,
-[00:27:11] VOIX: ce genre de choses-là,
-[00:27:12] VOIX: ça va.
-[00:27:13] VOIX: Donc,
-[00:27:14] VOIX: ah oui,
-[00:27:15] VOIX: il fallait que j'ouvre,
-[00:27:16] VOIX: en fait,
-[00:27:16] VOIX: un document.
-[00:27:16] VOIX: Alors,
-[00:27:16] VOIX: je vais ouvrir,
-[00:27:17] VOIX: en fait,
-[00:27:17] VOIX: le document propre.
-[00:27:19] VOIX: D'accord?
-[00:27:20] VOIX: Ouvre le CRH
-[00:27:21] VOIX: parce que parfois,
-[00:27:21] VOIX: les CRH,
-[00:27:22] VOIX: quand ils sont bien faits,
-[00:27:24] VOIX: ils peuvent résumer
-[00:27:26] VOIX: la situation.
-[00:27:27] VOIX: Donc là,
-[00:27:28] VOIX: on est sur le CRH.
-[00:27:30] VOIX: Attends,
-[00:27:31] VOIX: j'essaie,
-[00:27:31] VOIX: en fait,
-[00:27:32] VOIX: de lire le truc.
-[00:27:33] VOIX: C'est le 2304,
-[00:27:38] VOIX: 2753.pdf.
-[00:27:39] VOIX: Je le dis
-[00:27:39] VOIX: parce que
-[00:27:39] VOIX: comment j'enregistre,
-[00:27:40] VOIX: en fait,
-[00:27:40] VOIX: comme ça,
-[00:27:41] VOIX: je vais retrouver le dossier.
-[00:27:43] VOIX: D'accord.
-[00:27:44] VOIX: OK.
-[00:27:44] VOIX: Je comprends.
-[00:27:45] VOIX: Donc,
-[00:27:46] VOIX: voilà.
-[00:27:46] VOIX: Donc là,
-[00:27:47] VOIX: par quoi,
-[00:27:48] VOIX: en fait,
-[00:27:48] VOIX: tu commences...
-[00:27:50] VOIX: Pardon,
-[00:27:51] VOIX: Dominique,
-[00:27:51] VOIX: il y a Jordan qui m'appelle
-[00:27:52] VOIX: parce que...
-[00:27:53] VOIX: Deux minutes,
-[00:27:54] VOIX: s'il te plaît.
-[00:27:54] VOIX: Vas-y,
-[00:27:55] VOIX: vas-y,
-[00:27:55] VOIX: vas-y.
-[00:27:56] VOIX: Oui,
-[00:27:57] VOIX: Jordan.
-[00:28:01] VOIX: Je vais dire...
-[00:28:02] VOIX: Voilà.
-[00:28:03] VOIX: Qu'est-ce que tu regardes
-[00:28:05] VOIX: en premier ?
-[00:28:06] VOIX: En premier,
-[00:28:07] VOIX: je vais regarder
-[00:28:08] VOIX: le nom du patient
-[00:28:10] VOIX: ou de la patiente
-[00:28:11] VOIX: pour vérifier que...
-[00:28:14] VOIX: Ça,
-[00:28:15] VOIX: je te parle
-[00:28:15] VOIX: du primo-codage
-[00:28:16] VOIX: habituel.
-[00:28:17] VOIX: D'accord ?
-[00:28:18] VOIX: Je regarde
-[00:28:18] VOIX: le numéro de dossier,
-[00:28:19] VOIX: ce qu'on appelle
-[00:28:20] VOIX: l'identité de vigilance.
-[00:28:21] VOIX: C'est-à-dire,
-[00:28:22] VOIX: on s'assure
-[00:28:22] VOIX: qu'on est en train
-[00:28:24] VOIX: de coder
-[00:28:24] VOIX: parce que parfois,
-[00:28:26] VOIX: dans des dossiers,
-[00:28:27] VOIX: on a des documents
-[00:28:28] VOIX: qui sont des patients
-[00:28:30] VOIX: erronés.
-[00:28:31] VOIX: Donc,
-[00:28:32] VOIX: on regarde les noms,
-[00:28:34] VOIX: on regarde les dates
-[00:28:34] VOIX: de naissance,
-[00:28:35] VOIX: on regarde les dates
-[00:28:36] VOIX: d'hospitalisation
-[00:28:37] VOIX: pour vérifier
-[00:28:39] VOIX: qu'on est en cohérence
-[00:28:40] VOIX: entre les différents documents,
-[00:28:41] VOIX: en fait.
-[00:28:41] VOIX: Donc,
-[00:28:42] VOIX: il y a un contrôle
-[00:28:44] VOIX: un peu
-[00:28:45] VOIX: de s'assurer
-[00:28:45] VOIX: qu'il n'y a pas
-[00:28:47] VOIX: d'erreur
-[00:28:47] VOIX: sur le fichier.
-[00:28:51] VOIX: De suite,
-[00:28:53] VOIX: moi,
-[00:28:53] VOIX: en tout cas,
-[00:28:53] VOIX: je vais dire
-[00:28:54] VOIX: OK,
-[00:28:55] VOIX: dans le service
-[00:28:56] VOIX: du 25 février
-[00:28:57] VOIX: ou 3 mars,
-[00:28:59] VOIX: donc,
-[00:28:59] VOIX: je sais
-[00:28:59] VOIX: qu'il est resté
-[00:29:00] VOIX: 7 jours
-[00:29:01] VOIX: et donc,
-[00:29:02] VOIX: je sais
-[00:29:02] VOIX: que mon enjeu,
-[00:29:04] VOIX: il va être
-[00:29:05] VOIX: autour du diagnostic
-[00:29:06] VOIX: principal,
-[00:29:07] VOIX: du diagnostic associé
-[00:29:08] VOIX: et des axes.
-[00:29:09] VOIX: Donc,
-[00:29:09] VOIX: il faut que je regarde
-[00:29:09] VOIX: le dossier
-[00:29:10] VOIX: dans sa globalité.
-[00:29:12] VOIX: Tu vois,
-[00:29:13] VOIX: la durée de ses jours,
-[00:29:14] VOIX: il me donne aussi
-[00:29:14] VOIX: une indication
-[00:29:17] VOIX: sur les objectifs
-[00:29:18] VOIX: parce que si j'avais
-[00:29:19] VOIX: un dossier de 2 jours,
-[00:29:20] VOIX: je vais aller
-[00:29:21] VOIX: beaucoup plus vite
-[00:29:22] VOIX: et je vais raisonner
-[00:29:23] VOIX: uniquement sur
-[00:29:23] VOIX: pourquoi il a été
-[00:29:24] VOIX: hospitalisé.
-[00:29:25] VOIX: Je ne vais pas
-[00:29:26] VOIX: raisonner sur
-[00:29:27] VOIX: les niveaux de sévérité,
-[00:29:28] VOIX: etc.
-[00:29:28] VOIX: Même si ça,
-[00:29:29] VOIX: je le mets
-[00:29:30] VOIX: en aparté,
-[00:29:31] VOIX: tu vois,
-[00:29:31] VOIX: ce n'est pas
-[00:29:31] VOIX: le raisonnement
-[00:29:32] VOIX: qu'il faut apprendre
-[00:29:35] VOIX: à l'IA,
-[00:29:36] VOIX: il faut lui apprendre
-[00:29:36] VOIX: à raisonner sur le dossier
-[00:29:38] VOIX: dans sa globalité.
-[00:29:40] VOIX: Si j'en fiche un peu
-[00:29:40] VOIX: sur certains dossiers,
-[00:29:42] VOIX: on sait qu'on peut
-[00:29:43] VOIX: les valoriser
-[00:29:44] VOIX: uniquement avec le DP.
-[00:29:46] VOIX: Donc,
-[00:29:47] VOIX: ça,
-[00:29:47] VOIX: c'est la première ligne
-[00:29:48] VOIX: que je vais regarder,
-[00:29:50] VOIX: vérifier,
-[00:29:50] VOIX: etc.
-[00:29:52] VOIX: Je vais regarder le MH,
-[00:29:56] VOIX: c'est motif d'hospitalisation.
-[00:29:58] VOIX: En fait,
-[00:29:59] VOIX: soit qu'il est clairement,
-[00:30:02] VOIX: d'accord,
-[00:30:03] VOIX: soit,
-[00:30:03] VOIX: il faut le deviner,
-[00:30:05] VOIX: c'est toute la difficulté.
-[00:30:08] VOIX: Donc là,
-[00:30:08] VOIX: j'ai une douleur.
-[00:30:10] VOIX: Une douleur,
-[00:30:11] VOIX: ça veut dire,
-[00:30:12] VOIX: c'est un symptôme.
-[00:30:13] VOIX: D'accord ?
-[00:30:14] VOIX: Tu te rappelleras
-[00:30:14] VOIX: de ce terme
-[00:30:15] VOIX: pour que tu vois
-[00:30:16] VOIX: le raisonnement après
-[00:30:17] VOIX: et pourquoi,
-[00:30:18] VOIX: quelles règles
-[00:30:19] VOIX: on applique
-[00:30:19] VOIX: dans ce cas-là.
-[00:30:20] VOIX: Donc,
-[00:30:21] VOIX: en fait,
-[00:30:22] VOIX: j'ose imaginer
-[00:30:24] VOIX: que ce patient-là,
-[00:30:26] VOIX: il vient
-[00:30:27] VOIX: parce qu'il a
-[00:30:28] VOIX: un diagnostic inconnu,
-[00:30:30] VOIX: il a mal
-[00:30:31] VOIX: et l'établissement,
-[00:30:34] VOIX: le médecin,
-[00:30:34] VOIX: il va faire
-[00:30:35] VOIX: en sorte
-[00:30:36] VOIX: d'identifier
-[00:30:37] VOIX: pourquoi le patient,
-[00:30:38] VOIX: il a mal
-[00:30:39] VOIX: et de traiter
-[00:30:41] VOIX: en fonction
-[00:30:42] VOIX: de la pathologie
-[00:30:43] VOIX: ou de traiter
-[00:30:44] VOIX: la douleur.
-[00:30:45] VOIX: Ça,
-[00:30:45] VOIX: c'est la démarche
-[00:30:46] VOIX: habituelle.
-[00:30:47] VOIX: En même temps,
-[00:30:49] VOIX: tu vas avoir
-[00:30:51] VOIX: antécédent,
-[00:30:53] VOIX: c'est un patient
-[00:30:55] VOIX: qui est jeune,
-[00:30:55] VOIX: déjà,
-[00:30:56] VOIX: tu vois,
-[00:30:56] VOIX: ah oui,
-[00:30:57] VOIX: je regarde aussi l'âge
-[00:30:58] VOIX: parce que si
-[00:30:59] VOIX: tu as un patient
-[00:30:59] VOIX: âgé,
-[00:31:01] VOIX: tu sauras
-[00:31:02] VOIX: qu'il va avoir
-[00:31:03] VOIX: une liste
-[00:31:03] VOIX: d'antécédents,
-[00:31:04] VOIX: d'histoires
-[00:31:05] VOIX: qui est extrêmement
-[00:31:08] VOIX: riche
-[00:31:08] VOIX: alors que 80,
-[00:31:10] VOIX: tu dis bon,
-[00:31:10] VOIX: il est jeune,
-[00:31:11] VOIX: il a mal au ventre,
-[00:31:13] VOIX: ça doit être
-[00:31:13] VOIX: une appendicite
-[00:31:15] VOIX: ou un truc
-[00:31:15] VOIX: qui va être opéré,
-[00:31:17] VOIX: tu vois.
-[00:31:18] VOIX: Moi,
-[00:31:18] VOIX: ça,
-[00:31:19] VOIX: c'est mon raisonnement
-[00:31:19] VOIX: un peu médical
-[00:31:20] VOIX: en disant
-[00:31:21] VOIX: qu'il est jeune
-[00:31:22] VOIX: et qu'il a déjà
-[00:31:24] VOIX: fait une crise
-[00:31:25] VOIX: de colique hépatique
-[00:31:26] VOIX: avec vésiculitis,
-[00:31:28] VOIX: lithiasique,
-[00:31:28] VOIX: donc c'est peut-être
-[00:31:29] VOIX: qu'il a une colécystite
-[00:31:31] VOIX: aiguë,
-[00:31:31] VOIX: colécystite,
-[00:31:32] VOIX: c'est la vésicule
-[00:31:33] VOIX: biliaire
-[00:31:34] VOIX: qui est inflammée
-[00:31:35] VOIX: d'autant plus
-[00:31:35] VOIX: que là,
-[00:31:37] VOIX: on est en février,
-[00:31:38] VOIX: en janvier,
-[00:31:39] VOIX: je regarde ça aussi,
-[00:31:40] VOIX: c'est-à-dire là,
-[00:31:41] VOIX: on est en février,
-[00:31:42] VOIX: il est rentré
-[00:31:42] VOIX: en février 2023
-[00:31:44] VOIX: et un mois avant,
-[00:31:45] VOIX: même pas un mois,
-[00:31:46] VOIX: il a eu une colique hépatique
-[00:31:49] VOIX: avec vésicule lithiasique
-[00:31:50] VOIX: sur les côtes.
-[00:31:51] VOIX: Donc probablement,
-[00:31:52] VOIX: le médecin,
-[00:31:53] VOIX: en tout cas,
-[00:31:53] VOIX: qui le prend en charge,
-[00:31:54] VOIX: il va raisonner
-[00:31:54] VOIX: que c'est ce qu'on appelle
-[00:31:55] VOIX: une récidive,
-[00:31:56] VOIX: c'est-à-dire la lithiase
-[00:32:00] VOIX: qui a migré
-[00:32:01] VOIX: ou voilà,
-[00:32:01] VOIX: un problème en tout cas
-[00:32:02] VOIX: digestif.
-[00:32:03] VOIX: Il n'a pas de traitement
-[00:32:04] VOIX: en cours,
-[00:32:05] VOIX: il n'a pas d'allergie,
-[00:32:07] VOIX: histoire de la maladie,
-[00:32:08] VOIX: donc fin janvier,
-[00:32:10] VOIX: premier épisode
-[00:32:11] VOIX: de colique hépatique
-[00:32:12] VOIX: prise en charge
-[00:32:13] VOIX: aux urgences
-[00:32:14] VOIX: de Saint-Palais,
-[00:32:16] VOIX: lithiase,
-[00:32:16] VOIX: tu vois,
-[00:32:17] VOIX: ils avaient déjà
-[00:32:17] VOIX: diagnostiqué
-[00:32:18] VOIX: des lithiases
-[00:32:19] VOIX: vésiculaires,
-[00:32:19] VOIX: ça veut dire
-[00:32:20] VOIX: lithiase de la vésicule
-[00:32:21] VOIX: biliaire
-[00:32:22] VOIX: sur éco-échographie
-[00:32:24] VOIX: en ville,
-[00:32:25] VOIX: la patiente
-[00:32:26] VOIX: devait revoir
-[00:32:27] VOIX: son médecin traitant
-[00:32:28] VOIX: début mars
-[00:32:28] VOIX: pour évoquer
-[00:32:29] VOIX: la suite
-[00:32:29] VOIX: de la prise en charge.
-[00:32:31] VOIX: Mais le 24 février
-[00:32:32] VOIX: vers 20 heures,
-[00:32:33] VOIX: récidive,
-[00:32:34] VOIX: récidive,
-[00:32:35] VOIX: c'est-à-dire de nouveau,
-[00:32:37] VOIX: récidive
-[00:32:39] VOIX: des douleurs
-[00:32:40] VOIX: ou de la maladie,
-[00:32:41] VOIX: etc.
-[00:32:42] VOIX: En hippocondre droit,
-[00:32:44] VOIX: l'hippocondre droit,
-[00:32:45] VOIX: c'est la partie
-[00:32:45] VOIX: haute droite
-[00:32:46] VOIX: du ventre,
-[00:32:47] VOIX: ne cédant pas,
-[00:32:49] VOIX: arrive hyperalgique,
-[00:32:51] VOIX: c'est-à-dire très douloureux
-[00:32:53] VOIX: aux urgences,
-[00:32:54] VOIX: pas de signe
-[00:32:55] VOIX: d'irritation péritoniale.
-[00:32:57] VOIX: L'irritation péritoniale,
-[00:32:59] VOIX: ça peut évoquer
-[00:33:00] VOIX: une péritonite,
-[00:33:01] VOIX: c'est-à-dire
-[00:33:01] VOIX: quand ils vont palper
-[00:33:02] VOIX: le ventre,
-[00:33:03] VOIX: ils vont trouver
-[00:33:03] VOIX: qu'il est dur.
-[00:33:04] VOIX: Donc là,
-[00:33:05] VOIX: bon,
-[00:33:05] VOIX: il n'y a rien
-[00:33:06] VOIX: qui les alerte
-[00:33:07] VOIX: d'un point de vue
-[00:33:08] VOIX: examen de l'abdomen.
-[00:33:10] VOIX: Hémodynamique stable,
-[00:33:11] VOIX: c'est-à-dire la tension,
-[00:33:12] VOIX: la fréquence cardiaque,
-[00:33:13] VOIX: etc.,
-[00:33:14] VOIX: c'est ça,
-[00:33:14] VOIX: l'hémodynamique.
-[00:33:15] VOIX: Donc la patiente stable.
-[00:33:18] VOIX: Ils ont fait
-[00:33:18] VOIX: un bilan biologique,
-[00:33:20] VOIX: tout ça,
-[00:33:20] VOIX: et regarde,
-[00:33:23] VOIX: ça,
-[00:33:23] VOIX: il te dit
-[00:33:24] VOIX: fin janvier,
-[00:33:25] VOIX: donc du coup,
-[00:33:26] VOIX: les deux premières phrases,
-[00:33:28] VOIX: c'était
-[00:33:29] VOIX: aux urgences
-[00:33:30] VOIX: de Saint-Palais,
-[00:33:31] VOIX: il est rentré
-[00:33:32] VOIX: chez lui
-[00:33:32] VOIX: et il devait voir
-[00:33:34] VOIX: la suite
-[00:33:34] VOIX: avec son médecin traitant,
-[00:33:36] VOIX: donc c'est
-[00:33:36] VOIX: l'ancien,
-[00:33:38] VOIX: c'est ancien.
-[00:33:39] VOIX: Le 24 février,
-[00:33:41] VOIX: il a été hospitalisé
-[00:33:42] VOIX: au CHCB,
-[00:33:43] VOIX: le 25 février,
-[00:33:45] VOIX: il est passé
-[00:33:46] VOIX: aux urgences
-[00:33:49] VOIX: du CHCB,
-[00:33:50] VOIX: je pense,
-[00:33:52] VOIX: donc il est rentré
-[00:33:53] VOIX: via les urgences,
-[00:33:54] VOIX: le 24,
-[00:33:56] VOIX: c'est ça ou pas ?
-[00:33:57] VOIX: Oui,
-[00:33:58] VOIX: oui,
-[00:33:58] VOIX: oui.
-[00:33:59] VOIX: Point d'interrogation.
-[00:34:00] VOIX: Oui.
-[00:34:00] VOIX: Parce que là,
-[00:34:01] VOIX: la date d'entrée,
-[00:34:02] VOIX: c'est le 25,
-[00:34:04] VOIX: et là,
-[00:34:04] VOIX: il parle du 24 à 20 heures.
-[00:34:06] VOIX: Oui,
-[00:34:06] VOIX: c'est ça.
-[00:34:07] VOIX: En fait,
-[00:34:07] VOIX: en fait,
-[00:34:07] VOIX: ça,
-[00:34:08] VOIX: 24,
-[00:34:08] VOIX: 25,
-[00:34:09] VOIX: il faut savoir
-[00:34:10] VOIX: que nous,
-[00:34:10] VOIX: quand on code,
-[00:34:11] VOIX: on va vérifier
-[00:34:12] VOIX: aussi le dossier
-[00:34:13] VOIX: administratif
-[00:34:13] VOIX: pour vérifier
-[00:34:15] VOIX: est-ce que la date
-[00:34:16] VOIX: d'entrée du séjour,
-[00:34:17] VOIX: c'est le 24 février
-[00:34:17] VOIX: ou c'est le 25 février.
-[00:34:19] VOIX: Et est-ce que
-[00:34:20] VOIX: quand il part
-[00:34:20] VOIX: d'hyper-algie
-[00:34:22] VOIX: aux urgences,
-[00:34:22] VOIX: est-ce que c'est
-[00:34:23] VOIX: les urgences
-[00:34:24] VOIX: du centre hospitalier
-[00:34:25] VOIX: de la Côte-Basque
-[00:34:26] VOIX: ou est-ce que
-[00:34:27] VOIX: ce sont les urgences
-[00:34:28] VOIX: d'un autre établissement
-[00:34:29] VOIX: comme Saint-Palais
-[00:34:30] VOIX: qui l'a vu,
-[00:34:31] VOIX: il est resté
-[00:34:32] VOIX: jusqu'au 25.
-[00:34:33] VOIX: Et ça,
-[00:34:34] VOIX: c'est peut-être
-[00:34:34] VOIX: la suite du dossier
-[00:34:35] VOIX: qui va l'expliquer.
-[00:34:36] VOIX: Peut-être que
-[00:34:37] VOIX: je ne le saurais pas.
-[00:34:39] VOIX: Tu vois,
-[00:34:40] VOIX: je reviens un peu
-[00:34:40] VOIX: pour te dire
-[00:34:41] VOIX: comment il a analysé
-[00:34:43] VOIX: les choses.
-[00:34:44] VOIX: C'est parfait.
-[00:34:45] VOIX: Franchement,
-[00:34:45] VOIX: en fait,
-[00:34:46] VOIX: c'est parfait.
-[00:34:46] VOIX: Mais tu m'en avais déjà parlé.
-[00:34:47] VOIX: On en a parlé,
-[00:34:48] VOIX: je ne sais plus avec qui.
-[00:34:50] VOIX: Effectivement,
-[00:34:50] VOIX: en fait,
-[00:34:50] VOIX: qu'on n'a pas
-[00:34:51] VOIX: de suivi nécessairement
-[00:34:52] VOIX: entre les urgences
-[00:34:55] VOIX: et l'hospitalisation même.
-[00:34:57] VOIX: Non.
-[00:34:58] VOIX: Parce que le patient,
-[00:34:59] VOIX: il peut aller aux urgences,
-[00:34:59] VOIX: de Saint-Palais,
-[00:35:00] VOIX: Saint-Palais,
-[00:35:01] VOIX: il trouve un problème
-[00:35:02] VOIX: et il l'envoie au CHTB
-[00:35:04] VOIX: pour hospitalisation.
-[00:35:05] VOIX: Donc,
-[00:35:05] VOIX: quand il dit
-[00:35:06] VOIX: arriver aux urgences,
-[00:35:08] VOIX: sans préciser,
-[00:35:09] VOIX: le plus probable
-[00:35:10] VOIX: c'est les urgences
-[00:35:11] VOIX: du CHTB,
-[00:35:12] VOIX: mais la date m'alerte.
-[00:35:14] VOIX: Tu vois,
-[00:35:15] VOIX: si c'était le 25,
-[00:35:16] VOIX: je l'aurais été
-[00:35:17] VOIX: moins alertée.
-[00:35:19] VOIX: Donc,
-[00:35:20] VOIX: soit qu'entre le 25 et le 25,
-[00:35:21] VOIX: il est resté
-[00:35:22] VOIX: dans les urgences du CHTB
-[00:35:23] VOIX: et il est monté
-[00:35:25] VOIX: à l'étage,
-[00:35:26] VOIX: là,
-[00:35:26] VOIX: je ne sais pas
-[00:35:27] VOIX: quel service,
-[00:35:28] VOIX: il n'y a pas
-[00:35:28] VOIX: de gastro-entéraux
-[00:35:29] VOIX: et il est monté.
-[00:35:30] VOIX: Donc,
-[00:35:30] VOIX: le médecin gastro-entéraux,
-[00:35:32] VOIX: il résonne par rapport
-[00:35:33] VOIX: à l'heure
-[00:35:34] VOIX: et à la date
-[00:35:34] VOIX: d'entrée dans le service
-[00:35:35] VOIX: et donc,
-[00:35:36] VOIX: le patient,
-[00:35:36] VOIX: il a passé
-[00:35:37] VOIX: une nuit aux urgences.
-[00:35:39] VOIX: C'est le plus probable
-[00:35:40] VOIX: à la lecture
-[00:35:41] VOIX: de...
-[00:35:41] VOIX: Oui, parce qu'en fait,
-[00:35:42] VOIX: oui, parce que ça dit,
-[00:35:44] VOIX: arrive,
-[00:35:47] VOIX: arrive hyperalgique
-[00:35:48] VOIX: aux urgences,
-[00:35:49] VOIX: pas de signe
-[00:35:49] VOIX: d'irritation
-[00:35:52] VOIX: péritoniale.
-[00:35:52] VOIX: Par contre,
-[00:35:53] VOIX: en fait,
-[00:35:53] VOIX: il y a suffisamment
-[00:35:54] VOIX: de précision,
-[00:35:57] VOIX: en fait,
-[00:35:57] VOIX: le bilan biologique,
-[00:35:59] VOIX: en fait,
-[00:35:59] VOIX: s'ils l'ont,
-[00:36:00] VOIX: en fait,
-[00:36:00] VOIX: ça,
-[00:36:01] VOIX: c'est parce qu'en fait,
-[00:36:01] VOIX: ils ont les informations,
-[00:36:02] VOIX: j'imagine.
-[00:36:04] VOIX: Oui,
-[00:36:05] VOIX: normalement,
-[00:36:05] VOIX: il doit y avoir
-[00:36:06] VOIX: dans le dossier,
-[00:36:07] VOIX: donc,
-[00:36:08] VOIX: l'IA,
-[00:36:08] VOIX: les documents,
-[00:36:09] VOIX: ça,
-[00:36:10] VOIX: c'est facile,
-[00:36:11] VOIX: si tu lui donnes l'info,
-[00:36:13] VOIX: vérifier que la patiente
-[00:36:14] VOIX: est passée par les urgences
-[00:36:15] VOIX: de l'établissement
-[00:36:17] VOIX: et combien de temps
-[00:36:18] VOIX: elle est restée,
-[00:36:19] VOIX: qu'est-ce qui a été fait,
-[00:36:20] VOIX: etc.
-[00:36:20] VOIX: Parce que ça,
-[00:36:21] VOIX: c'est la synthèse du médecin,
-[00:36:22] VOIX: mais normalement,
-[00:36:23] VOIX: le passage aux urgences,
-[00:36:24] VOIX: il est tracé aussi
-[00:36:25] VOIX: dans le dossier patient.
-[00:36:26] VOIX: D'accord.
-[00:36:27] VOIX: Et peut-être,
-[00:36:27] VOIX: quand on va regarder
-[00:36:28] VOIX: Tracker,
-[00:36:29] VOIX: on aura plus d'infos,
-[00:36:31] VOIX: ça nous donnera,
-[00:36:32] VOIX: mais là,
-[00:36:32] VOIX: je te dis,
-[00:36:32] VOIX: c'est-à-dire,
-[00:36:34] VOIX: si je n'ai que ce document,
-[00:36:36] VOIX: voilà les questions
-[00:36:37] VOIX: que je vais me poser
-[00:36:37] VOIX: et que je vais chercher
-[00:36:39] VOIX: ailleurs
-[00:36:40] VOIX: dans le dossier patient.
-[00:36:41] VOIX: Parce que rien de ce document
-[00:36:42] VOIX: ne permet pas
-[00:36:43] VOIX: de répondre
-[00:36:43] VOIX: à toutes les questions.
-[00:36:44] VOIX: Oui.
-[00:36:44] VOIX: Alors là,
-[00:36:45] VOIX: par contre,
-[00:36:45] VOIX: dans ce document,
-[00:36:47] VOIX: en fait,
-[00:36:47] VOIX: il y a marqué
-[00:36:47] VOIX: donc au total,
-[00:36:48] VOIX: premier épisode,
-[00:36:49] VOIX: donc en fait,
-[00:36:50] VOIX: j'imagine qu'il y a marqué
-[00:36:51] VOIX: synthèse,
-[00:36:52] VOIX: en fait,
-[00:36:52] VOIX: du paragraphe précédent
-[00:36:53] VOIX: et c'est premier épisode
-[00:36:55] VOIX: de créatite aiguë
-[00:36:57] VOIX: d'origine...
-[00:36:57] VOIX: Créatite aiguë
-[00:36:58] VOIX: d'origine lithéatique.
-[00:36:59] VOIX: D'accord,
-[00:37:00] VOIX: OK.
-[00:37:01] VOIX: OK.
-[00:37:01] VOIX: Voilà.
-[00:37:01] VOIX: Ils ont fait le bilan
-[00:37:03] VOIX: donc ils ont
-[00:37:04] VOIX: la lipazémie
-[00:37:05] VOIX: qui est élevée
-[00:37:05] VOIX: à 6000
-[00:37:06] VOIX: donc supérieure
-[00:37:07] VOIX: à 3N
-[00:37:08] VOIX: c'est-à-dire
-[00:37:08] VOIX: supérieure
-[00:37:09] VOIX: à 3 fois
-[00:37:10] VOIX: la normale.
-[00:37:10] VOIX: Oui.
-[00:37:12] VOIX: Tu vois ?
-[00:37:13] VOIX: Asat,
-[00:37:14] VOIX: 8 fois
-[00:37:14] VOIX: la normale.
-[00:37:15] VOIX: Alad,
-[00:37:16] VOIX: 9 fois
-[00:37:16] VOIX: la normale.
-[00:37:18] VOIX: Donc le bilan
-[00:37:18] VOIX: hépatique,
-[00:37:19] VOIX: pancréatique
-[00:37:20] VOIX: est complètement perturbé.
-[00:37:22] VOIX: Phosphatase
-[00:37:23] VOIX: alcaline,
-[00:37:23] VOIX: donc le gastroenterot
-[00:37:25] VOIX: à partir
-[00:37:26] VOIX: de ce bilan,
-[00:37:27] VOIX: il dit
-[00:37:28] VOIX: premier épisode
-[00:37:29] VOIX: parce que la patiente
-[00:37:30] VOIX: avait déjà
-[00:37:31] VOIX: une lithiase
-[00:37:32] VOIX: au niveau
-[00:37:32] VOIX: de la vésicule
-[00:37:33] VOIX: biliaire
-[00:37:34] VOIX: et ces lithiases
-[00:37:35] VOIX: peuvent migrer
-[00:37:36] VOIX: au niveau
-[00:37:36] VOIX: du pancréas
-[00:37:37] VOIX: et provoquer
-[00:37:38] VOIX: des pancréatites
-[00:37:39] VOIX: aiguës.
-[00:37:40] VOIX: Et sa pancréatite
-[00:37:41] VOIX: aiguë
-[00:37:42] VOIX: d'origine lithiasique,
-[00:37:43] VOIX: c'est un diagnostic
-[00:37:45] VOIX: qui est facile
-[00:37:46] VOIX: quelque part.
-[00:37:47] VOIX: Mais avant,
-[00:37:48] VOIX: tu vas le trouver,
-[00:37:49] VOIX: mais est-ce que c'est
-[00:37:49] VOIX: des dépées ou pas ?
-[00:37:50] VOIX: Je t'expliquerai
-[00:37:51] VOIX: pourquoi après,
-[00:37:53] VOIX: est-ce qu'on met
-[00:37:53] VOIX: la douleur,
-[00:37:54] VOIX: est-ce qu'on met
-[00:37:54] VOIX: la pancréatite ?
-[00:37:56] VOIX: Là,
-[00:37:57] VOIX: moi je sais,
-[00:37:57] VOIX: mais je vais aller
-[00:37:58] VOIX: pousser la réflexion
-[00:38:00] VOIX: jusqu'au bout.
-[00:38:02] VOIX: Donc suspecté
-[00:38:03] VOIX: chez cette patiente
-[00:38:04] VOIX: ayant fait un épisode
-[00:38:05] VOIX: de colique hépatique
-[00:38:06] VOIX: le 23 janvier
-[00:38:08] VOIX: avec indication
-[00:38:09] VOIX: chirurgicale
-[00:38:10] VOIX: retenue.
-[00:38:13] VOIX: Donc titration morphine,
-[00:38:15] VOIX: comme elle est
-[00:38:15] VOIX: hyperalgique
-[00:38:16] VOIX: et elle a très mal,
-[00:38:17] VOIX: donc ils l'ont mis
-[00:38:17] VOIX: sous morphine,
-[00:38:19] VOIX: relais en PCA,
-[00:38:21] VOIX: c'est-à-dire
-[00:38:21] VOIX: ils font,
-[00:38:23] VOIX: en fait,
-[00:38:24] VOIX: ils font une perf
-[00:38:25] VOIX: et c'est le patient
-[00:38:26] VOIX: qui gère
-[00:38:27] VOIX: la dose.
-[00:38:28] VOIX: Oui,
-[00:38:28] VOIX: qui a pris
-[00:38:29] VOIX: sur le bouton.
-[00:38:30] VOIX: Voilà.
-[00:38:31] VOIX: Et transfert,
-[00:38:32] VOIX: donc tout ça,
-[00:38:33] VOIX: c'est la synthèse
-[00:38:33] VOIX: des urgences
-[00:38:34] VOIX: et ils la transfèrent
-[00:38:35] VOIX: dans le service
-[00:38:35] VOIX: de gastro-entérologie.
-[00:38:37] VOIX: Ils la mettent
-[00:38:37] VOIX: sous morphine,
-[00:38:38] VOIX: ils la montent
-[00:38:39] VOIX: vers un spécialiste.
-[00:38:41] VOIX: rassurante
-[00:38:41] VOIX: sur le plan
-[00:38:42] VOIX: abdominal
-[00:38:42] VOIX: parce qu'ici,
-[00:38:43] VOIX: il n'a pas de signe
-[00:38:45] VOIX: d'irritation péritoniale.
-[00:38:47] VOIX: Donc,
-[00:38:48] VOIX: s'il y avait
-[00:38:48] VOIX: une péritonite,
-[00:38:50] VOIX: c'est-à-dire
-[00:38:50] VOIX: la patiente,
-[00:38:51] VOIX: soit elle a une infection,
-[00:38:53] VOIX: soit qu'elle a
-[00:38:53] VOIX: une perforation,
-[00:38:55] VOIX: soit...
-[00:38:55] VOIX: Donc,
-[00:38:55] VOIX: ça devient
-[00:38:55] VOIX: une urgence.
-[00:38:57] VOIX: Là,
-[00:38:58] VOIX: ils ont un peu
-[00:38:59] VOIX: le temps.
-[00:38:59] VOIX: Ils vont calmer
-[00:39:00] VOIX: la douleur,
-[00:39:01] VOIX: ils vont compléter
-[00:39:02] VOIX: le bilan,
-[00:39:02] VOIX: etc.
-[00:39:03] VOIX: Donc,
-[00:39:04] VOIX: ce n'est pas
-[00:39:04] VOIX: l'extrême urgence.
-[00:39:06] VOIX: Sous PCA de morphine,
-[00:39:07] VOIX: elle n'est pas douloureuse
-[00:39:08] VOIX: et pas dileus,
-[00:39:09] VOIX: c'est-à-dire
-[00:39:10] VOIX: pas d'occlusion.
-[00:39:12] VOIX: En fait,
-[00:39:13] VOIX: ça se code aussi
-[00:39:13] VOIX: l'ilus,
-[00:39:14] VOIX: mais tu vois,
-[00:39:15] VOIX: il n'y a pas
-[00:39:15] VOIX: non douloureuse
-[00:39:16] VOIX: et pas...
-[00:39:18] VOIX: Ça,
-[00:39:18] VOIX: c'est tout ça,
-[00:39:19] VOIX: c'est important.
-[00:39:20] VOIX: L'essai agent initialement
-[00:39:22] VOIX: parce qu'il pense
-[00:39:23] VOIX: l'opérer.
-[00:39:23] VOIX: L'opérer,
-[00:39:25] VOIX: évolution rapidement favorable,
-[00:39:27] VOIX: permettant une reprise
-[00:39:28] VOIX: alimentaire progressive
-[00:39:29] VOIX: et arrêt de la PCA de morphine.
-[00:39:32] VOIX: Le TDM,
-[00:39:33] VOIX: c'est le scanner à J3,
-[00:39:34] VOIX: donc il est resté hospitalisé
-[00:39:36] VOIX: trois jours.
-[00:39:38] VOIX: Absence...
-[00:39:38] VOIX: Ah oui,
-[00:39:38] VOIX: pardon.
-[00:39:40] VOIX: La phrase là,
-[00:39:41] VOIX: juste au-dessus,
-[00:39:42] VOIX: qu'on a dit
-[00:39:42] VOIX: « transfert dans le service
-[00:39:44] VOIX: du gastron entero »,
-[00:39:45] VOIX: c'est plus en faveur
-[00:39:47] VOIX: que c'était
-[00:39:47] VOIX: les urgences
-[00:39:48] VOIX: de l'hôpital.
-[00:39:49] VOIX: Tu vois,
-[00:39:49] VOIX: ils n'ont pas dit
-[00:39:50] VOIX: « transfert à l'hôpital »
-[00:39:51] VOIX: ou ils ont dit
-[00:39:52] VOIX: « dans le service »,
-[00:39:53] VOIX: donc à mon avis,
-[00:39:53] VOIX: chez eux.
-[00:39:55] VOIX: Donc,
-[00:39:56] VOIX: il faut un scanner
-[00:39:57] VOIX: le troisième jour
-[00:39:58] VOIX: d'hospitalisation.
-[00:40:00] VOIX: Absence
-[00:40:01] VOIX: de signes de gravité,
-[00:40:02] VOIX: pas d'anomalie significative
-[00:40:03] VOIX: de la glande pancréatique,
-[00:40:05] VOIX: ni d'infiltration,
-[00:40:06] VOIX: pas de coulée,
-[00:40:07] VOIX: pas d'argument
-[00:40:08] VOIX: pour une complication
-[00:40:10] VOIX: vasculaire,
-[00:40:10] VOIX: pas de pseudo-anévrisme
-[00:40:12] VOIX: décelé,
-[00:40:13] VOIX: minime infiltration
-[00:40:14] VOIX: péri-vésiculaire
-[00:40:15] VOIX: évocatrice
-[00:40:16] VOIX: de pop-at-signe.
-[00:40:18] VOIX: Ça,
-[00:40:19] VOIX: je ne sais pas
-[00:40:19] VOIX: ce que c'est,
-[00:40:21] VOIX: mais sans
-[00:40:21] VOIX: hydrocholéciste
-[00:40:22] VOIX: le jour,
-[00:40:23] VOIX: score de Balthazar
-[00:40:24] VOIX: à zéro.
-[00:40:25] VOIX: Le score de Balthazar,
-[00:40:26] VOIX: c'est pour évaluer
-[00:40:28] VOIX: la gravité
-[00:40:29] VOIX: de la pancréatite.
-[00:40:30] VOIX: OK.
-[00:40:32] VOIX: Donc,
-[00:40:33] VOIX: ils organisent
-[00:40:34] VOIX: une cholécistectomie
-[00:40:36] VOIX: par cellioscopie
-[00:40:37] VOIX: le 1er mars
-[00:40:38] VOIX: par docteur,
-[00:40:39] VOIX: voilà.
-[00:40:41] VOIX: Colangio retrouve
-[00:40:42] VOIX: une lithiase
-[00:40:43] VOIX: du bacolédoc,
-[00:40:44] VOIX: petite taille,
-[00:40:46] VOIX: opacification initiale
-[00:40:47] VOIX: du diodénum.
-[00:40:48] VOIX: après plusieurs injections
-[00:40:49] VOIX: de produits
-[00:40:50] VOIX: de contraste,
-[00:40:50] VOIX: on arrive finalement
-[00:40:51] VOIX: à pousser le calcul
-[00:40:53] VOIX: qui réussit
-[00:40:54] VOIX: à franchir
-[00:40:55] VOIX: la papille,
-[00:40:55] VOIX: papille du diodénum,
-[00:40:57] VOIX: aux pacifications
-[00:40:58] VOIX: franches
-[00:40:58] VOIX: du diodénum
-[00:40:59] VOIX: ou des courses
-[00:41:00] VOIX: en lithiase
-[00:41:01] VOIX: résiduelle
-[00:41:01] VOIX: de la VBP,
-[00:41:04] VOIX: c'est la voie
-[00:41:04] VOIX: biliaire principale.
-[00:41:05] VOIX: Ça,
-[00:41:06] VOIX: c'est des petites
-[00:41:08] VOIX: abréviations
-[00:41:08] VOIX: que petit à petit
-[00:41:09] VOIX: on pourra,
-[00:41:11] VOIX: tu vois,
-[00:41:12] VOIX: à un moment donné,
-[00:41:12] VOIX: tu peux sortir
-[00:41:13] VOIX: un peu toute la liste
-[00:41:14] VOIX: des abréviations
-[00:41:15] VOIX: et on essaiera
-[00:41:16] VOIX: de leur donner
-[00:41:18] VOIX: du sens.
-[00:41:19] VOIX: Et pas de gramme
-[00:41:20] VOIX: complet,
-[00:41:20] VOIX: pas de cathéter
-[00:41:21] VOIX: nécessaire donc,
-[00:41:23] VOIX: puisque en fait,
-[00:41:24] VOIX: avec le produit
-[00:41:25] VOIX: de contraste,
-[00:41:26] VOIX: ils ont réussi
-[00:41:26] VOIX: à pousser
-[00:41:28] VOIX: la lithiase
-[00:41:29] VOIX: et donc du coup,
-[00:41:30] VOIX: ils ont libéré
-[00:41:31] VOIX: la voie.
-[00:41:31] VOIX: C'est comme si la lithiase
-[00:41:33] VOIX: qui a bougé
-[00:41:33] VOIX: au fait de sa place
-[00:41:34] VOIX: et c'est ça
-[00:41:35] VOIX: qui a provoqué
-[00:41:36] VOIX: tout ça.
-[00:41:37] VOIX: Donc,
-[00:41:39] VOIX: fermeture,
-[00:41:40] VOIX: colle,
-[00:41:41] VOIX: ça c'est,
-[00:41:41] VOIX: tu vois,
-[00:41:42] VOIX: c'est un peu,
-[00:41:42] VOIX: ça c'est un peu
-[00:41:44] VOIX: un semblant
-[00:41:45] VOIX: de cro,
-[00:41:47] VOIX: de courrier opératoire.
-[00:41:49] VOIX: Et en fait,
-[00:41:50] VOIX: ça,
-[00:41:51] VOIX: quand il te dit
-[00:41:52] VOIX: cholecystectomie
-[00:41:53] VOIX: par cellio,
-[00:41:54] VOIX: cholangiographie,
-[00:41:55] VOIX: etc.,
-[00:41:55] VOIX: tu sais,
-[00:41:56] VOIX: on a touché
-[00:41:57] VOIX: un petit peu
-[00:41:57] VOIX: aux actes CCAM
-[00:41:58] VOIX: parce qu'on parle
-[00:42:00] VOIX: de l'allemand
-[00:42:01] VOIX: que la tire C1-10
-[00:42:02] VOIX: des diagnostics
-[00:42:02] VOIX: mais là,
-[00:42:06] VOIX: on peut raisonner
-[00:42:09] VOIX: d'emblée
-[00:42:10] VOIX: sur les actes
-[00:42:11] VOIX: si tu veux.
-[00:42:11] VOIX: C'est-à-dire,
-[00:42:12] VOIX: tu vérifies
-[00:42:13] VOIX: que le codage
-[00:42:13] VOIX: du scanner
-[00:42:14] VOIX: il est fait,
-[00:42:15] VOIX: que la colle
-[00:42:16] VOIX: citectomie
-[00:42:16] VOIX: est faite,
-[00:42:17] VOIX: la collangiologie
-[00:42:18] VOIX: est faite,
-[00:42:19] VOIX: etc.
-[00:42:19] VOIX: Tu vois,
-[00:42:20] VOIX: parce que tu introduis
-[00:42:21] VOIX: la logique globale
-[00:42:22] VOIX: du dossier
-[00:42:23] VOIX: avec les diagnostics
-[00:42:24] VOIX: et les actes.
-[00:42:26] VOIX: Donc ça,
-[00:42:27] VOIX: ça peut être intéressant.
-[00:42:29] VOIX: Donc,
-[00:42:30] VOIX: il fait la colle,
-[00:42:31] VOIX: OK,
-[00:42:34] VOIX: alimentaire bien toléré,
-[00:42:35] VOIX: bilan hépatique
-[00:42:36] VOIX: à la sortie.
-[00:42:37] VOIX: Donc,
-[00:42:37] VOIX: tu vois le bilan hépatique,
-[00:42:39] VOIX: tous les bilans
-[00:42:39] VOIX: qu'ils ont fait
-[00:42:40] VOIX: à l'entrée,
-[00:42:40] VOIX: comment ça s'est amélioré.
-[00:42:42] VOIX: Éruption,
-[00:42:43] VOIX: elle fait une complication.
-[00:42:47] VOIX: Éruption cutanée,
-[00:42:48] VOIX: érythémateuse,
-[00:42:49] VOIX: prédominant sur le tronc,
-[00:42:50] VOIX: apparu le matin
-[00:42:51] VOIX: de la sortie.
-[00:42:53] VOIX: Nantes origineuses,
-[00:42:54] VOIX: c'est-à-dire,
-[00:42:54] VOIX: elle ne gratte pas
-[00:42:55] VOIX: dans un contexte
-[00:42:56] VOIX: de prise de contre-armale
-[00:42:58] VOIX: la veille.
-[00:42:58] VOIX: Donc,
-[00:42:58] VOIX: elle a fait probablement
-[00:43:00] VOIX: une réaction médicamenteuse.
-[00:43:03] VOIX: Introduction de cétérisine
-[00:43:05] VOIX: qui est un antihistaminique.
-[00:43:07] VOIX: Au total,
-[00:43:09] VOIX: pancréasite aiguille,
-[00:43:10] VOIX: thiasique,
-[00:43:10] VOIX: sensible de gravité,
-[00:43:11] VOIX: polycystectomisé.
-[00:43:13] VOIX: Donc,
-[00:43:13] VOIX: avec ces éléments-là,
-[00:43:16] VOIX: on peut déjà proposer
-[00:43:19] VOIX: un codage
-[00:43:20] VOIX: et je t'explique
-[00:43:20] VOIX: sur quoi
-[00:43:21] VOIX: je vais me baser
-[00:43:23] VOIX: pour coder ce dossier.
-[00:43:25] VOIX: En fait,
-[00:43:28] VOIX: on peut partager mon écran.
-[00:43:31] VOIX: Attends,
-[00:43:31] VOIX: je vais ouvrir le fichier
-[00:43:32] VOIX: que je voulais ouvrir.
-[00:43:36] VOIX: Le guide méthodologique.
-[00:43:40] VOIX: Attends,
-[00:43:40] VOIX: je te reprends.
-[00:43:42] VOIX: C'est moi qui partage.
-[00:43:44] VOIX: Vas-y, vas-y.
-[00:43:44] VOIX: Je fais une pause
-[00:43:45] VOIX: et on revient
-[00:43:46] VOIX: à ton dossier après.
-[00:43:48] VOIX: D'accord ?
-[00:43:48] VOIX: Vas-y, vas-y.
-[00:43:50] VOIX: OK.
-[00:43:50] VOIX: Alors,
-[00:43:52] VOIX: le guide méthodologique,
-[00:43:54] VOIX: ce document,
-[00:43:54] VOIX: je pense que je te l'avais envoyé.
-[00:43:56] VOIX: Je l'ai.
-[00:43:57] VOIX: Il est mis à jour
-[00:43:58] VOIX: tous les ans,
-[00:43:59] VOIX: tu vois,
-[00:44:00] VOIX: décembre 2025,
-[00:44:01] VOIX: parce que c'est celui
-[00:44:01] VOIX: qui s'applique en 2026.
-[00:44:03] VOIX: Ce qui nous intéresse
-[00:44:05] VOIX: dans
-[00:44:08] VOIX: ce fichier,
-[00:44:10] VOIX: c'était
-[00:44:12] VOIX: à partir
-[00:44:13] VOIX: du chapitre 5.
-[00:44:16] VOIX: Non,
-[00:44:16] VOIX: hiérarchisation,
-[00:44:17] VOIX: à partir de là.
-[00:44:19] VOIX: Hiérarchisation,
-[00:44:21] VOIX: et codage
-[00:44:22] VOIX: des informations médicales
-[00:44:23] VOIX: du résumé
-[00:44:23] VOIX: d'unité médicale.
-[00:44:25] VOIX: On va aller là.
-[00:44:28] VOIX: Et en fait,
-[00:44:29] VOIX: attends,
-[00:44:31] VOIX: attends,
-[00:44:33] VOIX: non,
-[00:44:33] VOIX: je reviens avant
-[00:44:35] VOIX: parce que ça,
-[00:44:36] VOIX: ça va être
-[00:44:36] VOIX: beaucoup de blabla.
-[00:44:38] VOIX: Je vais te montrer.
-[00:44:39] VOIX: Le sommaire,
-[00:44:39] VOIX: tu vois déjà,
-[00:44:40] VOIX: on va réussir
-[00:44:41] VOIX: à expliquer un peu
-[00:44:44] VOIX: parce que tu as
-[00:44:44] VOIX: les actes
-[00:44:46] VOIX: quand on dit
-[00:44:46] VOIX: que les cystectomies,
-[00:44:47] VOIX: c'est les actes.
-[00:44:49] VOIX: Et
-[00:44:51] VOIX: consigne de codage,
-[00:44:52] VOIX: les situations cliniques.
-[00:44:54] VOIX: Voilà.
-[00:44:56] VOIX: En fait,
-[00:44:57] VOIX: comme tout à l'heure,
-[00:44:58] VOIX: je t'ai dit,
-[00:44:59] VOIX: le patient,
-[00:44:59] VOIX: il est rentré
-[00:45:00] VOIX: pour un symptôme
-[00:45:01] VOIX: qui est la douleur.
-[00:45:03] VOIX: Et le séjour,
-[00:45:05] VOIX: je vais,
-[00:45:05] VOIX: soit son objectif,
-[00:45:07] VOIX: ça sera faire un bilan
-[00:45:08] VOIX: pour poser le diagnostic,
-[00:45:12] VOIX: soit traiter
-[00:45:13] VOIX: et ou traiter.
-[00:45:16] VOIX: Là,
-[00:45:16] VOIX: il y a eu les deux.
-[00:45:17] VOIX: D'accord ?
-[00:45:18] VOIX: Puisqu'ils ont posé
-[00:45:21] VOIX: le diagnostic
-[00:45:21] VOIX: de pancréatite
-[00:45:22] VOIX: et non seulement
-[00:45:23] VOIX: ils ont posé
-[00:45:24] VOIX: le diagnostic,
-[00:45:25] VOIX: mais ils ont mis
-[00:45:26] VOIX: en place
-[00:45:26] VOIX: un traitement chirurgical.
-[00:45:28] VOIX: Et le traitement
-[00:45:30] VOIX: chirurgical,
-[00:45:31] VOIX: c'est ce qu'on appelle
-[00:45:32] VOIX: le traitement unique.
-[00:45:35] VOIX: Et si je clique
-[00:45:36] VOIX: ici,
-[00:45:38] VOIX: la règle
-[00:45:38] VOIX: que je vais utiliser,
-[00:45:40] VOIX: c'est celle-ci.
-[00:45:41] VOIX: Le traitement unique
-[00:45:42] VOIX: chirurgical.
-[00:45:43] VOIX: Dans la situation
-[00:45:44] VOIX: de traitement unique
-[00:45:45] VOIX: chirurgical,
-[00:45:46] VOIX: le DP
-[00:45:47] VOIX: et en général
-[00:45:48] VOIX: la maladie opérée.
-[00:45:49] VOIX: Le DP,
-[00:45:50] VOIX: c'est le diagnostic
-[00:45:50] VOIX: principal.
-[00:45:51] VOIX: Donc là,
-[00:45:52] VOIX: cette règle,
-[00:45:53] VOIX: il se dit,
-[00:45:55] VOIX: si le patient
-[00:45:57] VOIX: a été opéré,
-[00:45:58] VOIX: le diagnostic
-[00:45:59] VOIX: principal
-[00:46:00] VOIX: et la maladie
-[00:46:01] VOIX: opérée.
-[00:46:02] VOIX: D'accord.
-[00:46:05] VOIX: Donc,
-[00:46:06] VOIX: le patient,
-[00:46:07] VOIX: d'ailleurs ça,
-[00:46:08] VOIX: ce type de dossier,
-[00:46:09] VOIX: il nous a posé,
-[00:46:10] VOIX: en fait,
-[00:46:10] VOIX: il est censé
-[00:46:11] VOIX: être simple,
-[00:46:12] VOIX: mais il peut
-[00:46:13] VOIX: poser certains problèmes.
-[00:46:15] VOIX: Je vais
-[00:46:16] VOIX: t'expliquer
-[00:46:17] VOIX: le problème.
-[00:46:19] VOIX: En fait,
-[00:46:20] VOIX: l'acte chirurgical,
-[00:46:21] VOIX: c'était quoi ?
-[00:46:25] VOIX: Ils ont enlevé
-[00:46:26] VOIX: la cholecystite.
-[00:46:28] VOIX: Alors,
-[00:46:28] VOIX: attends,
-[00:46:28] VOIX: tu veux que je le reprenne
-[00:46:29] VOIX: le truc que je te dise ?
-[00:46:32] VOIX: Oui,
-[00:46:32] VOIX: moi je sais,
-[00:46:33] VOIX: je veux juste
-[00:46:35] VOIX: vérifier que tu suis
-[00:46:36] VOIX: la logique.
-[00:46:37] VOIX: Non,
-[00:46:37] VOIX: mais je suis la logique,
-[00:46:39] VOIX: mais les termes
-[00:46:40] VOIX: et tout ça,
-[00:46:40] VOIX: en fait,
-[00:46:41] VOIX: oui,
-[00:46:41] VOIX: c'est pour enlever
-[00:46:41] VOIX: la cholecystécomie.
-[00:46:43] VOIX: C'est
-[00:46:43] VOIX: la cholecystécomie.
-[00:46:46] VOIX: La cholecystécomie.
-[00:46:46] VOIX: Voilà.
-[00:46:47] VOIX: tout ce qui est cholecyste,
-[00:46:48] VOIX: c'est en lien
-[00:46:49] VOIX: avec la vésicule biliaire.
-[00:46:51] VOIX: Cholecystéctomie,
-[00:46:52] VOIX: ça veut dire
-[00:46:52] VOIX: qu'ils ont enlevé
-[00:46:54] VOIX: la vésicule biliaire.
-[00:46:55] VOIX: Et donc là,
-[00:46:56] VOIX: ça va être
-[00:46:57] VOIX: la cholecystéctomie,
-[00:46:58] VOIX: parcellioscopie,
-[00:46:59] VOIX: etc.
-[00:47:01] VOIX: Et pourquoi
-[00:47:01] VOIX: ils ont enlevé
-[00:47:03] VOIX: la vésicule biliaire ?
-[00:47:05] VOIX: En fait,
-[00:47:07] VOIX: soit on va dire
-[00:47:08] VOIX: qu'ils l'ont enlevé
-[00:47:09] VOIX: parce qu'il y avait
-[00:47:10] VOIX: une lithiase
-[00:47:11] VOIX: au niveau
-[00:47:12] VOIX: de la vésicule biliaire,
-[00:47:14] VOIX: soit
-[00:47:14] VOIX: tu te fies
-[00:47:16] VOIX: à ce qu'il a écrit
-[00:47:17] VOIX: le chirurgien
-[00:47:18] VOIX: et tu dis
-[00:47:18] VOIX: pour pancréatite.
-[00:47:20] VOIX: D'accord ?
-[00:47:22] VOIX: Je vais chercher
-[00:47:24] VOIX: quelque chose
-[00:47:25] VOIX: parce que je pense
-[00:47:25] VOIX: que ce type
-[00:47:26] VOIX: de dossier
-[00:47:27] VOIX: nous a posé
-[00:47:28] VOIX: d'énormes problèmes.
-[00:47:31] VOIX: Initial pendant
-[00:47:33] VOIX: le contrôle
-[00:47:34] VOIX: et après
-[00:47:35] VOIX: le contrôle,
-[00:47:38] VOIX: rapport
-[00:47:38] VOIX: du contrôle
-[00:47:39] VOIX: et document
-[00:47:40] VOIX: transmis
-[00:47:41] VOIX: à l'UE 1,
-[00:47:43] VOIX: fiche de concertation.
-[00:47:45] VOIX: Je ne sais plus
-[00:47:46] VOIX: comment il était codé
-[00:47:47] VOIX: le dossier 1.
-[00:47:49] VOIX: Parce que c'est ça
-[00:47:49] VOIX: qui est intéressant.
-[00:47:51] VOIX: Parce que,
-[00:47:51] VOIX: en fait,
-[00:47:52] VOIX: si je code
-[00:47:53] VOIX: pour créatite,
-[00:47:56] VOIX: mon dossier
-[00:47:57] VOIX: va être mieux
-[00:47:57] VOIX: valorisé
-[00:47:58] VOIX: que si je code
-[00:48:02] VOIX: lithiase
-[00:48:02] VOIX: de la vésicule
-[00:48:04] VOIX: biliaire.
-[00:48:06] VOIX: On a eu
-[00:48:06] VOIX: des dossiers
-[00:48:07] VOIX: pour lesquels
-[00:48:09] VOIX: à l'UE 1.
-[00:48:10] VOIX: pourquoi je...
-[00:48:11] VOIX: Pardon,
-[00:48:11] VOIX: j'arrête
-[00:48:11] VOIX: le partage
-[00:48:12] VOIX: juste pour
-[00:48:14] VOIX: ouvrir...
-[00:48:14] VOIX: J'arrête pas
-[00:48:15] VOIX: à ouvrir
-[00:48:15] VOIX: mon répertoire,
-[00:48:16] VOIX: ça m'énerve.
-[00:48:19] VOIX: Ah !
-[00:48:20] VOIX: C'est pas cool.
-[00:48:22] VOIX: Document
-[00:48:23] VOIX: transmis
-[00:48:23] VOIX: à l'avocat.
-[00:48:25] VOIX: Le répertoire
-[00:48:26] VOIX: évite.
-[00:48:29] VOIX: Fiche médicale.
-[00:48:31] VOIX: Ah oui,
-[00:48:31] VOIX: ça tourne.
-[00:48:32] VOIX: Encore,
-[00:48:32] VOIX: c'est pour ça.
-[00:48:33] VOIX: C'est long.
-[00:48:37] VOIX: En fait,
-[00:48:38] VOIX: ce que je ne sais pas,
-[00:48:39] VOIX: mais je vais essayer
-[00:48:40] VOIX: peut-être de me connecter
-[00:48:41] VOIX: sur l'appli.
-[00:48:42] VOIX: En fait,
-[00:48:43] VOIX: avant de dire
-[00:48:44] VOIX: le codage,
-[00:48:45] VOIX: je t'explique déjà
-[00:48:46] VOIX: la règle que j'utilise.
-[00:48:47] VOIX: Le patient,
-[00:48:48] VOIX: il est opéré.
-[00:48:50] VOIX: La pathologie opérée,
-[00:48:51] VOIX: c'est elle
-[00:48:52] VOIX: qui va être
-[00:48:53] VOIX: le diagnostic principal.
-[00:48:54] VOIX: Déjà,
-[00:48:54] VOIX: ça,
-[00:48:55] VOIX: c'est la règle.
-[00:48:55] VOIX: Et tu as vu
-[00:48:56] VOIX: qu'il y avait
-[00:48:56] VOIX: plusieurs règles
-[00:48:57] VOIX: en fonction
-[00:48:57] VOIX: de pourquoi
-[00:48:58] VOIX: le patient
-[00:48:58] VOIX: est rentré.
-[00:48:59] VOIX: Et en fonction
-[00:49:00] VOIX: des dossiers,
-[00:49:01] VOIX: je pourrais
-[00:49:01] VOIX: t'expliquer
-[00:49:02] VOIX: la règle
-[00:49:03] VOIX: et te dire,
-[00:49:04] VOIX: et c'est ça
-[00:49:04] VOIX: qu'il faut mettre,
-[00:49:05] VOIX: c'est-à-dire,
-[00:49:06] VOIX: dans les argumentaires,
-[00:49:07] VOIX: on peut dire
-[00:49:08] VOIX: selon la règle T3
-[00:49:09] VOIX: de traitement unique
-[00:49:10] VOIX: chirurgical,
-[00:49:11] VOIX: le patient opéré,
-[00:49:13] VOIX: c'est la phrase
-[00:49:13] VOIX: que je t'ai montrée
-[00:49:14] VOIX: dans le guide,
-[00:49:15] VOIX: et je dis,
-[00:49:16] VOIX: c'est comme ça
-[00:49:17] VOIX: qu'on écrit
-[00:49:17] VOIX: nos argumentaires,
-[00:49:19] VOIX: et dire,
-[00:49:19] VOIX: dans le CROS
-[00:49:20] VOIX: ou dans le CRH,
-[00:49:22] VOIX: dans le document
-[00:49:23] VOIX: en tout cas écrit
-[00:49:23] VOIX: par le médecin
-[00:49:24] VOIX: à telle date,
-[00:49:25] VOIX: etc.,
-[00:49:25] VOIX: c'est pour ça
-[00:49:26] VOIX: que c'est bien
-[00:49:26] VOIX: d'identifier
-[00:49:27] VOIX: la nature
-[00:49:28] VOIX: des documents.
-[00:49:30] VOIX: Le médecin
-[00:49:31] VOIX: a écrit
-[00:49:32] VOIX: que le chirurgien
-[00:49:34] VOIX: a écrit
-[00:49:34] VOIX: que le patient
-[00:49:35] VOIX: a été opéré
-[00:49:36] VOIX: par colibistectomie
-[00:49:37] VOIX: pour pancréatite.
-[00:49:39] VOIX: Si je veux défendre
-[00:49:40] VOIX: la pancréatite,
-[00:49:42] VOIX: j'essaie de trouver
-[00:49:44] VOIX: un peu
-[00:49:45] VOIX: parce que
-[00:49:47] VOIX: c'est important
-[00:49:49] VOIX: que je puisse
-[00:49:51] VOIX: te montrer
-[00:49:52] VOIX: et en même temps
-[00:49:54] VOIX: je vais partager
-[00:49:55] VOIX: quand même mon écran
-[00:49:56] VOIX: parce que
-[00:49:56] VOIX: tu vas voir
-[00:49:59] VOIX: et bien
-[00:50:02] VOIX: je vais partager
-[00:50:04] VOIX: parce que
-[00:50:04] VOIX: je suis en train
-[00:50:06] VOIX: de me connecter
-[00:50:08] VOIX: à l'appli
-[00:50:09] VOIX: et
-[00:50:15] VOIX: on va voir
-[00:50:18] VOIX: est-ce que ça va marcher ?
-[00:50:20] VOIX: mon téléphone
-[00:50:22] VOIX: mais sécurisé
-[00:50:24] VOIX: ça fait plaisir
-[00:50:34] VOIX: 742
-[00:50:36] VOIX: 743
-[00:50:38] VOIX: oui c'est
-[00:50:39] VOIX: authenticateur
-[00:50:42] VOIX: donc ça c'est
-[00:50:44] VOIX: le contrôle
-[00:50:44] VOIX: duquel on parle
-[00:50:45] VOIX: et on va voir
-[00:50:50] VOIX: suivre
-[00:50:51] VOIX: ce serait plus simple
-[00:50:52] VOIX: et soit
-[00:50:54] VOIX: je mets
-[00:50:54] VOIX: 1
-[00:50:55] VOIX: c'était
-[00:50:56] VOIX: l'OGC1
-[00:50:57] VOIX: oui c'est ça
-[00:50:57] VOIX: c'est ça
-[00:51:14] VOIX: comment j'ai pile poil
-[00:51:15] VOIX: en face
-[00:51:16] VOIX: qu'il me faut
-[00:51:16] VOIX: attends je prends
-[00:51:17] VOIX: une photo
-[00:51:17] VOIX: d'accord
-[00:51:19] VOIX: oui mais il y a
-[00:51:20] VOIX: une erreur
-[00:51:21] VOIX: que je vois
-[00:51:23] VOIX: et c'est
-[00:51:24] VOIX: c'est quoi
-[00:51:24] VOIX: l'erreur
-[00:51:27] VOIX: en haut
-[00:51:28] VOIX: c'est marqué
-[00:51:30] VOIX: K802
-[00:51:30] VOIX: ouais
-[00:51:32] VOIX: en bas
-[00:51:33] VOIX: c'est marqué
-[00:51:34] VOIX: K851
-[00:51:36] VOIX: donc du coup
-[00:51:37] VOIX: moi je vais
-[00:51:37] VOIX: faire une copie
-[00:51:38] VOIX: d'écran
-[00:51:41] VOIX: et je vais
-[00:51:41] VOIX: l'envoyer
-[00:51:42] VOIX: à Jordan
-[00:51:50] VOIX: est-ce que
-[00:51:50] VOIX: je sais
-[00:51:51] VOIX: qu'on a eu
-[00:51:51] VOIX: un problème
-[00:51:52] VOIX: sur
-[00:51:52] VOIX: donc du coup
-[00:51:53] VOIX: je ne sais plus
-[00:51:54] VOIX: quel est le codage
-[00:51:55] VOIX: initial
-[00:51:55] VOIX: de l'établissement
-[00:52:02] VOIX: c'est ballon
-[00:52:08] VOIX: tu vois
-[00:52:08] VOIX: on a
-[00:52:09] VOIX: ce type
-[00:52:09] VOIX: de problème
-[00:52:10] VOIX: ouais
-[00:52:19] VOIX: donc
-[00:52:23] VOIX: K802
-[00:52:25] VOIX: en haut
-[00:52:26] VOIX: et K
-[00:52:29] VOIX: 851
-[00:52:30] VOIX: en bas
-[00:52:47] VOIX: et tous
-[00:52:48] VOIX: nos argumentaires
-[00:52:49] VOIX: en fait
-[00:52:50] VOIX: on les fait
-[00:52:51] VOIX: avec l'OGC
-[00:52:52] VOIX: parce que c'est
-[00:52:53] VOIX: ce qui est censé
-[00:52:54] VOIX: être anonyme
-[00:52:55] VOIX: voilà
-[00:52:57] VOIX: donc du coup
-[00:52:58] VOIX: si je reviens
-[00:52:59] VOIX: à l'appli
-[00:53:00] VOIX: la fiche
-[00:53:01] VOIX: où j'essaye
-[00:53:02] VOIX: c'est ça
-[00:53:03] VOIX: et tu vois
-[00:53:04] VOIX: quand je t'ai dit
-[00:53:04] VOIX: tout à l'heure
-[00:53:05] VOIX: la préparation
-[00:53:10] VOIX: ça c'est
-[00:53:15] VOIX: tu m'entends
-[00:53:16] VOIX: oui
-[00:53:16] VOIX: oui voilà
-[00:53:18] VOIX: en fait
-[00:53:19] VOIX: là ça c'est
-[00:53:20] VOIX: la préparation
-[00:53:21] VOIX: ça c'est le travail
-[00:53:22] VOIX: du team
-[00:53:23] VOIX: oui
-[00:53:24] VOIX: d'accord
-[00:53:25] VOIX: team
-[00:53:25] VOIX: du team
-[00:53:26] VOIX: c'est tout le monde
-[00:53:27] VOIX: il passe
-[00:53:28] VOIX: d'accord
-[00:53:30] VOIX: ici sur ton programme
-[00:53:31] VOIX: en fait
-[00:53:31] VOIX: le contrôleur
-[00:53:32] VOIX: en fait
-[00:53:32] VOIX: c'est quand tu as
-[00:53:34] VOIX: un contrôle
-[00:53:35] VOIX: tes deux ans
-[00:53:36] VOIX: ou c'est quelque chose
-[00:53:37] VOIX: de plus
-[00:53:37] VOIX: ça c'est le retour
-[00:53:39] VOIX: ça c'est les fiches
-[00:53:40] VOIX: du médecin contrôleur
-[00:53:41] VOIX: c'est le retour
-[00:53:43] VOIX: du médecin contrôleur
-[00:53:44] VOIX: ouais
-[00:53:44] VOIX: il te laisse un peu
-[00:53:46] VOIX: de temps
-[00:53:47] VOIX: pour revoir
-[00:53:49] VOIX: c'est là
-[00:53:49] VOIX: où l'argumentaire
-[00:53:51] VOIX: de
-[00:53:51] VOIX: parce que là
-[00:53:53] VOIX: tu fais tes argumentaires
-[00:53:54] VOIX: tu prépares ton podage
-[00:53:57] VOIX: je vais te montrer
-[00:53:58] VOIX: un autre dossier
-[00:53:59] VOIX: le dossier
-[00:54:00] VOIX: qu'on met
-[00:54:00] VOIX: dans
-[00:54:02] VOIX: dans la démo
-[00:54:03] VOIX: c'est le
-[00:54:03] VOIX: 399
-[00:54:04] VOIX: il me semble
-[00:54:05] VOIX: parce qu'il est plus
-[00:54:08] VOIX: par exemple
-[00:54:08] VOIX: l'établissement
-[00:54:09] VOIX: il a préparé
-[00:54:11] VOIX: il a dit
-[00:54:12] VOIX: ce code là
-[00:54:13] VOIX: je ne vais pas
-[00:54:15] VOIX: pouvoir le défendre
-[00:54:16] VOIX: d'accord
-[00:54:18] VOIX: et
-[00:54:20] VOIX: le médecin contrôleur
-[00:54:22] VOIX: quand il a fait
-[00:54:23] VOIX: sa fiche
-[00:54:24] VOIX: il a dit
-[00:54:25] VOIX: non seulement
-[00:54:26] VOIX: je ne retiens pas
-[00:54:27] VOIX: ce code
-[00:54:28] VOIX: mais en plus
-[00:54:29] VOIX: je vous modifie
-[00:54:30] VOIX: le DP
-[00:54:31] VOIX: ok
-[00:54:32] VOIX: tu vois
-[00:54:33] VOIX: et donc là
-[00:54:34] VOIX: nous
-[00:54:35] VOIX: on avait préparé
-[00:54:36] VOIX: à
-[00:54:37] VOIX: à perdre
-[00:54:38] VOIX: 1500
-[00:54:39] VOIX: lui
-[00:54:40] VOIX: il nous a fait
-[00:54:41] VOIX: un codage
-[00:54:42] VOIX: qui nous fait perdre
-[00:54:42] VOIX: presque 4000 euros
-[00:54:44] VOIX: quand j'ai vu
-[00:54:46] VOIX: son truc
-[00:54:47] VOIX: parce que
-[00:54:48] VOIX: dans la prépa
-[00:54:49] VOIX: on ne s'est pas
-[00:54:50] VOIX: attendu
-[00:54:50] VOIX: à ce qu'il nous
-[00:54:51] VOIX: change ça
-[00:54:53] VOIX: donc on n'a pas
-[00:54:54] VOIX: argumenté
-[00:54:55] VOIX: sur ce changement
-[00:54:56] VOIX: de DP
-[00:54:58] VOIX: on a argumenté
-[00:55:00] VOIX: sur ça
-[00:55:00] VOIX: peut-être
-[00:55:01] VOIX: on l'avait préparé
-[00:55:02] VOIX: mais donc
-[00:55:02] VOIX: je suis obligé
-[00:55:03] VOIX: de avant
-[00:55:04] VOIX: la concertation
-[00:55:05] VOIX: la concertation
-[00:55:06] VOIX: c'est la clôture
-[00:55:08] VOIX: finale
-[00:55:09] VOIX: du dossier
-[00:55:10] VOIX: avant la clôture
-[00:55:11] VOIX: finale
-[00:55:12] VOIX: du dossier
-[00:55:12] VOIX: il
-[00:55:15] VOIX: c'est comme
-[00:55:16] VOIX: si là
-[00:55:16] VOIX: moi
-[00:55:16] VOIX: le dim
-[00:55:17] VOIX: je fais
-[00:55:17] VOIX: une contre-proposition
-[00:55:19] VOIX: je ne me rappelle
-[00:55:20] VOIX: plus
-[00:55:20] VOIX: de l'argument
-[00:55:21] VOIX: mais tu vois
-[00:55:23] VOIX: par exemple
-[00:55:23] VOIX: au début
-[00:55:25] VOIX: ça c'est un bug
-[00:55:28] VOIX: qu'on a
-[00:55:29] VOIX: tu vois
-[00:55:29] VOIX: tous les arguments
-[00:55:31] VOIX: qu'on écrit
-[00:55:31] VOIX: le médecin
-[00:55:32] VOIX: il a fait ça
-[00:55:33] VOIX: il a écrit ça
-[00:55:34] VOIX: bla bla bla
-[00:55:39] VOIX: ah j'ai un bug
-[00:55:41] VOIX: j'ai un bug
-[00:55:42] VOIX: dans l'appli
-[00:55:43] VOIX: depuis hier
-[00:55:45] VOIX: ça je l'ai dit
-[00:55:46] VOIX: à Jorban
-[00:55:46] VOIX: je ne pourrais pas
-[00:55:47] VOIX: te montrer correctement
-[00:55:48] VOIX: c'est pas grave
-[00:55:49] VOIX: ça s'affiche
-[00:55:50] VOIX: tu vois
-[00:55:51] VOIX: en fait
-[00:55:53] VOIX: nous
-[00:55:53] VOIX: on avait
-[00:55:54] VOIX: on avait
-[00:55:55] VOIX: enlevé
-[00:55:56] VOIX: le diagnostic associé
-[00:55:57] VOIX: mais on n'avait
-[00:55:58] VOIX: pas préparé ça
-[00:55:59] VOIX: ça
-[00:56:00] VOIX: je l'ai fait
-[00:56:00] VOIX: a posteriori
-[00:56:01] VOIX: pour dire
-[00:56:02] VOIX: parce que je ne savais pas
-[00:56:04] VOIX: qu'il allait modifier
-[00:56:04] VOIX: le DP
-[00:56:06] VOIX: mais si mon argumentaire
-[00:56:08] VOIX: au début
-[00:56:08] VOIX: initial
-[00:56:09] VOIX: il est fait par l'IA
-[00:56:11] VOIX: il évoque aussi
-[00:56:13] VOIX: ce risque
-[00:56:13] VOIX: il pourra me préparer
-[00:56:15] VOIX: à dire
-[00:56:16] VOIX: attention
-[00:56:16] VOIX: le DP peut être
-[00:56:18] VOIX: modifié
-[00:56:19] VOIX: pourquoi
-[00:56:19] VOIX: etc
-[00:56:20] VOIX: voici votre argumentaire
-[00:56:22] VOIX: pour le défendre
-[00:56:23] VOIX: par exemple
-[00:56:24] VOIX: tu vois
-[00:56:25] VOIX: et lui
-[00:56:26] VOIX: le contrôleur
-[00:56:27] VOIX: il m'a fait sauter
-[00:56:30] VOIX: il fait son argumentaire
-[00:56:32] VOIX: pour dire
-[00:56:32] VOIX: non mais moi
-[00:56:33] VOIX: je garde
-[00:56:34] VOIX: la règle
-[00:56:34] VOIX: SA
-[00:56:35] VOIX: le DP
-[00:56:36] VOIX: c'est ça
-[00:56:37] VOIX: ça c'est un cas
-[00:56:38] VOIX: un petit peu plus
-[00:56:39] VOIX: complexe
-[00:56:39] VOIX: que le cas
-[00:56:40] VOIX: de la chirurgie
-[00:56:41] VOIX: donc
-[00:56:42] VOIX: mais ça
-[00:56:42] VOIX: on le travaillera aussi
-[00:56:43] VOIX: parce que
-[00:56:44] VOIX: c'est des cas
-[00:56:45] VOIX: qui existent en vrai
-[00:56:47] VOIX: c'est un jeu
-[00:56:49] VOIX: avec ça
-[00:56:49] VOIX: c'est à dire que
-[00:56:51] VOIX: lui il peut dire
-[00:56:52] VOIX: j'utilise
-[00:56:53] VOIX: la règle
-[00:56:53] VOIX: T3
-[00:56:54] VOIX: et moi
-[00:56:55] VOIX: je lui dis
-[00:56:55] VOIX: ben non
-[00:56:56] VOIX: j'utilise la règle
-[00:56:57] VOIX: SA
-[00:56:58] VOIX: oui oui
-[00:56:59] VOIX: j'ai saisi
-[00:57:01] VOIX: tu vois
-[00:57:01] VOIX: c'est ça
-[00:57:02] VOIX: les contre-argumentaires
-[00:57:04] VOIX: c'est à dire que
-[00:57:04] VOIX: parfois nous
-[00:57:05] VOIX: dans les formations
-[00:57:06] VOIX: on fait des synthèses
-[00:57:07] VOIX: de ces règles
-[00:57:08] VOIX: tu vois
-[00:57:09] VOIX: t'as des exemples
-[00:57:10] VOIX: t'as etc
-[00:57:11] VOIX: et c'est vrai
-[00:57:12] VOIX: qu'à la lecture
-[00:57:13] VOIX: c'est pas l'impi
-[00:57:14] VOIX: t'as des dossiers
-[00:57:16] VOIX: qui sont clairs
-[00:57:17] VOIX: mais typiquement
-[00:57:18] VOIX: celui-là
-[00:57:20] VOIX: il me dit
-[00:57:21] VOIX: le patient
-[00:57:21] VOIX: il vient pour
-[00:57:22] VOIX: semaine d'éducation
-[00:57:25] VOIX: anti-diabétique
-[00:57:26] VOIX: au reau
-[00:57:27] VOIX: il est traité
-[00:57:27] VOIX: machin
-[00:57:28] VOIX: bon on va pas
-[00:57:28] VOIX: tout lire
-[00:57:29] VOIX: mais il me dit
-[00:57:30] VOIX: pas de modification
-[00:57:31] VOIX: erreur du DP
-[00:57:32] VOIX: tu vois
-[00:57:33] VOIX: pas de modification
-[00:57:34] VOIX: du schéma thérapeutique
-[00:57:36] VOIX: non justifié
-[00:57:37] VOIX: il s'agit d'un cas
-[00:57:38] VOIX: de surveillance négative
-[00:57:39] VOIX: le séjour
-[00:57:40] VOIX: on n'a pas mis
-[00:57:41] VOIX: en évidence
-[00:57:42] VOIX: l'affection nouvelle
-[00:57:43] VOIX: donc je code
-[00:57:44] VOIX: le Z92
-[00:57:45] VOIX: et t'as vu
-[00:57:46] VOIX: le Z092
-[00:57:49] VOIX: il me dévalorise
-[00:57:50] VOIX: de 4000 euros
-[00:57:51] VOIX: et moi
-[00:57:52] VOIX: j'ai essayé
-[00:57:53] VOIX: à défaut
-[00:57:54] VOIX: de pouvoir
-[00:57:55] VOIX: garder
-[00:57:55] VOIX: mon premier DP
-[00:57:57] VOIX: de proposer
-[00:57:58] VOIX: un autre DP
-[00:57:59] VOIX: qui casse
-[00:58:00] VOIX: moins le dossier
-[00:58:01] VOIX: tu vois
-[00:58:02] VOIX: au moins
-[00:58:02] VOIX: ça casse moi
-[00:58:03] VOIX: au lieu de 4000
-[00:58:04] VOIX: bah t'es à moins
-[00:58:05] VOIX: 1600
-[00:58:05] VOIX: et j'essaie
-[00:58:07] VOIX: essentiellement
-[00:58:08] VOIX: autour de l'éducation
-[00:58:09] VOIX: la réponse
-[00:58:10] VOIX: de docteur
-[00:58:10] VOIX: machin
-[00:58:11] VOIX: sur Agora
-[00:58:12] VOIX: c'est le forum
-[00:58:12] VOIX: de la TIH
-[00:58:13] VOIX: c'était marqué
-[00:58:15] VOIX: qu'on vient bien
-[00:58:15] VOIX: en effet à l'éducation
-[00:58:16] VOIX: thérapeutique
-[00:58:17] VOIX: s'il en tient compte
-[00:58:18] VOIX: de la notion
-[00:58:19] VOIX: de conseil diététique
-[00:58:20] VOIX: de son intitulé
-[00:58:21] VOIX: et je lui casse
-[00:58:22] VOIX: son code
-[00:58:22] VOIX: en disant
-[00:58:23] VOIX: le Z092
-[00:58:24] VOIX: ou situation
-[00:58:25] VOIX: de surveillance négative
-[00:58:27] VOIX: en lui incluant
-[00:58:28] VOIX: les examens médicaux
-[00:58:30] VOIX: de recherche
-[00:58:30] VOIX: notamment
-[00:58:31] VOIX: les complications
-[00:58:31] VOIX: machin
-[00:58:32] VOIX: donc
-[00:58:33] VOIX: l'éducation
-[00:58:34] VOIX: thérapeutique
-[00:58:35] VOIX: est bien été
-[00:58:36] VOIX: dans le dossier
-[00:58:36] VOIX: patient
-[00:58:37] VOIX: avec les différents
-[00:58:37] VOIX: intervenants
-[00:58:39] VOIX: voilà
-[00:58:40] VOIX: donc je défends
-[00:58:41] VOIX: j'essaie de
-[00:58:42] VOIX: de défendre
-[00:58:43] VOIX: un nouveau code
-[00:58:44] VOIX: pour moi
-[00:58:45] VOIX: parce que je n'ai pas
-[00:58:47] VOIX: d'argument
-[00:58:47] VOIX: pour maintenir
-[00:58:48] VOIX: le codage initial
-[00:58:50] VOIX: donc ça n'a pas été
-[00:58:52] VOIX: en tout cas nous
-[00:58:53] VOIX: on n'a pas repéré
-[00:58:54] VOIX: lors de la préparation
-[00:58:56] VOIX: mais comme le médecin contrôleur
-[00:58:58] VOIX: propose quelque chose
-[00:58:59] VOIX: qui cache beaucoup
-[00:59:00] VOIX: le dossier
-[00:59:00] VOIX: nous on va s'adapter
-[00:59:02] VOIX: dans le nouveau argumentaire
-[00:59:04] VOIX: à la décision du contrôleur
-[00:59:06] VOIX: et la concertation
-[00:59:08] VOIX: c'est la décision finale
-[00:59:09] VOIX: est-ce que j'arrive
-[00:59:10] VOIX: à le convaincre
-[00:59:11] VOIX: ou est-ce qu'il garde
-[00:59:12] VOIX: son avis initial
-[00:59:13] VOIX: d'accord
-[00:59:14] VOIX: bon j'ai parfaitement
-[00:59:15] VOIX: en fait
-[00:59:15] VOIX: saisi
-[00:59:16] VOIX: ok
-[00:59:16] VOIX: voilà
-[00:59:18] VOIX: et donc
-[00:59:19] VOIX: pour le dossier 1
-[00:59:24] VOIX: je ne sais pas
-[00:59:25] VOIX: comment
-[00:59:26] VOIX: je vais savoir
-[00:59:28] VOIX: comment il était
-[00:59:30] VOIX: il était codé
-[00:59:32] VOIX: en fait
-[00:59:34] VOIX: mais après
-[00:59:35] VOIX: peu importe
-[00:59:36] VOIX: tu vois
-[00:59:37] VOIX: c'est la logique
-[00:59:37] VOIX: qui compte
-[00:59:38] VOIX: ouais
-[00:59:39] VOIX: ouais
-[00:59:44] VOIX: donc
-[00:59:45] VOIX: justificatif
-[00:59:47] VOIX: retour
-[00:59:48] VOIX: je t'en as
-[00:59:52] VOIX: donc
-[00:59:53] VOIX: est-ce que
-[00:59:54] VOIX: je vais
-[00:59:55] VOIX: donc
-[00:59:59] VOIX: tu vois
-[00:59:59] VOIX: je n'ai pas de fiche
-[01:00:01] VOIX: pour le 8
-[01:00:01] VOIX: donc
-[01:00:02] VOIX: le médecin contrôleur
-[01:00:03] VOIX: il a validé
-[01:00:04] VOIX: le codage initial
-[01:00:05] VOIX: de l'établissement
-[01:00:06] VOIX: mais par exemple
-[01:00:07] VOIX: pour le dossier 8
-[01:00:11] VOIX: nous on avait codé
-[01:00:13] VOIX: on pourra regarder
-[01:00:14] VOIX: le dossier 8
-[01:00:16] VOIX: on a codé
-[01:00:17] VOIX: calcul de la vésicule
-[01:00:18] VOIX: biliaire
-[01:00:18] VOIX: et lui
-[01:00:18] VOIX: il dit
-[01:00:19] VOIX: non
-[01:00:19] VOIX: c'est pas le K800
-[01:00:20] VOIX: c'est le K
-[01:00:23] VOIX: le K805
-[01:00:25] VOIX: et tu vois
-[01:00:26] VOIX: il n'a pas changé
-[01:00:27] VOIX: les actes
-[01:00:28] VOIX: mais le groupage
-[01:00:29] VOIX: tu vois
-[01:00:29] VOIX: ça passe
-[01:00:30] VOIX: de 07713
-[01:00:31] VOIX: A 07714
-[01:00:33] VOIX: et si on regarde
-[01:00:35] VOIX: l'appli
-[01:00:35] VOIX: le dossier 8
-[01:00:40] VOIX: on va voir
-[01:00:42] VOIX: nous
-[01:00:44] VOIX: on avait déjà repéré
-[01:00:45] VOIX: dans la prépa
-[01:00:46] VOIX: que le patient
-[01:00:48] VOIX: que le codage
-[01:00:49] VOIX: n'était pas bon
-[01:00:50] VOIX: parce qu'il n'y avait pas
-[01:00:51] VOIX: de cholycystite aiguë
-[01:00:53] VOIX: donc on avait proposé
-[01:00:54] VOIX: un code
-[01:00:55] VOIX: qui est le 801
-[01:00:56] VOIX: différent de 800
-[01:00:57] VOIX: mais qui faisait perdre
-[01:00:58] VOIX: déjà 800 euros
-[01:01:00] VOIX: le contrôleur
-[01:01:01] VOIX: il a dit
-[01:01:01] VOIX: le 805
-[01:01:02] VOIX: mais c'est pareil
-[01:01:04] VOIX: sa groupe
-[01:01:05] VOIX: pareil
-[01:01:06] VOIX: il dit
-[01:01:07] VOIX: bon moi
-[01:01:07] VOIX: je ne sais pas
-[01:01:07] VOIX: pourquoi
-[01:01:08] VOIX: j'ai essayé de changer
-[01:01:09] VOIX: ça ne change rien du tout
-[01:01:10] VOIX: et que finalement
-[01:01:11] VOIX: on était tous d'accord
-[01:01:12] VOIX: à mon avis
-[01:01:13] VOIX: lui
-[01:01:13] VOIX: il est en accord
-[01:01:14] VOIX: parce que
-[01:01:15] VOIX: on ne peut rien tirer
-[01:01:16] VOIX: nous
-[01:01:17] VOIX: d'emblée
-[01:01:18] VOIX: dans la prépa
-[01:01:19] VOIX: on a dit
-[01:01:20] VOIX: il n'y a pas de choix
-[01:01:21] VOIX: par contre
-[01:01:22] VOIX: on va le lire ensemble
-[01:01:23] VOIX: le dossier 8
-[01:01:24] VOIX: tu vois
-[01:01:25] VOIX: je vais sortir
-[01:01:26] VOIX: pour que tu vois
-[01:01:28] VOIX: parce qu'il ressemble
-[01:01:28] VOIX: un peu
-[01:01:29] VOIX: à ce dossier là
-[01:01:31] VOIX: et on va voir
-[01:01:35] VOIX: ça va ?
-[01:01:36] VOIX: alors ça fait
-[01:01:37] VOIX: ça fait beaucoup d'informations
-[01:01:40] VOIX: mais bon
-[01:01:41] VOIX: tu as répondu
-[01:01:42] VOIX: à une partie de mes questions
-[01:01:43] VOIX: il va falloir
-[01:01:43] VOIX: que je t'en pose
-[01:01:46] VOIX: quelques-unes
-[01:01:46] VOIX: mais pour moi
-[01:01:48] VOIX: en fait
-[01:01:48] VOIX: c'est bon
-[01:01:49] VOIX: en fait
-[01:01:49] VOIX: ce qu'il me faut
-[01:01:50] VOIX: ce qu'il me faudrait
-[01:01:51] VOIX: en fait
-[01:01:51] VOIX: pour faire un truc carré
-[01:01:53] VOIX: pour moi
-[01:01:54] VOIX: je parle en fait
-[01:01:54] VOIX: pour que je puisse avoir
-[01:01:55] VOIX: ma vision
-[01:01:56] VOIX: il me faudrait en fait
-[01:01:57] VOIX: discuter
-[01:01:58] VOIX: en fait
-[01:01:58] VOIX: on prend
-[01:01:59] VOIX: un ou deux
-[01:02:00] VOIX: je t'écoute
-[01:02:01] VOIX: je prends juste mon ordi
-[01:02:02] VOIX: ouais
-[01:02:03] VOIX: c'est de prendre
-[01:02:03] VOIX: c'est de prendre
-[01:02:04] VOIX: un ou deux
-[01:02:06] VOIX: compte rendu
-[01:02:08] VOIX: et de faire
-[01:02:09] VOIX: toute la chaîne
-[01:02:10] VOIX: en fait
-[01:02:11] VOIX: c'est
-[01:02:11] VOIX: voilà
-[01:02:12] VOIX: on l'a codé comme ça
-[01:02:13] VOIX: et puis après
-[01:02:14] VOIX: en fait
-[01:02:15] VOIX: on a un contrôle
-[01:02:16] VOIX: le contrôleur
-[01:02:17] VOIX: il dit ça
-[01:02:18] VOIX: ou ça
-[01:02:18] VOIX: et tu vois
-[01:02:19] VOIX: en fait
-[01:02:19] VOIX: faire toute la chaîne
-[01:02:20] VOIX: sur un ou deux dossiers
-[01:02:22] VOIX: de telle manière
-[01:02:22] VOIX: avoir quelque chose
-[01:02:23] VOIX: en fait
-[01:02:23] VOIX: de probant
-[01:02:25] VOIX: voilà
-[01:02:26] VOIX: alors regarde
-[01:02:27] VOIX: le dossier
-[01:02:28] VOIX: pas le 8
-[01:02:29] VOIX: parce qu'on était d'accord
-[01:02:30] VOIX: de A à Z
-[01:02:33] VOIX: on va
-[01:02:33] VOIX: parce que j'ai mis
-[01:02:35] VOIX: 399
-[01:02:35] VOIX: mais en fait
-[01:02:36] VOIX: c'est le dossier
-[01:02:38] VOIX: 339
-[01:02:38] VOIX: alors je vais pas
-[01:02:40] VOIX: l'avoir
-[01:02:40] VOIX: figure toi
-[01:02:40] VOIX: parce qu'en fait
-[01:02:42] VOIX: je travaille que sur
-[01:02:43] VOIX: 250 dossiers
-[01:02:45] VOIX: et à mon avis
-[01:02:45] VOIX: le 300
-[01:02:47] VOIX: 39
-[01:02:48] VOIX: OGC
-[01:02:49] VOIX: 339
-[01:02:50] VOIX: non
-[01:02:50] VOIX: je l'ai pas
-[01:02:53] VOIX: je l'ai pas
-[01:02:54] VOIX: à la limite
-[01:02:55] VOIX: en fait
-[01:02:55] VOIX: si tu l'as sous la main
-[01:02:56] VOIX: envoie-moi-le
-[01:02:56] VOIX: comme ça
-[01:02:57] VOIX: en fait
-[01:02:57] VOIX: je l'aurai
-[01:02:57] VOIX: ou alors
-[01:02:58] VOIX: je vais le chercher
-[01:02:59] VOIX: parce que
-[01:02:59] VOIX: j'ai été obligé
-[01:03:00] VOIX: d'arrêter
-[01:03:00] VOIX: mon navigateur
-[01:03:01] VOIX: et je peux pas
-[01:03:02] VOIX: aller le chercher
-[01:03:03] VOIX: de suite
-[01:03:04] VOIX: quoi que
-[01:03:05] VOIX: allez
-[01:03:05] VOIX: 50
-[01:03:06] VOIX: en anglais
-[01:03:12] VOIX: où est-ce que
-[01:03:13] VOIX: je t'avais déposé
-[01:03:14] VOIX: tout ça
-[01:03:15] VOIX: dans
-[01:03:17] VOIX: chère fille
-[01:03:17] VOIX: attends
-[01:03:18] VOIX: je vais peut-être
-[01:03:19] VOIX: je vais le trouver
-[01:03:20] VOIX: c'est bon
-[01:03:24] VOIX: ok
-[01:03:28] VOIX: et moi
-[01:03:29] VOIX: j'ai le codage
-[01:03:30] VOIX: dans l'appli
-[01:03:30] VOIX: les argumentaires
-[01:03:31] VOIX: dans l'appli
-[01:03:32] VOIX: donc on le lit
-[01:03:33] VOIX: ensemble
-[01:03:34] VOIX: et on
-[01:03:37] VOIX: regarde l'appli
-[01:03:38] VOIX: après
-[01:03:39] VOIX: ensemble
-[01:03:39] VOIX: allez
-[01:03:40] VOIX: le partage
-[01:03:42] VOIX: document post
-[01:03:44] VOIX: contrôle
-[01:03:45] VOIX: je pense que
-[01:03:46] VOIX: c'était celui
-[01:03:46] VOIX: non
-[01:03:47] VOIX: attends
-[01:03:47] VOIX: ça c'est
-[01:03:47] VOIX: guillière
-[01:03:49] VOIX: hop là
-[01:03:53] VOIX: voilà
-[01:03:53] VOIX: je l'ai
-[01:03:54] VOIX: voilà
-[01:03:55] VOIX: et tu m'as dit
-[01:03:56] VOIX: alors attends
-[01:03:56] VOIX: parce que
-[01:03:57] VOIX: c'est
-[01:03:58] VOIX: j'avais
-[01:03:58] VOIX: trouvé que
-[01:03:59] VOIX: 250 dossiers
-[01:04:01] VOIX: et en fait
-[01:04:01] VOIX: il faut cliquer
-[01:04:02] VOIX: 10 fois
-[01:04:03] VOIX: pour voir
-[01:04:03] VOIX: en fait
-[01:04:03] VOIX: alors tu m'as dit
-[01:04:04] VOIX: 339
-[01:04:06] VOIX: oui
-[01:04:09] VOIX: 339
-[01:04:12] VOIX: voilà
-[01:04:13] VOIX: c'est
-[01:04:13] VOIX: 339
-[01:04:14] VOIX: 23 07
-[01:04:16] VOIX: 2740
-[01:04:17] VOIX: c'est ça
-[01:04:18] VOIX: je vérifie
-[01:04:19] VOIX: oui
-[01:04:22] VOIX: 2740
-[01:04:22] VOIX: oui
-[01:04:23] VOIX: allez
-[01:04:23] VOIX: je le télécharge
-[01:04:25] VOIX: on est sur
-[01:04:25] VOIX: le bon dossier
-[01:04:27] VOIX: hop là
-[01:04:28] VOIX: alors
-[01:04:28] VOIX: ah oui
-[01:04:29] VOIX: mais ça c'est
-[01:04:29] VOIX: le tracker
-[01:04:31] VOIX: il me faut tout
-[01:04:34] VOIX: mais t'auras
-[01:04:35] VOIX: que le tracker
-[01:04:36] VOIX: mais on va voir
-[01:04:36] VOIX: que dans le tracker
-[01:04:37] VOIX: il peut y avoir
-[01:04:38] VOIX: beaucoup de choses
-[01:04:39] VOIX: d'accord
-[01:04:39] VOIX: peut-être que
-[01:04:40] VOIX: t'as qu'un tracker
-[01:04:41] VOIX: mais le tracker
-[01:04:42] VOIX: c'est
-[01:04:43] VOIX: il concatène
-[01:04:45] VOIX: un peu
-[01:04:45] VOIX: les documents
-[01:04:46] VOIX: qui sort
-[01:04:46] VOIX: ensemble
-[01:04:48] VOIX: hop là
-[01:04:49] VOIX: donc j'ai le document
-[01:04:51] VOIX: je te partage
-[01:04:52] VOIX: mon écran
-[01:04:52] VOIX: oui
-[01:04:53] VOIX: en premier
-[01:04:54] VOIX: oui
-[01:04:56] VOIX: voilà
-[01:04:57] VOIX: tu te vois
-[01:05:00] VOIX: oui
-[01:05:02] VOIX: allez
-[01:05:03] VOIX: voilà
-[01:05:04] VOIX: donc on va tout en bout
-[01:05:08] VOIX: exactement
-[01:05:09] VOIX: donc on refait
-[01:05:10] VOIX: la même
-[01:05:12] VOIX: mais j'ai l'impression
-[01:05:13] VOIX: que c'est le même
-[01:05:14] VOIX: non c'est le même
-[01:05:14] VOIX: celui-là c'est le même
-[01:05:15] VOIX: alors attends
-[01:05:16] VOIX: ouais non c'est pas le bon
-[01:05:18] VOIX: je vais le fermer
-[01:05:20] VOIX: je vais fermer ça
-[01:05:22] VOIX: et je vais recommencer
-[01:05:24] VOIX: à ouvrir en fait
-[01:05:25] VOIX: le document
-[01:05:26] VOIX: que je viens de télécharger
-[01:05:29] VOIX: voilà c'est ça
-[01:05:31] VOIX: il était ouvert
-[01:05:32] VOIX: mais dans un autre logiciel
-[01:05:33] VOIX: oui
-[01:05:34] VOIX: augmente
-[01:05:34] VOIX: augmente un peu
-[01:05:35] VOIX: s'il te plaît
-[01:05:39] VOIX: ça te va comme ça
-[01:05:40] VOIX: voilà
-[01:05:40] VOIX: en fait
-[01:05:40] VOIX: tracker
-[01:05:42] VOIX: d'abord
-[01:05:43] VOIX: c'est
-[01:05:45] VOIX: la solution
-[01:05:46] VOIX: de tracker
-[01:05:47] VOIX: de sortir les documents
-[01:05:48] VOIX: il faut vraiment pas
-[01:05:49] VOIX: que tu penses
-[01:05:50] VOIX: que tous les documents
-[01:05:51] VOIX: on va les avoir comme ça
-[01:05:52] VOIX: ça sera complètement différent
-[01:05:54] VOIX: ça il faut l'avoir en tête
-[01:05:56] VOIX: donc déjà
-[01:05:58] VOIX: tu regardes toujours
-[01:06:01] VOIX: les noms
-[01:06:02] VOIX: tu vois le dossier
-[01:06:05] VOIX: les dates
-[01:06:06] VOIX: là je vois que la date d'admission
-[01:06:08] VOIX: c'est le 13 avril
-[01:06:10] VOIX: date de sortie 21 avril
-[01:06:11] VOIX: donc tout ça
-[01:06:13] VOIX: parce que je vais le croiser
-[01:06:14] VOIX: aussi bien
-[01:06:15] VOIX: avec le dossier administratif
-[01:06:18] VOIX: avec le dossier du codage
-[01:06:19] VOIX: etc
-[01:06:21] VOIX: si tu descends
-[01:06:22] VOIX: j'ai pas la main
-[01:06:24] VOIX: donc c'est
-[01:06:26] VOIX: tu vois
-[01:06:27] VOIX: il vient
-[01:06:28] VOIX: le 13
-[01:06:31] VOIX: il passe aux urgences
-[01:06:32] VOIX: très bien
-[01:06:32] VOIX: bien vu
-[01:06:33] VOIX: tu vois
-[01:06:33] VOIX: là tu vois que
-[01:06:35] VOIX: t'as la trace
-[01:06:36] VOIX: du passage aux urgences
-[01:06:37] VOIX: par véhicule personnel
-[01:06:40] VOIX: les horaires
-[01:06:42] VOIX: circulant
-[01:06:43] VOIX: douleur à dos
-[01:06:44] VOIX: tu vois
-[01:06:44] VOIX: ça c'est quoi
-[01:06:45] VOIX: priorité 4
-[01:06:46] VOIX: c'est une classification
-[01:06:47] VOIX: des urgences
-[01:06:48] VOIX: nous on l'utilise pas
-[01:06:50] VOIX: mais après
-[01:06:51] VOIX: tu sais
-[01:06:53] VOIX: les différentes échelles
-[01:06:56] VOIX: etc
-[01:06:57] VOIX: on pourra
-[01:06:58] VOIX: peut-être
-[01:06:59] VOIX: petit à petit
-[01:07:00] VOIX: les intégrer
-[01:07:00] VOIX: mais c'est pas
-[01:07:01] VOIX: quelque chose
-[01:07:01] VOIX: que nous
-[01:07:03] VOIX: priorité 4
-[01:07:04] VOIX: peut-être eux
-[01:07:04] VOIX: tu sais
-[01:07:05] VOIX: ils classent
-[01:07:06] VOIX: est-ce que
-[01:07:07] VOIX: d'accord
-[01:07:07] VOIX: ça c'est dans
-[01:07:08] VOIX: est-ce que
-[01:07:10] VOIX: je le fais
-[01:07:10] VOIX: rentrer en premier
-[01:07:11] VOIX: ou est-ce qu'il peut
-[01:07:12] VOIX: attendre un peu
-[01:07:13] VOIX: ok
-[01:07:13] VOIX: je connais pas
-[01:07:15] VOIX: est-ce que c'est une
-[01:07:17] VOIX: classification nationale
-[01:07:18] VOIX: ou juste interne
-[01:07:20] VOIX: ça je peux pas te dire
-[01:07:20] VOIX: d'accord
-[01:07:21] VOIX: ça c'est un truc
-[01:07:21] VOIX: oui c'est un truc
-[01:07:23] VOIX: interne
-[01:07:23] VOIX: bon ok
-[01:07:24] VOIX: ok
-[01:07:24] VOIX: et après en fait
-[01:07:25] VOIX: on va partir du motif
-[01:07:27] VOIX: de prise en charge
-[01:07:28] VOIX: et là on sait que
-[01:07:29] VOIX: c'est le motif
-[01:07:30] VOIX: au niveau des urgences
-[01:07:31] VOIX: d'accord
-[01:07:32] VOIX: on voit douleur abdo
-[01:07:34] VOIX: on voit les observations
-[01:07:36] VOIX: de l'IDE
-[01:07:36] VOIX: des urgences
-[01:07:37] VOIX: il dit
-[01:07:38] VOIX: abrécé par médecin
-[01:07:39] VOIX: pour douleur abdo
-[01:07:40] VOIX: avec diarrhée persistante
-[01:07:41] VOIX: personne suivie
-[01:07:43] VOIX: par docteur machin
-[01:07:44] VOIX: pour anémie
-[01:07:44] VOIX: fériprive
-[01:07:45] VOIX: dans un contexte
-[01:07:47] VOIX: de maladie
-[01:07:48] VOIX: de bière mère
-[01:07:49] VOIX: donc là
-[01:07:50] VOIX: c'est un
-[01:07:53] VOIX: on va dire
-[01:07:54] VOIX: c'est un antécédent
-[01:07:55] VOIX: on le note
-[01:07:57] VOIX: dans un coin
-[01:07:58] VOIX: du cerveau
-[01:07:59] VOIX: pour vérifier
-[01:08:00] VOIX: si pendant
-[01:08:01] VOIX: le séjour
-[01:08:03] VOIX: il y a une prise en charge
-[01:08:05] VOIX: diagnostique
-[01:08:06] VOIX: thérapeutique
-[01:08:06] VOIX: ou surveillance
-[01:08:08] VOIX: de sa maladie
-[01:08:09] VOIX: de bière mère
-[01:08:09] VOIX: et de son anémie
-[01:08:10] VOIX: fériprive
-[01:08:11] VOIX: si elle prend
-[01:08:12] VOIX: un traitement
-[01:08:13] VOIX: si ça a été surveillé
-[01:08:14] VOIX: si ça a été évoqué
-[01:08:16] VOIX: on va le coder
-[01:08:17] VOIX: en diagnostic associé
-[01:08:18] VOIX: mais si personne
-[01:08:20] VOIX: n'en parle
-[01:08:20] VOIX: ou ne s'en préoccupe
-[01:08:22] VOIX: ça va être difficile
-[01:08:23] VOIX: de le coder
-[01:08:24] VOIX: parce qu'on va considérer
-[01:08:25] VOIX: que c'est juste
-[01:08:26] VOIX: un antécédent
-[01:08:28] VOIX: qui n'a pas été
-[01:08:29] VOIX: prise en charge
-[01:08:29] VOIX: pendant
-[01:08:30] VOIX: l'hospite
-[01:08:31] VOIX: il n'y a pas
-[01:08:32] VOIX: d'histoire de maladie
-[01:08:33] VOIX: là
-[01:08:33] VOIX: remonte
-[01:08:33] VOIX: remonte
-[01:08:34] VOIX: t'es allé un peu vite
-[01:08:34] VOIX: non non mais je sais
-[01:08:35] VOIX: pourquoi j'ai fait ça
-[01:08:36] VOIX: mais vas-y
-[01:08:37] VOIX: oui
-[01:08:39] VOIX: donc tu vois
-[01:08:40] VOIX: là
-[01:08:40] VOIX: les urgences
-[01:08:41] VOIX: machin
-[01:08:42] VOIX: il l'hospitalise
-[01:08:43] VOIX: en médecine
-[01:08:45] VOIX: donc on voit
-[01:08:46] VOIX: qu'elle est remontée
-[01:08:47] VOIX: l'avant-dernière ligne
-[01:08:48] VOIX: les deux dernières lignes
-[01:08:49] VOIX: tu vois
-[01:08:50] VOIX: qu'elle est remontée
-[01:08:50] VOIX: en gastro-entérologie
-[01:08:52] VOIX: là en tout cas
-[01:08:53] VOIX: ce document
-[01:08:54] VOIX: des urgences
-[01:08:55] VOIX: il trace
-[01:08:55] VOIX: un petit peu
-[01:08:56] VOIX: ce qui s'est passé
-[01:08:57] VOIX: donc c'est plutôt
-[01:08:58] VOIX: pas mal
-[01:08:59] VOIX: c'est pas le cas partout
-[01:09:00] VOIX: dans tous les établissements
-[01:09:02] VOIX: et avec tous les
-[01:09:03] VOIX: des pays
-[01:09:03] VOIX: encore une fois
-[01:09:05] VOIX: c'est pour ça que
-[01:09:06] VOIX: c'est des adaptations
-[01:09:08] VOIX: t'as des informations
-[01:09:10] VOIX: aussi en dehors du tableau
-[01:09:12] VOIX: dans taille
-[01:09:13] VOIX: poids
-[01:09:13] VOIX: tu vois
-[01:09:18] VOIX: il faut les garder
-[01:09:21] VOIX: dans un coin
-[01:09:22] VOIX: ça fait partie des infos
-[01:09:24] VOIX: et
-[01:09:25] VOIX: on descend
-[01:09:26] VOIX: on voit
-[01:09:27] VOIX: ce qui s'est passé
-[01:09:28] VOIX: après
-[01:09:30] VOIX: qu'est-ce qu'il y a
-[01:09:30] VOIX: ah oui
-[01:09:31] VOIX: là
-[01:09:31] VOIX: tu as
-[01:09:32] VOIX: ce qu'on appelle
-[01:09:33] VOIX: les constantes
-[01:09:35] VOIX: la température
-[01:09:36] VOIX: le pouls
-[01:09:37] VOIX: la fréquence cardiaque
-[01:09:38] VOIX: la pression arternelle
-[01:09:39] VOIX: etc
-[01:09:40] VOIX: jour par jour
-[01:09:41] VOIX: et heure par heure
-[01:09:42] VOIX: en tout cas
-[01:09:43] VOIX: quand ils l'ont prise
-[01:09:44] VOIX: et ça parfois
-[01:09:45] VOIX: on l'exploite
-[01:09:48] VOIX: pour un certain codage
-[01:09:51] VOIX: là typiquement
-[01:09:53] VOIX: moi d'emblée
-[01:09:54] VOIX: à la
-[01:09:54] VOIX: 1, 2, 3, 4, 5, 6
-[01:09:57] VOIX: la 6ème colonne
-[01:09:58] VOIX: c'est-à-dire
-[01:09:59] VOIX: le 20 avril
-[01:10:00] VOIX: à 15h51
-[01:10:01] VOIX: je repère
-[01:10:03] VOIX: que la patiente
-[01:10:05] VOIX: elle avait
-[01:10:05] VOIX: 38 de fièvre
-[01:10:07] VOIX: et 95 de
-[01:10:09] VOIX: battements
-[01:10:10] VOIX: par l'équipe
-[01:10:11] VOIX: ce qui fait beaucoup
-[01:10:12] VOIX: et le 20 avril
-[01:10:14] VOIX: à 8h
-[01:10:15] VOIX: c'est-à-dire
-[01:10:15] VOIX: 3 colonnes après
-[01:10:16] VOIX: c'est pareil
-[01:10:18] VOIX: il y a 38
-[01:10:19] VOIX: et 93
-[01:10:20] VOIX: et on a
-[01:10:21] VOIX: une règle
-[01:10:22] VOIX: du codage
-[01:10:22] VOIX: qui dit
-[01:10:24] VOIX: qu'on peut
-[01:10:24] VOIX: coder
-[01:10:25] VOIX: un niveau 2
-[01:10:26] VOIX: un R65
-[01:10:27] VOIX: en chante fou
-[01:10:28] VOIX: quand on a
-[01:10:29] VOIX: la fièvre
-[01:10:30] VOIX: supérieure à 38
-[01:10:31] VOIX: ou la fréquence cardiaque
-[01:10:32] VOIX: supérieure à 90
-[01:10:33] VOIX: donc quand on est
-[01:10:35] VOIX: à court de diagnostic
-[01:10:36] VOIX: ça on l'analyse aussi
-[01:10:38] VOIX: tu vois
-[01:10:40] VOIX: parfois dans
-[01:10:41] VOIX: certains services
-[01:10:42] VOIX: pour certains suppléments
-[01:10:43] VOIX: etc
-[01:10:44] VOIX: on va regarder
-[01:10:45] VOIX: la pression artérielle
-[01:10:46] VOIX: on va regarder
-[01:10:47] VOIX: la diurèse
-[01:10:48] VOIX: on va regarder
-[01:10:49] VOIX: donc
-[01:10:49] VOIX: on l'analyse pas
-[01:10:51] VOIX: systématiquement
-[01:10:52] VOIX: mais en fonction
-[01:10:53] VOIX: des besoins
-[01:10:53] VOIX: on peut avoir
-[01:10:55] VOIX: on peut avoir besoin
-[01:10:57] VOIX: d'analyser
-[01:10:58] VOIX: ces éléments là
-[01:10:59] VOIX: ok
-[01:11:01] VOIX: mais pas
-[01:11:02] VOIX: mais pas a priori
-[01:11:03] VOIX: tu vois
-[01:11:04] VOIX: ça j'y vais
-[01:11:05] VOIX: après
-[01:11:05] VOIX: là
-[01:11:06] VOIX: c'est parce qu'on passe
-[01:11:07] VOIX: comme ça
-[01:11:07] VOIX: et pour que j'oublie pas
-[01:11:09] VOIX: pour te dire
-[01:11:10] VOIX: qu'on l'utilise
-[01:11:11] VOIX: mais je vais pas
-[01:11:12] VOIX: commencer par analyser ça
-[01:11:13] VOIX: mais là je vois
-[01:11:14] VOIX: que j'ai un R65
-[01:11:15] VOIX: en niveau 2
-[01:11:16] VOIX: que je peux aller chercher
-[01:11:18] VOIX: si vraiment
-[01:11:18] VOIX: je suis à court
-[01:11:19] VOIX: de niveau de sévérité
-[01:11:21] VOIX: ok
-[01:11:23] VOIX: si tu descends
-[01:11:26] VOIX: parce qu'après
-[01:11:27] VOIX: on les concentre
-[01:11:28] VOIX: là
-[01:11:28] VOIX: tu vois le type
-[01:11:30] VOIX: d'observation
-[01:11:31] VOIX: on va parcourir
-[01:11:32] VOIX: là tu es dans
-[01:11:34] VOIX: les observations médicales
-[01:11:35] VOIX: descends
-[01:11:36] VOIX: tu as l'histoire
-[01:11:37] VOIX: de la maladie
-[01:11:38] VOIX: on va regarder
-[01:11:38] VOIX: que les gros titres
-[01:11:39] VOIX: tu as la note d'évolution
-[01:11:42] VOIX: continue
-[01:11:44] VOIX: continue en bas
-[01:11:45] VOIX: note d'évolution
-[01:11:46] VOIX: donc ça
-[01:11:46] VOIX: on est dans
-[01:11:47] VOIX: le médical
-[01:11:48] VOIX: les observations médicales
-[01:11:49] VOIX: tout ça
-[01:11:51] VOIX: c'est les observations médicales
-[01:11:52] VOIX: on va voir
-[01:11:53] VOIX: est-ce qu'on a
-[01:11:54] VOIX: d'autres types
-[01:11:55] VOIX: de documents
-[01:11:56] VOIX: on a
-[01:11:57] VOIX: attends stop
-[01:11:58] VOIX: on a la sur
-[01:11:59] VOIX: remonte un tout petit peu
-[01:12:00] VOIX: on a la surveillance
-[01:12:03] VOIX: psychiatrie
-[01:12:03] VOIX: mais ça c'est bizarre
-[01:12:05] VOIX: oui
-[01:12:05] VOIX: c'est des constantes
-[01:12:06] VOIX: je te répète
-[01:12:07] VOIX: tu vois ces pièges
-[01:12:09] VOIX: on a les notes paramédicales
-[01:12:11] VOIX: ça c'est les infirmières
-[01:12:12] VOIX: notes idéales
-[01:12:14] VOIX: attend parce qu'en fait
-[01:12:14] VOIX: la personne elle est rentrée
-[01:12:15] VOIX: attends attends
-[01:12:16] VOIX: tu permets une seconde
-[01:12:18] VOIX: bien sûr
-[01:12:19] VOIX: pourquoi en fait
-[01:12:20] VOIX: elle a mal
-[01:12:28] VOIX: pourquoi ta question
-[01:12:29] VOIX: non non non
-[01:12:30] VOIX: c'est parce que
-[01:12:31] VOIX: j'ai essayé de comprendre
-[01:12:32] VOIX: pourquoi en fait
-[01:12:32] VOIX: on part en fait
-[01:12:33] VOIX: la prise en charge
-[01:12:34] VOIX: en fait
-[01:12:34] VOIX: avec une douleur abdominale
-[01:12:36] VOIX: et elle finit en psychiatrie
-[01:12:38] VOIX: enfin pourquoi en fait
-[01:12:38] VOIX: il y a une note
-[01:12:39] VOIX: je t'ai dit
-[01:12:40] VOIX: c'est un piège
-[01:12:41] VOIX: c'est une erreur
-[01:12:42] VOIX: parce que tu vois
-[01:12:43] VOIX: psychiatrie
-[01:12:43] VOIX: et t'as que des constantes
-[01:12:45] VOIX: ouais d'accord
-[01:12:45] VOIX: donc c'est pas normal
-[01:12:47] VOIX: que la surveillance psychiatrique
-[01:12:49] VOIX: se fait avec des constantes
-[01:12:51] VOIX: ça n'a pas de sens
-[01:12:52] VOIX: oui ça n'a pas de sens
-[01:12:53] VOIX: d'accord ok
-[01:12:54] VOIX: mais tu vois
-[01:12:54] VOIX: il faut faire gaffe
-[01:12:55] VOIX: aux termes
-[01:12:56] VOIX: parce que là
-[01:12:57] VOIX: évidemment
-[01:12:58] VOIX: je ne sais pas pourquoi
-[01:12:59] VOIX: t'as des notes
-[01:13:01] VOIX: paramédicales
-[01:13:02] VOIX: tu vois
-[01:13:02] VOIX: c'est des notes
-[01:13:03] VOIX: de l'infirmière
-[01:13:03] VOIX: tu vois
-[01:13:04] VOIX: que t'as la note
-[01:13:05] VOIX: du kinésithérapeute
-[01:13:07] VOIX: tu vois tout ça
-[01:13:08] VOIX: tu sais qui l'a vu
-[01:13:09] VOIX: il y a le médecin
-[01:13:11] VOIX: qui l'a vu
-[01:13:11] VOIX: il y a l'IDE
-[01:13:12] VOIX: il y a le kiné
-[01:13:14] VOIX: descend
-[01:13:15] VOIX: tu vois
-[01:13:16] VOIX: descend parce que
-[01:13:16] VOIX: c'est important
-[01:13:17] VOIX: qu'on voit
-[01:13:20] VOIX: combien un dossier
-[01:13:21] VOIX: il peut être lent
-[01:13:22] VOIX: et tu peux avoir
-[01:13:23] VOIX: beaucoup de notes
-[01:13:23] VOIX: et tout ça
-[01:13:24] VOIX: on le lit
-[01:13:24] VOIX: pour info
-[01:13:26] VOIX: t'as la note
-[01:13:27] VOIX: de la aide soignante
-[01:13:28] VOIX: note IDE
-[01:13:28] VOIX: stop
-[01:13:29] VOIX: t'as
-[01:13:30] VOIX: après t'as
-[01:13:31] VOIX: les administrations
-[01:13:33] VOIX: médicamenteuses
-[01:13:34] VOIX: continue
-[01:13:35] VOIX: continue
-[01:13:38] VOIX: parfois
-[01:13:39] VOIX: on va aller chercher
-[01:13:41] VOIX: pour vérifier
-[01:13:42] VOIX: est-ce qu'il y a
-[01:13:42] VOIX: de la sueline
-[01:13:43] VOIX: pour un diabète
-[01:13:44] VOIX: est-ce qu'il y a
-[01:13:44] VOIX: des antibiotiques
-[01:13:46] VOIX: est-ce que
-[01:13:46] VOIX: tu vois
-[01:13:47] VOIX: ça dépend après
-[01:13:48] VOIX: du niveau de
-[01:13:50] VOIX: de vigilance
-[01:13:51] VOIX: ou de lecture
-[01:13:52] VOIX: de la personne
-[01:13:52] VOIX: qui code
-[01:13:53] VOIX: mais c'est des choses
-[01:13:54] VOIX: qu'on peut aller chercher
-[01:13:55] VOIX: après t'as la biologie
-[01:13:57] VOIX: t'as la radiologie
-[01:13:59] VOIX: tu vois
-[01:13:59] VOIX: donc au moins
-[01:14:00] VOIX: les prescriptions
-[01:14:01] VOIX: de radio
-[01:14:02] VOIX: tu vas vérifier
-[01:14:03] VOIX: si elle a le scan
-[01:14:04] VOIX: l'écho
-[01:14:05] VOIX: la radio du thorax
-[01:14:06] VOIX: est-ce que t'as
-[01:14:07] VOIX: les comptes rendus
-[01:14:07] VOIX: en fait
-[01:14:08] VOIX: quand je te dis
-[01:14:09] VOIX: parfois
-[01:14:10] VOIX: dans un dossier
-[01:14:11] VOIX: et ça c'est un élément
-[01:14:12] VOIX: vraiment important
-[01:14:13] VOIX: quand on va faire
-[01:14:14] VOIX: l'IA vision
-[01:14:14] VOIX: sur le DPI
-[01:14:16] VOIX: c'est identifier
-[01:14:17] VOIX: ce qui a été fait
-[01:14:18] VOIX: comme document
-[01:14:19] VOIX: et s'assurer
-[01:14:20] VOIX: qu'on a tous
-[01:14:21] VOIX: les comptes rendus
-[01:14:22] VOIX: ça c'est quelque chose
-[01:14:22] VOIX: qui est capital
-[01:14:23] VOIX: là typiquement
-[01:14:24] VOIX: on voit
-[01:14:25] VOIX: qu'il y a eu
-[01:14:26] VOIX: une prescription
-[01:14:26] VOIX: de radio
-[01:14:27] VOIX: de scan
-[01:14:27] VOIX: etc
-[01:14:28] VOIX: on doit vérifier
-[01:14:30] VOIX: est-ce qu'elle a eu
-[01:14:32] VOIX: est-ce qu'elle a eu
-[01:14:33] VOIX: vraiment son scan
-[01:14:34] VOIX: c'est marqué
-[01:14:34] VOIX: réalisé
-[01:14:35] VOIX: réalisé
-[01:14:35] VOIX: il faut normalement
-[01:14:37] VOIX: que je puisse récupérer
-[01:14:38] VOIX: les comptes rendus
-[01:14:40] VOIX: sinon je dois indiquer
-[01:14:42] VOIX: que la patiente
-[01:14:42] VOIX: a un scan
-[01:14:43] VOIX: une écho
-[01:14:43] VOIX: mais finalement
-[01:14:44] VOIX: je n'ai pas
-[01:14:44] VOIX: les comptes rendus
-[01:14:45] VOIX: tu vois
-[01:14:46] VOIX: et Amina
-[01:14:47] VOIX: en fait
-[01:14:47] VOIX: dans ce cas là
-[01:14:48] VOIX: alors
-[01:14:50] VOIX: normalement
-[01:14:50] VOIX: s'il a les comptes rendus
-[01:14:52] VOIX: ils sont en fait
-[01:14:53] VOIX: dans le même dossier
-[01:14:54] VOIX: ou alors
-[01:14:55] VOIX: normalement
-[01:14:55] VOIX: un dossier bien fait
-[01:14:56] VOIX: il y a tout
-[01:14:57] VOIX: tu vas tout retrouver
-[01:14:59] VOIX: tu vas retrouver
-[01:15:00] VOIX: tous les comptes rendus
-[01:15:01] VOIX: d'accord
-[01:15:02] VOIX: ou bien
-[01:15:03] VOIX: ils te disent
-[01:15:03] VOIX: les comptes rendus
-[01:15:04] VOIX: d'imagerie
-[01:15:05] VOIX: sont dans tel logiciel
-[01:15:06] VOIX: donc on peut aller
-[01:15:07] VOIX: les chercher
-[01:15:07] VOIX: dans un autre logiciel
-[01:15:09] VOIX: ok
-[01:15:09] VOIX: tu vois
-[01:15:10] VOIX: mais
-[01:15:11] VOIX: le plus important
-[01:15:13] VOIX: c'est de
-[01:15:13] VOIX: c'est de dire
-[01:15:15] VOIX: là on a vu
-[01:15:16] VOIX: les observations médicales
-[01:15:18] VOIX: on a vu
-[01:15:18] VOIX: les transmissions
-[01:15:19] VOIX: des infirmières
-[01:15:20] VOIX: on a vu
-[01:15:21] VOIX: une note du kiné
-[01:15:22] VOIX: mais je ne sais pas
-[01:15:23] VOIX: est-ce que j'ai
-[01:15:23] VOIX: un compte rendu opératoire
-[01:15:24] VOIX: est-ce que j'ai
-[01:15:25] VOIX: un compte rendu
-[01:15:26] VOIX: est-ce que j'ai
-[01:15:28] VOIX: les comptes rendus
-[01:15:28] VOIX: de l'imagerie
-[01:15:30] VOIX: là la biologie
-[01:15:31] VOIX: il faut que je m'assure
-[01:15:32] VOIX: que j'ai les résultats
-[01:15:33] VOIX: de mes prescriptions
-[01:15:35] VOIX: biologiques
-[01:15:36] VOIX: bactériologiques
-[01:15:36] VOIX: parce que je vois
-[01:15:37] VOIX: qu'il y a eu
-[01:15:37] VOIX: de l'anapate
-[01:15:38] VOIX: il y a eu
-[01:15:39] VOIX: des hémocultures
-[01:15:41] VOIX: typiquement
-[01:15:41] VOIX: dans le premier dossier
-[01:15:42] VOIX: que tu m'as montré
-[01:15:43] VOIX: il fallait qu'on aille
-[01:15:44] VOIX: chercher le résultat
-[01:15:45] VOIX: de l'anapate
-[01:15:45] VOIX: je ne sais pas
-[01:15:46] VOIX: s'il y était
-[01:15:47] VOIX: dans le traquer
-[01:15:47] VOIX: ou pas
-[01:15:48] VOIX: mais le résultat
-[01:15:49] VOIX: de l'anapate
-[01:15:49] VOIX: il est clé
-[01:15:50] VOIX: dans le codage
-[01:15:52] VOIX: il est important
-[01:15:54] VOIX: il y a des dossiers
-[01:15:54] VOIX: qu'on ne code
-[01:15:55] VOIX: que sur l'anapate
-[01:15:58] VOIX: donc en fait
-[01:16:00] VOIX: c'est ça
-[01:16:00] VOIX: en fait
-[01:16:02] VOIX: quand je te dis
-[01:16:03] VOIX: l'exhaustivité
-[01:16:04] VOIX: des documents
-[01:16:05] VOIX: indépendamment
-[01:16:06] VOIX: de qu'est-ce
-[01:16:07] VOIX: qu'il y a dedans
-[01:16:07] VOIX: c'est déjà
-[01:16:08] VOIX: est-ce que j'ai
-[01:16:09] VOIX: mes documents
-[01:16:10] VOIX: et ça remonte
-[01:16:11] VOIX: un tout petit peu
-[01:16:11] VOIX: remonte un tout petit peu
-[01:16:12] VOIX: tu vois là
-[01:16:14] VOIX: régime diététique
-[01:16:15] VOIX: est-ce qu'elle a vu
-[01:16:17] VOIX: une diète
-[01:16:18] VOIX: tu vois
-[01:16:19] VOIX: qu'ils ont fait
-[01:16:19] VOIX: le CG
-[01:16:20] VOIX: ils ont
-[01:16:21] VOIX: tu vois
-[01:16:22] VOIX: à peu près
-[01:16:23] VOIX: ça te donne
-[01:16:23] VOIX: une idée
-[01:16:24] VOIX: des soins
-[01:16:25] VOIX: de ce qui a été fait
-[01:16:26] VOIX: pour cette patiente
-[01:16:27] VOIX: qui est resté
-[01:16:27] VOIX: quand même
-[01:16:27] VOIX: quelques jours
-[01:16:29] VOIX: et si
-[01:16:30] VOIX: tu descends
-[01:16:36] VOIX: ça c'est du
-[01:16:36] VOIX: blablabla
-[01:16:37] VOIX: c'est toutes les
-[01:16:38] VOIX: administrations
-[01:16:39] VOIX: médicamenteuses
-[01:16:40] VOIX: qui vont se répéter
-[01:16:41] VOIX: parce que
-[01:16:42] VOIX: ils ont fait
-[01:16:43] VOIX: comme ils pouvaient
-[01:16:44] VOIX: pour sortir
-[01:16:45] VOIX: des éléments
-[01:16:45] VOIX: pour le contrôle
-[01:16:46] VOIX: nous si on fait
-[01:16:48] VOIX: une extraction
-[01:16:49] VOIX: IA
-[01:16:50] VOIX: on va travailler
-[01:16:51] VOIX: tout ça
-[01:16:52] VOIX: de manière intelligente
-[01:16:53] VOIX: si tu veux
-[01:16:55] VOIX: ok
-[01:16:58] VOIX: et alors lui
-[01:16:59] VOIX: sur ce dossier là
-[01:17:00] VOIX: en fait
-[01:17:00] VOIX: toi tu peux me montrer
-[01:17:01] VOIX: en fait
-[01:17:01] VOIX: ce que vous aviez codé
-[01:17:03] VOIX: comment vous l'avez traité
-[01:17:05] VOIX: en amont
-[01:17:06] VOIX: avant
-[01:17:06] VOIX: regarde
-[01:17:07] VOIX: résultat de radiologie
-[01:17:09] VOIX: tu vois ce que je t'ai dit
-[01:17:10] VOIX: tout à l'heure
-[01:17:10] VOIX: il était caché
-[01:17:12] VOIX: le compte rendu
-[01:17:12] VOIX: du scanner
-[01:17:14] VOIX: il descend encore
-[01:17:15] VOIX: on va voir
-[01:17:16] VOIX: parce que tout ça
-[01:17:17] VOIX: on va le relire
-[01:17:17] VOIX: là je ne fais que
-[01:17:19] VOIX: la partie
-[01:17:20] VOIX: quel document
-[01:17:21] VOIX: qu'est-ce que je dois
-[01:17:22] VOIX: avoir comme résultat
-[01:17:23] VOIX: etc
-[01:17:24] VOIX: mais on va
-[01:17:24] VOIX: on va regarder
-[01:17:25] VOIX: le contenu
-[01:17:27] VOIX: descend encore
-[01:17:28] VOIX: j'ai le compte rendu
-[01:17:31] VOIX: de quoi
-[01:17:32] VOIX: de l'échographie
-[01:17:33] VOIX: tout à l'heure
-[01:17:33] VOIX: on a vu qu'elle a eu
-[01:17:34] VOIX: un scan
-[01:17:34] VOIX: une écho
-[01:17:35] VOIX: donc j'ai les comptes rendus
-[01:17:36] VOIX: j'ai le compte rendu
-[01:17:37] VOIX: de la radio
-[01:17:39] VOIX: descend encore
-[01:17:41] VOIX: j'ai le compte rendu
-[01:17:43] VOIX: de la radio
-[01:17:44] VOIX: écho
-[01:17:45] VOIX: parce qu'elle en a eu
-[01:17:46] VOIX: plusieurs je pense
-[01:17:46] VOIX: à des dates différentes
-[01:17:49] VOIX: parce qu'une fois
-[01:17:49] VOIX: elle dit
-[01:17:50] VOIX: voilà
-[01:17:51] VOIX: j'ai les résultats
-[01:17:52] VOIX: du labo
-[01:17:54] VOIX: quelques résultats
-[01:17:55] VOIX: est-ce que j'ai tout
-[01:17:56] VOIX: est-ce que j'ai une partie
-[01:17:57] VOIX: je n'en sais rien
-[01:17:59] VOIX: descend on va voir
-[01:18:02] VOIX: bon
-[01:18:03] VOIX: j'ai quelques résultats
-[01:18:04] VOIX: et tout ça
-[01:18:06] VOIX: on le lit
-[01:18:07] VOIX: en fait
-[01:18:08] VOIX: on le regarde
-[01:18:09] VOIX: mais avec un ordre
-[01:18:10] VOIX: de priorité
-[01:18:12] VOIX: et
-[01:18:12] VOIX: je vais t'expliquer
-[01:18:13] VOIX: l'ordre de priorité
-[01:18:15] VOIX: et qu'est-ce que tu recherches
-[01:18:17] VOIX: c'est important
-[01:18:17] VOIX: que tu me dises
-[01:18:18] VOIX: en fait ce que tu recherches
-[01:18:19] VOIX: tu le regardes
-[01:18:20] VOIX: mais qu'est-ce que tu
-[01:18:21] VOIX: c'est quoi en fait
-[01:18:22] VOIX: que tu regardes en premier
-[01:18:23] VOIX: là
-[01:18:23] VOIX: donc je vais remonter
-[01:18:24] VOIX: aux observations médicales
-[01:18:27] VOIX: tu sais
-[01:18:27] VOIX: on a vu les urgences
-[01:18:28] VOIX: on a compris
-[01:18:29] VOIX: que la patiente
-[01:18:30] VOIX: était venue
-[01:18:30] VOIX: pour douleur
-[01:18:31] VOIX: et pour diarrhée
-[01:18:34] VOIX: tu te rappelles
-[01:18:34] VOIX: oui oui
-[01:18:36] VOIX: ou pas
-[01:18:36] VOIX: donc je vais remonter
-[01:18:38] VOIX: au tout début
-[01:18:40] VOIX: quand elle a commencé
-[01:18:42] VOIX: quand on a commencé
-[01:18:43] VOIX: les observations médicales
-[01:18:45] VOIX: là
-[01:18:45] VOIX: je note
-[01:18:46] VOIX: dans ma tête
-[01:18:47] VOIX: elle vient aux urgences
-[01:18:49] VOIX: pour douleur
-[01:18:49] VOIX: diarrhée
-[01:18:50] VOIX: et elle est suivie
-[01:18:51] VOIX: pour une anémiférie
-[01:18:53] VOIX: prive
-[01:18:53] VOIX: et maladie du bien-mère
-[01:18:54] VOIX: c'est tout ce que je note
-[01:18:55] VOIX: soit dans un papier
-[01:18:56] VOIX: les types
-[01:18:57] VOIX: parfois
-[01:18:58] VOIX: un papier
-[01:18:58] VOIX: un crayon
-[01:18:59] VOIX: et elle se note
-[01:19:00] VOIX: deux trois trucs
-[01:19:01] VOIX: pour dire
-[01:19:02] VOIX: est-ce que je le code
-[01:19:02] VOIX: après ou pas
-[01:19:03] VOIX: pour ne pas oublier
-[01:19:04] VOIX: tu descends
-[01:19:05] VOIX: tu descends
-[01:19:06] VOIX: on va avoir
-[01:19:07] VOIX: les observations médicales
-[01:19:10] VOIX: après ce truc là
-[01:19:11] VOIX: il y avait
-[01:19:11] VOIX: les observations
-[01:19:12] VOIX: là
-[01:19:12] VOIX: donc là
-[01:19:14] VOIX: c'est hyper important
-[01:19:16] VOIX: de voir
-[01:19:17] VOIX: d'abord
-[01:19:18] VOIX: les observations médicales
-[01:19:20] VOIX: parfois
-[01:19:21] VOIX: tu as
-[01:19:21] VOIX: ceux qui commencent
-[01:19:22] VOIX: du haut en bas
-[01:19:23] VOIX: et parfois
-[01:19:24] VOIX: tu as ceux
-[01:19:25] VOIX: qui commencent
-[01:19:26] VOIX: de bas en haut
-[01:19:26] VOIX: d'accord
-[01:19:28] VOIX: et la date
-[01:19:28] VOIX: elle est
-[01:19:29] VOIX: hyper importante
-[01:19:30] VOIX: le 21
-[01:19:31] VOIX: c'est la fin
-[01:19:32] VOIX: du séjour
-[01:19:36] VOIX: ok
-[01:19:36] VOIX: ok
-[01:19:38] VOIX: il faut que je commence
-[01:20:06] VOIX: par
-[01:20:07] VOIX: par ça
-[01:20:08] VOIX: 9h21
-[01:20:09] VOIX: qu'est-ce qu'il dit
-[01:20:10] VOIX: au début
-[01:20:10] VOIX: il vient pour douleur
-[01:20:12] VOIX: et là
-[01:20:13] VOIX: des antécédents
-[01:20:14] VOIX: DT c'est diabète
-[01:20:16] VOIX: de type 2
-[01:20:19] VOIX: ok
-[01:20:22] VOIX: dyslipidémie
-[01:20:22] VOIX: HTA
-[01:20:23] VOIX: anémie
-[01:20:24] VOIX: phériprive chronique
-[01:20:25] VOIX: suivie
-[01:20:26] VOIX: maladie de bière
-[01:20:27] VOIX: mère
-[01:20:27] VOIX: il te dit
-[01:20:28] VOIX: qu'elle a déjà
-[01:20:29] VOIX: fait des biopsies
-[01:20:30] VOIX: elle a de la vitamine
-[01:20:31] VOIX: P12
-[01:20:31] VOIX: elle est transfusée
-[01:20:33] VOIX: elle a du fer
-[01:20:34] VOIX: injecte
-[01:20:34] VOIX: dernière transfusion
-[01:20:35] VOIX: etc
-[01:20:36] VOIX: mais on ne sait pas
-[01:20:37] VOIX: si elle a un traitement
-[01:20:38] VOIX: poursuivi ou pas
-[01:20:39] VOIX: parce que jusque là
-[01:20:40] VOIX: tout ce qu'il raconte
-[01:20:41] VOIX: sur l'anémie
-[01:20:42] VOIX: phériprive chronique
-[01:20:43] VOIX: c'est un antécédent
-[01:20:45] VOIX: je ne sais pas
-[01:20:46] VOIX: est-ce qu'ils ont
-[01:20:46] VOIX: continué le traitement
-[01:20:49] VOIX: pendant le séjour
-[01:20:50] VOIX: par exemple
-[01:20:51] VOIX: c'est là où parfois
-[01:20:53] VOIX: on va aller
-[01:20:53] VOIX: voir les médicaments
-[01:20:55] VOIX: et on va
-[01:20:56] VOIX: on va dire
-[01:20:57] VOIX: si la patiente
-[01:20:58] VOIX: a de la vitamine
-[01:20:58] VOIX: P12
-[01:20:59] VOIX: comme traitement
-[01:21:00] VOIX: peut-être que je vais
-[01:21:01] VOIX: coder la maladie
-[01:21:02] VOIX: de bière mère
-[01:21:03] VOIX: en diagnostic associé
-[01:21:05] VOIX: ok
-[01:21:06] VOIX: c'est pas en diagnostic
-[01:21:07] VOIX: principal
-[01:21:08] VOIX: parce qu'elle n'est pas
-[01:21:08] VOIX: venue pour ça
-[01:21:09] VOIX: le diagnostic principal
-[01:21:10] VOIX: c'est toujours
-[01:21:11] VOIX: pourquoi la patiente
-[01:21:13] VOIX: est venue
-[01:21:13] VOIX: en hospitalisation
-[01:21:14] VOIX: elle n'est pas venue
-[01:21:15] VOIX: pour son anémie
-[01:21:15] VOIX: ni pour sa maladie
-[01:21:16] VOIX: de bière mère
-[01:21:18] VOIX: il dit qu'elle a
-[01:21:19] VOIX: une cirrhose hépatique
-[01:21:21] VOIX: ok
-[01:21:21] VOIX: les traitements
-[01:21:23] VOIX: elle a
-[01:21:25] VOIX: vas-y
-[01:21:25] VOIX: vas-y
-[01:21:25] VOIX: vas-y
-[01:21:26] VOIX: elle a un certain
-[01:21:28] VOIX: type de traitement
-[01:21:33] VOIX: bisoprolol
-[01:21:33] VOIX: cardégique
-[01:21:34] VOIX: bétformine
-[01:21:35] VOIX: moi dans le diabète
-[01:21:36] VOIX: de type 2
-[01:21:37] VOIX: je regarde
-[01:21:37] VOIX: si elle a de l'insuline
-[01:21:38] VOIX: ou pas
-[01:21:39] VOIX: mais là pour l'instant
-[01:21:40] VOIX: je ne vois pas
-[01:21:42] VOIX: allergie
-[01:21:43] VOIX: il n'y en a pas
-[01:21:45] VOIX: MDV
-[01:21:45] VOIX: mode de vie
-[01:21:46] VOIX: MDV c'est mode de vie
-[01:21:48] VOIX: il vit seul
-[01:21:49] VOIX: autonome
-[01:21:50] VOIX: pas d'aide
-[01:21:50] VOIX: utilise un fauteuil roulant
-[01:21:52] VOIX: depuis peu
-[01:21:52] VOIX: pour les grandes distances
-[01:21:54] VOIX: car
-[01:21:55] VOIX: asthénie
-[01:21:56] VOIX: tabagisme
-[01:21:56] VOIX: pas de OH
-[01:21:57] VOIX: chronique
-[01:21:58] VOIX: OH c'est alcool
-[01:21:59] VOIX: pas d'alcoolisme
-[01:22:00] VOIX: chronique
-[01:22:01] VOIX: histoire de la maladie
-[01:22:02] VOIX: HDM
-[01:22:03] VOIX: patiente de 68 ans
-[01:22:05] VOIX: suivi par docteur
-[01:22:06] VOIX: pour anémiférie
-[01:22:07] VOIX: près de chronique
-[01:22:08] VOIX: sans cause digestif
-[01:22:08] VOIX: franche
-[01:22:09] VOIX: imputable
-[01:22:11] VOIX: elle doit voir
-[01:22:12] VOIX: l'hémato
-[01:22:13] VOIX: tu vois
-[01:22:14] VOIX: elle est en cours
-[01:22:14] VOIX: de bilan
-[01:22:15] VOIX: quand même
-[01:22:15] VOIX: de son anémie
-[01:22:18] VOIX: elle a
-[01:22:19] VOIX: été traité
-[01:22:20] VOIX: pour infection
-[01:22:21] VOIX: urinaire
-[01:22:24] VOIX: en fait
-[01:22:25] VOIX: il ne faut jamais
-[01:22:26] VOIX: perdre de vue
-[01:22:27] VOIX: la date
-[01:22:28] VOIX: de l'entrée
-[01:22:29] VOIX: pour savoir
-[01:22:29] VOIX: est-ce que c'est des choses
-[01:22:30] VOIX: qui sont encore en cours
-[01:22:32] VOIX: ou est-ce que
-[01:22:33] VOIX: c'est des choses
-[01:22:33] VOIX: qui sont passées
-[01:22:34] VOIX: parce que là
-[01:22:35] VOIX: le 7 avril
-[01:22:37] VOIX: l'antibiogramme
-[01:22:38] VOIX: disparaissant de la douleur
-[01:22:39] VOIX: abdo
-[01:22:42] VOIX: donc c'est passé
-[01:22:45] VOIX: elle a des douleurs
-[01:22:47] VOIX: diarrhées depuis 15 jours
-[01:22:48] VOIX: asthénie
-[01:22:49] VOIX: pas de vomissement
-[01:22:50] VOIX: etc
-[01:22:50] VOIX: elle a la biologie
-[01:22:52] VOIX: du 11
-[01:22:52] VOIX: donc avant de rentrer
-[01:22:53] VOIX: elle avait une IRA
-[01:22:55] VOIX: une insuffisance rénale
-[01:22:56] VOIX: aiguë
-[01:22:57] VOIX: parce que IRA
-[01:22:58] VOIX: tu peux la voir
-[01:23:00] VOIX: en insuffisance rénale
-[01:23:01] VOIX: aiguë
-[01:23:01] VOIX: ou en insuffisance respiratoire
-[01:23:02] VOIX: aiguë
-[01:23:05] VOIX: immédiatement
-[01:23:07] VOIX: regarde en haut
-[01:23:07] VOIX: le gras
-[01:23:08] VOIX: en haut de la page
-[01:23:09] VOIX: il y a un truc gras
-[01:23:11] VOIX: créate
-[01:23:11] VOIX: ça veut dire
-[01:23:12] VOIX: rénale
-[01:23:13] VOIX: c'est en lien
-[01:23:14] VOIX: avec le rein
-[01:23:15] VOIX: donc du coup
-[01:23:15] VOIX: je sais que c'est
-[01:23:16] VOIX: insuffisance rénale
-[01:23:17] VOIX: aiguë
-[01:23:18] VOIX: il n'y a pas
-[01:23:19] VOIX: d'hyperlococytose
-[01:23:20] VOIX: pas d'anémie
-[01:23:20] VOIX: etc
-[01:23:22] VOIX: en fait
-[01:23:22] VOIX: ce que je suis en train
-[01:23:23] VOIX: de chercher
-[01:23:24] VOIX: c'est quelle est leur logique
-[01:23:26] VOIX: qu'est-ce qu'ils vont faire
-[01:23:28] VOIX: parce qu'elle vient
-[01:23:29] VOIX: pour douleur diarrhée
-[01:23:30] VOIX: est-ce que
-[01:23:31] VOIX: ça sera un bilan
-[01:23:32] VOIX: diagnostique
-[01:23:33] VOIX: est-ce que ça sera
-[01:23:33] VOIX: un bilan thérapeutique
-[01:23:34] VOIX: qu'est-ce qu'ils vont trouver
-[01:23:35] VOIX: donc elle n'a pas
-[01:23:37] VOIX: de fièvre
-[01:23:38] VOIX: langue grottie
-[01:23:39] VOIX: consciente orientée
-[01:23:41] VOIX: telle amène
-[01:23:43] VOIX: jolie
-[01:23:44] VOIX: sans
-[01:23:45] VOIX: pas de haute voix
-[01:23:46] VOIX: tachycardie
-[01:23:47] VOIX: tension artérielle limite
-[01:23:48] VOIX: pas de marbrures
-[01:23:50] VOIX: pas de souffle
-[01:23:51] VOIX: pas d'hérisipel
-[01:23:53] VOIX: pas de toux
-[01:23:54] VOIX: tu vois
-[01:23:54] VOIX: il n'y a pas
-[01:23:55] VOIX: papa
-[01:23:55] VOIX: et là
-[01:23:56] VOIX: l'abdomen
-[01:23:56] VOIX: pléthorique
-[01:23:57] VOIX: gonflé
-[01:23:58] VOIX: mais sans
-[01:23:58] VOIX: tintanisme
-[01:23:59] VOIX: auréficiaire
-[01:24:00] VOIX: mignonne
-[01:24:00] VOIX: sensibilité
-[01:24:01] VOIX: au hippocondre gauche
-[01:24:02] VOIX: donc il a le ventre sensible
-[01:24:04] VOIX: mais voilà
-[01:24:05] VOIX: bruit
-[01:24:06] VOIX: hydrosairie
-[01:24:06] VOIX: discrète
-[01:24:07] VOIX: pas de douleur
-[01:24:08] VOIX: pas d'hypthère
-[01:24:09] VOIX: jusque là
-[01:24:09] VOIX: tout ça
-[01:24:10] VOIX: moi je ne cherche pas
-[01:24:11] VOIX: à coder
-[01:24:12] VOIX: tout ça
-[01:24:13] VOIX: parce que
-[01:24:13] VOIX: l'objectif
-[01:24:14] VOIX: c'est pas de
-[01:24:15] VOIX: trouver des codes
-[01:24:16] VOIX: pour
-[01:24:17] VOIX: les symptômes
-[01:24:18] VOIX: l'objectif
-[01:24:18] VOIX: c'est de trouver
-[01:24:19] VOIX: une pathologie
-[01:24:20] VOIX: en fait
-[01:24:21] VOIX: ok
-[01:24:22] VOIX: donc au total
-[01:24:23] VOIX: ils disent
-[01:24:25] VOIX: insuffisance rénale
-[01:24:26] VOIX: aiguë
-[01:24:26] VOIX: probablement
-[01:24:27] VOIX: fonctionnelle
-[01:24:28] VOIX: sur déshydratation
-[01:24:30] VOIX: majeure
-[01:24:30] VOIX: en contexte
-[01:24:31] VOIX: de diarrhée
-[01:24:32] VOIX: et des faux
-[01:24:32] VOIX: d'hydratation
-[01:24:35] VOIX: peut-être
-[01:24:36] VOIX: que mon
-[01:24:37] VOIX: DP
-[01:24:38] VOIX: je peux
-[01:24:39] VOIX: être tenté
-[01:24:40] VOIX: par mettre
-[01:24:40] VOIX: la déshydratation
-[01:24:42] VOIX: en diagnostic
-[01:24:42] VOIX: principal
-[01:24:44] VOIX: pourquoi ?
-[01:24:45] VOIX: parce que
-[01:24:47] VOIX: elle a
-[01:24:47] VOIX: une insuffisance
-[01:24:48] VOIX: rénale
-[01:24:48] VOIX: aiguë
-[01:24:49] VOIX: mais il me dit
-[01:24:50] VOIX: qu'elle est
-[01:24:51] VOIX: plutôt fonctionnelle
-[01:24:52] VOIX: et la cause
-[01:24:53] VOIX: c'est la déshydratation
-[01:24:55] VOIX: et c'est moi
-[01:24:56] VOIX: qui partage
-[01:24:57] VOIX: mon écran
-[01:24:58] VOIX: pour te montrer
-[01:24:58] VOIX: une règle
-[01:24:59] VOIX: d'accord
-[01:25:01] VOIX: tu peux ?
-[01:25:02] VOIX: vas-y
-[01:25:02] VOIX: vas-y
-[01:25:04] VOIX: tu peux
-[01:25:07] VOIX: allez
-[01:25:08] VOIX: dans le guide
-[01:25:09] VOIX: méthodo
-[01:25:10] VOIX: que je t'ai
-[01:25:10] VOIX: montré
-[01:25:11] VOIX: tout à l'heure
-[01:25:11] VOIX: si je reviens
-[01:25:12] VOIX: à mon sommaire
-[01:25:15] VOIX: on avait parlé
-[01:25:17] VOIX: des situations
-[01:25:17] VOIX: cliniques
-[01:25:18] VOIX: on a dit
-[01:25:19] VOIX: hospitalisation
-[01:25:20] VOIX: pour diagnostic
-[01:25:20] VOIX: pour traitement
-[01:25:21] VOIX: pour surveillance
-[01:25:22] VOIX: là
-[01:25:25] VOIX: je suis
-[01:25:26] VOIX: sur
-[01:25:27] VOIX: à l'heure
-[01:25:28] VOIX: où on parle
-[01:25:29] VOIX: sur une situation
-[01:25:29] VOIX: de diagnostic
-[01:25:30] VOIX: elle vient
-[01:25:31] VOIX: pour douleur
-[01:25:32] VOIX: diarrhée
-[01:25:32] VOIX: avec insuffisance
-[01:25:33] VOIX: rénale aiguë
-[01:25:35] VOIX: ils disent
-[01:25:35] VOIX: que la patiente
-[01:25:36] VOIX: elle avait
-[01:25:37] VOIX: une diarrhée
-[01:25:37] VOIX: elle a fait
-[01:25:38] VOIX: une déshydratation
-[01:25:39] VOIX: qui a provoqué
-[01:25:40] VOIX: l'insuffisance
-[01:25:41] VOIX: rénale aiguë
-[01:25:41] VOIX: fonctionnelle
-[01:25:43] VOIX: donc je peux
-[01:25:45] VOIX: être tentée
-[01:25:45] VOIX: en DP
-[01:25:46] VOIX: soit par la diarrhée
-[01:25:47] VOIX: soit par la déshydratation
-[01:25:48] VOIX: tu vois
-[01:25:49] VOIX: c'est pas
-[01:25:49] VOIX: aussi formel
-[01:25:50] VOIX: et donc
-[01:25:52] VOIX: j'ai une règle
-[01:25:53] VOIX: ici
-[01:25:54] VOIX: qui me dit
-[01:25:55] VOIX: que le diagnostic
-[01:25:56] VOIX: s'accompagne
-[01:25:57] VOIX: au nom
-[01:25:57] VOIX: d'un traitement
-[01:25:58] VOIX: oui pardon
-[01:26:00] VOIX: hospitalisation
-[01:26:01] VOIX: pour diagnostic
-[01:26:02] VOIX: la situation
-[01:26:03] VOIX: est celle
-[01:26:03] VOIX: d'un patient
-[01:26:04] VOIX: hospitalisé
-[01:26:05] VOIX: en raison
-[01:26:06] VOIX: d'une symptomatologie
-[01:26:07] VOIX: pour un diagnostic
-[01:26:08] VOIX: éthiologique
-[01:26:09] VOIX: donc le motif
-[01:26:10] VOIX: d'hospitalisation
-[01:26:11] VOIX: il a un symptôme
-[01:26:12] VOIX: une douleur
-[01:26:12] VOIX: une diarrhée
-[01:26:13] VOIX: etc
-[01:26:13] VOIX: et je vais essayer
-[01:26:15] VOIX: de comprendre
-[01:26:16] VOIX: pourquoi
-[01:26:20] VOIX: le diagnostic
-[01:26:21] VOIX: qu'il soit traité
-[01:26:23] VOIX: ou pas
-[01:26:23] VOIX: à partir du moment
-[01:26:24] VOIX: où je fais le diagnostic
-[01:26:27] VOIX: sauf en cas
-[01:26:28] VOIX: de chirurgie
-[01:26:28] VOIX: parce que là
-[01:26:29] VOIX: il y a le cas
-[01:26:29] VOIX: de traitement chirurgical
-[01:26:30] VOIX: qui change
-[01:26:32] VOIX: cette règle
-[01:26:35] VOIX: j'ai deux cas
-[01:26:37] VOIX: soit
-[01:26:38] VOIX: le séjour
-[01:26:39] VOIX: a permis
-[01:26:39] VOIX: le diagnostic
-[01:26:40] VOIX: de l'infection causale
-[01:26:41] VOIX: soit
-[01:26:41] VOIX: ils vont faire
-[01:26:42] VOIX: une biologie
-[01:26:43] VOIX: une bactériale
-[01:26:44] VOIX: de la déshydratation
-[01:26:45] VOIX: il va me dire
-[01:26:47] VOIX: diarrhée
-[01:26:48] VOIX: liée à telle bactérie
-[01:26:49] VOIX: ou tel virus
-[01:26:50] VOIX: peut-être que
-[01:26:51] VOIX: dans ce cas-là
-[01:26:51] VOIX: je poterai la diarrhée
-[01:26:53] VOIX: soit
-[01:26:53] VOIX: je me contenterai
-[01:26:55] VOIX: de la déshydratation
-[01:26:56] VOIX: mais déjà
-[01:26:56] VOIX: je commence à réfléchir
-[01:26:58] VOIX: tu vois
-[01:26:59] VOIX: et ça
-[01:27:00] VOIX: c'est les règles
-[01:27:01] VOIX: que j'utilise
-[01:27:01] VOIX: c'est-à-dire
-[01:27:02] VOIX: si le séjour
-[01:27:03] VOIX: a permis
-[01:27:03] VOIX: le diagnostic
-[01:27:04] VOIX: de l'infection causale
-[01:27:05] VOIX: elle va être le DP
-[01:27:06] VOIX: par exemple
-[01:27:08] VOIX: hospice
-[01:27:08] VOIX: pour une confusion
-[01:27:09] VOIX: mais on découvre
-[01:27:10] VOIX: une tumeur cérébrale
-[01:27:11] VOIX: le DP
-[01:27:12] VOIX: c'est la tumeur cérébrale
-[01:27:13] VOIX: hospice
-[01:27:14] VOIX: pour douleurs thoraciques
-[01:27:15] VOIX: le diagnostic
-[01:27:16] VOIX: de l'angine de poitrine
-[01:27:17] VOIX: le DP
-[01:27:17] VOIX: c'est l'angine de poitrine
-[01:27:19] VOIX: hospice
-[01:27:20] VOIX: pour anémie
-[01:27:21] VOIX: occlusion
-[01:27:22] VOIX: découverte
-[01:27:23] VOIX: d'un cancer colique
-[01:27:24] VOIX: c'est le cancer colique
-[01:27:25] VOIX: là on a l'exemple
-[01:27:26] VOIX: de hospice
-[01:27:27] VOIX: pour douleurs
-[01:27:29] VOIX: diarrhées
-[01:27:29] VOIX: insuffisance rénale
-[01:27:31] VOIX: donc soit
-[01:27:32] VOIX: c'est la diarrhée
-[01:27:32] VOIX: qui cause tout ça
-[01:27:33] VOIX: soit c'est
-[01:27:34] VOIX: il y a une infection
-[01:27:36] VOIX: etc
-[01:27:37] VOIX: soit c'est la déshydratation
-[01:27:39] VOIX: qui cause
-[01:27:39] VOIX: l'insuffisance rénale
-[01:27:40] VOIX: je peux avoir
-[01:27:41] VOIX: plusieurs niveaux
-[01:27:42] VOIX: de cause à effet
-[01:27:43] VOIX: en tout cas
-[01:27:44] VOIX: ça c'est une règle
-[01:27:44] VOIX: par contre
-[01:27:46] VOIX: si je ne trouve pas
-[01:27:47] VOIX: de cause
-[01:27:48] VOIX: dans ce cas là
-[01:27:50] VOIX: hospice pour céphalée
-[01:27:51] VOIX: je garde des céphalées
-[01:27:53] VOIX: hospice pour douleurs
-[01:27:54] VOIX: abdos
-[01:27:54] VOIX: mais je ne trouve pas
-[01:27:55] VOIX: pourquoi
-[01:27:56] VOIX: ça sera douleurs abdos
-[01:27:58] VOIX: j'ai l'état de choc
-[01:27:59] VOIX: mais je ne trouve pas
-[01:27:59] VOIX: pourquoi
-[01:28:00] VOIX: le DP reste
-[01:28:01] VOIX: l'état de choc
-[01:28:03] VOIX: syndrome inflammatoire
-[01:28:04] VOIX: je ne sais pas pourquoi
-[01:28:05] VOIX: ça reste
-[01:28:05] VOIX: le syndrome inflammatoire
-[01:28:07] VOIX: donc c'est un peu ça
-[01:28:08] VOIX: les règles
-[01:28:11] VOIX: tu vois
-[01:28:11] VOIX: tu n'as pas la cause
-[01:28:12] VOIX: et bien là
-[01:28:13] VOIX: la douleur abdos
-[01:28:14] VOIX: il a une polyarthrite
-[01:28:15] VOIX: de l'homatoïde
-[01:28:16] VOIX: en raison de douleurs
-[01:28:17] VOIX: abdos
-[01:28:18] VOIX: disparition des douleurs
-[01:28:19] VOIX: en 48 heures
-[01:28:20] VOIX: mais pas de cause
-[01:28:20] VOIX: trouvée
-[01:28:22] VOIX: c'est les douleurs
-[01:28:23] VOIX: abdos
-[01:28:23] VOIX: parce qu'on n'a pas
-[01:28:24] VOIX: compris pourquoi
-[01:28:26] VOIX: ok
-[01:28:29] VOIX: ça c'est les règles
-[01:28:30] VOIX: du guide
-[01:28:30] VOIX: tu as vu
-[01:28:31] VOIX: que je reviens
-[01:28:31] VOIX: aux règles
-[01:28:32] VOIX: parce que moi
-[01:28:33] VOIX: perso
-[01:28:34] VOIX: et j'ai formé
-[01:28:35] VOIX: Pauline à ça
-[01:28:36] VOIX: les gens
-[01:28:37] VOIX: qui réfléchissent
-[01:28:39] VOIX: et qui sont capables
-[01:28:40] VOIX: d'argumenter le dossier
-[01:28:41] VOIX: c'est les gens
-[01:28:43] VOIX: qui rapportent
-[01:28:43] VOIX: la logique
-[01:28:44] VOIX: au guide méthodologique
-[01:28:50] VOIX: est-ce que jusque là
-[01:28:51] VOIX: ça va
-[01:28:51] VOIX: oui ça va
-[01:28:52] VOIX: mais ça fait
-[01:28:52] VOIX: beaucoup
-[01:28:53] VOIX: beaucoup
-[01:28:53] VOIX: beaucoup
-[01:28:53] VOIX: d'informations
-[01:28:54] VOIX: en fait
-[01:28:54] VOIX: que je dois
-[01:28:57] VOIX: retenir
-[01:28:57] VOIX: tu vas enregistrer
-[01:28:59] VOIX: enregistre
-[01:29:00] VOIX: oui mais je t'enregistre
-[01:29:01] VOIX: mais après
-[01:29:01] VOIX: en fait
-[01:29:01] VOIX: il faut que je traduise
-[01:29:02] VOIX: tout ce que tu viens
-[01:29:03] VOIX: de me dire
-[01:29:03] VOIX: il faut que je la traduise
-[01:29:04] VOIX: dans un autre langage
-[01:29:05] VOIX: que je ne pourrais pas
-[01:29:07] VOIX: en fait apprendre
-[01:29:07] VOIX: surtout aujourd'hui
-[01:29:08] VOIX: donc on peut s'arrêter là
-[01:29:09] VOIX: déjà
-[01:29:10] VOIX: oui je pense en fait
-[01:29:10] VOIX: que oui
-[01:29:11] VOIX: j'ai peut-être
-[01:29:11] VOIX: 2-3 questions
-[01:29:12] VOIX: à te poser
-[01:29:13] VOIX: si tu permets
-[01:29:13] VOIX: parce que c'est important
-[01:29:14] VOIX: pour moi
-[01:29:15] VOIX: non
-[01:29:16] VOIX: en fait
-[01:29:17] VOIX: attends
-[01:29:18] VOIX: ne bouge pas
-[01:29:20] VOIX: je regarde mon écran
-[01:29:21] VOIX: donc je ne te vois plus
-[01:29:22] VOIX: c'est dommage
-[01:29:23] VOIX: mais bon
-[01:29:24] VOIX: c'est pas grave
-[01:29:25] VOIX: donc tu m'as répondu
-[01:29:26] VOIX: je regarde en fait
-[01:29:27] VOIX: parce que je me suis noté
-[01:29:27] VOIX: plein plein de questions
-[01:29:28] VOIX: et tu m'as répondu
-[01:29:30] VOIX: en fait
-[01:29:30] VOIX: à une grosse partie
-[01:29:31] VOIX: quel est le critère
-[01:29:32] VOIX: pour toi
-[01:29:33] VOIX: oui
-[01:29:33] VOIX: si tu as en fait
-[01:29:35] VOIX: 2 d'épées plausibles
-[01:29:37] VOIX: qu'est-ce que c'est ton critère
-[01:29:42] VOIX: mon critère
-[01:29:43] VOIX: je repartage mon écran
-[01:29:46] VOIX: non mais si tu peux me le dire
-[01:29:47] VOIX: comme ça
-[01:29:48] VOIX: c'est celui qui a mobilisé
-[01:29:50] VOIX: le plus d'efforts
-[01:29:54] VOIX: je ne sais pas comment
-[01:29:55] VOIX: en quelle règle
-[01:29:57] VOIX: en fait petit à petit
-[01:29:59] VOIX: il faudra qu'on intègre
-[01:30:02] VOIX: dans la notion de règle
-[01:30:03] VOIX: dans la logique
-[01:30:05] VOIX: c'est hyper important
-[01:30:12] VOIX: ok
-[01:30:12] VOIX: ok ça je l'ai intégré
-[01:30:14] VOIX: mais en fait
-[01:30:15] VOIX: on reverra
-[01:30:16] VOIX: de toute façon
-[01:30:16] VOIX: il va falloir
-[01:30:17] VOIX: que je pense
-[01:30:18] VOIX: c'est pas fini
-[01:30:22] VOIX: c'est celui
-[01:30:23] VOIX: en fait
-[01:30:25] VOIX: soit celui
-[01:30:25] VOIX: qui a mobilisé
-[01:30:26] VOIX: le plus d'efforts
-[01:30:27] VOIX: et il te dit
-[01:30:28] VOIX: si les deux
-[01:30:28] VOIX: ont mobilisé
-[01:30:30] VOIX: le même niveau d'efforts
-[01:30:31] VOIX: ça reste
-[01:30:32] VOIX: au choix
-[01:30:33] VOIX: de l'établissement
-[01:30:34] VOIX: d'accord
-[01:30:34] VOIX: parfois on simule
-[01:30:36] VOIX: et on choisit
-[01:30:36] VOIX: celui qui valorise
-[01:30:37] VOIX: le mieux
-[01:30:38] VOIX: et on peut le défendre
-[01:30:39] VOIX: alors attendez
-[01:30:41] VOIX: mobiliser
-[01:30:45] VOIX: je vais regarder
-[01:30:46] VOIX: moi si je te trouve
-[01:30:47] VOIX: la règle
-[01:30:47] VOIX: ça sera plus simple
-[01:31:08] VOIX: en même temps
-[01:31:09] VOIX: en fait
-[01:31:09] VOIX: je peux te poser
-[01:31:10] VOIX: des questions
-[01:31:11] VOIX: ben bien sûr
-[01:31:17] VOIX: une question
-[01:31:18] VOIX: en fait
-[01:31:18] VOIX: c'est
-[01:31:19] VOIX: c'est des symptômes
-[01:31:20] VOIX: versus
-[01:31:21] VOIX: causes
-[01:31:21] VOIX: pardon pardon
-[01:31:22] VOIX: je peux t'embêter
-[01:31:23] VOIX: je vais juste partager
-[01:31:25] VOIX: un truc
-[01:31:25] VOIX: pour te montrer
-[01:31:27] VOIX: toujours un peu
-[01:31:27] VOIX: la logique
-[01:31:28] VOIX: tu vois
-[01:31:29] VOIX: tu m'as posé
-[01:31:30] VOIX: la question
-[01:31:30] VOIX: que faire
-[01:31:32] VOIX: si l'analyse
-[01:31:32] VOIX: en termes
-[01:31:33] VOIX: de situation clinique
-[01:31:33] VOIX: propose plus
-[01:31:34] VOIX: d'un diagnostic
-[01:31:35] VOIX: principal
-[01:31:36] VOIX: tu vois
-[01:31:37] VOIX: le DP
-[01:31:38] VOIX: étant le problème
-[01:31:39] VOIX: de santé
-[01:31:39] VOIX: qui a motivé
-[01:31:40] VOIX: l'admission
-[01:31:41] VOIX: une telle circonstance
-[01:31:42] VOIX: ne peut être que rare
-[01:31:43] VOIX: le DP déterminé
-[01:31:44] VOIX: à la sortie
-[01:31:45] VOIX: est alors celui
-[01:31:45] VOIX: du problème
-[01:31:46] VOIX: qui a mobilisé
-[01:31:47] VOIX: l'essentiel
-[01:31:47] VOIX: des efforts
-[01:31:48] VOIX: de soins
-[01:31:48] VOIX: si tu en as plus
-[01:31:50] VOIX: dans les cas
-[01:31:51] VOIX: où les deux problèmes
-[01:31:52] VOIX: auraient mobilisé
-[01:31:52] VOIX: des efforts
-[01:31:53] VOIX: d'importance
-[01:31:53] VOIX: comparables
-[01:31:54] VOIX: c'est à dire
-[01:31:55] VOIX: dans le cas
-[01:31:55] VOIX: de prise en charge
-[01:31:57] VOIX: équivalente
-[01:31:57] VOIX: et dans ce cas
-[01:31:58] VOIX: seulement
-[01:31:58] VOIX: le choix du DP
-[01:31:59] VOIX: parmi les ex-ecos
-[01:32:00] VOIX: est laissé à l'établissement
-[01:32:01] VOIX: de santé
-[01:32:02] VOIX: d'accord
-[01:32:02] VOIX: ok
-[01:32:03] VOIX: tu vois
-[01:32:04] VOIX: et c'est la règle
-[01:32:05] VOIX: M2
-[01:32:05] VOIX: ok
-[01:32:09] VOIX: tiens
-[01:32:09] VOIX: que si tu vois
-[01:32:10] VOIX: je pourrais te dire
-[01:32:11] VOIX: non mais c'est
-[01:32:12] VOIX: c'est
-[01:32:13] VOIX: c'est la règle
-[01:32:13] VOIX: M2
-[01:32:14] VOIX: et c'est telle réponse
-[01:32:16] VOIX: ouais
-[01:32:16] VOIX: c'est impeccable
-[01:32:17] VOIX: pour moi
-[01:32:17] VOIX: alors donc
-[01:32:18] VOIX: ça c'est fait
-[01:32:19] VOIX: le symptôme
-[01:32:20] VOIX: versus cause
-[01:32:22] VOIX: c'est
-[01:32:23] VOIX: dans quel cas
-[01:32:24] VOIX: en fait
-[01:32:24] VOIX: le symptôme
-[01:32:25] VOIX: reste un DP
-[01:32:27] VOIX: si le bilan
-[01:32:29] VOIX: ne trouve
-[01:32:30] VOIX: aucun
-[01:32:30] VOIX: aucune cause
-[01:32:31] VOIX: d'accord
-[01:32:32] VOIX: le symptôme
-[01:32:33] VOIX: il reste le DP
-[01:32:35] VOIX: si par exemple
-[01:32:36] VOIX: la douleur
-[01:32:37] VOIX: abdo
-[01:32:37] VOIX: mais finalement
-[01:32:38] VOIX: ils font le scan
-[01:32:39] VOIX: ils font l'écho
-[01:32:40] VOIX: etc
-[01:32:40] VOIX: il n'y a aucune cause
-[01:32:41] VOIX: à la douleur
-[01:32:42] VOIX: ils vont traiter
-[01:32:43] VOIX: la douleur
-[01:32:44] VOIX: et ils disent
-[01:32:45] VOIX: douleur
-[01:32:46] VOIX: sans cause précisée
-[01:32:52] VOIX: ok
-[01:32:55] VOIX: ça c'était
-[01:32:56] VOIX: la règle
-[01:32:58] VOIX: D1
-[01:32:59] VOIX: je pense
-[01:33:01] VOIX: on va regarder
-[01:33:05] VOIX: situation
-[01:33:06] VOIX: attends
-[01:33:06] VOIX: je vais revenir
-[01:33:08] VOIX: un peu avant
-[01:33:09] VOIX: parce que c'est important
-[01:33:10] VOIX: que je te montre
-[01:33:11] VOIX: en fonction
-[01:33:12] VOIX: de tes questions
-[01:33:14] VOIX: attends
-[01:33:14] VOIX: je partage
-[01:33:17] VOIX: par exemple
-[01:33:18] VOIX: là c'était
-[01:33:19] VOIX: la situation
-[01:33:20] VOIX: le séjour
-[01:33:21] VOIX: a permis
-[01:33:21] VOIX: le diagnostic
-[01:33:22] VOIX: de l'infection
-[01:33:23] VOIX: causale
-[01:33:24] VOIX: il n'a pas été
-[01:33:26] VOIX: découvert
-[01:33:26] VOIX: de cause
-[01:33:27] VOIX: à la symptomatologie
-[01:33:28] VOIX: c'est le cas
-[01:33:28] VOIX: dont tu parles
-[01:33:29] VOIX: d'accord
-[01:33:31] VOIX: il te dit
-[01:33:32] VOIX: il n'a pas été
-[01:33:32] VOIX: découvert
-[01:33:33] VOIX: de cause
-[01:33:33] VOIX: à la symptomatologie
-[01:33:34] VOIX: lorsqu'il n'a pas
-[01:33:35] VOIX: été découvert
-[01:33:36] VOIX: de cause
-[01:33:36] VOIX: à la symptomatologie
-[01:33:38] VOIX: et les DP
-[01:33:38] VOIX: c'est le symptôme
-[01:33:39] VOIX: c'est à dire
-[01:33:41] VOIX: au speech
-[01:33:41] VOIX: le patient
-[01:33:42] VOIX: il vient pour
-[01:33:42] VOIX: céphalée
-[01:33:43] VOIX: mais la conclusion
-[01:33:45] VOIX: c'est céphalée
-[01:33:46] VOIX: sans cause
-[01:33:46] VOIX: trouvée
-[01:33:47] VOIX: le DP
-[01:33:48] VOIX: reste
-[01:33:48] VOIX: les céphalées
-[01:33:49] VOIX: tu vois
-[01:33:51] VOIX: c'est clair
-[01:33:52] VOIX: c'est clair
-[01:33:53] VOIX: ou tu veux plus
-[01:33:54] VOIX: d'exemple
-[01:33:54] VOIX: non mais
-[01:33:55] VOIX: là pour l'instant
-[01:33:55] VOIX: la douleur abdo
-[01:33:56] VOIX: c'est ce que je t'ai montré
-[01:33:57] VOIX: ici
-[01:33:58] VOIX: il vient pour douleur abdo
-[01:34:00] VOIX: mais disparaissant des douleurs
-[01:34:01] VOIX: à 48 heures
-[01:34:02] VOIX: pas de cause
-[01:34:03] VOIX: trouvée
-[01:34:04] VOIX: le DP
-[01:34:05] VOIX: c'est les douleurs abdo
-[01:34:06] VOIX: tout à l'heure
-[01:34:08] VOIX: dans ton dossier
-[01:34:09] VOIX: on avait
-[01:34:10] VOIX: la pancréatite
-[01:34:11] VOIX: on avait
-[01:34:12] VOIX: la lithiase
-[01:34:12] VOIX: de la vésicule biliaire
-[01:34:14] VOIX: etc
-[01:34:14] VOIX: c'est ce qui expliquait
-[01:34:16] VOIX: la douleur
-[01:34:16] VOIX: là
-[01:34:17] VOIX: ta douleur
-[01:34:18] VOIX: mais il fait le bilan
-[01:34:19] VOIX: il trouve rien
-[01:34:20] VOIX: ça reste la douleur
-[01:34:22] VOIX: ça c'est fréquent aussi
-[01:34:23] VOIX: surtout dans les services
-[01:34:24] VOIX: des urgences
-[01:34:25] VOIX: etc
-[01:34:27] VOIX: ok
-[01:34:30] VOIX: ok
-[01:34:30] VOIX: ok
-[01:34:31] VOIX: c'est la D2
-[01:34:32] VOIX: la D2
-[01:34:33] VOIX: pour
-[01:34:34] VOIX: cette deuxième réponse
-[01:34:36] VOIX: c'est noté
-[01:34:38] VOIX: c'est noté
-[01:34:38] VOIX: quel rôle
-[01:34:39] VOIX: jouent les actes
-[01:34:40] VOIX: traitement
-[01:34:41] VOIX: en fait
-[01:34:41] VOIX: dans la décision
-[01:34:42] VOIX: en fait
-[01:34:42] VOIX: du DP
-[01:34:45] VOIX: en fait
-[01:34:45] VOIX: ça rejouit
-[01:34:46] VOIX: à peu près
-[01:34:46] VOIX: ce que tu viens
-[01:34:46] VOIX: de me dire
-[01:34:47] VOIX: aussi
-[01:34:47] VOIX: tu passes en fait
-[01:34:48] VOIX: sur le guide
-[01:34:49] VOIX: méthodologique
-[01:34:50] VOIX: ok
-[01:34:50] VOIX: ta question
-[01:34:51] VOIX: j'ai pas forcément
-[01:34:52] VOIX: compris
-[01:34:52] VOIX: le rôle
-[01:34:53] VOIX: des actes
-[01:34:54] VOIX: voilà
-[01:34:57] VOIX: actes et traitement
-[01:34:58] VOIX: dans la décision
-[01:34:59] VOIX: du DP
-[01:34:59] VOIX: ah
-[01:35:00] VOIX: c'était la situation
-[01:35:01] VOIX: de traitement
-[01:35:02] VOIX: en fait
-[01:35:03] VOIX: dans les actes
-[01:35:03] VOIX: alors
-[01:35:04] VOIX: je vais faire
-[01:35:04] VOIX: une petite pause
-[01:35:05] VOIX: parce que ça
-[01:35:06] VOIX: tu ne le trouveras pas
-[01:35:07] VOIX: dans le guide méthodo
-[01:35:08] VOIX: parce que je ne t'ai pas
-[01:35:08] VOIX: donné
-[01:35:09] VOIX: toute la logique
-[01:35:11] VOIX: sur les actes
-[01:35:12] VOIX: tu n'as pas encore
-[01:35:13] VOIX: la réglementation
-[01:35:14] VOIX: là-dessus
-[01:35:14] VOIX: mais dans les actes
-[01:35:16] VOIX: CCAM
-[01:35:17] VOIX: il y en a
-[01:35:18] VOIX: qu'on appelle
-[01:35:19] VOIX: les actes
-[01:35:19] VOIX: classants
-[01:35:20] VOIX: et les actes
-[01:35:21] VOIX: non classants
-[01:35:22] VOIX: les actes
-[01:35:23] VOIX: classants
-[01:35:24] VOIX: par exemple
-[01:35:25] VOIX: la cholycystectomie
-[01:35:26] VOIX: tu vas lui donner
-[01:35:28] VOIX: la cholycystite
-[01:35:29] VOIX: ou la pancréatite
-[01:35:30] VOIX: etc
-[01:35:31] VOIX: elle va t'orienter
-[01:35:32] VOIX: vers un GHM
-[01:35:33] VOIX: chirurgical
-[01:35:34] VOIX: puisque
-[01:35:35] VOIX: l'acte
-[01:35:36] VOIX: il joue
-[01:35:37] VOIX: un rôle
-[01:35:37] VOIX: dans le groupage
-[01:35:39] VOIX: dans le GHM
-[01:35:40] VOIX: et dans le GHS
-[01:35:41] VOIX: et dans le tarif
-[01:35:42] VOIX: les actes
-[01:35:43] VOIX: non classants
-[01:35:44] VOIX: c'est toutes
-[01:35:45] VOIX: les imageries
-[01:35:46] VOIX: l'acte
-[01:35:50] VOIX: d'anapath
-[01:35:51] VOIX: le scanner
-[01:35:52] VOIX: l'IRM
-[01:35:53] VOIX: les échos
-[01:35:53] VOIX: etc
-[01:35:54] VOIX: c'est des actes
-[01:35:55] VOIX: d'imagerie
-[01:35:56] VOIX: mais qui ne vont pas
-[01:35:56] VOIX: influencer
-[01:35:57] VOIX: le groupage
-[01:35:59] VOIX: mais par contre
-[01:36:00] VOIX: ils peuvent avoir
-[01:36:01] VOIX: un rôle
-[01:36:01] VOIX: parce que
-[01:36:02] VOIX: peut-être que c'est
-[01:36:03] VOIX: dans le scanner
-[01:36:04] VOIX: ou dans l'échographie
-[01:36:05] VOIX: etc
-[01:36:05] VOIX: le compte-rendu
-[01:36:06] VOIX: qui va te dire
-[01:36:08] VOIX: au fait
-[01:36:09] VOIX: le scan
-[01:36:10] VOIX: il retrouve
-[01:36:11] VOIX: une tumeur
-[01:36:12] VOIX: du côlon
-[01:36:13] VOIX: et elle peut
-[01:36:14] VOIX: expliquer très bien
-[01:36:15] VOIX: la douleur abdominale
-[01:36:16] VOIX: et dans ce cas-là
-[01:36:17] VOIX: c'est la tumeur
-[01:36:18] VOIX: du côlon
-[01:36:19] VOIX: qu'il faut suivre
-[01:36:20] VOIX: pour donner
-[01:36:23] VOIX: de la précision
-[01:36:24] VOIX: au codage
-[01:36:24] VOIX: du diagnostic
-[01:36:25] VOIX: donc
-[01:36:26] VOIX: c'est soit
-[01:36:27] VOIX: à travers
-[01:36:27] VOIX: l'acte classant
-[01:36:30] VOIX: tu sais que
-[01:36:31] VOIX: si ton acte
-[01:36:34] VOIX: c'est une cholycystectomie
-[01:36:35] VOIX: tu vas regarder
-[01:36:37] VOIX: le résultat
-[01:36:38] VOIX: le compte-rendu
-[01:36:39] VOIX: opératoire
-[01:36:40] VOIX: tu vas regarder
-[01:36:41] VOIX: pourquoi il a enlevé
-[01:36:42] VOIX: il a fait une cholycystectomie
-[01:36:45] VOIX: et ça t'oriente
-[01:36:47] VOIX: vers le DP
-[01:36:48] VOIX: avec la règle T
-[01:36:49] VOIX: alors je ne sais plus
-[01:36:50] VOIX: comment il s'appelle
-[01:36:51] VOIX: traitement unique
-[01:36:51] VOIX: chirurgical
-[01:36:52] VOIX: et la deuxième situation
-[01:36:55] VOIX: c'est tous les résultats
-[01:36:56] VOIX: des examens
-[01:36:58] VOIX: des actes
-[01:36:59] VOIX: des examens
-[01:37:00] VOIX: surtout les examens
-[01:37:01] VOIX: d'imagerie
-[01:37:01] VOIX: qui peuvent être codés
-[01:37:04] VOIX: parce qu'ils montrent
-[01:37:05] VOIX: des diagnostics
-[01:37:06] VOIX: qu'ils posent des diagnostics
-[01:37:07] VOIX: qu'ils aident
-[01:37:07] VOIX: à poser des diagnostics
-[01:37:10] VOIX: et bien ok
-[01:37:12] VOIX: moi qui croyais en fait
-[01:37:13] VOIX: que mon métier
-[01:37:13] VOIX: était super compliqué
-[01:37:18] VOIX: en fait c'est simple
-[01:37:19] VOIX: et je comprends
-[01:37:20] VOIX: pourquoi tu as envie
-[01:37:20] VOIX: en fait d'apprendre
-[01:37:21] VOIX: en fait à lire
-[01:37:23] VOIX: j'ai compris
-[01:37:24] VOIX: j'ai compris le bordel
-[01:37:25] VOIX: pardon
-[01:37:27] VOIX: non non mais
-[01:37:28] VOIX: tu as raison
-[01:37:29] VOIX: c'est un vrai bordel
-[01:37:29] VOIX: et c'est pour ça
-[01:37:31] VOIX: que quand je tempère
-[01:37:32] VOIX: parfois
-[01:37:33] VOIX: certaines situations
-[01:37:34] VOIX: en disant
-[01:37:34] VOIX: c'est pas
-[01:37:35] VOIX: A plus B
-[01:37:36] VOIX: égal à C
-[01:37:37] VOIX: il faut quand même
-[01:37:38] VOIX: qu'on réfléchisse
-[01:37:39] VOIX: proprement
-[01:37:41] VOIX: pour ne pas
-[01:37:42] VOIX: se planter
-[01:37:43] VOIX: quitte
-[01:37:44] VOIX: par exemple
-[01:37:45] VOIX: dans
-[01:37:45] VOIX: on peut procéder
-[01:37:48] VOIX: aussi
-[01:37:49] VOIX: par méthode
-[01:37:50] VOIX: c'est à dire
-[01:37:51] VOIX: on prend les séjours
-[01:37:52] VOIX: chirurgicaux
-[01:37:53] VOIX: et
-[01:37:54] VOIX: tu travailles
-[01:37:55] VOIX: sur une règle
-[01:37:57] VOIX: sur une règle
-[01:37:58] VOIX: traitement unique
-[01:37:59] VOIX: chirurgical
-[01:38:00] VOIX: pour dire
-[01:38:03] VOIX: il n'y en a pas
-[01:38:03] VOIX: beaucoup dans ce contrôle
-[01:38:05] VOIX: par contre
-[01:38:05] VOIX: parce qu'il y a
-[01:38:06] VOIX: quelques-uns
-[01:38:07] VOIX: mais c'est pas très fréquent
-[01:38:08] VOIX: mais si on avait
-[01:38:09] VOIX: un établissement MCO
-[01:38:10] VOIX: ça serait très intéressant
-[01:38:12] VOIX: d'exploiter
-[01:38:13] VOIX: cette notion
-[01:38:14] VOIX: de dire
-[01:38:14] VOIX: je prends les séjours
-[01:38:15] VOIX: de chirurgie
-[01:38:16] VOIX: parce que tu vas voir
-[01:38:17] VOIX: le patient qui a une
-[01:38:18] VOIX: prothèse de hanche
-[01:38:20] VOIX: il a deux
-[01:38:21] VOIX: trois diagnostics
-[01:38:21] VOIX: c'est soit il a fait
-[01:38:22] VOIX: une fracture
-[01:38:22] VOIX: soit qu'il a une arthrose
-[01:38:24] VOIX: de la hanche
-[01:38:25] VOIX: voilà
-[01:38:26] VOIX: ça va être vite vu
-[01:38:28] VOIX: si tu veux
-[01:38:28] VOIX: et ça permet
-[01:38:29] VOIX: d'entraîner
-[01:38:30] VOIX: aussi à dire
-[01:38:31] VOIX: avec tel acte
-[01:38:32] VOIX: je vais avoir
-[01:38:33] VOIX: souvent
-[01:38:34] VOIX: tel diagnostic
-[01:38:35] VOIX: avec la règle
-[01:38:36] VOIX: traitement unique
-[01:38:37] VOIX: chirurgical
-[01:38:37] VOIX: tu vois
-[01:38:38] VOIX: parce que ça
-[01:38:39] VOIX: c'est les séjours
-[01:38:41] VOIX: les plus faciles
-[01:38:42] VOIX: à coder
-[01:38:42] VOIX: c'est les traitements
-[01:38:43] VOIX: uniques chirurgical
-[01:38:44] VOIX: quand on forme
-[01:38:45] VOIX: les teams
-[01:38:45] VOIX: on commence par
-[01:38:46] VOIX: les séjours
-[01:38:47] VOIX: de zéro jour
-[01:38:47] VOIX: et puis par
-[01:38:49] VOIX: la chirurgie
-[01:38:49] VOIX: parce que la chirurgie
-[01:38:50] VOIX: la logique
-[01:38:51] VOIX: elle est plus facile
-[01:38:52] VOIX: que dire
-[01:38:52] VOIX: le symptôme
-[01:38:53] VOIX: la cause
-[01:38:54] VOIX: machin
-[01:38:54] VOIX: en plus
-[01:38:55] VOIX: elles n'ont pas
-[01:38:55] VOIX: de notion médical
-[01:38:56] VOIX: donc c'est pas évident
-[01:38:59] VOIX: il faut avoir
-[01:39:00] VOIX: quand même
-[01:39:01] VOIX: médical
-[01:39:02] VOIX: oui j'ai l'impression
-[01:39:04] VOIX: j'ai un peu l'impression
-[01:39:08] VOIX: je continue mes questions
-[01:39:11] VOIX: bien sûr
-[01:39:12] VOIX: alors
-[01:39:13] VOIX: parce que je suis organisé
-[01:39:14] VOIX: ça se voit pas
-[01:39:15] VOIX: mais je suis un garçon
-[01:39:16] VOIX: organisé
-[01:39:17] VOIX: c'est parfait ça
-[01:39:19] VOIX: ça on le verra
-[01:39:19] VOIX: après en fait
-[01:39:20] VOIX: c'est des problèmes
-[01:39:21] VOIX: que je suis en train
-[01:39:22] VOIX: non mais ça va aller
-[01:39:24] VOIX: peut-être un peu plus vite
-[01:39:25] VOIX: tu vois
-[01:39:29] VOIX: alors ça
-[01:39:30] VOIX: est-ce que je te le pose
-[01:39:31] VOIX: là aujourd'hui
-[01:39:33] VOIX: non
-[01:39:34] VOIX: vas-y vas-y
-[01:39:35] VOIX: non non
-[01:39:36] VOIX: c'est pas une question
-[01:39:36] VOIX: de peur
-[01:39:37] VOIX: en fait
-[01:39:37] VOIX: si tu veux
-[01:39:38] VOIX: tu m'as donné
-[01:39:38] VOIX: tellement d'informations
-[01:39:39] VOIX: sur des sujets
-[01:39:40] VOIX: que j'avais prévu
-[01:39:41] VOIX: un peu plus tard
-[01:39:42] VOIX: c'est pour ça en fait
-[01:39:43] VOIX: que j'aime bien
-[01:39:44] VOIX: cadrer un tout petit peu
-[01:39:45] VOIX: c'est parce qu'il faut
-[01:39:46] VOIX: que j'avance
-[01:39:47] VOIX: en fait sur le truc
-[01:39:47] VOIX: mais
-[01:39:49] VOIX: tu m'as répondu
-[01:39:50] VOIX: en fait
-[01:39:50] VOIX: sur plein de choses
-[01:39:51] VOIX: donc il y a plein
-[01:39:52] VOIX: de réponses
-[01:39:53] VOIX: en fait que j'ai déjà
-[01:39:55] VOIX: moi j'ai des problèmes
-[01:39:56] VOIX: en fait
-[01:40:06] VOIX: ah si
-[01:40:07] VOIX: pour morbidité
-[01:40:09] VOIX: je peux te donner
-[01:40:10] VOIX: une réponse
-[01:40:10] VOIX: vas-y vas-y
-[01:40:11] VOIX: tout à l'heure
-[01:40:11] VOIX: on a vu le patient
-[01:40:13] VOIX: qui est suivi
-[01:40:13] VOIX: pour anémie
-[01:40:14] VOIX: pour maladie de bière-mère
-[01:40:17] VOIX: ma logique
-[01:40:17] VOIX: je vais dire
-[01:40:19] VOIX: soit
-[01:40:19] VOIX: il a un traitement
-[01:40:21] VOIX: pour son anémie
-[01:40:22] VOIX: et pour sa maladie
-[01:40:23] VOIX: de bière-mère
-[01:40:23] VOIX: je vais regarder
-[01:40:25] VOIX: les médicaments
-[01:40:25] VOIX: je vais voir
-[01:40:26] VOIX: est-ce qu'il a dû faire
-[01:40:28] VOIX: est-ce qu'il a de la vitamine
-[01:40:29] VOIX: B12
-[01:40:29] VOIX: etc
-[01:40:30] VOIX: soit
-[01:40:30] VOIX: je vais avoir
-[01:40:31] VOIX: un examen complémentaire
-[01:40:33] VOIX: où ils disent
-[01:40:33] VOIX: je fais le bilan
-[01:40:34] VOIX: de la maladie de bière-mère
-[01:40:36] VOIX: pour voir son hémoglobine
-[01:40:37] VOIX: etc
-[01:40:37] VOIX: donc en fait
-[01:40:39] VOIX: je vais chercher
-[01:40:39] VOIX: dans le dossier
-[01:40:40] VOIX: est-ce que cet antécédent
-[01:40:42] VOIX: ou cette maladie chronique
-[01:40:44] VOIX: qui est
-[01:40:44] VOIX: qui est avec le patient
-[01:40:47] VOIX: a été prise en charge
-[01:40:49] VOIX: d'un point de vue
-[01:40:51] VOIX: diagnostic thérapeutique
-[01:40:52] VOIX: aux surveillances
-[01:40:53] VOIX: pendant le séjour
-[01:40:55] VOIX: si elle a été prise en charge
-[01:40:57] VOIX: je peux la coder
-[01:40:59] VOIX: en diagnostic associé
-[01:41:00] VOIX: qui représente
-[01:41:01] VOIX: les comorbidités
-[01:41:02] VOIX: s'il n'y a eu
-[01:41:03] VOIX: aucune prise en charge
-[01:41:05] VOIX: ou aucune mention
-[01:41:06] VOIX: et qu'il est écrite
-[01:41:07] VOIX: exclusivement
-[01:41:08] VOIX: comme antécédent
-[01:41:09] VOIX: je ne vais pas la coder
-[01:41:12] VOIX: ok
-[01:41:14] VOIX: et je
-[01:41:15] VOIX: je
-[01:41:16] VOIX: ou les dates
-[01:41:16] VOIX: dont on a parlé
-[01:41:17] VOIX: au tout début
-[01:41:18] VOIX: il te dit
-[01:41:18] VOIX: en 1928
-[01:41:20] VOIX: il a eu
-[01:41:21] VOIX: une fracture
-[01:41:23] VOIX: ben là
-[01:41:23] VOIX: tu sais que tu ne vas pas
-[01:41:25] VOIX: coder sa fracture
-[01:41:27] VOIX: tu peux coder
-[01:41:28] VOIX: peut-être des conséquences
-[01:41:29] VOIX: des choses qui sont
-[01:41:30] VOIX: encore actives
-[01:41:32] VOIX: aujourd'hui
-[01:41:32] VOIX: en fait il faut regarder
-[01:41:34] VOIX: la notion de comorbidité
-[01:41:35] VOIX: elle est liée au séjour
-[01:41:37] VOIX: c'est qu'est-ce qui a été fait
-[01:41:39] VOIX: pendant le séjour
-[01:41:40] VOIX: qu'on est en train de coder
-[01:41:41] VOIX: c'est pas
-[01:41:42] VOIX: le patient
-[01:41:43] VOIX: son histoire
-[01:41:44] VOIX: c'est pas
-[01:41:47] VOIX: c'est un codage
-[01:41:47] VOIX: le codage
-[01:41:48] VOIX: c'est vraiment
-[01:41:48] VOIX: lié au séjour
-[01:41:51] VOIX: ok
-[01:41:52] VOIX: est-ce que
-[01:41:53] VOIX: alors attends
-[01:41:54] VOIX: je continue
-[01:41:55] VOIX: mes questions
-[01:41:56] VOIX: et on a bien avancé
-[01:41:58] VOIX: c'est une bonne question
-[01:41:59] VOIX: en fait
-[01:41:59] VOIX: ah mais je te remercie
-[01:42:02] VOIX: ah oui
-[01:42:02] VOIX: non mais c'est vrai
-[01:42:03] VOIX: ça montre que
-[01:42:04] VOIX: voilà
-[01:42:05] VOIX: c'est
-[01:42:06] VOIX: la logique
-[01:42:07] VOIX: elle est en train de
-[01:42:08] VOIX: alors
-[01:42:11] VOIX: donc
-[01:42:11] VOIX: j'aime beaucoup
-[01:42:15] VOIX: cette manière
-[01:42:16] VOIX: de
-[01:42:17] VOIX: de se projeter
-[01:42:18] VOIX: dans le codage
-[01:42:19] VOIX: et qu'on ne soit pas
-[01:42:20] VOIX: que sur des mots clés
-[01:42:21] VOIX: au fait
-[01:42:22] VOIX: ah mais disons que
-[01:42:23] VOIX: si tu veux
-[01:42:24] VOIX: moi la partie
-[01:42:24] VOIX: mots clés
-[01:42:25] VOIX: pour tout te dire
-[01:42:26] VOIX: en fait moi je l'ai terminé
-[01:42:27] VOIX: les mots clés
-[01:42:28] VOIX: tu sais
-[01:42:28] VOIX: en fait avec le
-[01:42:30] VOIX: SL
-[01:42:31] VOIX: NMP
-[01:42:31] VOIX: en fait j'arrive
-[01:42:32] VOIX: à les choper
-[01:42:33] VOIX: tout ce qui est négation
-[01:42:34] VOIX: machin et tout ça
-[01:42:35] VOIX: tout ça en fait j'y arrive
-[01:42:36] VOIX: et c'est justement
-[01:42:37] VOIX: ce que l'on est en train de faire
-[01:42:38] VOIX: qui va amener
-[01:42:40] VOIX: en fait la plus value
-[01:42:40] VOIX: en fait au système
-[01:42:41] VOIX: parce qu'en réalité
-[01:42:42] VOIX: tout le reste
-[01:42:43] VOIX: en fait ce sont des règles
-[01:42:44] VOIX: mais il y a en fait
-[01:42:45] VOIX: des petites subtilités
-[01:42:46] VOIX: que tu m'as
-[01:42:47] VOIX: que tu m'as donné
-[01:42:48] VOIX: et ce sont en fait
-[01:42:49] VOIX: non mais
-[01:42:50] VOIX: quand je dis petite
-[01:42:51] VOIX: c'est pas nécessairement
-[01:42:52] VOIX: c'est justement
-[01:42:53] VOIX: c'est le grain de sable
-[01:42:54] VOIX: en fait qui fait bloquer
-[01:42:55] VOIX: en fait l'ensemble
-[01:42:55] VOIX: d'une machine
-[01:42:56] VOIX: qui est cadrée
-[01:42:57] VOIX: le fait que les contrôles
-[01:42:58] VOIX: c'est plus
-[01:42:59] VOIX: c'est le bazar absolu
-[01:43:00] VOIX: c'est les petits grains de sable
-[01:43:02] VOIX: et l'interprétation
-[01:43:03] VOIX: et comment t'orientes
-[01:43:05] VOIX: ton dossier
-[01:43:05] VOIX: c'est toute la nuance
-[01:43:07] VOIX: avec l'avocat hier
-[01:43:08] VOIX: on s'est dit
-[01:43:09] VOIX: bon on va faire ça comme ça
-[01:43:10] VOIX: on va jouer
-[01:43:11] VOIX: parce que c'est vrai
-[01:43:14] VOIX: ben oui
-[01:43:14] VOIX: mais bien sûr
-[01:43:15] VOIX: mais là aussi
-[01:43:16] VOIX: tu vois en fait
-[01:43:17] VOIX: dans ce que tu m'as
-[01:43:18] VOIX: ce que tu m'as dit
-[01:43:19] VOIX: tout à l'heure
-[01:43:20] VOIX: en fait
-[01:43:20] VOIX: le premier tableau
-[01:43:23] VOIX: en fait que tu m'as montré
-[01:43:24] VOIX: sur ta vision
-[01:43:24] VOIX: en fait des choses
-[01:43:25] VOIX: avec en fait
-[01:43:25] VOIX: ce que on fait aujourd'hui
-[01:43:27] VOIX: ce qu'on veut faire
-[01:43:28] VOIX: en fait demain
-[01:43:29] VOIX: il y a déjà
-[01:43:30] VOIX: tout ce que
-[01:43:31] VOIX: c'est ce que j'ai montré
-[01:43:32] VOIX: l'autre jour
-[01:43:32] VOIX: j'ai déjà pris en compte
-[01:43:34] VOIX: une partie
-[01:43:34] VOIX: une grosse partie
-[01:43:35] VOIX: de ce que tu m'as montré
-[01:43:36] VOIX: et notamment
-[01:43:37] VOIX: en fait le fait
-[01:43:38] VOIX: que quand on fait
-[01:43:39] VOIX: le codage
-[01:43:39] VOIX: d'accord
-[01:43:40] VOIX: quand on fait le codage
-[01:43:41] VOIX: on fait pas en fait
-[01:43:42] VOIX: simplement en fait
-[01:43:43] VOIX: le codage primaire
-[01:43:45] VOIX: c'est en fait
-[01:43:46] VOIX: on va justifier
-[01:43:47] VOIX: parce que la machine
-[01:43:48] VOIX: te permet de faire ça
-[01:43:49] VOIX: en même temps
-[01:43:50] VOIX: pourquoi tu as mis ça
-[01:43:51] VOIX: c'est pour ça
-[01:43:52] VOIX: que je pose des questions
-[01:43:53] VOIX: peut-être
-[01:43:53] VOIX: peut-être très bêtes
-[01:43:54] VOIX: c'est parfait
-[01:43:55] VOIX: voilà
-[01:43:55] VOIX: parce que
-[01:43:56] VOIX: c'est pas bête
-[01:43:57] VOIX: parce que dans la logique
-[01:43:59] VOIX: je vais te donner
-[01:43:59] VOIX: une astuce aussi
-[01:44:00] VOIX: de raisonnement
-[01:44:02] VOIX: qui est utile
-[01:44:03] VOIX: dans un des mails
-[01:44:04] VOIX: je sais pas
-[01:44:05] VOIX: si tu t'en rappelles
-[01:44:06] VOIX: je t'ai dit
-[01:44:07] VOIX: il y a la règle
-[01:44:08] VOIX: de la TIH
-[01:44:10] VOIX: il y a la règle
-[01:44:11] VOIX: médicale
-[01:44:11] VOIX: c'est-à-dire
-[01:44:12] VOIX: si tu donnes
-[01:44:14] VOIX: par exemple
-[01:44:15] VOIX: le site de l'HAS
-[01:44:16] VOIX: ou des sites
-[01:44:17] VOIX: de recommandations
-[01:44:17] VOIX: médicales
-[01:44:18] VOIX: lui va te dire
-[01:44:20] VOIX: par exemple
-[01:44:21] VOIX: j'ai lu
-[01:44:21] VOIX: dans un de tes documents
-[01:44:22] VOIX: l'insuffisance rénale
-[01:44:24] VOIX: aiguë
-[01:44:25] VOIX: obstruxive
-[01:44:25] VOIX: elle est liée
-[01:44:26] VOIX: souvent
-[01:44:27] VOIX: à l'obstruction
-[01:44:29] VOIX: urétérale
-[01:44:30] VOIX: ou je sais pas quoi
-[01:44:32] VOIX: médicalement parlant
-[01:44:32] VOIX: ça tient la route
-[01:44:33] VOIX: mais c'est quelque chose
-[01:44:35] VOIX: qui est peut-être
-[01:44:36] VOIX: non écrit
-[01:44:37] VOIX: par le médecin
-[01:44:37] VOIX: par exemple
-[01:44:38] VOIX: le patient
-[01:44:39] VOIX: il va rentrer
-[01:44:39] VOIX: pour douleur
-[01:44:41] VOIX: je vais essayer
-[01:44:42] VOIX: de trouver des dossiers
-[01:44:43] VOIX: dans cette logique là
-[01:44:43] VOIX: le médecin lui
-[01:44:44] VOIX: il va dire
-[01:44:45] VOIX: est-ce que c'est mon cause
-[01:44:46] VOIX: ma cause
-[01:44:46] VOIX: c'est une tumeur
-[01:44:47] VOIX: est-ce que c'est une infection
-[01:44:48] VOIX: est-ce que
-[01:44:49] VOIX: c'est une inflammation
-[01:44:50] VOIX: est-ce que
-[01:44:51] VOIX: donc dans sa tête lui
-[01:44:52] VOIX: il a fait
-[01:44:53] VOIX: mais il va pas
-[01:44:54] VOIX: les écrire forcément
-[01:44:55] VOIX: il y a des médecins
-[01:44:56] VOIX: qui vont les écrire
-[01:44:57] VOIX: et il y a des médecins
-[01:44:57] VOIX: qui vont pas les écrire
-[01:44:58] VOIX: et leur réflexion
-[01:45:01] VOIX: va rejoindre
-[01:45:02] VOIX: la logique médicale
-[01:45:03] VOIX: les recommandations médicales
-[01:45:04] VOIX: devant tel symptôme
-[01:45:05] VOIX: voici la démarche
-[01:45:07] VOIX: pour éliminer
-[01:45:08] VOIX: la première hypothèse
-[01:45:09] VOIX: il faut faire
-[01:45:10] VOIX: tel bilan
-[01:45:10] VOIX: tel bilan
-[01:45:11] VOIX: éliminer
-[01:45:12] VOIX: ou confirmer
-[01:45:13] VOIX: confirmer ou infirmer
-[01:45:14] VOIX: la deuxième hypothèse
-[01:45:15] VOIX: tel bilan
-[01:45:15] VOIX: tel bilan
-[01:45:16] VOIX: c'est comme ça
-[01:45:16] VOIX: qu'ils arrivent
-[01:45:17] VOIX: à la démarche diagnostique
-[01:45:19] VOIX: et en fait
-[01:45:20] VOIX: parfois
-[01:45:22] VOIX: ils vont pas
-[01:45:23] VOIX: forcément
-[01:45:24] VOIX: nommer
-[01:45:25] VOIX: la conclusion
-[01:45:28] VOIX: parce qu'elle est
-[01:45:29] VOIX: tellement évidente
-[01:45:30] VOIX: pour eux
-[01:45:30] VOIX: et nous
-[01:45:32] VOIX: on se retrouve
-[01:45:32] VOIX: coincés
-[01:45:33] VOIX: même si elle est
-[01:45:34] VOIX: évidente
-[01:45:35] VOIX: à la lecture
-[01:45:35] VOIX: mais si elle n'est
-[01:45:37] VOIX: pas écrite
-[01:45:38] VOIX: tu peux pas
-[01:45:39] VOIX: ça compte ou pas
-[01:45:41] VOIX: sauf si
-[01:45:42] VOIX: t'as un argument
-[01:45:44] VOIX: médical
-[01:45:45] VOIX: indiscutable
-[01:45:46] VOIX: c'est à dire
-[01:45:46] VOIX: tu dis
-[01:45:47] VOIX: c'est pas le plus probable
-[01:45:48] VOIX: tu dis
-[01:45:49] VOIX: j'ai tel argument
-[01:45:50] VOIX: tel symptôme
-[01:45:51] VOIX: tel bilan
-[01:45:52] VOIX: tel ça
-[01:45:53] VOIX: ça ça veut dire
-[01:45:54] VOIX: que c'est telle maladie
-[01:45:55] VOIX: et t'arrives à te défendre
-[01:45:57] VOIX: tu peux aller au tribunal
-[01:45:59] VOIX: avec
-[01:45:59] VOIX: il y a personne
-[01:45:59] VOIX: qui te dit
-[01:46:00] VOIX: non c'est pas telle pathologie
-[01:46:02] VOIX: donc ça c'est des arguments
-[01:46:04] VOIX: qu'on peut avancer
-[01:46:05] VOIX: ça veut pas dire
-[01:46:05] VOIX: qu'ils sont acceptés
-[01:46:07] VOIX: lors des contrôles
-[01:46:08] VOIX: mais ça non plus
-[01:46:09] VOIX: que c'est des arguments
-[01:46:10] VOIX: qu'on utilise
-[01:46:11] VOIX: les médecins
-[01:46:11] VOIX: ils les utilisent
-[01:46:12] VOIX: beaucoup d'ailleurs
-[01:46:14] VOIX: ok
-[01:46:14] VOIX: alors je reprends
-[01:46:17] VOIX: mes petites notes
-[01:46:18] VOIX: oui
-[01:46:18] VOIX: les questions
-[01:46:19] VOIX: alors les diagnostics
-[01:46:20] VOIX: alors attends
-[01:46:22] VOIX: je t'ai parlé
-[01:46:22] VOIX: les comorbidités
-[01:46:25] VOIX: et ensuite
-[01:46:25] VOIX: des diagnostics
-[01:46:27] VOIX: fréquents
-[01:46:27] VOIX: dans les CHR
-[01:46:29] VOIX: mais rarement
-[01:46:30] VOIX: en DP
-[01:46:34] VOIX: est-ce que tu en as
-[01:46:35] VOIX: en tête
-[01:46:36] VOIX: quelques-uns
-[01:46:37] VOIX: oui
-[01:46:40] VOIX: tu peux
-[01:46:41] VOIX: par exemple
-[01:46:43] VOIX: le sepsis
-[01:46:45] VOIX: la dénutrition
-[01:46:46] VOIX: le patient
-[01:46:48] VOIX: ne vient pas
-[01:46:48] VOIX: pour ça
-[01:46:50] VOIX: ça revient souvent
-[01:46:52] VOIX: dénutrition
-[01:46:54] VOIX: oui
-[01:46:55] VOIX: parce que le patient
-[01:46:56] VOIX: il peut être suivi
-[01:46:57] VOIX: pour un problème
-[01:46:58] VOIX: oncologique
-[01:46:59] VOIX: il vient
-[01:47:01] VOIX: parce que
-[01:47:01] VOIX: il fait une complication
-[01:47:03] VOIX: de son traitement
-[01:47:04] VOIX: de sa chimiothérapie
-[01:47:05] VOIX: par exemple
-[01:47:06] VOIX: il fait une aplasie
-[01:47:09] VOIX: médulaire
-[01:47:09] VOIX: je dis n'importe quoi
-[01:47:10] VOIX: ou une anémie
-[01:47:11] VOIX: etc
-[01:47:12] VOIX: il vient pour ça
-[01:47:14] VOIX: en même temps
-[01:47:15] VOIX: on va se rendre compte
-[01:47:17] VOIX: qu'il a une dénutrition
-[01:47:18] VOIX: on peut la prendre
-[01:47:19] VOIX: en charge
-[01:47:20] VOIX: elle va être plus ou moins
-[01:47:22] VOIX: traitée
-[01:47:23] VOIX: plus ou moins prise en charge
-[01:47:25] VOIX: plus ou moins tracée
-[01:47:25] VOIX: c'est pas traité
-[01:47:26] VOIX: il peut être traité
-[01:47:27] VOIX: mais pas forcément
-[01:47:28] VOIX: tracé dans le dossier
-[01:47:29] VOIX: donc dans ce cas là
-[01:47:31] VOIX: la dénutrition
-[01:47:32] VOIX: il faut qu'elle soit codée
-[01:47:33] VOIX: pour la coder
-[01:47:35] VOIX: il y a des règles
-[01:47:36] VOIX: dans le guide méthodologique
-[01:47:38] VOIX: qui définissent
-[01:47:39] VOIX: la dénutrition
-[01:47:40] VOIX: donc il faut que le dossier
-[01:47:41] VOIX: puisse retrouver
-[01:47:42] VOIX: si je t'ai envoyé
-[01:47:44] VOIX: un linkedin
-[01:47:45] VOIX: sur parallèle
-[01:47:47] VOIX: qui parlait
-[01:47:47] VOIX: de la dénutrition
-[01:47:48] VOIX: parce que voilà
-[01:47:49] VOIX: ils ont fait un travail
-[01:47:59] VOIX: avec risque
-[01:48:01] VOIX: ou avec une potentielle
-[01:48:02] VOIX: dénutrition
-[01:48:03] VOIX: mais peut-être qu'avec
-[01:48:05] VOIX: une meilleure traçabilité
-[01:48:06] VOIX: on pourra mieux
-[01:48:07] VOIX: les valoriser
-[01:48:07] VOIX: tu vois ça
-[01:48:08] VOIX: c'est les contrôles
-[01:48:09] VOIX: après qu'on pourra
-[01:48:10] VOIX: intégrer
-[01:48:11] VOIX: ou à défaut
-[01:48:12] VOIX: de pouvoir les coder
-[01:48:14] VOIX: tu peux me donner
-[01:48:15] VOIX: là on a parlé
-[01:48:16] VOIX: de la dénutrition
-[01:48:17] VOIX: tu peux m'en donner
-[01:48:18] VOIX: en fait un ou deux
-[01:48:19] VOIX: de plus
-[01:48:20] VOIX: comme ça
-[01:48:21] VOIX: le sepsis
-[01:48:22] VOIX: le sepsis
-[01:48:23] VOIX: s-e-p-s-i-s
-[01:48:33] VOIX: ok
-[01:48:33] VOIX: le diabète
-[01:48:37] VOIX: hypertension
-[01:48:37] VOIX: tu vois
-[01:48:38] VOIX: ah d'accord
-[01:48:39] VOIX: mais c'est des trucs
-[01:48:39] VOIX: fréquents en fait
-[01:48:40] VOIX: ça
-[01:48:40] VOIX: par exemple
-[01:48:44] VOIX: l'hypertension
-[01:48:45] VOIX: c'est hyper fréquemment
-[01:48:46] VOIX: codé en diagnostics
-[01:48:47] VOIX: associés
-[01:48:47] VOIX: d'accord
-[01:48:48] VOIX: c'est pas forcément
-[01:48:49] VOIX: c'est rare
-[01:48:50] VOIX: les patients
-[01:48:51] VOIX: qui viennent
-[01:48:51] VOIX: pour une poussée
-[01:48:52] VOIX: d'hypertension
-[01:48:54] VOIX: bon moi
-[01:48:54] VOIX: je l'ai fait
-[01:48:56] VOIX: oui
-[01:48:57] VOIX: mais je suis
-[01:48:58] VOIX: un oiseau rare
-[01:49:01] VOIX: si le patient
-[01:49:02] VOIX: il vient
-[01:49:02] VOIX: pour la poussée
-[01:49:03] VOIX: d'hypertension
-[01:49:03] VOIX: dans ce cas là
-[01:49:04] VOIX: ça sera le DP
-[01:49:05] VOIX: et si on regarde
-[01:49:07] VOIX: les diagnostics
-[01:49:07] VOIX: associés
-[01:49:08] VOIX: les plus fréquemment
-[01:49:08] VOIX: codés
-[01:49:09] VOIX: ça peut être
-[01:49:10] VOIX: l'hypertension
-[01:49:10] VOIX: par exemple
-[01:49:14] VOIX: bon
-[01:49:14] VOIX: t'es prête
-[01:49:14] VOIX: pour l'autre question
-[01:49:17] VOIX: alors
-[01:49:18] VOIX: quels diagnostics
-[01:49:19] VOIX: sont quasi
-[01:49:20] VOIX: toujours
-[01:49:21] VOIX: des DAS
-[01:49:22] VOIX: sauf exception
-[01:49:27] VOIX: tu as une règle
-[01:49:30] VOIX: déjà
-[01:49:31] VOIX: liée
-[01:49:32] VOIX: au code
-[01:49:32] VOIX: qui te dit
-[01:49:33] VOIX: que ces codes là
-[01:49:34] VOIX: ne pourraient pas
-[01:49:34] VOIX: être acceptés
-[01:49:35] VOIX: en DP
-[01:49:36] VOIX: donc ça
-[01:49:37] VOIX: c'est lié
-[01:49:37] VOIX: à la CIN10
-[01:49:38] VOIX: mais on a un fichier
-[01:49:39] VOIX: Excel référentiel
-[01:49:40] VOIX: qui dit
-[01:49:42] VOIX: diagnostic
-[01:49:43] VOIX: accepté
-[01:49:44] VOIX: ou pas
-[01:49:44] VOIX: en DP
-[01:49:45] VOIX: donc ça
-[01:49:46] VOIX: c'est facile
-[01:49:46] VOIX: c'est un référentiel
-[01:49:55] VOIX: franchement
-[01:49:56] VOIX: là
-[01:49:56] VOIX: comme ça
-[01:49:57] VOIX: c'est pas grave
-[01:50:01] VOIX: là
-[01:50:02] VOIX: comme ça
-[01:50:02] VOIX: puisque
-[01:50:03] VOIX: par définition
-[01:50:06] VOIX: tu sais
-[01:50:07] VOIX: le patient
-[01:50:08] VOIX: il peut venir
-[01:50:09] VOIX: pour un diabète
-[01:50:10] VOIX: il peut venir
-[01:50:10] VOIX: pour un
-[01:50:11] VOIX: hypertension
-[01:50:12] VOIX: il peut venir
-[01:50:13] VOIX: peut-être
-[01:50:14] VOIX: un peu
-[01:50:14] VOIX: plus rarement
-[01:50:15] VOIX: tout ce qui est
-[01:50:16] VOIX: hypercholestérolémie
-[01:50:17] VOIX: mais bon
-[01:50:18] VOIX: c'est pas
-[01:50:19] VOIX: interdit
-[01:50:19] VOIX: c'est pas
-[01:50:20] VOIX: impossible
-[01:50:21] VOIX: il y a
-[01:50:22] VOIX: les diagnostics
-[01:50:23] VOIX: interdits
-[01:50:23] VOIX: en DP
-[01:50:24] VOIX: mais ça
-[01:50:24] VOIX: c'est un référentiel
-[01:50:25] VOIX: tu ne les trouveras
-[01:50:26] VOIX: jamais en DP
-[01:50:27] VOIX: donc ça
-[01:50:27] VOIX: c'est facile
-[01:50:29] VOIX: donc ça
-[01:50:30] VOIX: en fait
-[01:50:30] VOIX: c'est mes exceptions
-[01:50:31] VOIX: c'est ça
-[01:50:32] VOIX: ok
-[01:50:33] VOIX: c'est peut-être
-[01:50:34] VOIX: si tu cherches
-[01:50:35] VOIX: l'exception
-[01:50:36] VOIX: ça peut être
-[01:50:36] VOIX: les diagnostics
-[01:50:37] VOIX: interdits
-[01:50:37] VOIX: en DP
-[01:50:38] VOIX: mais ça
-[01:50:38] VOIX: le référentiel
-[01:50:39] VOIX: on peut te le donner
-[01:50:40] VOIX: il n'y a pas de soucis
-[01:50:47] VOIX: j'arrive
-[01:50:49] VOIX: parce que
-[01:50:49] VOIX: j'ai une autre question
-[01:50:50] VOIX: et là
-[01:50:51] VOIX: celle-là
-[01:50:51] VOIX: elle m'intéresse
-[01:50:52] VOIX: beaucoup
-[01:50:53] VOIX: c'est
-[01:50:54] VOIX: ce que j'appelle
-[01:50:55] VOIX: en fait
-[01:50:55] VOIX: des codes
-[01:50:56] VOIX: toxiques
-[01:50:57] VOIX: en contrôle
-[01:50:58] VOIX: c'est en fait
-[01:50:59] VOIX: qu'ils sont souvent
-[01:51:00] VOIX: en fait
-[01:51:00] VOIX: contestés
-[01:51:02] VOIX: ah ben
-[01:51:03] VOIX: ça c'est facile
-[01:51:04] VOIX: attends
-[01:51:05] VOIX: je vais les écrire
-[01:51:07] VOIX: mais je peux te faire
-[01:51:08] VOIX: avec les écrits
-[01:51:08] VOIX: déjà
-[01:51:09] VOIX: la dénutrition
-[01:51:10] VOIX: les sepsis
-[01:51:12] VOIX: on a eu
-[01:51:13] VOIX: ce qu'on appelle
-[01:51:14] VOIX: les T80
-[01:51:14] VOIX: les complications
-[01:51:16] VOIX: les complications
-[01:51:17] VOIX: post-opératoires
-[01:51:19] VOIX: mais pas tous
-[01:51:20] VOIX: c'est certain
-[01:51:21] VOIX: mais je peux te donner
-[01:51:22] VOIX: ces quelques codes
-[01:51:23] VOIX: ouais
-[01:51:23] VOIX: je veux
-[01:51:24] VOIX: 74 2
-[01:51:26] VOIX: pas maintenant
-[01:51:27] VOIX: en fait
-[01:51:28] VOIX: mais
-[01:51:28] VOIX: tu vois en fait
-[01:51:29] VOIX: les questions
-[01:51:30] VOIX: bêtes que je me pose
-[01:51:31] VOIX: pour gérer le truc
-[01:51:32] VOIX: parce qu'en fait
-[01:51:33] VOIX: non non
-[01:51:33] VOIX: elles ne sont pas bêtes
-[01:51:34] VOIX: les questions
-[01:51:35] VOIX: ne sont pas bêtes
-[01:51:35] VOIX: Dominique
-[01:51:36] VOIX: ce sont en fait
-[01:51:37] VOIX: des questions
-[01:51:37] VOIX: qui me permettent
-[01:51:39] VOIX: je t'expliquerai quand même
-[01:51:40] VOIX: comment ça marche
-[01:51:41] VOIX: parce que
-[01:51:42] VOIX: ouais
-[01:51:43] VOIX: pour que tu comprennes
-[01:51:43] VOIX: en fait
-[01:51:44] VOIX: j'établis en fait
-[01:51:45] VOIX: justement
-[01:51:46] VOIX: il y a une partie
-[01:51:47] VOIX: en fait
-[01:51:47] VOIX: c'est l'IA qui le gère
-[01:51:49] VOIX: mais une grosse partie
-[01:51:50] VOIX: en fait
-[01:51:50] VOIX: ce sont en fait
-[01:51:51] VOIX: des règles
-[01:51:52] VOIX: qui sont fixées
-[01:51:54] VOIX: et on ne peut pas
-[01:51:55] VOIX: y déroger
-[01:51:56] VOIX: parce que de toute façon
-[01:51:57] VOIX: en fait
-[01:51:57] VOIX: ça rejouit
-[01:51:58] VOIX: en fait
-[01:51:58] VOIX: ce que
-[01:51:59] VOIX: voilà
-[01:51:59] VOIX: tu me renvoies
-[01:52:00] VOIX: en fait
-[01:52:00] VOIX: au document
-[01:52:01] VOIX: en fait
-[01:52:01] VOIX: de la TIH
-[01:52:03] VOIX: à chaque fois
-[01:52:04] VOIX: donc
-[01:52:04] VOIX: mais par contre
-[01:52:06] VOIX: j'ai chopé
-[01:52:07] VOIX: en fait
-[01:52:07] VOIX: alors j'ai chopé
-[01:52:08] VOIX: on va le voir
-[01:52:08] VOIX: on va le vérifier
-[01:52:10] VOIX: ta logique
-[01:52:11] VOIX: en fait
-[01:52:11] VOIX: ta façon
-[01:52:12] VOIX: en fait
-[01:52:12] VOIX: parce que
-[01:52:12] VOIX: tout ce que
-[01:52:13] VOIX: tu as dit
-[01:52:15] VOIX: je l'avais
-[01:52:16] VOIX: entreaperçu
-[01:52:16] VOIX: avec toutes les discussions
-[01:52:17] VOIX: qu'on a eu
-[01:52:18] VOIX: mais je les ai validées
-[01:52:21] VOIX: donc bon
-[01:52:21] VOIX: après
-[01:52:23] VOIX: c'est pour ça
-[01:52:24] VOIX: que je t'ai dit
-[01:52:24] VOIX: on fera des sessions
-[01:52:25] VOIX: autant de fois
-[01:52:26] VOIX: que nécessaire
-[01:52:27] VOIX: ah oui
-[01:52:27] VOIX: non mais là
-[01:52:28] VOIX: tu ne vas pas y couper
-[01:52:32] VOIX: non parce que
-[01:52:33] VOIX: alors attends
-[01:52:34] VOIX: attends
-[01:52:34] VOIX: parce qu'il me reste
-[01:52:35] VOIX: une ou deux
-[01:52:36] VOIX: et après
-[01:52:36] VOIX: je suis tranquille
-[01:52:38] VOIX: donc ok
-[01:52:40] VOIX: ah oui
-[01:52:41] VOIX: si
-[01:52:41] VOIX: tu sais
-[01:52:42] VOIX: oui pardon
-[01:52:42] VOIX: juste
-[01:52:43] VOIX: tes questions là
-[01:52:44] VOIX: tu les as notées
-[01:52:45] VOIX: ah bah oui
-[01:52:46] VOIX: je suis en fait
-[01:52:47] VOIX: mon protocole
-[01:52:47] VOIX: parce que tu sais
-[01:52:48] VOIX: moi tes questions
-[01:52:50] VOIX: mais c'est une formation
-[01:52:52] VOIX: que
-[01:52:52] VOIX: c'est génial
-[01:52:54] VOIX: tes questions
-[01:52:54] VOIX: tu vois
-[01:52:55] VOIX: quand tu
-[01:52:56] VOIX: non mais c'est vrai
-[01:52:57] VOIX: parce que
-[01:52:58] VOIX: nous en formation
-[01:52:59] VOIX: on travaille beaucoup
-[01:53:00] VOIX: avec
-[01:53:02] VOIX: est-ce que tu peux mettre ça
-[01:53:03] VOIX: c'est quoi la règle
-[01:53:05] VOIX: pour coder ça
-[01:53:06] VOIX: etc
-[01:53:07] VOIX: donc c'est une manière
-[01:53:08] VOIX: aussi
-[01:53:08] VOIX: tu vois
-[01:53:08] VOIX: moi aussi
-[01:53:10] VOIX: ta manière
-[01:53:11] VOIX: de te poser les questions
-[01:53:11] VOIX: m'intéresse énormément
-[01:53:12] VOIX: parce que
-[01:53:14] VOIX: je vais te l'envoyer
-[01:53:15] VOIX: en fait
-[01:53:15] VOIX: je vais t'envoyer
-[01:53:16] VOIX: en fait
-[01:53:16] VOIX: le truc
-[01:53:17] VOIX: oui
-[01:53:18] VOIX: je te l'envoie en fait
-[01:53:19] VOIX: ben là quand on a fini
-[01:53:20] VOIX: je t'envoie en fait
-[01:53:20] VOIX: le document de vierge
-[01:53:22] VOIX: c'est très intéressant
-[01:53:24] VOIX: et en fait
-[01:53:25] VOIX: si tu
-[01:53:25] VOIX: et je t'expliquerai
-[01:53:26] VOIX: en fait
-[01:53:27] VOIX: comment je fonctionne
-[01:53:27] VOIX: d'ailleurs en fait
-[01:53:28] VOIX: une grosse discussion
-[01:53:29] VOIX: en fait
-[01:53:29] VOIX: de même
-[01:53:29] VOIX: du mode de fonctionnement
-[01:53:31] VOIX: de tout ce bazar
-[01:53:32] VOIX: ok
-[01:53:33] VOIX: j'aurais besoin
-[01:53:34] VOIX: que tu m'envoies
-[01:53:35] VOIX: en fait
-[01:53:35] VOIX: si tu as la possibilité
-[01:53:37] VOIX: attends
-[01:53:38] VOIX: parce que
-[01:53:38] VOIX: ça c'est fait
-[01:53:43] VOIX: oui
-[01:53:44] VOIX: il me faudra en fait avoir
-[01:53:46] VOIX: mais ça je vais peut-être
-[01:53:47] VOIX: demander à Jordan
-[01:53:48] VOIX: ou à
-[01:53:49] VOIX: ou à Pauline
-[01:53:50] VOIX: si je pourrais avoir
-[01:53:52] VOIX: en fait une fiche
-[01:53:54] VOIX: attends
-[01:53:55] VOIX: comment ça s'appelle
-[01:53:56] VOIX: j'arrête pas en fait
-[01:53:58] VOIX: de rouler
-[01:53:59] VOIX: et de rouler
-[01:53:59] VOIX: moi aussi je cherche
-[01:54:00] VOIX: un document
-[01:54:01] VOIX: je vais te le montrer aussi
-[01:54:02] VOIX: oui
-[01:54:03] VOIX: en fait
-[01:54:04] VOIX: c'est un format
-[01:54:05] VOIX: en fait
-[01:54:06] VOIX: de synthèse
-[01:54:06] VOIX: PMSI
-[01:54:08] VOIX: idéal
-[01:54:08] VOIX: ce que tu aimerais avoir
-[01:54:17] VOIX: mais tu sais quoi
-[01:54:19] VOIX: mais tu sais quoi
-[01:54:22] VOIX: le format en fait
-[01:54:23] VOIX: de synthèse
-[01:54:24] VOIX: PMSI
-[01:54:24] VOIX: alors
-[01:54:26] VOIX: j'ai besoin
-[01:54:26] VOIX: en fait de savoir
-[01:54:27] VOIX: en fait
-[01:54:27] VOIX: comment en fait
-[01:54:28] VOIX: tu ferais
-[01:54:29] VOIX: en fait
-[01:54:30] VOIX: une synthèse
-[01:54:31] VOIX: en fait
-[01:54:31] VOIX: de PMSI
-[01:54:35] VOIX: tu vois pas du tout
-[01:54:37] VOIX: en fait
-[01:54:38] VOIX: est-ce que
-[01:54:39] VOIX: c'est une synthèse
-[01:54:40] VOIX: parce que j'ai des diapos
-[01:54:41] VOIX: de formation
-[01:54:42] VOIX: PMSI
-[01:54:44] VOIX: où je vais
-[01:54:45] VOIX: expliquer
-[01:54:46] VOIX: c'est quoi le DP
-[01:54:46] VOIX: c'est quoi le DAS
-[01:54:47] VOIX: etc
-[01:54:49] VOIX: ou est-ce que
-[01:54:49] VOIX: c'est un format
-[01:54:52] VOIX: des codes
-[01:54:54] VOIX: un peu comme
-[01:54:54] VOIX: la fiche
-[01:54:55] VOIX: OGC
-[01:54:56] VOIX: qu'est-ce qu'il y a
-[01:54:57] VOIX: dans
-[01:54:57] VOIX: oui
-[01:54:58] VOIX: voilà
-[01:54:58] VOIX: en fait
-[01:54:59] VOIX: c'est plutôt ça
-[01:55:00] VOIX: plutôt la fiche
-[01:55:02] VOIX: d'un séjour
-[01:55:03] VOIX: oui
-[01:55:03] VOIX: la fiche
-[01:55:03] VOIX: d'un séjour
-[01:55:05] VOIX: oui
-[01:55:06] VOIX: mais
-[01:55:07] VOIX: attends
-[01:55:08] VOIX: attends
-[01:55:08] VOIX: parce que
-[01:55:09] VOIX: Guy
-[01:55:09] VOIX: parle en même temps
-[01:55:11] VOIX: la fiche
-[01:55:12] VOIX: d'un séjour
-[01:55:13] VOIX: t'as deux points
-[01:55:14] VOIX: de vue
-[01:55:14] VOIX: t'as le point
-[01:55:15] VOIX: de vue PMSI
-[01:55:16] VOIX: les données
-[01:55:17] VOIX: la variable
-[01:55:18] VOIX: un peu comme
-[01:55:20] VOIX: on va le remontrer
-[01:55:21] VOIX: là
-[01:55:23] VOIX: pardon
-[01:55:24] VOIX: vas-y
-[01:55:24] VOIX: vas-y
-[01:55:25] VOIX: la fiche
-[01:55:26] VOIX: là qu'on a
-[01:55:27] VOIX: vue tout à l'heure
-[01:55:30] VOIX: j'espère que
-[01:55:31] VOIX: je vais me reconnecter
-[01:55:32] VOIX: la fiche
-[01:55:35] VOIX: OGC
-[01:55:36] VOIX: et là
-[01:55:36] VOIX: tu vas voir
-[01:55:38] VOIX: elle est intéressante
-[01:55:39] VOIX: parce que
-[01:55:39] VOIX: pratiquement
-[01:55:41] VOIX: si on revient
-[01:55:42] VOIX: à notre
-[01:55:43] VOIX: 339
-[01:55:49] VOIX: la fiche
-[01:55:50] VOIX: en fait
-[01:55:50] VOIX: t'as
-[01:55:52] VOIX: t'as
-[01:55:52] VOIX: ici
-[01:55:53] VOIX: le nombre
-[01:55:54] VOIX: d'unités médicales
-[01:55:56] VOIX: qu'il y a
-[01:55:57] VOIX: là il y a une seule
-[01:55:58] VOIX: mais on peut avoir
-[01:55:59] VOIX: plusieurs unités médicales
-[01:56:00] VOIX: avec leur propre date d'entrée
-[01:56:02] VOIX: etc
-[01:56:03] VOIX: t'as les données administratives
-[01:56:05] VOIX: du patient
-[01:56:05] VOIX: tu vois quand je te dis
-[01:56:07] VOIX: l'âge
-[01:56:09] VOIX: mode d'entrée
-[01:56:10] VOIX: la durée de séjour
-[01:56:11] VOIX: les dates
-[01:56:14] VOIX: les modes de remboursement
-[01:56:15] VOIX: bon ça c'est des données
-[01:56:16] VOIX: de facturation
-[01:56:18] VOIX: ça on le regarde après
-[01:56:20] VOIX: attends
-[01:56:20] VOIX: je lève le doigt
-[01:56:21] VOIX: t'as vu
-[01:56:21] VOIX: pardon
-[01:56:22] VOIX: je vois pas
-[01:56:23] VOIX: parce que
-[01:56:24] VOIX: je
-[01:56:25] VOIX: ouais
-[01:56:25] VOIX: il est entré
-[01:56:27] VOIX: est-ce que c'est une donnée
-[01:56:28] VOIX: en fait
-[01:56:28] VOIX: très importante
-[01:56:30] VOIX: très importante
-[01:56:31] VOIX: ou moyennement importante
-[01:56:32] VOIX: il est entré
-[01:56:33] VOIX: via domicile
-[01:56:34] VOIX: et sorti
-[01:56:34] VOIX: par domicile
-[01:56:37] VOIX: oui
-[01:56:37] VOIX: est importante
-[01:56:39] VOIX: parce que
-[01:56:42] VOIX: il est importante
-[01:56:43] VOIX: parce qu'on peut faire
-[01:56:44] VOIX: des contrôles là-dessus
-[01:56:45] VOIX: les modes d'entrée
-[01:56:47] VOIX: on a des contrôles aussi
-[01:56:48] VOIX: sur
-[01:56:49] VOIX: est-ce qu'il est passé
-[01:56:50] VOIX: par les urgences
-[01:56:51] VOIX: ou pas
-[01:56:51] VOIX: est-ce que
-[01:56:52] VOIX: il a été transféré
-[01:56:54] VOIX: est-ce que
-[01:56:55] VOIX: c'est pas important
-[01:56:57] VOIX: au même niveau
-[01:56:58] VOIX: que le codage
-[01:56:59] VOIX: mais c'est des choses
-[01:56:59] VOIX: qu'on contrôle
-[01:57:00] VOIX: qu'on regarde
-[01:57:01] VOIX: tout ça
-[01:57:02] VOIX: le team
-[01:57:02] VOIX: va le vérifier
-[01:57:03] VOIX: par exemple
-[01:57:04] VOIX: parfois on corrige ça
-[01:57:05] VOIX: au moment
-[01:57:06] VOIX: à la lecture du dossier
-[01:57:07] VOIX: on se rend compte
-[01:57:08] VOIX: ben finalement
-[01:57:09] VOIX: c'était
-[01:57:09] VOIX: il a été transféré
-[01:57:10] VOIX: depuis un établissement
-[01:57:11] VOIX: on va corriger
-[01:57:13] VOIX: ok
-[01:57:13] VOIX: d'accord
-[01:57:16] VOIX: ou les dates
-[01:57:16] VOIX: tu vois tout à l'heure
-[01:57:17] VOIX: quand je t'ai dit
-[01:57:18] VOIX: le 24 ou le 25
-[01:57:21] VOIX: si on se rend compte
-[01:57:23] VOIX: qu'il y a une erreur
-[01:57:23] VOIX: on peut la corriger
-[01:57:26] VOIX: donc il y a cette vue là
-[01:57:27] VOIX: de dire
-[01:57:28] VOIX: la synthèse du séjour
-[01:57:30] VOIX: d'un point de vue
-[01:57:31] VOIX: codage PMSI
-[01:57:32] VOIX: où je vais avoir
-[01:57:34] VOIX: mes diagnostics
-[01:57:36] VOIX: là typiquement
-[01:57:37] VOIX: mes diagnostics
-[01:57:38] VOIX: si ils disent
-[01:57:39] VOIX: j'ai mes modes d'entrée
-[01:57:41] VOIX: mes dates
-[01:57:41] VOIX: parce qu'on peut
-[01:57:42] VOIX: les modifier aussi
-[01:57:43] VOIX: j'ai mes actes
-[01:57:45] VOIX: et en haut
-[01:57:46] VOIX: tu vois les actes
-[01:57:47] VOIX: est-ce qu'il est classant
-[01:57:48] VOIX: est-ce qu'il y a eu
-[01:57:49] VOIX: de l'anesthésie
-[01:57:50] VOIX: etc
-[01:57:51] VOIX: et en haut
-[01:57:52] VOIX: je vais avoir
-[01:57:54] VOIX: le groupage
-[01:57:55] VOIX: c'est-à-dire
-[01:57:55] VOIX: comment il a été groupé
-[01:57:57] VOIX: le GHF
-[01:57:58] VOIX: le GHS
-[01:57:59] VOIX: la valo
-[01:58:00] VOIX: est-ce qu'il y a des suppléments
-[01:58:01] VOIX: est-ce qu'il y a des bornes
-[01:58:03] VOIX: parce que ça
-[01:58:03] VOIX: là on rentre
-[01:58:04] VOIX: dans la logique
-[01:58:05] VOIX: de valorisation
-[01:58:06] VOIX: du séjour
-[01:58:07] VOIX: qui est encore
-[01:58:07] VOIX: une autre histoire
-[01:58:10] VOIX: est-ce que c'est
-[01:58:11] VOIX: cette logique là
-[01:58:12] VOIX: ou est-ce que
-[01:58:13] VOIX: quand tu dis
-[01:58:14] VOIX: la synthèse du séjour
-[01:58:15] VOIX: c'est
-[01:58:15] VOIX: ce qu'on a vu
-[01:58:17] VOIX: tout à l'heure
-[01:58:17] VOIX: dans Tracker
-[01:58:19] VOIX: j'ai les observations
-[01:58:21] VOIX: médicales
-[01:58:21] VOIX: j'ai le courrier
-[01:58:22] VOIX: j'ai les comptes rendus
-[01:58:23] VOIX: opératoires
-[01:58:24] VOIX: j'ai les comptes rendus
-[01:58:24] VOIX: de biologie
-[01:58:25] VOIX: non c'est plutôt ça
-[01:58:26] VOIX: qu'est-ce que tu attends
-[01:58:27] VOIX: c'était plutôt ça
-[01:58:29] VOIX: en fait
-[01:58:29] VOIX: et avec
-[01:58:30] VOIX: mais
-[01:58:32] VOIX: non non
-[01:58:32] VOIX: c'est plutôt ce que tu viens
-[01:58:33] VOIX: de me montrer
-[01:58:35] VOIX: donc ce que je viens
-[01:58:36] VOIX: de te montrer
-[01:58:37] VOIX: demande à Jordan
-[01:58:38] VOIX: de te montrer
-[01:58:39] VOIX: ça c'est la fiche
-[01:58:40] VOIX: où j'essaie
-[01:58:41] VOIX: mais dans les autres
-[01:58:42] VOIX: logiciels
-[01:58:42] VOIX: Eva
-[01:58:43] VOIX: on a un truc
-[01:58:44] VOIX: qui s'appelle
-[01:58:45] VOIX: fiche
-[01:58:45] VOIX: fiche dossier
-[01:58:47] VOIX: dans Eva
-[01:58:49] VOIX: améliorer le codage
-[01:58:50] VOIX: par exemple
-[01:58:51] VOIX: et c'est ni plus
-[01:58:52] VOIX: ni moins que ça
-[01:58:53] VOIX: on s'est inspiré d'ailleurs
-[01:58:54] VOIX: de cette fiche
-[01:58:55] VOIX: pour créer les étapes
-[01:58:56] VOIX: du contrôle T2A
-[01:58:57] VOIX: et elle
-[01:58:58] VOIX: elle n'a pas
-[01:58:59] VOIX: les étapes
-[01:59:00] VOIX: elle est de l'établissement
-[01:59:02] VOIX: toute seule
-[01:59:02] VOIX: elle est déjà faite
-[01:59:04] VOIX: et elle est
-[01:59:04] VOIX: voilà
-[01:59:05] VOIX: c'est un module à part
-[01:59:06] VOIX: et il peut te montrer
-[01:59:07] VOIX: par exemple
-[01:59:08] VOIX: quand il y a plusieurs unités
-[01:59:09] VOIX: quand on a les suppléments
-[01:59:10] VOIX: quand on a les variables
-[01:59:11] VOIX: gradations
-[01:59:12] VOIX: quand on a le bébé
-[01:59:14] VOIX: la maman
-[01:59:14] VOIX: pour les accouchements
-[01:59:15] VOIX: il y a plusieurs cas
-[01:59:16] VOIX: spécifiques
-[01:59:17] VOIX: si tu veux
-[01:59:17] VOIX: c'est pas juste
-[01:59:18] VOIX: un modèle unique
-[01:59:20] VOIX: parce que là
-[01:59:21] VOIX: on est en train
-[01:59:22] VOIX: de voir
-[01:59:22] VOIX: les cas généraux
-[01:59:23] VOIX: mais si tu rentres
-[01:59:24] VOIX: dans l'obstétrique
-[01:59:26] VOIX: si tu rentres
-[01:59:26] VOIX: dans l'area
-[01:59:28] VOIX: la surveillance continue
-[01:59:29] VOIX: etc
-[01:59:29] VOIX: on va avoir
-[01:59:30] VOIX: des spécificités
-[01:59:31] VOIX: de codage
-[01:59:33] VOIX: de valorisation
-[01:59:33] VOIX: etc
-[01:59:34] VOIX: là on est dans
-[01:59:35] VOIX: le cas de base
-[01:59:36] VOIX: un team
-[01:59:37] VOIX: qui est en train
-[01:59:38] VOIX: d'apprendre
-[01:59:38] VOIX: à coder
-[01:59:40] VOIX: ok
-[01:59:41] VOIX: bon
-[01:59:42] VOIX: la 11
-[01:59:42] VOIX: ça tu m'as
-[01:59:44] VOIX: déjà répondu
-[01:59:44] VOIX: c'est les champs
-[01:59:46] VOIX: en fait
-[01:59:46] VOIX: c'est les motifs
-[01:59:47] VOIX: d'admission
-[01:59:47] VOIX: diagnostic
-[01:59:48] VOIX: en fait
-[01:59:48] VOIX: il y a la date
-[01:59:49] VOIX: en fait
-[01:59:49] VOIX: que j'ai remarqué
-[01:59:51] VOIX: excuse-moi
-[01:59:52] VOIX: je me parle
-[01:59:52] VOIX: en même temps
-[01:59:53] VOIX: et je dis
-[01:59:53] VOIX: en fait
-[01:59:53] VOIX: en même temps
-[01:59:54] VOIX: oui
-[01:59:54] VOIX: est-ce qu'on n'a pas
-[01:59:55] VOIX: ici dans ces fiches
-[01:59:57] VOIX: mais on l'a
-[01:59:58] VOIX: dans les bases
-[01:59:59] VOIX: c'est tout ce qui est lié
-[02:00:01] VOIX: aux règles
-[02:00:03] VOIX: de remboursement
-[02:00:04] VOIX: et d'affiliation
-[02:00:05] VOIX: du patient
-[02:00:05] VOIX: à la caisse
-[02:00:06] VOIX: est-ce qu'il est
-[02:00:07] VOIX: assurance maladie
-[02:00:08] VOIX: est-ce qu'il est
-[02:00:09] VOIX: remboursable
-[02:00:09] VOIX: est-ce qu'il est à 80%
-[02:00:11] VOIX: est-ce qu'il est à 100%
-[02:00:12] VOIX: c'est quelle caisse
-[02:00:15] VOIX: voilà
-[02:00:15] VOIX: c'est toutes les règles
-[02:00:16] VOIX: liées à la facturation
-[02:00:19] VOIX: bon ok
-[02:00:19] VOIX: ok
-[02:00:20] VOIX: c'est noté
-[02:00:22] VOIX: moi il me faut
-[02:00:24] VOIX: en fait
-[02:00:24] VOIX: je rajoute
-[02:00:25] VOIX: parce qu'en fait
-[02:00:26] VOIX: c'est plus
-[02:00:26] VOIX: la fiche
-[02:00:29] VOIX: pour ma synthèse
-[02:00:30] VOIX: ben moi je note
-[02:00:31] VOIX: fiche EVA
-[02:00:32] VOIX: au fait
-[02:00:32] VOIX: oui
-[02:00:33] VOIX: oui ça rejoint
-[02:00:34] VOIX: c'est toujours
-[02:00:35] VOIX: ouais
-[02:00:35] VOIX: et ça
-[02:00:36] VOIX: là je t'ai montré
-[02:00:37] VOIX: la fiche OGC
-[02:00:38] VOIX: mais
-[02:00:39] VOIX: Jordan pourra te montrer
-[02:00:41] VOIX: ou voilà
-[02:00:42] VOIX: ou moi-même
-[02:00:44] VOIX: si je me connecte
-[02:00:47] VOIX: non mais on le
-[02:00:48] VOIX: tu veux que je te la montre
-[02:00:52] VOIX: ou pas
-[02:00:52] VOIX: bon
-[02:00:53] VOIX: ben vas-y
-[02:00:54] VOIX: vas-y
-[02:00:54] VOIX: puisque tu y es
-[02:00:55] VOIX: vas-y
-[02:00:57] VOIX: non j'y suis pas
-[02:00:58] VOIX: tu peux encore connecter
-[02:01:06] VOIX: je sais même pas si
-[02:01:09] VOIX: ah c'est
-[02:01:10] VOIX: c'est
-[02:01:10] VOIX: c'est
-[02:01:10] VOIX: c'est
-[02:01:11] VOIX: le
-[02:01:11] VOIX: le fameux code
-[02:01:12] VOIX: qui est écrit
-[02:01:13] VOIX: en fait
-[02:01:13] VOIX: dans une feuille Excel
-[02:01:14] VOIX: est-ce que
-[02:01:17] VOIX: ouais
-[02:01:17] VOIX: est-ce qu'il va marcher
-[02:01:18] VOIX: ou pas
-[02:01:19] VOIX: ça a marché de suite
-[02:01:20] VOIX: ça a marché
-[02:01:21] VOIX: tu as vu
-[02:01:22] VOIX: la fiche Excel
-[02:01:23] VOIX: de mon cerveau
-[02:01:24] VOIX: pour celui-là
-[02:01:26] VOIX: je me rappelle
-[02:01:27] VOIX: je me rappelle
-[02:01:33] VOIX: je me rappelle
-[02:01:35] VOIX: je m'en fous
-[02:01:35] VOIX: et
-[02:01:36] VOIX: c'est là
-[02:01:37] VOIX: où on va s'amuser
-[02:01:38] VOIX: parce que
-[02:01:39] VOIX: je me rappelle
-[02:01:40] VOIX: plus
-[02:01:40] VOIX: si
-[02:01:41] VOIX: je me rappelle
-[02:01:54] VOIX: le codage
-[02:01:55] VOIX: le codage
-[02:01:56] VOIX: ça c'est
-[02:01:56] VOIX: du MCO
-[02:01:57] VOIX: parce que
-[02:01:57] VOIX: là on est en train
-[02:01:58] VOIX: de travailler
-[02:01:59] VOIX: le champ MCO
-[02:02:00] VOIX: on n'a pas encore
-[02:02:01] VOIX: touché le champ
-[02:02:01] VOIX: et c'est là
-[02:02:04] VOIX: je suis pas sûr
-[02:02:09] VOIX: parce que
-[02:02:10] VOIX: il y a trop
-[02:02:10] VOIX: d'abréviations
-[02:02:11] VOIX: il y a trop
-[02:02:12] VOIX: de machin
-[02:02:12] VOIX: là
-[02:02:13] VOIX: bon
-[02:02:13] VOIX: mais au fur et à mesure
-[02:02:14] VOIX: ce que tu peux faire aussi
-[02:02:16] VOIX: c'est sortir une liste
-[02:02:18] VOIX: d'abréviations
-[02:02:19] VOIX: qu'on peut nous
-[02:02:20] VOIX: te compléter
-[02:02:21] VOIX: oui
-[02:02:22] VOIX: ben en fait
-[02:02:22] VOIX: je les ai
-[02:02:23] VOIX: en fait
-[02:02:23] VOIX: à ces listes-là
-[02:02:24] VOIX: mais c'est que
-[02:02:25] VOIX: je suis
-[02:02:25] VOIX: je n'arrive pas
-[02:02:26] VOIX: à les retenir
-[02:02:27] VOIX: c'est avec
-[02:02:28] VOIX: l'expérience
-[02:02:29] VOIX: avec le
-[02:02:30] VOIX: au fur et à mesure
-[02:02:31] VOIX: il y a des choses
-[02:02:32] VOIX: que j'arrive à comprendre
-[02:02:33] VOIX: et à retenir
-[02:02:33] VOIX: bien sûr
-[02:02:35] VOIX: tu vois ça
-[02:02:36] VOIX: c'est un type
-[02:02:36] VOIX: c'est les contrôles
-[02:02:37] VOIX: qu'on fait
-[02:02:38] VOIX: c'est les contrôles
-[02:02:38] VOIX: qualité interne
-[02:02:41] VOIX: par exemple
-[02:02:41] VOIX: je prends celui-là
-[02:02:43] VOIX: et je vais voir
-[02:02:45] VOIX: un dossier
-[02:02:46] VOIX: tu vois la fiche là
-[02:02:49] VOIX: non ça c'est
-[02:02:51] VOIX: maintenant
-[02:02:51] VOIX: je vais prendre
-[02:02:52] VOIX: un qui est
-[02:02:55] VOIX: c'est un cauchement
-[02:02:57] VOIX: tu vois là
-[02:02:58] VOIX: par exemple
-[02:02:59] VOIX: dans ce dossier
-[02:03:01] VOIX: j'ai
-[02:03:01] VOIX: le groupage
-[02:03:03] VOIX: du séjour
-[02:03:04] VOIX: complet
-[02:03:04] VOIX: avec sa durée globale
-[02:03:06] VOIX: et je vois
-[02:03:07] VOIX: qu'il est passé
-[02:03:08] VOIX: par deux unités
-[02:03:09] VOIX: médicales
-[02:03:10] VOIX: de chirurgie
-[02:03:11] VOIX: l'unité 92
-[02:03:14] VOIX: et l'unité 91
-[02:03:16] VOIX: donc
-[02:03:17] VOIX: il manque codage
-[02:03:18] VOIX: la valeur globale
-[02:03:20] VOIX: elle est liée
-[02:03:20] VOIX: au séjour
-[02:03:21] VOIX: mais le codage
-[02:03:23] VOIX: par exemple
-[02:03:24] VOIX: ce codage là
-[02:03:25] VOIX: le diagnostic
-[02:03:26] VOIX: des actes
-[02:03:27] VOIX: il peut être différent
-[02:03:28] VOIX: tu vois
-[02:03:29] VOIX: il a changé
-[02:03:30] VOIX: 11 diagnostic
-[02:03:31] VOIX: machin
-[02:03:31] VOIX: là si je fais ça
-[02:03:33] VOIX: j'ai 26 diagnostics
-[02:03:35] VOIX: et au fait
-[02:03:36] VOIX: le groupage
-[02:03:37] VOIX: c'est ce qui va
-[02:03:38] VOIX: regrouper
-[02:03:38] VOIX: l'ensemble
-[02:03:40] VOIX: des codes
-[02:03:41] VOIX: qui sont faits
-[02:03:41] VOIX: dans chaque
-[02:03:42] VOIX: unité médicale
-[02:03:43] VOIX: donc ça la logique
-[02:03:44] VOIX: de l'unité médicale
-[02:03:45] VOIX: elle est aussi importante
-[02:03:46] VOIX: ça on n'en a pas
-[02:03:47] VOIX: beaucoup parlé
-[02:03:48] VOIX: mais
-[02:03:49] VOIX: ça
-[02:03:51] VOIX: en psychiatrie
-[02:03:52] VOIX: on va avoir
-[02:03:53] VOIX: beaucoup de découpage
-[02:03:54] VOIX: comme ça
-[02:03:55] VOIX: le codage
-[02:03:55] VOIX: il est lié
-[02:03:56] VOIX: à l'unité médicale
-[02:03:58] VOIX: et après
-[02:03:58] VOIX: c'est le groupeur
-[02:04:00] VOIX: qui va jouer
-[02:04:01] VOIX: d'accord
-[02:04:01] VOIX: mais là tu parles
-[02:04:02] VOIX: d'accord
-[02:04:03] VOIX: ok
-[02:04:04] VOIX: tu parles en fait
-[02:04:06] VOIX: le groupeur
-[02:04:07] VOIX: en fait
-[02:04:08] VOIX: le groupeur
-[02:04:08] VOIX: on est d'accord
-[02:04:09] VOIX: en fait
-[02:04:09] VOIX: c'est essentiellement
-[02:04:10] VOIX: en fait
-[02:04:10] VOIX: c'est ce qui va donner
-[02:04:11] VOIX: une tarification
-[02:04:13] VOIX: oui
-[02:04:13] VOIX: d'accord
-[02:04:14] VOIX: c'est lui qui va te dire
-[02:04:15] VOIX: ce DP
-[02:04:17] VOIX: avec
-[02:04:19] VOIX: ces actes là
-[02:04:20] VOIX: avec les actes classants
-[02:04:21] VOIX: je vais te montrer
-[02:04:22] VOIX: la notion d'actes classants
-[02:04:25] VOIX: je ne sais pas
-[02:04:25] VOIX: si je peux
-[02:04:26] VOIX: je ne peux pas grandir ça
-[02:04:27] VOIX: non mais c'est
-[02:04:28] VOIX: l'actes classants
-[02:04:28] VOIX: tu vois
-[02:04:29] VOIX: on a une étoile
-[02:04:30] VOIX: pour dire
-[02:04:30] VOIX: l'ostéosynthèse
-[02:04:32] VOIX: ça c'est un acte classant
-[02:04:34] VOIX: tu vois
-[02:04:34] VOIX: qu'il a eu de l'anesthésie
-[02:04:37] VOIX: classant classant
-[02:04:38] VOIX: et un acte
-[02:04:40] VOIX: qui n'est pas classant
-[02:04:41] VOIX: ben le scan
-[02:04:42] VOIX: le scan n'est pas classant
-[02:04:44] VOIX: ok
-[02:04:45] VOIX: il ne va pas jouer
-[02:04:45] VOIX: dans le groupeur
-[02:04:47] VOIX: et en fait
-[02:04:48] VOIX: comme j'ai des actes
-[02:04:49] VOIX: chirurgicaux ici
-[02:04:51] VOIX: mon séjour
-[02:04:52] VOIX: il est chirurgical
-[02:04:53] VOIX: c'est pour ça
-[02:04:54] VOIX: qu'il a un C
-[02:04:56] VOIX: 08C
-[02:04:58] VOIX: 08C
-[02:04:58] VOIX: c'est l'ortho
-[02:04:59] VOIX: le C
-[02:05:00] VOIX: c'est la Chir
-[02:05:02] VOIX: et ce 2 là
-[02:05:03] VOIX: le dernier
-[02:05:04] VOIX: il est lié
-[02:05:05] VOIX: au niveau de sévérité
-[02:05:07] VOIX: des diagnostics
-[02:05:08] VOIX: qui sont codés
-[02:05:09] VOIX: tu vois
-[02:05:10] VOIX: les listes
-[02:05:11] VOIX: les diagnostics associés
-[02:05:12] VOIX: il y en a
-[02:05:13] VOIX: qui n'ont pas
-[02:05:14] VOIX: de niveau de sévérité
-[02:05:15] VOIX: ils vont moins
-[02:05:16] VOIX: nous intéresser
-[02:05:17] VOIX: par exemple
-[02:05:17] VOIX: dans le contrôle sécu
-[02:05:18] VOIX: ils ne vont même pas
-[02:05:19] VOIX: les regarder
-[02:05:20] VOIX: parce qu'ils n'ont
-[02:05:21] VOIX: pas d'impact
-[02:05:22] VOIX: ils vont regarder
-[02:05:24] VOIX: beaucoup
-[02:05:24] VOIX: les diagnostics
-[02:05:25] VOIX: ça aussi
-[02:05:26] VOIX: c'est des référentiels
-[02:05:27] VOIX: les diagnostics
-[02:05:28] VOIX: qui ont des
-[02:05:29] VOIX: niveau 2
-[02:05:30] VOIX: quand tu parles
-[02:05:30] VOIX: de comorbidité
-[02:05:31] VOIX: c'est ça
-[02:05:31] VOIX: les CMA
-[02:05:32] VOIX: c'est les comorbidités
-[02:05:33] VOIX: c'est les diagnostics
-[02:05:34] VOIX: associés avec CMA
-[02:05:37] VOIX: 1, 2, 3 ou 4
-[02:05:38] VOIX: sur cette partie là
-[02:05:40] VOIX: je le verrai
-[02:05:40] VOIX: un peu plus tard
-[02:05:41] VOIX: parce que
-[02:05:43] VOIX: avant d'arriver là
-[02:05:44] VOIX: il faut que j'intègre
-[02:05:46] VOIX: tout ce qu'on a vu là
-[02:05:47] VOIX: aujourd'hui
-[02:05:48] VOIX: et puis surtout
-[02:05:49] VOIX: il faut que je vois
-[02:05:50] VOIX: avec Jordan
-[02:05:50] VOIX: parce qu'en fait
-[02:05:51] VOIX: sur cette partie là
-[02:05:52] VOIX: j'ai pris
-[02:05:54] VOIX: le groupeur officiel
-[02:05:59] VOIX: qui est vendu
-[02:06:01] VOIX: et comme je ne l'ai pas
-[02:06:02] VOIX: j'ai utilisé
-[02:06:03] VOIX: de le faire
-[02:06:05] VOIX: et c'était juste
-[02:06:06] VOIX: pour te montrer
-[02:06:07] VOIX: aussi la fiche
-[02:06:08] VOIX: si tu veux
-[02:06:08] VOIX: tu m'as parlé
-[02:06:09] VOIX: de la fiche
-[02:06:10] VOIX: on s'est inspiré
-[02:06:11] VOIX: de celle là
-[02:06:12] VOIX: pour travailler
-[02:06:12] VOIX: celle de l'OGC
-[02:06:14] VOIX: on retrouve
-[02:06:15] VOIX: tu vois la même logique
-[02:06:17] VOIX: avec les diagnostics
-[02:06:19] VOIX: leur niveau de sévérité
-[02:06:21] VOIX: etc
-[02:06:22] VOIX: et les actes
-[02:06:23] VOIX: avec les dates
-[02:06:24] VOIX: avec qu'est-ce qui sont classants
-[02:06:26] VOIX: est-ce qu'on a
-[02:06:27] VOIX: de l'anesthésie
-[02:06:27] VOIX: etc
-[02:06:28] VOIX: et on peut répartir
-[02:06:29] VOIX: par
-[02:06:30] VOIX: par
-[02:06:31] VOIX: unité médicale
-[02:06:32] VOIX: et juste un petit mot
-[02:06:33] VOIX: pour que tu vois
-[02:06:34] VOIX: que ça
-[02:06:35] VOIX: c'est du fait
-[02:06:37] VOIX: de l'identifiant patient
-[02:06:38] VOIX: qu'on peut dire
-[02:06:40] VOIX: que c'est le même patient
-[02:06:41] VOIX: qui a eu
-[02:06:42] VOIX: 3-6 jours
-[02:06:44] VOIX: tu sais
-[02:06:45] VOIX: quand on va travailler
-[02:06:46] VOIX: le parcours
-[02:06:49] VOIX: c'est ça en fait
-[02:06:50] VOIX: qu'il faut arriver
-[02:06:51] VOIX: à reconstituer
-[02:06:52] VOIX: pour dire
-[02:06:53] VOIX: parcours
-[02:06:54] VOIX: de douleur chronique
-[02:06:56] VOIX: là typiquement
-[02:06:58] VOIX: j'ai un dossier
-[02:06:59] VOIX: avec fracture
-[02:07:01] VOIX: du tibia
-[02:07:02] VOIX: donc une fracture
-[02:07:03] VOIX: il est revenu
-[02:07:05] VOIX: 4 mois après
-[02:07:06] VOIX: il a un problème
-[02:07:08] VOIX: du rein
-[02:07:08] VOIX: d'accord
-[02:07:09] VOIX: et il est revenu
-[02:07:11] VOIX: 2 mois après
-[02:07:12] VOIX: il a enlevé
-[02:07:14] VOIX: la plaque
-[02:07:15] VOIX: donc sur les 3 séjours
-[02:07:17] VOIX: il y a 2 qui sont
-[02:07:18] VOIX: en lien
-[02:07:19] VOIX: avec l'ortho
-[02:07:20] VOIX: parce que là
-[02:07:21] VOIX: il est revenu
-[02:07:22] VOIX: pour enlever la plaque
-[02:07:23] VOIX: alors qu'il faut
-[02:07:24] VOIX: que j'arrive à dire
-[02:07:25] VOIX: le rein
-[02:07:26] VOIX: quand il est venu
-[02:07:27] VOIX: pour le rein
-[02:07:29] VOIX: c'était pas
-[02:07:31] VOIX: tu vois
-[02:07:32] VOIX: il a fait
-[02:07:32] VOIX: un acte rénal
-[02:07:33] VOIX: ça ne fait pas partir
-[02:07:34] VOIX: du parcours ortho
-[02:07:36] VOIX: même si
-[02:07:37] VOIX: c'est un séjour
-[02:07:38] VOIX: du patient
-[02:07:38] VOIX: tu vois
-[02:07:39] VOIX: c'est ça
-[02:07:40] VOIX: que quand on va
-[02:07:41] VOIX: travailler les parcours
-[02:07:42] VOIX: il faudra aussi
-[02:07:43] VOIX: qu'on arrive à raisonner
-[02:07:44] VOIX: à relier
-[02:07:45] VOIX: etc
-[02:07:46] VOIX: ça c'est encore
-[02:07:47] VOIX: un autre niveau
-[02:07:49] VOIX: ok
-[02:07:51] VOIX: tu penses que
-[02:07:52] VOIX: l'IA
-[02:07:52] VOIX: va nous aider
-[02:07:53] VOIX: pour tout ça
-[02:07:56] VOIX: disons
-[02:07:57] VOIX: oui
-[02:07:57] VOIX: mais par contre
-[02:07:58] VOIX: en fait
-[02:07:59] VOIX: si tu veux
-[02:07:59] VOIX: il faut que
-[02:08:01] VOIX: qu'on arrive
-[02:08:02] VOIX: en fait
-[02:08:02] VOIX: que j'arrive
-[02:08:03] VOIX: moi
-[02:08:04] VOIX: une grosse partie
-[02:08:05] VOIX: de tout ce que tu as dit
-[02:08:05] VOIX: je l'ai déjà fait
-[02:08:07] VOIX: simplement
-[02:08:07] VOIX: en fait
-[02:08:07] VOIX: moi
-[02:08:08] VOIX: il faut que j'arrive
-[02:08:08] VOIX: à attraper
-[02:08:09] VOIX: en fait
-[02:08:09] VOIX: les trucs métiers
-[02:08:10] VOIX: et il faut que je les fasse
-[02:08:11] VOIX: en fait
-[02:08:12] VOIX: de façon séquentielle
-[02:08:13] VOIX: c'est en fait
-[02:08:13] VOIX: on va travailler
-[02:08:14] VOIX: en fait
-[02:08:15] VOIX: sur la psychiatrie
-[02:08:16] VOIX: par exemple
-[02:08:17] VOIX: ou ce genre de choses là
-[02:08:18] VOIX: on va
-[02:08:19] VOIX: on va étudier également
-[02:08:20] VOIX: en fait
-[02:08:20] VOIX: les relations
-[02:08:21] VOIX: qu'il y a entre les règles
-[02:08:22] VOIX: mais c'est
-[02:08:23] VOIX: c'est ce que tu m'as expliqué
-[02:08:24] VOIX: en fait tout à l'heure
-[02:08:25] VOIX: mais se concentrer
-[02:08:26] VOIX: en fait
-[02:08:27] VOIX: un par un
-[02:08:27] VOIX: parce qu'il y en a
-[02:08:28] VOIX: beaucoup
-[02:08:28] VOIX: en fait
-[02:08:29] VOIX: c'est pour ça que je te dis
-[02:08:30] VOIX: il y a beaucoup
-[02:08:30] VOIX: parce qu'en fonction
-[02:08:31] VOIX: de ton mode de réflexion
-[02:08:34] VOIX: change tout le temps
-[02:08:35] VOIX: en fonction en fait
-[02:08:36] VOIX: de certains nombres
-[02:08:37] VOIX: de critères
-[02:08:38] VOIX: qu'il faut que j'arrive
-[02:08:39] VOIX: à choper
-[02:08:39] VOIX: il y en a plein
-[02:08:40] VOIX: en fait
-[02:08:41] VOIX: que j'ai attrapé
-[02:08:41] VOIX: mais il y en a
-[02:08:42] VOIX: plein en fait
-[02:08:43] VOIX: qui m'échappent
-[02:08:44] VOIX: en fait
-[02:08:44] VOIX: pour le moment
-[02:08:44] VOIX: pourquoi
-[02:08:45] VOIX: parce que ça me fait
-[02:08:46] VOIX: beaucoup d'informations
-[02:08:47] VOIX: en fait
-[02:08:47] VOIX: en même temps
-[02:08:48] VOIX: donc là
-[02:08:48] VOIX: ce que je vais faire
-[02:08:50] VOIX: en fait
-[02:08:50] VOIX: quand on va arrêter
-[02:08:51] VOIX: c'est que je vais prendre
-[02:08:52] VOIX: en fait
-[02:08:53] VOIX: tout ce que tu m'as raconté
-[02:08:54] VOIX: je vais le mettre
-[02:08:55] VOIX: dans la machine
-[02:08:55] VOIX: je vais lui demander
-[02:08:56] VOIX: de me faire une synthèse
-[02:08:58] VOIX: de tout ce que l'on a vu
-[02:08:59] VOIX: et à chaque fois
-[02:09:01] VOIX: en fait
-[02:09:01] VOIX: que tu m'as montré
-[02:09:02] VOIX: en fait des choses
-[02:09:03] VOIX: j'ai pris des photos
-[02:09:04] VOIX: et donc en fait
-[02:09:05] VOIX: je vais demander
-[02:09:05] VOIX: en fait
-[02:09:06] VOIX: de faire un rapprochement
-[02:09:07] VOIX: entre les photos
-[02:09:08] VOIX: entre
-[02:09:09] VOIX: voilà
-[02:09:09] VOIX: mais pas toi
-[02:09:10] VOIX: en fait
-[02:09:10] VOIX: en photo
-[02:09:10] VOIX: la capture d'écran
-[02:09:14] VOIX: tu peux
-[02:09:14] VOIX: tu peux
-[02:09:15] VOIX: attends
-[02:09:16] VOIX: j'ai encore
-[02:09:17] VOIX: une ou deux questions
-[02:09:18] VOIX: mais pas plus
-[02:09:19] VOIX: c'est fini
-[02:09:19] VOIX: d'accord
-[02:09:22] VOIX: oui alors
-[02:09:25] VOIX: j'avais noté
-[02:09:27] VOIX: ça
-[02:09:27] VOIX: on va pas
-[02:09:28] VOIX: quelqu'un
-[02:09:29] VOIX: oui
-[02:09:30] VOIX: quels sont
-[02:09:31] VOIX: ouais
-[02:09:34] VOIX: il y a des cas
-[02:09:35] VOIX: alors
-[02:09:36] VOIX: on va pas
-[02:09:36] VOIX: le développer
-[02:09:37] VOIX: maintenant
-[02:09:38] VOIX: mais c'est
-[02:09:38] VOIX: quels sont
-[02:09:39] VOIX: en fait
-[02:09:39] VOIX: les moyens
-[02:09:40] VOIX: puisque tu parlais
-[02:09:40] VOIX: en fait
-[02:09:41] VOIX: d'IA
-[02:09:41] VOIX: quels sont
-[02:09:42] VOIX: en fait
-[02:09:42] VOIX: les cas
-[02:09:43] VOIX: où
-[02:09:44] VOIX: obligatoirement
-[02:09:44] VOIX: en fait
-[02:09:45] VOIX: tu veux garder
-[02:09:45] VOIX: la main
-[02:09:45] VOIX: en fait
-[02:09:46] VOIX: sur l'étude
-[02:09:46] VOIX: en fait
-[02:09:47] VOIX: du dossier
-[02:09:48] VOIX: ou toi
-[02:09:49] VOIX: en tant que team
-[02:09:51] VOIX: aucun
-[02:09:53] VOIX: bon
-[02:09:53] VOIX: j'en étais
-[02:09:53] VOIX: à peu près sûr
-[02:09:54] VOIX: mais
-[02:09:56] VOIX: le boulot
-[02:09:56] VOIX: le boulot
-[02:09:57] VOIX: peut être
-[02:09:57] VOIX: refait
-[02:09:57] VOIX: sans que
-[02:09:58] VOIX: je garde
-[02:09:59] VOIX: la main
-[02:09:59] VOIX: alors là
-[02:10:01] VOIX: c'est le top
-[02:10:02] VOIX: du top
-[02:10:02] VOIX: il y en a
-[02:10:03] VOIX: non mais
-[02:10:03] VOIX: alors
-[02:10:04] VOIX: pourquoi en fait
-[02:10:05] VOIX: je te dis ça
-[02:10:06] VOIX: c'est parce que
-[02:10:06] VOIX: c'est parce que
-[02:10:08] VOIX: si tu as en fait
-[02:10:09] VOIX: alors attends
-[02:10:09] VOIX: parce que ça fait partie
-[02:10:10] VOIX: de mes questions
-[02:10:11] VOIX: en fait
-[02:10:11] VOIX: dernière question
-[02:10:15] VOIX: ouais
-[02:10:15] VOIX: c'est l'indicateur
-[02:10:16] VOIX: parce qu'en fait
-[02:10:17] VOIX: ça va me donner
-[02:10:17] VOIX: en fait
-[02:10:18] VOIX: parce que moi
-[02:10:19] VOIX: quand je fais
-[02:10:19] VOIX: en fait
-[02:10:20] VOIX: le traitement
-[02:10:20] VOIX: en fait
-[02:10:21] VOIX: des documents
-[02:10:22] VOIX: je sors en fait
-[02:10:23] VOIX: un indicateur
-[02:10:25] VOIX: ou un scoring
-[02:10:26] VOIX: tu sais
-[02:10:26] VOIX: on avait parlé de ça
-[02:10:27] VOIX: c'est à dire que
-[02:10:28] VOIX: si je suis pas sûr
-[02:10:29] VOIX: en fait
-[02:10:29] VOIX: à tant de pourcents
-[02:10:30] VOIX: je mets de côté
-[02:10:32] VOIX: maintenant
-[02:10:32] VOIX: au fur et à mesure
-[02:10:33] VOIX: où j'avance moi
-[02:10:34] VOIX: j'arrive à avoir
-[02:10:35] VOIX: en fait
-[02:10:35] VOIX: des comparaisons
-[02:10:37] VOIX: elles sont pas
-[02:10:38] VOIX: nécessairement
-[02:10:38] VOIX: super bonnes
-[02:10:39] VOIX: pour le moment
-[02:10:39] VOIX: j'attends d'avoir
-[02:10:40] VOIX: en fait
-[02:10:40] VOIX: la suite logique
-[02:10:41] VOIX: dont je t'avais parlé
-[02:10:42] VOIX: donc le
-[02:10:45] VOIX: le compte rendu
-[02:10:46] VOIX: en fait
-[02:10:46] VOIX: le compte rendu
-[02:10:47] VOIX: est arrivé
-[02:10:48] VOIX: en fait
-[02:10:48] VOIX: jusqu'au codage
-[02:10:49] VOIX: qui a été vérifié
-[02:10:51] VOIX: contrôlé par la suite
-[02:10:52] VOIX: si j'ai ça
-[02:10:53] VOIX: en fait
-[02:10:53] VOIX: ça va être top
-[02:10:54] VOIX: mais je fais du scoring
-[02:10:56] VOIX: et si tu veux
-[02:10:57] VOIX: il y en a beaucoup
-[02:10:58] VOIX: en fait
-[02:10:58] VOIX: beaucoup
-[02:10:59] VOIX: ou pas beaucoup
-[02:10:59] VOIX: j'en sais
-[02:11:00] VOIX: pas grand chose
-[02:11:01] VOIX: là pour le moment
-[02:11:01] VOIX: mais que je vais écarter
-[02:11:03] VOIX: mais au fur et à mesure
-[02:11:04] VOIX: en fait
-[02:11:04] VOIX: mon scoring
-[02:11:04] VOIX: va diminuer
-[02:11:05] VOIX: c'est en fait
-[02:11:06] VOIX: au lieu d'avoir
-[02:11:07] VOIX: en fait
-[02:11:07] VOIX: que 80%
-[02:11:08] VOIX: de résultats bons
-[02:11:09] VOIX: j'en suis à 75%
-[02:11:11] VOIX: pour le moment
-[02:11:12] VOIX: ben je vais arriver
-[02:11:13] VOIX: près de 100%
-[02:11:15] VOIX: mais il y a des dossiers
-[02:11:16] VOIX: en fait
-[02:11:17] VOIX: typiquement
-[02:11:18] VOIX: sur lequel
-[02:11:18] VOIX: en fait
-[02:11:19] VOIX: il va vouloir
-[02:11:19] VOIX: prendre la main
-[02:11:20] VOIX: pour des tonnes de raisons
-[02:11:21] VOIX: pour l'instant
-[02:11:22] VOIX: par exemple
-[02:11:22] VOIX: quand il va y avoir
-[02:11:24] VOIX: le groupeur
-[02:11:25] VOIX: et qu'il va y avoir
-[02:11:26] VOIX: un dossier
-[02:11:27] VOIX: qui va être évalué
-[02:11:28] VOIX: alors s'il est évalué
-[02:11:29] VOIX: à 50 euros
-[02:11:30] VOIX: ça peut aller
-[02:11:30] VOIX: mais s'il est évalué
-[02:11:32] VOIX: en fait
-[02:11:32] VOIX: à 150 000 euros
-[02:11:33] VOIX: ou 200 000 euros
-[02:11:34] VOIX: je sais pas en fait
-[02:11:34] VOIX: je dis des bêtises
-[02:11:36] VOIX: peut-être que ce dossier
-[02:11:37] VOIX: là en particulier
-[02:11:38] VOIX: tu vas préférer
-[02:11:39] VOIX: en fait
-[02:11:39] VOIX: le travailler
-[02:11:40] VOIX: tout seul
-[02:11:40] VOIX: pour ne pas avoir
-[02:11:41] VOIX: d'embêtement
-[02:11:42] VOIX: par la suite
-[02:11:43] VOIX: ou le vérifier
-[02:11:45] VOIX: ou le contrôler
-[02:11:47] VOIX: ou le valider
-[02:11:48] VOIX: parce que
-[02:11:49] VOIX: nous le progrès
-[02:11:51] VOIX: si tout va bien
-[02:11:52] VOIX: c'est au delà
-[02:11:54] VOIX: de proposer des codes
-[02:11:55] VOIX: c'est les intégrer
-[02:11:56] VOIX: dans le dossier
-[02:11:58] VOIX: et peut-être
-[02:11:59] VOIX: valider
-[02:12:00] VOIX: ou pas le dossier
-[02:12:00] VOIX: voilà
-[02:12:01] VOIX: ça c'est
-[02:12:03] VOIX: une priorisation
-[02:12:05] VOIX: pour dire
-[02:12:06] VOIX: certains dossiers
-[02:12:07] VOIX: seront obligatoirement
-[02:12:08] VOIX: à valider
-[02:12:09] VOIX: par l'humain
-[02:12:09] VOIX: tu fais toujours
-[02:12:10] VOIX: ta proposition initiale
-[02:12:11] VOIX: oui
-[02:12:12] VOIX: ben c'est
-[02:12:12] VOIX: voilà
-[02:12:13] VOIX: en fait
-[02:12:13] VOIX: l'étape que tu viens
-[02:12:14] VOIX: de donner
-[02:12:14] VOIX: en fait
-[02:12:15] VOIX: j'ai fait un raccourci
-[02:12:16] VOIX: en fait
-[02:12:17] VOIX: en 100 ans
-[02:12:18] VOIX: mais
-[02:12:18] VOIX: voilà
-[02:12:19] VOIX: il y en a peut-être
-[02:12:19] VOIX: certaines
-[02:12:20] VOIX: donc ok
-[02:12:21] VOIX: merci
-[02:12:23] VOIX: attends
-[02:12:27] VOIX: ok
-[02:12:28] VOIX: ben écoute
-[02:12:28] VOIX: ben là
-[02:12:29] VOIX: on va arrêter
-[02:12:30] VOIX: pour aujourd'hui
-[02:12:31] VOIX: en fait
-[02:12:32] VOIX: sur ce sujet
-[02:12:33] VOIX: parce que
-[02:12:33] VOIX: c'est bon
-[02:12:37] VOIX: non c'est
-[02:12:38] VOIX: c'est dense
-[02:12:39] VOIX: et puis en plus
-[02:12:39] VOIX: comme je suis
-[02:12:40] VOIX: en plein dedans
-[02:12:41] VOIX: il y a plein de choses
-[02:12:42] VOIX: en fait
-[02:12:42] VOIX: que tu m'as dit
-[02:12:43] VOIX: où je suis assez content
-[02:12:44] VOIX: en fait
-[02:12:44] VOIX: parce que je me dis
-[02:12:45] VOIX: là j'ai bien chopé le truc
-[02:12:46] VOIX: et j'ai bien
-[02:12:47] VOIX: en fait
-[02:12:48] VOIX: il y a d'autres choses
-[02:12:49] VOIX: où je me dis
-[02:12:50] VOIX: ah merde
-[02:12:52] VOIX: j'oublie
-[02:12:53] VOIX: mais bon
-[02:12:54] VOIX: après en fait
-[02:12:54] VOIX: c'est justement
-[02:12:55] VOIX: là
-[02:12:55] VOIX: on est
-[02:12:56] VOIX: on va être
-[02:12:58] VOIX: sur des itérations
-[02:12:58] VOIX: c'est à dire
-[02:12:59] VOIX: aujourd'hui
-[02:12:59] VOIX: on a pris un peu de temps
-[02:13:00] VOIX: après en fait
-[02:13:01] VOIX: on va s'échanger
-[02:13:02] VOIX: en fait des mails
-[02:13:03] VOIX: mais tu vas voir
-[02:13:04] VOIX: je vais essayer
-[02:13:04] VOIX: de faire des trucs
-[02:13:05] VOIX: ultra courts
-[02:13:08] VOIX: je finis
-[02:13:08] VOIX: je vais essayer
-[02:13:09] VOIX: en fait de faire
-[02:13:10] VOIX: en fait
-[02:13:10] VOIX: ce qu'on appelle
-[02:13:10] VOIX: des itérations
-[02:13:11] VOIX: mais courtes
-[02:13:12] VOIX: c'est jette moi
-[02:13:13] VOIX: un coup d'oeil
-[02:13:13] VOIX: en fait sur ça
-[02:13:14] VOIX: et toi en fait
-[02:13:15] VOIX: en deux secondes
-[02:13:15] VOIX: tu vas le faire
-[02:13:16] VOIX: le truc que j'aurais
-[02:13:17] VOIX: mis deux heures
-[02:13:17] VOIX: à faire
-[02:13:19] VOIX: et toi tu vas me dire
-[02:13:20] VOIX: c'est bon
-[02:13:20] VOIX: ou c'est pas bon
-[02:13:21] VOIX: parce que là
-[02:13:21] VOIX: je sais que
-[02:13:22] VOIX: voilà
-[02:13:23] VOIX: est-ce que tu penses
-[02:13:27] VOIX: que si je te donne
-[02:13:28] VOIX: par exemple
-[02:13:29] VOIX: le chapitre
-[02:13:30] VOIX: un ou deux chapitres
-[02:13:32] VOIX: du guide méthodologique
-[02:13:33] VOIX: le 5
-[02:13:35] VOIX: le 6
-[02:13:36] VOIX: tu sais
-[02:13:37] VOIX: je vais repartager
-[02:13:38] VOIX: juste pour que tu vois
-[02:13:39] VOIX: de quoi je parle
-[02:13:41] VOIX: et indépendamment
-[02:13:42] VOIX: d'où c'est passé
-[02:13:44] VOIX: c'est une question
-[02:13:45] VOIX: je n'en sais rien
-[02:13:46] VOIX: si c'est utile ou pas
-[02:13:48] VOIX: parce que
-[02:13:50] VOIX: non ça c'est 3
-[02:13:51] VOIX: 4
-[02:13:52] VOIX: à partir de là
-[02:13:53] VOIX: le chapitre 5
-[02:13:55] VOIX: parce qu'il définit
-[02:13:57] VOIX: non non
-[02:13:58] VOIX: même ça
-[02:13:58] VOIX: c'est quoi un DP
-[02:13:59] VOIX: c'est quoi un DAS
-[02:14:02] VOIX: les règles
-[02:14:03] VOIX: de certains codages
-[02:14:04] VOIX: le chapitre 5
-[02:14:06] VOIX: et le chapitre 6
-[02:14:08] VOIX: c'est celui des guides
-[02:14:10] VOIX: des situations cliniques
-[02:14:11] VOIX: au moins
-[02:14:13] VOIX: le 4
-[02:14:14] VOIX: et le 6
-[02:14:15] VOIX: est-ce que tu penses
-[02:14:17] VOIX: que s'il y a réfléchi
-[02:14:19] VOIX: sur un raisonnement
-[02:14:20] VOIX: sur une synthèse
-[02:14:23] VOIX: sur des
-[02:14:24] VOIX: une autre manière
-[02:14:25] VOIX: de
-[02:14:27] VOIX: de
-[02:14:29] VOIX: de
-[02:14:29] VOIX: de
-[02:14:29] VOIX: nommer
-[02:14:30] VOIX: ou de
-[02:14:31] VOIX: de traiter les règles
-[02:14:33] VOIX: c'est quelque chose
-[02:14:34] VOIX: qui peut être utile
-[02:14:35] VOIX: c'est une question
-[02:14:36] VOIX: je n'en sais rien moi
-[02:14:37] VOIX: mais
-[02:14:37] VOIX: c'est une question
-[02:14:39] VOIX: bah écoute
-[02:14:39] VOIX: comme moi
-[02:14:40] VOIX: mes questions
-[02:14:40] VOIX: elles n'étaient pas si bêtes
-[02:14:42] VOIX: la tienne
-[02:14:42] VOIX: elle n'est pas bête non plus
-[02:14:44] VOIX: en fait
-[02:14:44] VOIX: oui c'est utile
-[02:14:45] VOIX: c'est utile
-[02:14:46] VOIX: mais
-[02:14:47] VOIX: ce document là
-[02:14:48] VOIX: moi je l'ai déjà
-[02:14:48] VOIX: enfin
-[02:14:49] VOIX: les règles
-[02:14:50] VOIX: que tu m'as données
-[02:14:50] VOIX: sont déjà intégrées
-[02:14:52] VOIX: en fait dans le système
-[02:14:53] VOIX: mais
-[02:14:53] VOIX: moi ce qui me manque
-[02:14:54] VOIX: c'est
-[02:14:55] VOIX: je reviens toujours là dessus
-[02:14:56] VOIX: mais je vais les avoir
-[02:14:57] VOIX: en fait
-[02:14:57] VOIX: demain
-[02:14:57] VOIX: je ne vais pas embêter
-[02:14:58] VOIX: Jordan et Pauline
-[02:15:00] VOIX: en fait
-[02:15:00] VOIX: parce que je sais
-[02:15:00] VOIX: qu'ils ont beaucoup de boulot
-[02:15:01] VOIX: mais moi en fait
-[02:15:02] VOIX: à partir du moment
-[02:15:03] VOIX: où je vais avoir
-[02:15:03] VOIX: en fait
-[02:15:04] VOIX: je ne sais pas
-[02:15:05] VOIX: 300
-[02:15:05] VOIX: 400 en fait
-[02:15:06] VOIX: documents
-[02:15:06] VOIX: avec la chaîne complète
-[02:15:08] VOIX: le compte rendu
-[02:15:09] VOIX: le codage
-[02:15:10] VOIX: etc
-[02:15:10] VOIX: moi dès que j'ai ça
-[02:15:12] VOIX: après je ne casse plus
-[02:15:13] VOIX: les pieds à personne
-[02:15:14] VOIX: pendant un petit moment
-[02:15:15] VOIX: et après en fait
-[02:15:16] VOIX: on fait des tests
-[02:15:19] VOIX: oui mais
-[02:15:20] VOIX: Dominique
-[02:15:21] VOIX: tu l'as ça
-[02:15:22] VOIX: tu l'as
-[02:15:24] VOIX: ah bon
-[02:15:24] VOIX: et bien oui
-[02:15:26] VOIX: on a 700 dossiers
-[02:15:28] VOIX: du CHCB
-[02:15:29] VOIX: non
-[02:15:29] VOIX: j'ai le compte rendu
-[02:15:30] VOIX: j'ai le tracker
-[02:15:32] VOIX: non
-[02:15:33] VOIX: t'as vu ce qu'on a
-[02:15:34] VOIX: ouvert tout à l'heure
-[02:15:35] VOIX: le 339
-[02:15:36] VOIX: oui
-[02:15:37] VOIX: le tracker
-[02:15:38] VOIX: t'avais
-[02:15:39] VOIX: les résultats de biologie
-[02:15:41] VOIX: t'avais les comptes rendus
-[02:15:42] VOIX: des examens complémentaires
-[02:15:44] VOIX: d'imagerie
-[02:15:45] VOIX: t'avais les observations
-[02:15:46] VOIX: médicales
-[02:15:47] VOIX: t'avais les notes
-[02:15:48] VOIX: des IDE
-[02:15:49] VOIX: donc t'avais tout le dossier
-[02:15:50] VOIX: les trackers
-[02:15:51] VOIX: ils sont pas que
-[02:15:52] VOIX: les comptes rendus
-[02:15:53] VOIX: les trackers
-[02:15:54] VOIX: c'est
-[02:15:54] VOIX: la majorité
-[02:15:55] VOIX: je le sais
-[02:15:56] VOIX: je le sais
-[02:15:58] VOIX: mais j'ai pas
-[02:15:59] VOIX: toute la chaîne
-[02:16:00] VOIX: il me manque en fait
-[02:16:01] VOIX: ce que t'appelles
-[02:16:03] VOIX: le primo codage
-[02:16:03] VOIX: ou le codage
-[02:16:04] VOIX: en fait
-[02:16:04] VOIX: qui a été fait
-[02:16:05] VOIX: ça c'est facile
-[02:16:06] VOIX: on a le fichier
-[02:16:07] VOIX: PM ici
-[02:16:07] VOIX: on a tout
-[02:16:09] VOIX: on a tout
-[02:16:10] VOIX: on a les documents
-[02:16:11] VOIX: parce que
-[02:16:12] VOIX: si tu le veux
-[02:16:13] VOIX: je sais pas
-[02:16:14] VOIX: comment tu le veux
-[02:16:15] VOIX: parce que
-[02:16:16] VOIX: il y a trois manières
-[02:16:18] VOIX: de l'avoir
-[02:16:19] VOIX: il y a
-[02:16:20] VOIX: Jordan par exemple
-[02:16:22] VOIX: il a travaillé
-[02:16:23] VOIX: sur les fichiers
-[02:16:24] VOIX: bruts
-[02:16:24] VOIX: il sort
-[02:16:26] VOIX: pour Guy
-[02:16:27] VOIX: des fichiers
-[02:16:29] VOIX: plats
-[02:16:29] VOIX: en Excel
-[02:16:30] VOIX: que Guy travaille
-[02:16:31] VOIX: pour me sortir
-[02:16:32] VOIX: les rapports
-[02:16:33] VOIX: d'activité
-[02:16:33] VOIX: donc on peut
-[02:16:35] VOIX: très bien
-[02:16:35] VOIX: demander à Jordan
-[02:16:36] VOIX: je vais regarder
-[02:16:37] VOIX: d'ailleurs
-[02:16:38] VOIX: le fichier
-[02:16:38] VOIX: avec Guy
-[02:16:38] VOIX: si tu as
-[02:16:40] VOIX: numéro de dossier
-[02:16:41] VOIX: par numéro de dossier
-[02:16:42] VOIX: pour les OGC
-[02:16:43] VOIX: parce que c'est ça
-[02:16:44] VOIX: qu'on va demander
-[02:16:45] VOIX: à Jordan
-[02:16:45] VOIX: prendre le druide
-[02:16:47] VOIX: de CHCB
-[02:16:48] VOIX: coller le numéro
-[02:16:50] VOIX: OGC
-[02:16:50] VOIX: le numéro de dossier
-[02:16:51] VOIX: me dire
-[02:16:52] VOIX: c'est quoi mon DP
-[02:16:53] VOIX: c'est quoi mon DAS
-[02:16:54] VOIX: voilà
-[02:16:55] VOIX: en fait
-[02:16:56] VOIX: j'ai besoin
-[02:16:56] VOIX: de ça
-[02:16:57] VOIX: peu importe
-[02:16:59] VOIX: le format
-[02:16:59] VOIX: en fait
-[02:16:59] VOIX: j'ai besoin
-[02:17:00] VOIX: de ça
-[02:17:00] VOIX: comme ça
-[02:17:01] VOIX: en fait
-[02:17:01] VOIX: j'ai toute la chaîne
-[02:17:03] VOIX: parce qu'en fait
-[02:17:03] VOIX: si tu veux
-[02:17:04] VOIX: pour entraîner
-[02:17:04] VOIX: une IA
-[02:17:05] VOIX: en fait
-[02:17:05] VOIX: ce dont je t'avais
-[02:17:06] VOIX: parlé
-[02:17:06] VOIX: en fait
-[02:17:07] VOIX: on a besoin
-[02:17:08] VOIX: d'avoir
-[02:17:08] VOIX: toute la chaîne
-[02:17:10] VOIX: et en plus
-[02:17:11] VOIX: là
-[02:17:12] VOIX: parce que
-[02:17:12] VOIX: pour moi
-[02:17:13] VOIX: en fait
-[02:17:13] VOIX: je dis
-[02:17:13] VOIX: on a de la chance
-[02:17:14] VOIX: parce qu'en fait
-[02:17:15] VOIX: vous avez tout
-[02:17:17] VOIX: ah oui
-[02:17:17] VOIX: donc
-[02:17:20] VOIX: c'est bien
-[02:17:22] VOIX: attends
-[02:17:23] VOIX: je vais te montrer
-[02:17:23] VOIX: juste un fichier
-[02:17:26] VOIX: si j'arrive
-[02:17:27] VOIX: à l'ouvrir
-[02:17:30] VOIX: euh
-[02:17:33] VOIX: recettes
-[02:17:33] VOIX: c'est
-[02:17:34] VOIX: jour
-[02:17:39] VOIX: non
-[02:17:40] VOIX: j'arrive pas
-[02:17:40] VOIX: à l'ouvrir
-[02:17:41] VOIX: non
-[02:17:42] VOIX: j'arrive pas
-[02:17:43] VOIX: à l'ouvrir
-[02:17:43] VOIX: donc
-[02:17:45] VOIX: voilà
-[02:17:45] VOIX: j'ai besoin
-[02:17:46] VOIX: ça c'est
-[02:17:47] VOIX: demain
-[02:17:48] VOIX: je laisserai
-[02:17:49] VOIX: le message
-[02:17:49] VOIX: à Guy
-[02:17:50] VOIX: je sais si tu seras
-[02:17:51] VOIX: avec eux
-[02:17:52] VOIX: pour la réunion
-[02:17:53] VOIX: qu'ils ont prévu
-[02:17:54] VOIX: Guy et Jordan
-[02:17:54] VOIX: oui
-[02:17:57] VOIX: donc en tout cas
-[02:17:58] VOIX: c'est facile
-[02:17:58] VOIX: ça
-[02:17:59] VOIX: si tu veux
-[02:18:00] VOIX: oui c'est facile
-[02:18:01] VOIX: mais je l'ai pas
-[02:18:02] VOIX: donc
-[02:18:02] VOIX: mais après
-[02:18:02] VOIX: en fait
-[02:18:03] VOIX: voilà
-[02:18:03] VOIX: demain
-[02:18:04] VOIX: je vais demander
-[02:18:04] VOIX: à Jordan
-[02:18:05] VOIX: non
-[02:18:05] VOIX: parce que
-[02:18:07] VOIX: il n'oublie pas
-[02:18:08] VOIX: que dans
-[02:18:08] VOIX: les 100
-[02:18:10] VOIX: les 700 dossiers
-[02:18:11] VOIX: du CHCB
-[02:18:12] VOIX: il y a
-[02:18:13] VOIX: au moins
-[02:18:14] VOIX: je pense
-[02:18:14] VOIX: 150
-[02:18:16] VOIX: où le médecin contrôleur
-[02:18:17] VOIX: il est ok
-[02:18:17] VOIX: avec l'établissement
-[02:18:18] VOIX: donc il n'y a pas
-[02:18:19] VOIX: de changement
-[02:18:20] VOIX: et c'est
-[02:18:21] VOIX: ceux-là
-[02:18:21] VOIX: que tu peux dire
-[02:18:23] VOIX: j'ai pas
-[02:18:24] VOIX: les allers-retours
-[02:18:24] VOIX: etc
-[02:18:25] VOIX: je peux déjà
-[02:18:26] VOIX: commencer par
-[02:18:27] VOIX: ceux-là
-[02:18:27] VOIX: pour appliquer
-[02:18:28] VOIX: les règles
-[02:18:29] VOIX: ceux qui n'ont
-[02:18:29] VOIX: pas été modifiés
-[02:18:31] VOIX: mais je vais réfléchir
-[02:18:32] VOIX: à un format
-[02:18:34] VOIX: nous aussi
-[02:18:35] VOIX: on travaille
-[02:18:35] VOIX: sur ça
-[02:18:36] VOIX: moi le format
-[02:18:37] VOIX: en fait
-[02:18:37] VOIX: si c'est du CSV
-[02:18:38] VOIX: ou de l'Excel
-[02:18:39] VOIX: ou même un PDF
-[02:18:40] VOIX: en fait
-[02:18:40] VOIX: moi ça me va
-[02:18:41] VOIX: je n'ai pas
-[02:18:42] VOIX: de soucis
-[02:18:43] VOIX: de formatage
-[02:18:43] VOIX: d'accord
-[02:18:44] VOIX: il n'y a pas
-[02:18:44] VOIX: il n'y a pas
-[02:18:45] VOIX: de soucis
-[02:18:47] VOIX: non je n'ai pas
-[02:18:47] VOIX: de soucis
-[02:18:48] VOIX: de formatage
-[02:18:48] VOIX: d'ailleurs
-[02:18:49] VOIX: en fait
-[02:18:49] VOIX: en ce qui concerne
-[02:18:51] VOIX: les
-[02:19:04] VOIX: parce qu'en fait
-[02:19:05] VOIX: le document
-[02:19:07] VOIX: en fait
-[02:19:07] VOIX: que Harid
-[02:19:08] VOIX: a fait
-[02:19:08] VOIX: avait quelques
-[02:19:09] VOIX: petits trous
-[02:19:09] VOIX: dans la raquette
-[02:19:10] VOIX: c'est les
-[02:19:11] VOIX: i
-[02:19:11] VOIX: qui confond
-[02:19:12] VOIX: avec les
-[02:19:13] VOIX: 1
-[02:19:13] VOIX: ou ce genre
-[02:19:14] VOIX: de choses
-[02:19:14] VOIX: là
-[02:19:14] VOIX: donc j'ai
-[02:19:15] VOIX: corrigé
-[02:19:15] VOIX: le programme
-[02:19:16] VOIX: en entier
-[02:19:17] VOIX: et le programme
-[02:19:17] VOIX: que j'ai fait
-[02:19:20] VOIX: prend en fait
-[02:19:20] VOIX: n'importe quel
-[02:19:21] VOIX: type de document
-[02:19:22] VOIX: et le réorganise
-[02:19:23] VOIX: complètement
-[02:19:23] VOIX: et voilà
-[02:19:25] VOIX: Harid
-[02:19:25] VOIX: tu sais
-[02:19:26] VOIX: qu'il est absent
-[02:19:26] VOIX: pendant un mois
-[02:19:27] VOIX: et bien oui
-[02:19:28] VOIX: je sais
-[02:19:28] VOIX: il revient
-[02:19:29] VOIX: un mois
-[02:19:30] VOIX: et il repart
-[02:19:30] VOIX: 3-4 mois
-[02:19:32] VOIX: donc c'est vrai
-[02:19:32] VOIX: qu'il faut que
-[02:19:34] VOIX: bon ça c'est un sujet
-[02:19:35] VOIX: à part entière
-[02:19:36] VOIX: mais on en reparlera
-[02:19:37] VOIX: tranquillement
-[02:19:39] VOIX: parce que je pense
-[02:19:40] VOIX: que pendant
-[02:19:40] VOIX: 5 à 6 mois
-[02:19:41] VOIX: je ne sais pas
-[02:19:42] VOIX: s'il va continuer
-[02:19:43] VOIX: à faire des choses
-[02:19:44] VOIX: pour nous
-[02:19:45] VOIX: alors je te le dis
-[02:19:46] VOIX: pour avoir discuté
-[02:19:47] VOIX: avec lui
-[02:19:48] VOIX: longuement
-[02:19:48] VOIX: en fait il a dit
-[02:19:51] VOIX: oui mais bon
-[02:19:52] VOIX: là où je vais
-[02:19:53] VOIX: je vais travailler
-[02:19:54] VOIX: en fait je vais arrêter
-[02:19:55] VOIX: de travailler
-[02:19:55] VOIX: à 15h ou 16h
-[02:19:57] VOIX: et après en fait
-[02:19:58] VOIX: parce qu'il aime bien
-[02:19:59] VOIX: travailler
-[02:19:59] VOIX: parce que moi
-[02:20:00] VOIX: je l'ai en fait
-[02:20:01] VOIX: on échange beaucoup
-[02:20:02] VOIX: je lui donne
-[02:20:03] VOIX: en fait des trucs
-[02:20:04] VOIX: et astuces
-[02:20:04] VOIX: en fait des choses
-[02:20:05] VOIX: comme ça
-[02:20:06] VOIX: donc il m'a dit
-[02:20:08] VOIX: en fait
-[02:20:08] VOIX: je veux continuer
-[02:20:09] VOIX: je le mets en copie
-[02:20:10] VOIX: en fait de tout
-[02:20:11] VOIX: ce que je fais
-[02:20:11] VOIX: de la manière
-[02:20:12] VOIX: à ce qu'il ait
-[02:20:13] VOIX: les infos
-[02:20:14] VOIX: d'accord
-[02:20:15] VOIX: donc il continue
-[02:20:16] VOIX: quand même
-[02:20:16] VOIX: à avancer
-[02:20:17] VOIX: sur 2-3 petits sujets
-[02:20:18] VOIX: oui
-[02:20:19] VOIX: mais je vais essayer
-[02:20:20] VOIX: de l'appeler
-[02:20:21] VOIX: peut-être pas là
-[02:20:21] VOIX: mais en tout cas
-[02:20:23] VOIX: pour caler un peu
-[02:20:25] VOIX: là où on peut
-[02:20:26] VOIX: encore compter
-[02:20:27] VOIX: sur lui
-[02:20:27] VOIX: ou pas
-[02:20:28] VOIX: tu vois
-[02:20:29] VOIX: parce que c'est
-[02:20:29] VOIX: important quand même
-[02:20:31] VOIX: il a dit
-[02:20:31] VOIX: c'est qu'il ne voulait
-[02:20:32] VOIX: pas
-[02:20:32] VOIX: surtout pas perdre
-[02:20:33] VOIX: notre fil
-[02:20:33] VOIX: de continuer
-[02:20:35] VOIX: dans mon sens
-[02:20:36] VOIX: la logique
-[02:20:37] VOIX: quand même
-[02:20:38] VOIX: il sera ailleurs
-[02:20:39] VOIX: ok
-[02:20:40] VOIX: bon on peut clôturer
-[02:20:42] VOIX: nous clôturons
-[02:20:43] VOIX: nous clôturons
-[02:20:44] VOIX: bon en tout cas
-[02:20:45] VOIX: merci beaucoup
-[02:20:46] VOIX: pour toutes ces explications
-[02:20:47] VOIX: dans quelques minutes
-[02:20:48] VOIX: je envoie en fait
-[02:20:49] VOIX: mes questions
-[02:20:50] VOIX: d'accord
-[02:20:51] VOIX: si tu as envie
-[02:20:52] VOIX: en fait de répondre
-[02:20:54] VOIX: aux questions
-[02:20:55] VOIX: de t'amuser
-[02:20:56] VOIX: mais en fait
-[02:20:57] VOIX: je ne veux pas
-[02:20:57] VOIX: te prendre le temps
-[02:20:58] VOIX: je sais qu'en ce moment
-[02:20:59] VOIX: en fait c'est la course
-[02:20:59] VOIX: et machin
-[02:21:00] VOIX: donc on le verra après
-[02:21:02] VOIX: moi j'ai noté en fait
-[02:21:03] VOIX: tout ce que j'avais à noter
-[02:21:05] VOIX: et je reviendrai
-[02:21:06] VOIX: de toute façon
-[02:21:06] VOIX: en fait obligatoirement
-[02:21:07] VOIX: vers toi
-[02:21:08] VOIX: parce que
-[02:21:08] VOIX: au fur et à mesure
-[02:21:09] VOIX: je vais
-[02:21:09] VOIX: là
-[02:21:10] VOIX: au fur et à mesure
-[02:21:11] VOIX: j'avance
-[02:21:11] VOIX: et après
-[02:21:12] VOIX: on va passer en fait
-[02:21:13] VOIX: sur des tests
-[02:21:14] VOIX: je vais faire des corrections
-[02:21:15] VOIX: tu vas me critiquer
-[02:21:16] VOIX: je vais faire des corrections
-[02:21:18] VOIX: tu vas me re-re-critiquer
-[02:21:21] VOIX: mais c'est normal en fait
-[02:21:23] VOIX: c'est voilà
-[02:21:23] VOIX: si il n'y a pas de critique
-[02:21:25] VOIX: là je vais m'inquiéter
-[02:21:28] VOIX: n'inquiète pas
-[02:21:29] VOIX: on va
-[02:21:31] VOIX: on va faire les choses
-[02:21:32] VOIX: les unes après les autres
-[02:21:34] VOIX: allez
-[02:21:34] VOIX: mais bonne soirée
-[02:21:35] VOIX: bonsoir Guy
-[02:21:38] VOIX: à demain
-[02:21:39] VOIX: à demain
-[02:21:42] VOIX: et bonne soirée
-[02:21:43] VOIX: à toi aussi
-[02:21:44] VOIX: allez ciao ciao
-[02:21:44] VOIX: bonne soirée
-[02:21:45] VOIX: salut
-[02:22:13] VOIX: mais non
-[02:22:15] VOIX: bah écoute
-[02:22:29] VOIX: ah ma doublure
-[02:22:47] VOIX: coucou
-[02:22:49] VOIX: ça va ?
-[02:22:55] VOIX: tous les jours
-[02:22:56] VOIX: dans les plans
-[02:22:58] VOIX: ils changent
-[02:22:59] VOIX: mon cas aussi
-[02:23:00] VOIX: notre emplacement
-[02:23:01] VOIX: c'est plus vers moi
-[02:23:04] VOIX: c'est pénible
-[02:23:09] VOIX: attendez moi
-[02:23:09] VOIX: j'ai vu quelque chose
-[02:23:11] VOIX: comme vous parlez
-[02:23:13] VOIX: il y a quoi
-[02:23:13] VOIX: ça va ?
-[02:23:13] VOIX: ça va ?
-[02:23:13] VOIX: non
-[02:23:13] VOIX: pas
-[02:23:14] VOIX: pas
-[02:23:15] VOIX: comprends
-[02:23:19] VOIX: Merci.
-[02:23:51] VOIX: Merci.
-[02:24:23] VOIX: Merci.
-[02:24:48] VOIX: Merci.
-[02:25:23] VOIX: Merci.
-[02:25:53] VOIX: Merci.
-[02:26:01] VOIX: Merci.
-[02:26:25] VOIX: Merci.
-[02:26:55] VOIX: Merci.
-[02:27:18] VOIX: Merci.
-[02:27:54] VOIX: Merci.
-[02:28:00] VOIX: Merci.
-[02:28:19] VOIX: Merci.
-[02:28:23] VOIX: Merci.
-[02:28:50] VOIX: Merci.
-[02:28:53] VOIX: Merci.
-[02:28:54] VOIX: Merci.
-[02:29:01] VOIX: Merci.
-[02:29:05] VOIX: Merci.
-[02:29:08] VOIX: Merci.
-[02:29:09] VOIX: Merci.
-[02:29:09] VOIX: Moi, si je me souviens, si je me dansais dans un truc, ça m'amènerait.
-[02:29:12] VOIX: Moi, le truc, je me disais, là, passé, au final, si, je fais quelque chose, mais je ne fais pas travailler.
-[02:29:23] VOIX: Et là, si je me fais la même chose, je sentais que ça a su le pire.
-[02:29:32] VOIX: Je ne sais pas si je me disais, mais je ne sentais pas le pire.
-[02:29:35] VOIX: Je ne sais pas si je me disais, mais je ne sais pas.
-[02:29:47] VOIX: ...
-[02:29:48] VOIX: ...
-[02:29:54] VOIX: ...
-[02:29:54] VOIX: ...
-[02:29:54] VOIX: ...
-[02:29:56] VOIX: ...
-[02:29:56] VOIX: ...
-[02:29:58] VOIX: ...
-[02:30:00] VOIX: ...
-[02:30:01] VOIX: ...
-[02:30:02] VOIX: Et moi, c'est ce que j'ai été dit.
-[02:30:04] VOIX: Et s'ils me prennent, tu crois que tu vas travailler chez eux ?
-[02:30:08] VOIX: Bah, s'ils te prennent, ça va être difficile.
-[02:30:11] VOIX: Je crois que tu vas travailler chez eux.
-[02:30:17] VOIX: Donc...
-[02:30:17] VOIX: Oui, c'est pas comme moi que je vais travailler chez eux.
-[02:30:23] VOIX: Je comprends, mais quoi ?
-[02:30:24] VOIX: Moi, j'ai l'impression que c'est la science.
-[02:30:29] VOIX: Oui, c'est longtemps que tu vas travailler chez eux.
-[02:30:33] VOIX: Encore là, les mythes, elle est allée samedi, on est appelé,
-[02:30:37] VOIX: donc ils font des photos.
-[02:30:38] VOIX: Et rien de voir les photos, ça a été tourné.
-[02:30:42] VOIX: Parce que...
-[02:30:43] VOIX: Non, mais...
-[02:30:43] VOIX: C'est bon, en fait, c'est bien aussi.
-[02:30:46] VOIX: Donc, si tu veux, c'est déjà possible.
-[02:30:49] VOIX: Après, en fait, tu vois, en fait, si c'était prise, c'est pas prise.
-[02:30:52] VOIX: Si c'était prise, c'est tant mieux, parce que c'est pas le besoin.
-[02:30:55] VOIX: C'est bon.
-[02:30:55] VOIX: Il est utile, peut-être que...
-[02:30:57] VOIX: Il y a un, peut-être qu'il y a un, peut-être qu'il y a un,
-[02:30:59] VOIX: mais il empêche un peu de continuer avec la...
-[02:31:01] VOIX: Sur cette partie-là, en même temps.
-[02:31:07] VOIX: C'est justement...
-[02:31:08] VOIX: C'est justement...
-[02:31:09] VOIX: C'est justement...
-[02:31:10] VOIX: C'est tout cas, peut-être qu'il faut travailler dans ta vie,
-[02:31:13] VOIX: mais c'est pas le plus...
-[02:31:15] VOIX: C'est sûrement, ça va.
-[02:31:20] VOIX: C'est un peu plus simple.
-[02:31:24] VOIX: C'est un peu plus simple.
-[02:31:25] VOIX: C'est un peu plus simple.
-[02:31:26] VOIX: C'est un peu plus simple.
-[02:31:38] VOIX: C'est-à-dire, en fait,
-[02:31:48] VOIX: il y a un peu plus simple.
-[02:31:57] VOIX: C'est un peu plus simple.
-[02:32:09] VOIX: C'est un peu plus simple.
-[02:32:11] VOIX: C'est un peu plus simple.
-[02:32:13] VOIX: Bon.
-[02:32:14] VOIX: Je vais m'amener la cabane,
-[02:32:15] VOIX: et...
-[02:32:16] VOIX: ...
\ No newline at end of file
diff --git a/sans titre_summary_v2.md b/sans titre_summary_v2.md
deleted file mode 100644
index 18f8990..0000000
--- a/sans titre_summary_v2.md
+++ /dev/null
@@ -1,73 +0,0 @@
-### 1. Cadre général du « codage CIM‑10 » dans le cadre d’un **contrôle T2A**
-*(vision d’un médecin DIM – expert en codage et en gestion des contrôles)*
-
-| Phase | Objectif | Principales actions | Outils / livrables |
-|-------|----------|----------------------|--------------------|
-| **Avant le contrôle** | Préparer le dossier afin que toutes les données nécessaires soient disponibles | 1. Réception de la **notification ARS** (lettre de contrôle).
2. Identifier le **patient / séjour** (nom, prénom, NIR, unité médicale, date d’entrée et de sortie, mode d’entrée).
3. Récupérer :
• le **CRH** (Compte‑Rendu Hospitalier),
• les **actes** classants (chirurgicaux, anesthésie, etc.) et non‑classants (imagerie, biologie, etc.),
• les **examens** (scanner, IRM, échographies),
• les **bilans biologiques** (hémocultures, Bilan rénal, Bilan hépatique, etc.),
• les **notes infirmières / IDE / kiné**,
• les **documents d‑administration** (admission, transfert, sortie).
4. Vérifier la présence du **numéro OGC** (fichier RSS‑DEF/DRUIDE) : il permet d’associer chaque unité médicale au dossier OGC.
5. Créer le **fichier de suivi** (tracker / Excel ou CSV) contenant tous les éléments ci‑dessus, classés par unité médicale. | - Fichier RSS‑DEF (ou export DRUIDE) – numéros OGC
- EVA (ou système interne) – tracker
- Guide méthodologique CIM‑10/ T2A (chapitres 5‑6, règles M2, D1‑D3, etc.) |
-| **Pendant le contrôle** | Analyser le dossier et déterminer le **diagnostic principal (DP)** et les **diagnostics associés (DAS)** conformément aux règles de la **CIM‑10** et du **tarif T2A** | 1. Lire **chronologiquement** le parcours patient : motif d’entrée → urgences → prise en charge → interventions → évolution.
2. Identifier le **DP** : le problème de santé qui a justifié l’hospitalisation.
- Si le DP est **une pathologie clairement identifiée** (ex. pancréatite, colique hépatique), il devient le DP.
- Si seul un **symptôme** persiste et aucune cause n’a été trouvée (ex. douleur abdominale sans cause retrouvée), le symptôme reste le DP (règle D1).
3. Lister les **DAS** : comorbidités (CMA), maladies chroniques, antécédents pertinents qui ont été **traités ou suivis** pendant le séjour.
- Ne coder comme DAS que les diagnostics **effectivement pris en charge** (ex. diabète de type 2 traité, anémie traitée).
4. Appliquer les **règles de hiérarchisation** (ex. R65 = sepsis lorsqu’une fièvre > 38 °C ou fréquence cardiaque > 90 bpm).
5. Vérifier les **actes classants** : ils orientent le **GHM** (Groupe Homogène de Malades) et le **groupage tarifaire**.
- Ex. une cholécystectomie → GHM 08C (chirurgical).
- Les actes non‑classants (scan, échographie) n’influencent pas le groupage, mais justifient le DP/DAS.
6. Comparer le **DP/DAS** proposé avec le **fichier d’interdiction/acceptation** (Excel référentiel CIM‑10 / T2A).
- Détecter les codes interdits (ex. K802 vs K805) ; choisir le code autorisé.
7. Émettre un **argumentaire** : expliquer le choix du DP/DAS, citer les preuves (CRH, compte‑rendu opératoire, résultats d’imagerie, prescriptions).
8. Soumettre le **fichier d’UCR** (Unité de Contrôle et de Retour) au contrôleur. | - Guide méthodologique (chapitres 5‑6)
- Référentiel CIM‑10 / T2A (Excel)
- Outil IA (lecture automatisée) – facultatif, à valider par le médecin |
-| **Après le contrôle** | Répondre aux **observations du médecin contrôleur**, réviser le codage si nécessaire, finaliser la **fiche OGC** (numéro OGC + DP/DAS) | 1. Réception du **rapport du contrôleur** (commentaires, refus de certains codes).
2. Re‑examiner les pièces du dossier pointées (ex. rapport de scanner, prescription).
3. Corriger le DP/DAS ou les actes, en justifiant les changements dans l’**argumentaire**.
4. Mettre à jour la **fiche OGC** (nouveau DP/DAS, nouveau groupe‑tarif).
5. Valider la version finale avec le **MRC** (médecin responsable du contrôle) et le **DIM**. | - Fiche OGC mise à jour
- Rapport final du contrôle (archivé) |
-
----
-
-## 2. Points d’attention obligatoires (check‑list)
-
-| # | Point d’attention | Pourquoi c’est crucial |
-|---|-------------------|------------------------|
-| **1. Identification du patient** | Nom, prénom, NIR, date de naissance, unité(s) médicale(s) et numéro OGC. | Garantit que le codage porte sur le bon séjour. |
-| **2. Dates d’entrée et de sortie** | Vérifier la **date d’admission** (souvent le jour où le patient passe aux urgences) et la **date de sortie** exacte. | Influence le calcul de la durée, le groupe‑tarif et la pertinence du DP. |
-| **3. Mode d’entrée / sortie** | Urgences, réanimation, domicile, transfert. | Certains contrôles (ex. transfert inter‑établissements) sont soumis à des règles spécifiques. |
-| **4. Exhaustivité du dossier** | Tous les documents : CRH, compte‑rendu opératoire, rapports d’imagerie, bilans biologiques, notes IDE/kyné, prescriptions. | Sans la totalité, le DP peut être mal identifié (symptôme vs pathologie). |
-| **5. Présence du numéro OGC** | Extraction du fichier RSS‑DEF/DRUIDE → OGC 1, OGC 2, … | Nécessaire pour créer/mettre à jour la **fiche OGC** et lier le séjour au groupe‑tarif. |
-| **6. Distinction DP / DAS** | DP = motif d’hospitalisation ; DAS = diagnostics réellement traités. | Le DP entre dans le groupe‑tarif, les DAS sont facturés en supplément seulement s’ils sont pris en charge. |
-| **7. Règle “symptôme vs cause”** | Si aucune cause n’est identifiée, le symptôme reste DP (ex. douleur abdominale non expliquée). | Évite de coder un diagnostic absent du dossier (refus du contrôleur). |
-| **8. Sévérité (niveau 2, 3, 4)** | Appliquer les règles de gravité (ex. R65 = sepsis, R70 = déshydratation). | Influence le **groupe de gravité** et le montant de la prise en charge. |
-| **9. Actes classants** | Identifier les actes qui déterminent le **GHM** (ex. cholécystectomie → 08C). | Si l’acte classant manque, le groupe‑tarif sera erroné. |
-| **10. Actes non‑classants** | Conservés pour justifier le DP/DAS mais ne participent pas au groupage. | Useful for argumentation ; must still être mentionnés dans le dossier. |
-| **11. Respect du référentiel de codes** | Utiliser le fichier Excel **CIM‑10/T2A** fourni (liste des codes acceptés, interdits). | Empêche le choix de codes “interdits” (ex. K802 vs K805). |
-| **12. Comorbidités (CMA)** | Coder les pathologies chroniques **si elles sont effectivement prises en charge** pendant le séjour. | Elles sont prises en compte dans le calcul du **GHS** et du **coefficient de sévérité**. |
-| **13. Vérification des transferts** | Si le patient a été transféré d’un autre établissement, en note la prise en charge et l’inscrire dans le OGC. | Le contrôle peut remettre en cause le DP si le transfert n’est pas justifié. |
-| **14. Argumentaire écrit** | Chaque code doit être justifié (CRH, rapport d’anatomopathologie, résultats d’imagerie, prescription). | Le contrôleur se base sur cet argumentaire pour valider ou refuser. |
-| **15. Validation par le DIM** | Avant l’envoi, le **DIM** vérifie la conformité du DP, des DAS, des actes et du fichier OGC. | Dernière couche de contrôle interne avant la soumission. |
-| **16. Gestion des corrections** | Si le contrôleur propose un changement, modifier le DP/DAS/OGC **et** mettre à jour l’argumentaire, puis re‑soumettre. | Permet de clôturer le contrôle sans pénalité financière. |
-| **17. Suivi des dossiers “hors‑norme”** | Identifier les dossiers où le DP reste un symptôme non résolu ou où le montant à payer est très élevé ; les faire valider manuellement. | Ces dossiers requièrent souvent l’intervention du médecin contrôleur pour éviter un « refus total ». |
-| **18. Historisation** | Conserver chaque version de la fiche OGC, du tracker et du rapport d’argumentation. | Traçabilité exigée lors de contrôles subséquents ou d’audits. |
-| **19. Format de remise** | CSV / Excel : numéro dossier, OGC, DP, DAS, actes classants/non‑classants, dates, mode d’entrée, groupe‑tarif, justification. | Le format attendu par le service de contrôle (ex. Blue‑File/ShareFile). |
-| **20. Respect des délais** | Soumission du dossier complet **avant la date limite** indiquée dans la notification T2A. | Tout retard entraîne des pénalités automatiques. |
-
----
-
-## 3. Synthèse pratique – Modèle de fiche de contrôle T2A
-
-| Colonnes | Contenu attendu |
-|----------|----------------|
-| **N° dossier** | Identifiant interne de l’établissement. |
-| **Numéro OGC** | OGC‑1, OGC‑2 … (extrait de RSS‑DEF/DRUIDE). |
-| **Date d’entrée** | JJ/MM/AAAA + mode (urgences, SD, etc.). |
-| **Date de sortie** | JJ/MM/AAAA + mode (domicile, transfert). |
-| **DP** | Code CIM‑10 **principal** (ex. K85.0 pancréatite aiguë). |
-| **Motif d’hospitalisation** | Texte libre (ex. douleur abdominale, colique hépatique). |
-| **DAS (CMA)** | Codes CIM‑10 associés, **seuls** s’ils ont été traités (ex. E11.9 diabète, Z92 antécédent). |
-| **Actes classants** | Code CCAM + date (ex. HAJFA cholécystectomie). |
-| **Actes non‑classants** | Code CCAM + date (ex. ZEFA scanner). |
-| **Groupe‑tarif (GHM)** | Code GHM résultant du regroupement (ex. 08C). |
-| **Sevérité** | Niveau 1/2/3/4 ou R65 etc. |
-| **Justification** | Référence précise (CRH n°, compte‑rendu opératoire, rapport d’imagerie, ordonnance). |
-| **Commentaire contrôleur** | (à remplir après réception). |
-| **Statut** | Soumis / accepté / corrigé / clôturé. |
-
-*Cette fiche peut être exportée en CSV/Excel et importée dans le système de **Blue‑File/ShareFile** ou directement dans **EVA**.*
-
----
-
-## 4. Conclusion
-
-Le **codage CIM‑10 dans le cadre d’un contrôle T2A** repose sur une méthode rigoureuse :
-
-1. **Collecter** l’ensemble des pièces du dossier (CRH, OGC, actes, bilans).
-2. **Analyser** le parcours clinique pour identifier le **diagnostic principal** (DP) et les **diagnostics associés** (DAS).
-3. **Appliquer** les **règles de hiérarchisation**, de **sévérité** et de **compatibilité** du référentiel (Excel).
-4. **Justifier** chaque code par des documents probants et rédiger un **argumentaire** clair.
-5. **Vérifier** le tout avec le **DIM**, puis le soumettre via la **fiche OGC**.
-6. **Traiter** les remarques du contrôleur en révisant le DP/DAS, en documentant les changements et en re‑soumettant.
-
-Respecter scrupuleusement les **points d’attention obligatoires** listés ci‑dessus garantit une conformité maximale, minimise les rejets et optimise la valorisation financière du séjour.
\ No newline at end of file
diff --git a/unsloth_compiled_cache/.locks/.lock.UnslothBCOTrainer.py b/unsloth_compiled_cache/.locks/.lock.UnslothBCOTrainer.py
deleted file mode 100644
index e69de29..0000000
diff --git a/unsloth_compiled_cache/.locks/.lock.UnslothCPOTrainer.py b/unsloth_compiled_cache/.locks/.lock.UnslothCPOTrainer.py
deleted file mode 100644
index e69de29..0000000
diff --git a/unsloth_compiled_cache/.locks/.lock.UnslothDPOTrainer.py b/unsloth_compiled_cache/.locks/.lock.UnslothDPOTrainer.py
deleted file mode 100644
index e69de29..0000000
diff --git a/unsloth_compiled_cache/.locks/.lock.UnslothGKDTrainer.py b/unsloth_compiled_cache/.locks/.lock.UnslothGKDTrainer.py
deleted file mode 100644
index e69de29..0000000
diff --git a/unsloth_compiled_cache/.locks/.lock.UnslothGRPOTrainer.py b/unsloth_compiled_cache/.locks/.lock.UnslothGRPOTrainer.py
deleted file mode 100644
index e69de29..0000000
diff --git a/unsloth_compiled_cache/.locks/.lock.UnslothKTOTrainer.py b/unsloth_compiled_cache/.locks/.lock.UnslothKTOTrainer.py
deleted file mode 100644
index e69de29..0000000
diff --git a/unsloth_compiled_cache/.locks/.lock.UnslothNashMDTrainer.py b/unsloth_compiled_cache/.locks/.lock.UnslothNashMDTrainer.py
deleted file mode 100644
index e69de29..0000000
diff --git a/unsloth_compiled_cache/.locks/.lock.UnslothORPOTrainer.py b/unsloth_compiled_cache/.locks/.lock.UnslothORPOTrainer.py
deleted file mode 100644
index e69de29..0000000
diff --git a/unsloth_compiled_cache/.locks/.lock.UnslothOnlineDPOTrainer.py b/unsloth_compiled_cache/.locks/.lock.UnslothOnlineDPOTrainer.py
deleted file mode 100644
index e69de29..0000000
diff --git a/unsloth_compiled_cache/.locks/.lock.UnslothPPOTrainer.py b/unsloth_compiled_cache/.locks/.lock.UnslothPPOTrainer.py
deleted file mode 100644
index e69de29..0000000
diff --git a/unsloth_compiled_cache/.locks/.lock.UnslothPRMTrainer.py b/unsloth_compiled_cache/.locks/.lock.UnslothPRMTrainer.py
deleted file mode 100644
index e69de29..0000000
diff --git a/unsloth_compiled_cache/.locks/.lock.UnslothRLOOTrainer.py b/unsloth_compiled_cache/.locks/.lock.UnslothRLOOTrainer.py
deleted file mode 100644
index e69de29..0000000
diff --git a/unsloth_compiled_cache/.locks/.lock.UnslothRewardTrainer.py b/unsloth_compiled_cache/.locks/.lock.UnslothRewardTrainer.py
deleted file mode 100644
index e69de29..0000000
diff --git a/unsloth_compiled_cache/.locks/.lock.UnslothSFTTrainer.py b/unsloth_compiled_cache/.locks/.lock.UnslothSFTTrainer.py
deleted file mode 100644
index e69de29..0000000
diff --git a/unsloth_compiled_cache/.locks/.lock.UnslothXPOTrainer.py b/unsloth_compiled_cache/.locks/.lock.UnslothXPOTrainer.py
deleted file mode 100644
index e69de29..0000000
diff --git a/unsloth_compiled_cache/UnslothBCOTrainer.py b/unsloth_compiled_cache/UnslothBCOTrainer.py
deleted file mode 100644
index d390393..0000000
--- a/unsloth_compiled_cache/UnslothBCOTrainer.py
+++ /dev/null
@@ -1,2180 +0,0 @@
-"""
-2026.2.1
-2026.2.1
-4.57.6
-0.24.0
-__UNSLOTH_VERSIONING__
-"""
-
-# Unsloth auto generated code
-# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Lesser General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU Lesser General Public License
-# along with this program. If not, see .
-
-from torch import Tensor
-import torch
-import torch.nn as nn
-from torch.nn import functional as F
-from unsloth_zoo.temporary_patches.common import torch_compile
-from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
-from trl.trainer.bco_trainer import (Any, AutoModelForCausalLM, BCOConfig, BCOTrainer, BaseImageProcessor, BaseTrainer, CLF_NAME, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, LogisticRegression, Optional, PartialState, Path, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, RUNNING_NAME, RunningMoments, SequentialSampler, TrainerCallback, TrainingArguments, Union, _process_tokens, _tokenize, autocast, contextmanager, create_reference_model, defaultdict, disable_dropout_in_model, has_length, inspect, is_comet_available, is_joblib_available, is_peft_available, is_sklearn_available, is_wandb_available, itemgetter, joblib, log_table_to_comet_experiment, logger, logging, maybe_apply_chat_template, maybe_extract_prompt, maybe_unpair_preference_dataset, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_deepspeed, prepare_model_for_kbit_training, random, selective_log_softmax, textwrap, torch, tqdm, warnings, AutoModelForCausalLM, BCOConfig, BCOTrainer, BaseImageProcessor, Callable, DPODataCollatorWithPadding, DataCollator, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, LogisticRegression, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, RunningMoments, TrainerCallback, TrainingArguments, Union, autocast, create_reference_model, defaultdict, disable_dropout_in_model, inspect, is_comet_available, is_joblib_available, is_peft_available, is_sklearn_available, is_wandb_available, joblib, logger, maybe_apply_chat_template, maybe_extract_prompt, maybe_unpair_preference_dataset, nn, np, os, peft_module_casting_to_bf16, prepare_deepspeed, prepare_model_for_kbit_training, torch, warnings, F, Optional, PeftModel, PreTrainedModel, is_peft_available, logger, os, torch)
-
-
-import os
-from typing import *
-from dataclasses import dataclass, field
-from packaging.version import Version
-import torch
-import numpy as np
-from contextlib import nullcontext
-from torch.nn import functional as F
-import inspect
-from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
-from transformers.training_args import ParallelMode
-from unsloth_zoo.device_type import DEVICE_TYPE, device_synchronize
-
-# Wrap trainer with padding to right and enable training mode
-# Also patches W&B since multiple runs must use wandb.finish()
-import functools
-from types import MethodType
-try:
- from unsloth_zoo.gradient_checkpointing import reset_unsloth_gradient_checkpointing_buffers
-except:
- def reset_unsloth_gradient_checkpointing_buffers(): pass
-def prepare_for_training_mode(f):
- @functools.wraps(f)
- def wrapper(self, *args, **kwargs):
- # Enable training mode
- _was_training = None
- # Get gradient checkpointing setting from training arguments
- use_gc = getattr(self.args, 'gradient_checkpointing', True)
- if hasattr(self, 'model') and hasattr(self.model, "training"):
- _was_training = self.model.training
- if hasattr(self, 'model') and hasattr(self.model, "for_training"):
- self.model.for_training(use_gradient_checkpointing=use_gc)
- output = f(self, *args, **kwargs)
- # Restore previous mode when possible
- if hasattr(self, 'model') and hasattr(self.model, "for_inference"):
- if _was_training is False:
- self.model.for_inference()
- elif _was_training is True and hasattr(self.model, "for_training"):
- self.model.for_training(use_gradient_checkpointing=use_gc)
- # Reset gradient checkpointing buffers to free memory while staying ready for next run
- try:
- reset_unsloth_gradient_checkpointing_buffers()
- except:
- pass
- # Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run
- try:
- import wandb
- wandb.finish()
- except:
- pass
- return output
- return wrapper
-pass
-
-torch_compile_options = {
- "epilogue_fusion" : True,
- "max_autotune" : False,
- "shape_padding" : True,
- "trace.enabled" : False,
- "triton.cudagraphs" : False,
-}
-
-@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
-def chunked_hidden_states_selective_log_softmax(
- hidden_states: torch.Tensor,
- lm_head: torch.Tensor,
- index: torch.Tensor,
- chunks: int = 4,
- logit_scale_multiply: float = 0.0,
- logit_scale_divide: float = 0.0,
- logit_softcapping: float = 0.0,
- temperature: float = 1.0,
-) -> torch.Tensor:
- # All Unsloth Zoo code licensed under AGPL3
- flat_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1])
- flat_index = index.reshape(-1)
-
- chunked_hidden_states = torch.chunk(flat_hidden_states, chunks=chunks, dim=0)
- chunked_index = torch.chunk(flat_index, chunks=chunks, dim=0)
-
- all_per_token_logps = []
-
- for chunk_hidden_states, chunk_index in zip(chunked_hidden_states, chunked_index):
- chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t()
-
- if logit_scale_multiply != 0.0:
- chunk_logits = chunk_logits * logit_scale_multiply
- if logit_scale_divide != 0.0:
- chunk_logits = chunk_logits / logit_scale_divide
- if logit_softcapping != 0.0:
- chunk_logits = chunk_logits * torch.tanh(chunk_logits / logit_softcapping)
-
- chunk_logits = chunk_logits.to(torch.float32)
-
- if temperature != 1.0:
- chunk_logits = chunk_logits / temperature
-
- selected_logits = torch.gather(chunk_logits, dim=-1, index=chunk_index.unsqueeze(-1)).squeeze(-1)
- logsumexp_values = torch.logsumexp(chunk_logits, dim=-1)
- per_token_logps = selected_logits - logsumexp_values
- all_per_token_logps.append(per_token_logps)
-
- all_per_token_logps = torch.concat(all_per_token_logps)
-
- all_per_token_logps = all_per_token_logps.reshape((hidden_states.shape[0], hidden_states.shape[1]))
- return all_per_token_logps
-
-@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
-def chunked_selective_log_softmax(logits, index):
- # Split into 4 chunks only
- chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
- chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
- all_per_token_logps = []
- # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
- for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
- chunk_logits = chunk_logits.to(torch.float32)
- selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
- logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
- per_token_logps = selected_logits - logsumexp_values
- all_per_token_logps.append(per_token_logps)
- pass
- all_per_token_logps = torch.concat(all_per_token_logps)
- all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
- return all_per_token_logps
-
-def calculate_pad_tokens_in_prompt(
- input_ids: torch.Tensor,
- logits_to_keep: int,
- pad_token_id: int
-) -> torch.Tensor:
- """
- Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens
- """
- if logits_to_keep >= input_ids.shape[1]:
- raise ValueError("logits_to_keep must be smaller than the sequence length.")
-
- prompt_section = input_ids[:, :-logits_to_keep]
-
- padding_mask = (prompt_section == pad_token_id)
-
- pad_token_counts = padding_mask.sum(dim=1)
-
- return pad_token_counts
-
-def create_completion_attention_mask(
- completion_input_ids: torch.Tensor,
- left_pad_tokens_per_prompt: torch.Tensor,
- max_left_pad: int,
- pad_token_id: int
-) -> torch.Tensor:
- """
- Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad]
-
- Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens
- and pad are pad tokens, this function would make a completion mask that would 0 out the pad
- and p tokens. so in this example [0,0,0,1,1,1,0,0,0]
- """
- batch_size, completion_len = completion_input_ids.shape
- device = completion_input_ids.device
-
- num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt
-
- indices = torch.arange(completion_len, device=device).unsqueeze(0)
- shift_mask = indices >= num_tokens_to_mask.unsqueeze(1)
-
- non_padding_mask = (completion_input_ids != pad_token_id)
-
- final_mask = shift_mask & non_padding_mask
-
- return final_mask
-
-def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor:
- """
- Moves all padding tokens in each sequence of a batch to the right.
- """
- mask = (tensor != pad_id)
- # Must do stable=True since binary mark is unordered
- sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True)
- packed_tensor = torch.gather(tensor, 1, sorted_indices)
- return packed_tensor
-
-def align_logprobs_with_mask(
- logprob_tensor: torch.Tensor,
- attention_mask: torch.Tensor,
- pad_value: float = 0.0
-) -> torch.Tensor:
- """
- Aligns a log probability tensor with a given attention mask.
- """
-
- device = logprob_tensor.device
- batch_size, logprob_seq_len = logprob_tensor.shape
- mask_seq_len = attention_mask.shape[1]
-
- padded_logprobs = torch.full(
- attention_mask.shape,
- fill_value=pad_value,
- dtype=logprob_tensor.dtype,
- device=device
- )
-
- left_pad_counts = torch.argmax(attention_mask, dim=1)
-
- cols = torch.arange(logprob_seq_len, device=device)
- dest_indices = left_pad_counts.unsqueeze(1) + cols
-
- # Create destination row indices
- # Shape: [batch_size, logprob_seq_len]
- row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices)
-
- # --- 4. Filter out-of-bounds indices and perform assignment ---
- # Create a mask to identify only the indices that are within the bounds
- # of the target tensor's sequence length.
- valid_mask = dest_indices < mask_seq_len
-
- # Use this mask to select only the valid row indices, column indices,
- # and the corresponding values from the logprob tensor.
- # This flattens the selected elements into 1D tensors.
- valid_rows = row_indices[valid_mask]
- valid_cols = dest_indices[valid_mask]
- valid_vals = logprob_tensor[valid_mask]
-
- # Place the valid values into their correct positions in the padded tensor
- # using a single, efficient advanced indexing operation.
- padded_logprobs[valid_rows, valid_cols] = valid_vals
-
- return padded_logprobs
-
-def autotune_batch_and_chunks(
- total_input_rows,
- seq_len,
- hidden_size,
- vocab_size,
- dtype_bytes=16,
- multiplier=None
-):
- if multiplier is None:
- final_m = max(4, seq_len // 4096)
- else:
- final_m = multiplier
-
- if torch.cuda.is_available():
- free_bytes, _ = torch.cuda.mem_get_info()
- limit_gb = (free_bytes / (1024**3))*.80
- elif hasattr(torch, "xpu") and torch.xpu.is_available():
- # For XPU: estimate free memory from total - reserved
- total_mem = torch.xpu.get_device_properties(0).total_memory
- reserved_mem = torch.xpu.memory_reserved()
- free_bytes = total_mem - reserved_mem
- limit_gb = (free_bytes / (1024**3)) * 0.80
- else:
- # Fallback: assume 8GB available
- limit_gb = 8.0
-
- bytes_to_gb = 1024**3
-
- b_vals = torch.arange(total_input_rows, 0, -1, device='cpu', dtype=torch.float32)
-
- hidden_gb = (b_vals * seq_len * hidden_size * dtype_bytes) / bytes_to_gb
-
- base_logits = ((b_vals/total_input_rows) * b_vals * seq_len * vocab_size * dtype_bytes) / bytes_to_gb
- logits_gb = base_logits / final_m
-
- total_mem_gb = hidden_gb + logits_gb
-
- valid_mask = total_mem_gb <= limit_gb
- valid_indices = torch.nonzero(valid_mask, as_tuple=False)
-
- if valid_indices.shape[0] == 0:
- #This means your GPU will OOM
- return 4, final_m
-
- best_idx = valid_indices[0].item()
- final_b = int(b_vals[best_idx].item())
-
- return final_b, final_m
-@dataclass
-class UnslothBCOConfig(BCOConfig):
- """
-
- Configuration class for the [`BCOTrainer`].
-
- This class includes only the parameters that are specific to BCO training. For a full list of training arguments,
- please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may
- differ from those in [`~transformers.TrainingArguments`].
-
- Using [`~transformers.HfArgumentParser`] we can turn this class into
- [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
- command line.
-
- Parameters:
- max_length (`int` or `None`, *optional*, defaults to `1024`):
- Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
- to use the default data collator.
- max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
- Maximum length of the prompt. This argument is required if you want to use the default data collator.
- max_completion_length (`int`, *optional*):
- Maximum length of the completion. This argument is required if you want to use the default data collator
- and your model is an encoder-decoder.
- beta (`float`, *optional*, defaults to `0.1`):
- Parameter controlling the deviation from the reference model. Higher β means less deviation from the
- reference model.
- label_pad_token_id (`int`, *optional*, defaults to `-100`):
- Label pad token id. This argument is required if you want to use the default data collator.
- padding_value (`int`, *optional*):
- Padding value to use. If `None`, the padding value of the tokenizer is used.
- truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
- Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
- This argument is required if you want to use the default data collator.
- disable_dropout (`bool`, *optional*, defaults to `True`):
- Whether to disable dropout in the model and reference model.
- generate_during_eval (`bool`, *optional*, defaults to `False`):
- If `True`, generates and logs completions from both the model and the reference model to W&B or Comet
- during evaluation.
- is_encoder_decoder (`bool`, *optional*):
- When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
- you need to specify if the model returned by the callable is an encoder-decoder model.
- precompute_ref_log_probs (`bool`, *optional*, defaults to `False`):
- Whether to precompute reference model log probabilities for training and evaluation datasets. This is
- useful when training without the reference model to reduce the total GPU memory needed.
- model_init_kwargs (`dict[str, Any]`, *optional*):
- Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
- string.
- ref_model_init_kwargs (`dict[str, Any]`, *optional*):
- Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the reference model
- from a string.
- dataset_num_proc (`int`, *optional*):
- Number of processes to use for processing the dataset.
- prompt_sample_size (`int`, *optional*, defaults to `1024`):
- Number of prompts that are fed to density ratio classifier.
- min_density_ratio (`float`, *optional*, defaults to `0.5`):
- Minimum value of the density ratio. The estimated density ratio is clamped to this value.
- max_density_ratio (`float`, *optional*, defaults to `10.0`):
- Maximum value of the density ratio. The estimated density ratio is clamped to this value.
-
- """
- vllm_sampling_params: Optional[Any] = field(
- default = None,
- metadata = {'help': 'vLLM SamplingParams'},
- )
- unsloth_num_chunks : Optional[int] = field(
- default = -1,
- metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
- )
- unsloth_logit_chunk_multiplier : Optional[int] = field(
- default = None,
- metadata = {'help': 'Multiplier for chunked logit computations.'},
- )
- unsloth_grpo_mini_batch : Optional[int] = field(
- default = None,
- metadata = {'help': 'Mini batch size for GRPO hidden state accumulation. Default is None unless user defines it.'},
- )
- max_seq_length : Optional[int] = field(
- default = None,
- metadata = {'help': 'Maximum sequence length to truncate to.'},
- )
- def __init__(
- self,
- output_dir = None,
- overwrite_output_dir = None,
- do_train = False,
- do_eval = False,
- do_predict = False,
- eval_strategy = 'no',
- prediction_loss_only = False,
- per_device_train_batch_size = 4,
- per_device_eval_batch_size = 4,
- per_gpu_train_batch_size = None,
- per_gpu_eval_batch_size = None,
- gradient_accumulation_steps = 2,
- eval_accumulation_steps = 2,
- eval_delay = 0,
- torch_empty_cache_steps = 250,
- learning_rate = 5e-05,
- weight_decay = 0.01,
- adam_beta1 = 0.9,
- adam_beta2 = 0.999,
- adam_epsilon = 1e-08,
- max_grad_norm = 1.0,
- num_train_epochs = 3.0,
- max_steps = -1,
- lr_scheduler_type = 'linear',
- lr_scheduler_kwargs = None,
- warmup_ratio = 0.1,
- warmup_steps = 0,
- log_level = 'passive',
- log_level_replica = 'warning',
- log_on_each_node = True,
- logging_dir = None,
- logging_strategy = 'steps',
- logging_first_step = False,
- logging_steps = 1,
- logging_nan_inf_filter = False,
- save_strategy = 'steps',
- save_steps = 500,
- save_total_limit = None,
- save_safetensors = True,
- save_on_each_node = False,
- save_only_model = False,
- restore_callback_states_from_checkpoint = False,
- no_cuda = False,
- use_cpu = False,
- use_mps_device = False,
- seed = 3407,
- data_seed = 3407,
- jit_mode_eval = False,
- bf16 = False,
- fp16 = False,
- fp16_opt_level = 'O1',
- half_precision_backend = 'auto',
- bf16_full_eval = False,
- fp16_full_eval = False,
- tf32 = None,
- local_rank = -1,
- ddp_backend = None,
- tpu_num_cores = None,
- tpu_metrics_debug = False,
- debug = '',
- dataloader_drop_last = False,
- eval_steps = None,
- dataloader_num_workers = 0,
- dataloader_prefetch_factor = None,
- past_index = -1,
- run_name = None,
- disable_tqdm = None,
- remove_unused_columns = True,
- label_names = None,
- load_best_model_at_end = False,
- metric_for_best_model = None,
- greater_is_better = None,
- ignore_data_skip = False,
- fsdp = None,
- fsdp_min_num_params = 0,
- fsdp_config = None,
- fsdp_transformer_layer_cls_to_wrap = None,
- accelerator_config = None,
- parallelism_config = None,
- deepspeed = None,
- label_smoothing_factor = 0.0,
- optim = 'adamw_8bit',
- optim_args = None,
- adafactor = False,
- group_by_length = False,
- length_column_name = 'length',
- report_to = 'none',
- project = 'huggingface',
- trackio_space_id = 'trackio',
- ddp_find_unused_parameters = None,
- ddp_bucket_cap_mb = None,
- ddp_broadcast_buffers = None,
- dataloader_pin_memory = True,
- dataloader_persistent_workers = False,
- skip_memory_metrics = True,
- use_legacy_prediction_loop = False,
- push_to_hub = False,
- resume_from_checkpoint = None,
- hub_model_id = None,
- hub_strategy = 'every_save',
- hub_token = None,
- hub_private_repo = None,
- hub_always_push = False,
- hub_revision = None,
- gradient_checkpointing = True,
- gradient_checkpointing_kwargs = None,
- include_inputs_for_metrics = False,
- eval_do_concat_batches = True,
- fp16_backend = 'auto',
- push_to_hub_model_id = None,
- push_to_hub_organization = None,
- push_to_hub_token = None,
- mp_parameters = '',
- auto_find_batch_size = False,
- full_determinism = False,
- torchdynamo = None,
- ray_scope = 'last',
- ddp_timeout = 1800,
- torch_compile = False,
- torch_compile_backend = None,
- torch_compile_mode = None,
- include_tokens_per_second = False,
- include_num_input_tokens_seen = False,
- neftune_noise_alpha = None,
- optim_target_modules = None,
- batch_eval_metrics = False,
- eval_on_start = False,
- use_liger_kernel = False,
- liger_kernel_config = None,
- eval_use_gather_object = False,
- average_tokens_across_devices = True,
- max_length = 1024,
- max_prompt_length = 512,
- max_completion_length = None,
- beta = 0.1,
- label_pad_token_id = -100,
- padding_value = None,
- truncation_mode = 'keep_end',
- disable_dropout = True,
- generate_during_eval = False,
- is_encoder_decoder = None,
- precompute_ref_log_probs = False,
- model_init_kwargs = None,
- ref_model_init_kwargs = None,
- dataset_num_proc = None,
- prompt_sample_size = 1024,
- min_density_ratio = 0.5,
- max_density_ratio = 10.0,
- vllm_sampling_params = None,
- unsloth_num_chunks = -1,
- unsloth_logit_chunk_multiplier = None,
- unsloth_grpo_mini_batch = None,
- max_seq_length = None,
- **kwargs,
- ):
- if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
- if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
- if num_train_epochs is None:
- num_train_epochs = 3.0 # Default to 3 epochs if None, max_steps will override
- if output_dir is None and save_strategy == 'steps' and save_steps == 500:
- output_dir = 'unsloth_training_checkpoints'
- save_strategy = 'no'
- import multiprocessing as _mp
- if _mp.get_start_method() != 'fork':
- dataset_num_proc = None
- elif dataset_num_proc is None:
- import psutil
- dataset_num_proc = min(max((psutil.cpu_count() or 1)+4, 2), 64)
- memory_gb_left = psutil.virtual_memory().available / (1024**3)
- if memory_gb_left <= 2: dataset_num_proc = 1
- else: dataset_num_proc = min(dataset_num_proc, int(memory_gb_left))
-
- super().__init__(
- output_dir = output_dir,
- overwrite_output_dir = overwrite_output_dir,
- do_train = do_train,
- do_eval = do_eval,
- do_predict = do_predict,
- eval_strategy = eval_strategy,
- prediction_loss_only = prediction_loss_only,
- per_device_train_batch_size = per_device_train_batch_size,
- per_device_eval_batch_size = per_device_eval_batch_size,
- per_gpu_train_batch_size = per_gpu_train_batch_size,
- per_gpu_eval_batch_size = per_gpu_eval_batch_size,
- gradient_accumulation_steps = gradient_accumulation_steps,
- eval_accumulation_steps = eval_accumulation_steps,
- eval_delay = eval_delay,
- torch_empty_cache_steps = torch_empty_cache_steps,
- learning_rate = learning_rate,
- weight_decay = weight_decay,
- adam_beta1 = adam_beta1,
- adam_beta2 = adam_beta2,
- adam_epsilon = adam_epsilon,
- max_grad_norm = max_grad_norm,
- num_train_epochs = num_train_epochs,
- max_steps = max_steps,
- lr_scheduler_type = lr_scheduler_type,
- lr_scheduler_kwargs = lr_scheduler_kwargs,
- warmup_ratio = warmup_ratio,
- warmup_steps = warmup_steps,
- log_level = log_level,
- log_level_replica = log_level_replica,
- log_on_each_node = log_on_each_node,
- logging_dir = logging_dir,
- logging_strategy = logging_strategy,
- logging_first_step = logging_first_step,
- logging_steps = logging_steps,
- logging_nan_inf_filter = logging_nan_inf_filter,
- save_strategy = save_strategy,
- save_steps = save_steps,
- save_total_limit = save_total_limit,
- save_safetensors = save_safetensors,
- save_on_each_node = save_on_each_node,
- save_only_model = save_only_model,
- restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
- no_cuda = no_cuda,
- use_cpu = use_cpu,
- use_mps_device = use_mps_device,
- seed = seed,
- data_seed = data_seed,
- jit_mode_eval = jit_mode_eval,
- bf16 = bf16,
- fp16 = fp16,
- fp16_opt_level = fp16_opt_level,
- half_precision_backend = half_precision_backend,
- bf16_full_eval = bf16_full_eval,
- fp16_full_eval = fp16_full_eval,
- tf32 = tf32,
- local_rank = local_rank,
- ddp_backend = ddp_backend,
- tpu_num_cores = tpu_num_cores,
- tpu_metrics_debug = tpu_metrics_debug,
- debug = debug,
- dataloader_drop_last = dataloader_drop_last,
- eval_steps = eval_steps,
- dataloader_num_workers = dataloader_num_workers,
- dataloader_prefetch_factor = dataloader_prefetch_factor,
- past_index = past_index,
- run_name = run_name,
- disable_tqdm = disable_tqdm,
- remove_unused_columns = remove_unused_columns,
- label_names = label_names,
- load_best_model_at_end = load_best_model_at_end,
- metric_for_best_model = metric_for_best_model,
- greater_is_better = greater_is_better,
- ignore_data_skip = ignore_data_skip,
- fsdp = fsdp,
- fsdp_min_num_params = fsdp_min_num_params,
- fsdp_config = fsdp_config,
- fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
- accelerator_config = accelerator_config,
- parallelism_config = parallelism_config,
- deepspeed = deepspeed,
- label_smoothing_factor = label_smoothing_factor,
- optim = optim,
- optim_args = optim_args,
- adafactor = adafactor,
- group_by_length = group_by_length,
- length_column_name = length_column_name,
- report_to = report_to,
- project = project,
- trackio_space_id = trackio_space_id,
- ddp_find_unused_parameters = ddp_find_unused_parameters,
- ddp_bucket_cap_mb = ddp_bucket_cap_mb,
- ddp_broadcast_buffers = ddp_broadcast_buffers,
- dataloader_pin_memory = dataloader_pin_memory,
- dataloader_persistent_workers = dataloader_persistent_workers,
- skip_memory_metrics = skip_memory_metrics,
- use_legacy_prediction_loop = use_legacy_prediction_loop,
- push_to_hub = push_to_hub,
- resume_from_checkpoint = resume_from_checkpoint,
- hub_model_id = hub_model_id,
- hub_strategy = hub_strategy,
- hub_token = hub_token,
- hub_private_repo = hub_private_repo,
- hub_always_push = hub_always_push,
- hub_revision = hub_revision,
- gradient_checkpointing = gradient_checkpointing,
- gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
- include_inputs_for_metrics = include_inputs_for_metrics,
- eval_do_concat_batches = eval_do_concat_batches,
- fp16_backend = fp16_backend,
- push_to_hub_model_id = push_to_hub_model_id,
- push_to_hub_organization = push_to_hub_organization,
- push_to_hub_token = push_to_hub_token,
- mp_parameters = mp_parameters,
- auto_find_batch_size = auto_find_batch_size,
- full_determinism = full_determinism,
- torchdynamo = torchdynamo,
- ray_scope = ray_scope,
- ddp_timeout = ddp_timeout,
- torch_compile = torch_compile,
- torch_compile_backend = torch_compile_backend,
- torch_compile_mode = torch_compile_mode,
- include_tokens_per_second = include_tokens_per_second,
- include_num_input_tokens_seen = include_num_input_tokens_seen,
- neftune_noise_alpha = neftune_noise_alpha,
- optim_target_modules = optim_target_modules,
- batch_eval_metrics = batch_eval_metrics,
- eval_on_start = eval_on_start,
- use_liger_kernel = use_liger_kernel,
- liger_kernel_config = liger_kernel_config,
- eval_use_gather_object = eval_use_gather_object,
- average_tokens_across_devices = average_tokens_across_devices,
- max_length = max_length,
- max_prompt_length = max_prompt_length,
- max_completion_length = max_completion_length,
- beta = beta,
- label_pad_token_id = label_pad_token_id,
- padding_value = padding_value,
- truncation_mode = truncation_mode,
- disable_dropout = disable_dropout,
- generate_during_eval = generate_during_eval,
- is_encoder_decoder = is_encoder_decoder,
- precompute_ref_log_probs = precompute_ref_log_probs,
- model_init_kwargs = model_init_kwargs,
- ref_model_init_kwargs = ref_model_init_kwargs,
- dataset_num_proc = dataset_num_proc,
- prompt_sample_size = prompt_sample_size,
- min_density_ratio = min_density_ratio,
- max_density_ratio = max_density_ratio,**kwargs)
- self.vllm_sampling_params = vllm_sampling_params
- self.unsloth_num_chunks = unsloth_num_chunks
- if unsloth_grpo_mini_batch is not None:
- if self.generation_batch_size >= unsloth_grpo_mini_batch:
- self.unsloth_grpo_mini_batch = unsloth_grpo_mini_batch
- else:
- raise ValueError(
- f"Unsloth GRPO mini batch size needs to be less than or equal to the effective generation batch size, "
- f"which is self.per_device_train_batch_size * gradient_accumulation_steps."
- )
- self.unsloth_logit_chunk_multiplier = unsloth_logit_chunk_multiplier
- self.max_seq_length = max_seq_length
-
-pass
-
-class _UnslothBCOTrainer(BaseTrainer):
- r""""""
-
- _tag_names = ["trl", "bco"]
- _name = "BCO"
- _paper = {
- "title": "Binary Classifier Optimization for Large Language Model Alignment",
- "id": "2404.04656",
- # docstyle-ignore
- "citation": textwrap.dedent("""\
- @article{jung2024binary,
- title = {{Binary Classifier Optimization for Large Language Model Alignment}},
- author = {Seungjae Jung and Gunsoo Han and Daniel Wontae Nam and Kyoung{-}Woon On},
- year = 2024,
- eprint = {arXiv:2404.04656}
- }"""),
- }
-
- def __init__(
- self,
- model: Union[PreTrainedModel, nn.Module, str] = None,
- ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
- args: BCOConfig = None,
- train_dataset: Optional[Dataset] = None,
- eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
- processing_class: Optional[
- Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
- ] = None,
- data_collator: Optional[DataCollator] = None,
- model_init: Optional[Callable[[], PreTrainedModel]] = None,
- callbacks: Optional[list[TrainerCallback]] = None,
- optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
- preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
- peft_config: Optional[dict] = None,
- compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
- model_adapter_name: Optional[str] = None,
- ref_adapter_name: Optional[str] = None,
- embedding_func: Optional[Callable] = None,
- embedding_tokenizer: Optional[PreTrainedTokenizerBase] = None,
- ):
- if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"):
- warnings.warn(
- "This trainer will soon be moved to trl.experimental and is a candidate for removal. If you rely on "
- "it and want it to remain, please share your comments here: "
- "https://github.com/huggingface/trl/issues/4223. Silence this warning by setting environment variable "
- "TRL_EXPERIMENTAL_SILENCE=1."
- )
- if embedding_func is not None and not (is_sklearn_available() and is_joblib_available()):
- raise ImportError(
- "BCOTrainer with UDM requires the scikit-learn and joblib libraries. Please install it with `pip install scikit-learn joblib`."
- )
-
- if type(args) is TrainingArguments:
- raise ValueError("Please use `BCOConfig` instead `TrainingArguments`.")
-
- if not isinstance(model, str) and model is not None and ref_model is model:
- raise ValueError(
- "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
- "same as `model`, you must mass a copy of it, or `None` if you use peft."
- )
-
- if args.model_init_kwargs is None:
- model_init_kwargs = {}
- elif not isinstance(model, str):
- raise ValueError("You passed model_kwargs to the BCOTrainer. But your model is already instantiated.")
- else:
- model_init_kwargs = args.model_init_kwargs
- dtype = model_init_kwargs.get("dtype")
- if dtype is not None:
- # Convert to `torch.dtype` if an str is passed
- if isinstance(dtype, str) and dtype != "auto":
- dtype = getattr(torch, dtype)
- if dtype != "auto" and not isinstance(dtype, torch.dtype):
- raise ValueError(
- f"Invalid `dtype` passed to the BCOConfig. Expected a string with either `torch.dtype` or 'auto', but got {dtype}."
- )
- model_init_kwargs["dtype"] = dtype
-
- if args.ref_model_init_kwargs is None:
- ref_model_init_kwargs = {}
- elif not isinstance(ref_model, str):
- raise ValueError(
- "You passed ref_model_kwargs to the BCOTrainer. But your ref_model is already instantiated."
- )
- else:
- ref_model_init_kwargs = args.ref_model_init_kwargs
- dtype = ref_model_init_kwargs.get("dtype")
- if dtype is not None:
- # Convert to `torch.dtype` if an str is passed
- if isinstance(dtype, str) and dtype != "auto":
- dtype = getattr(torch, dtype)
- if dtype != "auto" and not isinstance(dtype, torch.dtype):
- raise ValueError(
- f"Invalid `dtype` passed to the BCOConfig. Expected a string with either `torch.dtype` or 'auto', but got {dtype}."
- )
- ref_model_init_kwargs["dtype"] = dtype
-
- if isinstance(model, str):
- model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
-
- if isinstance(ref_model, str):
- ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs)
-
- # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
- # has been called in order to properly call autocast if needed.
- self._peft_has_been_casted_to_bf16 = False
-
- if not is_peft_available() and peft_config is not None:
- raise ValueError(
- "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it with `pip install peft` to use the PEFT models"
- )
- elif is_peft_available() and peft_config is not None:
- # if model is a peft model and we have a peft_config, we merge and unload it first
- if isinstance(model, PeftModel):
- model = model.merge_and_unload()
-
- if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
- _support_gc_kwargs = hasattr(
- args, "gradient_checkpointing_kwargs"
- ) and "gradient_checkpointing_kwargs" in list(
- inspect.signature(prepare_model_for_kbit_training).parameters
- )
-
- prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
-
- if _support_gc_kwargs:
- prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
-
- model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
- elif args.gradient_checkpointing:
- # For backward compatibility with older versions of transformers
- if hasattr(model, "enable_input_require_grads"):
- model.enable_input_require_grads()
- else:
-
- def make_inputs_require_grad(module, input, output):
- output.requires_grad_(True)
-
- model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
-
- # get peft model with the given config
- model = model
- if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
- peft_module_casting_to_bf16(model)
- # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
- self._peft_has_been_casted_to_bf16 = True
-
- # For models that use gradient_checkpointing, we need to attach a hook that enables input
- # to explicitly have `requires_grad=True`, otherwise training will either silently
- # fail or completely fail.
- elif args.gradient_checkpointing:
- # For backward compatibility with older versions of transformers
- if hasattr(model, "enable_input_require_grads"):
- model.enable_input_require_grads()
- else:
-
- def make_inputs_require_grad(module, input, output):
- output.requires_grad_(True)
-
- model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
-
- if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
- raise ValueError(
- "`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
- " Please install `wandb` or `comet-ml` to resolve."
- )
-
- if model is not None:
- self.is_encoder_decoder = model.config.is_encoder_decoder
- elif args.is_encoder_decoder is None:
- raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
- else:
- self.is_encoder_decoder = args.is_encoder_decoder
-
- self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
- self.model_adapter_name = model_adapter_name
- self.ref_adapter_name = ref_adapter_name
-
- if ref_model:
- self.ref_model = ref_model
- elif self.is_peft_model or args.precompute_ref_log_probs:
- # The `model` with adapters turned off will be used as the reference model
- self.ref_model = None
- else:
- self.ref_model = create_reference_model(model)
-
- if processing_class is None:
- raise ValueError(
- "max_length or a processing_class must be specified when using the default DPODataCollatorWithPadding"
- )
- if args.max_length is None:
- logger.warning(
- "When using DPODataCollatorWithPadding, you should set `max_length` in the `BCOConfig`. "
- "It will be set to `512` by default, but you should do it yourself in the future.",
- )
- max_length = 512
- if args.max_length is not None:
- max_length = args.max_length
-
- if args.max_prompt_length is None:
- logger.warning(
- "When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the `BCOConfig`. "
- "It will be set to `128` by default, but you should do it yourself in the future.",
- )
- max_prompt_length = 128
- if args.max_prompt_length is not None:
- max_prompt_length = args.max_prompt_length
-
- max_completion_length = None
- if args.max_completion_length is None and self.is_encoder_decoder:
- logger.warning(
- "When using DPODataCollatorWithPadding with an encoder decoder architecture, you should set `max_completion_length` in the BCOTrainer's init"
- " it will be set to `128` by default, but you should do it yourself in the future.",
- )
- max_completion_length = 128
- if args.max_completion_length is not None and self.is_encoder_decoder:
- max_completion_length = args.max_completion_length
-
- if data_collator is None:
- data_collator = DPODataCollatorWithPadding(
- pad_token_id=processing_class.pad_token_id,
- label_pad_token_id=args.label_pad_token_id,
- is_encoder_decoder=self.is_encoder_decoder,
- )
-
- if args.remove_unused_columns:
- args.remove_unused_columns = False
- # warn users
- logger.warning(
- "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your BCOConfig"
- " we have set it for you, but you should do it yourself in the future.",
- )
-
- self.use_dpo_data_collator = True
- else:
- self.use_dpo_data_collator = False
-
- # Disable dropout in the model and reference model
- if args.disable_dropout:
- disable_dropout_in_model(model)
- if self.ref_model is not None:
- disable_dropout_in_model(self.ref_model)
-
- self.max_length = max_length
- self.generate_during_eval = args.generate_during_eval
- self.label_pad_token_id = args.label_pad_token_id
- self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
- self.max_prompt_length = max_prompt_length
- self.truncation_mode = args.truncation_mode
- self.max_completion_length = max_completion_length
- self.precompute_ref_log_probs = args.precompute_ref_log_probs
-
- # Since ref_logs are precomputed on the first call to get_train/eval_dataloader
- # keep track of first called to avoid computation of future calls
- self._precomputed_train_ref_log_probs = False
- self._precomputed_eval_ref_log_probs = False
-
- # metric
- self._stored_metrics = defaultdict(lambda: defaultdict(list))
-
- # BCO parameter
- self.beta = args.beta
- self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
- self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
- if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
- logger.warning(
- "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
- "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
- "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
- "loss.",
- )
-
- # Underlying Distribution Matching argument
- self.embedding_func = embedding_func
- self.embedding_tokenizer = embedding_tokenizer
-
- # The trainer estimates the number of FLOPs [floating-point operations] using the number of elements in the
- # input tensor associated with the key "input_ids". However, in BCO, the sampled data does not include the
- # "input_ids" key. Instead, the available keys are "prompt_input_ids" and "completion_input_ids". As a result,
- # the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
- # operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
- # "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
- # issued.
- model.warnings_issued["estimate_tokens"] = True
-
- with PartialState().main_process_first():
- # Extract the prompt if needed
- train_dataset = train_dataset.map(
- maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from train dataset"
- )
- # Unpair the dataset if needed
- train_dataset = maybe_unpair_preference_dataset(
- train_dataset, args.dataset_num_proc, desc="Unpairing train dataset"
- )
- # Apply the chat template if needed
- train_dataset = train_dataset.map(
- maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc
- )
- if eval_dataset is not None:
- # Extract the prompt if needed
- eval_dataset = eval_dataset.map(
- maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from eval dataset"
- )
- # Unpair the dataset if needed
- eval_dataset = maybe_unpair_preference_dataset(
- eval_dataset, args.dataset_num_proc, desc="Unpairing eval dataset"
- )
- eval_dataset = eval_dataset.map(
- maybe_apply_chat_template,
- fn_kwargs={"tokenizer": processing_class},
- num_proc=args.dataset_num_proc,
- )
-
- # Tokenize and prepare the training datasets
- train_dataset = train_dataset.map(
- _tokenize,
- batched=True,
- fn_kwargs={"tokenizer": processing_class, "embedding_tokenizer": self.embedding_tokenizer},
- num_proc=args.dataset_num_proc,
- desc="Tokenizing train dataset",
- )
-
- # Prepare the datasets
- fn_kwargs = {
- "prefix": "",
- "is_encoder_decoder": self.is_encoder_decoder,
- "tokenizer": processing_class,
- "max_length": self.max_length,
- "truncation_mode": self.truncation_mode,
- "label_pad_token_id": self.label_pad_token_id,
- "max_prompt_length": self.max_prompt_length,
- "max_completion_length": self.max_completion_length,
- }
- train_dataset = train_dataset.map(
- _process_tokens,
- fn_kwargs=fn_kwargs,
- num_proc=args.dataset_num_proc,
- desc="Processing tokenized train dataset",
- )
-
- if eval_dataset is not None:
- # Tokenize
- eval_dataset = eval_dataset.map(
- _tokenize,
- fn_kwargs={"tokenizer": processing_class, "embedding_tokenizer": self.embedding_tokenizer},
- batched=True,
- num_proc=args.dataset_num_proc,
- desc="Tokenizing eval dataset",
- )
-
- # Process
- fn_kwargs = {
- "prefix": "",
- "is_encoder_decoder": self.is_encoder_decoder,
- "tokenizer": processing_class,
- "max_length": self.max_length,
- "truncation_mode": self.truncation_mode,
- "label_pad_token_id": self.label_pad_token_id,
- "max_prompt_length": self.max_prompt_length,
- "max_completion_length": self.max_completion_length,
- }
- eval_dataset = eval_dataset.map(
- _process_tokens,
- fn_kwargs=fn_kwargs,
- num_proc=args.dataset_num_proc,
- desc="Processing tokenized eval dataset",
- )
-
- desirable = train_dataset.filter(
- lambda x: x["label"], num_proc=args.dataset_num_proc, desc="Filtering desirable examples"
- )
- undesirable = train_dataset.filter(
- lambda x: not x["label"], num_proc=args.dataset_num_proc, desc="Filtering undesirable examples"
- )
-
- super().__init__(
- model=model,
- args=args,
- data_collator=data_collator,
- train_dataset=train_dataset,
- eval_dataset=eval_dataset,
- processing_class=processing_class,
- model_init=model_init,
- compute_metrics=compute_metrics,
- callbacks=callbacks,
- optimizers=optimizers,
- preprocess_logits_for_metrics=preprocess_logits_for_metrics,
- )
-
- # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
- # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
- # self.model_accepts_loss_kwargs to False to enable scaling.
- self.model_accepts_loss_kwargs = False
-
- # Add tags for models that have been loaded with the correct transformers version
- if hasattr(self.model, "add_model_tags"):
- self.model.add_model_tags(self._tag_names)
-
- if not hasattr(self, "accelerator"):
- raise AttributeError(
- "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
- )
-
- # Deepspeed Zero-3 does not support precompute_ref_log_probs
- if self.is_deepspeed_enabled:
- if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs:
- raise ValueError(
- "You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`."
- )
-
- if self.ref_model is None:
- if not (self.is_peft_model or self.precompute_ref_log_probs):
- raise ValueError(
- "No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`"
- )
- else:
- if self.is_deepspeed_enabled:
- self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
- else:
- self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
-
- self.running = RunningMoments(accelerator=self.accelerator)
-
- if self.embedding_func is None or args.resume_from_checkpoint:
- return
-
- chosen_embeddings = self._get_sample_prompt_embeddings(desirable, sample_size=self.args.prompt_sample_size)
- rejected_embeddings = self._get_sample_prompt_embeddings(undesirable, sample_size=self.args.prompt_sample_size)
-
- embeddings = torch.cat((chosen_embeddings, rejected_embeddings), dim=0)
- labels = torch.cat(
- (torch.ones_like(chosen_embeddings[:, 0]), torch.zeros_like(rejected_embeddings[:, 0])), dim=0
- )
-
- self.clf = LogisticRegression(class_weight="balanced").fit(
- embeddings.cpu().float().numpy(), labels.cpu().numpy()
- )
- chosen_mean = self.clf.score(
- chosen_embeddings.cpu().float().numpy(), torch.ones_like(chosen_embeddings[:, 0]).cpu().numpy()
- )
- rejected_mean = self.clf.score(
- rejected_embeddings.cpu().float().numpy(), torch.zeros_like(rejected_embeddings[:, 0]).cpu().numpy()
- )
- logger.info(f"UDM classifier training scores: chosen: {chosen_mean}, rejected: {rejected_mean}")
-
- @property
- def match_underlying_distribution(self):
- return self.embedding_func is not None and self.embedding_tokenizer is not None
-
- def _get_chosen_prob(self, prompt_embeddings: torch.FloatTensor) -> torch.FloatTensor:
- """
- Calculates the probability if the given prompt embedding is from desirable dataset. This function calculates
- the probability in the process and ensemble across processes.
- """
- dtype = prompt_embeddings.dtype
- device = prompt_embeddings.device
- rank = self.accelerator.process_index
-
- padded_prompt_embeddings = self.accelerator.pad_across_processes(
- prompt_embeddings, pad_index=self.embedding_tokenizer.pad_token_id
- )
- sample_size = padded_prompt_embeddings.shape[0]
- nonzero = padded_prompt_embeddings.mean(dim=1) != self.embedding_tokenizer.pad_token_id
- prompt_embeddings = self.accelerator.gather(padded_prompt_embeddings)
-
- # cannot predict for all empty values
- if prompt_embeddings.shape[0] == 0:
- return torch.tensor([], device=device, dtype=dtype)
-
- prob = self.clf.predict_proba(prompt_embeddings.cpu().float().numpy())[:, 1]
- prob = torch.as_tensor(prob, dtype=dtype, device=device)
- prob = self.accelerator.reduce(prob, reduction="mean")
-
- prob = prob[sample_size * rank : sample_size * (rank + 1)]
- prob = prob[nonzero]
-
- return prob
-
- def _vectorize_prompt(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor) -> torch.FloatTensor:
- """
- Replaces processing_class.pad_token_id to embedding_tokenizer.pad_token_id and applies self.embedding_func
- """
- input_ids = torch.where(
- input_ids == self.processing_class.pad_token_id,
- self.embedding_tokenizer.pad_token_id,
- input_ids,
- )
-
- with torch.no_grad():
- embeddings = self.embedding_func(
- input_ids=input_ids,
- attention_mask=attention_mask,
- )
-
- return embeddings
-
- def _get_prompt_embeddings(
- self, batch: dict[str, Union[list, torch.LongTensor]]
- ) -> tuple[torch.FloatTensor, torch.FloatTensor]:
- """Extract embeddings from frozen embedding model"""
-
- if not self.match_underlying_distribution:
- return None, None
-
- embeddings = self._vectorize_prompt(
- input_ids=batch["embedding_input_ids"],
- attention_mask=batch["embedding_attention_mask"],
- )
-
- labels = torch.tensor(batch["label"], dtype=torch.bool, device=embeddings.device)
- chosen_idx = torch.where(labels)[0]
- rejected_idx = torch.where(~labels)[0]
-
- chosen_embeddings = embeddings[chosen_idx, ...]
- rejected_embeddings = embeddings[rejected_idx, ...]
-
- return (chosen_embeddings, rejected_embeddings)
-
- def _get_sample_prompt_embeddings(self, dataset: Dataset, sample_size: int = 512) -> torch.FloatTensor:
- """
- Sample instances from dataset and get prompt embeddings. Used for density ratio classifier training.
- """
- n_samples = min(len(dataset), sample_size)
- rand_indices = np.random.choice(len(dataset), size=(n_samples,))
-
- embedding_dataset = dataset.select(rand_indices)
-
- dataloader_params = {
- "batch_size": self.args.per_device_train_batch_size,
- "collate_fn": self.data_collator,
- "num_workers": self.args.dataloader_num_workers,
- "pin_memory": self.args.dataloader_pin_memory,
- "shuffle": False,
- }
-
- # prepare dataloader
- data_loader = self.accelerator.prepare(DataLoader(embedding_dataset, **dataloader_params))
-
- with torch.no_grad():
- all_embeddings = torch.empty(0)
- for padded_batch in tqdm(iterable=data_loader, desc="Building sample prompt embeddings"):
- embeddings = self._vectorize_prompt(
- input_ids=padded_batch["embedding_input_ids"],
- attention_mask=padded_batch["embedding_attention_mask"],
- )
- embeddings = self.accelerator.gather_for_metrics(embeddings)
- all_embeddings = torch.cat((all_embeddings, embeddings.cpu()))
-
- return all_embeddings
-
- def _save_optimizer_and_scheduler(self, output_dir):
- output_dir = output_dir if output_dir is not None else self.args.output_dir
- super()._save_optimizer_and_scheduler(output_dir)
-
- if self.accelerator.is_main_process:
- # When saving optimizer and scheduler to checkpoint, save also the running delta object.
- self.running.save_to_json(os.path.join(output_dir, RUNNING_NAME))
-
- if self.match_underlying_distribution:
- joblib.dump(self.clf, os.path.join(output_dir, CLF_NAME), compress=True)
-
- def _load_optimizer_and_scheduler(self, checkpoint):
- if checkpoint is None:
- logger.warning_once(f"Missing Checkpoint {checkpoint}")
- return
-
- super()._load_optimizer_and_scheduler(checkpoint)
-
- # when loading optimizer and scheduler from checkpoint, also load the running delta object.
- running_file = os.path.join(checkpoint, RUNNING_NAME)
- if os.path.isfile(running_file):
- self.running = RunningMoments.load_from_json(self.accelerator, running_file)
-
- if self.match_underlying_distribution:
- clf_file = os.path.join(checkpoint, CLF_NAME)
- if os.path.isfile(clf_file):
- self.clf = joblib.load(clf_file)
-
- @contextmanager
- def null_ref_context(self):
- """Context manager for handling null reference model (that is, peft adapter manipulation)."""
- with (
- self.accelerator.unwrap_model(self.model).disable_adapter()
- if self.is_peft_model and not self.ref_adapter_name
- else nullcontext()
- ):
- if self.ref_adapter_name:
- self.model.set_adapter(self.ref_adapter_name)
- yield
- if self.ref_adapter_name:
- self.model.set_adapter(self.model_adapter_name or "default")
-
- def get_train_dataloader(self) -> DataLoader:
- """
- Returns the training [`~torch.utils.data.DataLoader`].
-
- Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`.
- """
-
- if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs:
- dataloader_params = {
- "batch_size": self.args.per_device_train_batch_size,
- "collate_fn": self.data_collator,
- "num_workers": self.args.dataloader_num_workers,
- "pin_memory": self.args.dataloader_pin_memory,
- "shuffle": False,
- }
-
- # prepare dataloader
- data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params))
- reference_completion_logps = []
-
- for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"):
- reference_completion_logp = self.compute_reference_log_probs(padded_batch)
-
- reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
- reference_completion_logps.append(reference_completion_logp.cpu())
-
- self.train_dataset = self.train_dataset.add_column(
- name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
- )
-
- self._precomputed_train_ref_log_probs = True
-
- return super().get_train_dataloader()
-
- def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
- """
- Returns the evaluation [`~torch.utils.data.DataLoader`].
-
- Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`.
-
- Args:
- eval_dataset (`torch.utils.data.Dataset`, *optional*):
- If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
- by the `model.forward()` method are automatically removed. It must implement `__len__`.
- """
- if eval_dataset is None and self.eval_dataset is None:
- raise ValueError("Trainer: evaluation requires an eval_dataset.")
- eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
-
- if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs:
- dataloader_params = {
- "batch_size": self.args.per_device_eval_batch_size,
- "collate_fn": self.data_collator,
- "num_workers": self.args.dataloader_num_workers,
- "pin_memory": self.args.dataloader_pin_memory,
- "shuffle": False,
- }
-
- # prepare dataloader
- data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))
-
- reference_completion_logps = []
-
- for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"):
- reference_completion_logp = self.compute_reference_log_probs(padded_batch)
-
- reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
- reference_completion_logps.append(reference_completion_logp.cpu())
-
- eval_dataset = eval_dataset.add_column(
- name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
- )
-
- # Save calculated reference_chosen_logps and reference_rejected_logps to the eval_dataset for subsequent runs
- if self.eval_dataset is not None:
- self.eval_dataset = eval_dataset
- self._precomputed_eval_ref_log_probs = True
-
- return super().get_eval_dataloader(eval_dataset=eval_dataset)
-
- def compute_reference_log_probs(self, padded_batch: dict) -> dict:
- """Computes log probabilities of the reference model for a single padded batch of a BCO specific dataset."""
- with torch.no_grad():
- if self.ref_model is None:
- with self.null_ref_context():
- if self.is_encoder_decoder:
- completion_logits = self.model(
- padded_batch["prompt_input_ids"],
- attention_mask=padded_batch["prompt_attention_mask"],
- decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
- labels=padded_batch["completion_labels"],
- ).logits
-
- else:
- completion_logits = self.model(
- padded_batch["completion_input_ids"],
- attention_mask=padded_batch["completion_attention_mask"],
- ).logits
-
- else:
- if self.is_encoder_decoder:
- completion_logits = self.ref_model(
- padded_batch["prompt_input_ids"],
- attention_mask=padded_batch["prompt_attention_mask"],
- decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
- labels=padded_batch["completion_labels"],
- ).logits
-
- else:
- completion_logits = self.ref_model(
- padded_batch["completion_input_ids"], attention_mask=padded_batch["completion_attention_mask"]
- ).logits
-
- completion_logps = self.get_batch_logps(
- completion_logits,
- padded_batch["completion_labels"],
- average_log_prob=False,
- is_encoder_decoder=self.is_encoder_decoder,
- label_pad_token_id=self.label_pad_token_id,
- )
-
- return completion_logps
-
- @staticmethod
- def get_batch_logps(
- logits: torch.FloatTensor,
- labels: torch.LongTensor,
- average_log_prob: bool = False,
- label_pad_token_id: int = -100,
- is_encoder_decoder: bool = False,
- ) -> torch.FloatTensor:
- """Compute the log probabilities of the given labels under the given logits.
-
- Args:
- logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
- labels:
- Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are
- ignored. Shape: (batch_size, sequence_length)
- average_log_prob:
- If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the
- log probabilities of the (non-masked) tokens.
- label_pad_token_id:
- The label value to ignore when computing log probabilities.
- is_encoder_decoder:
- Whether the model is an encoder-decoder model. If True, the labels are not shifted, and the logits are
- assumed to already be aligned with the labels. If False, the labels are shifted to the right by one
- position, and the logits are assumed to be aligned with the shifted labels.
-
- Returns:
- A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the
- given logits.
- """
- if logits.shape[:-1] != labels.shape:
- raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
-
- if not is_encoder_decoder:
- labels = labels[:, 1:].clone()
- logits = logits[:, :-1, :]
- else:
- # Fixes end-dec RuntimeError
- labels = labels.clone()
-
- loss_mask = labels != label_pad_token_id
-
- # dummy token; we'll ignore the losses on these tokens later
- labels[labels == label_pad_token_id] = 0
-
- per_token_logps = selective_log_softmax(logits, labels)
-
- if average_log_prob:
- return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
- else:
- return (per_token_logps * loss_mask).sum(-1)
-
- def forward(
- self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
- ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
- model_kwargs = (
- {
- "labels": batch["completion_labels"],
- "decoder_input_ids": batch.get("completion_decoder_input_ids"),
- }
- if self.is_encoder_decoder
- else {}
- )
- if self.aux_loss_enabled:
- model_kwargs["output_router_logits"] = True
-
- outputs = model(
- batch["completion_input_ids"],
- attention_mask=batch["completion_attention_mask"],
- **model_kwargs,
- )
- completion_logits = outputs.logits
-
- completion_logps = self.get_batch_logps(
- completion_logits,
- batch["completion_labels"],
- average_log_prob=False,
- is_encoder_decoder=self.is_encoder_decoder,
- label_pad_token_id=self.label_pad_token_id,
- )
-
- if completion_logps.shape[0] != len(batch["label"]):
- raise ValueError(
- "There is a mismatch between the number of examples in this batch and the number of "
- "examples for which an output sequence was predicted."
- )
-
- chosen_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is True]
- rejected_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is False]
-
- chosen_logps = completion_logps[chosen_idx, ...]
- rejected_logps = completion_logps[rejected_idx, ...]
-
- chosen_logits = completion_logits[chosen_idx, ...]
- rejected_logits = completion_logits[rejected_idx, ...]
-
- if self.aux_loss_enabled:
- return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, outputs.aux_loss)
- else:
- return (chosen_logps, rejected_logps, chosen_logits, rejected_logits)
-
- def _get_udm_weight(self, rejected_embeddings: torch.FloatTensor) -> torch.FloatTensor:
- prob_desirable = self._get_chosen_prob(rejected_embeddings)
- min_ratio = self.args.min_density_ratio
- max_ratio = self.args.max_density_ratio
-
- weight = (prob_desirable / (1 - prob_desirable + 1e-8)).clamp(min=min_ratio, max=max_ratio)
-
- return weight
-
- def bco_loss(
- self,
- policy_chosen_logps: torch.FloatTensor,
- policy_rejected_logps: torch.FloatTensor,
- reference_chosen_logps: torch.FloatTensor,
- reference_rejected_logps: torch.FloatTensor,
- chosen_embeddings: Optional[torch.FloatTensor],
- rejected_embeddings: Optional[torch.FloatTensor],
- do_train: bool = True,
- ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
- """Compute the BCO loss for a batch of policy and reference model log probabilities.
-
- Args:
- policy_chosen_logps:
- Log probabilities of the policy model for the chosen responses. Shape: (num(chosen) in batch_size,)
- policy_rejected_logps:
- Log probabilities of the policy model for the rejected responses. Shape: (num(rejected) in batch_size,)
- reference_chosen_logps:
- Log probabilities of the reference model for the chosen responses. Shape: (num(chosen) in batch_size,)
- reference_rejected_logps:
- Log probabilities of the reference model for the rejected responses. Shape: (num(rejected) in
- batch_size,)
- chosen_embeddings: embeddings of desirable prompts
- rejected_embeddings: embeddings of undesirable prompts
- do_train: whether to update the running delta value. Default is True.
-
- Returns:
- A tuple of four tensors: (losses, chosen_rewards, rejected_rewards, delta). The losses tensor contains the
- BCO loss for each example in the batch. The chosen_rewards and rejected_rewards tensors contain the rewards
- for the chosen and rejected responses, respectively. The delta value contains the moving average of all
- implicit rewards.
- """
-
- chosen_logratios = policy_chosen_logps - reference_chosen_logps
- chosen_rewards = self.beta * chosen_logratios
-
- rejected_logratios = policy_rejected_logps - reference_rejected_logps
- rejected_rewards = self.beta * rejected_logratios
-
- if do_train:
- self.running.update(torch.cat((chosen_rewards, rejected_rewards), 0).detach())
- delta = torch.as_tensor(self.running.mean, device=chosen_rewards.device)
-
- chosen_losses = -F.logsigmoid(chosen_rewards - delta)
- rejected_losses = -F.logsigmoid(-(rejected_rewards - delta))
-
- if self.match_underlying_distribution:
- chosen_weight = torch.ones_like(chosen_losses)
- rejected_weight = self._get_udm_weight(rejected_embeddings)
-
- losses = torch.cat((chosen_weight * chosen_losses, rejected_weight * rejected_losses), dim=0)
- else:
- losses = torch.cat((chosen_losses, rejected_losses), dim=0)
-
- return losses, chosen_rewards, rejected_rewards, delta
-
- def get_batch_loss_metrics(
- self,
- model,
- batch: dict[str, Union[list, torch.LongTensor]],
- do_train: bool = True,
- ):
- """Compute the BCO loss and other metrics for the given batch of inputs for train or test."""
- metrics = {}
- batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
-
- forward_output = self.forward(model, batch)
- (
- policy_chosen_logps,
- policy_rejected_logps,
- policy_chosen_logits,
- policy_rejected_logits,
- ) = forward_output[:4]
- if self.aux_loss_enabled:
- aux_loss = forward_output[4]
-
- # if reference_logps in batch use them, otherwise use the reference model
- if "reference_logps" in batch:
- chosen_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is True]
- rejected_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is False]
-
- reference_chosen_logps = batch["reference_logps"][chosen_idx, ...]
- reference_rejected_logps = batch["reference_logps"][rejected_idx, ...]
- else:
- with torch.no_grad():
- if self.ref_model is None:
- with self.null_ref_context():
- (
- reference_chosen_logps,
- reference_rejected_logps,
- _,
- _,
- ) = self.forward(self.model, batch)[:4]
- else:
- (
- reference_chosen_logps,
- reference_rejected_logps,
- _,
- _,
- ) = self.forward(self.ref_model, batch)[:4]
-
- chosen_embeddings, rejected_embeddings = self._get_prompt_embeddings(batch)
-
- losses, chosen_rewards, rejected_rewards, delta = self.bco_loss(
- policy_chosen_logps,
- policy_rejected_logps,
- reference_chosen_logps,
- reference_rejected_logps,
- chosen_embeddings,
- rejected_embeddings,
- do_train=do_train,
- )
- metrics["delta"] = self.accelerator.gather_for_metrics(delta).mean().item()
-
- num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device)
- num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device)
-
- all_num_chosen = self.accelerator.gather_for_metrics(num_chosen).sum().item()
- all_num_rejected = self.accelerator.gather_for_metrics(num_rejected).sum().item()
-
- if all_num_chosen > 0:
- metrics["rewards/chosen_sum"] = (
- self.accelerator.gather_for_metrics(chosen_rewards.nansum()).nansum().item()
- )
- metrics["logps/chosen_sum"] = (
- self.accelerator.gather_for_metrics(policy_chosen_logps.nansum()).nansum().item()
- )
- metrics["logits/chosen_sum"] = (
- self.accelerator.gather_for_metrics(policy_chosen_logits.nansum()).nansum().item()
- )
- metrics["count/chosen"] = all_num_chosen
-
- if all_num_rejected > 0:
- metrics["rewards/rejected_sum"] = (
- self.accelerator.gather_for_metrics(rejected_rewards.nansum()).nansum().item()
- )
- metrics["logps/rejected_sum"] = (
- self.accelerator.gather_for_metrics(policy_rejected_logps.nansum()).nansum().item()
- )
- metrics["logits/rejected_sum"] = (
- self.accelerator.gather_for_metrics(policy_rejected_logits.nansum()).nansum().item()
- )
- metrics["count/rejected"] = all_num_rejected
-
- loss = losses.nanmean()
- if self.aux_loss_enabled:
- loss += self.aux_loss_coef * aux_loss
-
- return loss, metrics
-
- def compute_loss(
- self,
- model: Union[PreTrainedModel, nn.Module],
- inputs: dict[str, Union[torch.Tensor, Any]],
- return_outputs=False,
- num_items_in_batch=None,
- ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
- compute_loss_context_manager = (
- autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
- )
-
- with compute_loss_context_manager:
- loss, metrics = self.get_batch_loss_metrics(model, inputs)
-
- # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
- loss = loss.to(self.args.device)
- # force log the metrics
- if self.accelerator.is_main_process:
- self.store_metrics(metrics, train_eval="train")
-
- if return_outputs:
- return (loss, metrics)
- return loss
-
- def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
- for key, value in metrics.items():
- self._stored_metrics[train_eval][key].append(value)
-
- def _get_train_sampler(self, dataset: Optional[Dataset] = None) -> Optional[torch.utils.data.Sampler]:
- if dataset is None:
- dataset = self.train_dataset
- if dataset is None or not has_length(dataset):
- return None
- return SequentialSampler(dataset)
-
- def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]:
- """Generate samples from the model and reference model for the given batch of inputs."""
-
- # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
- # the torch amp context manager as some hidden states are silently casted to full precision.
- generate_context_manager = (
- autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
- )
- with generate_context_manager:
- policy_output = model.generate(
- input_ids=batch["prompt_input_ids"],
- attention_mask=batch["prompt_attention_mask"],
- max_length=self.max_length,
- do_sample=True,
- pad_token_id=self.processing_class.pad_token_id,
- )
-
- # if reference_output in batch use that otherwise use the reference model
- if "reference_output" in batch:
- reference_output = batch["reference_output"]
- else:
- if self.ref_model is None:
- with self.null_ref_context():
- reference_output = self.model.generate(
- input_ids=batch["prompt_input_ids"],
- attention_mask=batch["prompt_attention_mask"],
- max_length=self.max_length,
- do_sample=True,
- pad_token_id=self.processing_class.pad_token_id,
- )
- else:
- reference_output = self.ref_model.generate(
- input_ids=batch["prompt_input_ids"],
- attention_mask=batch["prompt_attention_mask"],
- max_length=self.max_length,
- do_sample=True,
- pad_token_id=self.processing_class.pad_token_id,
- )
-
- policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
- policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
-
- reference_output = pad_to_length(reference_output, self.max_length, self.processing_class.pad_token_id)
- reference_output_decoded = self.processing_class.batch_decode(reference_output, skip_special_tokens=True)
-
- return policy_output_decoded, reference_output_decoded
-
- def prediction_step(
- self,
- model: Union[PreTrainedModel, nn.Module],
- inputs: dict[str, Union[torch.Tensor, Any]],
- prediction_loss_only: bool,
- ignore_keys: Optional[list[str]] = None,
- ):
- if ignore_keys is None:
- if hasattr(model, "config"):
- ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
- else:
- ignore_keys = []
-
- prediction_context_manager = (
- autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
- )
- with torch.no_grad(), prediction_context_manager:
- loss, metrics = self.get_batch_loss_metrics(model, inputs, do_train=False)
-
- # force log the metrics
- if self.accelerator.is_main_process:
- self.store_metrics(metrics, train_eval="eval")
-
- if prediction_loss_only:
- return (loss.detach(), None, None)
-
- # logits for the chosen and rejected samples from model
- logits_dict = {}
- if "logits/chosen_sum" in metrics:
- logits_dict["eval_logits/chosen"] = metrics["logits/chosen_sum"]
- if "logits/rejected_sum" in metrics:
- logits_dict["eval_logits/rejected"] = metrics["logits/rejected_sum"]
- logits = [v for k, v in logits_dict.items() if k not in ignore_keys]
- logits = torch.tensor(logits, device=self.accelerator.device)
- labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
-
- return (loss.detach(), logits, labels)
-
- def evaluation_loop(
- self,
- dataloader: DataLoader,
- description: str,
- prediction_loss_only: Optional[bool] = None,
- ignore_keys: Optional[list[str]] = None,
- metric_key_prefix: str = "eval",
- ) -> EvalLoopOutput:
- """
- Overriding built-in evaluation loop to store metrics for each batch. Prediction/evaluation loop, shared by
- `Trainer.evaluate()` and `Trainer.predict()`.
-
- Works both with or without labels.
- """
-
- # Sample and save to game log if requested (for one batch to save time)
- if self.generate_during_eval:
- # Generate random indices within the range of the total number of samples
- num_samples = len(dataloader.dataset)
- random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
-
- # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
- random_batch_dataset = dataloader.dataset.select(random_indices)
- random_batch = self.data_collator(random_batch_dataset)
- random_batch = self._prepare_inputs(random_batch)
-
- target_labels = torch.tensor(random_batch["label"], dtype=torch.bool, device=self.accelerator.device)
- target_indices = torch.where(~target_labels)[0]
- target_batch = {
- "prompt_input_ids": random_batch["prompt_input_ids"][target_indices],
- "prompt_attention_mask": random_batch["prompt_attention_mask"][target_indices],
- "prompt": itemgetter(*target_indices)(random_batch["prompt"]),
- }
- policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, target_batch)
-
- table = pd.DataFrame(
- columns=["Prompt", "Policy", "Ref Model"],
- data=[
- [prompt, pol[len(prompt) :], ref[len(prompt) :]]
- for prompt, pol, ref in zip(target_batch["prompt"], policy_output_decoded, ref_output_decoded)
- ],
- )
- if "wandb" in self.args.report_to:
- wandb.log({"game_log": wandb.Table(data=table)})
-
- if "comet_ml" in self.args.report_to:
- log_table_to_comet_experiment(
- name="game_log.csv",
- table=table,
- )
-
- # Base evaluation
- initial_output = super().evaluation_loop(
- dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
- )
-
- return initial_output
-
- def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
- """
- Log `logs` on the various objects watching training, including stored metrics.
-
- Args:
- logs (`dict[str, float]`):
- The values to log.
- start_time (`float`, *optional*):
- Start time of the training.
- """
- # logs either has 'loss' or 'eval_loss'
- train_eval = "train" if "loss" in logs else "eval"
- # train metrics should have no prefix, eval should have 'eval_'
- prefix = "eval_" if train_eval == "eval" else ""
- # accumulate average metrics from sums and lengths
- for split in ["chosen", "rejected"]:
- if f"count/{split}" in self._stored_metrics[train_eval]:
- count_sum = torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"]).sum().item()
- for metric in ["rewards", "logps", "logits"]:
- logs[f"{prefix}{metric}/{split}"] = (
- torch.Tensor(self._stored_metrics[train_eval][f"{metric}/{split}_sum"]).sum().item()
- / count_sum
- )
- # delete obsolete metric
- del self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
- del self._stored_metrics[train_eval][f"count/{split}"]
- # calculate reward margin
- if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs:
- logs[f"{prefix}rewards/margins"] = logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
- # Add averaged stored metrics to logs
- for key, metrics in self._stored_metrics[train_eval].items():
- logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
- del self._stored_metrics[train_eval]
- return super().log(logs, start_time)
-
- # Ensure the model card is saved along with the checkpoint
- def _save_checkpoint(self, model, trial):
- if self.args.hub_model_id is None:
- model_name = Path(self.args.output_dir).name
- else:
- model_name = self.args.hub_model_id.split("/")[-1]
- self.create_model_card(model_name=model_name)
- super()._save_checkpoint(model, trial)
-class UnslothBCOTrainer(_UnslothBCOTrainer):
- """
-
- Initialize BCOTrainer from [BCO](https://huggingface.co/papers/2404.04656) paper.
-
- Args:
- model ([`~transformers.PreTrainedModel`]):
- The model to train, preferably an [`~transformers.AutoModelForSequenceClassification`].
- ref_model ([`PreTrainedModelWrapper`]):
- Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation
- and loss. If no reference model is provided, the trainer will create a reference model with the same
- architecture as the model to be optimized.
- args ([`BCOConfig`]):
- The arguments to use for training.
- train_dataset ([`~datasets.Dataset`]):
- The dataset to use for training.
- eval_dataset ([`~datasets.Dataset`]):
- The dataset to use for evaluation.
- processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*):
- Processing class used to process the data. If provided, will be used to automatically process the inputs
- for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
- reuse the fine-tuned model.
- data_collator ([`~transformers.DataCollator`], *optional*):
- The data collator to use for training. If None is specified, the default data collator
- ([`DPODataCollatorWithPadding`]) will be used which will pad the sequences to the maximum length of the
- sequences in the batch, given a dataset of paired sequences.
- model_init (`Callable[[], transformers.PreTrainedModel]`):
- The model initializer to use for training. If None is specified, the default model initializer will be
- used.
- callbacks (`list[transformers.TrainerCallback]`):
- The callbacks to use for training.
- optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
- The optimizer and scheduler to use for training.
- preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
- The function to use to preprocess the logits before computing the metrics.
- peft_config (`dict`, defaults to `None`):
- The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in
- a PEFT model.
- compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
- The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to
- metric values.
- model_adapter_name (`str`, defaults to `None`):
- Name of the train target PEFT adapter, when using LoRA with multiple adapters.
- ref_adapter_name (`str`, defaults to `None`):
- Name of the reference PEFT adapter, when using LoRA with multiple adapters.
-
- """
- def __init__(
- self,
- model = None,
- ref_model = None,
- args = None,
- train_dataset = None,
- eval_dataset = None,
- processing_class = None,
- data_collator = None,
- model_init = None,
- callbacks = None,
- preprocess_logits_for_metrics = None,
- peft_config = None,
- compute_metrics = None,
- model_adapter_name = None,
- ref_adapter_name = None,
- embedding_func = None,
- embedding_tokenizer = None,
- **kwargs
- ):
- if args is None: args = UnslothBCOConfig()
- use_bf16 = getattr(args, 'bf16', False)
- if type(use_bf16) is not bool: use_bf16 = False
- use_fp16 = getattr(args, 'fp16', False)
- if type(use_fp16) is not bool: use_fp16 = False
- force_float32 = False
- full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1'
- if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'):
- print('Unsloth: Switching to float32 training since model cannot work with float16')
- force_float32 = True
- mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
- dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None)
- if dtype is None: dtype = model.get_input_embeddings().weight.dtype
- from unsloth_zoo.utils import _get_dtype
- dtype = _get_dtype(dtype)
- float16 = dtype == torch.float16
- if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
- if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
- if force_float32:
- # Forced float32 training
- args.fp16 = False
- args.bf16 = False
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
- if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
- # args.mixed_precision is a new argument which needs to be set now
- elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
- # Mixed precision training
- args.fp16 = float16
- args.bf16 = not float16
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
- if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16'
- # args.mixed_precision is a new argument which needs to be set now
- elif mixed_precision_dtype == 'bfloat16':
- # Both False since bfloat16 full finetuning doesn't do any autocasting.
- args.fp16 = False
- args.bf16 = False
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
- if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
- # args.mixed_precision is a new argument which needs to be set now
-
- if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
- args.eval_strategy = 'steps'
- if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
- ga_steps = getattr(args, 'gradient_accumulation_steps', None)
- if ga_steps is not None and ga_steps > 1:
- from transformers import __version__ as transformers_version
- if Version(transformers_version) <= Version('4.45.2'):
- print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
- '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
- if getattr(args, 'eval_strategy', 'no') != 'no':
- eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
- if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
- if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
- fp16_full_eval = getattr(args, 'fp16_full_eval', False)
- if type(fp16_full_eval) is not bool: fp16_full_eval = False
- bf16_full_eval = getattr(args, 'bf16_full_eval', False)
- if type(bf16_full_eval) is not bool: bf16_full_eval = False
- if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
- if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
- if force_float32:
- args.bf16_full_eval = False
- args.fp16_full_eval = False
- elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
- args.bf16_full_eval = True
- args.fp16_full_eval = False
- elif not bf16_full_eval and not fp16_full_eval:
- args.bf16_full_eval = args.bf16
- args.fp16_full_eval = args.fp16
- _output_logits = False
- if locals().get('compute_metrics', None) is not None: _output_logits = True
- if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
- if _output_logits:
- os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
- if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
- pass
- else:
- model_max_seq_length = getattr(model, 'max_seq_length', None)
- args_max_seq_length = getattr(args, 'max_seq_length', None)
- if args_max_seq_length is None and model_max_seq_length is not None:
- max_seq_length = model.max_seq_length
- if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
- elif args_max_seq_length is not None and model_max_seq_length is not None:
- if args_max_seq_length > model_max_seq_length:
- print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but '
- 'the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.')
- args.max_seq_length = model_max_seq_length
- if model is not None and hasattr(model, 'for_training'):
- model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
- if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
- if 'processing_class' in locals():
- if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
- if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
- __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
- from unsloth_zoo.vision_utils import UnslothVisionDataCollator
- if not isinstance(data_collator, UnslothVisionDataCollator):
- if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
- data_collator = TransformersDataCollatorForLanguageModeling(
- __tokenizer,
- mlm = False,
- mlm_probability = 0.0,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
- data_collator = DataCollatorForSeq2Seq(
- __tokenizer,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- else:
- if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
- if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
- if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
- if not isinstance(data_collator, UnslothVisionDataCollator):
- if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
- if isinstance(data_collator, DataCollatorForSeq2Seq):
- data_collator = DataCollatorForSeq2Seq(
- __tokenizer.tokenizer,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- else:
- data_collator = TransformersDataCollatorForLanguageModeling(
- __tokenizer.tokenizer,
- mlm = False,
- mlm_probability = 0.0,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- other_metrics = []
-
- from unsloth_zoo.logging_utils import PatchRLStatistics
- PatchRLStatistics('bco_trainer', other_metrics)
-
- # [TODO] Fix up DataParallel multiplying batch sizes
- # [TODO] DDP works, but DP seems to not work? [TODO]
- if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1:
- if getattr(args, "_n_gpu", 1) != 1:
- args._n_gpu = 1
- if "model" in locals() and hasattr(model, "for_training"):
- model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
- super().__init__(
- model = model,
- ref_model = ref_model,
- args = args,
- train_dataset = train_dataset,
- eval_dataset = eval_dataset,
- processing_class = processing_class,
- data_collator = data_collator,
- model_init = model_init,
- callbacks = callbacks,
- preprocess_logits_for_metrics = preprocess_logits_for_metrics,
- peft_config = peft_config,
- compute_metrics = compute_metrics,
- model_adapter_name = model_adapter_name,
- ref_adapter_name = ref_adapter_name,
- embedding_func = embedding_func,
- embedding_tokenizer = embedding_tokenizer,**kwargs)
- if "model" in locals() and hasattr(model, "for_inference"):
- model.for_inference()
- if hasattr(self, 'neftune_hook_handle'):
- self.neftune_hook_handle.remove()
- if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
- if getattr(args, 'neftune_noise_alpha', None) is not None:
- model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
- pass
- if hasattr(self, 'accelerator'):
- scaler = self.accelerator.scaler
- current_model = model
- while hasattr(current_model, 'model'):
- current_model.accelerator_scaler = scaler
- current_model = current_model.model
- current_model.accelerator_scaler = scaler
- pass
- if hasattr(self, 'train'):
- self.train = MethodType(prepare_for_training_mode(self.__class__.train), self)
- pass
- if hasattr(self, 'llm') and self.llm is not None and hasattr(self.llm, 'get_tokenizer'):
- _vllm_tok = self.llm.get_tokenizer()
- _pc = getattr(self, 'processing_class', None) or getattr(self, 'tokenizer', None)
- if _vllm_tok is not None and _pc is not None and getattr(_pc, 'chat_template', None) is not None and getattr(_vllm_tok, 'chat_template', None) is None:
- _vllm_tok.chat_template = _pc.chat_template
- pass
-
-pass
-
-
-if hasattr(logger, "addFilter"):
- import logging
- class HideLoggingMessage(logging.Filter):
- def __init__(self, text): self.text = text
- def filter(self, x): return not (self.text in x.getMessage())
- pass
- logger.addFilter(HideLoggingMessage("`use_cache=True`"))
-
diff --git a/unsloth_compiled_cache/UnslothCPOTrainer.py b/unsloth_compiled_cache/UnslothCPOTrainer.py
deleted file mode 100644
index eee10fe..0000000
--- a/unsloth_compiled_cache/UnslothCPOTrainer.py
+++ /dev/null
@@ -1,1960 +0,0 @@
-"""
-2026.2.1
-2026.2.1
-4.57.6
-0.24.0
-__UNSLOTH_VERSIONING__
-"""
-
-# Unsloth auto generated code
-# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Lesser General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU Lesser General Public License
-# along with this program. If not, see .
-
-from torch import Tensor
-import torch
-import torch.nn as nn
-from torch.nn import functional as F
-from unsloth_zoo.temporary_patches.common import torch_compile
-from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
-from trl.trainer.cpo_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, BaseTrainer, CPOConfig, CPOTrainer, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, Optional, PartialState, Path, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, Union, add_bos_token_if_needed, add_eos_token_if_needed, autocast, defaultdict, disable_dropout_in_model, inspect, is_comet_available, is_peft_available, is_torch_fx_proxy, is_wandb_available, log_table_to_comet_experiment, logger, logging, maybe_apply_chat_template, maybe_extract_prompt, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, selective_log_softmax, textwrap, torch, warnings, AutoModelForCausalLM, BaseImageProcessor, CPOConfig, CPOTrainer, Callable, DPODataCollatorWithPadding, DataCollator, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, Union, autocast, defaultdict, disable_dropout_in_model, inspect, is_comet_available, is_peft_available, is_wandb_available, logger, maybe_apply_chat_template, maybe_extract_prompt, nn, np, os, peft_module_casting_to_bf16, prepare_model_for_kbit_training, torch, warnings, F, Optional, PeftModel, PreTrainedModel, is_peft_available, logger, os, torch)
-
-
-import os
-from typing import *
-from dataclasses import dataclass, field
-from packaging.version import Version
-import torch
-import numpy as np
-from contextlib import nullcontext
-from torch.nn import functional as F
-import inspect
-from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
-from transformers.training_args import ParallelMode
-from unsloth_zoo.device_type import DEVICE_TYPE, device_synchronize
-
-# Wrap trainer with padding to right and enable training mode
-# Also patches W&B since multiple runs must use wandb.finish()
-import functools
-from types import MethodType
-try:
- from unsloth_zoo.gradient_checkpointing import reset_unsloth_gradient_checkpointing_buffers
-except:
- def reset_unsloth_gradient_checkpointing_buffers(): pass
-def prepare_for_training_mode(f):
- @functools.wraps(f)
- def wrapper(self, *args, **kwargs):
- # Enable training mode
- _was_training = None
- # Get gradient checkpointing setting from training arguments
- use_gc = getattr(self.args, 'gradient_checkpointing', True)
- if hasattr(self, 'model') and hasattr(self.model, "training"):
- _was_training = self.model.training
- if hasattr(self, 'model') and hasattr(self.model, "for_training"):
- self.model.for_training(use_gradient_checkpointing=use_gc)
- output = f(self, *args, **kwargs)
- # Restore previous mode when possible
- if hasattr(self, 'model') and hasattr(self.model, "for_inference"):
- if _was_training is False:
- self.model.for_inference()
- elif _was_training is True and hasattr(self.model, "for_training"):
- self.model.for_training(use_gradient_checkpointing=use_gc)
- # Reset gradient checkpointing buffers to free memory while staying ready for next run
- try:
- reset_unsloth_gradient_checkpointing_buffers()
- except:
- pass
- # Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run
- try:
- import wandb
- wandb.finish()
- except:
- pass
- return output
- return wrapper
-pass
-
-torch_compile_options = {
- "epilogue_fusion" : True,
- "max_autotune" : False,
- "shape_padding" : True,
- "trace.enabled" : False,
- "triton.cudagraphs" : False,
-}
-
-@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
-def chunked_hidden_states_selective_log_softmax(
- hidden_states: torch.Tensor,
- lm_head: torch.Tensor,
- index: torch.Tensor,
- chunks: int = 4,
- logit_scale_multiply: float = 0.0,
- logit_scale_divide: float = 0.0,
- logit_softcapping: float = 0.0,
- temperature: float = 1.0,
-) -> torch.Tensor:
- # All Unsloth Zoo code licensed under AGPL3
- flat_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1])
- flat_index = index.reshape(-1)
-
- chunked_hidden_states = torch.chunk(flat_hidden_states, chunks=chunks, dim=0)
- chunked_index = torch.chunk(flat_index, chunks=chunks, dim=0)
-
- all_per_token_logps = []
-
- for chunk_hidden_states, chunk_index in zip(chunked_hidden_states, chunked_index):
- chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t()
-
- if logit_scale_multiply != 0.0:
- chunk_logits = chunk_logits * logit_scale_multiply
- if logit_scale_divide != 0.0:
- chunk_logits = chunk_logits / logit_scale_divide
- if logit_softcapping != 0.0:
- chunk_logits = chunk_logits * torch.tanh(chunk_logits / logit_softcapping)
-
- chunk_logits = chunk_logits.to(torch.float32)
-
- if temperature != 1.0:
- chunk_logits = chunk_logits / temperature
-
- selected_logits = torch.gather(chunk_logits, dim=-1, index=chunk_index.unsqueeze(-1)).squeeze(-1)
- logsumexp_values = torch.logsumexp(chunk_logits, dim=-1)
- per_token_logps = selected_logits - logsumexp_values
- all_per_token_logps.append(per_token_logps)
-
- all_per_token_logps = torch.concat(all_per_token_logps)
-
- all_per_token_logps = all_per_token_logps.reshape((hidden_states.shape[0], hidden_states.shape[1]))
- return all_per_token_logps
-
-@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
-def chunked_selective_log_softmax(logits, index):
- # Split into 4 chunks only
- chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
- chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
- all_per_token_logps = []
- # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
- for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
- chunk_logits = chunk_logits.to(torch.float32)
- selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
- logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
- per_token_logps = selected_logits - logsumexp_values
- all_per_token_logps.append(per_token_logps)
- pass
- all_per_token_logps = torch.concat(all_per_token_logps)
- all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
- return all_per_token_logps
-
-def calculate_pad_tokens_in_prompt(
- input_ids: torch.Tensor,
- logits_to_keep: int,
- pad_token_id: int
-) -> torch.Tensor:
- """
- Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens
- """
- if logits_to_keep >= input_ids.shape[1]:
- raise ValueError("logits_to_keep must be smaller than the sequence length.")
-
- prompt_section = input_ids[:, :-logits_to_keep]
-
- padding_mask = (prompt_section == pad_token_id)
-
- pad_token_counts = padding_mask.sum(dim=1)
-
- return pad_token_counts
-
-def create_completion_attention_mask(
- completion_input_ids: torch.Tensor,
- left_pad_tokens_per_prompt: torch.Tensor,
- max_left_pad: int,
- pad_token_id: int
-) -> torch.Tensor:
- """
- Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad]
-
- Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens
- and pad are pad tokens, this function would make a completion mask that would 0 out the pad
- and p tokens. so in this example [0,0,0,1,1,1,0,0,0]
- """
- batch_size, completion_len = completion_input_ids.shape
- device = completion_input_ids.device
-
- num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt
-
- indices = torch.arange(completion_len, device=device).unsqueeze(0)
- shift_mask = indices >= num_tokens_to_mask.unsqueeze(1)
-
- non_padding_mask = (completion_input_ids != pad_token_id)
-
- final_mask = shift_mask & non_padding_mask
-
- return final_mask
-
-def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor:
- """
- Moves all padding tokens in each sequence of a batch to the right.
- """
- mask = (tensor != pad_id)
- # Must do stable=True since binary mark is unordered
- sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True)
- packed_tensor = torch.gather(tensor, 1, sorted_indices)
- return packed_tensor
-
-def align_logprobs_with_mask(
- logprob_tensor: torch.Tensor,
- attention_mask: torch.Tensor,
- pad_value: float = 0.0
-) -> torch.Tensor:
- """
- Aligns a log probability tensor with a given attention mask.
- """
-
- device = logprob_tensor.device
- batch_size, logprob_seq_len = logprob_tensor.shape
- mask_seq_len = attention_mask.shape[1]
-
- padded_logprobs = torch.full(
- attention_mask.shape,
- fill_value=pad_value,
- dtype=logprob_tensor.dtype,
- device=device
- )
-
- left_pad_counts = torch.argmax(attention_mask, dim=1)
-
- cols = torch.arange(logprob_seq_len, device=device)
- dest_indices = left_pad_counts.unsqueeze(1) + cols
-
- # Create destination row indices
- # Shape: [batch_size, logprob_seq_len]
- row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices)
-
- # --- 4. Filter out-of-bounds indices and perform assignment ---
- # Create a mask to identify only the indices that are within the bounds
- # of the target tensor's sequence length.
- valid_mask = dest_indices < mask_seq_len
-
- # Use this mask to select only the valid row indices, column indices,
- # and the corresponding values from the logprob tensor.
- # This flattens the selected elements into 1D tensors.
- valid_rows = row_indices[valid_mask]
- valid_cols = dest_indices[valid_mask]
- valid_vals = logprob_tensor[valid_mask]
-
- # Place the valid values into their correct positions in the padded tensor
- # using a single, efficient advanced indexing operation.
- padded_logprobs[valid_rows, valid_cols] = valid_vals
-
- return padded_logprobs
-
-def autotune_batch_and_chunks(
- total_input_rows,
- seq_len,
- hidden_size,
- vocab_size,
- dtype_bytes=16,
- multiplier=None
-):
- if multiplier is None:
- final_m = max(4, seq_len // 4096)
- else:
- final_m = multiplier
-
- if torch.cuda.is_available():
- free_bytes, _ = torch.cuda.mem_get_info()
- limit_gb = (free_bytes / (1024**3))*.80
- elif hasattr(torch, "xpu") and torch.xpu.is_available():
- # For XPU: estimate free memory from total - reserved
- total_mem = torch.xpu.get_device_properties(0).total_memory
- reserved_mem = torch.xpu.memory_reserved()
- free_bytes = total_mem - reserved_mem
- limit_gb = (free_bytes / (1024**3)) * 0.80
- else:
- # Fallback: assume 8GB available
- limit_gb = 8.0
-
- bytes_to_gb = 1024**3
-
- b_vals = torch.arange(total_input_rows, 0, -1, device='cpu', dtype=torch.float32)
-
- hidden_gb = (b_vals * seq_len * hidden_size * dtype_bytes) / bytes_to_gb
-
- base_logits = ((b_vals/total_input_rows) * b_vals * seq_len * vocab_size * dtype_bytes) / bytes_to_gb
- logits_gb = base_logits / final_m
-
- total_mem_gb = hidden_gb + logits_gb
-
- valid_mask = total_mem_gb <= limit_gb
- valid_indices = torch.nonzero(valid_mask, as_tuple=False)
-
- if valid_indices.shape[0] == 0:
- #This means your GPU will OOM
- return 4, final_m
-
- best_idx = valid_indices[0].item()
- final_b = int(b_vals[best_idx].item())
-
- return final_b, final_m
-@dataclass
-class UnslothCPOConfig(CPOConfig):
- """
-
- Configuration class for the [`CPOTrainer`].
-
- This class includes only the parameters that are specific to CPO training. For a full list of training arguments,
- please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may
- differ from those in [`~transformers.TrainingArguments`].
-
- Using [`~transformers.HfArgumentParser`] we can turn this class into
- [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
- command line.
-
- Parameters:
- max_length (`int` or `None`, *optional*, defaults to `1024`):
- Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
- to use the default data collator.
- max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
- Maximum length of the prompt. This argument is required if you want to use the default data collator.
- max_completion_length (`int`, *optional*):
- Maximum length of the completion. This argument is required if you want to use the default data collator
- and your model is an encoder-decoder.
- beta (`float`, *optional*, defaults to `0.1`):
- Parameter controlling the deviation from the reference model. Higher β means less deviation from the
- reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in
- the [paper](https://huggingface.co/papers/2310.12036).
- label_smoothing (`float`, *optional*, defaults to `0.0`):
- Label smoothing factor. This argument is required if you want to use the default data collator.
- loss_type (`str`, *optional*, defaults to `"sigmoid"`):
- Type of loss to use. Possible values are:
-
- - `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper.
- - `"hinge"`: hinge loss on the normalized likelihood from the
- [SLiC](https://huggingface.co/papers/2305.10425) paper.
- - `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper.
- - `"simpo"`: SimPO loss from the [SimPO](https://huggingface.co/papers/2405.14734) paper.
- - `"alphapo"`: AlphaPO loss from the [AlphaPO](https://huggingface.co/papers/2501.03884) paper. This
- automatically sets `loss_type="simpo"` and `cpo_alpha=0.0`.
-
- disable_dropout (`bool`, *optional*, defaults to `True`):
- Whether to disable dropout in the model.
- cpo_alpha (`float`, *optional*, defaults to `1.0`):
- Weight of the BC regularizer in CPO training.
- simpo_gamma (`float`, *optional*, defaults to `0.5`):
- Target reward margin for the SimPO loss, used only when the `loss_type="simpo"`.
- alpha (`float`, *optional*, defaults to `0.0`):
- Alpha parameter that controls reward function shape across all loss types. When alpha=0 (default), uses
- standard log probability rewards. When `alpha != 0`, applies AlphaPO transformation: `r = (1 - p^(-alpha))
- / alpha` from the [AlphaPO paper](https://huggingface.co/papers/2501.03884). This parameter works with all
- loss types.
- label_pad_token_id (`int`, *optional*, defaults to `-100`):
- Label pad token id. This argument is required if you want to use the default data collator.
- padding_value (`int`, *optional*):
- Padding value to use. If `None`, the padding value of the tokenizer is used.
- truncation_mode (`str`,*optional*, defaults to `"keep_end"`):
- Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
- This argument is required if you want to use the default data collator.
- generate_during_eval (`bool`, *optional*, defaults to `False`):
- If `True`, generates and logs completions from the model to W&B or Comet during evaluation.
- is_encoder_decoder (`bool`, *optional*):
- When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
- you need to specify if the model returned by the callable is an encoder-decoder model.
- model_init_kwargs (`dict[str, Any]`, *optional*):
- Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
- string.
- dataset_num_proc (`int`, *optional*):
- Number of processes to use for processing the dataset.
-
- """
- vllm_sampling_params: Optional[Any] = field(
- default = None,
- metadata = {'help': 'vLLM SamplingParams'},
- )
- unsloth_num_chunks : Optional[int] = field(
- default = -1,
- metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
- )
- unsloth_logit_chunk_multiplier : Optional[int] = field(
- default = None,
- metadata = {'help': 'Multiplier for chunked logit computations.'},
- )
- unsloth_grpo_mini_batch : Optional[int] = field(
- default = None,
- metadata = {'help': 'Mini batch size for GRPO hidden state accumulation. Default is None unless user defines it.'},
- )
- max_seq_length : Optional[int] = field(
- default = None,
- metadata = {'help': 'Maximum sequence length to truncate to.'},
- )
- def __init__(
- self,
- output_dir = None,
- overwrite_output_dir = None,
- do_train = False,
- do_eval = False,
- do_predict = False,
- eval_strategy = 'no',
- prediction_loss_only = False,
- per_device_train_batch_size = 4,
- per_device_eval_batch_size = 4,
- per_gpu_train_batch_size = None,
- per_gpu_eval_batch_size = None,
- gradient_accumulation_steps = 2,
- eval_accumulation_steps = 2,
- eval_delay = 0,
- torch_empty_cache_steps = 250,
- learning_rate = 5e-05,
- weight_decay = 0.01,
- adam_beta1 = 0.9,
- adam_beta2 = 0.999,
- adam_epsilon = 1e-08,
- max_grad_norm = 1.0,
- num_train_epochs = 3.0,
- max_steps = -1,
- lr_scheduler_type = 'linear',
- lr_scheduler_kwargs = None,
- warmup_ratio = 0.1,
- warmup_steps = 0,
- log_level = 'passive',
- log_level_replica = 'warning',
- log_on_each_node = True,
- logging_dir = None,
- logging_strategy = 'steps',
- logging_first_step = False,
- logging_steps = 1,
- logging_nan_inf_filter = False,
- save_strategy = 'steps',
- save_steps = 500,
- save_total_limit = None,
- save_safetensors = True,
- save_on_each_node = False,
- save_only_model = False,
- restore_callback_states_from_checkpoint = False,
- no_cuda = False,
- use_cpu = False,
- use_mps_device = False,
- seed = 3407,
- data_seed = 3407,
- jit_mode_eval = False,
- bf16 = False,
- fp16 = False,
- fp16_opt_level = 'O1',
- half_precision_backend = 'auto',
- bf16_full_eval = False,
- fp16_full_eval = False,
- tf32 = None,
- local_rank = -1,
- ddp_backend = None,
- tpu_num_cores = None,
- tpu_metrics_debug = False,
- debug = '',
- dataloader_drop_last = False,
- eval_steps = None,
- dataloader_num_workers = 0,
- dataloader_prefetch_factor = None,
- past_index = -1,
- run_name = None,
- disable_tqdm = None,
- remove_unused_columns = True,
- label_names = None,
- load_best_model_at_end = False,
- metric_for_best_model = None,
- greater_is_better = None,
- ignore_data_skip = False,
- fsdp = None,
- fsdp_min_num_params = 0,
- fsdp_config = None,
- fsdp_transformer_layer_cls_to_wrap = None,
- accelerator_config = None,
- parallelism_config = None,
- deepspeed = None,
- label_smoothing_factor = 0.0,
- optim = 'adamw_8bit',
- optim_args = None,
- adafactor = False,
- group_by_length = False,
- length_column_name = 'length',
- report_to = 'none',
- project = 'huggingface',
- trackio_space_id = 'trackio',
- ddp_find_unused_parameters = None,
- ddp_bucket_cap_mb = None,
- ddp_broadcast_buffers = None,
- dataloader_pin_memory = True,
- dataloader_persistent_workers = False,
- skip_memory_metrics = True,
- use_legacy_prediction_loop = False,
- push_to_hub = False,
- resume_from_checkpoint = None,
- hub_model_id = None,
- hub_strategy = 'every_save',
- hub_token = None,
- hub_private_repo = None,
- hub_always_push = False,
- hub_revision = None,
- gradient_checkpointing = True,
- gradient_checkpointing_kwargs = None,
- include_inputs_for_metrics = False,
- eval_do_concat_batches = True,
- fp16_backend = 'auto',
- push_to_hub_model_id = None,
- push_to_hub_organization = None,
- push_to_hub_token = None,
- mp_parameters = '',
- auto_find_batch_size = False,
- full_determinism = False,
- torchdynamo = None,
- ray_scope = 'last',
- ddp_timeout = 1800,
- torch_compile = False,
- torch_compile_backend = None,
- torch_compile_mode = None,
- include_tokens_per_second = False,
- include_num_input_tokens_seen = False,
- neftune_noise_alpha = None,
- optim_target_modules = None,
- batch_eval_metrics = False,
- eval_on_start = False,
- use_liger_kernel = False,
- liger_kernel_config = None,
- eval_use_gather_object = False,
- average_tokens_across_devices = True,
- max_length = 1024,
- max_prompt_length = 512,
- max_completion_length = None,
- beta = 0.1,
- label_smoothing = 0.0,
- loss_type = 'sigmoid',
- disable_dropout = True,
- cpo_alpha = 1.0,
- simpo_gamma = 0.5,
- alpha = 0.0,
- label_pad_token_id = -100,
- padding_value = None,
- truncation_mode = 'keep_end',
- generate_during_eval = False,
- is_encoder_decoder = None,
- model_init_kwargs = None,
- dataset_num_proc = None,
- vllm_sampling_params = None,
- unsloth_num_chunks = -1,
- unsloth_logit_chunk_multiplier = None,
- unsloth_grpo_mini_batch = None,
- max_seq_length = None,
- **kwargs,
- ):
- if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
- if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
- if num_train_epochs is None:
- num_train_epochs = 3.0 # Default to 3 epochs if None, max_steps will override
- if output_dir is None and save_strategy == 'steps' and save_steps == 500:
- output_dir = 'unsloth_training_checkpoints'
- save_strategy = 'no'
- import multiprocessing as _mp
- if _mp.get_start_method() != 'fork':
- dataset_num_proc = None
- elif dataset_num_proc is None:
- import psutil
- dataset_num_proc = min(max((psutil.cpu_count() or 1)+4, 2), 64)
- memory_gb_left = psutil.virtual_memory().available / (1024**3)
- if memory_gb_left <= 2: dataset_num_proc = 1
- else: dataset_num_proc = min(dataset_num_proc, int(memory_gb_left))
-
- super().__init__(
- output_dir = output_dir,
- overwrite_output_dir = overwrite_output_dir,
- do_train = do_train,
- do_eval = do_eval,
- do_predict = do_predict,
- eval_strategy = eval_strategy,
- prediction_loss_only = prediction_loss_only,
- per_device_train_batch_size = per_device_train_batch_size,
- per_device_eval_batch_size = per_device_eval_batch_size,
- per_gpu_train_batch_size = per_gpu_train_batch_size,
- per_gpu_eval_batch_size = per_gpu_eval_batch_size,
- gradient_accumulation_steps = gradient_accumulation_steps,
- eval_accumulation_steps = eval_accumulation_steps,
- eval_delay = eval_delay,
- torch_empty_cache_steps = torch_empty_cache_steps,
- learning_rate = learning_rate,
- weight_decay = weight_decay,
- adam_beta1 = adam_beta1,
- adam_beta2 = adam_beta2,
- adam_epsilon = adam_epsilon,
- max_grad_norm = max_grad_norm,
- num_train_epochs = num_train_epochs,
- max_steps = max_steps,
- lr_scheduler_type = lr_scheduler_type,
- lr_scheduler_kwargs = lr_scheduler_kwargs,
- warmup_ratio = warmup_ratio,
- warmup_steps = warmup_steps,
- log_level = log_level,
- log_level_replica = log_level_replica,
- log_on_each_node = log_on_each_node,
- logging_dir = logging_dir,
- logging_strategy = logging_strategy,
- logging_first_step = logging_first_step,
- logging_steps = logging_steps,
- logging_nan_inf_filter = logging_nan_inf_filter,
- save_strategy = save_strategy,
- save_steps = save_steps,
- save_total_limit = save_total_limit,
- save_safetensors = save_safetensors,
- save_on_each_node = save_on_each_node,
- save_only_model = save_only_model,
- restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
- no_cuda = no_cuda,
- use_cpu = use_cpu,
- use_mps_device = use_mps_device,
- seed = seed,
- data_seed = data_seed,
- jit_mode_eval = jit_mode_eval,
- bf16 = bf16,
- fp16 = fp16,
- fp16_opt_level = fp16_opt_level,
- half_precision_backend = half_precision_backend,
- bf16_full_eval = bf16_full_eval,
- fp16_full_eval = fp16_full_eval,
- tf32 = tf32,
- local_rank = local_rank,
- ddp_backend = ddp_backend,
- tpu_num_cores = tpu_num_cores,
- tpu_metrics_debug = tpu_metrics_debug,
- debug = debug,
- dataloader_drop_last = dataloader_drop_last,
- eval_steps = eval_steps,
- dataloader_num_workers = dataloader_num_workers,
- dataloader_prefetch_factor = dataloader_prefetch_factor,
- past_index = past_index,
- run_name = run_name,
- disable_tqdm = disable_tqdm,
- remove_unused_columns = remove_unused_columns,
- label_names = label_names,
- load_best_model_at_end = load_best_model_at_end,
- metric_for_best_model = metric_for_best_model,
- greater_is_better = greater_is_better,
- ignore_data_skip = ignore_data_skip,
- fsdp = fsdp,
- fsdp_min_num_params = fsdp_min_num_params,
- fsdp_config = fsdp_config,
- fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
- accelerator_config = accelerator_config,
- parallelism_config = parallelism_config,
- deepspeed = deepspeed,
- label_smoothing_factor = label_smoothing_factor,
- optim = optim,
- optim_args = optim_args,
- adafactor = adafactor,
- group_by_length = group_by_length,
- length_column_name = length_column_name,
- report_to = report_to,
- project = project,
- trackio_space_id = trackio_space_id,
- ddp_find_unused_parameters = ddp_find_unused_parameters,
- ddp_bucket_cap_mb = ddp_bucket_cap_mb,
- ddp_broadcast_buffers = ddp_broadcast_buffers,
- dataloader_pin_memory = dataloader_pin_memory,
- dataloader_persistent_workers = dataloader_persistent_workers,
- skip_memory_metrics = skip_memory_metrics,
- use_legacy_prediction_loop = use_legacy_prediction_loop,
- push_to_hub = push_to_hub,
- resume_from_checkpoint = resume_from_checkpoint,
- hub_model_id = hub_model_id,
- hub_strategy = hub_strategy,
- hub_token = hub_token,
- hub_private_repo = hub_private_repo,
- hub_always_push = hub_always_push,
- hub_revision = hub_revision,
- gradient_checkpointing = gradient_checkpointing,
- gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
- include_inputs_for_metrics = include_inputs_for_metrics,
- eval_do_concat_batches = eval_do_concat_batches,
- fp16_backend = fp16_backend,
- push_to_hub_model_id = push_to_hub_model_id,
- push_to_hub_organization = push_to_hub_organization,
- push_to_hub_token = push_to_hub_token,
- mp_parameters = mp_parameters,
- auto_find_batch_size = auto_find_batch_size,
- full_determinism = full_determinism,
- torchdynamo = torchdynamo,
- ray_scope = ray_scope,
- ddp_timeout = ddp_timeout,
- torch_compile = torch_compile,
- torch_compile_backend = torch_compile_backend,
- torch_compile_mode = torch_compile_mode,
- include_tokens_per_second = include_tokens_per_second,
- include_num_input_tokens_seen = include_num_input_tokens_seen,
- neftune_noise_alpha = neftune_noise_alpha,
- optim_target_modules = optim_target_modules,
- batch_eval_metrics = batch_eval_metrics,
- eval_on_start = eval_on_start,
- use_liger_kernel = use_liger_kernel,
- liger_kernel_config = liger_kernel_config,
- eval_use_gather_object = eval_use_gather_object,
- average_tokens_across_devices = average_tokens_across_devices,
- max_length = max_length,
- max_prompt_length = max_prompt_length,
- max_completion_length = max_completion_length,
- beta = beta,
- label_smoothing = label_smoothing,
- loss_type = loss_type,
- disable_dropout = disable_dropout,
- cpo_alpha = cpo_alpha,
- simpo_gamma = simpo_gamma,
- alpha = alpha,
- label_pad_token_id = label_pad_token_id,
- padding_value = padding_value,
- truncation_mode = truncation_mode,
- generate_during_eval = generate_during_eval,
- is_encoder_decoder = is_encoder_decoder,
- model_init_kwargs = model_init_kwargs,
- dataset_num_proc = dataset_num_proc,**kwargs)
- self.vllm_sampling_params = vllm_sampling_params
- self.unsloth_num_chunks = unsloth_num_chunks
- if unsloth_grpo_mini_batch is not None:
- if self.generation_batch_size >= unsloth_grpo_mini_batch:
- self.unsloth_grpo_mini_batch = unsloth_grpo_mini_batch
- else:
- raise ValueError(
- f"Unsloth GRPO mini batch size needs to be less than or equal to the effective generation batch size, "
- f"which is self.per_device_train_batch_size * gradient_accumulation_steps."
- )
- self.unsloth_logit_chunk_multiplier = unsloth_logit_chunk_multiplier
- self.max_seq_length = max_seq_length
-
-pass
-
-class _UnslothCPOTrainer(BaseTrainer):
- r""""""
-
- _tag_names = ["trl", "cpo"]
- _name = "CPO"
- _paper = {
- "title": "Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation",
- "id": "2401.08417",
- # docstyle-ignore
- "citation": textwrap.dedent("""\
- @inproceedings{xu2024contrastive,
- title = {{Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation}},
- author = {Haoran Xu and Amr Sharaf and Yunmo Chen and Weiting Tan and Lingfeng Shen and Benjamin Van Durme and Kenton Murray and Young Jin Kim},
- year = 2024,
- booktitle = {Forty-first International Conference on Machine Learning, {ICML} 2024, Vienna, Austria, July 21-27, 2024},
- publisher = {OpenReview.net},
- url = {https://openreview.net/forum?id=51iwkioZpn}
- }"""),
- }
-
- def __init__(
- self,
- model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
- args: Optional[CPOConfig] = None,
- data_collator: Optional[DataCollator] = None,
- train_dataset: Optional[Dataset] = None,
- eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
- processing_class: Optional[
- Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
- ] = None,
- model_init: Optional[Callable[[], PreTrainedModel]] = None,
- callbacks: Optional[list[TrainerCallback]] = None,
- optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
- preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
- peft_config: Optional[dict] = None,
- compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
- ):
- if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"):
- warnings.warn(
- "This trainer will soon be moved to trl.experimental and is a candidate for removal. If you rely on "
- "it and want it to remain, please share your comments here: "
- "https://github.com/huggingface/trl/issues/4223. Silence this warning by setting environment variable "
- "TRL_EXPERIMENTAL_SILENCE=1."
- )
- if args.model_init_kwargs is None:
- model_init_kwargs = {}
- elif not isinstance(model, str):
- raise ValueError("You passed model_kwargs to the CPOTrainer. But your model is already instantiated.")
- else:
- model_init_kwargs = args.model_init_kwargs
- dtype = model_init_kwargs.get("dtype")
- if dtype is not None:
- # Convert to `torch.dtype` if an str is passed
- if isinstance(dtype, str) and dtype != "auto":
- dtype = getattr(torch, dtype)
- if dtype != "auto" and not isinstance(dtype, torch.dtype):
- raise ValueError(
- f"Invalid `dtype` passed to the CPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {dtype}."
- )
- model_init_kwargs["dtype"] = dtype
-
- if isinstance(model, str):
- model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
-
- # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
- # has been called in order to properly call autocast if needed.
- self._peft_has_been_casted_to_bf16 = False
-
- if not is_peft_available() and peft_config is not None:
- raise ValueError(
- "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
- )
- elif is_peft_available() and peft_config is not None:
- # if model is a peft model and we have a peft_config, we merge and unload it first
- if isinstance(model, PeftModel):
- model = model.merge_and_unload()
-
- if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
- _support_gc_kwargs = hasattr(
- args, "gradient_checkpointing_kwargs"
- ) and "gradient_checkpointing_kwargs" in list(
- inspect.signature(prepare_model_for_kbit_training).parameters
- )
-
- prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
-
- if _support_gc_kwargs:
- prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
-
- model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
- elif args.gradient_checkpointing:
- # For backward compatibility with older versions of transformers
- if hasattr(model, "enable_input_require_grads"):
- model.enable_input_require_grads()
- else:
-
- def make_inputs_require_grad(module, input, output):
- output.requires_grad_(True)
-
- model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
-
- # get peft model with the given config
- model = model
- if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
- peft_module_casting_to_bf16(model)
- # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
- self._peft_has_been_casted_to_bf16 = True
-
- # For models that use gradient_checkpointing, we need to attach a hook that enables input
- # to explicitly have `requires_grad=True`, otherwise training will either silently
- # fail or completely fail.
- elif args.gradient_checkpointing:
- # For backward compatibility with older versions of transformers
- if hasattr(model, "enable_input_require_grads"):
- model.enable_input_require_grads()
- else:
-
- def make_inputs_require_grad(module, input, output):
- output.requires_grad_(True)
-
- model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
-
- if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
- raise ValueError(
- "`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
- " Please install `wandb` or `comet-ml` to resolve."
- )
-
- if model is not None:
- self.is_encoder_decoder = model.config.is_encoder_decoder
- elif args.is_encoder_decoder is None:
- raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
- else:
- self.is_encoder_decoder = args.is_encoder_decoder
-
- if self.is_encoder_decoder:
- self.decoder_start_token_id = model.config.decoder_start_token_id
- self.pad_token_id = model.config.pad_token_id
-
- if processing_class is None:
- raise ValueError("processing_class must be specified to tokenize a CPO dataset.")
- if args.max_length is None:
- logger.warning(
- "`max_length` is not set in the CPOConfig's init"
- " it will default to `512` by default, but you should do it yourself in the future.",
- )
- max_length = 512
- else:
- max_length = args.max_length
- if args.max_prompt_length is None:
- logger.warning(
- "`max_prompt_length` is not set in the CPOConfig's init"
- " it will default to `128` by default, but you should do it yourself in the future.",
- )
- max_prompt_length = 128
- else:
- max_prompt_length = args.max_prompt_length
-
- if not max_prompt_length < max_length:
- raise ValueError(
- f"max_prompt_length ({max_prompt_length}) should be strictly less than max_length ({max_length})."
- )
-
- if args.max_completion_length is None and self.is_encoder_decoder:
- logger.warning(
- "When using an encoder decoder architecture, you should set `max_completion_length` in the CPOConfig's init"
- " it will default to `128` by default, but you should do it yourself in the future.",
- )
- max_completion_length = 128
- else:
- max_completion_length = args.max_completion_length
-
- if data_collator is None:
- data_collator = DPODataCollatorWithPadding(
- pad_token_id=processing_class.pad_token_id,
- label_pad_token_id=args.label_pad_token_id,
- is_encoder_decoder=self.is_encoder_decoder,
- )
-
- if args.remove_unused_columns:
- args.remove_unused_columns = False
- # warn users
- logger.warning(
- "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments"
- " we have set it for you, but you should do it yourself in the future.",
- )
-
- self.use_dpo_data_collator = True
- else:
- self.use_dpo_data_collator = False
-
- # Disable dropout in the model
- if args.disable_dropout:
- disable_dropout_in_model(model)
-
- self.max_length = max_length
- self.generate_during_eval = args.generate_during_eval
- self.label_pad_token_id = args.label_pad_token_id
- self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
- self.max_prompt_length = max_prompt_length
- self.truncation_mode = args.truncation_mode
- self.max_completion_length = max_completion_length
- self.processing_class = processing_class
-
- if args.loss_type in ["hinge", "ipo"] and args.label_smoothing > 0:
- logger.warning(
- f"You are using the {args.loss_type} loss type that does not support label smoothing. The "
- "`label_smoothing` parameter will be ignored. Set `label_smoothing` to `0.0` to remove this warning.",
- )
- if args.loss_type == "kto_pair":
- raise ValueError("Support for kto_pair has been removed in CPOTrainer. Please use KTOTrainer.")
-
- self.beta = args.beta
- self.label_smoothing = args.label_smoothing
- self.loss_type = args.loss_type
- self.cpo_alpha = args.cpo_alpha
- self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
- self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
- if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
- logger.warning(
- "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
- "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
- "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
- "loss.",
- )
-
- if args.loss_type == "simpo":
- self.simpo_gamma = args.simpo_gamma
-
- # AlphaPO parameter for reward shaping
- self.alpha = args.alpha
-
- self._stored_metrics = defaultdict(lambda: defaultdict(list))
-
- # The trainer estimates the number of FLOPs [floating-point operations] using the number of elements in the
- # input tensor associated with the key "input_ids". However, in CPO, the sampled data does not include the
- # "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and
- # "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens
- # of the input, floating-point operations will not be computed." To suppress this warning, we set the
- # "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate
- # that the warning has already been issued.
- model.warnings_issued["estimate_tokens"] = True
-
- # Compute that only on the main process for faster data processing.
- # see: https://github.com/huggingface/trl/pull/1255
- with PartialState().main_process_first():
- # Extract the prompt if needed, and apply the chat template if needed
- train_dataset = train_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
- train_dataset = train_dataset.map(
- maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc
- )
- if eval_dataset is not None:
- eval_dataset = eval_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
- eval_dataset = eval_dataset.map(
- maybe_apply_chat_template,
- fn_kwargs={"tokenizer": processing_class},
- num_proc=args.dataset_num_proc,
- )
-
- # tokenize the dataset
- train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
- if eval_dataset is not None:
- eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
-
- super().__init__(
- model=model,
- args=args,
- data_collator=data_collator,
- train_dataset=train_dataset,
- eval_dataset=eval_dataset,
- processing_class=processing_class,
- model_init=model_init,
- compute_metrics=compute_metrics,
- callbacks=callbacks,
- optimizers=optimizers,
- preprocess_logits_for_metrics=preprocess_logits_for_metrics,
- )
-
- # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
- # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
- # self.model_accepts_loss_kwargs to False to enable scaling.
- self.model_accepts_loss_kwargs = False
-
- # Add tags for models that have been loaded with the correct transformers version
- if hasattr(self.model, "add_model_tags"):
- self.model.add_model_tags(self._tag_names)
-
- if not hasattr(self, "accelerator"):
- raise AttributeError(
- "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
- )
-
- def build_tokenized_answer(self, prompt, answer):
- """
- Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`. It does ensure `enc(a + b) = enc(a) + enc(a +
- b)[len(enc(a)):]`. Reference:
- https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
- """
-
- full_tokenized = self.processing_class(prompt + answer, add_special_tokens=False)
- prompt_input_ids = self.processing_class(prompt, add_special_tokens=False)["input_ids"]
-
- answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :]
- answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :]
-
- # Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
- full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids])
-
- # Prepare input tokens for token by token comparison
- full_input_ids = np.array(full_tokenized["input_ids"])
-
- if len(full_input_ids) != len(full_concat_input_ids):
- raise ValueError("Prompt input ids and answer input ids should have the same length.")
-
- # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
- # can be merged together when tokenizing prompt+answer. This could result
- # on the last token from the prompt being different when tokenized on its own
- # vs when done as prompt+answer.
- response_token_ids_start_idx = len(prompt_input_ids)
-
- # If tokenized prompt is different than both prompt+answer, then it means the
- # last token has changed due to merging.
- if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]:
- response_token_ids_start_idx -= 1
-
- prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx]
- prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx]
-
- if len(prompt_input_ids) != len(prompt_attention_mask):
- raise ValueError("Prompt input ids and attention mask should have the same length.")
-
- answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
- answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]
-
- return dict(
- prompt_input_ids=prompt_input_ids,
- prompt_attention_mask=prompt_attention_mask,
- input_ids=answer_input_ids,
- attention_mask=answer_attention_mask,
- )
-
- def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> dict:
- """Tokenize a single row from a CPO specific dataset.
-
- At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation in case the prompt +
- chosen or prompt + rejected responses is/are too long. First we truncate the prompt; if we're still too long,
- we truncate the chosen/rejected.
-
- We also create the labels for the chosen/rejected responses, which are of length equal to the sum of the length
- of the prompt and the chosen/rejected response, with label_pad_token_id for the prompt tokens.
- """
- batch = {}
- prompt = feature["prompt"]
- chosen = feature["chosen"]
- rejected = feature["rejected"]
-
- if not self.is_encoder_decoder:
- # Check issues below for more details
- # 1. https://github.com/huggingface/trl/issues/907
- # 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
- # 3. https://github.com/LianjiaTech/BELLE/issues/337
-
- if not isinstance(prompt, str):
- raise ValueError(f"prompt should be an str but got {type(prompt)}")
- prompt_tokens = self.processing_class(prompt, add_special_tokens=False)
- prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}
-
- if not isinstance(chosen, str):
- raise ValueError(f"chosen should be an str but got {type(chosen)}")
- chosen_tokens = self.build_tokenized_answer(prompt, chosen)
-
- if not isinstance(rejected, str):
- raise ValueError(f"rejected should be an str but got {type(rejected)}")
- rejected_tokens = self.build_tokenized_answer(prompt, rejected)
-
- # Last prompt token might get merged by tokenizer and
- # it should not be included for generation if that happens
- prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"])
-
- chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"])
- rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"])
- prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids)
-
- for k, v in prompt_tokens.items():
- prompt_tokens[k] = v[:prompt_len_input_ids]
-
- # Make sure prompts only have one different token at most an
- # and length only differs by 1 at most
- num_diff_tokens = sum(
- a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])
- )
- num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
- if num_diff_tokens > 1 or num_diff_len > 1:
- raise ValueError(
- "Chosen and rejected prompt_input_ids might only differ on the "
- "last token due to tokenizer merge ops."
- )
-
- # add BOS token to head of prompt. Avoid adding if it's already there
- prompt_tokens, chosen_tokens, rejected_tokens = add_bos_token_if_needed(
- self.processing_class.bos_token_id,
- prompt_len_input_ids,
- prompt_tokens,
- chosen_prompt_len_input_ids,
- chosen_tokens,
- rejected_prompt_len_input_ids,
- rejected_tokens,
- )
-
- # add EOS token to end of answer. Avoid adding if it's already there
- chosen_tokens, rejected_tokens = add_eos_token_if_needed(
- self.processing_class.eos_token_id, chosen_tokens, rejected_tokens
- )
-
- longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))
-
- # if combined sequence is too long, truncate the prompt
- for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
- if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
- if self.truncation_mode == "keep_start":
- for k in ["prompt_input_ids", "prompt_attention_mask"]:
- answer_tokens[k] = answer_tokens[k][: self.max_prompt_length]
- elif self.truncation_mode == "keep_end":
- for k in ["prompt_input_ids", "prompt_attention_mask"]:
- answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :]
- else:
- raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
-
- # if that's still too long, truncate the response
- for answer_tokens in [chosen_tokens, rejected_tokens]:
- if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
- for k in ["input_ids", "attention_mask"]:
- answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length]
-
- # Create labels
- chosen_sequence_tokens = {
- k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"]
- }
- rejected_sequence_tokens = {
- k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"]
- }
- chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
- chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [
- self.label_pad_token_id
- ] * len(chosen_tokens["prompt_input_ids"])
- rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
- rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [
- self.label_pad_token_id
- ] * len(rejected_tokens["prompt_input_ids"])
-
- for k, toks in {
- "chosen_": chosen_sequence_tokens,
- "rejected_": rejected_sequence_tokens,
- "": prompt_tokens,
- }.items():
- for type_key, tokens in toks.items():
- if type_key == "token_type_ids":
- continue
- batch[f"{k}{type_key}"] = tokens
-
- else:
- chosen_tokens = self.processing_class(
- chosen, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
- )
- rejected_tokens = self.processing_class(
- rejected, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
- )
- prompt_tokens = self.processing_class(
- prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True
- )
-
- batch["chosen_labels"] = chosen_tokens["input_ids"]
- batch["rejected_labels"] = rejected_tokens["input_ids"]
- batch["prompt_input_ids"] = prompt_tokens["input_ids"]
- batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]
-
- if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
- batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
- labels=torch.tensor(batch["rejected_labels"])
- )
- batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
- labels=torch.tensor(batch["chosen_labels"])
- )
-
- return batch
-
- @staticmethod
- def concatenated_inputs(
- batch: dict[str, Union[list, torch.LongTensor]],
- is_encoder_decoder: bool = False,
- label_pad_token_id: int = -100,
- padding_value: int = 0,
- device: Optional[torch.device] = None,
- ) -> dict[str, torch.LongTensor]:
- """Concatenate the chosen and rejected inputs into a single tensor.
-
- Args:
- batch:
- A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors
- of shape (batch_size, sequence_length).
- is_encoder_decoder:
- Whether the model is an encoder-decoder model.
- label_pad_token_id:
- The label pad token id.
- padding_value:
- The padding value to use for the concatenated inputs_ids.
- device:
- The device for the concatenated inputs.
-
- Returns:
- A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
- """
- concatenated_batch = {}
-
- if is_encoder_decoder:
- max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1])
- else:
- max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
-
- for k in batch:
- if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
- if "labels" in k or is_encoder_decoder:
- pad_value = label_pad_token_id
- elif k.endswith("_input_ids"):
- pad_value = padding_value
- elif k.endswith("_attention_mask"):
- pad_value = 0
- concatenated_key = k.replace("chosen", "concatenated")
- concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
- for k in batch:
- if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
- if "labels" in k or is_encoder_decoder:
- pad_value = label_pad_token_id
- elif k.endswith("_input_ids"):
- pad_value = padding_value
- elif k.endswith("_attention_mask"):
- pad_value = 0
- concatenated_key = k.replace("rejected", "concatenated")
- concatenated_batch[concatenated_key] = torch.cat(
- (
- concatenated_batch[concatenated_key],
- pad_to_length(batch[k], max_length, pad_value=pad_value),
- ),
- dim=0,
- ).to(device=device)
-
- if is_encoder_decoder:
- concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device)
- concatenated_batch["concatenated_attention_mask"] = (
- batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
- )
-
- return concatenated_batch
-
- def cpo_loss(
- self,
- policy_chosen_logps: torch.FloatTensor,
- policy_rejected_logps: torch.FloatTensor,
- ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
- """Compute the CPO loss for a batch of policy and reference model log probabilities.
-
- Args:
- policy_chosen_logps:
- Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
- policy_rejected_logps:
- Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
-
- Returns:
- A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). The losses tensor contains the CPO
- loss for each example in the batch. The chosen_rewards and rejected_rewards tensors contain the rewards for
- the chosen and rejected responses, respectively.
- """
- # Apply AlphaPO reward transformation if alpha != 0
- if self.alpha != 0.0:
- # Compute probabilities
- chosen_probs = torch.exp(policy_chosen_logps)
- rejected_probs = torch.exp(policy_rejected_logps)
-
- # Apply AlphaPO transformation: r = (1 - p^(-alpha)) / alpha
- policy_chosen_rewards = (1 - chosen_probs.pow(-self.alpha)) / self.alpha
- policy_rejected_rewards = (1 - rejected_probs.pow(-self.alpha)) / self.alpha
-
- logits = (policy_chosen_rewards - policy_rejected_rewards).to(self.accelerator.device)
- else:
- # Standard log probability rewards when alpha = 0
- logits = (policy_chosen_logps - policy_rejected_logps).to(self.accelerator.device)
-
- # The beta is a temperature parameter for the CPO loss, typically something in the range of 0.1 to 0.5.
- # We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and
- # calculates a conservative CPO loss.
-
- if self.loss_type == "simpo":
- gamma_logratios = self.simpo_gamma / self.beta
- logits = logits - gamma_logratios
- # This reduces to Equation 3 from the CPO paper when label_smoothing -> 0.
- losses = (
- -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
- - F.logsigmoid(-self.beta * logits) * self.label_smoothing
- )
- elif self.loss_type == "sigmoid":
- # This reduces to Equation 3 from the CPO paper when label_smoothing -> 0.
- losses = (
- -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
- - F.logsigmoid(-self.beta * logits) * self.label_smoothing
- )
- elif self.loss_type == "hinge":
- losses = torch.relu(1 - self.beta * logits)
- elif self.loss_type == "ipo":
- # eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper.
- losses = (logits - 1 / (2 * self.beta)) ** 2
- else:
- raise ValueError(
- f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'simpo']"
- )
-
- # Calculate rewards for logging
- if self.alpha != 0.0:
- # When using AlphaPO transformation, use the transformed rewards
- chosen_rewards = self.beta * policy_chosen_rewards.to(self.accelerator.device).detach()
- rejected_rewards = self.beta * policy_rejected_rewards.to(self.accelerator.device).detach()
- else:
- # Standard log probability rewards
- chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach()
- rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach()
-
- return losses, chosen_rewards, rejected_rewards
-
- @staticmethod
- def get_batch_logps(
- logits: torch.FloatTensor,
- labels: torch.LongTensor,
- average_log_prob: bool = False,
- label_pad_token_id: int = -100,
- is_encoder_decoder: bool = False,
- ) -> torch.FloatTensor:
- """Compute the log probabilities of the given labels under the given logits.
-
- Args:
- logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
- labels:
- Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are
- ignored. Shape: (batch_size, sequence_length)
- average_log_prob:
- If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the
- log probabilities of the (non-masked) tokens.
- label_pad_token_id: The label pad token id.
- is_encoder_decoder: Whether the model is an encoder-decoder model.
-
- Returns:
- A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the
- given logits.
- """
- if logits.shape[:-1] != labels.shape:
- raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
-
- if not is_encoder_decoder:
- labels = labels[:, 1:].clone()
- logits = logits[:, :-1, :]
- loss_mask = labels != label_pad_token_id
-
- # dummy token; we'll ignore the losses on these tokens later
- labels[labels == label_pad_token_id] = 0
-
- per_token_logps = selective_log_softmax(logits, labels)
-
- if average_log_prob:
- return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
- else:
- return (per_token_logps * loss_mask).sum(-1)
-
- def concatenated_forward(
- self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
- ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
- """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
-
- We do this to avoid doing two forward passes, because it's faster for FSDP.
- """
- concatenated_batch = self.concatenated_inputs(
- batch,
- is_encoder_decoder=self.is_encoder_decoder,
- label_pad_token_id=self.label_pad_token_id,
- padding_value=self.padding_value,
- device=self.accelerator.device,
- )
- len_chosen = batch["chosen_labels"].shape[0]
-
- model_kwargs = (
- {
- "decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]),
- }
- if self.is_encoder_decoder
- else {}
- )
-
- if self.aux_loss_enabled:
- model_kwargs["output_router_logits"] = True
-
- outputs = model(
- concatenated_batch["concatenated_input_ids"],
- attention_mask=concatenated_batch["concatenated_attention_mask"],
- use_cache=False,
- **model_kwargs,
- )
- all_logits = outputs.logits
-
- def cross_entropy_loss(logits, labels):
- if not self.is_encoder_decoder:
- # Shift so that tokens < n predict n
- logits = logits[..., :-1, :].contiguous()
- labels = labels[..., 1:].contiguous()
- # Flatten the tokens
- loss_fct = nn.CrossEntropyLoss()
- logits = logits.view(-1, logits.shape[-1])
- labels = labels.view(-1)
- # Enable model parallelism
- labels = labels.to(logits.device)
- loss = loss_fct(logits, labels)
- return loss
-
- labels = concatenated_batch["concatenated_labels"].clone()
-
- if self.cpo_alpha == 0:
- nll_loss = torch.tensor(0.0).to(self.accelerator.device)
- else:
- nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
-
- all_logps = self.get_batch_logps(
- all_logits,
- concatenated_batch["concatenated_labels"],
- average_log_prob=self.loss_type in ["ipo", "simpo"],
- is_encoder_decoder=self.is_encoder_decoder,
- label_pad_token_id=self.label_pad_token_id,
- )
-
- chosen_logps = all_logps[:len_chosen]
- rejected_logps = all_logps[len_chosen:]
-
- chosen_logits = all_logits[:len_chosen]
- rejected_logits = all_logits[len_chosen:]
-
- if self.aux_loss_enabled:
- return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss, outputs.aux_loss)
-
- return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss)
-
- def get_batch_loss_metrics(
- self,
- model,
- batch: dict[str, Union[list, torch.LongTensor]],
- train_eval: Literal["train", "eval"] = "train",
- ):
- """Compute the CPO loss and other metrics for the given batch of inputs for train or test."""
- metrics = {}
-
- forward_output = self.concatenated_forward(model, batch)
- (
- policy_chosen_logps,
- policy_rejected_logps,
- policy_chosen_logits,
- policy_rejected_logits,
- policy_nll_loss,
- ) = forward_output[:5]
- if self.aux_loss_enabled:
- aux_loss = forward_output[5]
-
- losses, chosen_rewards, rejected_rewards = self.cpo_loss(
- policy_chosen_logps,
- policy_rejected_logps,
- )
-
- loss = losses.mean() + self.cpo_alpha * policy_nll_loss
- reward_accuracies = (chosen_rewards > rejected_rewards).float()
-
- prefix = "eval_" if train_eval == "eval" else ""
- metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean().item()
- metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean().item()
- metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean().item()
- metrics[f"{prefix}rewards/margins"] = (
- self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards).mean().item()
- )
- metrics[f"{prefix}logps/rejected"] = (
- self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean().item()
- )
- metrics[f"{prefix}logps/chosen"] = (
- self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean().item()
- )
- metrics[f"{prefix}logits/rejected"] = (
- self.accelerator.gather_for_metrics(policy_rejected_logits.detach().mean()).mean().item()
- )
- metrics[f"{prefix}logits/chosen"] = (
- self.accelerator.gather_for_metrics(policy_chosen_logits.detach().mean()).mean().item()
- )
- metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean().item()
-
- if self.aux_loss_enabled:
- loss += self.aux_loss_coef * aux_loss
-
- return loss, metrics
-
- def compute_loss(
- self,
- model: Union[PreTrainedModel, nn.Module],
- inputs: dict[str, Union[torch.Tensor, Any]],
- return_outputs=False,
- num_items_in_batch=None,
- ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
- compute_loss_context_manager = (
- autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
- )
-
- with compute_loss_context_manager:
- loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
-
- # force log the metrics
- self.store_metrics(metrics, train_eval="train")
-
- if return_outputs:
- return (loss, metrics)
- return loss
-
- def generate_from_model(self, model, batch: dict[str, torch.LongTensor]) -> str:
- """Generate samples from the model and reference model for the given batch of inputs."""
-
- # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
- # the torch amp context manager as some hidden states are silently casted to full precision.
- generate_context_manager = (
- autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
- )
-
- with generate_context_manager:
- policy_output = model.generate(
- input_ids=batch["prompt_input_ids"],
- attention_mask=batch["prompt_attention_mask"],
- max_length=self.max_length,
- do_sample=True,
- pad_token_id=self.processing_class.pad_token_id,
- )
-
- policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
- policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
-
- return policy_output_decoded
-
- def prediction_step(
- self,
- model: Union[PreTrainedModel, nn.Module],
- inputs: dict[str, Union[torch.Tensor, Any]],
- prediction_loss_only: bool,
- ignore_keys: Optional[list[str]] = None,
- ):
- if ignore_keys is None:
- if hasattr(model, "config"):
- ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
- else:
- ignore_keys = []
-
- prediction_context_manager = (
- autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
- )
-
- with torch.no_grad(), prediction_context_manager:
- loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")
-
- # force log the metrics
- self.store_metrics(metrics, train_eval="eval")
-
- if prediction_loss_only:
- return (loss.detach(), None, None)
-
- # logits for the chosen and rejected samples from model
- logits_dict = {
- "eval_logits/chosen": metrics["eval_logits/chosen"],
- "eval_logits/rejected": metrics["eval_logits/rejected"],
- }
- logits = [v for k, v in logits_dict.items() if k not in ignore_keys]
- logits = torch.tensor(logits, device=self.accelerator.device)
- labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
-
- return (loss.detach(), logits, labels)
-
- def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
- for key, value in metrics.items():
- self._stored_metrics[train_eval][key].append(value)
-
- def evaluation_loop(
- self,
- dataloader: DataLoader,
- description: str,
- prediction_loss_only: Optional[bool] = None,
- ignore_keys: Optional[list[str]] = None,
- metric_key_prefix: str = "eval",
- ) -> EvalLoopOutput:
- """
- Overriding built-in evaluation loop to store metrics for each batch. Prediction/evaluation loop, shared by
- `Trainer.evaluate()` and `Trainer.predict()`.
-
- Works both with or without labels.
- """
-
- # Sample and save to game log if requested (for one batch to save time)
- if self.generate_during_eval:
- # Generate random indices within the range of the total number of samples
- num_samples = len(dataloader.dataset)
- random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
-
- # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
- random_batch_dataset = dataloader.dataset.select(random_indices)
- random_batch = self.data_collator(random_batch_dataset)
- random_batch = self._prepare_inputs(random_batch)
-
- policy_output_decoded = self.generate_from_model(self.model, random_batch)
-
- table = pd.DataFrame(
- columns=["Prompt", "Policy"],
- data=[
- [prompt, pol[len(prompt) :]] for prompt, pol in zip(random_batch["prompt"], policy_output_decoded)
- ],
- )
- if "wandb" in self.args.report_to:
- wandb.log({"game_log": wandb.Table(data=table)})
-
- if "comet_ml" in self.args.report_to:
- log_table_to_comet_experiment(
- name="game_log.csv",
- table=table,
- )
-
- # Base evaluation
- initial_output = super().evaluation_loop(
- dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
- )
-
- return initial_output
-
- def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
- """
- Log `logs` on the various objects watching training, including stored metrics.
-
- Args:
- logs (`dict[str, float]`):
- The values to log.
- start_time (`float`, *optional*):
- Start time of the training.
- """
- # logs either has 'loss' or 'eval_loss'
- train_eval = "train" if "loss" in logs else "eval"
- # Add averaged stored metrics to logs
- for key, metrics in self._stored_metrics[train_eval].items():
- logs[key] = torch.tensor(metrics).mean().item()
- del self._stored_metrics[train_eval]
- return super().log(logs, start_time)
-
- def _shift_right(self, input_ids):
- if self.decoder_start_token_id is None:
- raise ValueError(
- "model.config.decoder_start_token_id has to be defined. It is usually set to the pad_token_id."
- )
-
- # shift inputs to the right
- if is_torch_fx_proxy(input_ids):
- # Item assignment is not supported natively for proxies.
- shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), self.decoder_start_token_id)
- shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
- else:
- shifted_input_ids = input_ids.new_zeros(input_ids.shape)
- shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
- shifted_input_ids[..., 0] = self.decoder_start_token_id
-
- if self.pad_token_id is None:
- raise ValueError("model.config.pad_token_id has to be defined.")
- # replace possible -100 values in labels by `pad_token_id`
- shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id)
-
- return shifted_input_ids
-
- # Ensure the model card is saved along with the checkpoint
- def _save_checkpoint(self, model, trial):
- if self.args.hub_model_id is None:
- model_name = Path(self.args.output_dir).name
- else:
- model_name = self.args.hub_model_id.split("/")[-1]
- self.create_model_card(model_name=model_name)
- super()._save_checkpoint(model, trial)
-class UnslothCPOTrainer(_UnslothCPOTrainer):
- """
-
- Initialize CPOTrainer.
-
- Args:
- model ([`~transformers.PreTrainedModel`]):
- The model to train, preferably an [`~transformers.AutoModelForSequenceClassification`].
- args ([`CPOConfig`]):
- The CPO config arguments to use for training.
- data_collator ([`~transformers.DataCollator`]):
- The data collator to use for training. If None is specified, the default data collator
- ([`DPODataCollatorWithPadding`]) will be used which will pad the sequences to the maximum length of the
- sequences in the batch, given a dataset of paired sequences.
- train_dataset ([`~datasets.Dataset`]):
- The dataset to use for training.
- eval_dataset ([`~datasets.Dataset`]):
- The dataset to use for evaluation.
- processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*):
- Processing class used to process the data. If provided, will be used to automatically process the inputs
- for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
- reuse the fine-tuned model.
- model_init (`Callable[[], transformers.PreTrainedModel]`):
- The model initializer to use for training. If None is specified, the default model initializer will be
- used.
- callbacks (`list[transformers.TrainerCallback]`):
- The callbacks to use for training.
- optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
- The optimizer and scheduler to use for training.
- preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
- The function to use to preprocess the logits before computing the metrics.
- peft_config (`dict`, defaults to `None`):
- The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in
- a PEFT model.
- compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
- The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to
- metric values.
-
- """
- def __init__(
- self,
- model = None,
- args = None,
- data_collator = None,
- train_dataset = None,
- eval_dataset = None,
- processing_class = None,
- model_init = None,
- callbacks = None,
- preprocess_logits_for_metrics = None,
- peft_config = None,
- compute_metrics = None,
- **kwargs
- ):
- if args is None: args = UnslothCPOConfig()
- use_bf16 = getattr(args, 'bf16', False)
- if type(use_bf16) is not bool: use_bf16 = False
- use_fp16 = getattr(args, 'fp16', False)
- if type(use_fp16) is not bool: use_fp16 = False
- force_float32 = False
- full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1'
- if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'):
- print('Unsloth: Switching to float32 training since model cannot work with float16')
- force_float32 = True
- mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
- dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None)
- if dtype is None: dtype = model.get_input_embeddings().weight.dtype
- from unsloth_zoo.utils import _get_dtype
- dtype = _get_dtype(dtype)
- float16 = dtype == torch.float16
- if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
- if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
- if force_float32:
- # Forced float32 training
- args.fp16 = False
- args.bf16 = False
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
- if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
- # args.mixed_precision is a new argument which needs to be set now
- elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
- # Mixed precision training
- args.fp16 = float16
- args.bf16 = not float16
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
- if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16'
- # args.mixed_precision is a new argument which needs to be set now
- elif mixed_precision_dtype == 'bfloat16':
- # Both False since bfloat16 full finetuning doesn't do any autocasting.
- args.fp16 = False
- args.bf16 = False
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
- if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
- # args.mixed_precision is a new argument which needs to be set now
-
- if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
- args.eval_strategy = 'steps'
- if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
- ga_steps = getattr(args, 'gradient_accumulation_steps', None)
- if ga_steps is not None and ga_steps > 1:
- from transformers import __version__ as transformers_version
- if Version(transformers_version) <= Version('4.45.2'):
- print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
- '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
- if getattr(args, 'eval_strategy', 'no') != 'no':
- eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
- if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
- if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
- fp16_full_eval = getattr(args, 'fp16_full_eval', False)
- if type(fp16_full_eval) is not bool: fp16_full_eval = False
- bf16_full_eval = getattr(args, 'bf16_full_eval', False)
- if type(bf16_full_eval) is not bool: bf16_full_eval = False
- if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
- if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
- if force_float32:
- args.bf16_full_eval = False
- args.fp16_full_eval = False
- elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
- args.bf16_full_eval = True
- args.fp16_full_eval = False
- elif not bf16_full_eval and not fp16_full_eval:
- args.bf16_full_eval = args.bf16
- args.fp16_full_eval = args.fp16
- _output_logits = False
- if locals().get('compute_metrics', None) is not None: _output_logits = True
- if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
- if _output_logits:
- os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
- if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
- pass
- else:
- model_max_seq_length = getattr(model, 'max_seq_length', None)
- args_max_seq_length = getattr(args, 'max_seq_length', None)
- if args_max_seq_length is None and model_max_seq_length is not None:
- max_seq_length = model.max_seq_length
- if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
- elif args_max_seq_length is not None and model_max_seq_length is not None:
- if args_max_seq_length > model_max_seq_length:
- print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but '
- 'the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.')
- args.max_seq_length = model_max_seq_length
- if model is not None and hasattr(model, 'for_training'):
- model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
- if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
- if 'processing_class' in locals():
- if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
- if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
- __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
- from unsloth_zoo.vision_utils import UnslothVisionDataCollator
- if not isinstance(data_collator, UnslothVisionDataCollator):
- if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
- data_collator = TransformersDataCollatorForLanguageModeling(
- __tokenizer,
- mlm = False,
- mlm_probability = 0.0,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
- data_collator = DataCollatorForSeq2Seq(
- __tokenizer,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- else:
- if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
- if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
- if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
- if not isinstance(data_collator, UnslothVisionDataCollator):
- if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
- if isinstance(data_collator, DataCollatorForSeq2Seq):
- data_collator = DataCollatorForSeq2Seq(
- __tokenizer.tokenizer,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- else:
- data_collator = TransformersDataCollatorForLanguageModeling(
- __tokenizer.tokenizer,
- mlm = False,
- mlm_probability = 0.0,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- other_metrics = []
-
- from unsloth_zoo.logging_utils import PatchRLStatistics
- PatchRLStatistics('cpo_trainer', other_metrics)
-
- # [TODO] Fix up DataParallel multiplying batch sizes
- # [TODO] DDP works, but DP seems to not work? [TODO]
- if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1:
- if getattr(args, "_n_gpu", 1) != 1:
- args._n_gpu = 1
- if "model" in locals() and hasattr(model, "for_training"):
- model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
- super().__init__(
- model = model,
- args = args,
- data_collator = data_collator,
- train_dataset = train_dataset,
- eval_dataset = eval_dataset,
- processing_class = processing_class,
- model_init = model_init,
- callbacks = callbacks,
- preprocess_logits_for_metrics = preprocess_logits_for_metrics,
- peft_config = peft_config,
- compute_metrics = compute_metrics,**kwargs)
- if "model" in locals() and hasattr(model, "for_inference"):
- model.for_inference()
- if hasattr(self, 'neftune_hook_handle'):
- self.neftune_hook_handle.remove()
- if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
- if getattr(args, 'neftune_noise_alpha', None) is not None:
- model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
- pass
- if hasattr(self, 'accelerator'):
- scaler = self.accelerator.scaler
- current_model = model
- while hasattr(current_model, 'model'):
- current_model.accelerator_scaler = scaler
- current_model = current_model.model
- current_model.accelerator_scaler = scaler
- pass
- if hasattr(self, 'train'):
- self.train = MethodType(prepare_for_training_mode(self.__class__.train), self)
- pass
- if hasattr(self, 'llm') and self.llm is not None and hasattr(self.llm, 'get_tokenizer'):
- _vllm_tok = self.llm.get_tokenizer()
- _pc = getattr(self, 'processing_class', None) or getattr(self, 'tokenizer', None)
- if _vllm_tok is not None and _pc is not None and getattr(_pc, 'chat_template', None) is not None and getattr(_vllm_tok, 'chat_template', None) is None:
- _vllm_tok.chat_template = _pc.chat_template
- pass
-
-pass
-
-
-if hasattr(logger, "addFilter"):
- import logging
- class HideLoggingMessage(logging.Filter):
- def __init__(self, text): self.text = text
- def filter(self, x): return not (self.text in x.getMessage())
- pass
- logger.addFilter(HideLoggingMessage("`use_cache=True`"))
-
diff --git a/unsloth_compiled_cache/UnslothDPOTrainer.py b/unsloth_compiled_cache/UnslothDPOTrainer.py
deleted file mode 100644
index fc4dd5d..0000000
--- a/unsloth_compiled_cache/UnslothDPOTrainer.py
+++ /dev/null
@@ -1,2898 +0,0 @@
-"""
-2026.2.1
-2026.2.1
-4.57.6
-0.24.0
-__UNSLOTH_VERSIONING__
-"""
-
-# Unsloth auto generated code
-# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Lesser General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU Lesser General Public License
-# along with this program. If not, see .
-
-from torch import Tensor
-import torch
-import torch.nn as nn
-from torch.nn import functional as F
-from unsloth_zoo.temporary_patches.common import torch_compile
-from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
-from trl.trainer.dpo_trainer import (Any, AutoProcessor, BaseImageProcessor, BaseTrainer, Callable, DPOConfig, DPOTrainer, DataCollator, DataCollatorForPreference, DataLoader, Dataset, EvalLoopOutput, F, FDivergenceConstants, FDivergenceType, FeatureExtractionMixin, IterableDataset, Literal, MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES, Optional, PartialState, Path, PeftConfig, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, RunningMoments, SyncRefModelCallback, TrainerCallback, Union, autocast, cap_exp, contextmanager, create_model_from_path, create_reference_model, dataclass, defaultdict, disable_dropout_in_model, empty_cache, flush_left, flush_right, get_peft_model, inspect, is_comet_available, is_liger_kernel_available, is_mlflow_available, is_peft_available, is_wandb_available, log_table_to_comet_experiment, logger, logging, maybe_apply_chat_template, maybe_extract_prompt, nn, nullcontext, pad, pad_to_length, pd, peft_module_casting_to_bf16, prepare_deepspeed, prepare_fsdp, prepare_model_for_kbit_training, random, selective_log_softmax, shift_tokens_right, textwrap, torch, tqdm, warnings, Any, AutoProcessor, BaseImageProcessor, Callable, DPOConfig, DPOTrainer, DataCollator, DataCollatorForPreference, Dataset, EvalLoopOutput, F, FDivergenceConstants, FeatureExtractionMixin, IterableDataset, MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES, Optional, PeftConfig, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, RunningMoments, SyncRefModelCallback, TrainerCallback, Union, create_model_from_path, create_reference_model, defaultdict, disable_dropout_in_model, is_comet_available, is_liger_kernel_available, is_mlflow_available, is_peft_available, is_wandb_available, logger, nn, pad, prepare_deepspeed, prepare_fsdp, torch, warnings, F, Optional, PeftModel, PreTrainedModel, is_peft_available, logger, torch)
-
-
-import os
-from typing import *
-from dataclasses import dataclass, field
-from packaging.version import Version
-import torch
-import numpy as np
-from contextlib import nullcontext
-from torch.nn import functional as F
-import inspect
-from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
-from transformers.training_args import ParallelMode
-from unsloth_zoo.device_type import DEVICE_TYPE, device_synchronize
-
-# Wrap trainer with padding to right and enable training mode
-# Also patches W&B since multiple runs must use wandb.finish()
-import functools
-from types import MethodType
-try:
- from unsloth_zoo.gradient_checkpointing import reset_unsloth_gradient_checkpointing_buffers
-except:
- def reset_unsloth_gradient_checkpointing_buffers(): pass
-def prepare_for_training_mode(f):
- @functools.wraps(f)
- def wrapper(self, *args, **kwargs):
- # Enable training mode
- _was_training = None
- # Get gradient checkpointing setting from training arguments
- use_gc = getattr(self.args, 'gradient_checkpointing', True)
- if hasattr(self, 'model') and hasattr(self.model, "training"):
- _was_training = self.model.training
- if hasattr(self, 'model') and hasattr(self.model, "for_training"):
- self.model.for_training(use_gradient_checkpointing=use_gc)
- output = f(self, *args, **kwargs)
- # Restore previous mode when possible
- if hasattr(self, 'model') and hasattr(self.model, "for_inference"):
- if _was_training is False:
- self.model.for_inference()
- elif _was_training is True and hasattr(self.model, "for_training"):
- self.model.for_training(use_gradient_checkpointing=use_gc)
- # Reset gradient checkpointing buffers to free memory while staying ready for next run
- try:
- reset_unsloth_gradient_checkpointing_buffers()
- except:
- pass
- # Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run
- try:
- import wandb
- wandb.finish()
- except:
- pass
- return output
- return wrapper
-pass
-
-torch_compile_options = {
- "epilogue_fusion" : True,
- "max_autotune" : False,
- "shape_padding" : True,
- "trace.enabled" : False,
- "triton.cudagraphs" : False,
-}
-
-@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
-def chunked_hidden_states_selective_log_softmax(
- hidden_states: torch.Tensor,
- lm_head: torch.Tensor,
- index: torch.Tensor,
- chunks: int = 4,
- logit_scale_multiply: float = 0.0,
- logit_scale_divide: float = 0.0,
- logit_softcapping: float = 0.0,
- temperature: float = 1.0,
-) -> torch.Tensor:
- # All Unsloth Zoo code licensed under AGPL3
- flat_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1])
- flat_index = index.reshape(-1)
-
- chunked_hidden_states = torch.chunk(flat_hidden_states, chunks=chunks, dim=0)
- chunked_index = torch.chunk(flat_index, chunks=chunks, dim=0)
-
- all_per_token_logps = []
-
- for chunk_hidden_states, chunk_index in zip(chunked_hidden_states, chunked_index):
- chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t()
-
- if logit_scale_multiply != 0.0:
- chunk_logits = chunk_logits * logit_scale_multiply
- if logit_scale_divide != 0.0:
- chunk_logits = chunk_logits / logit_scale_divide
- if logit_softcapping != 0.0:
- chunk_logits = chunk_logits * torch.tanh(chunk_logits / logit_softcapping)
-
- chunk_logits = chunk_logits.to(torch.float32)
-
- if temperature != 1.0:
- chunk_logits = chunk_logits / temperature
-
- selected_logits = torch.gather(chunk_logits, dim=-1, index=chunk_index.unsqueeze(-1)).squeeze(-1)
- logsumexp_values = torch.logsumexp(chunk_logits, dim=-1)
- per_token_logps = selected_logits - logsumexp_values
- all_per_token_logps.append(per_token_logps)
-
- all_per_token_logps = torch.concat(all_per_token_logps)
-
- all_per_token_logps = all_per_token_logps.reshape((hidden_states.shape[0], hidden_states.shape[1]))
- return all_per_token_logps
-
-@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
-def chunked_selective_log_softmax(logits, index):
- # Split into 4 chunks only
- chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
- chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
- all_per_token_logps = []
- # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
- for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
- chunk_logits = chunk_logits.to(torch.float32)
- selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
- logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
- per_token_logps = selected_logits - logsumexp_values
- all_per_token_logps.append(per_token_logps)
- pass
- all_per_token_logps = torch.concat(all_per_token_logps)
- all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
- return all_per_token_logps
-
-def calculate_pad_tokens_in_prompt(
- input_ids: torch.Tensor,
- logits_to_keep: int,
- pad_token_id: int
-) -> torch.Tensor:
- """
- Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens
- """
- if logits_to_keep >= input_ids.shape[1]:
- raise ValueError("logits_to_keep must be smaller than the sequence length.")
-
- prompt_section = input_ids[:, :-logits_to_keep]
-
- padding_mask = (prompt_section == pad_token_id)
-
- pad_token_counts = padding_mask.sum(dim=1)
-
- return pad_token_counts
-
-def create_completion_attention_mask(
- completion_input_ids: torch.Tensor,
- left_pad_tokens_per_prompt: torch.Tensor,
- max_left_pad: int,
- pad_token_id: int
-) -> torch.Tensor:
- """
- Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad]
-
- Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens
- and pad are pad tokens, this function would make a completion mask that would 0 out the pad
- and p tokens. so in this example [0,0,0,1,1,1,0,0,0]
- """
- batch_size, completion_len = completion_input_ids.shape
- device = completion_input_ids.device
-
- num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt
-
- indices = torch.arange(completion_len, device=device).unsqueeze(0)
- shift_mask = indices >= num_tokens_to_mask.unsqueeze(1)
-
- non_padding_mask = (completion_input_ids != pad_token_id)
-
- final_mask = shift_mask & non_padding_mask
-
- return final_mask
-
-def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor:
- """
- Moves all padding tokens in each sequence of a batch to the right.
- """
- mask = (tensor != pad_id)
- # Must do stable=True since binary mark is unordered
- sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True)
- packed_tensor = torch.gather(tensor, 1, sorted_indices)
- return packed_tensor
-
-def align_logprobs_with_mask(
- logprob_tensor: torch.Tensor,
- attention_mask: torch.Tensor,
- pad_value: float = 0.0
-) -> torch.Tensor:
- """
- Aligns a log probability tensor with a given attention mask.
- """
-
- device = logprob_tensor.device
- batch_size, logprob_seq_len = logprob_tensor.shape
- mask_seq_len = attention_mask.shape[1]
-
- padded_logprobs = torch.full(
- attention_mask.shape,
- fill_value=pad_value,
- dtype=logprob_tensor.dtype,
- device=device
- )
-
- left_pad_counts = torch.argmax(attention_mask, dim=1)
-
- cols = torch.arange(logprob_seq_len, device=device)
- dest_indices = left_pad_counts.unsqueeze(1) + cols
-
- # Create destination row indices
- # Shape: [batch_size, logprob_seq_len]
- row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices)
-
- # --- 4. Filter out-of-bounds indices and perform assignment ---
- # Create a mask to identify only the indices that are within the bounds
- # of the target tensor's sequence length.
- valid_mask = dest_indices < mask_seq_len
-
- # Use this mask to select only the valid row indices, column indices,
- # and the corresponding values from the logprob tensor.
- # This flattens the selected elements into 1D tensors.
- valid_rows = row_indices[valid_mask]
- valid_cols = dest_indices[valid_mask]
- valid_vals = logprob_tensor[valid_mask]
-
- # Place the valid values into their correct positions in the padded tensor
- # using a single, efficient advanced indexing operation.
- padded_logprobs[valid_rows, valid_cols] = valid_vals
-
- return padded_logprobs
-
-def autotune_batch_and_chunks(
- total_input_rows,
- seq_len,
- hidden_size,
- vocab_size,
- dtype_bytes=16,
- multiplier=None
-):
- if multiplier is None:
- final_m = max(4, seq_len // 4096)
- else:
- final_m = multiplier
-
- if torch.cuda.is_available():
- free_bytes, _ = torch.cuda.mem_get_info()
- limit_gb = (free_bytes / (1024**3))*.80
- elif hasattr(torch, "xpu") and torch.xpu.is_available():
- # For XPU: estimate free memory from total - reserved
- total_mem = torch.xpu.get_device_properties(0).total_memory
- reserved_mem = torch.xpu.memory_reserved()
- free_bytes = total_mem - reserved_mem
- limit_gb = (free_bytes / (1024**3)) * 0.80
- else:
- # Fallback: assume 8GB available
- limit_gb = 8.0
-
- bytes_to_gb = 1024**3
-
- b_vals = torch.arange(total_input_rows, 0, -1, device='cpu', dtype=torch.float32)
-
- hidden_gb = (b_vals * seq_len * hidden_size * dtype_bytes) / bytes_to_gb
-
- base_logits = ((b_vals/total_input_rows) * b_vals * seq_len * vocab_size * dtype_bytes) / bytes_to_gb
- logits_gb = base_logits / final_m
-
- total_mem_gb = hidden_gb + logits_gb
-
- valid_mask = total_mem_gb <= limit_gb
- valid_indices = torch.nonzero(valid_mask, as_tuple=False)
-
- if valid_indices.shape[0] == 0:
- #This means your GPU will OOM
- return 4, final_m
-
- best_idx = valid_indices[0].item()
- final_b = int(b_vals[best_idx].item())
-
- return final_b, final_m
-@dataclass
-class UnslothDPOConfig(DPOConfig):
- """
-
- Configuration class for the [`DPOTrainer`].
-
- This class includes only the parameters that are specific to DPO training. For a full list of training arguments,
- please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may
- differ from those in [`~transformers.TrainingArguments`].
-
- Using [`~transformers.HfArgumentParser`] we can turn this class into
- [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
- command line.
-
- Parameters:
- > Parameters that control the model and reference model
-
- model_init_kwargs (`dict[str, Any]`, *optional*):
- Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `model` argument of the
- [`DPOTrainer`] is provided as a string.
- ref_model_init_kwargs (`dict[str, Any]`, *optional*):
- Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `ref_model` argument of the
- [`DPOTrainer`] is provided as a string.
- model_adapter_name (`str`, *optional*):
- Name of the train target PEFT adapter, when using LoRA with multiple adapters.
- ref_adapter_name (`str`, *optional*):
- Name of the reference PEFT adapter, when using LoRA with multiple adapters.
- force_use_ref_model (`bool`, *optional*, defaults to `False`):
- If you provide a PEFT model as the active model and wish to use a different model for the `ref_model`, set
- this flag to `True`.
- disable_dropout (`bool`, *optional*, defaults to `True`):
- Whether to disable dropout in the model and reference model.
- use_logits_to_keep (`bool`, *optional*, defaults to `False`):
- If `True`, only a specified number of logits are computed in the forward pass. This can be useful for
- saving memory and speeding up training by not computing the logits for all tokens, especially in scenarios
- when working with very long prompts where labels are ignored (-100).
-
- > Parameters that control the data preprocessing
-
- dataset_num_proc (`int`, *optional*):
- Number of processes to use for processing the dataset.
- pad_token (`str`, *optional*):
- Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that is also `None`,
- it falls back to `processing_class.eos_token`.
- label_pad_token_id (`int`, *optional*, defaults to `-100`):
- Padding value to use for labels.
- max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
- Maximum length of the prompt.
- max_completion_length (`int`, *optional*):
- Maximum length of the completion.
- max_length (`int` or `None`, *optional*, defaults to `1024`):
- Maximum length of the full sequence (prompt + completion).
- truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
- Truncation mode to use when the sequence exceeds `max_length`. Possible values are `"keep_end"` and
- `"keep_start"`.
- padding_free (`bool`, *optional*, defaults to `False`):
- Whether to perform forward passes without padding by flattening all sequences in the batch into a single
- continuous sequence. This reduces memory usage by eliminating padding overhead. Currently, this is only
- supported with the `flash_attention_2` attention implementation, which can efficiently handle the flattened
- batch structure.
- precompute_ref_log_probs (`bool`, *optional*, defaults to `False`):
- Whether to precompute the log probabilities from the reference model. Setting this to `True` allows
- training without needing the reference model during training, which can help reduce GPU memory usage. If
- set to `False` (default), the reference model will be used during training to compute log probabilities
- on-the-fly.
- precompute_ref_batch_size (`int`, *optional*):
- Batch size to use when precomputing reference model log probabilities. This can be set higher than the
- training batch size to speed up preprocessing. If `None`, defaults to `per_device_train_batch_size` for
- training and `per_device_eval_batch_size` for evaluation.
- tools (`Optional[list[Union[dict, Callable]]]`, *optional*):
- List of tools (callable functions) that will be accessible to the model. If the template does not support
- function calling, this argument will have no effect.
-
- > Parameters that control the training
-
- loss_type (`str` or `list[str]`, *optional*, defaults to `"sigmoid"`):
- Type of loss to use. Possible values are:
-
- - `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper.
- - `"hinge"`: hinge loss on the normalized likelihood from the
- [SLiC](https://huggingface.co/papers/2305.10425) paper.
- - `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper.
- - `"exo_pair"`: pairwise EXO loss from the [EXO](https://huggingface.co/papers/2402.00856) paper.
- - `"nca_pair"`: pairwise NCA loss from the [NCA](https://huggingface.co/papers/2402.05369) paper.
- - `"robust"`: unbiased estimate of the DPO loss that is robust to preference noise from the [Robust
- DPO](https://huggingface.co/papers/2403.00409) paper.
- - `"bco_pair"`: pairwise BCO loss from the [BCO](https://huggingface.co/papers/2404.04656) paper.
- - `"sppo_hard"`: SPPO loss with hard label from the [SPPO](https://huggingface.co/papers/2405.00675)
- paper.
- - `"aot"`: AOT loss for paired datasets from the [AOT](https://huggingface.co/papers/2406.05882) paper.
- - `"aot_pair"`: AOT loss for unpaired datasets from the [AOT](https://huggingface.co/papers/2406.05882)
- paper.
- - `"discopop"`: DiscoPOP (a.k.a Log-Ratio Modulated Loss, LRML) loss from the
- [DiscoPOP](https://huggingface.co/papers/2406.08414) paper.
- - `"apo_zero"`: APO-zero loss from the [APO](https://huggingface.co/papers/2408.06266) paper.
- - `"apo_down"`: APO-down loss from the [APO](https://huggingface.co/papers/2408.06266) paper.
- - `"sft"`: Negative log-likelihood loss (standard supervised fine-tuning loss).
-
- Multiple loss types can be combined using comma separation (e.g., `["sigmoid", "bco_pair", "sft"]` for
- [MPO](https://huggingface.co/papers/2411.10442)). The `loss_weights` parameter can be used to specify
- corresponding weights for each loss type.
-
- use_liger_loss (`bool`, *optional*, defaults to `False`):
- Whether to use Liger loss.
- base_model_attribute_name (`str`, *optional*, defaults to `"model"`):
- Name of the attribute in the model that contains the base model. This is used to get the base model from
- the model when the model does not have a `get_decoder` method in the case when `use_liger_loss` is `True`.
- beta (`float`, *optional*, defaults to `0.1`):
- Parameter controlling the deviation from the reference model. Higher β means less deviation from the
- reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in
- the [paper](https://huggingface.co/papers/2310.12036).
- f_divergence_type ([`FDivergenceType`] or `str`, *optional*, defaults to `FDivergenceType.REVERSE_KL`):
- Type of f-divergence regularization function to compute divergence between policy and reference model.
- f_alpha_divergence_coef (`float`, *optional*, defaults to `1.0`):
- α coefficient in the α-divergence u^-α regularization function for DPO loss.
- reference_free (`bool`, *optional*, defaults to `False`):
- Whether to ignore the provided reference model and implicitly use a reference model that assigns equal
- probability to all responses.
- label_smoothing (`float`, *optional*, defaults to `0.0`):
- Robust DPO label smoothing parameter from the [cDPO report](https://ericmitchell.ai/cdpo.pdf) and [Robust
- DPO](https://huggingface.co/papers/2403.00409) paper that should be between `0.0` and `0.5`.
- use_weighting (`bool`, *optional*, defaults to `False`):
- Whether to weight the loss as done in the [WPO paper](https://huggingface.co/papers/2406.11827).
- rpo_alpha (`float`, *optional*):
- α parameter from the [RPO paper](https://huggingface.co/papers/2404.19733) (v3), which controls the
- weighting of the NLL term in the loss. If `None`, no weighting is applied and the loss is the same as the
- DPO loss. The paper recommends `rpo_alpha=1.0`.
- ld_alpha (`float`, *optional*):
- α parameter from the [LD-DPO paper](https://huggingface.co/papers/2409.06411), which controls the weighting
- of the verbose token log-probabilities in responses. If `None`, no weighting is applied to the verbose
- part, and the loss is equivalent to the standard DPO loss. The paper recommends setting `ld_alpha` between
- `0.0` and `1.0`.
- discopop_tau (`float`, *optional*, defaults to `0.05`):
- τ/temperature parameter from the [DiscoPOP](https://huggingface.co/papers/2406.08414) paper, which controls
- the shape of log ratio modulated loss. The paper recommends the default value `discopop_tau=0.05`.
- loss_weights (`list[float]`, *optional*):
- List of loss weights for multi-loss combinations. Used when combining multiple loss types. Example: `[0.8,
- 0.2, 1.0]` for [MPO](https://huggingface.co/papers/2411.10442). If not provided, defaults to equal weights
- (`1.0`) for all loss types.
- sync_ref_model (`bool`, *optional*, defaults to `False`):
- Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using
- the `ref_model_mixup_alpha` parameter. This synchronization originates from the
- [TR-DPO](https://huggingface.co/papers/2404.09656) paper.
- ref_model_mixup_alpha (`float`, *optional*, defaults to `0.6`):
- α parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which controls the mix
- between the current policy and the previous reference policy during updates. The reference policy is
- updated according to the equation: `π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you
- must set `sync_ref_model=True`.
- ref_model_sync_steps (`int`, *optional*, defaults to `512`):
- τ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which determines how
- frequently the current policy is synchronized with the reference policy. To use this parameter, you must
- set `sync_ref_model=True`.
-
- > Parameters that control the logging
-
- generate_during_eval (`bool`, *optional*, defaults to `False`):
- Whether to generate and log completions from both the model and the reference model to W&B or Comet during
- evaluation.
-
- > Deprecated parameters
-
- padding_value:
-
-
-
- This parameter is deprecated and will be removed in version 0.25.0. Use `pad_token` (`str`) instead.
-
-
-
- """
- vllm_sampling_params: Optional[Any] = field(
- default = None,
- metadata = {'help': 'vLLM SamplingParams'},
- )
- unsloth_num_chunks : Optional[int] = field(
- default = -1,
- metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
- )
- unsloth_logit_chunk_multiplier : Optional[int] = field(
- default = None,
- metadata = {'help': 'Multiplier for chunked logit computations.'},
- )
- unsloth_grpo_mini_batch : Optional[int] = field(
- default = None,
- metadata = {'help': 'Mini batch size for GRPO hidden state accumulation. Default is None unless user defines it.'},
- )
- max_seq_length : Optional[int] = field(
- default = None,
- metadata = {'help': 'Maximum sequence length to truncate to.'},
- )
- def __init__(
- self,
- output_dir = None,
- overwrite_output_dir = None,
- do_train = False,
- do_eval = False,
- do_predict = False,
- eval_strategy = 'no',
- prediction_loss_only = False,
- per_device_train_batch_size = 4,
- per_device_eval_batch_size = 4,
- per_gpu_train_batch_size = None,
- per_gpu_eval_batch_size = None,
- gradient_accumulation_steps = 2,
- eval_accumulation_steps = 2,
- eval_delay = 0,
- torch_empty_cache_steps = 250,
- learning_rate = 5e-05,
- weight_decay = 0.01,
- adam_beta1 = 0.9,
- adam_beta2 = 0.999,
- adam_epsilon = 1e-08,
- max_grad_norm = 1.0,
- num_train_epochs = 3.0,
- max_steps = -1,
- lr_scheduler_type = 'linear',
- lr_scheduler_kwargs = None,
- warmup_ratio = 0.1,
- warmup_steps = 0,
- log_level = 'passive',
- log_level_replica = 'warning',
- log_on_each_node = True,
- logging_dir = None,
- logging_strategy = 'steps',
- logging_first_step = False,
- logging_steps = 1,
- logging_nan_inf_filter = False,
- save_strategy = 'steps',
- save_steps = 500,
- save_total_limit = None,
- save_safetensors = True,
- save_on_each_node = False,
- save_only_model = False,
- restore_callback_states_from_checkpoint = False,
- no_cuda = False,
- use_cpu = False,
- use_mps_device = False,
- seed = 3407,
- data_seed = 3407,
- jit_mode_eval = False,
- bf16 = False,
- fp16 = False,
- fp16_opt_level = 'O1',
- half_precision_backend = 'auto',
- bf16_full_eval = False,
- fp16_full_eval = False,
- tf32 = None,
- local_rank = -1,
- ddp_backend = None,
- tpu_num_cores = None,
- tpu_metrics_debug = False,
- debug = '',
- dataloader_drop_last = False,
- eval_steps = None,
- dataloader_num_workers = 0,
- dataloader_prefetch_factor = None,
- past_index = -1,
- run_name = None,
- disable_tqdm = None,
- remove_unused_columns = True,
- label_names = None,
- load_best_model_at_end = False,
- metric_for_best_model = None,
- greater_is_better = None,
- ignore_data_skip = False,
- fsdp = None,
- fsdp_min_num_params = 0,
- fsdp_config = None,
- fsdp_transformer_layer_cls_to_wrap = None,
- accelerator_config = None,
- parallelism_config = None,
- deepspeed = None,
- label_smoothing_factor = 0.0,
- optim = 'adamw_8bit',
- optim_args = None,
- adafactor = False,
- group_by_length = False,
- length_column_name = 'length',
- report_to = 'none',
- project = 'huggingface',
- trackio_space_id = 'trackio',
- ddp_find_unused_parameters = None,
- ddp_bucket_cap_mb = None,
- ddp_broadcast_buffers = None,
- dataloader_pin_memory = True,
- dataloader_persistent_workers = False,
- skip_memory_metrics = True,
- use_legacy_prediction_loop = False,
- push_to_hub = False,
- resume_from_checkpoint = None,
- hub_model_id = None,
- hub_strategy = 'every_save',
- hub_token = None,
- hub_private_repo = None,
- hub_always_push = False,
- hub_revision = None,
- gradient_checkpointing = True,
- gradient_checkpointing_kwargs = None,
- include_inputs_for_metrics = False,
- eval_do_concat_batches = True,
- fp16_backend = 'auto',
- push_to_hub_model_id = None,
- push_to_hub_organization = None,
- push_to_hub_token = None,
- mp_parameters = '',
- auto_find_batch_size = False,
- full_determinism = False,
- torchdynamo = None,
- ray_scope = 'last',
- ddp_timeout = 1800,
- torch_compile = False,
- torch_compile_backend = None,
- torch_compile_mode = None,
- include_tokens_per_second = False,
- include_num_input_tokens_seen = False,
- neftune_noise_alpha = None,
- optim_target_modules = None,
- batch_eval_metrics = False,
- eval_on_start = False,
- use_liger_kernel = False,
- liger_kernel_config = None,
- eval_use_gather_object = False,
- average_tokens_across_devices = True,
- model_init_kwargs = None,
- ref_model_init_kwargs = None,
- model_adapter_name = None,
- ref_adapter_name = None,
- force_use_ref_model = False,
- disable_dropout = True,
- use_logits_to_keep = False,
- dataset_num_proc = None,
- pad_token = None,
- label_pad_token_id = -100,
- max_prompt_length = 512,
- max_completion_length = None,
- max_length = 1024,
- truncation_mode = 'keep_end',
- padding_free = False,
- precompute_ref_log_probs = False,
- precompute_ref_batch_size = None,
- tools = None,
- use_liger_loss = False,
- base_model_attribute_name = 'model',
- beta = 0.1,
- f_alpha_divergence_coef = 1.0,
- reference_free = False,
- label_smoothing = 0.0,
- use_weighting = False,
- rpo_alpha = None,
- ld_alpha = None,
- discopop_tau = 0.05,
- loss_weights = None,
- sync_ref_model = False,
- ref_model_mixup_alpha = 0.6,
- ref_model_sync_steps = 512,
- generate_during_eval = False,
- padding_value = None,
- vllm_sampling_params = None,
- unsloth_num_chunks = -1,
- unsloth_logit_chunk_multiplier = None,
- unsloth_grpo_mini_batch = None,
- max_seq_length = None,
- **kwargs,
- ):
- if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
- if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
- if num_train_epochs is None:
- num_train_epochs = 3.0 # Default to 3 epochs if None, max_steps will override
- if output_dir is None and save_strategy == 'steps' and save_steps == 500:
- output_dir = 'unsloth_training_checkpoints'
- save_strategy = 'no'
- import multiprocessing as _mp
- if _mp.get_start_method() != 'fork':
- dataset_num_proc = None
- elif dataset_num_proc is None:
- import psutil
- dataset_num_proc = min(max((psutil.cpu_count() or 1)+4, 2), 64)
- memory_gb_left = psutil.virtual_memory().available / (1024**3)
- if memory_gb_left <= 2: dataset_num_proc = 1
- else: dataset_num_proc = min(dataset_num_proc, int(memory_gb_left))
-
- super().__init__(
- output_dir = output_dir,
- overwrite_output_dir = overwrite_output_dir,
- do_train = do_train,
- do_eval = do_eval,
- do_predict = do_predict,
- eval_strategy = eval_strategy,
- prediction_loss_only = prediction_loss_only,
- per_device_train_batch_size = per_device_train_batch_size,
- per_device_eval_batch_size = per_device_eval_batch_size,
- per_gpu_train_batch_size = per_gpu_train_batch_size,
- per_gpu_eval_batch_size = per_gpu_eval_batch_size,
- gradient_accumulation_steps = gradient_accumulation_steps,
- eval_accumulation_steps = eval_accumulation_steps,
- eval_delay = eval_delay,
- torch_empty_cache_steps = torch_empty_cache_steps,
- learning_rate = learning_rate,
- weight_decay = weight_decay,
- adam_beta1 = adam_beta1,
- adam_beta2 = adam_beta2,
- adam_epsilon = adam_epsilon,
- max_grad_norm = max_grad_norm,
- num_train_epochs = num_train_epochs,
- max_steps = max_steps,
- lr_scheduler_type = lr_scheduler_type,
- lr_scheduler_kwargs = lr_scheduler_kwargs,
- warmup_ratio = warmup_ratio,
- warmup_steps = warmup_steps,
- log_level = log_level,
- log_level_replica = log_level_replica,
- log_on_each_node = log_on_each_node,
- logging_dir = logging_dir,
- logging_strategy = logging_strategy,
- logging_first_step = logging_first_step,
- logging_steps = logging_steps,
- logging_nan_inf_filter = logging_nan_inf_filter,
- save_strategy = save_strategy,
- save_steps = save_steps,
- save_total_limit = save_total_limit,
- save_safetensors = save_safetensors,
- save_on_each_node = save_on_each_node,
- save_only_model = save_only_model,
- restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
- no_cuda = no_cuda,
- use_cpu = use_cpu,
- use_mps_device = use_mps_device,
- seed = seed,
- data_seed = data_seed,
- jit_mode_eval = jit_mode_eval,
- bf16 = bf16,
- fp16 = fp16,
- fp16_opt_level = fp16_opt_level,
- half_precision_backend = half_precision_backend,
- bf16_full_eval = bf16_full_eval,
- fp16_full_eval = fp16_full_eval,
- tf32 = tf32,
- local_rank = local_rank,
- ddp_backend = ddp_backend,
- tpu_num_cores = tpu_num_cores,
- tpu_metrics_debug = tpu_metrics_debug,
- debug = debug,
- dataloader_drop_last = dataloader_drop_last,
- eval_steps = eval_steps,
- dataloader_num_workers = dataloader_num_workers,
- dataloader_prefetch_factor = dataloader_prefetch_factor,
- past_index = past_index,
- run_name = run_name,
- disable_tqdm = disable_tqdm,
- remove_unused_columns = remove_unused_columns,
- label_names = label_names,
- load_best_model_at_end = load_best_model_at_end,
- metric_for_best_model = metric_for_best_model,
- greater_is_better = greater_is_better,
- ignore_data_skip = ignore_data_skip,
- fsdp = fsdp,
- fsdp_min_num_params = fsdp_min_num_params,
- fsdp_config = fsdp_config,
- fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
- accelerator_config = accelerator_config,
- parallelism_config = parallelism_config,
- deepspeed = deepspeed,
- label_smoothing_factor = label_smoothing_factor,
- optim = optim,
- optim_args = optim_args,
- adafactor = adafactor,
- group_by_length = group_by_length,
- length_column_name = length_column_name,
- report_to = report_to,
- project = project,
- trackio_space_id = trackio_space_id,
- ddp_find_unused_parameters = ddp_find_unused_parameters,
- ddp_bucket_cap_mb = ddp_bucket_cap_mb,
- ddp_broadcast_buffers = ddp_broadcast_buffers,
- dataloader_pin_memory = dataloader_pin_memory,
- dataloader_persistent_workers = dataloader_persistent_workers,
- skip_memory_metrics = skip_memory_metrics,
- use_legacy_prediction_loop = use_legacy_prediction_loop,
- push_to_hub = push_to_hub,
- resume_from_checkpoint = resume_from_checkpoint,
- hub_model_id = hub_model_id,
- hub_strategy = hub_strategy,
- hub_token = hub_token,
- hub_private_repo = hub_private_repo,
- hub_always_push = hub_always_push,
- hub_revision = hub_revision,
- gradient_checkpointing = gradient_checkpointing,
- gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
- include_inputs_for_metrics = include_inputs_for_metrics,
- eval_do_concat_batches = eval_do_concat_batches,
- fp16_backend = fp16_backend,
- push_to_hub_model_id = push_to_hub_model_id,
- push_to_hub_organization = push_to_hub_organization,
- push_to_hub_token = push_to_hub_token,
- mp_parameters = mp_parameters,
- auto_find_batch_size = auto_find_batch_size,
- full_determinism = full_determinism,
- torchdynamo = torchdynamo,
- ray_scope = ray_scope,
- ddp_timeout = ddp_timeout,
- torch_compile = torch_compile,
- torch_compile_backend = torch_compile_backend,
- torch_compile_mode = torch_compile_mode,
- include_tokens_per_second = include_tokens_per_second,
- include_num_input_tokens_seen = include_num_input_tokens_seen,
- neftune_noise_alpha = neftune_noise_alpha,
- optim_target_modules = optim_target_modules,
- batch_eval_metrics = batch_eval_metrics,
- eval_on_start = eval_on_start,
- use_liger_kernel = use_liger_kernel,
- liger_kernel_config = liger_kernel_config,
- eval_use_gather_object = eval_use_gather_object,
- average_tokens_across_devices = average_tokens_across_devices,
- model_init_kwargs = model_init_kwargs,
- ref_model_init_kwargs = ref_model_init_kwargs,
- model_adapter_name = model_adapter_name,
- ref_adapter_name = ref_adapter_name,
- force_use_ref_model = force_use_ref_model,
- disable_dropout = disable_dropout,
- use_logits_to_keep = use_logits_to_keep,
- dataset_num_proc = dataset_num_proc,
- pad_token = pad_token,
- label_pad_token_id = label_pad_token_id,
- max_prompt_length = max_prompt_length,
- max_completion_length = max_completion_length,
- max_length = max_length,
- truncation_mode = truncation_mode,
- padding_free = padding_free,
- precompute_ref_log_probs = precompute_ref_log_probs,
- precompute_ref_batch_size = precompute_ref_batch_size,
- tools = tools,
- use_liger_loss = use_liger_loss,
- base_model_attribute_name = base_model_attribute_name,
- beta = beta,
- f_alpha_divergence_coef = f_alpha_divergence_coef,
- reference_free = reference_free,
- label_smoothing = label_smoothing,
- use_weighting = use_weighting,
- rpo_alpha = rpo_alpha,
- ld_alpha = ld_alpha,
- discopop_tau = discopop_tau,
- loss_weights = loss_weights,
- sync_ref_model = sync_ref_model,
- ref_model_mixup_alpha = ref_model_mixup_alpha,
- ref_model_sync_steps = ref_model_sync_steps,
- generate_during_eval = generate_during_eval,
- padding_value = padding_value,**kwargs)
- self.vllm_sampling_params = vllm_sampling_params
- self.unsloth_num_chunks = unsloth_num_chunks
- if unsloth_grpo_mini_batch is not None:
- if self.generation_batch_size >= unsloth_grpo_mini_batch:
- self.unsloth_grpo_mini_batch = unsloth_grpo_mini_batch
- else:
- raise ValueError(
- f"Unsloth GRPO mini batch size needs to be less than or equal to the effective generation batch size, "
- f"which is self.per_device_train_batch_size * gradient_accumulation_steps."
- )
- self.unsloth_logit_chunk_multiplier = unsloth_logit_chunk_multiplier
- self.max_seq_length = max_seq_length
-
-pass
-
-class _UnslothDPOTrainer(BaseTrainer):
- """"""
-
- _tag_names = ["trl", "dpo"]
- _name = "DPO"
- _paper = {
- "title": "Direct Preference Optimization: Your Language Model is Secretly a Reward Model",
- "id": "2305.18290",
- # docstyle-ignore
- "citation": textwrap.dedent("""\
- @inproceedings{rafailov2023direct,
- title = {{Direct Preference Optimization: Your Language Model is Secretly a Reward Model}},
- author = {Rafael Rafailov and Archit Sharma and Eric Mitchell and Christopher D. Manning and Stefano Ermon and Chelsea Finn},
- year = 2023,
- booktitle = {Advances in Neural Information Processing Systems 36: Annual Conference on Neural Information Processing Systems 2023, NeurIPS 2023, New Orleans, LA, USA, December 10 - 16, 2023},
- url = {http://papers.nips.cc/paper_files/paper/2023/hash/a85b405ed65c6477a4fe8302b5e06ce7-Abstract-Conference.html},
- editor = {Alice Oh and Tristan Naumann and Amir Globerson and Kate Saenko and Moritz Hardt and Sergey Levine},
- }"""),
- }
-
- def __init__(
- self,
- model: Union[str, nn.Module, PreTrainedModel],
- ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
- args: Optional[DPOConfig] = None,
- data_collator: Optional[DataCollator] = None, # type: ignore
- train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
- eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
- processing_class: Optional[
- Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
- ] = None,
- compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
- callbacks: Optional[list[TrainerCallback]] = None,
- optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
- optimizer_cls_and_kwargs: Optional[tuple[type[torch.optim.Optimizer], dict[str, Any]]] = None,
- preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
- peft_config: Optional["PeftConfig"] = None,
- ):
- # Args
- if args is None:
- model_name = model if isinstance(model, str) else model.config._name_or_path
- model_name = model_name.split("/")[-1]
- args = DPOConfig(f"{model_name}-DPO")
-
- # Model and reference model
- if isinstance(model, str):
- model = create_model_from_path(model, **args.model_init_kwargs or {})
- else:
- if args.model_init_kwargs is not None:
- logger.warning(
- "You passed `model_init_kwargs` to the `DPOConfig`, but your model is already instantiated. "
- "The `model_init_kwargs` will be ignored."
- )
- model_id = model.config._name_or_path
- if isinstance(ref_model, str):
- ref_model = create_model_from_path(ref_model, **args.ref_model_init_kwargs or {})
- else:
- if args.ref_model_init_kwargs is not None:
- logger.warning(
- "You passed `ref_model_init_kwargs` to the `DPOConfig`, but your model is already instantiated. "
- "The `ref_model_init_kwargs` will be ignored."
- )
- if ref_model is model:
- raise ValueError(
- "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
- "same as `model`, you can simply omit the `ref_model` argument and it will be created for you."
- )
-
- # Processing class
- if processing_class is None:
- processing_class = AutoProcessor.from_pretrained(model_id)
-
- # Handle pad token for processors or tokenizers
- if isinstance(processing_class, ProcessorMixin):
- tokenizer = processing_class.tokenizer
- self._is_vlm = True
- elif isinstance(processing_class, PreTrainedTokenizerBase):
- tokenizer = processing_class
- self._is_vlm = False
- else:
- raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`")
-
- # Get the pad token: if not provided, use the one from the processing class or the eos token
- # if the processing class does not have a pad token.
- if args.padding_value is not None: # deprecated, will be removed in 0.26.0.
- warnings.warn(
- "The `padding_value` argument is deprecated and will be removed in version 0.26.0. Please use "
- "`pad_token` (str) instead."
- )
- self.pad_token_id = args.padding_value
- else:
- pad_token = args.pad_token or tokenizer.pad_token or tokenizer.eos_token
- self.pad_token_id = tokenizer.convert_tokens_to_ids(pad_token)
- if self.pad_token_id is None:
- raise ValueError(
- f"The specified `pad_token` ('{pad_token}') is not found in the vocabulary of the given "
- f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `pad_token` exists "
- "in the vocabulary before using it as a padding token."
- )
-
- # PEFT configuration and model wrapping
- model = self._prepare_peft_model(model, ref_model, peft_config, args)
-
- if args.generate_during_eval and not (is_wandb_available() or is_comet_available() or is_mlflow_available()):
- raise ValueError(
- "`generate_during_eval=True` requires Weights and Biases, MLFlow or Comet to be installed."
- " Please install `wandb`, `mlflow` or `comet-ml` to resolve."
- )
-
- self.is_encoder_decoder = model.config.is_encoder_decoder
- self.is_vision_model = model.config.model_type in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.keys()
- self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
- self.model_adapter_name = args.model_adapter_name
- self.ref_adapter_name = args.ref_adapter_name
- self.reference_free = args.reference_free
-
- if ref_model:
- self.ref_model = ref_model
- elif self.is_peft_model or args.precompute_ref_log_probs:
- # The `model` with adapters turned off will be used as the reference model
- self.ref_model = None
- else:
- self.ref_model = create_reference_model(model)
-
- # Disable dropout in the model and reference model
- if args.disable_dropout:
- disable_dropout_in_model(model)
- if self.ref_model is not None:
- disable_dropout_in_model(self.ref_model)
-
- # Liger kernel
- if args.use_liger_loss:
- if not is_liger_kernel_available():
- raise ImportError(
- "You set `use_liger_loss=True` but the liger kernel is not available. "
- "Please install liger-kernel first: `pip install liger-kernel`"
- )
- if args.loss_type not in ["sigmoid", "apo_zero", "apo_down", "sppo_hard", "nca_pair"]:
- raise ValueError(
- "You set `use_liger_loss=True` but the loss type is not from `[sigmoid, apo_zero, apo_down, sppo_hard, nca_pair`. "
- "Please set `loss_type='[sigmoid | apo_zero | apo_down | sppo_hard | nca_pair]'` to use the liger kernel."
- )
- self.dpo_loss_fn = LigerFusedLinearDPOLoss(
- ignore_index=args.label_pad_token_id,
- beta=args.beta,
- use_ref_model=not args.reference_free,
- average_log_prob=False,
- loss_type=args.loss_type,
- )
- # The trainer estimates the number of FLOPs [floating-point operations] using the number of elements in the
- # input tensor associated with the key "input_ids". However, in DPO, the sampled data does not include the
- # "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and
- # "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens
- # of the input, floating-point operations will not be computed." To suppress this warning, we set the
- # "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate
- # that the warning has already been issued.
- model.warnings_issued["estimate_tokens"] = True
-
- # Data collator
- if data_collator is None:
- data_collator = DataCollatorForPreference(pad_token_id=self.pad_token_id)
-
- self.generate_during_eval = args.generate_during_eval
- self.label_pad_token_id = args.label_pad_token_id
- self.max_prompt_length = args.max_prompt_length
- self.max_completion_length = args.max_completion_length
- self.max_length = args.max_length
- self.truncation_mode = args.truncation_mode
- self.precompute_ref_log_probs = args.precompute_ref_log_probs
- self.use_logits_to_keep = args.use_logits_to_keep
-
- if args.padding_free:
- if model.config._attn_implementation != "flash_attention_2":
- logger.warning(
- "Padding-free training is enabled, but the attention implementation is not set to "
- "'flash_attention_2'. Padding-free training flattens batches into a single sequence, and "
- "'flash_attention_2' is the only known attention mechanism that reliably supports this. Using "
- "other implementations may lead to unexpected behavior. To ensure compatibility, set "
- "`attn_implementation='flash_attention_2'` in the model configuration, or verify that your "
- "attention mechanism can handle flattened sequences."
- )
- self.padding_free = args.padding_free
-
- # Since ref_logs are precomputed on the first call to get_train/eval_dataloader
- # keep track of first called to avoid computation of future calls
- self._precomputed_train_ref_log_probs = False
- self._precomputed_eval_ref_log_probs = False
-
- self.beta = args.beta
- self.label_smoothing = args.label_smoothing
- self.loss_type = args.loss_type if isinstance(args.loss_type, list) else [args.loss_type]
- self.loss_weights = args.loss_weights
- self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
- self.use_weighting = args.use_weighting
- self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
- if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
- logger.warning(
- "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
- "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
- "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
- "loss.",
- )
- for loss_type in self.loss_type:
- if (
- loss_type in ["hinge", "ipo", "bco_pair", "sppo_hard", "nca_pair", "apo_zero", "apo_down"]
- and args.label_smoothing > 0
- ):
- logger.warning(
- f"You are using the {loss_type} loss type that does not support label smoothing. The "
- "`label_smoothing` parameter will be ignored. Set `label_smoothing` to `0.0` to remove this "
- "warning.",
- )
- if loss_type == "kto_pair":
- raise ValueError("Support for kto_pair has been removed in DPOTrainer. Please use KTOTrainer.")
-
- self._stored_metrics = defaultdict(lambda: defaultdict(list))
- self.f_divergence_type = args.f_divergence_type
- self.f_divergence_params = {FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY: args.f_alpha_divergence_coef}
- self.dataset_num_proc = args.dataset_num_proc
-
- # Dataset preparation
- train_dataset = self._prepare_dataset(train_dataset, processing_class, args, "train")
- if eval_dataset is not None:
- if isinstance(eval_dataset, dict):
- eval_dataset = {
- key: self._prepare_dataset(dataset, processing_class, args, key)
- for key, dataset in eval_dataset.items()
- }
- else:
- eval_dataset = self._prepare_dataset(eval_dataset, processing_class, args, "eval")
-
- super().__init__(
- model=model,
- args=args,
- data_collator=data_collator,
- train_dataset=train_dataset,
- eval_dataset=eval_dataset,
- processing_class=processing_class,
- compute_metrics=compute_metrics,
- callbacks=callbacks,
- optimizers=optimizers,
- optimizer_cls_and_kwargs=optimizer_cls_and_kwargs,
- preprocess_logits_for_metrics=preprocess_logits_for_metrics,
- )
-
- # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
- # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
- # self.model_accepts_loss_kwargs to False to enable scaling.
- self.model_accepts_loss_kwargs = False
-
- # Add tags for models that have been loaded with the correct transformers version
- if hasattr(self.model, "add_model_tags"):
- self.model.add_model_tags(self._tag_names)
-
- if not hasattr(self, "accelerator"):
- raise AttributeError(
- "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
- )
-
- # Deepspeed Zero-3 does not support precompute_ref_log_probs
- if self.is_deepspeed_enabled:
- if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs:
- raise ValueError(
- "You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`."
- )
-
- if self.ref_model is None:
- if not (self.is_peft_model or self.precompute_ref_log_probs):
- raise ValueError(
- "No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`"
- )
- if args.sync_ref_model:
- raise ValueError(
- "You currently cannot use `ref_model=None` with TR-DPO method. Please provide `ref_model`."
- )
- else:
- if self.is_deepspeed_enabled:
- self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
- elif self.is_fsdp_enabled:
- self.ref_model = prepare_fsdp(self.ref_model, self.accelerator)
- else:
- self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
-
- if args.sync_ref_model:
- if self.precompute_ref_log_probs:
- raise ValueError(
- "You cannot use `precompute_ref_log_probs=True` with TR-DPO method. Please set `precompute_ref_log_probs=False`."
- )
-
- self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator))
-
- if "bco_pair" in self.loss_type:
- self.running = RunningMoments(self.accelerator)
-
- @property
- def padding_value(self):
- warnings.warn(
- "The `padding_value` property is deprecated and will be removed in version 0.26.0. Please use "
- "`pad_token_id` instead.",
- )
- return self.pad_token_id
-
- @padding_value.setter
- def padding_value(self, value):
- warnings.warn(
- "The `padding_value` property is deprecated and will be removed in version 0.26.0. Please use "
- "`pad_token_id` instead.",
- )
- self.pad_token_id = value
-
- def _prepare_peft_model(
- self, model: PreTrainedModel, ref_model: PreTrainedModel, peft_config: Any, args: DPOConfig
- ) -> PreTrainedModel:
- """Prepares a model for PEFT training."""
- # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
- # has been called in order to properly call autocast if needed.
- self._peft_has_been_casted_to_bf16 = False
-
- if not is_peft_available() and peft_config is not None:
- raise ValueError(
- "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
- )
- elif is_peft_available() and peft_config is not None:
- # if model is a peft model and we have a peft_config, we merge and unload it first
- if isinstance(model, PeftModel):
- model = model.merge_and_unload()
-
- if ref_model is not None and not args.force_use_ref_model:
- raise ValueError(
- "You passed both a ref_model and a peft_config. For training PEFT adapters with DPO there is no need to pass a reference"
- " model. Please pass `ref_model=None` in case you want to train PEFT adapters, or pass a ref_model with `force_use_ref_model=True` in DPOTrainer's init."
- " if you want to use a different ref_model."
- )
-
- if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
- _support_gc_kwargs = hasattr(
- args, "gradient_checkpointing_kwargs"
- ) and "gradient_checkpointing_kwargs" in list(
- inspect.signature(prepare_model_for_kbit_training).parameters
- )
-
- prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
-
- if _support_gc_kwargs:
- prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
-
- model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
-
- else:
- model = self._prepare_gradient_checkpointing(model, args)
-
- # get peft model with the given config
- model = get_peft_model(model, peft_config)
- if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
- peft_module_casting_to_bf16(model)
- # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
- self._peft_has_been_casted_to_bf16 = True
-
- else:
- model = self._prepare_gradient_checkpointing(model, args)
-
- return model
-
- def _prepare_gradient_checkpointing(self, model: PreTrainedModel, args: DPOConfig):
- """Prepare the gradienting checkpointing for the model."""
- # For models that use gradient_checkpointing, we need to attach a hook that enables input
- # to explicitly have `requires_grad=True`, otherwise training will either silently
- # fail or completely fail.
- if args.gradient_checkpointing:
- # For backward compatibility with older versions of transformers
- if hasattr(model, "enable_input_require_grads"):
- model.enable_input_require_grads()
- else:
-
- def make_inputs_require_grad(module, input, output):
- output.requires_grad_(True)
-
- model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
-
- return model
-
- def _prepare_dataset(
- self,
- dataset: Union[Dataset, IterableDataset],
- processing_class: Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin],
- args: DPOConfig,
- dataset_name: str,
- ) -> Union[Dataset, IterableDataset]:
- # Build the kwargs for the `map` function
- map_kwargs = {}
- if isinstance(dataset, Dataset): # IterableDataset does not support num_proc nor writer_batch_size
- map_kwargs["num_proc"] = args.dataset_num_proc
- map_kwargs["writer_batch_size"] = 10
-
- with PartialState().main_process_first():
- # Extract prompt if needed
- if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
- map_kwargs["desc"] = f"Extracting prompt in {dataset_name} dataset"
- dataset = dataset.map(maybe_extract_prompt, **map_kwargs)
-
- # Apply the chat template if needed
- if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
- map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset"
- dataset = dataset.map(
- maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class, "tools": args.tools}, **map_kwargs
- )
-
- # Tokenize the dataset
- if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
- map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"
-
- dataset = dataset.map(
- self.tokenize_row if not self.is_vision_model else self.process_row,
- remove_columns=["chosen", "rejected"],
- fn_kwargs={
- "processing_class": processing_class,
- "max_prompt_length": args.max_prompt_length,
- "max_completion_length": args.max_completion_length,
- # for enc-dec, we add the special tokens ([bos_token] + prompt + [eos_token]; completion + [eos_token])
- "add_special_tokens": False,
- },
- **map_kwargs,
- )
-
- return dataset
-
- @staticmethod
- def tokenize_row(
- features: dict[str, str],
- processing_class: PreTrainedTokenizerBase,
- max_prompt_length: Optional[int] = None,
- max_completion_length: Optional[int] = None,
- add_special_tokens: bool = True,
- ) -> dict[str, list[int]]:
- """
- Tokenize a row of the dataset.
-
- Args:
- features (`dict[str, str]`):
- Row of the dataset, should contain the keys `"prompt"`, `"chosen"`, and `"rejected"`.
- processing_class ([`~transformers.PreTrainedTokenizerBase`]):
- Processing class used to process the data.
- max_prompt_length (`int` or `None`):
- Maximum length of the prompt sequence. If `None`, the prompt sequence is not truncated.
- max_completion_length (`int` or `None`):
- Maximum length of the completion sequences. If `None`, the completion sequences are not truncated.
- add_special_tokens (`bool`):
- Whether to add special tokens to the sequences. Typically used for encoder-decoder models. If `True`,
- the prompt sequence will have a bos token prepended and an eos token appended. In any case, the
- completion sequences will have an eos token appended.
-
- Returns:
- `dict[str, list[int]]`:
- Tokenized sequences with the keys `"prompt_input_ids"`, `"chosen_input_ids"`, and
- `"rejected_input_ids".
-
- Example:
- ```python
- >>> from transformers import GPT2Tokenizer
-
- >>> tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
- >>> features = {"prompt": "The sky is", "chosen": " blue", "rejected": " green"}
- >>> DPOTrainer.tokenize_row(
- ... features, tokenizer, max_prompt_length=3, max_completion_length=3, add_special_tokens=False
- ... )
- {'prompt_input_ids': [464, 6766, 318], 'chosen_input_ids': [4171, 50256], 'rejected_input_ids': [4077, 50256]}
- ```
- """
- tokenizer = processing_class # the processing class is a tokenizer
- prompt_input_ids = tokenizer(features["prompt"], add_special_tokens=False)["input_ids"]
- chosen_input_ids = tokenizer(features["chosen"], add_special_tokens=False)["input_ids"]
- rejected_input_ids = tokenizer(features["rejected"], add_special_tokens=False)["input_ids"]
-
- # Add special tokens (typically for encoder-decoder models)
- if add_special_tokens:
- if tokenizer.bos_token_id is not None:
- prompt_input_ids = [tokenizer.bos_token_id] + prompt_input_ids
- if tokenizer.eos_token_id is not None:
- prompt_input_ids = prompt_input_ids + [tokenizer.eos_token_id]
- chosen_input_ids = chosen_input_ids + [tokenizer.eos_token_id]
- rejected_input_ids = rejected_input_ids + [tokenizer.eos_token_id]
-
- # Truncate prompt and completion sequences
- if max_prompt_length is not None:
- prompt_input_ids = prompt_input_ids[-max_prompt_length:]
- if max_completion_length is not None:
- chosen_input_ids = chosen_input_ids[:max_completion_length]
- rejected_input_ids = rejected_input_ids[:max_completion_length]
-
- return {
- "prompt_input_ids": prompt_input_ids,
- "chosen_input_ids": chosen_input_ids,
- "rejected_input_ids": rejected_input_ids,
- }
-
- @staticmethod
- def process_row(
- features: dict[str, str],
- processing_class: PreTrainedTokenizerBase,
- max_prompt_length: Optional[int] = None,
- max_completion_length: Optional[int] = None,
- add_special_tokens: bool = True,
- ) -> dict[str, list[int]]:
- """
- Same as `tokenize_row` but for vision models. Please refer to `tokenize_row` for more information.
- """
- processor, tokenizer = processing_class, processing_class.tokenizer # the processing class is a processor
- processed_features = processor(images=features["images"], text=features["prompt"], add_special_tokens=False)
-
- prompt_input_ids = processed_features["input_ids"][0]
- pixel_values = processed_features["pixel_values"][0]
- chosen_input_ids = tokenizer(features["chosen"], add_special_tokens=False)["input_ids"]
- rejected_input_ids = tokenizer(features["rejected"], add_special_tokens=False)["input_ids"]
-
- # Add special tokens (typically for encoder-decoder models)
- if add_special_tokens:
- if tokenizer.bos_token_id is not None:
- prompt_input_ids = [tokenizer.bos_token_id] + prompt_input_ids
- if tokenizer.eos_token_id is not None:
- prompt_input_ids = prompt_input_ids + [tokenizer.eos_token_id]
- chosen_input_ids = chosen_input_ids + [tokenizer.eos_token_id]
- rejected_input_ids = rejected_input_ids + [tokenizer.eos_token_id]
-
- # Truncate prompt and completion sequences
- if max_prompt_length is not None:
- prompt_input_ids = prompt_input_ids[-max_prompt_length:]
- if max_completion_length is not None:
- chosen_input_ids = chosen_input_ids[:max_completion_length]
- rejected_input_ids = rejected_input_ids[:max_completion_length]
-
- output = {
- "prompt_input_ids": prompt_input_ids,
- "pixel_values": pixel_values,
- "chosen_input_ids": chosen_input_ids,
- "rejected_input_ids": rejected_input_ids,
- }
-
- if "pixel_attention_mask" in processed_features:
- output["pixel_attention_mask"] = processed_features["pixel_attention_mask"][0]
- if "image_sizes" in processed_features:
- output["image_sizes"] = processed_features["image_sizes"][0]
- if "token_type_ids" in processed_features:
- output["token_type_ids"] = processed_features["token_type_ids"][0]
-
- return output
-
- def _set_signature_columns_if_needed(self):
- # If `self.args.remove_unused_columns` is True, non-signature columns are removed.
- # By default, this method sets `self._signature_columns` to the model's expected inputs.
- # In DPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
- # Instead, we set them to the columns expected by `DataCollatorForPreference`, hence the override.
- if self._signature_columns is None:
- self._signature_columns = [
- "prompt_input_ids",
- "chosen_input_ids",
- "rejected_input_ids",
- "image_sizes",
- "token_type_ids",
- "ref_chosen_logps",
- "ref_rejected_logps",
- ]
-
- def get_train_dataloader(self) -> DataLoader:
- """
- Returns the training [`~torch.utils.data.DataLoader`].
-
- Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`.
- """
-
- if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs:
- batch_size = self.args.precompute_ref_batch_size or self.args.per_device_train_batch_size
- dataloader_params = {
- "batch_size": batch_size,
- "collate_fn": self.data_collator,
- "num_workers": self.args.dataloader_num_workers,
- "pin_memory": self.args.dataloader_pin_memory,
- "shuffle": False,
- }
-
- # prepare dataloader
- data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params))
-
- ref_chosen_logps = []
- ref_rejected_logps = []
- for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"):
- ref_chosen_logp, ref_rejected_logp = self.compute_ref_log_probs(padded_batch)
- ref_chosen_logp, ref_rejected_logp = self.accelerator.gather_for_metrics(
- (ref_chosen_logp, ref_rejected_logp)
- )
- ref_chosen_logps.append(ref_chosen_logp.cpu())
- ref_rejected_logps.append(ref_rejected_logp.cpu())
-
- # Unnecessary cache clearing to avoid OOM
- empty_cache()
- self.accelerator.free_memory()
-
- all_ref_chosen_logps = torch.cat(ref_chosen_logps).float().numpy()
- all_ref_rejected_logps = torch.cat(ref_rejected_logps).float().numpy()
-
- self.train_dataset = self.train_dataset.add_column(name="ref_chosen_logps", column=all_ref_chosen_logps)
- self.train_dataset = self.train_dataset.add_column(
- name="ref_rejected_logps", column=all_ref_rejected_logps
- )
-
- self._precomputed_train_ref_log_probs = True
-
- return super().get_train_dataloader()
-
- def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
- """
- Returns the evaluation [`~torch.utils.data.DataLoader`].
-
- Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`.
-
- Args:
- eval_dataset (`torch.utils.data.Dataset`, *optional*):
- If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
- by the `model.forward()` method are automatically removed. It must implement `__len__`.
- """
- if eval_dataset is None and self.eval_dataset is None:
- raise ValueError("Trainer: evaluation requires an eval_dataset.")
- eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
-
- if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs:
- batch_size = self.args.precompute_ref_batch_size or self.args.per_device_eval_batch_size
- dataloader_params = {
- "batch_size": batch_size,
- "collate_fn": self.data_collator,
- "num_workers": self.args.dataloader_num_workers,
- "pin_memory": self.args.dataloader_pin_memory,
- "shuffle": False,
- }
-
- # prepare dataloader
- data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))
-
- ref_chosen_logps = []
- ref_rejected_logps = []
- for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"):
- ref_chosen_logp, ref_rejected_logp = self.compute_ref_log_probs(padded_batch)
- ref_chosen_logp, ref_rejected_logp = self.accelerator.gather_for_metrics(
- (ref_chosen_logp, ref_rejected_logp)
- )
- ref_chosen_logps.append(ref_chosen_logp.cpu())
- ref_rejected_logps.append(ref_rejected_logp.cpu())
-
- all_ref_chosen_logps = torch.cat(ref_chosen_logps).float().numpy()
- all_ref_rejected_logps = torch.cat(ref_rejected_logps).float().numpy()
-
- eval_dataset = eval_dataset.add_column(name="ref_chosen_logps", column=all_ref_chosen_logps)
- eval_dataset = eval_dataset.add_column(name="ref_rejected_logps", column=all_ref_rejected_logps)
-
- # Save calculated ref_chosen_logps and ref_rejected_logps to the eval_dataset for subsequent runs
- if self.eval_dataset is not None:
- self.eval_dataset = eval_dataset
- self._precomputed_eval_ref_log_probs = True
-
- return super().get_eval_dataloader(eval_dataset=eval_dataset)
-
- @contextmanager
- def null_ref_context(self):
- """Context manager for handling null reference model (that is, peft adapter manipulation)."""
- with (
- self.accelerator.unwrap_model(self.model).disable_adapter()
- if self.is_peft_model and not self.ref_adapter_name
- else nullcontext()
- ):
- if self.ref_adapter_name:
- self.model.set_adapter(self.ref_adapter_name)
- yield
- if self.ref_adapter_name:
- self.model.set_adapter(self.model_adapter_name or "default")
-
- def compute_ref_log_probs(self, batch: dict[str, torch.LongTensor]) -> tuple[torch.Tensor, torch.Tensor]:
- """Computes log probabilities of the reference model for a single padded batch of a DPO specific dataset."""
- compte_ref_context_manager = (
- autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
- )
- with torch.no_grad(), compte_ref_context_manager:
- if self.ref_model is None:
- with self.null_ref_context():
- ref_model_output = self.concatenated_forward(self.model, batch, is_ref_model=True)
- else:
- ref_model_output = self.concatenated_forward(self.ref_model, batch, is_ref_model=True)
- return ref_model_output["chosen_logps"], ref_model_output["rejected_logps"]
-
- @staticmethod
- def concatenated_inputs(
- batch: dict[str, Union[list, torch.LongTensor]], padding_value: int
- ) -> dict[str, torch.LongTensor]:
- """
- Concatenate the `chosen` and `rejected` inputs from the batch into a single tensor for both the prompt and
- completion sequences.
-
- Args:
- batch (`dict[str, Union[list, torch.LongTensor]]`):
- A batch of input data. The batch must contain the following keys:
-
- - `"prompt_input_ids"`: Tensor of shape `(batch_size, prompt_length)` representing the prompt input
- IDs.
- - `"chosen_input_ids"`: Tensor of shape `(batch_size, chosen_length)` representing the chosen
- completion input IDs.
- - `"rejected_input_ids"`: Tensor of shape `(batch_size, rejected_length)` representing the rejected
- completion input IDs.
- - `"prompt_pixel_values"` (optional): Tensor for pixel values, if available.
- - `"prompt_pixel_attention_mask"` (optional): Tensor for pixel attention masks, if available.
-
- padding_value (`int`):
- The padding value to use for the concatenated completion sequences (`chosen_input_ids` and
- `rejected_input_ids`).
-
- Returns:
- `dict[str, torch.LongTensor]`: A dictionary containing:
-
- - `"prompt_input_ids"`: Concatenated prompt input IDs of shape `(2 * batch_size, prompt_length)`.
- - `"completion_input_ids"`: Concatenated chosen and rejected completion input IDs of shape `(2 *
- batch_size, max_completion_length)`.
- - `"prompt_attention_mask"`: Concatenated prompt attention masks of shape `(2 * batch_size,
- prompt_length)`.
- - `"completion_attention_mask"`: Concatenated chosen and rejected attention masks of shape `(2 *
- batch_size, max_completion_length)`.
- - `"pixel_values"` (optional): Concatenated pixel values if `"prompt_pixel_values"` are present.
- - `"pixel_attention_mask"` (optional): Concatenated pixel attention masks if
- `"prompt_pixel_attention_mask"` are present.
-
- Notes:
- The completion input IDs and attention masks are padded to the maximum completion length of the chosen or
- rejected sequences.
- """
- output = {}
-
- # For the prompt, the input_ids are the same for both the chosen and rejected responses
- output["prompt_input_ids"] = torch.cat([batch["prompt_input_ids"], batch["prompt_input_ids"]], dim=0)
- output["prompt_attention_mask"] = torch.cat(
- [batch["prompt_attention_mask"], batch["prompt_attention_mask"]], dim=0
- )
- if "pixel_values" in batch:
- output["pixel_values"] = torch.cat([batch["pixel_values"], batch["pixel_values"]], dim=0)
-
- if "pixel_attention_mask" in batch:
- output["pixel_attention_mask"] = torch.cat(
- [batch["pixel_attention_mask"], batch["pixel_attention_mask"]], dim=0
- )
- if "image_sizes" in batch:
- output["image_sizes"] = torch.cat([batch["image_sizes"], batch["image_sizes"]], dim=0)
- if "token_type_ids" in batch:
- output["token_type_ids"] = torch.cat((batch["token_type_ids"], batch["token_type_ids"]))
-
- # Concatenate the chosen and rejected completions
- max_completion_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
- output["completion_input_ids"] = torch.cat(
- (
- pad_to_length(batch["chosen_input_ids"], max_completion_length, pad_value=padding_value),
- pad_to_length(batch["rejected_input_ids"], max_completion_length, pad_value=padding_value),
- ),
- )
- output["completion_attention_mask"] = torch.cat(
- (
- pad_to_length(batch["chosen_attention_mask"], max_completion_length, pad_value=0),
- pad_to_length(batch["rejected_attention_mask"], max_completion_length, pad_value=0),
- ),
- )
-
- return output
-
- def dpo_loss(
- self,
- chosen_logps: torch.FloatTensor,
- rejected_logps: torch.FloatTensor,
- ref_chosen_logps: torch.FloatTensor,
- ref_rejected_logps: torch.FloatTensor,
- loss_type: str = "sigmoid",
- model_output: dict[str, torch.FloatTensor] = None,
- ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
- """
- Compute the DPO loss for a batch of policy and reference model log probabilities.
-
- Args:
- chosen_logps (`torch.FloatTensor`):
- Log probabilities of the model for the chosen responses. Shape: `(batch_size,)`.
- rejected_logps (`torch.FloatTensor`):
- Log probabilities of the model for the rejected responses. Shape: `(batch_size,)`.
- ref_chosen_logps (`torch.FloatTensor`):
- Log probabilities of the reference model for the chosen responses. Shape: `(batch_size,)`.
- ref_rejected_logps (`torch.FloatTensor`):
- Log probabilities of the reference model for the rejected responses. Shape: `(batch_size,)`.
- loss_type (`str`, defaults to `"sigmoid"`):
- The type of loss to compute. One of:
- - `"sigmoid"`: Sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper.
- - `"hinge"`: Hinge loss on the normalized likelihood from the
- [SLiC](https://huggingface.co/papers/2305.10425) paper.
- - `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper.
- - `"exo_pair"`: Pairwise EXO loss from the [EXO](https://huggingface.co/papers/2402.00856) paper.
- - `"nca_pair"`: Pairwise NCA loss from the [NCA](https://huggingface.co/papers/2402.05369) paper.
- - `"robust"`: Unbiased estimate of the DPO loss that is robust to preference noise from the [Robust
- DPO](https://huggingface.co/papers/2403.00409) paper.
- - `"bco_pair"`: Pairwise BCO loss from the [BCO](https://huggingface.co/papers/2404.04656) paper.
- - `"sppo_hard"`: SPPO loss with hard label from the [SPPO](https://huggingface.co/papers/2405.00675)
- paper.
- - `"aot"`: AOT loss for paired datasets from the [AOT](https://huggingface.co/papers/2406.05882) paper.
- - `"aot_pair"`: AOT loss for unpaired datasets from the [AOT](https://huggingface.co/papers/2406.05882)
- paper.
- - `"discopop"`: DiscoPOP (a.k.a Log-Ratio Modulated Loss, LRML) loss from the
- [DiscoPOP](https://huggingface.co/papers/2406.08414) paper.
- - `"apo_zero"`: APO-zero loss from the [APO](https://huggingface.co/papers/2408.06266) paper.
- - `"apo_down"`: APO-down loss from the [APO](https://huggingface.co/papers/2408.06266) paper.
- - `"sft"`: Negative log-likelihood loss (standard supervised fine-tuning loss).
- model_output (`dict[str, torch.FloatTensor]`, *optional*):
- The output of the model's forward pass. This is used to compute auxiliary losses if enabled.
-
- Returns:
- A tuple of three tensors: `(losses, chosen_rewards, rejected_rewards)`. The losses tensor contains the DPO
- loss for each example in the batch. The `chosen_rewards` and `rejected_rewards` tensors contain the rewards
- for the chosen and rejected responses, respectively.
- """
- device = self.accelerator.device
-
- # Get the log ratios for the chosen and rejected responses
- chosen_logratios = chosen_logps.to(device) - (not self.reference_free) * ref_chosen_logps.to(device)
- rejected_logratios = rejected_logps.to(device) - (not self.reference_free) * ref_rejected_logps.to(device)
-
- if self.f_divergence_type == FDivergenceType.ALPHA_DIVERGENCE:
- # The alpha-divergence formula: (1 - u^-alpha) / alpha
- # The divergence difference between the chosen and rejected sample is:
- # (1 - u[w]^-alpha) / alpha - (1 - u[l]^-alpha) / alpha
- # = (u[l]^-alpha - u[w]^-alpha) / alpha
- # where u[w] and u[l] are the policy/reference probability ratios
- # for the chosen and rejected samples, respectively.
- alpha_coef = FDivergenceConstants.ALPHA_DIVERGENCE_COEF_DEFAULT
- if self.f_divergence_params and FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY in self.f_divergence_params:
- alpha_coef = float(self.f_divergence_params[FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY])
- logits = (cap_exp(rejected_logratios * -alpha_coef) - cap_exp(chosen_logratios * -alpha_coef)) / alpha_coef
- else:
- logratios = chosen_logps - rejected_logps
- if self.reference_free:
- ref_logratios = torch.tensor([0], dtype=logratios.dtype, device=logratios.device)
- else:
- ref_logratios = ref_chosen_logps - ref_rejected_logps
-
- logratios = logratios.to(self.accelerator.device)
- ref_logratios = ref_logratios.to(self.accelerator.device)
- logits = logratios - ref_logratios
-
- if self.f_divergence_type == FDivergenceType.JS_DIVERGENCE:
- # The js-divergence formula: log(2 * u / (1 + u))
- # The divergence difference between the chosen and rejected sample is:
- # log(2 * u[w] / (1 + u[w])) - log(2 * u[l] / (1 + u[l]))
- # = log(u[w]) - log(u[l]) - (log(1 + u[w]) - log(1 + u[l]))
- # where u[w] and u[l] are the policy/reference probability ratios
- # for the chosen and rejected samples, respectively.
- logits -= F.softplus(chosen_logratios) - F.softplus(rejected_logratios)
-
- # The beta is a temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5.
- # We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the
- # labels and calculates a conservative DPO loss.
- if loss_type == "sigmoid":
- losses = (
- -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
- - F.logsigmoid(-self.beta * logits) * self.label_smoothing
- )
-
- elif loss_type == "robust":
- losses = (
- -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
- + F.logsigmoid(-self.beta * logits) * self.label_smoothing
- ) / (1 - 2 * self.label_smoothing)
-
- elif loss_type == "exo_pair":
- # eqn (16) of the EXO paper: https://huggingface.co/papers/2402.00856
- import math
-
- if self.label_smoothing == 0:
- self.label_smoothing = 1e-3
- losses = (self.beta * logits).sigmoid() * (
- F.logsigmoid(self.beta * logits) - math.log(1 - self.label_smoothing)
- ) + (-self.beta * logits).sigmoid() * (F.logsigmoid(-self.beta * logits) - math.log(self.label_smoothing))
-
- elif loss_type == "hinge":
- losses = torch.relu(1 - self.beta * logits)
-
- elif loss_type == "ipo":
- # eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper.
- losses = (logits - 1 / (2 * self.beta)) ** 2
-
- elif loss_type == "bco_pair":
- chosen_logratios = chosen_logps - ref_chosen_logps
- rejected_logratios = rejected_logps - ref_rejected_logps
- chosen_rewards = self.beta * chosen_logratios
- rejected_rewards = self.beta * rejected_logratios
- rewards = torch.cat((chosen_rewards, rejected_rewards), 0).mean().detach()
- self.running.update(rewards)
- delta = self.running.mean
- losses = -F.logsigmoid((self.beta * chosen_logratios) - delta) - F.logsigmoid(
- -(self.beta * rejected_logratios - delta)
- )
-
- elif loss_type == "sppo_hard":
- # In the paper (https://huggingface.co/papers/2405.00675), SPPO employs a soft probability approach,
- # estimated using the PairRM score. The probability calculation is conducted outside of the trainer class.
- # The version described here is the hard probability version, where P in Equation (4.7) of Algorithm 1 is
- # set to 1 for the winner and 0 for the loser.
- a = chosen_logps - ref_chosen_logps
- b = rejected_logps - ref_rejected_logps
- losses = (a - 0.5 / self.beta) ** 2 + (b + 0.5 / self.beta) ** 2
-
- elif loss_type == "nca_pair":
- chosen_rewards = (chosen_logps - ref_chosen_logps) * self.beta
- rejected_rewards = (rejected_logps - ref_rejected_logps) * self.beta
- losses = (
- -F.logsigmoid(chosen_rewards)
- - 0.5 * F.logsigmoid(-chosen_rewards)
- - 0.5 * F.logsigmoid(-rejected_rewards)
- )
-
- elif loss_type == "aot_pair":
- chosen_logratios = chosen_logps - ref_chosen_logps
- rejected_logratios = rejected_logps - ref_rejected_logps
- chosen_logratios_sorted, _ = torch.sort(chosen_logratios, dim=0)
- rejected_logratios_sorted, _ = torch.sort(rejected_logratios, dim=0)
- delta = chosen_logratios_sorted - rejected_logratios_sorted
- losses = (
- -F.logsigmoid(self.beta * delta) * (1 - self.label_smoothing)
- - F.logsigmoid(-self.beta * delta) * self.label_smoothing
- )
-
- elif loss_type == "aot":
- logratios = chosen_logps - rejected_logps
- ref_logratios = ref_chosen_logps - ref_rejected_logps
- logratios_sorted, _ = torch.sort(logratios, dim=0)
- ref_logratios_sorted, _ = torch.sort(ref_logratios, dim=0)
- delta = logratios_sorted - ref_logratios_sorted
- losses = (
- -F.logsigmoid(self.beta * delta) * (1 - self.label_smoothing)
- - F.logsigmoid(-self.beta * delta) * self.label_smoothing
- )
-
- elif loss_type == "apo_zero":
- # Eqn (7) of the APO paper (https://huggingface.co/papers/2408.06266)
- # Use this loss when you believe the chosen outputs are better than your model's default output
- losses_chosen = 1 - F.sigmoid(self.beta * chosen_logratios) # Increase chosen likelihood
- losses_rejected = F.sigmoid(self.beta * rejected_logratios) # Decrease rejected likelihood
- losses = losses_chosen + losses_rejected
-
- elif loss_type == "apo_down":
- # Eqn (8) of the APO paper (https://huggingface.co/papers/2408.06266)
- # Use this loss when you believe the chosen outputs are worse than your model's default output.
- # Decrease chosen likelihood and decrease rejected likelihood more
- losses_chosen = F.sigmoid(self.beta * chosen_logratios)
- losses_rejected = 1 - F.sigmoid(self.beta * (chosen_logratios - rejected_logratios))
- losses = losses_chosen + losses_rejected
-
- elif loss_type == "discopop":
- # Eqn (5) of the DiscoPOP paper (https://huggingface.co/papers/2406.08414)
- # This loss was discovered with LLM discovery
- logratios = chosen_logps - rejected_logps
- ref_logratios = ref_chosen_logps - ref_rejected_logps
- logits = logratios - ref_logratios
- logits = logits * self.beta
- # Modulate the mixing coefficient based on the log ratio magnitudes
- log_ratio_modulation = torch.sigmoid(logits / self.args.discopop_tau)
- logistic_component = -F.logsigmoid(logits)
- exp_component = torch.exp(-logits)
- # Blend between logistic and exponential component based on log ratio modulation
- losses = logistic_component * (1 - log_ratio_modulation) + exp_component * log_ratio_modulation
-
- elif loss_type == "sft":
- # SFT loss is the negative log likelihood loss on chosen responses
- # This acts as the generation loss component in MPO
- sft_loss = model_output["nll_loss"]
- # Create losses tensor with same shape as other losses (per-sample)
- batch_size = chosen_logps.shape[0]
- losses = sft_loss.expand(batch_size)
- # For SFT, we don't have preference rewards, so use zeros
- chosen_rewards = torch.zeros_like(chosen_logps)
- rejected_rewards = torch.zeros_like(rejected_logps)
-
- else:
- raise ValueError(
- f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'exo_pair', "
- "'nca_pair', 'robust', 'bco_pair', 'sppo_hard', 'aot', 'aot_pair', 'discopop', 'apo_zero', "
- "'apo_down', 'sft']"
- )
-
- chosen_rewards = self.beta * (chosen_logps.to(device) - ref_chosen_logps.to(device)).detach()
- rejected_rewards = self.beta * (rejected_logps.to(device) - ref_rejected_logps.to(device)).detach()
-
- return losses, chosen_rewards, rejected_rewards
-
- def _compute_loss_liger(
- self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
- ) -> dict[str, torch.Tensor]:
- unwrapped_model = self.accelerator.unwrap_model(model)
- concatenated_batch = self.concatenated_inputs(batch, padding_value=self.pad_token_id)
-
- model_kwargs = {}
- if self.aux_loss_enabled:
- model_kwargs["output_router_logits"] = True
-
- # Add the pixel values and attention masks for vision models
- if "pixel_values" in concatenated_batch:
- model_kwargs["pixel_values"] = concatenated_batch["pixel_values"]
- if "pixel_attention_mask" in concatenated_batch:
- model_kwargs["pixel_attention_mask"] = concatenated_batch["pixel_attention_mask"]
- if "image_sizes" in concatenated_batch:
- model_kwargs["image_sizes"] = concatenated_batch["image_sizes"]
-
- prompt_attention_mask = concatenated_batch["prompt_attention_mask"]
- completion_attention_mask = concatenated_batch["completion_attention_mask"]
-
- if self.is_encoder_decoder:
- # 1. Get encoder outputs
- encoder_outputs = unwrapped_model.get_encoder()(
- concatenated_batch["prompt_input_ids"],
- attention_mask=concatenated_batch["prompt_attention_mask"],
- return_dict=True,
- )
- # 2. Prepare decoder inputs
- decoder_input_ids = shift_tokens_right(
- concatenated_batch["completion_input_ids"],
- unwrapped_model.config.decoder_start_token_id,
- )
- # 3. Get decoder outputs
- decoder_outputs = unwrapped_model.get_decoder()(
- input_ids=decoder_input_ids,
- attention_mask=concatenated_batch["completion_attention_mask"],
- encoder_hidden_states=encoder_outputs.last_hidden_state,
- encoder_attention_mask=concatenated_batch["prompt_attention_mask"],
- use_cache=False,
- )
- hidden_states = decoder_outputs.last_hidden_state
-
- ref_hidden_states = None
- if not self.reference_free and self.ref_model is not None:
- unwrapped_ref_model = self.accelerator.unwrap_model(self.ref_model)
- ref_encoder_outputs = unwrapped_ref_model.get_encoder()(
- concatenated_batch["prompt_input_ids"],
- attention_mask=concatenated_batch["prompt_attention_mask"],
- return_dict=True,
- )
- ref_decoder_outputs = unwrapped_ref_model.get_decoder()(
- input_ids=decoder_input_ids,
- attention_mask=concatenated_batch["completion_attention_mask"],
- encoder_hidden_states=ref_encoder_outputs.last_hidden_state,
- encoder_attention_mask=concatenated_batch["prompt_attention_mask"],
- use_cache=False,
- )
- ref_hidden_states = ref_decoder_outputs.last_hidden_state
- elif not self.reference_free:
- with self.null_ref_context():
- ref_encoder_outputs = unwrapped_model.get_encoder()(
- concatenated_batch["prompt_input_ids"],
- attention_mask=concatenated_batch["prompt_attention_mask"],
- return_dict=True,
- )
- ref_decoder_outputs = unwrapped_model.get_decoder()(
- input_ids=decoder_input_ids,
- attention_mask=concatenated_batch["completion_attention_mask"],
- encoder_hidden_states=ref_encoder_outputs.last_hidden_state,
- encoder_attention_mask=concatenated_batch["prompt_attention_mask"],
- use_cache=False,
- )
- ref_hidden_states = ref_decoder_outputs.last_hidden_state
-
- labels = concatenated_batch["completion_input_ids"]
- loss_mask = completion_attention_mask.bool()
- else:
- # For decoder-only models
- input_ids = torch.cat(
- (concatenated_batch["prompt_input_ids"], concatenated_batch["completion_input_ids"]), dim=1
- )
- attention_mask = torch.cat(
- (concatenated_batch["prompt_attention_mask"], concatenated_batch["completion_attention_mask"]),
- dim=1,
- )
- # Mask the prompt but not the completion for the loss
- loss_mask = torch.cat(
- (torch.zeros_like(prompt_attention_mask), completion_attention_mask),
- dim=1,
- )
-
- # Flush and truncate
- if self.max_length is not None and self.max_length < attention_mask.size(1):
- if self.truncation_mode == "keep_start":
- # Flush left to reduce the memory usage
- # [[0, 0, x, x, x, x], -> [[x, x, x, x],
- # [0, x, x, x, 0, 0]] [x, x, x, 0]]
- attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
- attention_mask = attention_mask[:, : self.max_length]
- input_ids = input_ids[:, : self.max_length]
- loss_mask = loss_mask[:, : self.max_length]
- elif self.truncation_mode == "keep_end":
- # Flush right before truncating left, then flush left
- # [[0, 0, x, x, x, x], -> [[0, 0, x, x],
- # [0, x, x, x, 0, 0]] [0, x, x, x]]
- attention_mask, input_ids, loss_mask = flush_right(attention_mask, input_ids, loss_mask)
- input_ids = input_ids[:, -self.max_length :]
- attention_mask = attention_mask[:, -self.max_length :]
- loss_mask = loss_mask[:, -self.max_length :]
- attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
- else:
- raise ValueError(
- f"Unknown truncation mode: '{self.truncation_mode}'. Should be one of ['keep_end', "
- "'keep_start']."
- )
- else:
- # Flush left to reduce the memory usage
- # [[0, 0, x, x, x, x], -> [[x, x, x, x],
- # [0, x, x, x, 0, 0]] [x, x, x, 0]]
- attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
-
- # Add logits_to_keep optimization
- if self.use_logits_to_keep:
- first_compute_index = loss_mask.nonzero(as_tuple=True)[1].min()
- logits_to_keep = (loss_mask.shape[1] - first_compute_index).item() + 1
- model_kwargs["logits_to_keep"] = logits_to_keep
-
- model_kwargs["output_hidden_states"] = True
-
- # Add padding-free training support
- if self.padding_free:
- input_ids = input_ids[attention_mask.bool()].unsqueeze(0)
- loss_mask = loss_mask[attention_mask.bool()].unsqueeze(0)
- position_ids = attention_mask.cumsum(1)[attention_mask.bool()].unsqueeze(0) - 1
- model_kwargs["position_ids"] = position_ids
- else:
- model_kwargs["attention_mask"] = attention_mask
-
- # Get the base model outputs (before LM head)
- if hasattr(unwrapped_model, "get_decoder") and unwrapped_model.get_decoder() is not None:
- base_model = unwrapped_model.get_decoder()
- else:
- base_attr = getattr(unwrapped_model, "base_model_prefix", self.args.base_model_attribute_name)
- base_model = getattr(unwrapped_model, base_attr, unwrapped_model)
-
- outputs = base_model(
- input_ids,
- use_cache=False,
- **model_kwargs,
- )
- hidden_states = outputs.last_hidden_state[:, :-1]
-
- # Get reference hidden states if needed
- ref_hidden_states = None
- if not self.reference_free and self.ref_model is not None:
- unwrapped_ref_model = self.accelerator.unwrap_model(self.ref_model)
- if hasattr(unwrapped_ref_model, "get_decoder") and unwrapped_ref_model.get_decoder() is not None:
- ref_base_model = unwrapped_ref_model.get_decoder()
- else:
- ref_attr = getattr(unwrapped_ref_model, "base_model_prefix", self.args.base_model_attribute_name)
- ref_base_model = getattr(unwrapped_ref_model, ref_attr, unwrapped_ref_model)
-
- ref_outputs = ref_base_model(
- input_ids,
- use_cache=False,
- **model_kwargs,
- )
- ref_hidden_states = ref_outputs.last_hidden_state[:, :-1]
- elif not self.reference_free:
- if hasattr(unwrapped_model, "get_decoder") and unwrapped_model.get_decoder() is not None:
- ref_base_model = unwrapped_model.get_decoder()
- else:
- ref_attr = getattr(unwrapped_model, "base_model_prefix", self.args.base_model_attribute_name)
- ref_base_model = getattr(unwrapped_model, ref_attr, unwrapped_model)
- with self.null_ref_context():
- ref_outputs = ref_base_model(
- input_ids,
- use_cache=False,
- **model_kwargs,
- )
- ref_hidden_states = ref_outputs.last_hidden_state[:, :-1]
-
- masked_input_ids = torch.where(loss_mask != 0, input_ids, self.label_pad_token_id)
- labels = masked_input_ids[:, 1:] # Shift right for casual LM
-
- # Get the LM head
- lm_head = unwrapped_model.get_output_embeddings()
-
- # Get reference model weights if needed
- ref_weight = None
- ref_bias = None
- if not self.reference_free:
- if self.ref_model is not None:
- unwrapped_ref_model = self.accelerator.unwrap_model(self.ref_model)
- ref_lm_head = unwrapped_ref_model.get_output_embeddings()
- else:
- with self.null_ref_context():
- ref_lm_head = unwrapped_model.get_output_embeddings()
- ref_weight = ref_lm_head.weight
- ref_bias = ref_lm_head.bias if hasattr(ref_lm_head, "bias") else None
-
- # Compute loss using Liger kernel
- loss_output = self.dpo_loss_fn(
- lm_head.weight,
- hidden_states,
- labels,
- bias=lm_head.bias if hasattr(lm_head, "bias") else None,
- ref_input=ref_hidden_states if not self.reference_free else None,
- ref_weight=ref_weight if not self.reference_free else None,
- ref_bias=ref_bias if not self.reference_free else None,
- )
- (
- loss,
- (chosen_logps, rejected_logps, chosen_logits_mean, rejected_logits_mean, nll_loss, *aux_outputs),
- ) = loss_output
-
- output = {
- "loss": loss,
- "chosen_logps": chosen_logps,
- "rejected_logps": rejected_logps,
- "mean_chosen_logits": chosen_logits_mean,
- "mean_rejected_logits": rejected_logits_mean,
- "nll_loss": nll_loss,
- "chosen_rewards": aux_outputs[0],
- "rejected_rewards": aux_outputs[1],
- }
- if self.aux_loss_enabled:
- output["aux_loss"] = outputs.aux_loss
-
- return output
-
- def concatenated_forward(
- self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]], is_ref_model: bool = False
- ) -> dict[str, torch.Tensor]:
- """
- Runs the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
-
- We do this to avoid doing two forward passes, because it's faster for FSDP.
-
- Args:
- model:
- Model to run the forward pass on.
- batch:
- Batch of input data.
- is_ref_model:
- Whether this method is being called for the reference model. If `True`, length desensitization is not
- applied.
- """
- num_examples = batch["prompt_input_ids"].shape[0]
-
- concatenated_batch = self.concatenated_inputs(batch, padding_value=self.pad_token_id)
-
- model_kwargs = {"use_cache": False}
- if self.aux_loss_enabled:
- model_kwargs["output_router_logits"] = True
-
- # Add the pixel values and attention masks for vision models
- if "pixel_values" in concatenated_batch:
- model_kwargs["pixel_values"] = concatenated_batch["pixel_values"]
- if "pixel_attention_mask" in concatenated_batch:
- model_kwargs["pixel_attention_mask"] = concatenated_batch["pixel_attention_mask"]
- if "image_sizes" in concatenated_batch:
- model_kwargs["image_sizes"] = concatenated_batch["image_sizes"]
-
- prompt_input_ids = concatenated_batch["prompt_input_ids"]
- prompt_attention_mask = concatenated_batch["prompt_attention_mask"]
- completion_input_ids = concatenated_batch["completion_input_ids"]
- completion_attention_mask = concatenated_batch["completion_attention_mask"]
- if self.is_encoder_decoder:
- labels = completion_input_ids
- labels[completion_attention_mask == 0] = self.label_pad_token_id
- outputs = model(
- input_ids=prompt_input_ids,
- attention_mask=prompt_attention_mask,
- labels=labels, # we need the labels for the logits to be returned
- **model_kwargs,
- )
- logits = outputs.logits
- loss_mask = completion_attention_mask.bool()
- else:
- # Concatenate the prompt and completion inputs
- input_ids = torch.cat((prompt_input_ids, completion_input_ids), dim=1)
- attention_mask = torch.cat((prompt_attention_mask, completion_attention_mask), dim=1)
- if "token_type_ids" in concatenated_batch:
- prompt_token_type_ids = concatenated_batch["token_type_ids"]
- token_type_ids = pad_to_length(prompt_token_type_ids, input_ids.shape[1], 0)
- # Mask the prompt but not the completion for the loss
- loss_mask = torch.cat(
- (torch.zeros_like(prompt_attention_mask), completion_attention_mask),
- dim=1,
- )
-
- # Flush and truncate
- if self.max_length is not None and self.max_length < attention_mask.size(1):
- if self.truncation_mode == "keep_start":
- # Flush left to reduce the memory usage
- # [[0, 0, x, x, x, x], -> [[x, x, x, x],
- # [0, x, x, x, 0, 0]] [x, x, x, 0]]
- if "token_type_ids" in concatenated_batch:
- attention_mask, input_ids, loss_mask, token_type_ids = flush_left(
- attention_mask, input_ids, loss_mask, token_type_ids
- )
- else:
- attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
- attention_mask = attention_mask[:, : self.max_length]
- input_ids = input_ids[:, : self.max_length]
- loss_mask = loss_mask[:, : self.max_length]
- elif self.truncation_mode == "keep_end":
- # Flush right before truncating left, then flush left
- # [[0, 0, x, x, x, x], -> [[0, 0, x, x],
- # [0, x, x, x, 0, 0]] [0, x, x, x]]
- if "token_type_ids" in concatenated_batch:
- attention_mask, input_ids, loss_mask, token_type_ids = flush_left(
- attention_mask, input_ids, loss_mask, token_type_ids
- )
- token_type_ids = token_type_ids[:, -self.max_length :]
- else:
- attention_mask, input_ids, loss_mask = flush_right(attention_mask, input_ids, loss_mask)
- input_ids = input_ids[:, -self.max_length :]
- attention_mask = attention_mask[:, -self.max_length :]
- loss_mask = loss_mask[:, -self.max_length :]
- if "token_type_ids" in concatenated_batch:
- attention_mask, input_ids, loss_mask, token_type_ids = flush_left(
- attention_mask, input_ids, loss_mask, token_type_ids
- )
- else:
- attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
- else:
- raise ValueError(
- f"Unknown truncation mode: '{self.truncation_mode}'. Should be one of ['keep_end', "
- "'keep_start']."
- )
- else:
- # Flush left to reduce the memory usage
- # [[0, 0, x, x, x, x], -> [[x, x, x, x],
- # [0, x, x, x, 0, 0]] [x, x, x, 0]]
- if "token_type_ids" in concatenated_batch:
- attention_mask, input_ids, loss_mask, token_type_ids = flush_left(
- attention_mask, input_ids, loss_mask, token_type_ids
- )
- else:
- attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
-
- if "token_type_ids" in concatenated_batch:
- model_kwargs["token_type_ids"] = token_type_ids
-
- if self.use_logits_to_keep:
- # Compute logits_to_keep based on loss_mask pattern:
- # [[0, 0, 0, x, x, x, x],
- # [0, 0, 0, x, x, x, 0]]
- # ^ start computing logits from here ([:, -(7-3+1):])
- first_compute_index = loss_mask.nonzero(as_tuple=True)[1].min()
- logits_to_keep = (loss_mask.shape[1] - first_compute_index).item() + 1 # +1 for the first label
- model_kwargs["logits_to_keep"] = logits_to_keep
-
- model_kwargs["output_hidden_states"] = True
-
- if self.padding_free:
- # Flatten the input_ids, position_ids, and loss_mask
- # input_ids = [[a, b, c, 0], -> input_ids = [[a, b, c, d, e, f, g]]
- # [d, e, f, g]] position_ids = [[0, 1, 2, 0, 1, 2, 3]]
- input_ids = input_ids[attention_mask.bool()].unsqueeze(0)
- loss_mask = loss_mask[attention_mask.bool()].unsqueeze(0)
- position_ids = attention_mask.cumsum(1)[attention_mask.bool()].unsqueeze(0) - 1
- model_kwargs["position_ids"] = position_ids
- else:
- model_kwargs["attention_mask"] = attention_mask
-
- outputs = model(input_ids, **model_kwargs)
- logits = outputs.logits
-
- # Offset the logits by one to align with the labels
- labels = torch.roll(input_ids, shifts=-1, dims=1)
- loss_mask = torch.roll(loss_mask, shifts=-1, dims=1).bool()
-
- if self.use_logits_to_keep:
- # Align labels with logits
- # logits: -, -, [x2, x3, x4, x5, x6]
- # ^ --------- ^ after logits[:, :-1, :]
- # labels: [y0, y1, y2, y3, y4, y5, y6]
- # ^ --------- ^ with logits_to_keep=4, [:, -4:]
- # loss_mask: [0, 0, 0, 1, 1, 1, 1]
- labels = labels[:, -logits_to_keep:]
- loss_mask = loss_mask[:, -logits_to_keep:]
-
- if logits.shape[:2] != labels.shape[:2]:
- # for LLaVA, the returned logits include the image tokens (placed before the text tokens)
- seq_len = labels.shape[1]
- logits = logits[:, -seq_len:]
-
- # Compute the log probabilities of the labels
- labels[~loss_mask] = 0 # dummy token; we'll ignore the losses on these tokens later
- per_token_logps = selective_log_softmax(logits, labels)
- per_token_logps[~loss_mask] = 0
- per_token_logps = torch.roll(per_token_logps, shifts=1, dims=1)
-
- if self.padding_free:
- # Unflatten the per_token_logps (shape: [1, sum_seq_len] -> [batch_size, seq_len])
- batch_size, seq_len = attention_mask.shape
- per_token_logps_ = torch.zeros(
- batch_size, seq_len, device=outputs.logits.device, dtype=outputs.logits.dtype
- )
- per_token_logps_[attention_mask.bool()] = per_token_logps
- per_token_logps = per_token_logps_
-
- all_logps = per_token_logps[:, 1:].sum(-1)
-
- output = {}
-
- if self.use_weighting:
- with torch.no_grad():
- # Eq (2) of the WPO paper: https://huggingface.co/papers/2406.11827
- logprobs = F.log_softmax(logits, dim=-1)
- weights_adjustment_factor = torch.logsumexp(2 * logprobs, dim=-1) # same as sum(probs**2) in log space
- per_token_logps_adjusted = per_token_logps - weights_adjustment_factor
- all_weights = (per_token_logps_adjusted * loss_mask).sum(-1) / loss_mask.sum(-1)
- chosen_weights = all_weights[:num_examples]
- rejected_weights = all_weights[num_examples:]
- output["policy_weights"] = torch.clamp(torch.exp(chosen_weights + rejected_weights), max=1)
-
- if self.args.rpo_alpha is not None or "sft" in self.loss_type:
- # Only use the chosen logits for the RPO loss or SFT loss
- chosen_logits = logits[:num_examples, :-1] if not self.is_encoder_decoder else logits[:num_examples]
- chosen_labels = labels[:num_examples, :-1] if not self.is_encoder_decoder else labels[:num_examples]
-
- # Compute the log probabilities of the labels
- output["nll_loss"] = F.cross_entropy(
- torch.flatten(chosen_logits, end_dim=1), torch.flatten(chosen_labels, end_dim=1), ignore_index=0
- )
-
- if "ipo" in self.loss_type:
- all_logps = all_logps / loss_mask.sum(-1)
-
- if self.args.ld_alpha is not None and not is_ref_model:
- # Compute response lengths based on loss_mask
- completion_lengths = loss_mask.sum(dim=1)
-
- chosen_lengths = completion_lengths[:num_examples]
- rejected_lengths = completion_lengths[num_examples:]
- public_lengths = torch.min(chosen_lengths, rejected_lengths) # l_p in the paper
- public_lengths = torch.cat([public_lengths, public_lengths], dim=0)
-
- seq_len = per_token_logps.size(1)
- position_ids = torch.arange(seq_len, device=per_token_logps.device).expand_as(per_token_logps)
-
- ld_mask = position_ids < public_lengths.unsqueeze(1)
- mask = position_ids < completion_lengths.unsqueeze(1)
-
- front_mask = (ld_mask & mask).float()
- rear_mask = (~ld_mask & mask).float()
- front_logps = (per_token_logps * front_mask).sum(dim=1)
- rear_logps = (per_token_logps * rear_mask).sum(dim=1)
-
- all_logps = front_logps + self.args.ld_alpha * rear_logps
-
- output["chosen_logps"] = all_logps[:num_examples]
- output["rejected_logps"] = all_logps[num_examples:]
-
- # Compute the mean logits
- if self.padding_free:
- # position_ids contains a sequence of range identifiers (e.g., [[0, 1, 2, 0, 1, 2, 3, ...]]).
- # There are 2*num_examples ranges in total: the first half corresponds to the chosen tokens,
- # and the second half to the rejected tokens.
- # To find the start of the rejected tokens, we look for the num_examples+1-th zero in pos_id.
- split_idx = (position_ids == 0).nonzero(as_tuple=True)[1][num_examples]
- mean_chosen_logits = logits[0, :split_idx][loss_mask[0, :split_idx]].mean()
- mean_rejected_logits = logits[0, split_idx:][loss_mask[0, split_idx:]].mean()
- else:
- mean_chosen_logits = logits[:num_examples][loss_mask[:num_examples]].mean()
- mean_rejected_logits = logits[num_examples:][loss_mask[num_examples:]].mean()
-
- output["mean_chosen_logits"] = mean_chosen_logits
- output["mean_rejected_logits"] = mean_rejected_logits
-
- if self.aux_loss_enabled:
- output["aux_loss"] = outputs.aux_loss
-
- return output
-
- def get_batch_loss_metrics(
- self,
- model: Union[PreTrainedModel, nn.Module],
- batch: dict[str, Union[list, torch.LongTensor]],
- train_eval: Literal["train", "eval"] = "train",
- ) -> tuple[torch.Tensor, dict[str, float]]:
- """Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
- metrics = {}
-
- if self.args.use_liger_loss:
- model_output = self._compute_loss_liger(model, batch)
- losses = model_output["loss"]
- chosen_rewards = model_output["chosen_rewards"]
- rejected_rewards = model_output["rejected_rewards"]
- else:
- model_output = self.concatenated_forward(model, batch)
-
- # if ref_chosen_logps and ref_rejected_logps in batch use them, otherwise use the reference model
- if "ref_chosen_logps" in batch and "ref_rejected_logps" in batch:
- ref_chosen_logps = batch["ref_chosen_logps"]
- ref_rejected_logps = batch["ref_rejected_logps"]
- else:
- ref_chosen_logps, ref_rejected_logps = self.compute_ref_log_probs(batch)
-
- # Initialize combined losses
- losses = 0
- chosen_rewards = 0
- rejected_rewards = 0
-
- # Compute losses for each loss type
- for idx, loss_type in enumerate(self.loss_type):
- # Compute individual loss using standard DPO loss function
- _losses, _chosen_rewards, _rejected_rewards = self.dpo_loss(
- model_output["chosen_logps"],
- model_output["rejected_logps"],
- ref_chosen_logps,
- ref_rejected_logps,
- loss_type,
- model_output,
- )
-
- # Add weighted contributions
- weight = self.loss_weights[idx] if self.loss_weights else 1.0
- losses = losses + _losses * weight
- chosen_rewards = chosen_rewards + _chosen_rewards * weight
- rejected_rewards = rejected_rewards + _rejected_rewards * weight
-
- reward_accuracies = (chosen_rewards > rejected_rewards).float()
-
- if self.args.rpo_alpha is not None:
- losses = losses + self.args.rpo_alpha * model_output["nll_loss"] # RPO loss from V3 of the paper
-
- if self.use_weighting:
- losses = losses * model_output["policy_weights"]
-
- if self.aux_loss_enabled:
- losses = losses + self.aux_loss_coef * model_output["aux_loss"]
-
- prefix = "eval_" if train_eval == "eval" else ""
- metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean().item()
- metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean().item()
- metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean().item()
- metrics[f"{prefix}rewards/margins"] = (
- self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards).mean().item()
- )
- metrics[f"{prefix}logps/chosen"] = (
- self.accelerator.gather_for_metrics(model_output["chosen_logps"]).detach().mean().item()
- )
- metrics[f"{prefix}logps/rejected"] = (
- self.accelerator.gather_for_metrics(model_output["rejected_logps"]).detach().mean().item()
- )
- metrics[f"{prefix}logits/chosen"] = (
- self.accelerator.gather_for_metrics(model_output["mean_chosen_logits"]).detach().mean().item()
- )
- metrics[f"{prefix}logits/rejected"] = (
- self.accelerator.gather_for_metrics(model_output["mean_rejected_logits"]).detach().mean().item()
- )
- if self.args.rpo_alpha is not None or "sft" in self.loss_type:
- metrics[f"{prefix}nll_loss"] = (
- self.accelerator.gather_for_metrics(model_output["nll_loss"]).detach().mean().item()
- )
- if self.aux_loss_enabled:
- metrics[f"{prefix}aux_loss"] = (
- self.accelerator.gather_for_metrics(model_output["aux_loss"]).detach().mean().item()
- )
-
- return losses.mean(), metrics
-
- def compute_loss(
- self,
- model: Union[PreTrainedModel, nn.Module],
- inputs: dict[str, Union[torch.Tensor, Any]],
- return_outputs=False,
- num_items_in_batch=None,
- ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, float]]]:
- compute_loss_context_manager = (
- autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
- )
- with compute_loss_context_manager:
- loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
-
- # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
- loss = loss.to(self.args.device)
- # force log the metrics
- self.store_metrics(metrics, train_eval="train")
-
- if return_outputs:
- return loss, metrics
-
- return loss
-
- def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]:
- """Generate samples from the model and reference model for the given batch of inputs."""
-
- # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
- # the torch amp context manager as some hidden states are silently casted to full precision.
- generate_context_manager = (
- autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
- )
-
- with generate_context_manager:
- policy_output = model.generate(
- input_ids=batch["prompt_input_ids"],
- attention_mask=batch["prompt_attention_mask"],
- max_length=self.max_length,
- do_sample=True,
- pad_token_id=self.pad_token_id,
- )
-
- # if ref_output in batch use that otherwise use the reference model
- if "ref_output" in batch:
- ref_output = batch["ref_output"]
- else:
- if self.ref_model is None:
- with self.null_ref_context():
- ref_output = self.model.generate(
- input_ids=batch["prompt_input_ids"],
- attention_mask=batch["prompt_attention_mask"],
- max_length=self.max_length,
- do_sample=True,
- pad_token_id=self.pad_token_id,
- )
- else:
- ref_output = self.ref_model.generate(
- input_ids=batch["prompt_input_ids"],
- attention_mask=batch["prompt_attention_mask"],
- max_length=self.max_length,
- do_sample=True,
- pad_token_id=self.pad_token_id,
- )
-
- policy_output = pad_to_length(policy_output, self.max_length, self.pad_token_id)
- policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
-
- ref_output = pad_to_length(ref_output, self.max_length, self.pad_token_id)
- ref_output_decoded = self.processing_class.batch_decode(ref_output, skip_special_tokens=True)
-
- return policy_output_decoded, ref_output_decoded
-
- def prediction_step(
- self,
- model: Union[PreTrainedModel, nn.Module],
- inputs: dict[str, Union[torch.Tensor, Any]],
- prediction_loss_only: bool,
- ignore_keys: Optional[list[str]] = None,
- ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
- if ignore_keys is None:
- if hasattr(model, "config"):
- ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
- else:
- ignore_keys = []
-
- prediction_context_manager = (
- autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
- )
-
- with torch.no_grad(), prediction_context_manager:
- loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")
-
- # force log the metrics
- self.store_metrics(metrics, train_eval="eval")
-
- if prediction_loss_only:
- return loss.detach(), None, None
-
- # logits for the chosen and rejected samples from model
- logits_dict = {
- "eval_logits/chosen": metrics["eval_logits/chosen"],
- "eval_logits/rejected": metrics["eval_logits/rejected"],
- }
- logits = [v for k, v in logits_dict.items() if k not in ignore_keys]
- logits = torch.tensor(logits, device=self.accelerator.device)
- labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
-
- return (loss.detach(), logits, labels)
-
- def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
- for key, value in metrics.items():
- self._stored_metrics[train_eval][key].append(value)
-
- def evaluation_loop(
- self,
- dataloader: DataLoader,
- description: str,
- prediction_loss_only: Optional[bool] = None,
- ignore_keys: Optional[list[str]] = None,
- metric_key_prefix: str = "eval",
- ) -> EvalLoopOutput:
- """
- Overriding built-in evaluation loop to store metrics for each batch. Prediction/evaluation loop, shared by
- `Trainer.evaluate()` and `Trainer.predict()`.
-
- Works both with or without labels.
- """
-
- # Sample and save to game log if requested (for one batch to save time)
- if self.generate_during_eval:
- # Generate random indices within the range of the total number of samples
- num_samples = len(dataloader.dataset)
- random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
-
- # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
- random_batch_dataset = dataloader.dataset.select(random_indices)
- random_batch = self.data_collator(random_batch_dataset)
- random_batch = self._prepare_inputs(random_batch)
-
- policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, random_batch)
-
- table = pd.DataFrame(
- columns=["Prompt", "Policy", "Ref Model"],
- data=[
- [prompt, pol[len(prompt) :], ref[len(prompt) :]]
- for prompt, pol, ref in zip(
- random_batch_dataset["prompt"], policy_output_decoded, ref_output_decoded
- )
- ],
- )
- if "wandb" in self.args.report_to and self.accelerator.is_main_process:
- wandb.log({"game_log": wandb.Table(data=table)})
-
- if "comet_ml" in self.args.report_to:
- log_table_to_comet_experiment(
- name="game_log.csv",
- table=table,
- )
-
- if "mlflow" in self.args.report_to and self.accelerator.is_main_process:
- mlflow.log_table(data=table, artifact_file="game_log.json")
-
- # Base evaluation
- initial_output = super().evaluation_loop(
- dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
- )
-
- return initial_output
-
- def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
- """
- Log `logs` on the various objects watching training, including stored metrics.
-
- Args:
- logs (`dict[str, float]`):
- The values to log.
- start_time (`float`, *optional*):
- Start time of the training.
- """
- # logs either has 'loss' or 'eval_loss'
- train_eval = "train" if "loss" in logs else "eval"
- # Add averaged stored metrics to logs
- for key, metrics in self._stored_metrics[train_eval].items():
- logs[key] = torch.tensor(metrics).mean().item()
- del self._stored_metrics[train_eval]
- return super().log(logs, start_time)
-
- # Ensure the model card is saved along with the checkpoint
- def _save_checkpoint(self, model, trial):
- if self.args.hub_model_id is None:
- model_name = Path(self.args.output_dir).name
- else:
- model_name = self.args.hub_model_id.split("/")[-1]
- self.create_model_card(model_name=model_name)
- super()._save_checkpoint(model, trial)
-class UnslothDPOTrainer(_UnslothDPOTrainer):
- """
-
- Trainer for Direct Preference Optimization (DPO) method.
-
- This class is a wrapper around the [`transformers.Trainer`] class and inherits all of its attributes and methods.
-
- Args:
- model (`Union[str, PreTrainedModel]`):
- Model to be trained. Can be either:
-
- - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a
- path to a *directory* containing model weights saved using
- [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
- using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keyword arguments in
- `args.model_init_kwargs`.
- - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
- ref_model ([`PreTrainedModelWrapper`]):
- Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation
- and loss. If no reference model is provided, the trainer will create a reference model with the same
- architecture as the model to be optimized.
- args ([`DPOConfig`], *optional*):
- Configuration for this trainer. If `None`, a default configuration is used.
- data_collator ([`~transformers.DataCollator`], *optional*):
- Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`.
- Will default to [`DataCollatorForPreference`].
- train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
- Dataset to use for training. DPO supports [preference](#preference) type and. The format of the samples can
- be either:
-
- - [Standard](dataset_formats#standard): Each sample contains plain text.
- - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
- and content).
- eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
- Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
- processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*):
- Processing class used to process the data. If `None`, the processing class is loaded from the model's name
- with [`~transformers.AutoTokenizer.from_pretrained`].
- compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
- The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return
- a dictionary string to metric values. *Note* When passing TrainingArgs with `batch_eval_metrics` set to
- `True`, your compute_metrics function must take a boolean `compute_result` argument. This will be triggered
- after the last eval batch to signal that the function needs to calculate and return the global summary
- statistics rather than accumulating the batch-level statistics.
- callbacks (list of [`~transformers.TrainerCallback`], *optional*):
- List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed
- in [here](https://huggingface.co/docs/transformers/main_classes/callback).
-
- If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
- method.
- optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
- A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
- model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
- optimizer_cls_and_kwargs (`Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]`, *optional*):
- A tuple containing the optimizer class and keyword arguments to use. Overrides `optim` and `optim_args` in
- `args`. Incompatible with the `optimizers` argument.
- preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*):
- A function that preprocess the logits right before caching them at each evaluation step. Must take two
- tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
- by this function will be reflected in the predictions received by `compute_metrics`.
-
- Note that the labels (second parameter) will be `None` if the dataset does not have them.
- peft_config ([`~peft.PeftConfig`], *optional*):
- PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
-
- """
- def __init__(
- self,
- model,
- ref_model = None,
- args = None,
- data_collator = None,
- train_dataset = None,
- eval_dataset = None,
- processing_class = None,
- compute_metrics = None,
- callbacks = None,
- optimizer_cls_and_kwargs = None,
- preprocess_logits_for_metrics = None,
- peft_config = None,
- **kwargs
- ):
- if args is None: args = UnslothDPOConfig()
- use_bf16 = getattr(args, 'bf16', False)
- if type(use_bf16) is not bool: use_bf16 = False
- use_fp16 = getattr(args, 'fp16', False)
- if type(use_fp16) is not bool: use_fp16 = False
- force_float32 = False
- full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1'
- if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'):
- print('Unsloth: Switching to float32 training since model cannot work with float16')
- force_float32 = True
- mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
- dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None)
- if dtype is None: dtype = model.get_input_embeddings().weight.dtype
- from unsloth_zoo.utils import _get_dtype
- dtype = _get_dtype(dtype)
- float16 = dtype == torch.float16
- if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
- if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
- if force_float32:
- # Forced float32 training
- args.fp16 = False
- args.bf16 = False
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
- if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
- # args.mixed_precision is a new argument which needs to be set now
- elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
- # Mixed precision training
- args.fp16 = float16
- args.bf16 = not float16
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
- if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16'
- # args.mixed_precision is a new argument which needs to be set now
- elif mixed_precision_dtype == 'bfloat16':
- # Both False since bfloat16 full finetuning doesn't do any autocasting.
- args.fp16 = False
- args.bf16 = False
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
- if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
- # args.mixed_precision is a new argument which needs to be set now
-
- if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
- args.eval_strategy = 'steps'
- if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
- ga_steps = getattr(args, 'gradient_accumulation_steps', None)
- if ga_steps is not None and ga_steps > 1:
- from transformers import __version__ as transformers_version
- if Version(transformers_version) <= Version('4.45.2'):
- print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
- '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
- if getattr(args, 'eval_strategy', 'no') != 'no':
- eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
- if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
- if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
- fp16_full_eval = getattr(args, 'fp16_full_eval', False)
- if type(fp16_full_eval) is not bool: fp16_full_eval = False
- bf16_full_eval = getattr(args, 'bf16_full_eval', False)
- if type(bf16_full_eval) is not bool: bf16_full_eval = False
- if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
- if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
- if force_float32:
- args.bf16_full_eval = False
- args.fp16_full_eval = False
- elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
- args.bf16_full_eval = True
- args.fp16_full_eval = False
- elif not bf16_full_eval and not fp16_full_eval:
- args.bf16_full_eval = args.bf16
- args.fp16_full_eval = args.fp16
- _output_logits = False
- if locals().get('compute_metrics', None) is not None: _output_logits = True
- if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
- if _output_logits:
- os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
- if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
- pass
- else:
- model_max_seq_length = getattr(model, 'max_seq_length', None)
- args_max_seq_length = getattr(args, 'max_seq_length', None)
- if args_max_seq_length is None and model_max_seq_length is not None:
- max_seq_length = model.max_seq_length
- if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
- elif args_max_seq_length is not None and model_max_seq_length is not None:
- if args_max_seq_length > model_max_seq_length:
- print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but '
- 'the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.')
- args.max_seq_length = model_max_seq_length
- if model is not None and hasattr(model, 'for_training'):
- model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
- if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
- if 'processing_class' in locals():
- if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
- if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
- __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
- from unsloth_zoo.vision_utils import UnslothVisionDataCollator
- if not isinstance(data_collator, UnslothVisionDataCollator):
- if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
- data_collator = TransformersDataCollatorForLanguageModeling(
- __tokenizer,
- mlm = False,
- mlm_probability = 0.0,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
- data_collator = DataCollatorForSeq2Seq(
- __tokenizer,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- else:
- if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
- if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
- if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
- if not isinstance(data_collator, UnslothVisionDataCollator):
- if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
- if isinstance(data_collator, DataCollatorForSeq2Seq):
- data_collator = DataCollatorForSeq2Seq(
- __tokenizer.tokenizer,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- else:
- data_collator = TransformersDataCollatorForLanguageModeling(
- __tokenizer.tokenizer,
- mlm = False,
- mlm_probability = 0.0,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- other_metrics = []
-
- from unsloth_zoo.logging_utils import PatchRLStatistics
- PatchRLStatistics('dpo_trainer', other_metrics)
- if hasattr(train_dataset, 'column_names'):
- column_names = set(train_dataset.column_names)
- check = ['chosen', 'rejected', 'prompt', 'chosen_input_ids', 'chosen_attention_mask',
- 'chosen_labels', 'rejected_input_ids', 'rejected_attention_mask', 'rejected_labels',
- 'prompt_input_ids', 'prompt_attention_mask']
- if all(x in column_names for x in check):
- train_dataset = train_dataset.remove_columns(['chosen', 'rejected', 'prompt'])
- del check, column_names
-
- # [TODO] Fix up DataParallel multiplying batch sizes
- # [TODO] DDP works, but DP seems to not work? [TODO]
- if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1:
- if getattr(args, "_n_gpu", 1) != 1:
- args._n_gpu = 1
- if "model" in locals() and hasattr(model, "for_training"):
- model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
- super().__init__(
- model = model,
- ref_model = ref_model,
- args = args,
- data_collator = data_collator,
- train_dataset = train_dataset,
- eval_dataset = eval_dataset,
- processing_class = processing_class,
- compute_metrics = compute_metrics,
- callbacks = callbacks,
- optimizer_cls_and_kwargs = optimizer_cls_and_kwargs,
- preprocess_logits_for_metrics = preprocess_logits_for_metrics,
- peft_config = peft_config,**kwargs)
- if "model" in locals() and hasattr(model, "for_inference"):
- model.for_inference()
- if hasattr(self, 'neftune_hook_handle'):
- self.neftune_hook_handle.remove()
- if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
- if getattr(args, 'neftune_noise_alpha', None) is not None:
- model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
- pass
- if hasattr(self, 'accelerator'):
- scaler = self.accelerator.scaler
- current_model = model
- while hasattr(current_model, 'model'):
- current_model.accelerator_scaler = scaler
- current_model = current_model.model
- current_model.accelerator_scaler = scaler
- pass
- if hasattr(self, 'train'):
- self.train = MethodType(prepare_for_training_mode(self.__class__.train), self)
- pass
- if hasattr(self, 'llm') and self.llm is not None and hasattr(self.llm, 'get_tokenizer'):
- _vllm_tok = self.llm.get_tokenizer()
- _pc = getattr(self, 'processing_class', None) or getattr(self, 'tokenizer', None)
- if _vllm_tok is not None and _pc is not None and getattr(_pc, 'chat_template', None) is not None and getattr(_vllm_tok, 'chat_template', None) is None:
- _vllm_tok.chat_template = _pc.chat_template
- pass
-
-pass
-
-
-if hasattr(logger, "addFilter"):
- import logging
- class HideLoggingMessage(logging.Filter):
- def __init__(self, text): self.text = text
- def filter(self, x): return not (self.text in x.getMessage())
- pass
- logger.addFilter(HideLoggingMessage("`use_cache=True`"))
-
diff --git a/unsloth_compiled_cache/UnslothGKDTrainer.py b/unsloth_compiled_cache/UnslothGKDTrainer.py
deleted file mode 100644
index 07d22ac..0000000
--- a/unsloth_compiled_cache/UnslothGKDTrainer.py
+++ /dev/null
@@ -1,1311 +0,0 @@
-"""
-2026.2.1
-2026.2.1
-4.57.6
-0.24.0
-__UNSLOTH_VERSIONING__
-"""
-
-# Unsloth auto generated code
-# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Lesser General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU Lesser General Public License
-# along with this program. If not, see .
-
-from torch import Tensor
-import torch
-import torch.nn as nn
-from torch.nn import functional as F
-from unsloth_zoo.temporary_patches.common import torch_compile
-from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
-from trl.trainer.gkd_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, Callable, DataCollator, DataCollatorForChatML, Dataset, EvalPrediction, F, FeatureExtractionMixin, GKDConfig, GKDTrainer, GenerationConfig, Optional, PeftConfig, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SFTTrainer, TrainerCallback, Union, disable_dropout_in_model, empty_cache, nn, os, prepare_deepspeed, random, textwrap, torch, unwrap_model_for_generation, warnings, AutoModelForCausalLM, BaseImageProcessor, Callable, DataCollator, DataCollatorForChatML, Dataset, EvalPrediction, F, FeatureExtractionMixin, GKDConfig, GenerationConfig, Optional, PeftConfig, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SFTTrainer, TrainerCallback, Union, disable_dropout_in_model, nn, os, prepare_deepspeed, torch, warnings)
-
-
-import os
-from typing import *
-from dataclasses import dataclass, field
-from packaging.version import Version
-import torch
-import numpy as np
-from contextlib import nullcontext
-from torch.nn import functional as F
-import inspect
-from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
-from transformers.training_args import ParallelMode
-from unsloth_zoo.device_type import DEVICE_TYPE, device_synchronize
-
-# Wrap trainer with padding to right and enable training mode
-# Also patches W&B since multiple runs must use wandb.finish()
-import functools
-from types import MethodType
-try:
- from unsloth_zoo.gradient_checkpointing import reset_unsloth_gradient_checkpointing_buffers
-except:
- def reset_unsloth_gradient_checkpointing_buffers(): pass
-def prepare_for_training_mode(f):
- @functools.wraps(f)
- def wrapper(self, *args, **kwargs):
- # Enable training mode
- _was_training = None
- # Get gradient checkpointing setting from training arguments
- use_gc = getattr(self.args, 'gradient_checkpointing', True)
- if hasattr(self, 'model') and hasattr(self.model, "training"):
- _was_training = self.model.training
- if hasattr(self, 'model') and hasattr(self.model, "for_training"):
- self.model.for_training(use_gradient_checkpointing=use_gc)
- output = f(self, *args, **kwargs)
- # Restore previous mode when possible
- if hasattr(self, 'model') and hasattr(self.model, "for_inference"):
- if _was_training is False:
- self.model.for_inference()
- elif _was_training is True and hasattr(self.model, "for_training"):
- self.model.for_training(use_gradient_checkpointing=use_gc)
- # Reset gradient checkpointing buffers to free memory while staying ready for next run
- try:
- reset_unsloth_gradient_checkpointing_buffers()
- except:
- pass
- # Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run
- try:
- import wandb
- wandb.finish()
- except:
- pass
- return output
- return wrapper
-pass
-
-torch_compile_options = {
- "epilogue_fusion" : True,
- "max_autotune" : False,
- "shape_padding" : True,
- "trace.enabled" : False,
- "triton.cudagraphs" : False,
-}
-
-@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
-def chunked_hidden_states_selective_log_softmax(
- hidden_states: torch.Tensor,
- lm_head: torch.Tensor,
- index: torch.Tensor,
- chunks: int = 4,
- logit_scale_multiply: float = 0.0,
- logit_scale_divide: float = 0.0,
- logit_softcapping: float = 0.0,
- temperature: float = 1.0,
-) -> torch.Tensor:
- # All Unsloth Zoo code licensed under AGPL3
- flat_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1])
- flat_index = index.reshape(-1)
-
- chunked_hidden_states = torch.chunk(flat_hidden_states, chunks=chunks, dim=0)
- chunked_index = torch.chunk(flat_index, chunks=chunks, dim=0)
-
- all_per_token_logps = []
-
- for chunk_hidden_states, chunk_index in zip(chunked_hidden_states, chunked_index):
- chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t()
-
- if logit_scale_multiply != 0.0:
- chunk_logits = chunk_logits * logit_scale_multiply
- if logit_scale_divide != 0.0:
- chunk_logits = chunk_logits / logit_scale_divide
- if logit_softcapping != 0.0:
- chunk_logits = chunk_logits * torch.tanh(chunk_logits / logit_softcapping)
-
- chunk_logits = chunk_logits.to(torch.float32)
-
- if temperature != 1.0:
- chunk_logits = chunk_logits / temperature
-
- selected_logits = torch.gather(chunk_logits, dim=-1, index=chunk_index.unsqueeze(-1)).squeeze(-1)
- logsumexp_values = torch.logsumexp(chunk_logits, dim=-1)
- per_token_logps = selected_logits - logsumexp_values
- all_per_token_logps.append(per_token_logps)
-
- all_per_token_logps = torch.concat(all_per_token_logps)
-
- all_per_token_logps = all_per_token_logps.reshape((hidden_states.shape[0], hidden_states.shape[1]))
- return all_per_token_logps
-
-@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
-def chunked_selective_log_softmax(logits, index):
- # Split into 4 chunks only
- chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
- chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
- all_per_token_logps = []
- # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
- for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
- chunk_logits = chunk_logits.to(torch.float32)
- selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
- logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
- per_token_logps = selected_logits - logsumexp_values
- all_per_token_logps.append(per_token_logps)
- pass
- all_per_token_logps = torch.concat(all_per_token_logps)
- all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
- return all_per_token_logps
-
-def calculate_pad_tokens_in_prompt(
- input_ids: torch.Tensor,
- logits_to_keep: int,
- pad_token_id: int
-) -> torch.Tensor:
- """
- Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens
- """
- if logits_to_keep >= input_ids.shape[1]:
- raise ValueError("logits_to_keep must be smaller than the sequence length.")
-
- prompt_section = input_ids[:, :-logits_to_keep]
-
- padding_mask = (prompt_section == pad_token_id)
-
- pad_token_counts = padding_mask.sum(dim=1)
-
- return pad_token_counts
-
-def create_completion_attention_mask(
- completion_input_ids: torch.Tensor,
- left_pad_tokens_per_prompt: torch.Tensor,
- max_left_pad: int,
- pad_token_id: int
-) -> torch.Tensor:
- """
- Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad]
-
- Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens
- and pad are pad tokens, this function would make a completion mask that would 0 out the pad
- and p tokens. so in this example [0,0,0,1,1,1,0,0,0]
- """
- batch_size, completion_len = completion_input_ids.shape
- device = completion_input_ids.device
-
- num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt
-
- indices = torch.arange(completion_len, device=device).unsqueeze(0)
- shift_mask = indices >= num_tokens_to_mask.unsqueeze(1)
-
- non_padding_mask = (completion_input_ids != pad_token_id)
-
- final_mask = shift_mask & non_padding_mask
-
- return final_mask
-
-def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor:
- """
- Moves all padding tokens in each sequence of a batch to the right.
- """
- mask = (tensor != pad_id)
- # Must do stable=True since binary mark is unordered
- sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True)
- packed_tensor = torch.gather(tensor, 1, sorted_indices)
- return packed_tensor
-
-def align_logprobs_with_mask(
- logprob_tensor: torch.Tensor,
- attention_mask: torch.Tensor,
- pad_value: float = 0.0
-) -> torch.Tensor:
- """
- Aligns a log probability tensor with a given attention mask.
- """
-
- device = logprob_tensor.device
- batch_size, logprob_seq_len = logprob_tensor.shape
- mask_seq_len = attention_mask.shape[1]
-
- padded_logprobs = torch.full(
- attention_mask.shape,
- fill_value=pad_value,
- dtype=logprob_tensor.dtype,
- device=device
- )
-
- left_pad_counts = torch.argmax(attention_mask, dim=1)
-
- cols = torch.arange(logprob_seq_len, device=device)
- dest_indices = left_pad_counts.unsqueeze(1) + cols
-
- # Create destination row indices
- # Shape: [batch_size, logprob_seq_len]
- row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices)
-
- # --- 4. Filter out-of-bounds indices and perform assignment ---
- # Create a mask to identify only the indices that are within the bounds
- # of the target tensor's sequence length.
- valid_mask = dest_indices < mask_seq_len
-
- # Use this mask to select only the valid row indices, column indices,
- # and the corresponding values from the logprob tensor.
- # This flattens the selected elements into 1D tensors.
- valid_rows = row_indices[valid_mask]
- valid_cols = dest_indices[valid_mask]
- valid_vals = logprob_tensor[valid_mask]
-
- # Place the valid values into their correct positions in the padded tensor
- # using a single, efficient advanced indexing operation.
- padded_logprobs[valid_rows, valid_cols] = valid_vals
-
- return padded_logprobs
-
-def autotune_batch_and_chunks(
- total_input_rows,
- seq_len,
- hidden_size,
- vocab_size,
- dtype_bytes=16,
- multiplier=None
-):
- if multiplier is None:
- final_m = max(4, seq_len // 4096)
- else:
- final_m = multiplier
-
- if torch.cuda.is_available():
- free_bytes, _ = torch.cuda.mem_get_info()
- limit_gb = (free_bytes / (1024**3))*.80
- elif hasattr(torch, "xpu") and torch.xpu.is_available():
- # For XPU: estimate free memory from total - reserved
- total_mem = torch.xpu.get_device_properties(0).total_memory
- reserved_mem = torch.xpu.memory_reserved()
- free_bytes = total_mem - reserved_mem
- limit_gb = (free_bytes / (1024**3)) * 0.80
- else:
- # Fallback: assume 8GB available
- limit_gb = 8.0
-
- bytes_to_gb = 1024**3
-
- b_vals = torch.arange(total_input_rows, 0, -1, device='cpu', dtype=torch.float32)
-
- hidden_gb = (b_vals * seq_len * hidden_size * dtype_bytes) / bytes_to_gb
-
- base_logits = ((b_vals/total_input_rows) * b_vals * seq_len * vocab_size * dtype_bytes) / bytes_to_gb
- logits_gb = base_logits / final_m
-
- total_mem_gb = hidden_gb + logits_gb
-
- valid_mask = total_mem_gb <= limit_gb
- valid_indices = torch.nonzero(valid_mask, as_tuple=False)
-
- if valid_indices.shape[0] == 0:
- #This means your GPU will OOM
- return 4, final_m
-
- best_idx = valid_indices[0].item()
- final_b = int(b_vals[best_idx].item())
-
- return final_b, final_m
-@dataclass
-class UnslothGKDConfig(GKDConfig):
- """
-
- Configuration class for [`GKDTrainer`].
-
- This class includes only the parameters that are specific to GKD training. For a full list of training arguments,
- please refer to the [`~transformers.TrainingArguments`] and [`SFTConfig`] documentation.
-
- Args:
- temperature (`float`, *optional*, defaults to `0.9`):
- Temperature for sampling. The higher the temperature, the more random the completions.
- lmbda (`float`, *optional*, defaults to `0.5`):
- Lambda parameter that controls the student data fraction (i.e., the proportion of on-policy
- student-generated outputs).
- beta (`float`, *optional*, defaults to `0.5`):
- Interpolation coefficient between `0.0` and `1.0` of the Generalized Jensen-Shannon Divergence loss. When
- beta is `0.0`, the loss is the KL divergence. When beta is `1.0`, the loss is the Inverse KL Divergence.
- max_new_tokens (`int`, *optional*, defaults to `128`):
- Maximum number of tokens to generate per completion.
- teacher_model_name_or_path (`str`, *optional*):
- Model name or path of the teacher model. If `None`, the teacher model will be the same as the model being
- trained.
- teacher_model_init_kwargs (`dict[str, Any]]`, *optional*):
- Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the teacher model
- from a string.
- disable_dropout (`bool`, *optional*, defaults to `True`):
- Whether to disable dropout in the model.
- seq_kd (`bool`, *optional*, defaults to `False`):
- Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT on
- teacher-generated output).
-
- """
- vllm_sampling_params: Optional[Any] = field(
- default = None,
- metadata = {'help': 'vLLM SamplingParams'},
- )
- unsloth_num_chunks : Optional[int] = field(
- default = -1,
- metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
- )
- unsloth_logit_chunk_multiplier : Optional[int] = field(
- default = None,
- metadata = {'help': 'Multiplier for chunked logit computations.'},
- )
- unsloth_grpo_mini_batch : Optional[int] = field(
- default = None,
- metadata = {'help': 'Mini batch size for GRPO hidden state accumulation. Default is None unless user defines it.'},
- )
- max_seq_length : Optional[int] = field(
- default = None,
- metadata = {'help': 'Maximum sequence length to truncate to.'},
- )
- def __init__(
- self,
- output_dir = None,
- overwrite_output_dir = None,
- do_train = False,
- do_eval = False,
- do_predict = False,
- eval_strategy = 'no',
- prediction_loss_only = False,
- per_device_train_batch_size = 4,
- per_device_eval_batch_size = 4,
- per_gpu_train_batch_size = None,
- per_gpu_eval_batch_size = None,
- gradient_accumulation_steps = 2,
- eval_accumulation_steps = 2,
- eval_delay = 0,
- torch_empty_cache_steps = 250,
- learning_rate = 5e-05,
- weight_decay = 0.01,
- adam_beta1 = 0.9,
- adam_beta2 = 0.999,
- adam_epsilon = 1e-08,
- max_grad_norm = 1.0,
- num_train_epochs = 3.0,
- max_steps = -1,
- lr_scheduler_type = 'linear',
- lr_scheduler_kwargs = None,
- warmup_ratio = 0.1,
- warmup_steps = 0,
- log_level = 'passive',
- log_level_replica = 'warning',
- log_on_each_node = True,
- logging_dir = None,
- logging_strategy = 'steps',
- logging_first_step = False,
- logging_steps = 1,
- logging_nan_inf_filter = False,
- save_strategy = 'steps',
- save_steps = 500,
- save_total_limit = None,
- save_safetensors = True,
- save_on_each_node = False,
- save_only_model = False,
- restore_callback_states_from_checkpoint = False,
- no_cuda = False,
- use_cpu = False,
- use_mps_device = False,
- seed = 3407,
- data_seed = 3407,
- jit_mode_eval = False,
- bf16 = False,
- fp16 = False,
- fp16_opt_level = 'O1',
- half_precision_backend = 'auto',
- bf16_full_eval = False,
- fp16_full_eval = False,
- tf32 = None,
- local_rank = -1,
- ddp_backend = None,
- tpu_num_cores = None,
- tpu_metrics_debug = False,
- debug = '',
- dataloader_drop_last = False,
- eval_steps = None,
- dataloader_num_workers = 0,
- dataloader_prefetch_factor = None,
- past_index = -1,
- run_name = None,
- disable_tqdm = None,
- remove_unused_columns = True,
- label_names = None,
- load_best_model_at_end = False,
- metric_for_best_model = None,
- greater_is_better = None,
- ignore_data_skip = False,
- fsdp = None,
- fsdp_min_num_params = 0,
- fsdp_config = None,
- fsdp_transformer_layer_cls_to_wrap = None,
- accelerator_config = None,
- parallelism_config = None,
- deepspeed = None,
- label_smoothing_factor = 0.0,
- optim = 'adamw_8bit',
- optim_args = None,
- adafactor = False,
- group_by_length = False,
- length_column_name = 'length',
- report_to = 'none',
- project = 'huggingface',
- trackio_space_id = 'trackio',
- ddp_find_unused_parameters = None,
- ddp_bucket_cap_mb = None,
- ddp_broadcast_buffers = None,
- dataloader_pin_memory = True,
- dataloader_persistent_workers = False,
- skip_memory_metrics = True,
- use_legacy_prediction_loop = False,
- push_to_hub = False,
- resume_from_checkpoint = None,
- hub_model_id = None,
- hub_strategy = 'every_save',
- hub_token = None,
- hub_private_repo = None,
- hub_always_push = False,
- hub_revision = None,
- gradient_checkpointing = True,
- gradient_checkpointing_kwargs = None,
- include_inputs_for_metrics = False,
- eval_do_concat_batches = True,
- fp16_backend = 'auto',
- push_to_hub_model_id = None,
- push_to_hub_organization = None,
- push_to_hub_token = None,
- mp_parameters = '',
- auto_find_batch_size = False,
- full_determinism = False,
- torchdynamo = None,
- ray_scope = 'last',
- ddp_timeout = 1800,
- torch_compile = False,
- torch_compile_backend = None,
- torch_compile_mode = None,
- include_tokens_per_second = False,
- include_num_input_tokens_seen = False,
- neftune_noise_alpha = None,
- optim_target_modules = None,
- batch_eval_metrics = False,
- eval_on_start = False,
- use_liger_kernel = False,
- liger_kernel_config = None,
- eval_use_gather_object = False,
- average_tokens_across_devices = True,
- model_init_kwargs = None,
- chat_template_path = None,
- dataset_text_field = 'text',
- dataset_kwargs = None,
- dataset_num_proc = None,
- eos_token = None,
- pad_token = None,
- max_length = 1024,
- packing = False,
- packing_strategy = 'bfd',
- padding_free = False,
- pad_to_multiple_of = None,
- eval_packing = None,
- completion_only_loss = None,
- assistant_only_loss = False,
- loss_type = 'nll',
- activation_offloading = False,
- temperature = 0.9,
- lmbda = 0.5,
- beta = 0.5,
- max_new_tokens = 128,
- teacher_model_name_or_path = None,
- teacher_model_init_kwargs = None,
- disable_dropout = True,
- seq_kd = False,
- vllm_sampling_params = None,
- unsloth_num_chunks = -1,
- unsloth_logit_chunk_multiplier = None,
- unsloth_grpo_mini_batch = None,
- max_seq_length = None,
- **kwargs,
- ):
- if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
- if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
- if num_train_epochs is None:
- num_train_epochs = 3.0 # Default to 3 epochs if None, max_steps will override
- if output_dir is None and save_strategy == 'steps' and save_steps == 500:
- output_dir = 'unsloth_training_checkpoints'
- save_strategy = 'no'
- import multiprocessing as _mp
- if _mp.get_start_method() != 'fork':
- dataset_num_proc = None
- elif dataset_num_proc is None:
- import psutil
- dataset_num_proc = min(max((psutil.cpu_count() or 1)+4, 2), 64)
- memory_gb_left = psutil.virtual_memory().available / (1024**3)
- if memory_gb_left <= 2: dataset_num_proc = 1
- else: dataset_num_proc = min(dataset_num_proc, int(memory_gb_left))
- if os.environ.get('UNSLOTH_ENABLE_FLEX_ATTENTION', '0') == '1':
- from unsloth_zoo.flex_attention import HAS_FLEX_ATTENTION
- if HAS_FLEX_ATTENTION and pad_to_multiple_of is None:
- from unsloth_zoo.flex_attention import FLEX_ATTENTION_BLOCK_SIZE
- pad_to_multiple_of = FLEX_ATTENTION_BLOCK_SIZE
-
- if temperature <= 0:
- raise ValueError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.')
- elif temperature >= 10:
- raise ValueError('Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic.')
-
-
- super().__init__(
- output_dir = output_dir,
- overwrite_output_dir = overwrite_output_dir,
- do_train = do_train,
- do_eval = do_eval,
- do_predict = do_predict,
- eval_strategy = eval_strategy,
- prediction_loss_only = prediction_loss_only,
- per_device_train_batch_size = per_device_train_batch_size,
- per_device_eval_batch_size = per_device_eval_batch_size,
- per_gpu_train_batch_size = per_gpu_train_batch_size,
- per_gpu_eval_batch_size = per_gpu_eval_batch_size,
- gradient_accumulation_steps = gradient_accumulation_steps,
- eval_accumulation_steps = eval_accumulation_steps,
- eval_delay = eval_delay,
- torch_empty_cache_steps = torch_empty_cache_steps,
- learning_rate = learning_rate,
- weight_decay = weight_decay,
- adam_beta1 = adam_beta1,
- adam_beta2 = adam_beta2,
- adam_epsilon = adam_epsilon,
- max_grad_norm = max_grad_norm,
- num_train_epochs = num_train_epochs,
- max_steps = max_steps,
- lr_scheduler_type = lr_scheduler_type,
- lr_scheduler_kwargs = lr_scheduler_kwargs,
- warmup_ratio = warmup_ratio,
- warmup_steps = warmup_steps,
- log_level = log_level,
- log_level_replica = log_level_replica,
- log_on_each_node = log_on_each_node,
- logging_dir = logging_dir,
- logging_strategy = logging_strategy,
- logging_first_step = logging_first_step,
- logging_steps = logging_steps,
- logging_nan_inf_filter = logging_nan_inf_filter,
- save_strategy = save_strategy,
- save_steps = save_steps,
- save_total_limit = save_total_limit,
- save_safetensors = save_safetensors,
- save_on_each_node = save_on_each_node,
- save_only_model = save_only_model,
- restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
- no_cuda = no_cuda,
- use_cpu = use_cpu,
- use_mps_device = use_mps_device,
- seed = seed,
- data_seed = data_seed,
- jit_mode_eval = jit_mode_eval,
- bf16 = bf16,
- fp16 = fp16,
- fp16_opt_level = fp16_opt_level,
- half_precision_backend = half_precision_backend,
- bf16_full_eval = bf16_full_eval,
- fp16_full_eval = fp16_full_eval,
- tf32 = tf32,
- local_rank = local_rank,
- ddp_backend = ddp_backend,
- tpu_num_cores = tpu_num_cores,
- tpu_metrics_debug = tpu_metrics_debug,
- debug = debug,
- dataloader_drop_last = dataloader_drop_last,
- eval_steps = eval_steps,
- dataloader_num_workers = dataloader_num_workers,
- dataloader_prefetch_factor = dataloader_prefetch_factor,
- past_index = past_index,
- run_name = run_name,
- disable_tqdm = disable_tqdm,
- remove_unused_columns = remove_unused_columns,
- label_names = label_names,
- load_best_model_at_end = load_best_model_at_end,
- metric_for_best_model = metric_for_best_model,
- greater_is_better = greater_is_better,
- ignore_data_skip = ignore_data_skip,
- fsdp = fsdp,
- fsdp_min_num_params = fsdp_min_num_params,
- fsdp_config = fsdp_config,
- fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
- accelerator_config = accelerator_config,
- parallelism_config = parallelism_config,
- deepspeed = deepspeed,
- label_smoothing_factor = label_smoothing_factor,
- optim = optim,
- optim_args = optim_args,
- adafactor = adafactor,
- group_by_length = group_by_length,
- length_column_name = length_column_name,
- report_to = report_to,
- project = project,
- trackio_space_id = trackio_space_id,
- ddp_find_unused_parameters = ddp_find_unused_parameters,
- ddp_bucket_cap_mb = ddp_bucket_cap_mb,
- ddp_broadcast_buffers = ddp_broadcast_buffers,
- dataloader_pin_memory = dataloader_pin_memory,
- dataloader_persistent_workers = dataloader_persistent_workers,
- skip_memory_metrics = skip_memory_metrics,
- use_legacy_prediction_loop = use_legacy_prediction_loop,
- push_to_hub = push_to_hub,
- resume_from_checkpoint = resume_from_checkpoint,
- hub_model_id = hub_model_id,
- hub_strategy = hub_strategy,
- hub_token = hub_token,
- hub_private_repo = hub_private_repo,
- hub_always_push = hub_always_push,
- hub_revision = hub_revision,
- gradient_checkpointing = gradient_checkpointing,
- gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
- include_inputs_for_metrics = include_inputs_for_metrics,
- eval_do_concat_batches = eval_do_concat_batches,
- fp16_backend = fp16_backend,
- push_to_hub_model_id = push_to_hub_model_id,
- push_to_hub_organization = push_to_hub_organization,
- push_to_hub_token = push_to_hub_token,
- mp_parameters = mp_parameters,
- auto_find_batch_size = auto_find_batch_size,
- full_determinism = full_determinism,
- torchdynamo = torchdynamo,
- ray_scope = ray_scope,
- ddp_timeout = ddp_timeout,
- torch_compile = torch_compile,
- torch_compile_backend = torch_compile_backend,
- torch_compile_mode = torch_compile_mode,
- include_tokens_per_second = include_tokens_per_second,
- include_num_input_tokens_seen = include_num_input_tokens_seen,
- neftune_noise_alpha = neftune_noise_alpha,
- optim_target_modules = optim_target_modules,
- batch_eval_metrics = batch_eval_metrics,
- eval_on_start = eval_on_start,
- use_liger_kernel = use_liger_kernel,
- liger_kernel_config = liger_kernel_config,
- eval_use_gather_object = eval_use_gather_object,
- average_tokens_across_devices = average_tokens_across_devices,
- model_init_kwargs = model_init_kwargs,
- chat_template_path = chat_template_path,
- dataset_text_field = dataset_text_field,
- dataset_kwargs = dataset_kwargs,
- dataset_num_proc = dataset_num_proc,
- eos_token = eos_token,
- pad_token = pad_token,
- max_length = max_length,
- packing = packing,
- packing_strategy = packing_strategy,
- padding_free = padding_free,
- pad_to_multiple_of = pad_to_multiple_of,
- eval_packing = eval_packing,
- completion_only_loss = completion_only_loss,
- assistant_only_loss = assistant_only_loss,
- loss_type = loss_type,
- activation_offloading = activation_offloading,
- temperature = temperature,
- lmbda = lmbda,
- beta = beta,
- max_new_tokens = max_new_tokens,
- teacher_model_name_or_path = teacher_model_name_or_path,
- teacher_model_init_kwargs = teacher_model_init_kwargs,
- disable_dropout = disable_dropout,
- seq_kd = seq_kd,**kwargs)
- self.vllm_sampling_params = vllm_sampling_params
- self.unsloth_num_chunks = unsloth_num_chunks
- if unsloth_grpo_mini_batch is not None:
- if self.generation_batch_size >= unsloth_grpo_mini_batch:
- self.unsloth_grpo_mini_batch = unsloth_grpo_mini_batch
- else:
- raise ValueError(
- f"Unsloth GRPO mini batch size needs to be less than or equal to the effective generation batch size, "
- f"which is self.per_device_train_batch_size * gradient_accumulation_steps."
- )
- self.unsloth_logit_chunk_multiplier = unsloth_logit_chunk_multiplier
- self.max_seq_length = max_seq_length
-
-pass
-
-class _UnslothGKDTrainer(SFTTrainer):
- """"""
-
- _tag_names = ["trl", "gkd"]
- _name = "GKD"
- _paper = {
- "title": "On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes",
- "id": "2306.13649",
- # docstyle-ignore
- "citation": textwrap.dedent("""\
- @inproceedings{agarwal2024on-policy,
- title = {{On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes}},
- author = {Rishabh Agarwal and Nino Vieillard and Yongchao Zhou and Piotr Stanczyk and Sabela Ramos Garea and Matthieu Geist and Olivier Bachem},
- year = 2024,
- booktitle = {The Twelfth International Conference on Learning Representations, {ICLR} 2024, Vienna, Austria, May 7-11, 2024},
- publisher = {OpenReview.net},
- url = {https://openreview.net/forum?id=3zKtaqxLhW},
- }"""),
- }
-
- def __init__(
- self,
- model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
- teacher_model: Union[PreTrainedModel, nn.Module, str] = None,
- args: Optional[GKDConfig] = None,
- data_collator: Optional[DataCollator] = None, # type: ignore
- train_dataset: Optional[Dataset] = None,
- eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
- processing_class: Optional[
- Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
- ] = None,
- compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
- callbacks: Optional[list[TrainerCallback]] = None,
- optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
- preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
- peft_config: Optional["PeftConfig"] = None,
- formatting_func: Optional[Callable] = None,
- ):
- if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"):
- warnings.warn(
- "This trainer will soon be moved to trl.experimental and is a candidate for removal. If you rely on "
- "it and want it to remain, please share your comments here: "
- "https://github.com/huggingface/trl/issues/4223. Silence this warning by setting environment variable "
- "TRL_EXPERIMENTAL_SILENCE=1."
- )
- # Ensure Trainer does not drop non-signature columns used by the collator [e.g., "prompts"]
- args.remove_unused_columns = False
- # Respect a user-provided data_collator; otherwise, provide a ChatML collator that
- if data_collator is None:
- data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_length)
-
- # Ensure SFTTrainer does not pre-process the dataset when using a ChatML collator,
- # so that raw conversational fields [e.g., "messages"] remain available to the collator.
- if args.dataset_kwargs is None:
- args.dataset_kwargs = {"skip_prepare_dataset": True}
- else:
- args.dataset_kwargs["skip_prepare_dataset"] = True
-
- # Liger fused GKD loss [JSD]
- self.use_liger_gkd_loss = False
- if args.use_liger_kernel:
- self.liger_jsd_loss = LigerFusedLinearJSDLoss(
- beta=args.beta,
- ignore_index=-100,
- temperature=args.temperature,
- compiled=False,
- )
- self.use_liger_gkd_loss = True
-
- super().__init__(
- model,
- args=args,
- data_collator=data_collator,
- train_dataset=train_dataset,
- eval_dataset=eval_dataset,
- processing_class=processing_class,
- compute_metrics=compute_metrics,
- callbacks=callbacks,
- optimizers=optimizers,
- preprocess_logits_for_metrics=preprocess_logits_for_metrics,
- peft_config=peft_config,
- formatting_func=formatting_func,
- )
-
- if args.teacher_model_init_kwargs is None:
- teacher_model_init_kwargs = {}
- elif not isinstance(teacher_model, str):
- raise ValueError(
- "You passed teacher_model_init_kwargs to the GKDConfig, but your teacher_model is already instantiated."
- )
- else:
- teacher_model_init_kwargs = args.teacher_model_init_kwargs
- teacher_model_init_kwargs["dtype"] = (
- teacher_model_init_kwargs["dtype"]
- if teacher_model_init_kwargs["dtype"] in ["auto", None]
- else getattr(torch, teacher_model_init_kwargs["dtype"])
- )
-
- if isinstance(teacher_model, str):
- teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model, **teacher_model_init_kwargs)
-
- # Disable dropout in the model
- if args.disable_dropout:
- disable_dropout_in_model(self.model)
-
- if self.is_deepspeed_enabled:
- self.teacher_model = prepare_deepspeed(teacher_model, self.accelerator)
- else:
- self.teacher_model = self.accelerator.prepare_model(teacher_model, evaluation_mode=True)
-
- self.lmbda = args.lmbda
- self.beta = args.beta
- self.temperature = args.temperature
- self.seq_kd = args.seq_kd
-
- self.generation_config = GenerationConfig(
- max_new_tokens=args.max_new_tokens,
- temperature=args.temperature,
- do_sample=True,
- top_k=0,
- use_cache=False if args.gradient_checkpointing else True,
- pad_token_id=self.processing_class.pad_token_id,
- )
- # Set custom EOS tokens if they are specified by the model's generation
- # config. This is important for models with the Llama 3 chat template,
- # which use special tokens <|eot_id|> and <|eom_id|> to mark the end of
- # turns or messages.
- if (
- hasattr(self.model.generation_config, "eos_token_id")
- and self.model.generation_config.eos_token_id is not None
- ):
- self.generation_config.eos_token_id = self.model.generation_config.eos_token_id
-
- @staticmethod
- def generalized_jsd_loss(
- student_logits, teacher_logits, labels=None, beta=0.5, temperature=1.0, reduction="batchmean"
- ):
- """
- Compute the generalized Jensen-Shannon Divergence loss for knowledge distillation using F.kl_div. See Eq. (1)
- of https://huggingface.co/papers/2306.13649 for the definition.
-
- Args:
- student_logits:
- Tensor of shape (batch_size, sequence_length, vocab_size)
- teacher_logits:
- Tensor of shape (batch_size, sequence_length, vocab_size)
- labels:
- Tensor of shape (batch_size, sequence_length) with -100 for padding tokens to ignore when computing
- loss
- beta:
- Interpolation coefficient between 0 and 1 (default: 0.5)
- temperature:
- Softmax temperature (default: 1.0)
- reduction:
- Specifies the reduction to apply to the output (default: 'batchmean')
-
- Returns:
- loss: Scalar tensor with the generalized JSD loss
- """
-
- # Apply temperature scaling
- student_logits = student_logits / temperature
- teacher_logits = teacher_logits / temperature
-
- # Compute log probabilities for student and probabilities for teacher
- student_log_probs = F.log_softmax(student_logits, dim=-1)
- teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
-
- if beta == 0:
- jsd = F.kl_div(student_log_probs, teacher_log_probs, reduction="none", log_target=True)
- elif beta == 1:
- jsd = F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True)
- else:
- # Compute the log of the mixture distribution
- # log(a + b) = log(exp(log(a)) + exp(log(b))) -> for mixture
- beta = torch.tensor(beta, dtype=student_log_probs.dtype)
- mixture_log_probs = torch.logsumexp(
- torch.stack([student_log_probs + torch.log(1 - beta), teacher_log_probs + torch.log(beta)]),
- dim=0,
- )
-
- # Compute KL divergences using F.kl_div
- # PyTorch differs from the standard mathematical definition, so the order of the probability distributions is swapped compared to that defined in the paper.
- kl_teacher = F.kl_div(mixture_log_probs, teacher_log_probs, reduction="none", log_target=True)
- kl_student = F.kl_div(mixture_log_probs, student_log_probs, reduction="none", log_target=True)
-
- # Compute the Generalized Jensen-Shannon Divergence
- jsd = beta * kl_teacher + (1 - beta) * kl_student
-
- # Masking
- if labels is not None:
- mask = labels != -100
- jsd = jsd[mask]
-
- # Apply reduction
- if reduction == "batchmean":
- return jsd.sum() / mask.sum() if labels is not None else jsd.sum() / jsd.size(0)
- elif reduction == "sum":
- return jsd.sum()
- elif reduction == "mean":
- return jsd.mean()
- else:
- return jsd
-
- def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
- if self.use_liger_gkd_loss:
- # Forward only through the base models (avoid lm_head to save memory)
- unwrapped_student = self.accelerator.unwrap_model(model)
- if hasattr(unwrapped_student, "get_decoder") and unwrapped_student.get_decoder() is not None:
- base_student = unwrapped_student.get_decoder()
- else:
- base_student = getattr(
- unwrapped_student, getattr(unwrapped_student, "base_model_prefix", "model"), unwrapped_student
- )
-
- student_outputs = base_student(
- input_ids=inputs["input_ids"],
- attention_mask=inputs["attention_mask"],
- output_hidden_states=True,
- use_cache=False,
- )
-
- self.teacher_model.eval()
- unwrapped_teacher = self.accelerator.unwrap_model(self.teacher_model)
- if hasattr(unwrapped_teacher, "get_decoder") and unwrapped_teacher.get_decoder() is not None:
- base_teacher = unwrapped_teacher.get_decoder()
- else:
- base_teacher = getattr(
- unwrapped_teacher, getattr(unwrapped_teacher, "base_model_prefix", "model"), unwrapped_teacher
- )
- with torch.no_grad():
- teacher_outputs = base_teacher(
- input_ids=inputs["input_ids"],
- attention_mask=inputs["attention_mask"],
- output_hidden_states=True,
- use_cache=False,
- )
-
- # hidden states (shifted)
- student_hidden = student_outputs.last_hidden_state[:, :-1].contiguous()
- teacher_hidden = teacher_outputs.last_hidden_state[:, :-1].contiguous()
-
- # labels mask and labels (shifted)
- labels_mask = inputs["labels"] != -100
- masked_input_ids = torch.where(
- labels_mask, inputs["input_ids"], torch.full_like(inputs["input_ids"], -100)
- )
- true_labels = masked_input_ids[:, 1:].contiguous()
-
- # heads
- student_head = unwrapped_student.get_output_embeddings()
- teacher_head = unwrapped_teacher.get_output_embeddings()
-
- # liger fused jsd loss
- loss = self.liger_jsd_loss(
- student_input=student_hidden,
- student_weight=student_head.weight,
- teacher_input=teacher_hidden,
- teacher_weight=teacher_head.weight,
- true_labels=true_labels,
- student_bias=getattr(student_head, "bias", None),
- teacher_bias=getattr(teacher_head, "bias", None),
- )
- else:
- # compute student output
- student_outputs = model(
- input_ids=inputs["input_ids"],
- attention_mask=inputs["attention_mask"],
- )
-
- # compute teacher output in eval mode
- self.teacher_model.eval()
- with torch.no_grad():
- teacher_outputs = self.teacher_model(
- input_ids=inputs["input_ids"],
- attention_mask=inputs["attention_mask"],
- )
-
- # slice the logits for the generated tokens using the inputs["prompts"] lengths
- prompt_lengths = inputs["prompts"].shape[1]
- shifted_student_logits = student_outputs.logits[:, prompt_lengths - 1 : -1, :]
- shifted_teacher_logits = teacher_outputs.logits[:, prompt_lengths - 1 : -1, :]
- shifted_labels = inputs["labels"][:, prompt_lengths:]
-
- # compute loss
- loss = self.generalized_jsd_loss(
- student_logits=shifted_student_logits,
- teacher_logits=shifted_teacher_logits,
- labels=shifted_labels,
- beta=self.beta,
- )
-
- # empty cache
- empty_cache()
-
- # Return loss
- return (loss, student_outputs) if return_outputs else loss
-
- @staticmethod
- def generate_on_policy_outputs(model, inputs, generation_config, pad_token_id=None):
- # Generate output with respect to the prompt-only
- generated_outputs = model.generate(
- input_ids=inputs["prompts"],
- attention_mask=inputs.get("prompt_attention_mask", None),
- generation_config=generation_config,
- return_dict_in_generate=True,
- )
-
- # Get the generated token IDs
- generated_tokens = generated_outputs.sequences
- # Calculate new attention mask
- new_attention_mask = torch.ones_like(generated_tokens)
- new_labels = generated_tokens.clone()
-
- # If there's pad_token_id, set attention mask to 0 for padding tokens
- if pad_token_id is not None:
- new_labels[new_labels == pad_token_id] = -100
- new_attention_mask[generated_tokens == pad_token_id] = 0
-
- return generated_tokens, new_attention_mask, new_labels
-
- def training_step(
- self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
- ) -> torch.Tensor:
- """
- Perform a training step for the Generalized Knowledge Distillation (GKD) model.
-
- This method implements the on-policy learning approach described in the GKD paper. With probability
- `self.lmbda`, it generates new responses using the student model, which are then used for training instead of
- the original inputs.
- """
- if self.seq_kd:
- with unwrap_model_for_generation(self.teacher_model, self.accelerator) as unwrapped_model:
- new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
- unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id
- )
- inputs["input_ids"] = new_input_ids
- inputs["attention_mask"] = new_attention_mask
- inputs["labels"] = new_labels
- if random.random() <= self.lmbda:
- with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
- new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
- unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id
- )
- inputs["input_ids"] = new_input_ids
- inputs["attention_mask"] = new_attention_mask
- inputs["labels"] = new_labels
-
- loss = super().training_step(model, inputs, num_items_in_batch)
- return loss
-class UnslothGKDTrainer(_UnslothGKDTrainer):
- """
- Trainer for Generalized Knowledge Distillation (GKD) of language models.
-
- For details on GKD, see the paper: [On-Policy Distillation of Language Models: Learning from Self-Generated
- Mistakes](https://huggingface.co/papers/2306.13649).
-
- Args:
- model ([`~transformers.PreTrainedModel`] or `torch.nn.Module` or `str`, *optional*):
- Model to be trained, or the string identifier of the model to be instantiated from a pretrained model.
- teacher_model ([`~transformers.PreTrainedModel`] or `torch.nn.Module` or `str`, *optional*):
- Teacher model for knowledge distillation, or the string identifier of the model to be instantiated from a
- pretrained model.
- args ([`GKDConfig`], *optional*):
- Training arguments.
- data_collator ([`~transformers.DataCollator`], *optional*):
- Data collator to batch samples from the dataset. It defaults to a [`DataCollatorForChatML`] using the
- `processing_class`.
- train_dataset ([`~datasets.Dataset`], *optional*):
- Dataset for training.
- eval_dataset ([`~datasets.Dataset`] or `dict` of [`~datasets.Dataset`], *optional*):
- Dataset for evaluation.
- processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*):
- Class to process the data.
- compute_metrics (`Callable`, *optional*):
- Function to compute metrics at evaluation. Must take in an [`~transformers.EvalPrediction`] and return a
- dictionary string to float.
- callbacks (`list` of [`~transformers.TrainerCallback`], *optional*):
- Callbacks to use during training.
- optimizers (`tuple` of `torch.optim.Optimizer` and `torch.optim.lr_scheduler.LambdaLR`, *optional*, defaults to `(None, None)`):
- Tuple containing the optimizer and the learning rate scheduler to use for training.
- preprocess_logits_for_metrics (`Callable`, *optional*):
- Function to preprocess the logits before computing the metrics. Must take in the `logits` and `labels` and
- return the logits to be used for metrics computation.
- peft_config ([`~peft.PeftConfig`], *optional*):
- PEFT configuration to use PEFT for training. If `None`, PEFT is not used. If provided, the `model` will be
- wrapped with the specified PEFT adapter.
- formatting_func (`Callable`, *optional*):
- Function to format the dataset. Must take in an example and return an example.
-
- """
- def __init__(
- self,
- model = None,
- teacher_model = None,
- args = None,
- data_collator = None,
- train_dataset = None,
- eval_dataset = None,
- processing_class = None,
- compute_metrics = None,
- callbacks = None,
- preprocess_logits_for_metrics = None,
- peft_config = None,
- formatting_func = None,
- **kwargs
- ):
- if args is None: args = UnslothGKDConfig()
- use_bf16 = getattr(args, 'bf16', False)
- if type(use_bf16) is not bool: use_bf16 = False
- use_fp16 = getattr(args, 'fp16', False)
- if type(use_fp16) is not bool: use_fp16 = False
- force_float32 = False
- full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1'
- if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'):
- print('Unsloth: Switching to float32 training since model cannot work with float16')
- force_float32 = True
- mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
- dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None)
- if dtype is None: dtype = model.get_input_embeddings().weight.dtype
- from unsloth_zoo.utils import _get_dtype
- dtype = _get_dtype(dtype)
- float16 = dtype == torch.float16
- if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
- if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
- if force_float32:
- # Forced float32 training
- args.fp16 = False
- args.bf16 = False
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
- if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
- # args.mixed_precision is a new argument which needs to be set now
- elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
- # Mixed precision training
- args.fp16 = float16
- args.bf16 = not float16
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
- if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16'
- # args.mixed_precision is a new argument which needs to be set now
- elif mixed_precision_dtype == 'bfloat16':
- # Both False since bfloat16 full finetuning doesn't do any autocasting.
- args.fp16 = False
- args.bf16 = False
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
- if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
- # args.mixed_precision is a new argument which needs to be set now
-
- if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
- args.eval_strategy = 'steps'
- if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
- ga_steps = getattr(args, 'gradient_accumulation_steps', None)
- if ga_steps is not None and ga_steps > 1:
- from transformers import __version__ as transformers_version
- if Version(transformers_version) <= Version('4.45.2'):
- print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
- '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
- if getattr(args, 'eval_strategy', 'no') != 'no':
- eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
- if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
- if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
- fp16_full_eval = getattr(args, 'fp16_full_eval', False)
- if type(fp16_full_eval) is not bool: fp16_full_eval = False
- bf16_full_eval = getattr(args, 'bf16_full_eval', False)
- if type(bf16_full_eval) is not bool: bf16_full_eval = False
- if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
- if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
- if force_float32:
- args.bf16_full_eval = False
- args.fp16_full_eval = False
- elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
- args.bf16_full_eval = True
- args.fp16_full_eval = False
- elif not bf16_full_eval and not fp16_full_eval:
- args.bf16_full_eval = args.bf16
- args.fp16_full_eval = args.fp16
- _output_logits = False
- if locals().get('compute_metrics', None) is not None: _output_logits = True
- if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
- if _output_logits:
- os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
- if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
- pass
- else:
- model_max_seq_length = getattr(model, 'max_seq_length', None)
- args_max_seq_length = getattr(args, 'max_seq_length', None)
- if args_max_seq_length is None and model_max_seq_length is not None:
- max_seq_length = model.max_seq_length
- if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
- elif args_max_seq_length is not None and model_max_seq_length is not None:
- if args_max_seq_length > model_max_seq_length:
- print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but '
- 'the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.')
- args.max_seq_length = model_max_seq_length
- if model is not None and hasattr(model, 'for_training'):
- model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
- if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
- if 'processing_class' in locals():
- if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
- if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
- __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
- from unsloth_zoo.vision_utils import UnslothVisionDataCollator
- if not isinstance(data_collator, UnslothVisionDataCollator):
- if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
- data_collator = TransformersDataCollatorForLanguageModeling(
- __tokenizer,
- mlm = False,
- mlm_probability = 0.0,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
- data_collator = DataCollatorForSeq2Seq(
- __tokenizer,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- else:
- if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
- if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
- if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
- if not isinstance(data_collator, UnslothVisionDataCollator):
- if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
- if isinstance(data_collator, DataCollatorForSeq2Seq):
- data_collator = DataCollatorForSeq2Seq(
- __tokenizer.tokenizer,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- else:
- data_collator = TransformersDataCollatorForLanguageModeling(
- __tokenizer.tokenizer,
- mlm = False,
- mlm_probability = 0.0,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- other_metrics = []
-
- from unsloth_zoo.logging_utils import PatchRLStatistics
- PatchRLStatistics('gkd_trainer', other_metrics)
-
- # [TODO] Fix up DataParallel multiplying batch sizes
- # [TODO] DDP works, but DP seems to not work? [TODO]
- if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1:
- if getattr(args, "_n_gpu", 1) != 1:
- args._n_gpu = 1
- if "model" in locals() and hasattr(model, "for_training"):
- model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
- super().__init__(
- model = model,
- teacher_model = teacher_model,
- args = args,
- data_collator = data_collator,
- train_dataset = train_dataset,
- eval_dataset = eval_dataset,
- processing_class = processing_class,
- compute_metrics = compute_metrics,
- callbacks = callbacks,
- preprocess_logits_for_metrics = preprocess_logits_for_metrics,
- peft_config = peft_config,
- formatting_func = formatting_func,**kwargs)
- if "model" in locals() and hasattr(model, "for_inference"):
- model.for_inference()
- if hasattr(self, 'neftune_hook_handle'):
- self.neftune_hook_handle.remove()
- if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
- if getattr(args, 'neftune_noise_alpha', None) is not None:
- model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
- pass
- if hasattr(self, 'accelerator'):
- scaler = self.accelerator.scaler
- current_model = model
- while hasattr(current_model, 'model'):
- current_model.accelerator_scaler = scaler
- current_model = current_model.model
- current_model.accelerator_scaler = scaler
- pass
- if hasattr(self, 'train'):
- self.train = MethodType(prepare_for_training_mode(self.__class__.train), self)
- pass
- if hasattr(self, 'llm') and self.llm is not None and hasattr(self.llm, 'get_tokenizer'):
- _vllm_tok = self.llm.get_tokenizer()
- _pc = getattr(self, 'processing_class', None) or getattr(self, 'tokenizer', None)
- if _vllm_tok is not None and _pc is not None and getattr(_pc, 'chat_template', None) is not None and getattr(_vllm_tok, 'chat_template', None) is None:
- _vllm_tok.chat_template = _pc.chat_template
- pass
-
-pass
diff --git a/unsloth_compiled_cache/UnslothGRPOTrainer.py b/unsloth_compiled_cache/UnslothGRPOTrainer.py
deleted file mode 100644
index 380d394..0000000
--- a/unsloth_compiled_cache/UnslothGRPOTrainer.py
+++ /dev/null
@@ -1,4196 +0,0 @@
-"""
-2026.2.1
-2026.2.1
-4.57.6
-0.24.0
-__UNSLOTH_VERSIONING__
-"""
-
-# Unsloth auto generated code
-# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Lesser General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU Lesser General Public License
-# along with this program. If not, see .
-
-from torch import Tensor
-import torch
-import torch.nn as nn
-from torch.nn import functional as F
-from unsloth_zoo.temporary_patches.common import torch_compile
-from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
-from trl.trainer.grpo_trainer import (Any, AutoConfig, AutoModelForSequenceClassification, AutoProcessor, AutoTokenizer, BaseTrainer, DataLoader, Dataset, FSDP, GRPOConfig, GRPOTrainer, GenerationConfig, IterableDataset, Optional, Path, PeftConfig, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, RepeatSampler, RewardFunc, Sampler, SyncRefModelCallback, TrainerCallback, Union, VLLMClient, _ForwardRedirection, apply_chat_template, broadcast_object_list, datasets, defaultdict, deque, disable_dropout_in_model, ensure_master_addr_port, gather, gather_object, identity, inspect, is_conversational, is_datasets_available, is_flash_attn_2_available, is_liger_kernel_available, is_peft_model, is_rich_available, is_vllm_available, logger, logging, maybe_apply_chat_template, nanmax, nanmin, nanstd, nn, nullcontext, os, pad, partial, prepare_deepspeed, prepare_fsdp, prepare_multimodal_messages, print_prompt_completions_sample, profiling_context, profiling_decorator, seed_worker, selective_log_softmax, set_seed, shuffle_sequence_dict, split_pixel_values_by_grid, split_tensor_dict, textwrap, torch, transformers, unsplit_pixel_values_by_grid, unwrap_model_for_generation, AutoConfig, AutoModelForSequenceClassification, AutoProcessor, AutoTokenizer, Dataset, GRPOConfig, GRPOTrainer, GenerationConfig, IterableDataset, Optional, PeftConfig, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, RewardFunc, SyncRefModelCallback, TrainerCallback, Union, VLLMClient, datasets, defaultdict, deque, disable_dropout_in_model, ensure_master_addr_port, identity, inspect, is_liger_kernel_available, is_peft_model, is_vllm_available, logger, nn, os, pad, prepare_deepspeed, prepare_fsdp, set_seed, torch, transformers, Any, Union, gather, gather_object, is_conversational, logging, nanmax, nanmin, nanstd, os, pad, torch, FSDP, Optional, apply_chat_template, broadcast_object_list, gather, gather_object, is_flash_attn_2_available, maybe_apply_chat_template, nullcontext, os, pad, prepare_multimodal_messages, profiling_context, torch, transformers, unwrap_model_for_generation, os, pad, selective_log_softmax, torch, transformers, Any, Union, profiling_decorator, shuffle_sequence_dict, split_pixel_values_by_grid, split_tensor_dict, torch, unsplit_pixel_values_by_grid, Optional, PreTrainedModel, logger, os, torch, FSDP, nn, os, FSDP, nn, torch, GRPOTrainer, gather, nanmax, nanmin, os, pad, torch)
-
-
-import os
-from typing import *
-from dataclasses import dataclass, field
-from packaging.version import Version
-import torch
-import numpy as np
-from contextlib import nullcontext
-from torch.nn import functional as F
-import inspect
-from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
-from transformers.training_args import ParallelMode
-from unsloth_zoo.device_type import DEVICE_TYPE, device_synchronize
-
-# Wrap trainer with padding to right and enable training mode
-# Also patches W&B since multiple runs must use wandb.finish()
-import functools
-from types import MethodType
-try:
- from unsloth_zoo.gradient_checkpointing import reset_unsloth_gradient_checkpointing_buffers
-except:
- def reset_unsloth_gradient_checkpointing_buffers(): pass
-def prepare_for_training_mode(f):
- @functools.wraps(f)
- def wrapper(self, *args, **kwargs):
- # Enable training mode
- _was_training = None
- # Get gradient checkpointing setting from training arguments
- use_gc = getattr(self.args, 'gradient_checkpointing', True)
- if hasattr(self, 'model') and hasattr(self.model, "training"):
- _was_training = self.model.training
- if hasattr(self, 'model') and hasattr(self.model, "for_training"):
- self.model.for_training(use_gradient_checkpointing=use_gc)
- output = f(self, *args, **kwargs)
- # Restore previous mode when possible
- if hasattr(self, 'model') and hasattr(self.model, "for_inference"):
- if _was_training is False:
- self.model.for_inference()
- elif _was_training is True and hasattr(self.model, "for_training"):
- self.model.for_training(use_gradient_checkpointing=use_gc)
- # Reset gradient checkpointing buffers to free memory while staying ready for next run
- try:
- reset_unsloth_gradient_checkpointing_buffers()
- except:
- pass
- # Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run
- try:
- import wandb
- wandb.finish()
- except:
- pass
- return output
- return wrapper
-pass
-
-torch_compile_options = {
- "epilogue_fusion" : True,
- "max_autotune" : False,
- "shape_padding" : True,
- "trace.enabled" : False,
- "triton.enable_persistent_tma_matmul": torch.cuda.get_device_capability()[0] >= 9,
- "cuda.cutlass_epilogue_fusion_enabled": torch.cuda.get_device_capability()[0] >= 9,
- "cuda.cutlass_tma_only": torch.cuda.get_device_capability()[0] >= 9,
- "cuda.compile_opt_level" : "-O2",
- "cuda.enable_cuda_lto" : True,
- }
-
-@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
-def chunked_hidden_states_selective_log_softmax(
- hidden_states: torch.Tensor,
- lm_head: torch.Tensor,
- index: torch.Tensor,
- chunks: int = 4,
- logit_scale_multiply: float = 0.0,
- logit_scale_divide: float = 0.0,
- logit_softcapping: float = 0.0,
- temperature: float = 1.0,
-) -> torch.Tensor:
- # All Unsloth Zoo code licensed under AGPL3
- flat_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1])
- flat_index = index.reshape(-1)
-
- chunked_hidden_states = torch.chunk(flat_hidden_states, chunks=chunks, dim=0)
- chunked_index = torch.chunk(flat_index, chunks=chunks, dim=0)
-
- all_per_token_logps = []
-
- for chunk_hidden_states, chunk_index in zip(chunked_hidden_states, chunked_index):
- chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t()
-
- if logit_scale_multiply != 0.0:
- chunk_logits = chunk_logits * logit_scale_multiply
- if logit_scale_divide != 0.0:
- chunk_logits = chunk_logits / logit_scale_divide
- if logit_softcapping != 0.0:
- chunk_logits = chunk_logits * torch.tanh(chunk_logits / logit_softcapping)
-
- chunk_logits = chunk_logits.to(torch.float32)
-
- if temperature != 1.0:
- chunk_logits = chunk_logits / temperature
-
- selected_logits = torch.gather(chunk_logits, dim=-1, index=chunk_index.unsqueeze(-1)).squeeze(-1)
- logsumexp_values = torch.logsumexp(chunk_logits, dim=-1)
- per_token_logps = selected_logits - logsumexp_values
- all_per_token_logps.append(per_token_logps)
-
- all_per_token_logps = torch.concat(all_per_token_logps)
-
- all_per_token_logps = all_per_token_logps.reshape((hidden_states.shape[0], hidden_states.shape[1]))
- return all_per_token_logps
-
-@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
-def chunked_selective_log_softmax(logits, index):
- # Split into 4 chunks only
- chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
- chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
- all_per_token_logps = []
- # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
- for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
- chunk_logits = chunk_logits.to(torch.float32)
- selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
- logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
- per_token_logps = selected_logits - logsumexp_values
- all_per_token_logps.append(per_token_logps)
- pass
- all_per_token_logps = torch.concat(all_per_token_logps)
- all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
- return all_per_token_logps
-
-def calculate_pad_tokens_in_prompt(
- input_ids: torch.Tensor,
- logits_to_keep: int,
- pad_token_id: int
-) -> torch.Tensor:
- """
- Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens
- """
- if logits_to_keep >= input_ids.shape[1]:
- raise ValueError("logits_to_keep must be smaller than the sequence length.")
-
- prompt_section = input_ids[:, :-logits_to_keep]
-
- padding_mask = (prompt_section == pad_token_id)
-
- pad_token_counts = padding_mask.sum(dim=1)
-
- return pad_token_counts
-
-def create_completion_attention_mask(
- completion_input_ids: torch.Tensor,
- left_pad_tokens_per_prompt: torch.Tensor,
- max_left_pad: int,
- pad_token_id: int
-) -> torch.Tensor:
- """
- Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad]
-
- Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens
- and pad are pad tokens, this function would make a completion mask that would 0 out the pad
- and p tokens. so in this example [0,0,0,1,1,1,0,0,0]
- """
- batch_size, completion_len = completion_input_ids.shape
- device = completion_input_ids.device
-
- num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt
-
- indices = torch.arange(completion_len, device=device).unsqueeze(0)
- shift_mask = indices >= num_tokens_to_mask.unsqueeze(1)
-
- non_padding_mask = (completion_input_ids != pad_token_id)
-
- final_mask = shift_mask & non_padding_mask
-
- return final_mask
-
-def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor:
- """
- Moves all padding tokens in each sequence of a batch to the right.
- """
- mask = (tensor != pad_id)
- # Must do stable=True since binary mark is unordered
- sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True)
- packed_tensor = torch.gather(tensor, 1, sorted_indices)
- return packed_tensor
-
-def align_logprobs_with_mask(
- logprob_tensor: torch.Tensor,
- attention_mask: torch.Tensor,
- pad_value: float = 0.0
-) -> torch.Tensor:
- """
- Aligns a log probability tensor with a given attention mask.
- """
-
- device = logprob_tensor.device
- batch_size, logprob_seq_len = logprob_tensor.shape
- mask_seq_len = attention_mask.shape[1]
-
- padded_logprobs = torch.full(
- attention_mask.shape,
- fill_value=pad_value,
- dtype=logprob_tensor.dtype,
- device=device
- )
-
- left_pad_counts = torch.argmax(attention_mask, dim=1)
-
- cols = torch.arange(logprob_seq_len, device=device)
- dest_indices = left_pad_counts.unsqueeze(1) + cols
-
- # Create destination row indices
- # Shape: [batch_size, logprob_seq_len]
- row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices)
-
- # --- 4. Filter out-of-bounds indices and perform assignment ---
- # Create a mask to identify only the indices that are within the bounds
- # of the target tensor's sequence length.
- valid_mask = dest_indices < mask_seq_len
-
- # Use this mask to select only the valid row indices, column indices,
- # and the corresponding values from the logprob tensor.
- # This flattens the selected elements into 1D tensors.
- valid_rows = row_indices[valid_mask]
- valid_cols = dest_indices[valid_mask]
- valid_vals = logprob_tensor[valid_mask]
-
- # Place the valid values into their correct positions in the padded tensor
- # using a single, efficient advanced indexing operation.
- padded_logprobs[valid_rows, valid_cols] = valid_vals
-
- return padded_logprobs
-
-def autotune_batch_and_chunks(
- total_input_rows,
- seq_len,
- hidden_size,
- vocab_size,
- dtype_bytes=16,
- multiplier=None
-):
- if multiplier is None:
- final_m = max(4, seq_len // 4096)
- else:
- final_m = multiplier
-
- if torch.cuda.is_available():
- free_bytes, _ = torch.cuda.mem_get_info()
- limit_gb = (free_bytes / (1024**3))*.80
- elif hasattr(torch, "xpu") and torch.xpu.is_available():
- # For XPU: estimate free memory from total - reserved
- total_mem = torch.xpu.get_device_properties(0).total_memory
- reserved_mem = torch.xpu.memory_reserved()
- free_bytes = total_mem - reserved_mem
- limit_gb = (free_bytes / (1024**3)) * 0.80
- else:
- # Fallback: assume 8GB available
- limit_gb = 8.0
-
- bytes_to_gb = 1024**3
-
- b_vals = torch.arange(total_input_rows, 0, -1, device='cpu', dtype=torch.float32)
-
- hidden_gb = (b_vals * seq_len * hidden_size * dtype_bytes) / bytes_to_gb
-
- base_logits = ((b_vals/total_input_rows) * b_vals * seq_len * vocab_size * dtype_bytes) / bytes_to_gb
- logits_gb = base_logits / final_m
-
- total_mem_gb = hidden_gb + logits_gb
-
- valid_mask = total_mem_gb <= limit_gb
- valid_indices = torch.nonzero(valid_mask, as_tuple=False)
-
- if valid_indices.shape[0] == 0:
- #This means your GPU will OOM
- return 4, final_m
-
- best_idx = valid_indices[0].item()
- final_b = int(b_vals[best_idx].item())
-
- return final_b, final_m
-def grpo_compute_loss(
- ref,
- new,
- old,
- sampling_per_token_logps,
- input_ids,
- mask,
- beta,
- advantages,
- **kwargs
-):
- # All Unsloth Zoo code licensed under AGPL3
- # Set defaults for optional arguments
- loss_type = kwargs.get("loss_type", "grpo")
- epsilon_low = kwargs.get("epsilon_low", 0.2)
- epsilon_high = kwargs.get("epsilon_high", 0.2)
- max_completion_length = kwargs.get("max_completion_length", 8192)
- delta = kwargs.get("delta", None)
- importance_sampling_level = kwargs.get("importance_sampling_level", "token")
- num_items_in_batch = kwargs.get("num_items_in_batch", None)
- current_gradient_accumulation_steps = kwargs.get("current_gradient_accumulation_steps", 1)
- num_processes = kwargs.get("num_processes", 1)
- use_vllm = kwargs.get("use_vllm", False)
- vllm_importance_sampling_cap = kwargs.get("vllm_importance_sampling_cap", 2.0)
- get_sapo_token_loss = kwargs.get("get_sapo_token_loss", None)
- sapo_temperature_pos = kwargs.get("sapo_temperature_pos", 1.0)
- sapo_temperature_neg = kwargs.get("sapo_temperature_neg", 1.05)
- get_off_policy_mask = kwargs.get("get_off_policy_mask", None)
- off_policy_mask_threshold = kwargs.get("off_policy_mask_threshold", None)
- input_ids = input_ids.unsqueeze(-1)
-
- if advantages.dim() == 1:
- advantages = advantages.unsqueeze(1)
-
- if off_policy_mask_threshold is not None:
- off_policy_mask = get_off_policy_mask(
- advantages=advantages,
- per_token_logps=new,
- old_per_token_logps=old,
- mask=mask,
- off_policy_threshold=off_policy_mask_threshold,
- )
-
- with torch.no_grad():
- if use_vllm and sampling_per_token_logps is not None:
- #must filter out extra prompt tokens in begining after making input_ids left padded
- importance_sampling_ratio = torch.exp((old * mask) - sampling_per_token_logps)
- importance_sampling_ratio = torch.clamp(
- importance_sampling_ratio, max=vllm_importance_sampling_cap
- )
- pass
-
- # Must detach - otherwise gradients are not propagated correctly!
- # exp(x - x) == 1
- # loss_i = torch.exp(new - new.detach()) * advantages.unsqueeze(1)
- if old is not None:
- log_ratio = new - old
- else:
- log_ratio = new - new.detach()
-
- if importance_sampling_level == "token":
- log_importance_weights = log_ratio
- elif importance_sampling_level == "sequence":
- log_importance_weights = (log_ratio * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)
- log_importance_weights = log_importance_weights.unsqueeze(-1)
- else:
- raise ValueError(
- f"Unknown importance sampling level: {importance_sampling_level}. Possible values are 'token' "
- "and 'sequence'."
- )
-
- coef_1 = torch.exp(log_importance_weights)
-
- # Reverse KL
- # Note that this is a low variance low bias estimator for the KL divergence as used in GRPO paper
- if beta != 0.0:
- kl_i = torch.exp(ref - new) - (ref - new) - 1.0
-
- else:
- # set kl_i to a tensor of zeros with the correct shape
- if importance_sampling_level == "sequence":
- kl_i = new.new_zeros(new.size(0), 1)
- else:
- kl_i = torch.zeros_like(new)
- # Full correct reverse KL divergence?? Missing term maybe?
- # kl_i = torch.exp(new) * kl_i
-
- # Below is forward KL (normal KL)
- # kl_i = torch.exp(old) * (old - new)
- if loss_type == "cispo":
- clamped_ratios = torch.clamp(coef_1, max=epsilon_high).detach()
- loss_i = -clamped_ratios * advantages * new
- #breakpoint()
- elif loss_type in ["grpo", "bnpo", "dr_grpo", "dapo"]:
- coef_2 = torch.clamp(coef_1, 1 - epsilon_low, 1 + epsilon_high)
-
- if delta is not None:
- loss_1 = torch.clamp(coef_1, max=delta) * advantages
- else:
- loss_1 = coef_1 * advantages
- pass
- loss_2 = coef_2 * advantages
- loss_i = -torch.min(loss_1, loss_2)
- elif loss_type == "sapo":
- if get_sapo_token_loss is None:
- raise Exception(f"sapo is only available in TRL 0.26.0+")
- loss_i = torch.empty_like(coef_1)
- positive_advantages_mask = advantages.repeat([1, coef_1.shape[1]]) > 0
- #since we have n_chunks some tensors may error if they dont have elements in them
- if coef_1[positive_advantages_mask].numel() != 0:
- loss_i[positive_advantages_mask] = get_sapo_token_loss(
- coef_1[positive_advantages_mask], sapo_temperature_pos
- )
- if coef_1[~positive_advantages_mask].numel() != 0:
- loss_i[~positive_advantages_mask] = get_sapo_token_loss(
- coef_1[~positive_advantages_mask], sapo_temperature_neg
- )
- loss_i = -loss_i * advantages
- else:
- raise ValueError(f"Unknown loss type: {loss_type}")
-
- if off_policy_mask_threshold is not None:
- loss_i = loss_i * off_policy_mask
-
- if use_vllm and sampling_per_token_logps is not None:
- loss_i = loss_i * importance_sampling_ratio
- #delta for metric
- with torch.no_grad():
- delta = torch.abs(old - sampling_per_token_logps)
- delta = delta * mask
- flat_is_ratio = importance_sampling_ratio * mask
- else:
- delta = torch.tensor([]).detach()
- flat_is_ratio = torch.tensor([]).detach()
- if beta != 0.0:
- loss_i = loss_i + beta * kl_i
-
- mask = mask.to(torch.float32)
- n_mask_per_reward = mask.sum(1)
-
- # https://github.com/huggingface/trl/blob/e8b8499f1f8d76838155b515e414ee98f757d6d5/trl/trainer/grpo_trainer.py#L1624
- if loss_type in ["grpo", "sapo"]:
- loss = ((loss_i * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)).mean()
- loss = loss / current_gradient_accumulation_steps
- elif loss_type == "bnpo":
- loss = (loss_i * mask).sum() / mask.sum().clamp(min=1.0)
- loss = loss / current_gradient_accumulation_steps
- elif loss_type == "dr_grpo":
- loss = (loss_i * mask).sum() / (loss_i.size(0) * max_completion_length)
- loss = loss / current_gradient_accumulation_steps
- elif loss_type in ["cispo", "dapo"]:
- normalizer = num_items_in_batch/ num_processes
- loss = (loss_i * mask).sum() / normalizer
- else:
- raise ValueError(f"Unknown loss type: {loss_type}")
-
- # loss = (loss_i * mask).sum() / mask.sum()
-
- # Get metrics as well which are folded
- def masked_batch_mean(x):
- with torch.inference_mode():
- completion_length = n_mask_per_reward.mean()
- if x.shape[1] == 1: # when importance_sampling_level == "sequence"
- return completion_length, x.mean()
- else:
- mean_kl_per_reward = (x * mask).sum(1) / n_mask_per_reward
- mean_kl = mean_kl_per_reward.mean()
- return completion_length, mean_kl
- completion_length, mean_kl = masked_batch_mean(kl_i)
- return loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1
-
-class UnslothEfficientGRPO(torch.autograd.Function):
- # All Unsloth Zoo code licensed under AGPL3
- @staticmethod
- def forward(ctx, _new_logps, _old_logps, _ref_logps, _sampling_per_token_logps, lm_head, _input_ids, _mask, _advantages, beta, scaler = None, n_chunks = 1, extra_kwargs=None):
- if extra_kwargs is None:
- extra_kwargs = {}
- def compute_loss(new_logps, old_logps, ref_logps, sampling_per_token_logps, input_ids, mask, advantages, scaling):
- loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1 = grpo_compute_loss(
- ref_logps,
- new_logps,
- old_logps,
- sampling_per_token_logps,
- input_ids,
- mask,
- beta,
- advantages,
- **extra_kwargs,
- )
-
- # Scale loss if needed for mixed precision training
- scaled_loss = loss * scaling
- # Must add .loss.detach otherwise autograd uses 2x VRAM
- return scaled_loss, (loss.detach(), completion_length, mean_kl, delta, flat_is_ratio, coef_1)
- pass
-
- device =_new_logps.device
- grad_inputs = torch.empty_like(_new_logps)
- accumulated_loss = torch.zeros(1, device = device)
- accumulated_completion_length = torch.zeros(1, device = device)
- accumulated_mean_kl = torch.zeros(1, device = device)
- accumulated_delta = []
- accumulated_flat_is_ratio = []
- accumulated_coef_1 = []
-
- def accumulate_chunk(
- new_logps_j,
- old_logps_j,
- ref_logps_j,
- sampling_per_token_logps_j,
- input_ids_j,
- mask_j,
- advantages_j,
- scaling,
- grad_inputs_j,
- ):
- (chunk_grad_input,), (chunk_loss, (unscaled_loss, chunk_completion_length, chunk_mean_kl, chunk_delta, chunk_flat_is_ratio, chunk_coef_1)) = torch.func.grad_and_value(
- compute_loss,
- argnums = (0,),
- has_aux = True,
- )(new_logps_j, old_logps_j, ref_logps_j, sampling_per_token_logps_j, input_ids_j, mask_j, advantages_j, scaling)
- accumulated_loss .add_(unscaled_loss)
- accumulated_completion_length.add_(chunk_completion_length)
- accumulated_mean_kl .add_(chunk_mean_kl)
- accumulated_delta .append(chunk_delta)
- accumulated_flat_is_ratio .append(chunk_flat_is_ratio)
- accumulated_coef_1 .append(chunk_coef_1)
- grad_inputs_j[:] = chunk_grad_input
- pass
-
- accumulate_chunk = torch.compile(
- accumulate_chunk,
- fullgraph = True,
- # [TODO] Dynamic marking causes torch.compile errors if sequence length is long
- dynamic = True,
- options = torch_compile_options,
- )
-
- grad_inputs_chunks = torch.chunk(grad_inputs, chunks = n_chunks, dim = 0)
- new_logps = torch.chunk(_new_logps, chunks = n_chunks, dim = 0)
- if _old_logps is not None:
- old_logps = torch.chunk(_old_logps, chunks = n_chunks, dim = 0)
- else:
- old_logps = [None] * n_chunks
- if _ref_logps is not None:
- ref_logps = torch.chunk(_ref_logps, chunks = n_chunks, dim = 0)
- else:
- ref_logps = [None] * n_chunks
- if _sampling_per_token_logps is not None:
- sampling_per_token_logps = torch.chunk(_sampling_per_token_logps, chunks = n_chunks, dim = 0)
- else:
- sampling_per_token_logps = [None] * n_chunks
- input_ids = torch.chunk(_input_ids, chunks = n_chunks, dim = 0)
- mask = torch.chunk(_mask, chunks = n_chunks, dim = 0)
- advantages = torch.chunk(_advantages, chunks = n_chunks, dim = 0)
-
- # Get mixed precision scaling if seen
- scaling = scaler.get_scale() if scaler is not None else 1.0
-
- # Force torch.compile to use dynamic shapes for seqlen dim
- # mark_dynamic = lambda x: torch._dynamo.mark_dynamic(x, 1)
-
- for (grad_inputs_j, new_logps_j, old_logps_j, ref_logps_j, sampling_per_token_logps_j, input_ids_j, mask_j, advantages_j, ) in\
- zip(grad_inputs_chunks, new_logps, old_logps, ref_logps, sampling_per_token_logps, input_ids, mask, advantages):
-
- # [TODO] Dynamic marking causes torch.compile errors if sequence length is long
-
- # mark_dynamic(new_hidden_states_j)
- # mark_dynamic(ref_hidden_states_j)
- # if old_hidden_states_j is not None:
- # mark_dynamic(old_hidden_states_j)
- # mark_dynamic(input_ids_j)
- # mark_dynamic(mask_j)
- accumulate_chunk(
- new_logps_j,
- old_logps_j,
- ref_logps_j,
- sampling_per_token_logps_j,
- input_ids_j,
- mask_j,
- advantages_j,
- scaling,
- grad_inputs_j,
- )
- pass
-
- grad_inputs .div_(n_chunks)
- accumulated_loss .div_(n_chunks)
- accumulated_completion_length.div_(n_chunks)
- accumulated_mean_kl .div_(n_chunks)
-
- if _sampling_per_token_logps is not None:
- accumulated_delta = torch.cat(accumulated_delta, dim=0)
- accumulated_flat_is_ratio = torch.cat(accumulated_flat_is_ratio, dim=0)
- else:
- accumulated_delta = None
- accumulated_flat_is_ratio = None
- accumulated_coef_1 = torch.cat(accumulated_coef_1, dim=0)
- ctx.save_for_backward(grad_inputs)
- return (
- accumulated_loss,
- accumulated_completion_length,
- accumulated_mean_kl,
- accumulated_delta,
- accumulated_flat_is_ratio,
- accumulated_coef_1
- )
- pass
-
- @staticmethod
- def backward(ctx, grad_output, dcompletion_length, dmean_kl, ddelta, ddflat_is_ratio, dcoef_1):
- (grad_input,) = ctx.saved_tensors
- return (grad_input, None, None, None, None, None, None, None, None, None, None, None)
- pass
-
-def grpo_accumulated_loss(
- trainer,
- input_ids,
- attention_mask,
- logits_to_keep,
- completion_mask,
- advantages,
- old_logps,
- ref_logps,
- n_chunks = -1,
- **kwargs,
-):
- # All Unsloth Zoo code licensed under AGPL3
- bsz, qlen = input_ids.shape
-
- pixel_values = kwargs.get('pixel_values',None)
- image_grid_thw = kwargs.get('image_grid_thw',None)
- pixel_attention_mask = kwargs.get('pixel_attention_mask',None)
- image_sizes = kwargs.get('image_sizes',None)
- sampling_per_token_logps = kwargs.get("sampling_per_token_logps", None) if getattr(trainer, "vllm_importance_sampling_correction", False) else None
- temperature = kwargs.get("temperature", 1.0)
- logit_scale_multiply = kwargs.get("logit_scale_multiply", 0.0)
- logit_scale_divide = kwargs.get("logit_scale_divide", 0.0)
- logit_softcapping = kwargs.get("logit_softcapping", 0.0)
- prev_max_left_pad = kwargs.get("max_left_pad", 0) #Always get max_left_pad for when training LLMs, enabled by deafult.
-
- #Delete this from kwargs so less issues
- _ = kwargs.pop("sampling_per_token_logps", None)
- kwargs["vllm_importance_sampling_cap"] = trainer.vllm_importance_sampling_cap if sampling_per_token_logps is not None else None
- kwargs["get_sapo_token_loss"] = trainer.get_sapo_token_loss if hasattr(trainer, "get_sapo_token_loss") else None
- kwargs["sapo_temperature_pos"] = trainer.args.sapo_temperature_pos if hasattr(trainer.args, "sapo_temperature_pos") else None
- kwargs["sapo_temperature_neg"] = trainer.args.sapo_temperature_neg if hasattr(trainer.args, "sapo_temperature_neg") else None
- kwargs["get_off_policy_mask"] = trainer.get_off_policy_mask if hasattr(trainer, "get_off_policy_mask") else None
- kwargs["off_policy_mask_threshold"] = trainer.args.off_policy_mask_threshold if hasattr(trainer.args, "off_policy_mask_threshold") else None
- kwargs["use_vllm"] = trainer.use_vllm
- # Find closest multiple
- factors = [i for i in range(1, bsz + 1) if bsz % i == 0]
- if n_chunks == -1: n_chunks = bsz
- n_chunks = factors[min(np.searchsorted(factors, n_chunks), len(factors)-1)]
-
- if not hasattr(trainer, '_autocast_dtype'):
- trainer._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16
- if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1': trainer._autocast_dtype = None
- pass
- os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1"
-
- lm_head = trainer.model.get_output_embeddings().weight
- dtype_bytes = 16 if trainer._autocast_dtype in [torch.float16, torch.bfloat16] else 32
-
- total_rows = input_ids.shape[0]
- seq_len = input_ids.shape[1]
- hidden_dim = lm_head.shape[1]
- vocab_dim = lm_head.shape[0]
-
- if trainer.args.unsloth_grpo_mini_batch is None:
- if not hasattr(trainer, "_has_autotuned"):
- trainer._has_autotuned = True
- B, multiplier = autotune_batch_and_chunks(
- total_rows, seq_len, hidden_dim, vocab_dim, dtype_bytes, trainer.args.unsloth_logit_chunk_multiplier
- )
- trainer.args.unsloth_grpo_mini_batch = total_rows//B
- trainer.args.unsloth_logit_chunk_multiplier = multiplier
- B = trainer.args.unsloth_grpo_mini_batch
- multiplier = trainer.args.unsloth_logit_chunk_multiplier
- elif trainer._step % trainer.current_gradient_accumulation_steps == 0:
- B = trainer.args.unsloth_grpo_mini_batch
- multiplier = trainer.args.unsloth_logit_chunk_multiplier
- del trainer._has_autotuned
- del trainer.args.unsloth_grpo_mini_batch
- del trainer.args.unsloth_logit_chunk_multiplier
- else:
- B = trainer.unsloth_grpo_mini_batch
- multiplier = trainer.args.unsloth_logit_chunk_multiplier
- else:
- if trainer.args.unsloth_grpo_mini_batch > total_rows:
- B = total_rows
- else:
- B = trainer.args.unsloth_grpo_mini_batch
-
- if trainer.args.unsloth_logit_chunk_multiplier is None:
- multiplier = max(4, seq_len // 4096)
- else:
- multiplier = trainer.args.unsloth_logit_chunk_multiplier
-
- if pixel_values is None:
- left_pad_tokens_per_prompt = calculate_pad_tokens_in_prompt(input_ids, logits_to_keep, trainer.processing_class.pad_token_id)
-
- # Determine max_left_pad from precomputed logprobs shape for consistency
- if old_logps is not None:
- max_left_pad = old_logps.shape[1] - logits_to_keep
- elif ref_logps is not None:
- max_left_pad = ref_logps.shape[1] - logits_to_keep
- else:
- max_left_pad = torch.max(left_pad_tokens_per_prompt).item()
-
- input_ids = left_pack_padding(input_ids, trainer.processing_class.pad_token_id)
-
- completion_input_ids = input_ids[:, -(logits_to_keep +max_left_pad):]
-
- completion_mask = create_completion_attention_mask(completion_input_ids, left_pad_tokens_per_prompt, max_left_pad, trainer.processing_class.pad_token_id).to(attention_mask.dtype)
-
- if trainer.use_vllm and sampling_per_token_logps is not None and getattr(trainer, "vllm_importance_sampling_correction", False):
- sampling_per_token_logps = align_logprobs_with_mask(sampling_per_token_logps, completion_mask)
- else:
- sampling_per_token_logps = None
- attention_mask = input_ids != trainer.processing_class.pad_token_id
- attention_mask = attention_mask.to(attention_mask.dtype)
- else:
- completion_input_ids = input_ids[:, -logits_to_keep:]
-
- unwrapped_model = trainer.accelerator.unwrap_model(trainer.model, keep_fp32_wrapper = False)
-
- for module in unwrapped_model.modules():
- if hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "io_same_decice"):
- module._hf_hook.io_same_decice = False
- pass
-
- all_logprobs_list = []
-
- attention_mask_chunks = torch.chunk(attention_mask, chunks=B, dim=0)
- completion_ids_chunks = torch.chunk(completion_input_ids, chunks=B, dim=0)
-
- def chunk_optional(tensor, chunks):
- if tensor is None:
- return [None] * chunks
- return torch.chunk(tensor, chunks=chunks, dim=0)
-
- import math
- total_samples = input_ids.shape[0]
- batch_size = math.ceil(total_samples / B)
-
- input_ids_chunks = []
- attention_mask_chunks = []
- pixel_values_chunks = []
- image_grid_thw_chunks = []
- pixel_attention_mask_chunks = []
-
- current_pixel_idx = 0
- #TRL 0.23.0 batching logic
- for start in range(0, total_samples, batch_size):
- end = start + batch_size
-
- input_ids_chunks.append(input_ids[start:end])
- attention_mask_chunks.append(attention_mask[start:end])
-
- if image_grid_thw is not None and pixel_values is not None:
-
- grid_slice = image_grid_thw[start:end]
- image_grid_thw_chunks.append(grid_slice)
- batch_pixel_count = grid_slice.prod(dim=-1).sum().item()
-
- start_pixel_idx = current_pixel_idx
- end_pixel_idx = current_pixel_idx + batch_pixel_count
-
- pixel_values_chunks.append(pixel_values[start_pixel_idx:end_pixel_idx])
-
- if pixel_attention_mask is not None:
- pixel_attention_mask_chunks.append(
- pixel_attention_mask[start_pixel_idx:end_pixel_idx]
- )
- else:
- pixel_attention_mask_chunks.append(None)
-
- current_pixel_idx = end_pixel_idx
-
- else:
- pixel_values_chunks.append(None)
- image_grid_thw_chunks.append(None)
- pixel_attention_mask_chunks.append(None)
-
- if image_sizes is not None and not isinstance(image_sizes, torch.Tensor):
- image_sizes_chunks = [[size] for size in image_sizes]
- else:
- image_sizes_chunks = chunk_optional(image_sizes, B)
-
- zipped_inputs = zip(
- input_ids_chunks,
- attention_mask_chunks,
- pixel_values_chunks,
- image_grid_thw_chunks,
- pixel_attention_mask_chunks,
- image_sizes_chunks,
- completion_ids_chunks
- )
-
- if trainer._autocast_dtype is None:
- autocaster = nullcontext()
- else:
- autocaster = torch.amp.autocast(device_type = trainer.model.device.type, dtype = trainer._autocast_dtype)
-
- def to_device(tensor, device, non_blocking=True):
- if tensor is None: return None
- return tensor.to(device, non_blocking=non_blocking)
-
- class Unsloth_Offloaded_Log_Softmax(torch.autograd.Function):
- """
- Manual Gradient Checkpointing/CPU Offloading for Log Softmax.
- """
- @staticmethod
- def forward(ctx, hidden_states, lm_head, index, chunks,
- logit_scale_multiply, logit_scale_divide,
- logit_softcapping, temperature):
-
- ctx.saved_hidden_states = to_device(hidden_states, "cpu", non_blocking=True)
- ctx.device = hidden_states.device
- ctx.dtype = hidden_states.dtype
-
- ctx.lm_head = lm_head
- ctx.lm_head_requires_grad = lm_head.requires_grad
- ctx.index = index
- ctx.args = (chunks, logit_scale_multiply, logit_scale_divide, logit_softcapping, temperature)
-
- with torch.no_grad():
- output = chunked_hidden_states_selective_log_softmax(
- hidden_states, lm_head, index, *ctx.args
- )
-
- return output
-
- @staticmethod
- def backward(ctx, grad_output):
- hidden_states = to_device(ctx.saved_hidden_states, ctx.device)
- hidden_states = hidden_states.to(ctx.dtype)
- hidden_states.requires_grad_(True)
-
- lm_head = ctx.lm_head
- # #Possibly redundant lines
- # if ctx.lm_head_requires_grad:
- # hidden_states.requires_grad_(True)
- # else:
- # lm_head = lm_head.detach()
-
- index = ctx.index
-
- with torch.enable_grad():
- output = chunked_hidden_states_selective_log_softmax(
- hidden_states, lm_head, index, *ctx.args
- )
-
- torch.autograd.backward(output, grad_output)
-
- return (
- hidden_states.grad,
- lm_head.grad if ctx.lm_head_requires_grad else None,
- None,
- None,
- None,
- None,
- None,
- None,
- )
-
- def efficient_log_softmax(hidden_states, lm_head, index, chunks=32,
- logit_scale_multiply=0.0, logit_scale_divide=0.0,
- logit_softcapping=0.0, temperature=1, batch_size=8):
- if (index.shape[1] <= 1024 and batch_size <= 8) or batch_size==1:
- #We save a gigabyte or speed with the normal path under these specific conditions
- return chunked_hidden_states_selective_log_softmax(
- hidden_states,
- lm_head,
- index,
- chunks,
- logit_scale_multiply,
- logit_scale_divide,
- logit_softcapping,
- temperature
- )
- else:
- return Unsloth_Offloaded_Log_Softmax.apply(
- hidden_states, lm_head, index, chunks,
- logit_scale_multiply, logit_scale_divide,
- logit_softcapping, temperature
- )
- for (
- input_ids_chunk,
- attention_mask_chunk,
- pixel_values_chunk,
- image_grid_thw_chunk,
- pixel_attention_mask_chunk,
- image_sizes_chunk,
- completion_ids
- ) in zipped_inputs:
- with autocaster:
- if pixel_values is None:
- new_hidden_states_chunk = unwrapped_model(
- input_ids = input_ids_chunk,
- attention_mask = attention_mask_chunk,
- pixel_values = pixel_values_chunk,
- image_grid_thw = image_grid_thw_chunk,
- pixel_attention_mask = pixel_attention_mask_chunk,
- image_sizes = image_sizes_chunk,
- ).logits
-
- new_hidden_states_chunk = new_hidden_states_chunk[:, -(logits_to_keep + max_left_pad + 1): , :]
- new_hidden_states_chunk = new_hidden_states_chunk[:, :-1, :]
- else:
- new_hidden_states_chunk = unwrapped_model(
- input_ids = input_ids_chunk,
- attention_mask = attention_mask_chunk,
- pixel_values = pixel_values_chunk,
- image_grid_thw = image_grid_thw_chunk,
- pixel_attention_mask = pixel_attention_mask_chunk,
- image_sizes = image_sizes_chunk,
- logits_to_keep = logits_to_keep + 1,
- ).logits
-
- new_hidden_states_chunk = new_hidden_states_chunk[:, :-1, :]
-
- logprobs_chunk = efficient_log_softmax(
- new_hidden_states_chunk,
- lm_head,
- completion_ids,
- chunks=input_ids_chunk.shape[0]*multiplier,
- logit_scale_multiply=logit_scale_multiply,
- logit_scale_divide=logit_scale_divide,
- logit_softcapping=logit_softcapping,
- temperature=temperature,
- batch_size = B
- )
- #This is needed to avoid race conditions with GPT OSS offload_embbed=True
- #However, it seems that this line does not slow down or disrupt models.
- device_synchronize()
- all_logprobs_list.append(logprobs_chunk)
-
- new_logprobs = torch.cat(all_logprobs_list, dim=0)
-
- with autocaster:
- loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1 = UnslothEfficientGRPO.apply(
- new_logprobs,
- old_logps,
- ref_logps,
- sampling_per_token_logps,
- lm_head,
- completion_input_ids,
- completion_mask,
- advantages,
- trainer.beta,
- trainer.accelerator.scaler,
- 1,
- kwargs
- )
-
- # Must force not returning hidden states but logits otherwise gibberish
- os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "0"
-
- return loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1
- # Old non efficient code path
- new_logits = torch.matmul(new_hidden_states, lm_head.t())
- new_logits = new_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
- old_logits = torch.matmul(old_hidden_states, lm_head.t())
- old_logits = old_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
- loss, completion_length, mean_kl = grpo_compute_loss(
- old_logits,
- new_logits,
- completion_input_ids,
- completion_mask,
- trainer.beta,
- advantages,
- )
- return loss, completion_length, mean_kl
- pass
-
-@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options)
-def grpo_compute_loss_slow(
- ref,
- new,
- old,
- sampling_per_token_logps,
- input_ids,
- mask,
- beta,
- advantages,
- **kwargs
-):
- # All Unsloth Zoo code licensed under AGPL3
- # Set defaults for optional arguments
- loss_type = kwargs.get("loss_type", "grpo")
- epsilon_low = kwargs.get("epsilon_low", 0.2)
- epsilon_high = kwargs.get("epsilon_high", 0.2)
- max_completion_length = kwargs.get("max_completion_length", 8192)
- delta = kwargs.get("delta", None)
- importance_sampling_level = kwargs.get("importance_sampling_level", "token")
- num_items_in_batch = kwargs.get("num_items_in_batch", None)
- current_gradient_accumulation_steps = kwargs.get("current_gradient_accumulation_steps", 1)
- num_processes = kwargs.get("num_processes", 1)
- use_vllm = kwargs.get("use_vllm", False)
- vllm_importance_sampling_cap = kwargs.get("vllm_importance_sampling_cap", 2.0)
- get_sapo_token_loss = kwargs.get("get_sapo_token_loss", None)
- sapo_temperature_pos = kwargs.get("sapo_temperature_pos", 1.0)
- sapo_temperature_neg = kwargs.get("sapo_temperature_neg", 1.05)
- get_off_policy_mask = kwargs.get("get_off_policy_mask", None)
- off_policy_mask_threshold = kwargs.get("off_policy_mask_threshold", None)
- input_ids = input_ids.unsqueeze(-1)
-
- if advantages.dim() == 1:
- advantages = advantages.unsqueeze(1)
-
- if off_policy_mask_threshold is not None:
- off_policy_mask = get_off_policy_mask(
- advantages=advantages,
- per_token_logps=new,
- old_per_token_logps=old,
- mask=mask,
- off_policy_threshold=off_policy_mask_threshold,
- )
-
- with torch.no_grad():
- if use_vllm and sampling_per_token_logps is not None:
- #must filter out extra prompt tokens in begining after making input_ids left padded
- importance_sampling_ratio = torch.exp((old * mask) - sampling_per_token_logps)
- importance_sampling_ratio = torch.clamp(
- importance_sampling_ratio, max=vllm_importance_sampling_cap
- )
- pass
-
- # Must detach - otherwise gradients are not propagated correctly!
- # exp(x - x) == 1
- # loss_i = torch.exp(new - new.detach()) * advantages.unsqueeze(1)
- if old is not None:
- log_ratio = new - old
- else:
- log_ratio = new - new.detach()
-
- if importance_sampling_level == "token":
- log_importance_weights = log_ratio
- elif importance_sampling_level == "sequence":
- log_importance_weights = (log_ratio * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)
- log_importance_weights = log_importance_weights.unsqueeze(-1)
- else:
- raise ValueError(
- f"Unknown importance sampling level: {importance_sampling_level}. Possible values are 'token' "
- "and 'sequence'."
- )
-
- coef_1 = torch.exp(log_importance_weights)
-
- # Reverse KL
- # Note that this is a low variance low bias estimator for the KL divergence as used in GRPO paper
- if beta != 0.0:
- kl_i = torch.exp(ref - new) - (ref - new) - 1.0
-
- else:
- # set kl_i to a tensor of zeros with the correct shape
- if importance_sampling_level == "sequence":
- kl_i = new.new_zeros(new.size(0), 1)
- else:
- kl_i = torch.zeros_like(new)
- # Full correct reverse KL divergence?? Missing term maybe?
- # kl_i = torch.exp(new) * kl_i
-
- # Below is forward KL (normal KL)
- # kl_i = torch.exp(old) * (old - new)
- if loss_type == "cispo":
- clamped_ratios = torch.clamp(coef_1, max=epsilon_high).detach()
- loss_i = -clamped_ratios * advantages * new
- #breakpoint()
- elif loss_type in ["grpo", "bnpo", "dr_grpo", "dapo"]:
- coef_2 = torch.clamp(coef_1, 1 - epsilon_low, 1 + epsilon_high)
-
- if delta is not None:
- loss_1 = torch.clamp(coef_1, max=delta) * advantages
- else:
- loss_1 = coef_1 * advantages
- pass
- loss_2 = coef_2 * advantages
- loss_i = -torch.min(loss_1, loss_2)
- elif loss_type == "sapo":
- if get_sapo_token_loss is None:
- raise Exception(f"sapo is only available in TRL 0.26.0+")
- loss_i = torch.empty_like(coef_1)
- positive_advantages_mask = advantages.repeat([1, coef_1.shape[1]]) > 0
- #since we have n_chunks some tensors may error if they dont have elements in them
- if coef_1[positive_advantages_mask].numel() != 0:
- loss_i[positive_advantages_mask] = get_sapo_token_loss(
- coef_1[positive_advantages_mask], sapo_temperature_pos
- )
- if coef_1[~positive_advantages_mask].numel() != 0:
- loss_i[~positive_advantages_mask] = get_sapo_token_loss(
- coef_1[~positive_advantages_mask], sapo_temperature_neg
- )
- loss_i = -loss_i * advantages
- else:
- raise ValueError(f"Unknown loss type: {loss_type}")
-
- if off_policy_mask_threshold is not None:
- loss_i = loss_i * off_policy_mask
-
- if use_vllm and sampling_per_token_logps is not None:
- loss_i = loss_i * importance_sampling_ratio
- #delta for metric
- with torch.no_grad():
- delta = torch.abs(old - sampling_per_token_logps)
- delta = delta * mask
- flat_is_ratio = importance_sampling_ratio * mask
- else:
- delta = torch.tensor([]).detach()
- flat_is_ratio = torch.tensor([]).detach()
- if beta != 0.0:
- loss_i = loss_i + beta * kl_i
-
- mask = mask.to(torch.float32)
- n_mask_per_reward = mask.sum(1)
-
- # https://github.com/huggingface/trl/blob/e8b8499f1f8d76838155b515e414ee98f757d6d5/trl/trainer/grpo_trainer.py#L1624
- if loss_type in ["grpo", "sapo"]:
- loss = ((loss_i * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)).mean()
- loss = loss / current_gradient_accumulation_steps
- elif loss_type == "bnpo":
- loss = (loss_i * mask).sum() / mask.sum().clamp(min=1.0)
- loss = loss / current_gradient_accumulation_steps
- elif loss_type == "dr_grpo":
- loss = (loss_i * mask).sum() / (loss_i.size(0) * max_completion_length)
- loss = loss / current_gradient_accumulation_steps
- elif loss_type in ["cispo", "dapo"]:
- normalizer = num_items_in_batch/ num_processes
- loss = (loss_i * mask).sum() / normalizer
- else:
- raise ValueError(f"Unknown loss type: {loss_type}")
-
- # loss = (loss_i * mask).sum() / mask.sum()
-
- # Get metrics as well which are folded
- def masked_batch_mean(x):
- with torch.inference_mode():
- completion_length = n_mask_per_reward.mean()
- if x.shape[1] == 1: # when importance_sampling_level == "sequence"
- return completion_length, x.mean()
- else:
- mean_kl_per_reward = (x * mask).sum(1) / n_mask_per_reward
- mean_kl = mean_kl_per_reward.mean()
- return completion_length, mean_kl
- completion_length, mean_kl = masked_batch_mean(kl_i)
- return loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1
-
-def grpo_update_SamplingParams(SamplingParams, generation_kwargs, vllm_sampling_params = None):
- good_sampling_params_keys = inspect.signature(SamplingParams).parameters.keys()
-
- # Filter generation_kwargs
- new_generation_kwargs = {}
- for key in generation_kwargs.keys():
- if key in good_sampling_params_keys:
- new_generation_kwargs[key] = generation_kwargs[key]
- generation_kwargs = new_generation_kwargs
-
- if vllm_sampling_params is not None:
- for key in good_sampling_params_keys:
- if hasattr(vllm_sampling_params, key):
- overwrited_key = getattr(vllm_sampling_params, key)
- if overwrited_key is not None and (type(overwrited_key) in (list, tuple,) and len(overwrited_key) != 0):
- generation_kwargs[key] = overwrited_key
- return generation_kwargs
-
-def _get_inference_mode_context_manager(model: torch.nn.Module):
- """
- If the state dict was quantized using torchao, we will run into
- the following error when calling ops like aten.t() in inference mode.
- This is a bug in PyTorch that affects all tensor subclasses.
-
- Cannot set version_counter for inference tensor
-
- For now, we work around this issue by using `torch.no_grad()` in this case.
- See https://github.com/pytorch/pytorch/issues/164872 for more details.
- Otherwise, just return `torch.inference_mode()`.
- """
- torchao_config = getattr(model, "torchao_config", None)
- if torchao_config is not None and torchao_config.qat_scheme is None:
- return torch.no_grad()
- else:
- return torch.inference_mode()
-
-def vLLMSamplingParams(**kwargs):
- from vllm import SamplingParams
-
- sampling_params = SamplingParams(**kwargs)
- sampling_params._set_kwargs = kwargs
- return sampling_params
-@dataclass
-class UnslothGRPOConfig(GRPOConfig):
- """
-
- Configuration class for the [`GRPOTrainer`].
-
- This class includes only the parameters that are specific to GRPO training. For a full list of training arguments,
- please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may
- differ from those in [`~transformers.TrainingArguments`].
-
- Using [`~transformers.HfArgumentParser`] we can turn this class into
- [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
- command line.
-
- Parameters:
- > Parameters that control the model and reference model
-
- model_init_kwargs (`str`, `dict[str, Any]`, *optional*):
- Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
- argument of the [`GRPOTrainer`] is provided as a string.
- disable_dropout (`bool`, *optional*, defaults to `False`):
- Whether to disable dropout in the model. This is useful for training with a reference model, as it prevents
- the model from generating different logprobs for the same input.
-
- > Parameters that control the data preprocessing
-
- remove_unused_columns (`bool`, *optional*, defaults to `False`):
- Whether to only keep the column `"prompt"` in the dataset. If you use a custom reward function that
- requires any column other than `"prompts"` and `"completions"`, you should keep this to `False`.
- max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
- Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.
- num_generations (`int` or `None`, *optional*, defaults to `8`):
- Number of generations per prompt to sample. The effective batch size (num_processes * per_device_batch_size
- * gradient_accumulation_steps) must be evenly divisible by this value.
- max_completion_length (`int` or `None`, *optional*, defaults to `256`):
- Maximum length of the generated completion.
- ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
- This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
- improving generation speed. However, disabling this option allows training models that exceed the VRAM
- capacity of a single GPU, albeit at the cost of slower generation. Disabling this option is not compatible
- with vLLM generation.
- shuffle_dataset (`bool`, *optional*, defaults to `True`):
- Whether to shuffle the training dataset.
-
- > Parameters that control generation
-
- generation_batch_size: (`int`, *optional*):
- Batch size to use for generation. If `None`, it defaults to the effective training batch size:
- `per_device_train_batch_size * num_processes * steps_per_generation`. In other words, there is one
- generation batch processed per optimization step. Mutually exclusive with `steps_per_generation`.
- steps_per_generation: (`int`, *optional*):
- Number of steps per generation. If `None`, it defaults to `gradient_accumulation_steps`. Mutually exclusive
- with `generation_batch_size`.
- temperature (`float`, defaults to `1.0`):
- Temperature for sampling. The higher the temperature, the more random the completions.
- top_p (`float`, *optional*, defaults to `1.0`):
- Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. Set to
- `1.0` to consider all tokens.
- top_k (`int`, *optional*):
- Number of highest probability vocabulary tokens to keep for top-k-filtering. If `None`, top-k-filtering is
- disabled and all tokens are considered.
- min_p (`float`, *optional*):
- Minimum token probability, which will be scaled by the probability of the most likely token. It must be a
- value between `0.0` and `1.0`. Typical values are in the `0.01-0.2` range.
- repetition_penalty (`float`, *optional*, defaults to `1.0`):
- Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far.
- Values > `1.0` encourage the model to use new tokens, while values < `1.0` encourage the model to repeat
- tokens.
- use_transformers_paged (`bool`, *optional*, defaults to `False`):
- Whether to use the `transformers` paged implementation for generation. If set to `True`, the `transformers`
- paged implementation will be used for generation instead of the default padded implementation. This
- parameter is only effective when `use_vllm` is set to `False`.
- cache_implementation (`str`, *optional*):
- Implementation of the cache method for faster generation when `use_vllm` is set to `False`.
- generation_kwargs (`dict[str, Any]`, *optional*):
- Additional keyword arguments to pass to [`~transformers.GenerationConfig`] (if using transformers) or
- `SamplingParams` (if using vLLM) when sampling completions. This can be used to further customize the
- generation behavior, such as setting `suppress_tokens`, `num_beams`, etc. If it contains keys that conflict
- with the other generation parameters (like `min_p`, `top_p`, etc.), they will override them.
-
- > Parameters that control generation acceleration powered by vLLM
-
- use_vllm (`bool`, *optional*, defaults to `False`):
- Whether to use vLLM for generating completions. If set to `True`, the trainer will use vLLM for generation
- instead of the default model.generate(). Requires `vllm` to be installed.
- vllm_mode (`str`, *optional*, defaults to `"server"`):
- Mode to use for vLLM integration when `use_vllm` is set to `True`. Must be one of `"server"` or
- `"colocate"`.
-
- - `"server"`: The trainer will send generation requests to a separate vLLM server. Make sure a TRL vLLM
- server is running (start with `trl vllm-serve`).
- - `"colocate"`: vLLM will run in the same process and share the training GPUs. This avoids the need for a
- separate server but may cause resource contention with training.
- vllm_model_impl (`str`, *optional*, defaults to `"vllm"`):
- Model implementation to use for vLLM. Must be one of `"transformers"` or `"vllm"`. `"transformers"`: Use
- the `transformers` backend for model implementation. `"vllm"`: Use the `vllm` library for model
- implementation.
- vllm_guided_decoding_regex (`str`, *optional*):
- Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled.
-
- > Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`)
-
- vllm_server_base_url (`str`, *optional*):
- Base URL for the vLLM server (e.g., `"http://localhost:8000"`). If provided, `vllm_server_host` and
- `vllm_server_port` are ignored.
- vllm_server_host (`str`, *optional*, defaults to `"0.0.0.0"`):
- Host of the vLLM server to connect to. Ignored if `vllm_server_base_url` is provided.
- vllm_server_port (`int`, *optional*, defaults to `8000`):
- Port of the vLLM server to connect to. Ignored if `vllm_server_base_url` is provided.
- vllm_server_timeout (`float`, *optional*, defaults to `240.0`):
- Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up after the
- timeout, a `ConnectionError` is raised.
-
- > Parameters that control colocated vLLM execution (only used when `vllm_mode` is `"colocate"`)
-
- vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.3`):
- Control the GPU memory utilization for vLLM. This setting only applies when `vllm_mode` is set to
- `"colocate"`. If you are using `vllm_mode="server"`, this parameter must be passed separately when
- launching the vLLM server via the `--vllm_gpu_memory_utilization` flag.
- vllm_tensor_parallel_size (`int`, *optional*, defaults to `1`):
- Control the tensor parallel size for vLLM. This setting only applies when `vllm_mode` is set to
- `"colocate"`. If you are using `vllm_mode="server"`, this parameter must be passed separately when
- launching the vLLM server via the `--vllm_tensor_parallel_size` flag.
- vllm_enable_sleep_mode (`bool`, *optional*, defaults to `False`):
- Whether to enable sleep mode for vLLM. If `True`, vLLM will sleep during the optimization step and woken
- for weight sync and generation.
-
- > Parameters that control the training
-
- beta (`float`, *optional*, defaults to `0.0`):
- KL coefficient. If `0.0` (default), the reference model is not loaded, reducing memory usage and improving
- training speed.
- num_iterations (`int`, *optional*, defaults to `1`):
- Number of iterations per batch (denoted as μ in the algorithm).
- epsilon (`float`, *optional*, defaults to `0.2`):
- Epsilon value for clipping.
- delta (`float`, *optional*):
- Enables the upper clipping bound in two-sided GRPO loss when set to a float. If `None` (default), standard
- GRPO clipping is used. Recommended to be greater than `1 + ε` when enabled. This method is introduced in
- the [INTELLECT-2 tech report](https://huggingface.co/papers/2505.07291).
- epsilon_high (`float`, *optional*):
- Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the lower-bound
- specified in argument `epsilon`. Paper [DAPO](https://huggingface.co/papers/2503.14476) recommends `0.28`.
- importance_sampling_level (`str`, *optional*, defaults to `"token"`):
- Controls whether importance sampling ratios are computed at the `"token"` or `"sequence"` level. `"token"`
- keeps the raw per-token log-probability ratios (one weight per token). `"sequence"` averages the
- log-probability ratios across valid tokens to produce a single ratio per sequence. The [GSPO
- paper](https://huggingface.co/papers/2507.18071) shows that sequence-level sampling often yields more
- stable training and better alignment with sequence-level rewards.
- reward_weights (`list[float]`, *optional*):
- Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are
- weighted equally with weight `1.0`.
- scale_rewards (`str` or `bool`, *optional*, defaults to `"group"`):
- Specifies the scaling strategy for rewards. Supported values are:
-
- - `True` or `"group"` (default): rewards are scaled by the standard deviation within each group, ensuring
- unit variance within a group.
- - `"batch"`: rewards are scaled by the standard deviation across the entire batch, as recommended in the
- [PPO Lite paper](https://huggingface.co/papers/2508.08221).
- - `False` or `"none"`: no scaling is applied. The [Dr. GRPO
- paper](https://huggingface.co/papers/2503.20783) recommends not scaling rewards, as scaling by the
- standard deviation introduces a question-level difficulty bias.
- loss_type (`str`, *optional*, defaults to `"dapo"`):
- Specifies the loss formulation to use. Supported values are:
-
- - `"grpo"`: Aggregates token-level losses by normalizing over sequence length. Not recommended due to
- length bias—this approach tends to prefer shorter completions with positive advantages and longer ones
- with negative advantages.
- - `"dr_grpo"`: Aggregates token-level losses by normalizing with a global constant. This method was
- introduced in the [Dr. GRPO paper](https://huggingface.co/papers/2503.20783) to eliminate length bias.
- The value of the constant corresponds to `max_completion_length`.
- - `"dapo"` (default): Aggregates token-level losses by normalizing with the number of active token in the
- global accumulated batch. This method was introduced in the [DAPO
- paper](https://huggingface.co/papers/2503.14476) to eliminate length bias.
- - `"bnpo"`: Aggregates token-level losses by normalizing with the number of active token in the local
- batch. Note that normalization is performed over the local batch only, so results may slightly vary
- depending on the local batch size, despite a constant effective batch size. When using
- `per_device_train_batch_size==1`, the loss is equivalent to the GRPO loss.
- mask_truncated_completions (`bool`, *optional*, defaults to `False`):
- When enabled, truncated completions are excluded from the loss calculation, preventing them from being
- incorrectly penalized and introducing noise during training. According to the
- [DAPO](https://huggingface.co/papers/2503.14476) paper, this is a good practice for training stability.
- sync_ref_model (`bool`, *optional*, defaults to `False`):
- Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using
- the `ref_model_mixup_alpha` parameter. This synchronization originates from the
- [TR-DPO](https://huggingface.co/papers/2404.09656) paper.
- ref_model_mixup_alpha (`float`, *optional*, defaults to `0.6`):
- α parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which controls the mix
- between the current policy and the previous reference policy during updates. The reference policy is
- updated according to the equation: `π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you
- must set `sync_ref_model=True`.
- ref_model_sync_steps (`int`, *optional*, defaults to `512`):
- τ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which determines how
- frequently the current policy is synchronized with the reference policy. To use this parameter, you must
- set `sync_ref_model=True`.
- top_entropy_quantile (`float`, *optional*, defaults to `1.0`):
- ρ parameter from [Beyond the 80/20 Rule](https://huggingface.co/papers/2506.01939). Keeps in the policy
- loss term only the top-ρ quantile of tokens by entropy of the probability distribution at each sequence
- position, improving results. Range: `[0.0-1.0]`. A value of `0.0` masks all but the highest entropy token;
- `1.0` keeps all tokens. The paper recommends a value of `0.2`. If used with
- `mask_truncated_completions=True`, only tokens from non-truncated completions are considered.
- use_liger_loss (`bool`, *optional*, defaults to `False`):
- Whether to use the Liger GRPO loss.
- vllm_importance_sampling_correction (`bool`, *optional*, defaults to `True`):
- Whether to apply Truncated Importance Sampling (TIS) between vLLM completion logprobs and recomputed
- logprobs. [Your Efficient RL Framework Secretly Brings You Off-Policy RL
- Training](https://fengyao.notion.site/off-policy-rl) highlights that using a separate generation framework
- (such as vLLM) can introduce off-policy effects due to subtle implementation differences between generation
- and training backends. TIS is proposed as a remedy for this issue.
- vllm_importance_sampling_cap (`float`, *optional*, defaults to `2.0`):
- Truncation parameter C for Truncated Importance Sampling (TIS). This sets an upper bound on the importance
- sampling ratio, improving training stability.
-
- > Parameters that control the logging
-
- log_completions (`bool`, *optional*, defaults to `False`):
- Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is installed,
- it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`.
- num_completions_to_print (`int`, *optional*):
- Number of completions to print with `rich`. If `None`, all completions are logged.
- wandb_log_unique_prompts (`bool`, *optional*, defaults to `False`):
- Whether to log unique prompts in wandb. If `True`, only unique prompts are logged. If `False`, all prompts
- are logged.
-
- """
- vllm_sampling_params: Optional[Any] = field(
- default = None,
- metadata = {'help': 'vLLM SamplingParams'},
- )
- unsloth_num_chunks : Optional[int] = field(
- default = -1,
- metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
- )
- unsloth_logit_chunk_multiplier : Optional[int] = field(
- default = None,
- metadata = {'help': 'Multiplier for chunked logit computations.'},
- )
- unsloth_grpo_mini_batch : Optional[int] = field(
- default = None,
- metadata = {'help': 'Mini batch size for GRPO hidden state accumulation. Default is None unless user defines it.'},
- )
-
- def __init__(
- self,
- output_dir = None,
- overwrite_output_dir = None,
- do_train = False,
- do_eval = False,
- do_predict = False,
- eval_strategy = 'no',
- prediction_loss_only = False,
- per_device_train_batch_size = 4,
- per_device_eval_batch_size = 4,
- per_gpu_train_batch_size = None,
- per_gpu_eval_batch_size = None,
- gradient_accumulation_steps = 2,
- eval_accumulation_steps = 2,
- eval_delay = 0,
- torch_empty_cache_steps = 250,
- learning_rate = 5e-05,
- weight_decay = 0.01,
- adam_beta1 = 0.9,
- adam_beta2 = 0.999,
- adam_epsilon = 1e-08,
- max_grad_norm = 1.0,
- num_train_epochs = 3.0,
- max_steps = -1,
- lr_scheduler_type = 'linear',
- lr_scheduler_kwargs = None,
- warmup_ratio = 0.1,
- warmup_steps = 0,
- log_level = 'passive',
- log_level_replica = 'warning',
- log_on_each_node = True,
- logging_dir = None,
- logging_strategy = 'steps',
- logging_first_step = False,
- logging_steps = 1,
- logging_nan_inf_filter = False,
- save_strategy = 'steps',
- save_steps = 500,
- save_total_limit = None,
- save_safetensors = True,
- save_on_each_node = False,
- save_only_model = False,
- restore_callback_states_from_checkpoint = False,
- no_cuda = False,
- use_cpu = False,
- use_mps_device = False,
- seed = 3407,
- data_seed = 3407,
- jit_mode_eval = False,
- bf16 = False,
- fp16 = False,
- fp16_opt_level = 'O1',
- half_precision_backend = 'auto',
- bf16_full_eval = False,
- fp16_full_eval = False,
- tf32 = None,
- local_rank = -1,
- ddp_backend = None,
- tpu_num_cores = None,
- tpu_metrics_debug = False,
- debug = '',
- dataloader_drop_last = False,
- eval_steps = None,
- dataloader_num_workers = 0,
- dataloader_prefetch_factor = None,
- past_index = -1,
- run_name = None,
- disable_tqdm = None,
- remove_unused_columns = False,
- label_names = None,
- load_best_model_at_end = False,
- metric_for_best_model = None,
- greater_is_better = None,
- ignore_data_skip = False,
- fsdp = None,
- fsdp_min_num_params = 0,
- fsdp_config = None,
- fsdp_transformer_layer_cls_to_wrap = None,
- accelerator_config = None,
- parallelism_config = None,
- deepspeed = None,
- label_smoothing_factor = 0.0,
- optim = 'adamw_8bit',
- optim_args = None,
- adafactor = False,
- group_by_length = False,
- length_column_name = 'length',
- report_to = 'none',
- project = 'huggingface',
- trackio_space_id = 'trackio',
- ddp_find_unused_parameters = None,
- ddp_bucket_cap_mb = None,
- ddp_broadcast_buffers = None,
- dataloader_pin_memory = True,
- dataloader_persistent_workers = False,
- skip_memory_metrics = True,
- use_legacy_prediction_loop = False,
- push_to_hub = False,
- resume_from_checkpoint = None,
- hub_model_id = None,
- hub_strategy = 'every_save',
- hub_token = None,
- hub_private_repo = None,
- hub_always_push = False,
- hub_revision = None,
- gradient_checkpointing = True,
- gradient_checkpointing_kwargs = None,
- include_inputs_for_metrics = False,
- eval_do_concat_batches = True,
- fp16_backend = 'auto',
- push_to_hub_model_id = None,
- push_to_hub_organization = None,
- push_to_hub_token = None,
- mp_parameters = '',
- auto_find_batch_size = False,
- full_determinism = False,
- torchdynamo = None,
- ray_scope = 'last',
- ddp_timeout = 1800,
- torch_compile = False,
- torch_compile_backend = None,
- torch_compile_mode = None,
- include_tokens_per_second = False,
- include_num_input_tokens_seen = False,
- neftune_noise_alpha = None,
- optim_target_modules = None,
- batch_eval_metrics = False,
- eval_on_start = False,
- use_liger_kernel = False,
- liger_kernel_config = None,
- eval_use_gather_object = False,
- average_tokens_across_devices = True,
- model_init_kwargs = None,
- disable_dropout = False,
- max_prompt_length = 512,
- num_generations = 8,
- max_completion_length = 256,
- ds3_gather_for_generation = True,
- shuffle_dataset = True,
- generation_batch_size = None,
- steps_per_generation = None,
- temperature = 1.0,
- top_p = 1.0,
- top_k = None,
- min_p = None,
- generation_kwargs = {},
- repetition_penalty = 1.0,
- use_transformers_paged = False,
- cache_implementation = None,
- use_vllm = False,
- vllm_mode = 'colocate',
- vllm_model_impl = 'vllm',
- vllm_enable_sleep_mode = False,
- vllm_guided_decoding_regex = None,
- vllm_server_base_url = None,
- vllm_server_host = '0.0.0.0',
- vllm_server_port = 8000,
- vllm_server_timeout = 240.0,
- vllm_gpu_memory_utilization = 0.3,
- vllm_tensor_parallel_size = 1,
- beta = 0.001,
- num_iterations = 1,
- epsilon = 0.2,
- delta = None,
- epsilon_high = None,
- importance_sampling_level = 'token',
- reward_weights = None,
- scale_rewards = 'group',
- loss_type = 'bnpo',
- mask_truncated_completions = False,
- sync_ref_model = False,
- ref_model_mixup_alpha = 0.6,
- ref_model_sync_steps = 512,
- top_entropy_quantile = 1.0,
- use_liger_loss = False,
- vllm_importance_sampling_correction = False,
- vllm_importance_sampling_cap = 2.0,
- log_completions = False,
- num_completions_to_print = None,
- wandb_log_unique_prompts = False,
- vllm_sampling_params = None,
- unsloth_num_chunks = -1,
- unsloth_logit_chunk_multiplier = None,
- unsloth_grpo_mini_batch = None,
-
- **kwargs,
- ):
- if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
- if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
- if num_train_epochs is None:
- num_train_epochs = 3.0 # Default to 3 epochs if None, max_steps will override
- if output_dir is None and save_strategy == 'steps' and save_steps == 500:
- output_dir = 'unsloth_training_checkpoints'
- save_strategy = 'no'
- if loss_type.lower() == 'dr_grpo':
- loss_type = 'dr_grpo'
- elif loss_type.lower() == 'dapo':
- loss_type = 'dapo'
- if loss_type.lower() == 'dr_grpo':
- if scale_rewards == None:
- scale_rewards = True
- elif scale_rewards == True:
- print('Unsloth: The Dr GRPO paper recommends setting `scale_rewards` to False! Will override. Set it to `None` to force False.')
- scale_rewards = False
- elif loss_type.lower() == 'dapo':
- if mask_truncated_completions != True:
- print('Unsloth: The DAPO paper recommends `mask_truncated_completions = True` - we will set it.')
- if epsilon_high != 0.28:
- print('Unsloth: The DAPO paper recommends `epsilon_high = 0.28` - we will set it.')
- if beta != 0.0:
- print(f'[WARNING] Unsloth: The DAPO paper recommends setting `beta = 0.0` to remove the KL term - You have set it to {beta}.')
- mask_truncated_completions = True
- epsilon_high = 0.28
-
- if steps_per_generation is None and generation_batch_size is None:
- ga = gradient_accumulation_steps
- world_size = int(os.environ.get('WORLD_SIZE', '1'))
- if (ga * world_size * per_device_train_batch_size) % num_generations != 0:
- print('Unsloth: We now expect `per_device_train_batch_size` * `gradient_accumulation_steps` * `world_size` to be a multiple of `num_generations`.\nWe will change the batch size of ' + str(per_device_train_batch_size) + ' to the `num_generations` of ' + str(num_generations))
- per_device_train_batch_size = num_generations
-
- if temperature <= 0:
- raise ValueError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.')
- elif temperature >= 10:
- raise ValueError('Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic.')
-
- if use_vllm and (top_k is None or top_k == 0): top_k = -1
-
- super().__init__(
- output_dir = output_dir,
- overwrite_output_dir = overwrite_output_dir,
- do_train = do_train,
- do_eval = do_eval,
- do_predict = do_predict,
- eval_strategy = eval_strategy,
- prediction_loss_only = prediction_loss_only,
- per_device_train_batch_size = per_device_train_batch_size,
- per_device_eval_batch_size = per_device_eval_batch_size,
- per_gpu_train_batch_size = per_gpu_train_batch_size,
- per_gpu_eval_batch_size = per_gpu_eval_batch_size,
- gradient_accumulation_steps = gradient_accumulation_steps,
- eval_accumulation_steps = eval_accumulation_steps,
- eval_delay = eval_delay,
- torch_empty_cache_steps = torch_empty_cache_steps,
- learning_rate = learning_rate,
- weight_decay = weight_decay,
- adam_beta1 = adam_beta1,
- adam_beta2 = adam_beta2,
- adam_epsilon = adam_epsilon,
- max_grad_norm = max_grad_norm,
- num_train_epochs = num_train_epochs,
- max_steps = max_steps,
- lr_scheduler_type = lr_scheduler_type,
- lr_scheduler_kwargs = lr_scheduler_kwargs,
- warmup_ratio = warmup_ratio,
- warmup_steps = warmup_steps,
- log_level = log_level,
- log_level_replica = log_level_replica,
- log_on_each_node = log_on_each_node,
- logging_dir = logging_dir,
- logging_strategy = logging_strategy,
- logging_first_step = logging_first_step,
- logging_steps = logging_steps,
- logging_nan_inf_filter = logging_nan_inf_filter,
- save_strategy = save_strategy,
- save_steps = save_steps,
- save_total_limit = save_total_limit,
- save_safetensors = save_safetensors,
- save_on_each_node = save_on_each_node,
- save_only_model = save_only_model,
- restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
- no_cuda = no_cuda,
- use_cpu = use_cpu,
- use_mps_device = use_mps_device,
- seed = seed,
- data_seed = data_seed,
- jit_mode_eval = jit_mode_eval,
- bf16 = bf16,
- fp16 = fp16,
- fp16_opt_level = fp16_opt_level,
- half_precision_backend = half_precision_backend,
- bf16_full_eval = bf16_full_eval,
- fp16_full_eval = fp16_full_eval,
- tf32 = tf32,
- local_rank = local_rank,
- ddp_backend = ddp_backend,
- tpu_num_cores = tpu_num_cores,
- tpu_metrics_debug = tpu_metrics_debug,
- debug = debug,
- dataloader_drop_last = dataloader_drop_last,
- eval_steps = eval_steps,
- dataloader_num_workers = dataloader_num_workers,
- dataloader_prefetch_factor = dataloader_prefetch_factor,
- past_index = past_index,
- run_name = run_name,
- disable_tqdm = disable_tqdm,
- remove_unused_columns = remove_unused_columns,
- label_names = label_names,
- load_best_model_at_end = load_best_model_at_end,
- metric_for_best_model = metric_for_best_model,
- greater_is_better = greater_is_better,
- ignore_data_skip = ignore_data_skip,
- fsdp = fsdp,
- fsdp_min_num_params = fsdp_min_num_params,
- fsdp_config = fsdp_config,
- fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
- accelerator_config = accelerator_config,
- parallelism_config = parallelism_config,
- deepspeed = deepspeed,
- label_smoothing_factor = label_smoothing_factor,
- optim = optim,
- optim_args = optim_args,
- adafactor = adafactor,
- group_by_length = group_by_length,
- length_column_name = length_column_name,
- report_to = report_to,
- project = project,
- trackio_space_id = trackio_space_id,
- ddp_find_unused_parameters = ddp_find_unused_parameters,
- ddp_bucket_cap_mb = ddp_bucket_cap_mb,
- ddp_broadcast_buffers = ddp_broadcast_buffers,
- dataloader_pin_memory = dataloader_pin_memory,
- dataloader_persistent_workers = dataloader_persistent_workers,
- skip_memory_metrics = skip_memory_metrics,
- use_legacy_prediction_loop = use_legacy_prediction_loop,
- push_to_hub = push_to_hub,
- resume_from_checkpoint = resume_from_checkpoint,
- hub_model_id = hub_model_id,
- hub_strategy = hub_strategy,
- hub_token = hub_token,
- hub_private_repo = hub_private_repo,
- hub_always_push = hub_always_push,
- hub_revision = hub_revision,
- gradient_checkpointing = gradient_checkpointing,
- gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
- include_inputs_for_metrics = include_inputs_for_metrics,
- eval_do_concat_batches = eval_do_concat_batches,
- fp16_backend = fp16_backend,
- push_to_hub_model_id = push_to_hub_model_id,
- push_to_hub_organization = push_to_hub_organization,
- push_to_hub_token = push_to_hub_token,
- mp_parameters = mp_parameters,
- auto_find_batch_size = auto_find_batch_size,
- full_determinism = full_determinism,
- torchdynamo = torchdynamo,
- ray_scope = ray_scope,
- ddp_timeout = ddp_timeout,
- torch_compile = torch_compile,
- torch_compile_backend = torch_compile_backend,
- torch_compile_mode = torch_compile_mode,
- include_tokens_per_second = include_tokens_per_second,
- include_num_input_tokens_seen = include_num_input_tokens_seen,
- neftune_noise_alpha = neftune_noise_alpha,
- optim_target_modules = optim_target_modules,
- batch_eval_metrics = batch_eval_metrics,
- eval_on_start = eval_on_start,
- use_liger_kernel = use_liger_kernel,
- liger_kernel_config = liger_kernel_config,
- eval_use_gather_object = eval_use_gather_object,
- average_tokens_across_devices = average_tokens_across_devices,
- model_init_kwargs = model_init_kwargs,
- disable_dropout = disable_dropout,
- max_prompt_length = max_prompt_length,
- num_generations = num_generations,
- max_completion_length = max_completion_length,
- ds3_gather_for_generation = ds3_gather_for_generation,
- shuffle_dataset = shuffle_dataset,
- generation_batch_size = generation_batch_size,
- steps_per_generation = steps_per_generation,
- temperature = temperature,
- top_p = top_p,
- top_k = top_k,
- min_p = min_p,
- generation_kwargs = generation_kwargs,
- repetition_penalty = repetition_penalty,
- use_transformers_paged = use_transformers_paged,
- cache_implementation = cache_implementation,
- use_vllm = use_vllm,
- vllm_mode = vllm_mode,
- vllm_model_impl = vllm_model_impl,
- vllm_enable_sleep_mode = vllm_enable_sleep_mode,
- vllm_guided_decoding_regex = vllm_guided_decoding_regex,
- vllm_server_base_url = vllm_server_base_url,
- vllm_server_host = vllm_server_host,
- vllm_server_port = vllm_server_port,
- vllm_server_timeout = vllm_server_timeout,
- vllm_gpu_memory_utilization = vllm_gpu_memory_utilization,
- vllm_tensor_parallel_size = vllm_tensor_parallel_size,
- beta = beta,
- num_iterations = num_iterations,
- epsilon = epsilon,
- delta = delta,
- epsilon_high = epsilon_high,
- importance_sampling_level = importance_sampling_level,
- reward_weights = reward_weights,
- scale_rewards = scale_rewards,
- loss_type = loss_type,
- mask_truncated_completions = mask_truncated_completions,
- sync_ref_model = sync_ref_model,
- ref_model_mixup_alpha = ref_model_mixup_alpha,
- ref_model_sync_steps = ref_model_sync_steps,
- top_entropy_quantile = top_entropy_quantile,
- use_liger_loss = use_liger_loss,
- vllm_importance_sampling_correction = vllm_importance_sampling_correction,
- vllm_importance_sampling_cap = vllm_importance_sampling_cap,
- log_completions = log_completions,
- num_completions_to_print = num_completions_to_print,
- wandb_log_unique_prompts = wandb_log_unique_prompts,**kwargs)
- self.vllm_sampling_params = vllm_sampling_params
- self.unsloth_num_chunks = unsloth_num_chunks
- if unsloth_grpo_mini_batch is not None:
- if self.generation_batch_size >= unsloth_grpo_mini_batch:
- self.unsloth_grpo_mini_batch = unsloth_grpo_mini_batch
- else:
- raise ValueError(
- f"Unsloth GRPO mini batch size needs to be less than or equal to the effective generation batch size, "
- f"which is self.per_device_train_batch_size * gradient_accumulation_steps."
- )
- self.unsloth_logit_chunk_multiplier = unsloth_logit_chunk_multiplier
-
-
-pass
-
-class _UnslothGRPOTrainer(BaseTrainer):
- """"""
-
- _tag_names = ["trl", "grpo"]
- _name = "GRPO"
- _paper = {
- "title": "DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models",
- "id": "2402.03300",
- # docstyle-ignore
- "citation": textwrap.dedent("""\
- @article{shao2024deepseekmath,
- title = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}},
- author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo},
- year = 2024,
- eprint = {arXiv:2402.03300},
- }
- """),
- }
-
- def __init__(
- self,
- model: Union[str, PreTrainedModel],
- reward_funcs: Union[RewardFunc, list[RewardFunc]],
- args: Optional[GRPOConfig] = None,
- train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
- eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
- processing_class: Optional[Union[PreTrainedTokenizerBase, ProcessorMixin]] = None,
- reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
- callbacks: Optional[list[TrainerCallback]] = None,
- optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
- peft_config: Optional["PeftConfig"] = None,
- ):
-
- if hasattr(model, 'vllm_engine') and hasattr(args, 'use_vllm'):
- if (getattr(args, 'use_vllm', False) == False):
- args.use_vllm = True
- args.vllm_mode='colocate'
- if os.environ.get('UNSLOTH_VLLM_STANDBY', '0') == '1':
- args.vllm_enable_sleep_mode=True
- # Args
- if args is None:
- model_name = model if isinstance(model, str) else model.config._name_or_path
- model_name = model_name.split("/")[-1]
- args = GRPOConfig(f"{model_name}-GRPO")
-
- # Models
- # Trained model
- model_init_kwargs = args.model_init_kwargs or {}
- if isinstance(model, str):
- model_id = model
- dtype = model_init_kwargs.get("dtype")
- if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None:
- pass # dtype is already a torch.dtype or "auto" or None
- elif isinstance(dtype, str): # it's a str, but not "auto"
- dtype = getattr(torch, dtype)
- model_init_kwargs["dtype"] = dtype
- else:
- raise ValueError(
- "Invalid `dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
- f"a `torch.dtype` (e.g., 'float32'), but got {dtype}."
- )
- # Disable caching if gradient checkpointing is enabled [not supported]
- config = AutoConfig.from_pretrained(model_id)
- architecture = getattr(transformers, config.architectures[0])
- model = architecture.from_pretrained(model_id, **model_init_kwargs)
- else:
- model_id = model.config._name_or_path
- if args.model_init_kwargs is not None:
- logger.warning(
- "You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
- "The `model_init_kwargs` will be ignored."
- )
-
- # Some models [SmolVLM/Idefics3] don't support `logits_to_keep` argument and error out if we pass it
- # Inspect the forward method before we wrap the model with PEFT
- self.model_kwarg_keys = (
- inspect.signature(model.forward).parameters.keys()
- if not hasattr(model, "get_base_model")
- else inspect.signature(model.get_base_model().forward).parameters.keys()
- )
-
- if False:
- pass
-
- # Processing class
- if processing_class is None:
- processing_class = AutoProcessor.from_pretrained(model.config._name_or_path, truncation_side="left")
-
- # Handle pad token for processors or tokenizers
- if isinstance(processing_class, ProcessorMixin):
- tokenizer = processing_class.tokenizer
- elif isinstance(processing_class, PreTrainedTokenizerBase):
- tokenizer = processing_class
- else:
- raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`")
-
- if tokenizer.pad_token is None:
- tokenizer.pad_token = tokenizer.eos_token
-
- self.pad_token = tokenizer.pad_token
- self.pad_token_id = tokenizer.pad_token_id
- self.eos_token_id = tokenizer.eos_token_id
-
- # Reward functions
- if not isinstance(reward_funcs, list):
- reward_funcs = [reward_funcs]
- self.reward_func_names = []
- for i, reward_func in enumerate(reward_funcs):
- if isinstance(reward_func, str):
- reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
- reward_func, num_labels=1, **model_init_kwargs
- )
- if isinstance(reward_funcs[i], nn.Module): # Use Module over PretrainedModel for compat w/ compiled models
- self.reward_func_names.append(reward_funcs[i].config._name_or_path.split("/")[-1])
- else:
- self.reward_func_names.append(reward_funcs[i].__name__)
- self.reward_funcs = reward_funcs
-
- # Reward weights
- if args.reward_weights is not None:
- if len(args.reward_weights) != len(reward_funcs):
- raise ValueError(
- f"Number of reward weights ({len(args.reward_weights)}) must match number of reward "
- f"functions ({len(reward_funcs)})"
- )
- self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32)
- else:
- self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32)
-
- # Reward processing class
- if reward_processing_classes is None:
- reward_processing_classes = [None] * len(reward_funcs)
- elif not isinstance(reward_processing_classes, list):
- reward_processing_classes = [reward_processing_classes]
- if len(reward_processing_classes) != len(reward_funcs):
- raise ValueError(
- f"The number of reward processing classes ({len(reward_processing_classes)}) must match the number of "
- f"reward functions ({len(reward_funcs)})."
- )
-
- for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):
- if isinstance(reward_func, PreTrainedModel):
- if reward_processing_class is None:
- reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
- if reward_processing_class.pad_token_id is None:
- reward_processing_class.pad_token = reward_processing_class.eos_token
- # The reward model computes the reward for the latest non-padded token in the input sequence.
- # So it's important to set the pad token ID to the padding token ID of the processing class.
- reward_func.config.pad_token_id = reward_processing_class.pad_token_id
- reward_processing_classes[i] = reward_processing_class
-
- self.reward_processing_classes = reward_processing_classes
-
- # Training arguments
- self.max_prompt_length = args.max_prompt_length
- self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper
- self.num_generations = args.num_generations # = G in the GRPO paper
- self.temperature = args.temperature
- self.top_p = args.top_p
- self.top_k = args.top_k
- self.min_p = args.min_p
- self.repetition_penalty = args.repetition_penalty
- self.use_transformers_paged = args.use_transformers_paged
- self.use_vllm = args.use_vllm
- self.vllm_mode = args.vllm_mode
- self.vllm_gpu_memory_utilization = args.vllm_gpu_memory_utilization # only applies to colocation mode
- self.vllm_tensor_parallel_size = args.vllm_tensor_parallel_size # only applies to colocation mode
- self.vllm_importance_sampling_correction = args.vllm_importance_sampling_correction
- self.vllm_importance_sampling_cap = args.vllm_importance_sampling_cap
- self.use_liger_loss = args.use_liger_loss
- self.loss_type = args.loss_type
- self.scale_rewards = args.scale_rewards
- self.importance_sampling_level = args.importance_sampling_level
- self.mask_truncated_completions = args.mask_truncated_completions
- self.top_entropy_quantile = args.top_entropy_quantile
- if self.use_liger_loss and self.top_entropy_quantile < 1.0:
- raise NotImplementedError(
- "Liger Kernels don't currently support masking token positions based on entropy."
- )
- if self.use_liger_loss and not self.importance_sampling_level == "token":
- raise NotImplementedError(
- "Liger Kernels currently only support token-level importance sampling. Please set"
- "`importance_sampling_level` to 'token'."
- )
-
- # Datasets
- self.shuffle_dataset = args.shuffle_dataset
-
- if (
- isinstance(train_dataset, IterableDataset)
- or isinstance(eval_dataset, IterableDataset)
- or (
- isinstance(eval_dataset, dict) and any(isinstance(ds, IterableDataset) for ds in eval_dataset.values())
- )
- ):
- # See https://github.com/huggingface/trl/issues/3213
- raise NotImplementedError(
- "Iterable datasets are not yet supported in GRPOTrainer. Please use a standard dataset instead."
- )
-
- # Multi-step
- self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper
- self.epsilon_low = args.epsilon
- self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon
- # Tracks the number of iterations [forward + backward passes], including those within a grad accum cycle
- self._step = 0
- # Buffer the batch to reuse generated outputs across multiple updates. For more details, see
- # `_get_train_sampler` and `_prepare_inputs`.
- self._buffered_inputs = None
-
- # The trainer estimates the number of FLOPs [floating-point operations] using the number of elements in the
- # input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
- # "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
- # "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
- # suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
- # This acts as a flag to indicate that the warning has already been issued.
- model.warnings_issued["estimate_tokens"] = True
-
- super().__init__(
- model=model,
- args=args,
- data_collator=identity, # No data collation is needed in GRPO
- train_dataset=train_dataset,
- eval_dataset=eval_dataset,
- processing_class=processing_class,
- callbacks=callbacks,
- optimizers=optimizers,
- # In Trainer, `training_step` scales the loss by `gradient_accumulation_steps` only if `compute_loss_func`
- # is None. For DAPO, loss scaling instead depends on the total number of completions tokens across the
- # global accumulated batch. To control scaling ourselves, we must disable Trainer’s built-in scaling. The
- # simplest [though a bit hacky] way is to set `compute_loss_func` to any non-None value, which bypasses
- # that behavior without rewriting `training_step`.
- compute_loss_func="non-None value to disable scaling",
- )
-
- # Reference model
- self.beta = args.beta
- if self.beta == 0.0:
- # If beta is 0.0, the reference model is not needed
- self.ref_model = None
- elif is_peft_model(model):
- # If PEFT is used, the reference model is not needed since the adapter can be disabled
- # to revert to the initial model.
- self.ref_model = None
- else:
- # For deepspeed, fsdp or non-distributed models, create a reference model from scratch
- config = AutoConfig.from_pretrained(model_id)
- architecture = getattr(transformers, config.architectures[0])
- self.ref_model = architecture.from_pretrained(model_id, **model_init_kwargs)
-
- # Disable dropout in the models
- if args.disable_dropout:
- disable_dropout_in_model(model)
- if self.ref_model is not None:
- disable_dropout_in_model(self.ref_model)
-
- # Liger loss
- if self.use_liger_loss:
- if not is_liger_kernel_available():
- raise ImportError(
- "Liger is required to use `liger_loss` as the GRPO loss. Run `pip install liger-kernel`."
- )
- # redirect the model.module forward to the model forward to ensure pre-forward hooks are called
- self._forward_redirection = _ForwardRedirection()
-
- self.liger_grpo_loss = LigerFusedLinearGRPOLoss(
- beta=self.beta,
- epsilon_low=self.epsilon_low,
- epsilon_high=self.epsilon_high,
- temperature=self.temperature,
- use_ref_model=self.beta != 0.0,
- loss_type=self.loss_type,
- max_completion_length=self.max_completion_length,
- )
-
- # Initialize the metrics
- self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
- self._total_train_tokens = 0
- self.log_completions = args.log_completions
- self.wandb_log_unique_prompts = args.wandb_log_unique_prompts
- self.num_completions_to_print = args.num_completions_to_print
- # Keep logs sized to the generation batch to record only outputs from the latest model update.
- self._logs = {
- "images": deque(maxlen=args.generation_batch_size),
- "prompt": deque(maxlen=args.generation_batch_size),
- "completion": deque(maxlen=args.generation_batch_size),
- "rewards": defaultdict(lambda: deque(maxlen=args.generation_batch_size)),
- "advantages": deque(maxlen=args.generation_batch_size),
- }
-
- # Ensure each process receives a unique seed to prevent duplicate completions when generating with
- # transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but
- # it's safer to set it in all cases.
- set_seed(args.seed, device_specific=True)
-
- if self.use_vllm:
- if not is_vllm_available():
- raise ImportError(
- "vLLM is not available and `use_vllm` is set to True. Please install vLLM with "
- "`pip install trl[vllm]` to use it."
- )
-
- if self.vllm_mode == "server":
- if self.accelerator.is_main_process:
- if args.vllm_server_base_url is not None:
- base_url = args.vllm_server_base_url
- else:
- base_url = f"http://{args.vllm_server_host}:{args.vllm_server_port}"
- self.vllm_client = VLLMClient(base_url=base_url, connection_timeout=args.vllm_server_timeout)
- self.vllm_client.init_communicator(device=torch.cuda.current_device())
-
- elif self.vllm_mode == "colocate":
- if not self.accelerator.num_processes % self.vllm_tensor_parallel_size == 0:
- raise ValueError(
- f"vllm_tensor_parallel_size ({self.vllm_tensor_parallel_size}) must divide world size "
- f"({self.accelerator.num_processes}) evenly."
- )
-
- if self.vllm_tensor_parallel_size > 1:
- self.tp_group, _ = torch.distributed.new_subgroups_by_enumeration(
- [
- list(range(i * self.vllm_tensor_parallel_size, (i + 1) * self.vllm_tensor_parallel_size))
- for i in range(self.accelerator.num_processes // self.vllm_tensor_parallel_size)
- ]
- )
- os.environ["RANK"] = str(self.accelerator.process_index)
- os.environ["LOCAL_RANK"] = str(self.accelerator.local_process_index)
- os.environ["WORLD_SIZE"] = str(self.accelerator.num_processes)
- ensure_master_addr_port()
-
- if self.max_prompt_length is not None and self.max_completion_length is not None:
- max_model_len = self.max_prompt_length + self.max_completion_length
- else:
- max_model_len = None
- self.llm = model.vllm_engine
- if self.args.vllm_enable_sleep_mode:
- self.llm.sleep(level=1)
- else:
- raise ValueError(f"vllm_mode must be either 'server' or 'colocate', got '{self.vllm_mode}'.")
- self.guided_decoding_regex = args.vllm_guided_decoding_regex
-
- self._last_loaded_step = -1
- self.accelerator.wait_for_everyone()
- else:
- generation_kwargs = {
- "max_new_tokens": self.max_completion_length,
- "do_sample": True,
- "pad_token_id": tokenizer.pad_token_id,
- "bos_token_id": tokenizer.bos_token_id,
- "eos_token_id": tokenizer.eos_token_id,
- "temperature": self.temperature,
- "top_p": self.top_p,
- "top_k": self.top_k,
- "min_p": self.min_p,
- "repetition_penalty": self.repetition_penalty,
- "cache_implementation": args.cache_implementation,
- }
- if args.generation_kwargs is not None:
- generation_kwargs.update(args.generation_kwargs)
- self.generation_config = GenerationConfig(**generation_kwargs)
-
- # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
- # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
- # self.model_accepts_loss_kwargs to False to enable scaling.
- self.model_accepts_loss_kwargs = False
-
- # Add tags to the model
- self.model.add_model_tags(self._tag_names)
-
- if self.ref_model is not None:
- if self.is_deepspeed_enabled:
- self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
- elif self.is_fsdp_enabled:
- self.ref_model = prepare_fsdp(self.ref_model, self.accelerator)
- else:
- self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
-
- if args.sync_ref_model:
- self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator))
-
- for i, reward_func in enumerate(self.reward_funcs):
- if isinstance(reward_func, PreTrainedModel):
- if self.is_deepspeed_enabled:
- self.reward_funcs[i] = prepare_deepspeed(reward_func, self.accelerator)
- else:
- # set device placement to True to make `prepare_model` move `reward_func` to device when using fsdp
- self.reward_funcs[i] = self.accelerator.prepare_model(
- reward_func, evaluation_mode=True, device_placement=True
- )
-
- def _set_signature_columns_if_needed(self):
- # If `self.args.remove_unused_columns` is True, non-signature columns are removed.
- # By default, this method sets `self._signature_columns` to the model's expected inputs.
- # In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
- # Instead, we set them to the columns expected by the `training_step` method, hence the override.
- if self._signature_columns is None:
- self._signature_columns = ["prompt", "image", "images"]
-
- # This method overrides `Trainer.get_train_dataloader` to support our custom batching strategy.
- # Instead of returning a standard per-step batch (i.e., `per_device_batch_size), our dataloader loads an
- # *generation* batch (i.e., `per_device_batch_size × steps_per_generation`). This allows us to generate completions
- # once every steps_per_generation step—rather than once per accumulation step—which is significantly more
- # efficient. The only change from the original implementation is multiplying the batch size by
- # `steps_per_generation`. Thus, `_prepare_inputs` is called with this *generation* batch, and it handles the
- # splitting internally.
- # Maintenance note: This method is a copy-paste of the original `Trainer.get_train_dataloader` with only one line
- # modification. As a result, some parts of the method aren't relevant to GRPO, but we keep them to stay one line
- # apart from the super method, ensuring easier maintenance in the future.
- def get_train_dataloader(self):
- if self.train_dataset is None:
- raise ValueError("Trainer: training requires a train_dataset.")
-
- train_dataset = self.train_dataset
- data_collator = self.data_collator
- if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
- train_dataset = self._remove_unused_columns(train_dataset, description="training")
- else:
- data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
-
- dataloader_params = {
- "batch_size": self._train_batch_size * self.args.steps_per_generation, # < this is the change
- "collate_fn": data_collator,
- "num_workers": self.args.dataloader_num_workers,
- "pin_memory": self.args.dataloader_pin_memory,
- "persistent_workers": self.args.dataloader_persistent_workers,
- }
-
- if not isinstance(train_dataset, torch.utils.data.IterableDataset):
- dataloader_params["sampler"] = self._get_train_sampler()
- dataloader_params["drop_last"] = self.args.dataloader_drop_last
- dataloader_params["worker_init_fn"] = partial(
- seed_worker, num_workers=self.args.dataloader_num_workers, rank=self.args.process_index
- )
-
- dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
-
- return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
-
- def _get_train_sampler(self, dataset: Optional[Dataset] = None) -> Sampler:
- # Returns a sampler that
- # 1. ensures each prompt is repeated across multiple processes. This guarantees that identical prompts are
- # distributed to different GPUs, allowing rewards to be computed and normalized correctly within each prompt
- # group. Using the same seed across processes ensures consistent prompt assignment, preventing discrepancies
- # in group formation.
- # 2. repeats the batch multiple times to allow reusing generations across multiple updates. Refer to
- # _prepare_inputs to see how the generations are stored and reused.
-
- # In the following figure, the values are the prompt indices. The first row shows the first sampled batch, the
- # second row shows the second sampled batch, and so on.
- #
- # | GPU 0 | GPU 1 |
- #
- # global_step step <-───> num_generations=2
- # <-───────> per_device_train_batch_size=3
- # grad_accum ▲ ▲ 0 0 0 0 1 1 2 2 <- Generate for the first `steps_per_generation` (prompts 0 to 11); store the completions; use the first slice to compute the loss
- # =2 ▼ | 0 1 3 3 4 4 5 5 <- Take the stored generations and use the second slice to compute the loss
- # |
- # | 1 2 6 6 7 7 8 8 <- Take the stored generations and use the third slice to compute the loss
- # steps_per_gen=4 ▼ 1 3 9 9 10 10 11 11 <- Take the stored generations and use the fourth slice to compute the loss
- #
- # 2 4 12 12 13 13 14 14 <- Generate for the second `steps_per_generation` (prompts 12 to 23); store the completions; use the first slice to compute the loss
- # 2 5 15 15 16 16 17 17 <- Take the stored generations and use the second slice to compute the loss
- # ...
- if dataset is None:
- dataset = self.train_dataset
- return RepeatSampler(
- data_source=dataset,
- mini_repeat_count=self.num_generations,
- batch_size=self.args.generation_batch_size // self.num_generations,
- repeat_count=self.num_iterations * self.args.steps_per_generation,
- shuffle=self.shuffle_dataset,
- seed=self.args.seed,
- )
-
- def _get_eval_sampler(self, eval_dataset) -> Sampler:
- # See _get_train_sampler for an explanation of the sampler.
- return RepeatSampler(
- data_source=eval_dataset,
- mini_repeat_count=self.num_generations,
- seed=self.args.seed,
- )
-
- @profiling_decorator
- def _get_last_hidden_state(
- self,
- unwrapped_model,
- input_ids,
- attention_mask,
- logits_to_keep,
- pixel_values=None,
- image_grid_thw=None,
- pixel_attention_mask=None,
- image_sizes=None,
- ):
- if is_peft_model(unwrapped_model):
- unwrapped_model = unwrapped_model.base_model.model
-
- # Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't)
- model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
-
- # For Qwen models:
- if image_grid_thw is not None and pixel_values is not None:
- model_inputs["image_grid_thw"] = image_grid_thw
- # For Gemma, SmolVLM2, LLaVa-Next etc.:
- if pixel_values is not None:
- model_inputs["pixel_values"] = pixel_values
- # For SmolVLM2
- if pixel_attention_mask is not None:
- model_inputs["pixel_attention_mask"] = pixel_attention_mask
- # For LLaVa-Next
- if image_sizes is not None:
- model_inputs["image_sizes"] = image_sizes
-
- # Only add logits_to_keep if the model supports it
- if "logits_to_keep" in self.model_kwarg_keys:
- # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
- model_inputs["logits_to_keep"] = logits_to_keep + 1
-
- model_inputs["use_cache"] = False # only used in generation; set False to suppress warnings
-
- last_hidden_state = unwrapped_model.model(**model_inputs).last_hidden_state
- # Exclude the last value: it corresponds to the next token pred
- last_hidden_state = last_hidden_state[:, :-1, :] # (B, L-1, H)
- # Only keep the last logits_to_keep. For model that support logits_to_keep, this is a no-op.
- last_hidden_state = last_hidden_state[:, -logits_to_keep:, :] # (B, logits_to_keep, H)
- return last_hidden_state
-
- def get_high_entropy_mask(self, entropies: torch.Tensor, mask: torch.Tensor, threshold: float) -> torch.Tensor:
- """
- Returns a binary mask identifying tokens whose entropy exceeds a given quantile threshold.
-
- Args:
- entropies (`torch.Tensor`):
- Tensor of shape (batch_size, seq_len) with per-token entropy values.
- mask (`torch.Tensor`):
- Binary mask of the same shape as `entropies`, where `1` indicates valid tokens and `0` padding.
- threshold (`float`):
- Quantile threshold between `0.0` and `1.0` to select high-entropy tokens.
-
- Returns:
- `torch.Tensor`:
- Boolean mask of shape (batch_size, seq_len), where `True` indicates tokens with entropy >= threshold
- and `False` otherwise.
- """
- local = entropies[mask.bool()].float()
-
- # Use a negative pad_value as a sentinel because entropy values are always >= 0.
- # This guarantees that the sentinel cannot collide with any real entropy value.
- pad_value = -1e9
-
- # Pad across processes so that every rank has the same tensor length
- padded = self.accelerator.pad_across_processes(local, dim=0, pad_index=pad_value)
- gathered = self.accelerator.gather(padded)
-
- # Drop sentinel values (safe because no entropy can be negative)
- gathered = gathered[gathered != pad_value]
-
- if gathered.numel() == 0:
- return torch.zeros_like(entropies, dtype=torch.bool)
-
- entropy_threshold = torch.quantile(gathered, threshold)
- masked_entropies = entropies * mask.float()
- entropy_mask = masked_entropies >= entropy_threshold
- return entropy_mask & mask.bool() # ensure padding tokens are always masked out
-
- def _get_per_token_logps_and_entropies(
- self,
- model,
- input_ids,
- attention_mask,
- logits_to_keep,
- batch_size = None,
- compute_entropy = False,
- compute_efficient = False,
- *args,
- **kwargs,
- ):
- # All Unsloth code here in this function is licensed under AGPL3
- # if True: # os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0':
- # return None, None # logps, entropies Unsloth efficient GRPO
- if compute_efficient:
- return None, None
- else:
- if not hasattr(self, "_autocast_dtype"):
- self._autocast_dtype = (
- torch.float16
- if os.environ.get("ACCELERATE_MIXED_PRECISION", "fp16") == "fp16"
- else torch.bfloat16
- )
- if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1":
- self._autocast_dtype = torch.float16
-
- pixel_values, image_grid_thw = (
- kwargs.get("pixel_values", None),
- kwargs.get("image_grid_thw", None),
- )
- pixel_attention_mask, image_sizes = (
- kwargs.get("pixel_attention_mask", None),
- kwargs.get("image_sizes", None),
- )
-
- unwrapped_model = self.accelerator.unwrap_model(
- model, keep_fp32_wrapper = False
- )
-
- lm_head = self.model.get_output_embeddings().weight
-
- dtype_bytes = (
- 16 if self._autocast_dtype in [torch.float16, torch.bfloat16] else 32
- )
- total_rows = input_ids.shape[0]
- seq_len = input_ids.shape[1]
- hidden_dim = lm_head.shape[1]
- vocab_dim = lm_head.shape[0]
-
- if self.args.unsloth_grpo_mini_batch is None:
- B, multiplier = autotune_batch_and_chunks(
- total_rows,
- seq_len,
- hidden_dim,
- vocab_dim,
- dtype_bytes,
- self.args.unsloth_logit_chunk_multiplier,
- )
- B = total_rows // B
- else:
- B = self.args.unsloth_grpo_mini_batch
-
- if self.args.unsloth_logit_chunk_multiplier is None:
- multiplier = max(4, seq_len // 4096)
- else:
- multiplier = self.args.unsloth_logit_chunk_multiplier
-
- all_logprobs_list = []
- if pixel_values is None:
- left_pad_tokens_per_prompt = calculate_pad_tokens_in_prompt(
- input_ids, logits_to_keep, self.processing_class.pad_token_id
- )
- max_left_pad = torch.max(left_pad_tokens_per_prompt).item()
- input_ids = left_pack_padding(
- input_ids, self.processing_class.pad_token_id
- )
- attention_mask = input_ids != self.processing_class.pad_token_id
- attention_mask = attention_mask.to(attention_mask.dtype)
- else:
- max_left_pad = 0
-
- # input_ids_chunks = torch.chunk(input_ids, chunks = B, dim = 0)
- attention_mask_chunks = torch.chunk(attention_mask, chunks = B, dim = 0)
-
- def chunk_optional(tensor, chunks):
- if tensor is None:
- return [None] * chunks
- return torch.chunk(tensor, chunks = chunks, dim = 0)
-
- import math
-
- total_samples = input_ids.shape[0]
- batch_size = math.ceil(total_samples / B)
-
- input_ids_chunks = []
- attention_mask_chunks = []
- pixel_values_chunks = []
- image_grid_thw_chunks = []
- pixel_attention_mask_chunks = []
-
- current_pixel_idx = 0
- # TRL 0.23.0 batching logic
- for start in range(0, total_samples, batch_size):
- end = start + batch_size
-
- input_ids_chunks.append(input_ids[start:end])
- attention_mask_chunks.append(attention_mask[start:end])
-
- if image_grid_thw is not None and pixel_values is not None:
- grid_slice = image_grid_thw[start:end]
- image_grid_thw_chunks.append(grid_slice)
-
- batch_pixel_count = grid_slice.prod(dim = -1).sum().item()
-
- start_pixel_idx = current_pixel_idx
- end_pixel_idx = current_pixel_idx + batch_pixel_count
-
- pixel_values_chunks.append(
- pixel_values[start_pixel_idx:end_pixel_idx]
- )
-
- if pixel_attention_mask is not None:
- pixel_attention_mask_chunks.append(
- pixel_attention_mask[start_pixel_idx:end_pixel_idx]
- )
- else:
- pixel_attention_mask_chunks.append(None)
-
- current_pixel_idx = end_pixel_idx
-
- else:
- pixel_values_chunks.append(None)
- image_grid_thw_chunks.append(None)
- pixel_attention_mask_chunks.append(None)
-
- if image_sizes is not None and not isinstance(image_sizes, torch.Tensor):
- image_sizes_chunks = [[size] for size in image_sizes]
- else:
- image_sizes_chunks = chunk_optional(image_sizes, B)
-
- temperature = self.temperature
- logit_softcapping = getattr(model.config, "final_logit_softcapping", 0)
- if logit_softcapping is None:
- logit_softcapping = 0
- logit_scale_multiply = getattr(model.config, "logit_scale", 0)
- if logit_scale_multiply is None:
- logit_scale_multiply = 0
- logit_scale_divide = getattr(model.config, "logits_scaling", 0)
- if logit_scale_divide is None:
- logit_scale_divide = 0
-
- zipped_inputs = zip(
- input_ids_chunks,
- attention_mask_chunks,
- pixel_values_chunks,
- image_grid_thw_chunks,
- pixel_attention_mask_chunks,
- image_sizes_chunks,
- )
- os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1"
-
- with _get_inference_mode_context_manager(model):
- for (
- input_ids_chunk,
- attention_mask_chunk,
- pixel_values_chunk,
- image_grid_thw_chunk,
- pixel_attention_mask_chunk,
- image_sizes_chunk,
- ) in zipped_inputs:
- with torch.amp.autocast(
- device_type = "cuda", dtype = self._autocast_dtype
- ):
- if pixel_values is None:
- logits_chunk = unwrapped_model(
- input_ids = input_ids_chunk,
- attention_mask = attention_mask_chunk,
- pixel_values = pixel_values_chunk,
- image_grid_thw = image_grid_thw_chunk,
- pixel_attention_mask = pixel_attention_mask_chunk,
- image_sizes = image_sizes_chunk,
- ).logits
-
- completion_input_ids_chunk = input_ids_chunk[
- :, -(logits_to_keep + max_left_pad) :
- ]
- logits_chunk = logits_chunk[
- :, -(logits_to_keep + max_left_pad + 1) :, :
- ]
- logits_chunk = logits_chunk[:, :-1, :]
- else:
- # Essentially, for VLMs we do not go via the optimized path in models/,
- # so we don't encounter the Flash Attn left-padding issue.
- logits_chunk = unwrapped_model(
- input_ids = input_ids_chunk,
- attention_mask = attention_mask_chunk,
- pixel_values = pixel_values_chunk,
- image_grid_thw = image_grid_thw_chunk,
- pixel_attention_mask = pixel_attention_mask_chunk,
- image_sizes = image_sizes_chunk,
- logits_to_keep = logits_to_keep + 1,
- ).logits
-
- logits_chunk = logits_chunk[:, :-1, :]
- completion_input_ids_chunk = input_ids_chunk[
- :, -logits_to_keep:
- ]
-
- logprobs_chunk = chunked_hidden_states_selective_log_softmax(
- logits_chunk,
- lm_head,
- completion_input_ids_chunk,
- chunks = input_ids_chunk.shape[0] * multiplier,
- logit_scale_multiply = logit_scale_multiply,
- logit_scale_divide = logit_scale_divide,
- logit_softcapping = logit_softcapping,
- temperature = temperature,
- )
- # This is needed to avoid race conditions with GPT OSS offload_embbed=True
- # However, it seems that this line does not slow down or disrupt models.
- device_synchronize()
- all_logprobs_list.append(logprobs_chunk)
- logprobs = torch.cat(all_logprobs_list, dim = 0)
- entropies = None
-
- os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "0"
-
- return logprobs.detach(), entropies # logps, entropies
- # input_ids = input_ids[:, -logits_to_keep:]
- # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
- # See https://github.com/huggingface/trl/issues/2770
- # logits = logits[:, -logits_to_keep:]
- # return logits
- # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
- # logits = logits / self.temperature
- # logps = selective_log_softmax(logits, input_ids)
-
- # row_indices, col_indices = torch.where(logps < -20)
-
- # # Method 1: Check if tensors have elements
- # if len(row_indices) > 0 and len(col_indices) > 0:
- # breakpoint() # Breakpoint triggered here
- # print("Found high values!")
- # return logps # compute logprobs for the input tokens
-
- def _fix_param_name_to_vllm(self, name, extra_prefixes: Optional[list[str]] = None):
- extra_prefixes = extra_prefixes or []
- prefixes = ["_checkpoint_wrapped_module."] + extra_prefixes
- for prefix in prefixes:
- name = name.replace(prefix, "")
- return name
-
- def _sync_fsdp1_params_to_vllm(self, module: nn.Module, prefix: str = "", visited=None):
- """Memory-efficient post-order traversal of FSDP modules to extract full parameters and sync with vLLM."""
- # For FSDP1, we need to recurse into children and also use summon_full_params
- if visited is None:
- visited = set()
- for child_name, child_module in module.named_children():
- child_prefix = f"{prefix}.{child_name}" if prefix else child_name
- self._sync_fsdp1_params_to_vllm(
- child_module, prefix=child_prefix, visited=visited
- ) # recurse into the child
-
- if isinstance(module, FSDP):
- with FSDP.summon_full_params(module, recurse=False, writeback=False):
- for param_name, param in module.named_parameters():
- full_name = f"{prefix}.{param_name}" if prefix else param_name
- full_name = self._fix_param_name_to_vllm(full_name, extra_prefixes=["_fsdp_wrapped_module."])
-
- if full_name in visited:
- continue # skip FSDP subtrees already traversed
- visited.add(full_name)
-
- if self.vllm_mode == "server" and self.accelerator.is_main_process:
- self.vllm_client.update_named_param(full_name, param.data)
- elif self.vllm_mode == "colocate":
-
- pass
-
- pass
-
- def _sync_fsdp2_params_to_vllm(self, module: nn.Module):
- # For FSDP2, module already covers all parameters, so no need for recursion
- for name, param in module.items():
- if param.is_cpu:
- param = param.to(torch.device("cuda"))
- param = param.full_tensor()
-
- if self.vllm_mode == "server" and self.accelerator.is_main_process:
- self.vllm_client.update_named_param(name, param)
- elif self.vllm_mode == "colocate":
-
- pass
-
- pass
-
- def _move_model_to_vllm(self, *args, **kwargs):
- return None
-
- @profiling_decorator
- def _prepare_inputs(
- self, generation_batch: dict[str, Union[torch.Tensor, Any]]
- ) -> dict[str, Union[torch.Tensor, Any]]:
- # Prepares inputs for model training/evaluation by managing completion generation and batch handling.
- # During training:
- # - Receives the local generation batch (Per-GPU batch size × steps per generation)
- # from the modified training dataloader instead of the standard local batch
- # - Generates completions once for the entire generation batch and splits it into batches of size
- # `per_device_train_batch_size`
- # - Buffers these completions and returns the appropriate slice for the current accumulation step
- # - Optimizes by regenerating completions only periodically (every steps_per_generation * num_iterations)
- # During evaluation:
- # - The input is treated as a standard local batch (no accumulation, no multiple iterations)
- # - Completions are generated for each batch without buffering or reuse
- # Returns a single local batch in both cases.
-
- mode = "train" if self.model.training else "eval"
- if mode == "train":
- generate_every = self.args.steps_per_generation * self.num_iterations
- if self._step % generate_every == 0 or self._buffered_inputs is None:
- # self._buffered_inputs=None can occur when resuming from a checkpoint
- generation_batch = self._generate_and_score_completions(generation_batch)
- generation_batch = split_pixel_values_by_grid(generation_batch)
-
- try: generation_batch = shuffle_sequence_dict(generation_batch)
-
- except: pass
- generation_batches = split_tensor_dict(generation_batch, self.args.steps_per_generation)
- self._buffered_inputs = [unsplit_pixel_values_by_grid(batch) for batch in generation_batches]
- inputs = self._buffered_inputs[self._step % self.args.steps_per_generation]
- self._step += 1
- else:
- # In evaluation, there is neither batch grouping for generation, nor multiple iterations, hence
- # local generation batch == local eval batch
- inputs = self._generate_and_score_completions(generation_batch)
- return inputs
-
- @profiling_decorator
- def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list):
- device = self.accelerator.device
- rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
-
- # Repeat all input columns (but "prompt", "completion", and "completion_ids") to match the num of generations
- keys = [key for key in inputs[0] if key not in ["prompt", "completion", "completion_ids"]]
- reward_kwargs = {key: [example[key] for example in inputs] for key in keys}
-
- # This allows for dynamic reward shaping based on training progress.
- reward_kwargs["trainer_state"] = self.state
-
- for i, (reward_func, reward_processing_class, reward_func_name) in enumerate(
- zip(self.reward_funcs, self.reward_processing_classes, self.reward_func_names)
- ):
- with profiling_context(self, reward_func_name):
- if isinstance(reward_func, nn.Module): # Module (no PretrainedModel) for compat with compiled models
- if is_conversational(inputs[0]):
- messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
- texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
- else:
- texts = [p + c for p, c in zip(prompts, completions)]
- reward_inputs = reward_processing_class(
- text=texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
- )
- reward_inputs = super()._prepare_inputs(reward_inputs)
- with torch.inference_mode():
- rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
- else:
- output_reward_func = reward_func(
- prompts=prompts, completions=completions, completion_ids=completion_ids_list, **reward_kwargs
- )
- # Convert None values to NaN
- output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func]
-
- rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
-
- # If all reward functions return None for a given row, issue a detailed warning
- if torch.isnan(rewards_per_func).all(dim=1).any():
- nan_row_idx = torch.isnan(rewards_per_func).all(dim=1).nonzero(as_tuple=True)[0][0]
- row_reward_kwargs = {
- key: value[nan_row_idx] for key, value in reward_kwargs.items() if key != "trainer_state"
- }
- row_reward_kwargs["prompt"] = prompts[nan_row_idx]
- row_reward_kwargs["completion"] = completions[nan_row_idx]
- logger.warning(
- f"All reward functions returned None for the following kwargs:\n{row_reward_kwargs}\n"
- "Please ensure that at least one reward function returns a valid reward."
- )
-
- # Gather the reward per function: this part is crucial, because the rewards are normalized per group and the
- # completions may be distributed across processes
- rewards_per_func = gather(rewards_per_func)
- return rewards_per_func
-
- def _generate_single_turn(self, prompts: list[str], images: Optional[list]):
- device = self.accelerator.device
-
- # If the prompts are conversational and the inputs contain images, we need to convert the prompts from
- # [{"role": "user", "content": "What color is the sky?"}] to
- # [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}]
- kwargs = {}
- if images is not None:
- kwargs = {"images": images}
- for prompt, image_list in zip(prompts, images):
- if isinstance(prompt, list): # i.e., when using conversational data
- prepare_multimodal_messages(prompt, num_images=len(image_list))
-
-
- _chat_template_ = getattr(self.processing_class, "chat_template", None)
- if _chat_template_ is None: _chat_template_ = ""
- _supported_keys_ = set(("prompt", "chosen", "rejected", "completion", "messages", "label"))
- _batch_chat_kwargs_ = getattr(self, "_unsloth_batch_chat_kwargs", None)
-
- prompts_text = []
- for _idx_, _example_ in enumerate(prompts):
- _tokenizer_kwargs_ = {}
- if type(_example_) is not dict:
- _example_ = {"prompt": _example_}
- _left_keys_ = _example_.keys() - _supported_keys_
- for k in _left_keys_:
- if k in _chat_template_:
- v = _example_[k]
- if type(v) is str:
- _tokenizer_kwargs_[k] = v
- if _batch_chat_kwargs_ is not None and _idx_ < len(_batch_chat_kwargs_):
- for _bk_, _bv_ in _batch_chat_kwargs_[_idx_].items():
- if _bk_ not in _tokenizer_kwargs_:
- _tokenizer_kwargs_[_bk_] = _bv_
- _x_ = maybe_apply_chat_template(_example_, self.processing_class, **_tokenizer_kwargs_)["prompt"]
- prompts_text.append(_x_)
- if images is not None:
- prompt_inputs = self.processing_class(text=prompts_text, padding=True, return_tensors="pt", **kwargs)
- prompt_inputs = super()._prepare_inputs(prompt_inputs)
- forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
- else:
- forward_kwargs = {}
-
- # Generate completions using either vLLM or regular generation
- if self.use_vllm:
- if self.vllm_mode == "colocate" and self.args.vllm_enable_sleep_mode:
- # wake up colocated vLLM instances if needed
- torch.cuda.empty_cache() # required to avoid OOM in some cases
- self.llm.wake_up()
-
- # First, update the vLLM weights if needed
- if self.state.global_step != self._last_loaded_step:
- self._move_model_to_vllm()
- self._last_loaded_step = self.state.global_step
-
- # Generate completions using vLLM: gather all prompts and use them in a single call in the main process
- if self.vllm_mode == "server":
- all_prompts_text = gather_object(prompts_text)
- if images is not None:
- all_images = gather_object(images)
-
- if self.accelerator.is_main_process:
- # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate
- # num_generations outputs for each one. This is faster than generating outputs for each duplicate
- # prompt individually.
- ordered_set_of_prompts = all_prompts_text[:: self.num_generations]
-
- if images is not None:
- ordered_set_of_images = all_images[:: self.num_generations]
- else:
- ordered_set_of_images = None
-
- with profiling_context(self, "vLLM.generate"):
- output = self.vllm_client.generate(
- prompts=ordered_set_of_prompts,
- images=ordered_set_of_images,
- n=self.num_generations,
- repetition_penalty=self.repetition_penalty,
- temperature=self.temperature,
- top_p=self.top_p,
- top_k=-1 if self.top_k is None else self.top_k,
- min_p=0.0 if self.min_p is None else self.min_p,
- max_tokens=self.max_completion_length,
- truncate_prompt_tokens=self.max_prompt_length,
- guided_decoding_regex=self.guided_decoding_regex,
- generation_kwargs=self.args.generation_kwargs,
- )
- payload = (output["prompt_ids"], output["completion_ids"], output["logprobs"])
- else:
- payload = None
-
- # Broadcast the completions from the main process to all processes, ensuring each process receives its corresponding slice.
- obj_list = [payload]
- broadcast_object_list(obj_list, from_process=0)
- all_prompt_ids, all_completion_ids, all_logprobs = obj_list[0]
-
- # At this point, we only get 1 copy of each prompt, so we need to repeat them num_generations times
- all_prompt_ids = [ids for ids in all_prompt_ids for _ in range(self.num_generations)]
-
- process_slice = slice(
- self.accelerator.process_index * len(prompts),
- (self.accelerator.process_index + 1) * len(prompts),
- )
- prompt_ids = all_prompt_ids[process_slice]
- completion_ids = all_completion_ids[process_slice]
- logprobs = all_logprobs[process_slice]
-
- # Generate completions using colocated vLLM instances: each device holds vLLM copy and work on their own batch of prompts
- elif self.vllm_mode == "colocate":
- if self.guided_decoding_regex:
- guided_decoding = GuidedDecodingParams(regex=self.guided_decoding_regex)
- else:
- guided_decoding = None
-
- generation_kwargs = {
- "n": 1, # vLLM on each GPU generates only 1 in colocate mode
- "repetition_penalty": self.repetition_penalty,
- "temperature": self.temperature,
- "top_p": self.top_p,
- "top_k": -1 if self.top_k is None else self.top_k,
- "min_p": 0.0 if self.min_p is None else self.min_p,
- "max_tokens": self.max_completion_length,
- "truncate_prompt_tokens": self.max_prompt_length,
- "guided_decoding": guided_decoding,
- "logprobs": 0, # only return the logprob of the generated token
- }
- if self.args.generation_kwargs is not None:
- generation_kwargs.update(self.args.generation_kwargs)
- sampling_params = SamplingParams(**grpo_update_SamplingParams(SamplingParams, generation_kwargs, getattr(self.args, 'vllm_sampling_params', None)))
-
- if self.vllm_tensor_parallel_size > 1:
- # Gather prompts from all ranks in the TP group and flatten.
- # Each rank starts with its own prompts; after gathering, all ranks see the full group set.
- orig_size = len(prompts_text)
- gathered_prompts = [None for _ in range(self.vllm_tensor_parallel_size)]
- torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.tp_group)
- all_prompts_text = [p for sublist in gathered_prompts for p in sublist]
-
- if images is not None:
- gathered_images = [None for _ in range(self.vllm_tensor_parallel_size)]
- torch.distributed.all_gather_object(gathered_images, images, group=self.tp_group)
- all_images = [img for sublist in gathered_images for img in sublist]
- else:
- all_images = None
- else:
- all_prompts_text = prompts_text
- all_images = images
-
- if images is not None and all_images:
- vllm_inputs = []
- for prompt, image_list in zip(all_prompts_text, all_images):
- vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image_list}})
-
- else:
- vllm_inputs = all_prompts_text
-
- with profiling_context(self, "vLLM.generate"):
- all_outputs = self.llm.generate(vllm_inputs, sampling_params=sampling_params, use_tqdm=False, lora_request = self.model.load_lora('grpo_trainer_lora_model', load_tensors = True))
-
- all_prompt_ids = [output.prompt_token_ids for output in all_outputs]
- all_completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs]
- all_logprobs = [
- [next(iter(lp.values())).logprob for lp in output.logprobs]
- for outputs in all_outputs
- for output in outputs.outputs
- ]
-
- if self.vllm_tensor_parallel_size > 1:
- # Slice completions for this rank within its TP group.
- # Each rank generates all outputs — we keep only our share.
- local_rank_in_group = torch.distributed.get_rank(group=self.tp_group)
- tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size)
- prompt_ids = all_prompt_ids[tp_slice]
- completion_ids = all_completion_ids[tp_slice]
- logprobs = all_logprobs[tp_slice]
- else:
- prompt_ids = all_prompt_ids
- completion_ids = all_completion_ids
- logprobs = all_logprobs
-
- if self.args.vllm_enable_sleep_mode:
- self.llm.sleep(level=1)
-
- elif self.use_transformers_paged:
- # Re-process inputs for paged generation if needed
- # Note: images are already validated and preprocessed above
- paged_prompt_inputs = self.processing_class(text=prompts_text, **kwargs)
- previous_attn = self.model_wrapped.config._attn_implementation
-
- if is_flash_attn_2_available():
- self.model_wrapped.config._attn_implementation = "paged_attention"
- else:
- self.model_wrapped.config._attn_implementation = "sdpa_paged"
- with (
- profiling_context(self, "transformers.generate_batch"),
- unwrap_model_for_generation(
- self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
- ) as unwrapped_model,
- torch.no_grad(),
- FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(),
- ):
- # Cast to the appropriate dtype based on training configuration
- if self.args.bf16:
- unwrapped_model.to(torch.bfloat16)
- elif self.args.fp16:
- unwrapped_model.to(torch.float16)
- with torch.inference_mode():
- all_outputs = unwrapped_model.generate_batch(
- paged_prompt_inputs.input_ids, generation_config=self.generation_config, progress_bar=False
- )
- unwrapped_model.train() # restore training mode, as generate_batch forces eval mode
- completion_ids = [output.generated_tokens for output in all_outputs.values()]
- prompt_ids = paged_prompt_inputs.input_ids
- # Restore the original attention implementation, training mode
- self.model_wrapped.config._attn_implementation = previous_attn
- logprobs = None # not used in this case
-
- else:
- # Regular generation path
- generate_inputs = self.processing_class(
- text=prompts_text,
- return_tensors="pt",
- padding=True,
- padding_side="left",
- **kwargs,
- )
- generate_inputs = super()._prepare_inputs(generate_inputs)
-
- with (
- profiling_context(self, "transformers.generate"),
- unwrap_model_for_generation(
- self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
- ) as unwrapped_model,
- torch.no_grad(),
- FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(),
- ):
- prompt_completion_ids = unwrapped_model.generate(
- **generate_inputs, generation_config=self.generation_config, disable_compile=True
- )
- # Compute prompt length and extract completion ids
- prompt_ids, prompt_mask = generate_inputs["input_ids"], generate_inputs["attention_mask"]
- prompt_length = prompt_ids.size(1)
- completion_ids = prompt_completion_ids[:, prompt_length:]
-
- # Mask everything after the first EOS token
- is_eos = completion_ids == self.eos_token_id
- eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
- eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
- sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
- completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
- prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())]
- completion_ids = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool())]
- logprobs = None # not used in this case
-
- return prompt_ids, completion_ids, logprobs, forward_kwargs
-
- def _generate(self, prompts: list[str], images: Optional[list]):
- device = self.accelerator.device
- mode = "train" if self.model.training else "eval"
-
- prompt_ids, completion_ids, logprobs, forward_kwargs = self._generate_single_turn(prompts, images)
-
- # Get completion length per sequence, used for logging
- prompt_lengths = torch.tensor([len(ids) for ids in prompt_ids], device=device)
- completion_lengths = torch.tensor([len(ids) for ids in completion_ids], device=device)
- agg_prompt_lengths = self.accelerator.gather(prompt_lengths)
- agg_completion_lengths = self.accelerator.gather(completion_lengths)
- total_prompt_tokens = agg_prompt_lengths.sum()
- total_completion_tokens = agg_completion_lengths.sum() # = num_items_in_batch, required for the DAPO loss
-
- # Log the metrics
- if mode == "train":
- self.state.num_input_tokens_seen += (total_prompt_tokens + total_completion_tokens).item()
- self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen]
-
- # Log completion lengths, mean, min, max
- self._metrics[mode]["completions/mean_length"].append(agg_completion_lengths.float().mean().item())
- self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item())
- self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item())
-
- # Identify sequences that terminated with EOS and log their lengths
- eos_and_pad = [self.eos_token_id, self.pad_token_id]
- is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids], device=device)
- agg_is_truncated = self.accelerator.gather(is_truncated)
- self._metrics[mode]["completions/clipped_ratio"].append(agg_is_truncated.float().mean().item())
- term_completion_lengths = agg_completion_lengths[~agg_is_truncated]
- if len(term_completion_lengths) == 0: # edge case where no terminated sequences are found
- term_completion_lengths = torch.zeros(1, device=device)
- self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item())
- self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item())
- self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item())
-
- return prompt_ids, completion_ids, total_completion_tokens, logprobs, forward_kwargs
-
- def _generate_and_score_completions(
- self, inputs: list[dict[str, Union[torch.Tensor, Any]]]
- ) -> dict[str, Union[torch.Tensor, Any]]:
- device = self.accelerator.device
- mode = "train" if self.model.training else "eval"
-
- prompts = [x["prompt"] for x in inputs]
- # Unsloth: Extract per-sample chat_template_kwargs before metadata is lost
- _ct_ = getattr(self.processing_class, 'chat_template', None) or ''
- _sk_ = {'prompt', 'chosen', 'rejected', 'completion', 'messages', 'label',
- 'images', 'image', 'videos', 'video', 'audios', 'audio'}
- self._unsloth_batch_chat_kwargs = []
- for _inp_ in inputs:
- _kw_ = {}
- if isinstance(_inp_, dict):
- for _k_ in _inp_.keys() - _sk_:
- if _k_ in _ct_ and isinstance(_inp_[_k_], str):
- _kw_[_k_] = _inp_[_k_]
- self._unsloth_batch_chat_kwargs.append(_kw_)
- if "images" in inputs[0]:
- images = [example.get("images") for example in inputs]
- elif "image" in inputs[0]:
- images = [[example.get("image")] if example.get("image") is not None else None for example in inputs]
- else:
- images = None
- # Transformers requires at least one image in the batch, otherwise it throws an error
- if images is not None and all(img_list == [] for img_list in images):
- images = None
-
- (
- prompt_ids_list,
- completion_ids_list,
- num_items_in_batch,
- sampling_per_token_logps_list,
- forward_kwargs,
- ) = self._generate(prompts, images)
-
- # Convert lists of token IDs to padded tensors
- prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list]
- prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids]
- prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left")
- prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left")
- completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids_list]
- completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids]
- completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right")
- completion_mask = pad(completion_mask, padding_value=0, padding_side="right")
- if sampling_per_token_logps_list is not None:
- sampling_per_token_logps = [torch.tensor(logps, device=device) for logps in sampling_per_token_logps_list]
- sampling_per_token_logps = pad(sampling_per_token_logps, padding_value=0.0, padding_side="right")
- else:
- sampling_per_token_logps = None
-
- # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask
- if self.mask_truncated_completions:
- eos_and_pad = [self.eos_token_id, self.pad_token_id]
- is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device)
- completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int()
-
- # Concatenate prompt_mask with completion_mask for logit computation
- prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C)
- attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
- # If token_type_ids are used, extend them with zeros for the completion part
- if "token_type_ids" in forward_kwargs:
- token_type_ids = forward_kwargs["token_type_ids"]
- forward_kwargs["token_type_ids"] = torch.cat(
- [token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1
- )
-
- logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
-
- max_left_pad = None
- batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
- try:
- # TRL 0.23.1 and below path
- if not has_images:
- # Left pad prompt before calculation old and ref hidden states
- left_pad_tokens_per_prompt = calculate_pad_tokens_in_prompt(prompt_completion_ids, logits_to_keep, self.processing_class.pad_token_id)
- max_left_pad = torch.max(left_pad_tokens_per_prompt).item()
- except:
- # TRL 0.24.0 and below path
- if images is None:
- # Left pad prompt before calculation old and ref hidden states
- left_pad_tokens_per_prompt = calculate_pad_tokens_in_prompt(prompt_completion_ids, logits_to_keep, self.processing_class.pad_token_id)
- max_left_pad = torch.max(left_pad_tokens_per_prompt).item()
- self.model.for_training()
-
- num_images = [len(img_list) for img_list in images] if images is not None else None
-
- with torch.no_grad():
- # If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of
- # a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the
- # samples may come from an earlier version of the model. In that case, we need to track old_per_token_logps
- # for importance sampling. If the steps are aligned, importance sampling isn't necessary and we set
- # old_per_token_logps to None.
- # When using vLLM, we always compute old_per_token_logps for importance sampling, it was shown that the
- # distribution mismatch between vLLM and the training model can be large and harm the training.
- generate_every = self.args.steps_per_generation * self.num_iterations # generation frequency
-
- if self.args.gradient_accumulation_steps % generate_every != 0 or (
- self.use_vllm
- ):
- old_per_token_logps, _ = self._get_per_token_logps_and_entropies(
- self.model,
- prompt_completion_ids,
- attention_mask,
- logits_to_keep,
- batch_size,
- num_images=num_images,
- **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes
- )
- else:
- old_per_token_logps = None
-
- # Compute the importance sampling ratio when using vLLM, to correct for potential distribution mismatch
- if False and self.use_vllm and self.vllm_importance_sampling_correction:
- importance_sampling_ratio = torch.exp(old_per_token_logps - sampling_per_token_logps)
- importance_sampling_ratio = torch.clamp(
- importance_sampling_ratio, max=self.vllm_importance_sampling_cap
- )
-
- # Compute the per-token log probabilities for the reference model
- if self.beta != 0.0:
- if self.ref_model is not None:
- ref_per_token_logps, _ = self._get_per_token_logps_and_entropies(
- self.ref_model,
- prompt_completion_ids,
- attention_mask,
- logits_to_keep,
- batch_size=batch_size,
- num_images=num_images,
- **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes
- )
- else:
- with self.accelerator.unwrap_model(self.model).disable_adapter():
- ref_per_token_logps, _ = self._get_per_token_logps_and_entropies(
- self.model,
- prompt_completion_ids,
- attention_mask,
- logits_to_keep,
- batch_size=batch_size,
- num_images=num_images,
- **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes
- )
- else:
- ref_per_token_logps = None
-
- # Decode
- prompts_text = self.processing_class.batch_decode(prompt_ids, skip_special_tokens=True)
- completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
- if is_conversational(inputs[0]):
- completions = []
- for prompt, completion in zip(prompts, completions_text):
- bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
- completions.append([{"role": "assistant", "content": bootstrap + completion}])
- else:
- completions = completions_text
-
- # Calculate rewards for each reward function. rewards_per_func aggregates rewards across all processes. This is
- # important because rewards will be normalized per group, and completions are distributed. We will later slice
- # rewards_per_func to extract each process's subset.
- if images is not None:
- rewards_per_func = self._calculate_rewards(inputs, prompts_text, completions_text, completion_ids_list)
- else:
- rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list)
-
- # Apply weights to each reward function's output and sum
- rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1)
-
- # Compute grouped-wise rewards
- mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
-
- # Normalize the rewards to compute the advantages
- mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
- advantages = rewards - mean_grouped_rewards
-
- if self.scale_rewards in ["group", "none"]:
- # If self.scale_rewards = "none", we'll still log group level std
- std_rewards = rewards.view(-1, self.num_generations).std(dim=1)
- std_rewards = std_rewards.repeat_interleave(self.num_generations, dim=0)
- elif self.scale_rewards == "batch":
- # Compute global std
- std_rewards = rewards.std().expand_as(rewards)
- else:
- raise ValueError(
- f"Invalid value for scale_rewards: {self.scale_rewards}. Must be one of 'batch', 'group', or 'none'."
- )
-
- is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards))
- if self.scale_rewards != "none":
- advantages = advantages / (std_rewards + 1e-4)
-
- # Slice to keep only the local part of the data
- process_slice = slice(
- self.accelerator.process_index * len(prompts),
- (self.accelerator.process_index + 1) * len(prompts),
- )
- all_process_advantages = advantages.clone() # keep the aggregated advantages for logging
- advantages = advantages[process_slice]
-
- # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values)
- for i, reward_func_name in enumerate(self.reward_func_names):
- mean_rewards = torch.nanmean(rewards_per_func[:, i]).item()
- self._metrics[mode][f"rewards/{reward_func_name}/mean"].append(mean_rewards)
- std_func_rewards = nanstd(rewards_per_func[:, i]).item()
- self._metrics[mode][f"rewards/{reward_func_name}/std"].append(std_func_rewards)
- self._metrics[mode]["reward"].append(mean_grouped_rewards.mean().item())
- self._metrics[mode]["reward_std"].append(std_rewards.mean().item())
- self._metrics[mode]["frac_reward_zero_std"].append(is_std_zero.float().mean().item())
-
- # Log prompt and completion texts
- self._logs["prompt"].extend(gather_object(prompts_text))
- self._logs["completion"].extend(gather_object(completions_text))
- for i, name in enumerate(self.reward_func_names):
- self._logs["rewards"][name].extend(rewards_per_func[:, i].tolist())
- self._logs["advantages"].extend(all_process_advantages.tolist())
-
- if images is not None:
- self._logs["images"].extend(gather_object(images))
-
- if False and self.use_vllm and self.vllm_importance_sampling_correction:
- delta = torch.abs(old_per_token_logps - sampling_per_token_logps)
- delta = delta[completion_mask.bool()]
- mean_delta = torch.mean(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device)
- max_delta = torch.max(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device)
- self._metrics[mode]["sampling/sampling_logp_difference/mean"].append(
- self.accelerator.gather(mean_delta).mean().item()
- )
- self._metrics[mode]["sampling/sampling_logp_difference/max"].append(
- self.accelerator.gather(max_delta).max().item()
- )
-
- flat_is_ratio = importance_sampling_ratio[completion_mask.bool()]
- min_importance_sampling_ratio = (
- torch.min(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device)
- )
- mean_importance_sampling_ratio = (
- torch.mean(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device)
- )
- max_importance_sampling_ratio = (
- torch.max(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device)
- )
- self._metrics[mode]["sampling/importance_sampling_ratio/min"].append(
- nanmin(self.accelerator.gather(min_importance_sampling_ratio)).item()
- )
- self._metrics[mode]["sampling/importance_sampling_ratio/mean"].append(
- self.accelerator.gather(mean_importance_sampling_ratio).nanmean().item()
- )
- self._metrics[mode]["sampling/importance_sampling_ratio/max"].append(
- nanmax(self.accelerator.gather(max_importance_sampling_ratio)).item()
- )
-
- output = {
- "prompt_ids": prompt_ids,
- "prompt_mask": prompt_mask,
- "completion_ids": completion_ids,
- "completion_mask": completion_mask,
- "advantages": advantages,
- "num_items_in_batch": num_items_in_batch,
- }
- if old_per_token_logps is not None:
- output["old_per_token_logps"] = old_per_token_logps
- if False and self.use_vllm and self.vllm_importance_sampling_correction:
- output["importance_sampling_ratio"] = importance_sampling_ratio
- if ref_per_token_logps is not None:
- output["ref_per_token_logps"] = ref_per_token_logps
- if "pixel_values" in forward_kwargs:
- output["pixel_values"] = forward_kwargs["pixel_values"]
- if "image_grid_thw" in forward_kwargs:
- output["image_grid_thw"] = forward_kwargs["image_grid_thw"]
- if "pixel_attention_mask" in forward_kwargs:
- output["pixel_attention_mask"] = forward_kwargs["pixel_attention_mask"]
- if "image_sizes" in forward_kwargs:
- output["image_sizes"] = forward_kwargs["image_sizes"]
- if "token_type_ids" in forward_kwargs:
- output["token_type_ids"] = forward_kwargs["token_type_ids"]
- if images is not None:
- output["num_images"] = num_images
- if max_left_pad is not None:
- output["max_left_pad"] = torch.tensor(prompt_ids.shape[0] * [max_left_pad]).unsqueeze(-1)
- try:
- if self.use_vllm and getattr(self, "vllm_importance_sampling_correction", False):
- output["sampling_per_token_logps"] = sampling_per_token_logps
- except NameError:
- output["sampling_per_token_logps"] = None
- return output
-
- def compute_liger_loss(self, unwrapped_model, inputs):
- # Compute the per-token log probabilities for the model
- prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
- completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
- input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
- attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
- logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
-
- # Get the last hidden state of the model
- last_hidden_state = self._get_last_hidden_state(
- unwrapped_model,
- input_ids,
- attention_mask,
- logits_to_keep,
- inputs.get("pixel_values"),
- inputs.get("image_grid_thw"),
- inputs.get("pixel_attention_mask"),
- inputs.get("image_sizes"),
- )
-
- # compute loss and metrics using liger grpo loss
- loss, metrics = self.liger_grpo_loss(
- _input=last_hidden_state,
- lin_weight=unwrapped_model.lm_head.weight,
- selected_token_ids=completion_ids,
- attention_mask=completion_mask,
- advantages=inputs["advantages"],
- bias=unwrapped_model.lm_head.bias,
- old_per_token_logps=inputs.get("old_per_token_logps"),
- ref_per_token_logps=inputs.get("ref_per_token_logps"),
- )
- # Extract metrics from the liger_grpo_loss output
- # KL divergence is the first metric when beta is non-zero
- mean_kl = metrics[0] if self.beta != 0.0 else None
- clip_ratio = metrics[-1]
-
- mode = "train" if self.model.training else "eval"
- if self.beta != 0.0:
- self._metrics[mode]["kl"].append(self.accelerator.gather(mean_kl).mean().item())
- self._metrics[mode]["clip_ratio"].append(self.accelerator.gather(clip_ratio).mean().item())
- return loss / self.current_gradient_accumulation_steps
-
- def compute_loss(
- self, model, inputs, return_outputs = False, num_items_in_batch = None
- ):
- if return_outputs:
- raise ValueError("The GRPOTrainer does not support returning outputs")
- # Compute the per-token log probabilities for the model
-
- prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
- completion_ids, completion_mask = (
- inputs["completion_ids"],
- inputs["completion_mask"],
- )
- pixel_values, image_grid_thw = (
- inputs.get("pixel_values", None),
- inputs.get("image_grid_thw", None),
- )
- pixel_attention_mask, image_sizes = (
- inputs.get("pixel_attention_mask", None),
- inputs.get("image_sizes", None),
- )
- num_items_in_batch = inputs.get("num_items_in_batch", None)
- sampling_per_token_logps = inputs.get("sampling_per_token_logps", None)
- current_gradient_accumulation_steps = self.current_gradient_accumulation_steps
- num_processes = self.accelerator.num_processes
-
- input_ids = torch.cat([prompt_ids, completion_ids], dim = 1)
- bsz, qlen = input_ids.shape
- attention_mask = torch.cat([prompt_mask, completion_mask], dim = 1)
- # attention_mask = None
- logits_to_keep = completion_ids.size(
- 1
- ) # we only need to compute the logits for the completion tokens
- _input_ids = input_ids
- _logits_to_keep = logits_to_keep
-
- get_logps_func = (
- lambda model,
- input_ids,
- attention_mask,
- logits_to_keep,
- batch_size = None,
- compute_entropy = False,
- compute_efficient = False: self._get_per_token_logps(
- model, input_ids, attention_mask, logits_to_keep, compute_efficient
- )
- if hasattr(self, "_get_per_token_logps")
- else self._get_per_token_logps_and_entropies(
- model,
- input_ids,
- attention_mask,
- logits_to_keep,
- batch_size,
- compute_entropy,
- compute_efficient,
- )[0]
- ) # logps
-
- per_token_logps = get_logps_func(
- model, input_ids, attention_mask, logits_to_keep, compute_efficient = True
- )
- # Compute the KL divergence between the model and the reference model
- # _prepare_inputs doesn't return reference log probs anymore. We need to calculate it ourselves.
- # https://github.com/huggingface/trl/blob/05bc43e960396581e458195b8388efe6b82cae1f/trl/trainer/grpo_trainer.py#L1328
- # if self.beta != 0.0:
- # with torch.inference_mode(), model.disable_adapter():
- # ref_per_token_logps = per_token_logps = get_logps_func(model, input_ids, attention_mask, logits_to_keep)
- # else:
- # ref_per_token_logps = None
- ref_logps = inputs.get("ref_per_token_logps", None)
- # per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
- # x - x.detach() allows for preserving gradients from x
- advantages = inputs["advantages"]
- # per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
- # per_token_loss = -(per_token_loss - self.beta * per_token_kl)
- # loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
- old_logps = inputs.get("old_per_token_logps", None)
-
- input_ids = input_ids[:, -logits_to_keep:]
-
- # Get logit softcapping and logit scale
- logit_softcapping = getattr(model.config, "final_logit_softcapping", 0) # Gemma
- if logit_softcapping is None:
- logit_softcapping = 0
- logit_scale_multiply = getattr(model.config, "logit_scale", 0) # Cohere
- if logit_scale_multiply is None:
- logit_scale_multiply = 0
- logit_scale_divide = getattr(model.config, "logits_scaling", 0) # Granite
- if logit_scale_divide is None:
- logit_scale_divide = 0
-
- max_left_pad = inputs.get("max_left_pad", 0)
- if per_token_logps is not None:
- loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1 = (
- grpo_compute_loss_slow(
- ref_logps,
- per_token_logps,
- old_logps,
- input_ids,
- completion_mask,
- self.beta,
- advantages,
- pixel_values = pixel_values,
- image_grid_thw = image_grid_thw,
- loss_type = self.args.loss_type,
- importance_sampling_level = self.importance_sampling_level,
- epsilon_low = self.epsilon_low,
- epsilon_high = self.epsilon_high,
- max_completion_length = self.args.max_completion_length,
- delta = self.args.delta,
- temperature = self.args.temperature,
- max_left_pad = max_left_pad,
- logit_softcapping = logit_softcapping,
- logit_scale_multiply = logit_scale_multiply,
- logit_scale_divide = logit_scale_divide,
- num_items_in_batch = num_items_in_batch,
- current_gradient_accumulation_steps = current_gradient_accumulation_steps,
- num_processes = num_processes,
- sampling_per_token_logps = sampling_per_token_logps,
- )
- )
- else:
- if hasattr(self.args, "loss_type"):
- loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1 = (
- grpo_accumulated_loss(
- trainer = self,
- input_ids = _input_ids,
- pixel_values = pixel_values,
- image_grid_thw = image_grid_thw,
- logits_to_keep = logits_to_keep,
- completion_mask = completion_mask,
- advantages = advantages,
- old_logps = old_logps,
- ref_logps = ref_logps,
- n_chunks = self.args.unsloth_num_chunks,
- loss_type = self.args.loss_type,
- importance_sampling_level = self.importance_sampling_level,
- epsilon_low = self.epsilon_low,
- epsilon_high = self.epsilon_high,
- max_completion_length = self.args.max_completion_length,
- delta = self.args.delta,
- temperature = self.args.temperature,
- max_left_pad = max_left_pad,
- logit_softcapping = logit_softcapping,
- logit_scale_multiply = logit_scale_multiply,
- logit_scale_divide = logit_scale_divide,
- attention_mask = attention_mask,
- num_items_in_batch = num_items_in_batch,
- current_gradient_accumulation_steps = current_gradient_accumulation_steps,
- num_processes = num_processes,
- sampling_per_token_logps = sampling_per_token_logps,
- )
- )
- else:
- # to ensure backwards compatibility with trl 0.15.2 and maybe even 0.17
- loss, completion_length, mean_kl, coef_1 = grpo_accumulated_loss(
- trainer = self,
- input_ids = _input_ids,
- logits_to_keep = logits_to_keep,
- completion_mask = completion_mask,
- advantages = advantages,
- old_logps = old_logps,
- ref_logps = ref_logps,
- n_chunks = self.args.unsloth_num_chunks,
- temperature = self.args.temperature,
- logit_softcapping = logit_softcapping,
- logit_scale_multiply = logit_scale_multiply,
- logit_scale_divide = logit_scale_divide,
- attention_mask = attention_mask,
- )
- if "train" in self._metrics:
- mode = "eval" if self.control.should_evaluate else "train"
- self._metrics[mode]["completion_length"].append(completion_length.item())
- self._metrics[mode]["kl"].append(mean_kl.item())
- else:
- self._metrics["completion_length"].append(completion_length.item())
- self._metrics["kl"].append(mean_kl.item())
-
- if (
- self.use_vllm
- and delta is not None
- and getattr(self, "vllm_importance_sampling_correction", False)
- ):
- mean_delta = (
- torch.mean(delta)
- if delta.numel() > 0
- else torch.tensor(0.0, device = self.model.device)
- )
- max_delta = (
- torch.max(delta)
- if delta.numel() > 0
- else torch.tensor(0.0, device = self.model.device)
- )
- self._metrics[mode]["sampling/sampling_logp_difference/mean"].append(
- self.accelerator.gather(mean_delta).mean().item()
- )
- self._metrics[mode]["sampling/sampling_logp_difference/max"].append(
- self.accelerator.gather(max_delta).max().item()
- )
-
- min_importance_sampling_ratio = (
- torch.min(flat_is_ratio)
- if flat_is_ratio.numel() > 0
- else torch.tensor(0.0, device = self.model.device)
- )
- mean_importance_sampling_ratio = (
- torch.mean(flat_is_ratio)
- if flat_is_ratio.numel() > 0
- else torch.tensor(0.0, device = self.model.device)
- )
- max_importance_sampling_ratio = (
- torch.max(flat_is_ratio)
- if flat_is_ratio.numel() > 0
- else torch.tensor(0.0, device = self.model.device)
- )
- self._metrics[mode]["sampling/importance_sampling_ratio/min"].append(
- self.accelerator.gather(min_importance_sampling_ratio)
- .nan_to_num(nan = float("inf"))
- .min()
- .item()
- )
- self._metrics[mode]["sampling/importance_sampling_ratio/mean"].append(
- self.accelerator.gather(mean_importance_sampling_ratio).nanmean().item()
- )
- self._metrics[mode]["sampling/importance_sampling_ratio/max"].append(
- self.accelerator.gather(max_importance_sampling_ratio)
- .nan_to_num(nan = float("-inf"))
- .max()
- .item()
- )
-
- completion_token_count = completion_mask.sum().clamp(min = 1.0)
-
- def masked_batch_mean(x):
- if x.shape[1] == 1: # when importance_sampling_level == "sequence"
- return x.mean()
- else:
- return (x * completion_mask).sum() / completion_token_count
-
- if advantages.dim() == 1:
- advantages = advantages.unsqueeze(1)
-
- if self.loss_type in ["grpo", "bnpo", "dr_grpo", "dapo"]:
- # Compute the clipped probability ratios
- is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages < 0)
- is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages > 0)
- is_region_clipped = is_low_clipped | is_high_clipped
-
- low_clip = masked_batch_mean(is_low_clipped.float())
- high_clip = masked_batch_mean(is_high_clipped.float())
- clip_ratio = masked_batch_mean(is_region_clipped.float())
-
- gathered_low_clip = self.accelerator.gather(low_clip)
- self._metrics[mode]["clip_ratio/low_mean"].append(
- gathered_low_clip.nanmean().item()
- )
- self._metrics[mode]["clip_ratio/low_min"].append(
- nanmin(gathered_low_clip).item()
- )
- gathered_high_clip = self.accelerator.gather(high_clip)
- self._metrics[mode]["clip_ratio/high_mean"].append(
- gathered_high_clip.nanmean().item()
- )
- self._metrics[mode]["clip_ratio/high_max"].append(
- nanmax(gathered_high_clip).item()
- )
- gathered_clip_ratio = self.accelerator.gather(clip_ratio)
- self._metrics[mode]["clip_ratio/region_mean"].append(
- gathered_clip_ratio.nanmean().item()
- )
- elif self.loss_type == "cispo":
- is_cispo_clipped = (coef_1 > self.epsilon_high) & (advantages > 0)
- cispo_clip_ratio = masked_batch_mean(is_cispo_clipped.float())
- gathered_cispo_clip_ratio = self.accelerator.gather(cispo_clip_ratio)
- self._metrics[mode]["cispo_clip_ratio"].append(
- gathered_cispo_clip_ratio.nanmean().item()
- )
-
- return loss
-
- def _compute_loss(self, model, inputs):
- # Compute the per-token log probabilities for the model
- prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
- completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
- input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
- attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
- logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
-
- # Compute the per_token_logps and the entropy at each position in the completion
- per_token_logps, entropies = self._get_per_token_logps_and_entropies(
- model,
- input_ids,
- attention_mask,
- logits_to_keep,
- compute_entropy=True,
- pixel_values=inputs.get("pixel_values"),
- image_grid_thw=inputs.get("image_grid_thw"),
- num_images=inputs.get("num_images"),
- pixel_attention_mask=inputs.get("pixel_attention_mask"),
- image_sizes=inputs.get("image_sizes"),
- token_type_ids=inputs.get("token_type_ids"),
- )
-
- if self.top_entropy_quantile < 1.0:
- entropy_mask = self.get_high_entropy_mask(entropies, completion_mask, 1 - self.top_entropy_quantile)
- else:
- entropy_mask = None
-
- # Compute the KL divergence between the model and the reference model
- if self.beta != 0.0:
- ref_per_token_logps = inputs["ref_per_token_logps"]
- per_token_kl = (
- torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
- )
-
- # Compute the loss
- advantages = inputs["advantages"]
- # When num_iterations == 1 and steps_per_generation <= gradient_accumulation_steps,
- # old_per_token_logps == per_token_logps. In this case we can skip its computation
- # (see _generate_and_score_completions) and instead use per_token_logps.detach().
- # The exception is when using vLLM, where we always compute old_per_token_logps
- # for importance sampling
- old_per_token_logps = inputs.get("old_per_token_logps")
- old_per_token_logps = per_token_logps.detach() if old_per_token_logps is None else old_per_token_logps
-
- log_ratio = per_token_logps - old_per_token_logps
- if self.importance_sampling_level == "token":
- log_importance_weights = log_ratio
- elif self.importance_sampling_level == "sequence":
- log_importance_weights = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)
- log_importance_weights = log_importance_weights.unsqueeze(-1)
- else:
- raise ValueError(
- f"Unknown importance sampling level: {self.importance_sampling_level}. Possible values are 'token' "
- "and 'sequence'."
- )
- # From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on
- # importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1)
-
- coef_1 = torch.exp(log_importance_weights)
- coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)
-
- # Two-sided clipping
- if self.args.delta is not None:
- coef_1 = torch.clamp(coef_1, max=self.args.delta)
-
- per_token_loss1 = coef_1 * advantages.unsqueeze(1)
- per_token_loss2 = coef_2 * advantages.unsqueeze(1)
- per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
- if entropy_mask is not None:
- per_token_loss = per_token_loss * entropy_mask
-
- if self.use_vllm and self.vllm_importance_sampling_correction:
- per_token_loss = per_token_loss * inputs["importance_sampling_ratio"]
-
- if self.beta != 0.0:
- per_token_loss = per_token_loss + self.beta * per_token_kl
-
- if self.loss_type == "grpo":
- loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean()
- loss = loss / self.current_gradient_accumulation_steps
- elif self.loss_type == "bnpo":
- loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0)
- loss = loss / self.current_gradient_accumulation_steps
- elif self.loss_type == "dr_grpo":
- loss = (per_token_loss * completion_mask).sum() / (per_token_loss.size(0) * self.max_completion_length)
- loss = loss / self.current_gradient_accumulation_steps
- elif self.loss_type == "dapo":
- normalizer = inputs["num_items_in_batch"] / self.accelerator.num_processes
- loss = (per_token_loss * completion_mask).sum() / normalizer
- else:
- raise ValueError(f"Unknown loss type: {self.loss_type}")
-
- # Log the metrics
- mode = "train" if self.model.training else "eval"
-
- completion_token_count = completion_mask.sum().clamp(min=1.0)
-
- def masked_batch_mean(x):
- if x.shape[1] == 1: # when importance_sampling_level == "sequence"
- return x.mean()
- else:
- return (x * completion_mask).sum() / completion_token_count
-
- if self.beta != 0.0:
- mean_kl = masked_batch_mean(per_token_kl)
- self._metrics[mode]["kl"].append(self.accelerator.gather(mean_kl).nanmean().item())
-
- mean_entropy = masked_batch_mean(entropies)
- self._metrics[mode]["entropy"].append(self.accelerator.gather(mean_entropy).nanmean().item())
-
- # Compute the clipped probability ratios
- is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0)
- is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0)
- is_region_clipped = is_low_clipped | is_high_clipped
-
- low_clip = masked_batch_mean(is_low_clipped.float())
- high_clip = masked_batch_mean(is_high_clipped.float())
- clip_ratio = masked_batch_mean(is_region_clipped.float())
-
- gathered_low_clip = self.accelerator.gather(low_clip)
- self._metrics[mode]["clip_ratio/low_mean"].append(gathered_low_clip.nanmean().item())
- self._metrics[mode]["clip_ratio/low_min"].append(nanmin(gathered_low_clip).item())
- gathered_high_clip = self.accelerator.gather(high_clip)
- self._metrics[mode]["clip_ratio/high_mean"].append(gathered_high_clip.nanmean().item())
- self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item())
- gathered_clip_ratio = self.accelerator.gather(clip_ratio)
- self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item())
- return loss
-
- def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: Optional[list[str]] = None):
- inputs = self._prepare_inputs(inputs)
- with torch.no_grad():
- with self.compute_loss_context_manager():
- loss = self.compute_loss(model, inputs)
- loss = loss.mean().detach()
- return loss, None, None
-
- def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
- mode = "train" if self.model.training else "eval"
- metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics
-
- # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
- # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
- if mode == "eval":
- metrics = {f"eval_{key}": val for key, val in metrics.items()}
-
- logs = {**logs, **metrics}
- super().log(logs, start_time)
- self._metrics[mode].clear()
-
- if self.accelerator.is_main_process and self.log_completions:
- if is_rich_available():
- print_prompt_completions_sample(
- self._logs["prompt"],
- self._logs["completion"],
- self._logs["rewards"],
- self._logs["advantages"],
- self.state.global_step,
- self.num_completions_to_print,
- )
-
- if self.args.report_to and "wandb" in self.args.report_to and wandb.run is not None:
- import pandas as pd
-
- table = {
- "step": [str(self.state.global_step)] * len(self._logs["prompt"]),
- "prompt": self._logs["prompt"],
- "completion": self._logs["completion"],
- **self._logs["rewards"],
- "advantage": self._logs["advantages"],
- }
-
- if self._logs["images"]:
- table["images"] = []
- for image_list in self._logs["images"]:
- # Convert images to wandb Image objects for proper visualization
- table["images"].append([wandb.Image(image) for image in image_list])
-
- df = pd.DataFrame(table)
- if self.wandb_log_unique_prompts:
- df = df.drop_duplicates(subset=["prompt"])
- wandb.log({"completions": wandb.Table(dataframe=df)})
-
- # Ensure the model card is saved along with the checkpoint
- def _save_checkpoint(self, model, trial):
- if self.args.hub_model_id is None:
- model_name = Path(self.args.output_dir).name
- else:
- model_name = self.args.hub_model_id.split("/")[-1]
- self.create_model_card(model_name=model_name)
- super()._save_checkpoint(model, trial)
-class UnslothGRPOTrainer(_UnslothGRPOTrainer):
- """
-
- Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the
- paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language
- Models](https://huggingface.co/papers/2402.03300).
-
- Example:
-
- ```python
- from datasets import load_dataset
- from trl import GRPOTrainer
-
- dataset = load_dataset("trl-lib/tldr", split="train")
- def reward_func(completions, **kwargs):
- # Dummy reward function that rewards completions with more unique letters.
- return [float(len(set(completion))) for completion in completions]
- trainer = GRPOTrainer(
- model="Qwen/Qwen2-0.5B-Instruct",
- reward_funcs=reward_func,
- train_dataset=dataset,
- )
-
- trainer.train()
- ```
-
- Args:
- model (`Union[str, PreTrainedModel]`):
- Model to be trained. Can be either:
-
- - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a
- path to a *directory* containing model weights saved using
- [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
- using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keyword arguments in
- `args.model_init_kwargs`.
- - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
- reward_funcs (`Union[RewardFunc, list[RewardFunc]]`):
- Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward
- functions with the prompts and completions and sum the rewards. Can be either:
-
- - A single reward function, such as:
- - A string: The *model ID* of a pretrained model hosted inside a model repo on huggingface.co, or a
- path to a *directory* containing model weights saved using
- [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
- using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the
- keyword arguments in `args.model_init_kwargs`.
- - A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported.
- - A custom reward function: The function is provided with the prompts and the generated completions,
- plus any additional columns in the dataset. It should return a list of rewards. Custom reward
- functions can also return `None` when the reward is not applicable to those samples. This is useful
- for multi-task training where different reward functions apply to different types of samples. When a
- reward function returns `None` for a sample, that reward function is excluded from the reward
- calculation for that sample. For more details, see [Using a custom reward
- function](#using-a-custom-reward-function).
-
- The trainer's state is also passed to the reward function. The trainer's state is an instance of
- [`~transformers.TrainerState`] and can be accessed by accessing the `trainer_state` argument to the
- reward function's signature.
- - A list of reward functions, where each item can independently be any of the above types. Mixing different
- types within the list (e.g., a string model ID and a custom reward function) is allowed.
- args ([`GRPOConfig`], *optional*):
- Configuration for this trainer. If `None`, a default configuration is used.
- train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
- Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset is
- ignored. The format of the samples can be either:
-
- - [Standard](dataset_formats#standard): Each sample contains plain text.
- - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
- and content).
- eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
- Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
- processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.ProcessorMixin`], *optional*):
- Processing class used to process the data. The padding side must be set to "left". If `None`, the
- processing class is loaded from the model's name with [`~transformers.AutoProcessor.from_pretrained`]. A
- padding token, `tokenizer.pad_token`, must be set. If the processing class has not set a padding token,
- `tokenizer.eos_token` will be used as the default.
- reward_processing_classes ([`~transformers.PreTrainedTokenizerBase`] or `list[PreTrainedTokenizerBase]`, *optional*):
- Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either:
-
- - A single processing class: Used when `reward_funcs` contains only one reward function.
- - A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`.
- If set to `None`, or if an element of the list corresponding to a [`~transformers.PreTrainedModel`] is
- `None`, the tokenizer for the model is automatically loaded using
- [`~transformers.AutoTokenizer.from_pretrained`]. For elements in `reward_funcs` that are custom reward
- functions (not [`~transformers.PreTrainedModel`]), the corresponding entries in `reward_processing_classes`
- are ignored.
- callbacks (list of [`~transformers.TrainerCallback`], *optional*):
- List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed
- in [here](https://huggingface.co/docs/transformers/main_classes/callback).
-
- If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
- method.
- optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
- A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
- model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
- peft_config ([`~peft.PeftConfig`], *optional*):
- PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
-
- """
- def __init__(
- self,
- model,
- reward_funcs,
- args = None,
- train_dataset = None,
- eval_dataset = None,
- processing_class = None,
- reward_processing_classes = None,
- callbacks = None,
- peft_config = None,
- **kwargs
- ):
- if args is None: args = UnslothGRPOConfig()
- use_bf16 = getattr(args, 'bf16', False)
- if type(use_bf16) is not bool: use_bf16 = False
- use_fp16 = getattr(args, 'fp16', False)
- if type(use_fp16) is not bool: use_fp16 = False
- force_float32 = False
- full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1'
- if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'):
- print('Unsloth: Switching to float32 training since model cannot work with float16')
- force_float32 = True
- mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
- dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None)
- if dtype is None: dtype = model.get_input_embeddings().weight.dtype
- from unsloth_zoo.utils import _get_dtype
- dtype = _get_dtype(dtype)
- float16 = dtype == torch.float16
- if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
- if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
- if force_float32:
- # Forced float32 training
- args.fp16 = False
- args.bf16 = False
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
- if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
- # args.mixed_precision is a new argument which needs to be set now
- elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
- # Mixed precision training
- args.fp16 = float16
- args.bf16 = not float16
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
- if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16'
- # args.mixed_precision is a new argument which needs to be set now
- elif mixed_precision_dtype == 'bfloat16':
- # Both False since bfloat16 full finetuning doesn't do any autocasting.
- args.fp16 = False
- args.bf16 = False
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
- if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
- # args.mixed_precision is a new argument which needs to be set now
-
- if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
- args.eval_strategy = 'steps'
- if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
- ga_steps = getattr(args, 'gradient_accumulation_steps', None)
- if ga_steps is not None and ga_steps > 1:
- from transformers import __version__ as transformers_version
- if Version(transformers_version) <= Version('4.45.2'):
- print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
- '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
- if getattr(args, 'eval_strategy', 'no') != 'no':
- eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
- if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
- if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
- fp16_full_eval = getattr(args, 'fp16_full_eval', False)
- if type(fp16_full_eval) is not bool: fp16_full_eval = False
- bf16_full_eval = getattr(args, 'bf16_full_eval', False)
- if type(bf16_full_eval) is not bool: bf16_full_eval = False
- if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
- if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
- if force_float32:
- args.bf16_full_eval = False
- args.fp16_full_eval = False
- elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
- args.bf16_full_eval = True
- args.fp16_full_eval = False
- elif not bf16_full_eval and not fp16_full_eval:
- args.bf16_full_eval = args.bf16
- args.fp16_full_eval = args.fp16
- _output_logits = False
- if locals().get('compute_metrics', None) is not None: _output_logits = True
- if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
- if _output_logits:
- os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
- if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
- pass
- else:
- model_max_seq_length = getattr(model, 'max_seq_length', None)
- args_max_seq_length = getattr(args, 'max_seq_length', None)
- if args_max_seq_length is None and model_max_seq_length is not None:
- max_seq_length = model.max_seq_length
- if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
- elif args_max_seq_length is not None and model_max_seq_length is not None:
- if args_max_seq_length > model_max_seq_length:
- print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but '
- 'the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.')
- args.max_seq_length = model_max_seq_length
- if model is not None and hasattr(model, 'for_training'):
- model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
- if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
- if 'processing_class' in locals():
- if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
- if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
- other_metrics = []
- if not isinstance(reward_funcs, list): _reward_funcs = [reward_funcs]
- else: _reward_funcs = reward_funcs
- for reward_func in _reward_funcs:
- try:
- reward_func_name = reward_func.__name__
- if True:
- other_metrics.append(f'rewards/{reward_func_name}/mean')
- if True:
- other_metrics.append(f'rewards/{reward_func_name}/std')
- if False:
- other_metrics.append(f'rewards/{reward_func_name}')
- except: pass
-
- from unsloth_zoo.logging_utils import PatchRLStatistics
- PatchRLStatistics('grpo_trainer', other_metrics)
-
- # [TODO] Fix up DataParallel multiplying batch sizes
- # [TODO] DDP works, but DP seems to not work? [TODO]
- if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1:
- if getattr(args, "_n_gpu", 1) != 1:
- args._n_gpu = 1
- if "model" in locals() and hasattr(model, "for_training"):
- model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
- super().__init__(
- model = model,
- reward_funcs = reward_funcs,
- args = args,
- train_dataset = train_dataset,
- eval_dataset = eval_dataset,
- processing_class = processing_class,
- reward_processing_classes = reward_processing_classes,
- callbacks = callbacks,
- peft_config = peft_config,**kwargs)
- if "model" in locals() and hasattr(model, "for_inference"):
- model.for_inference()
- if hasattr(self, 'neftune_hook_handle'):
- self.neftune_hook_handle.remove()
- if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
- if getattr(args, 'neftune_noise_alpha', None) is not None:
- model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
- pass
- if hasattr(self, 'accelerator'):
- scaler = self.accelerator.scaler
- current_model = model
- while hasattr(current_model, 'model'):
- current_model.accelerator_scaler = scaler
- current_model = current_model.model
- current_model.accelerator_scaler = scaler
- pass
- if hasattr(self, 'train'):
- self.train = MethodType(prepare_for_training_mode(self.__class__.train), self)
- pass
- if hasattr(self, 'llm') and self.llm is not None and hasattr(self.llm, 'get_tokenizer'):
- _vllm_tok = self.llm.get_tokenizer()
- _pc = getattr(self, 'processing_class', None) or getattr(self, 'tokenizer', None)
- if _vllm_tok is not None and _pc is not None and getattr(_pc, 'chat_template', None) is not None and getattr(_vllm_tok, 'chat_template', None) is None:
- _vllm_tok.chat_template = _pc.chat_template
- pass
-
-pass
-
-
-if hasattr(logger, "addFilter"):
- import logging
- class HideLoggingMessage(logging.Filter):
- def __init__(self, text): self.text = text
- def filter(self, x): return not (self.text in x.getMessage())
- pass
- logger.addFilter(HideLoggingMessage("`use_cache=True`"))
-
diff --git a/unsloth_compiled_cache/UnslothKTOTrainer.py b/unsloth_compiled_cache/UnslothKTOTrainer.py
deleted file mode 100644
index c4f9194..0000000
--- a/unsloth_compiled_cache/UnslothKTOTrainer.py
+++ /dev/null
@@ -1,2377 +0,0 @@
-"""
-2026.2.1
-2026.2.1
-4.57.6
-0.24.0
-__UNSLOTH_VERSIONING__
-"""
-
-# Unsloth auto generated code
-# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Lesser General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU Lesser General Public License
-# along with this program. If not, see .
-
-from torch import Tensor
-import torch
-import torch.nn as nn
-from torch.nn import functional as F
-from unsloth_zoo.temporary_patches.common import torch_compile
-from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
-from trl.trainer.kto_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, BaseTrainer, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, KTOConfig, KTOTrainer, Literal, Optional, PartialState, Path, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SequentialSampler, TrainerCallback, TrainingArguments, Union, _get_kl_dataset, _process_tokens, _tokenize, autocast, concatenate_datasets, contextmanager, create_reference_model, defaultdict, disable_dropout_in_model, has_length, inspect, is_comet_available, is_liger_kernel_available, is_peft_available, is_wandb_available, itemgetter, log_table_to_comet_experiment, logger, logging, maybe_apply_chat_template, maybe_extract_prompt, maybe_unpair_preference_dataset, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_deepspeed, prepare_model_for_kbit_training, random, selective_log_softmax, textwrap, torch, tqdm, warnings, AutoModelForCausalLM, BaseImageProcessor, Callable, DPODataCollatorWithPadding, DataCollator, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, KTOConfig, KTOTrainer, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, TrainingArguments, Union, autocast, concatenate_datasets, create_reference_model, defaultdict, disable_dropout_in_model, inspect, is_comet_available, is_liger_kernel_available, is_peft_available, is_wandb_available, logger, maybe_apply_chat_template, maybe_extract_prompt, maybe_unpair_preference_dataset, nn, np, os, peft_module_casting_to_bf16, prepare_deepspeed, prepare_model_for_kbit_training, torch, warnings, F, Optional, PeftModel, PreTrainedModel, is_peft_available, logger, os, torch, F, nn, np, os, selective_log_softmax, torch)
-
-
-import os
-from typing import *
-from dataclasses import dataclass, field
-from packaging.version import Version
-import torch
-import numpy as np
-from contextlib import nullcontext
-from torch.nn import functional as F
-import inspect
-from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
-from transformers.training_args import ParallelMode
-from unsloth_zoo.device_type import DEVICE_TYPE, device_synchronize
-
-# Wrap trainer with padding to right and enable training mode
-# Also patches W&B since multiple runs must use wandb.finish()
-import functools
-from types import MethodType
-try:
- from unsloth_zoo.gradient_checkpointing import reset_unsloth_gradient_checkpointing_buffers
-except:
- def reset_unsloth_gradient_checkpointing_buffers(): pass
-def prepare_for_training_mode(f):
- @functools.wraps(f)
- def wrapper(self, *args, **kwargs):
- # Enable training mode
- _was_training = None
- # Get gradient checkpointing setting from training arguments
- use_gc = getattr(self.args, 'gradient_checkpointing', True)
- if hasattr(self, 'model') and hasattr(self.model, "training"):
- _was_training = self.model.training
- if hasattr(self, 'model') and hasattr(self.model, "for_training"):
- self.model.for_training(use_gradient_checkpointing=use_gc)
- output = f(self, *args, **kwargs)
- # Restore previous mode when possible
- if hasattr(self, 'model') and hasattr(self.model, "for_inference"):
- if _was_training is False:
- self.model.for_inference()
- elif _was_training is True and hasattr(self.model, "for_training"):
- self.model.for_training(use_gradient_checkpointing=use_gc)
- # Reset gradient checkpointing buffers to free memory while staying ready for next run
- try:
- reset_unsloth_gradient_checkpointing_buffers()
- except:
- pass
- # Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run
- try:
- import wandb
- wandb.finish()
- except:
- pass
- return output
- return wrapper
-pass
-
-torch_compile_options = {
- "epilogue_fusion" : True,
- "max_autotune" : False,
- "shape_padding" : True,
- "trace.enabled" : False,
- "triton.cudagraphs" : False,
-}
-
-@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
-def chunked_hidden_states_selective_log_softmax(
- hidden_states: torch.Tensor,
- lm_head: torch.Tensor,
- index: torch.Tensor,
- chunks: int = 4,
- logit_scale_multiply: float = 0.0,
- logit_scale_divide: float = 0.0,
- logit_softcapping: float = 0.0,
- temperature: float = 1.0,
-) -> torch.Tensor:
- # All Unsloth Zoo code licensed under AGPL3
- flat_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1])
- flat_index = index.reshape(-1)
-
- chunked_hidden_states = torch.chunk(flat_hidden_states, chunks=chunks, dim=0)
- chunked_index = torch.chunk(flat_index, chunks=chunks, dim=0)
-
- all_per_token_logps = []
-
- for chunk_hidden_states, chunk_index in zip(chunked_hidden_states, chunked_index):
- chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t()
-
- if logit_scale_multiply != 0.0:
- chunk_logits = chunk_logits * logit_scale_multiply
- if logit_scale_divide != 0.0:
- chunk_logits = chunk_logits / logit_scale_divide
- if logit_softcapping != 0.0:
- chunk_logits = chunk_logits * torch.tanh(chunk_logits / logit_softcapping)
-
- chunk_logits = chunk_logits.to(torch.float32)
-
- if temperature != 1.0:
- chunk_logits = chunk_logits / temperature
-
- selected_logits = torch.gather(chunk_logits, dim=-1, index=chunk_index.unsqueeze(-1)).squeeze(-1)
- logsumexp_values = torch.logsumexp(chunk_logits, dim=-1)
- per_token_logps = selected_logits - logsumexp_values
- all_per_token_logps.append(per_token_logps)
-
- all_per_token_logps = torch.concat(all_per_token_logps)
-
- all_per_token_logps = all_per_token_logps.reshape((hidden_states.shape[0], hidden_states.shape[1]))
- return all_per_token_logps
-
-@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
-def chunked_selective_log_softmax(logits, index):
- # Split into 4 chunks only
- chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
- chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
- all_per_token_logps = []
- # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
- for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
- chunk_logits = chunk_logits.to(torch.float32)
- selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
- logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
- per_token_logps = selected_logits - logsumexp_values
- all_per_token_logps.append(per_token_logps)
- pass
- all_per_token_logps = torch.concat(all_per_token_logps)
- all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
- return all_per_token_logps
-
-def calculate_pad_tokens_in_prompt(
- input_ids: torch.Tensor,
- logits_to_keep: int,
- pad_token_id: int
-) -> torch.Tensor:
- """
- Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens
- """
- if logits_to_keep >= input_ids.shape[1]:
- raise ValueError("logits_to_keep must be smaller than the sequence length.")
-
- prompt_section = input_ids[:, :-logits_to_keep]
-
- padding_mask = (prompt_section == pad_token_id)
-
- pad_token_counts = padding_mask.sum(dim=1)
-
- return pad_token_counts
-
-def create_completion_attention_mask(
- completion_input_ids: torch.Tensor,
- left_pad_tokens_per_prompt: torch.Tensor,
- max_left_pad: int,
- pad_token_id: int
-) -> torch.Tensor:
- """
- Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad]
-
- Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens
- and pad are pad tokens, this function would make a completion mask that would 0 out the pad
- and p tokens. so in this example [0,0,0,1,1,1,0,0,0]
- """
- batch_size, completion_len = completion_input_ids.shape
- device = completion_input_ids.device
-
- num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt
-
- indices = torch.arange(completion_len, device=device).unsqueeze(0)
- shift_mask = indices >= num_tokens_to_mask.unsqueeze(1)
-
- non_padding_mask = (completion_input_ids != pad_token_id)
-
- final_mask = shift_mask & non_padding_mask
-
- return final_mask
-
-def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor:
- """
- Moves all padding tokens in each sequence of a batch to the right.
- """
- mask = (tensor != pad_id)
- # Must do stable=True since binary mark is unordered
- sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True)
- packed_tensor = torch.gather(tensor, 1, sorted_indices)
- return packed_tensor
-
-def align_logprobs_with_mask(
- logprob_tensor: torch.Tensor,
- attention_mask: torch.Tensor,
- pad_value: float = 0.0
-) -> torch.Tensor:
- """
- Aligns a log probability tensor with a given attention mask.
- """
-
- device = logprob_tensor.device
- batch_size, logprob_seq_len = logprob_tensor.shape
- mask_seq_len = attention_mask.shape[1]
-
- padded_logprobs = torch.full(
- attention_mask.shape,
- fill_value=pad_value,
- dtype=logprob_tensor.dtype,
- device=device
- )
-
- left_pad_counts = torch.argmax(attention_mask, dim=1)
-
- cols = torch.arange(logprob_seq_len, device=device)
- dest_indices = left_pad_counts.unsqueeze(1) + cols
-
- # Create destination row indices
- # Shape: [batch_size, logprob_seq_len]
- row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices)
-
- # --- 4. Filter out-of-bounds indices and perform assignment ---
- # Create a mask to identify only the indices that are within the bounds
- # of the target tensor's sequence length.
- valid_mask = dest_indices < mask_seq_len
-
- # Use this mask to select only the valid row indices, column indices,
- # and the corresponding values from the logprob tensor.
- # This flattens the selected elements into 1D tensors.
- valid_rows = row_indices[valid_mask]
- valid_cols = dest_indices[valid_mask]
- valid_vals = logprob_tensor[valid_mask]
-
- # Place the valid values into their correct positions in the padded tensor
- # using a single, efficient advanced indexing operation.
- padded_logprobs[valid_rows, valid_cols] = valid_vals
-
- return padded_logprobs
-
-def autotune_batch_and_chunks(
- total_input_rows,
- seq_len,
- hidden_size,
- vocab_size,
- dtype_bytes=16,
- multiplier=None
-):
- if multiplier is None:
- final_m = max(4, seq_len // 4096)
- else:
- final_m = multiplier
-
- if torch.cuda.is_available():
- free_bytes, _ = torch.cuda.mem_get_info()
- limit_gb = (free_bytes / (1024**3))*.80
- elif hasattr(torch, "xpu") and torch.xpu.is_available():
- # For XPU: estimate free memory from total - reserved
- total_mem = torch.xpu.get_device_properties(0).total_memory
- reserved_mem = torch.xpu.memory_reserved()
- free_bytes = total_mem - reserved_mem
- limit_gb = (free_bytes / (1024**3)) * 0.80
- else:
- # Fallback: assume 8GB available
- limit_gb = 8.0
-
- bytes_to_gb = 1024**3
-
- b_vals = torch.arange(total_input_rows, 0, -1, device='cpu', dtype=torch.float32)
-
- hidden_gb = (b_vals * seq_len * hidden_size * dtype_bytes) / bytes_to_gb
-
- base_logits = ((b_vals/total_input_rows) * b_vals * seq_len * vocab_size * dtype_bytes) / bytes_to_gb
- logits_gb = base_logits / final_m
-
- total_mem_gb = hidden_gb + logits_gb
-
- valid_mask = total_mem_gb <= limit_gb
- valid_indices = torch.nonzero(valid_mask, as_tuple=False)
-
- if valid_indices.shape[0] == 0:
- #This means your GPU will OOM
- return 4, final_m
-
- best_idx = valid_indices[0].item()
- final_b = int(b_vals[best_idx].item())
-
- return final_b, final_m
-@dataclass
-class UnslothKTOConfig(KTOConfig):
- """
-
- Configuration class for the [`KTOTrainer`].
-
- This class includes only the parameters that are specific to KTO training. For a full list of training arguments,
- please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may
- differ from those in [`~transformers.TrainingArguments`].
-
- Using [`~transformers.HfArgumentParser`] we can turn this class into
- [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
- command line.
-
- Parameters:
- max_length (`int` or `None`, *optional*, defaults to `1024`):
- Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
- to use the default data collator.
- max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
- Maximum length of the prompt. This argument is required if you want to use the default data collator.
- max_completion_length (`int`, *optional*):
- Maximum length of the completion. This argument is required if you want to use the default data collator
- and your model is an encoder-decoder.
- beta (`float`, *optional*, defaults to `0.1`):
- Parameter controlling the deviation from the reference model. Higher β means less deviation from the
- reference model.
- loss_type (`str`, *optional*, defaults to `"kto"`):
- Type of loss to use. Possible values are:
-
- - `"kto"`: KTO loss from the [KTO](https://huggingface.co/papers/2402.01306) paper.
- - `"apo_zero_unpaired"`: Unpaired variant of APO-zero loss from the
- [APO](https://huggingface.co/papers/2408.06266) paper.
-
- desirable_weight (`float`, *optional*, defaults to `1.0`):
- Desirable losses are weighed by this factor to counter unequal number of desirable and undesirable paris.
- undesirable_weight (`float`, *optional*, defaults to `1.0`):
- Undesirable losses are weighed by this factor to counter unequal number of desirable and undesirable pairs.
- label_pad_token_id (`int`, *optional*, defaults to `-100`):
- Label pad token id. This argument is required if you want to use the default data collator.
- padding_value (`int`, *optional*):
- Padding value to use. If `None`, the padding value of the tokenizer is used.
- truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
- Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
- This argument is required if you want to use the default data collator.
- generate_during_eval (`bool`, *optional*, defaults to `False`):
- If `True`, generates and logs completions from both the model and the reference model to W&B or Comet
- during evaluation.
- is_encoder_decoder (`bool`, *optional*):
- When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
- you need to specify if the model returned by the callable is an encoder-decoder model.
- precompute_ref_log_probs (`bool`, *optional*, defaults to `False`):
- Whether to precompute reference model log probabilities for training and evaluation datasets. This is
- useful when training without the reference model to reduce the total GPU memory needed.
- model_init_kwargs (`dict[str, Any]`, *optional*):
- Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
- string.
- ref_model_init_kwargs (`dict[str, Any]`, *optional*):
- Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the reference model
- from a string.
- dataset_num_proc: (`int`, *optional*):
- Number of processes to use for processing the dataset.
- disable_dropout (`bool`, *optional*, defaults to `True`):
- Whether to disable dropout in the model and reference model.
- use_liger_loss (`bool`, *optional*, defaults to `False`):
- Whether to use Liger loss. It requires liger-kernel to be installed.
- base_model_attribute_name (`str`, *optional*, defaults to `"model"`):
- Name of the attribute in the model that contains the base model. This is used to get the base model from
- the model when the model does not have a `get_decoder` method in the case when `use_liger_loss` is `True`.
-
- """
- vllm_sampling_params: Optional[Any] = field(
- default = None,
- metadata = {'help': 'vLLM SamplingParams'},
- )
- unsloth_num_chunks : Optional[int] = field(
- default = -1,
- metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
- )
- unsloth_logit_chunk_multiplier : Optional[int] = field(
- default = None,
- metadata = {'help': 'Multiplier for chunked logit computations.'},
- )
- unsloth_grpo_mini_batch : Optional[int] = field(
- default = None,
- metadata = {'help': 'Mini batch size for GRPO hidden state accumulation. Default is None unless user defines it.'},
- )
- max_seq_length : Optional[int] = field(
- default = None,
- metadata = {'help': 'Maximum sequence length to truncate to.'},
- )
- def __init__(
- self,
- output_dir = None,
- overwrite_output_dir = None,
- do_train = False,
- do_eval = False,
- do_predict = False,
- eval_strategy = 'no',
- prediction_loss_only = False,
- per_device_train_batch_size = 4,
- per_device_eval_batch_size = 4,
- per_gpu_train_batch_size = None,
- per_gpu_eval_batch_size = None,
- gradient_accumulation_steps = 2,
- eval_accumulation_steps = 2,
- eval_delay = 0,
- torch_empty_cache_steps = 250,
- learning_rate = 5e-05,
- weight_decay = 0.01,
- adam_beta1 = 0.9,
- adam_beta2 = 0.999,
- adam_epsilon = 1e-08,
- max_grad_norm = 1.0,
- num_train_epochs = 3.0,
- max_steps = -1,
- lr_scheduler_type = 'linear',
- lr_scheduler_kwargs = None,
- warmup_ratio = 0.1,
- warmup_steps = 0,
- log_level = 'passive',
- log_level_replica = 'warning',
- log_on_each_node = True,
- logging_dir = None,
- logging_strategy = 'steps',
- logging_first_step = False,
- logging_steps = 1,
- logging_nan_inf_filter = False,
- save_strategy = 'steps',
- save_steps = 500,
- save_total_limit = None,
- save_safetensors = True,
- save_on_each_node = False,
- save_only_model = False,
- restore_callback_states_from_checkpoint = False,
- no_cuda = False,
- use_cpu = False,
- use_mps_device = False,
- seed = 3407,
- data_seed = 3407,
- jit_mode_eval = False,
- bf16 = False,
- fp16 = False,
- fp16_opt_level = 'O1',
- half_precision_backend = 'auto',
- bf16_full_eval = False,
- fp16_full_eval = False,
- tf32 = None,
- local_rank = -1,
- ddp_backend = None,
- tpu_num_cores = None,
- tpu_metrics_debug = False,
- debug = '',
- dataloader_drop_last = False,
- eval_steps = None,
- dataloader_num_workers = 0,
- dataloader_prefetch_factor = None,
- past_index = -1,
- run_name = None,
- disable_tqdm = None,
- remove_unused_columns = True,
- label_names = None,
- load_best_model_at_end = False,
- metric_for_best_model = None,
- greater_is_better = None,
- ignore_data_skip = False,
- fsdp = None,
- fsdp_min_num_params = 0,
- fsdp_config = None,
- fsdp_transformer_layer_cls_to_wrap = None,
- accelerator_config = None,
- parallelism_config = None,
- deepspeed = None,
- label_smoothing_factor = 0.0,
- optim = 'adamw_8bit',
- optim_args = None,
- adafactor = False,
- group_by_length = False,
- length_column_name = 'length',
- report_to = 'none',
- project = 'huggingface',
- trackio_space_id = 'trackio',
- ddp_find_unused_parameters = None,
- ddp_bucket_cap_mb = None,
- ddp_broadcast_buffers = None,
- dataloader_pin_memory = True,
- dataloader_persistent_workers = False,
- skip_memory_metrics = True,
- use_legacy_prediction_loop = False,
- push_to_hub = False,
- resume_from_checkpoint = None,
- hub_model_id = None,
- hub_strategy = 'every_save',
- hub_token = None,
- hub_private_repo = None,
- hub_always_push = False,
- hub_revision = None,
- gradient_checkpointing = True,
- gradient_checkpointing_kwargs = None,
- include_inputs_for_metrics = False,
- eval_do_concat_batches = True,
- fp16_backend = 'auto',
- push_to_hub_model_id = None,
- push_to_hub_organization = None,
- push_to_hub_token = None,
- mp_parameters = '',
- auto_find_batch_size = False,
- full_determinism = False,
- torchdynamo = None,
- ray_scope = 'last',
- ddp_timeout = 1800,
- torch_compile = False,
- torch_compile_backend = None,
- torch_compile_mode = None,
- include_tokens_per_second = False,
- include_num_input_tokens_seen = False,
- neftune_noise_alpha = None,
- optim_target_modules = None,
- batch_eval_metrics = False,
- eval_on_start = False,
- use_liger_kernel = False,
- liger_kernel_config = None,
- eval_use_gather_object = False,
- average_tokens_across_devices = True,
- max_length = 1024,
- max_prompt_length = 512,
- max_completion_length = None,
- beta = 0.1,
- loss_type = 'kto',
- desirable_weight = 1.0,
- undesirable_weight = 1.0,
- label_pad_token_id = -100,
- padding_value = None,
- truncation_mode = 'keep_end',
- generate_during_eval = False,
- is_encoder_decoder = None,
- disable_dropout = True,
- precompute_ref_log_probs = False,
- model_init_kwargs = None,
- ref_model_init_kwargs = None,
- dataset_num_proc = None,
- use_liger_loss = False,
- base_model_attribute_name = 'model',
- vllm_sampling_params = None,
- unsloth_num_chunks = -1,
- unsloth_logit_chunk_multiplier = None,
- unsloth_grpo_mini_batch = None,
- max_seq_length = None,
- **kwargs,
- ):
- if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
- if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
- if num_train_epochs is None:
- num_train_epochs = 3.0 # Default to 3 epochs if None, max_steps will override
- if output_dir is None and save_strategy == 'steps' and save_steps == 500:
- output_dir = 'unsloth_training_checkpoints'
- save_strategy = 'no'
- import multiprocessing as _mp
- if _mp.get_start_method() != 'fork':
- dataset_num_proc = None
- elif dataset_num_proc is None:
- import psutil
- dataset_num_proc = min(max((psutil.cpu_count() or 1)+4, 2), 64)
- memory_gb_left = psutil.virtual_memory().available / (1024**3)
- if memory_gb_left <= 2: dataset_num_proc = 1
- else: dataset_num_proc = min(dataset_num_proc, int(memory_gb_left))
-
- super().__init__(
- output_dir = output_dir,
- overwrite_output_dir = overwrite_output_dir,
- do_train = do_train,
- do_eval = do_eval,
- do_predict = do_predict,
- eval_strategy = eval_strategy,
- prediction_loss_only = prediction_loss_only,
- per_device_train_batch_size = per_device_train_batch_size,
- per_device_eval_batch_size = per_device_eval_batch_size,
- per_gpu_train_batch_size = per_gpu_train_batch_size,
- per_gpu_eval_batch_size = per_gpu_eval_batch_size,
- gradient_accumulation_steps = gradient_accumulation_steps,
- eval_accumulation_steps = eval_accumulation_steps,
- eval_delay = eval_delay,
- torch_empty_cache_steps = torch_empty_cache_steps,
- learning_rate = learning_rate,
- weight_decay = weight_decay,
- adam_beta1 = adam_beta1,
- adam_beta2 = adam_beta2,
- adam_epsilon = adam_epsilon,
- max_grad_norm = max_grad_norm,
- num_train_epochs = num_train_epochs,
- max_steps = max_steps,
- lr_scheduler_type = lr_scheduler_type,
- lr_scheduler_kwargs = lr_scheduler_kwargs,
- warmup_ratio = warmup_ratio,
- warmup_steps = warmup_steps,
- log_level = log_level,
- log_level_replica = log_level_replica,
- log_on_each_node = log_on_each_node,
- logging_dir = logging_dir,
- logging_strategy = logging_strategy,
- logging_first_step = logging_first_step,
- logging_steps = logging_steps,
- logging_nan_inf_filter = logging_nan_inf_filter,
- save_strategy = save_strategy,
- save_steps = save_steps,
- save_total_limit = save_total_limit,
- save_safetensors = save_safetensors,
- save_on_each_node = save_on_each_node,
- save_only_model = save_only_model,
- restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
- no_cuda = no_cuda,
- use_cpu = use_cpu,
- use_mps_device = use_mps_device,
- seed = seed,
- data_seed = data_seed,
- jit_mode_eval = jit_mode_eval,
- bf16 = bf16,
- fp16 = fp16,
- fp16_opt_level = fp16_opt_level,
- half_precision_backend = half_precision_backend,
- bf16_full_eval = bf16_full_eval,
- fp16_full_eval = fp16_full_eval,
- tf32 = tf32,
- local_rank = local_rank,
- ddp_backend = ddp_backend,
- tpu_num_cores = tpu_num_cores,
- tpu_metrics_debug = tpu_metrics_debug,
- debug = debug,
- dataloader_drop_last = dataloader_drop_last,
- eval_steps = eval_steps,
- dataloader_num_workers = dataloader_num_workers,
- dataloader_prefetch_factor = dataloader_prefetch_factor,
- past_index = past_index,
- run_name = run_name,
- disable_tqdm = disable_tqdm,
- remove_unused_columns = remove_unused_columns,
- label_names = label_names,
- load_best_model_at_end = load_best_model_at_end,
- metric_for_best_model = metric_for_best_model,
- greater_is_better = greater_is_better,
- ignore_data_skip = ignore_data_skip,
- fsdp = fsdp,
- fsdp_min_num_params = fsdp_min_num_params,
- fsdp_config = fsdp_config,
- fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
- accelerator_config = accelerator_config,
- parallelism_config = parallelism_config,
- deepspeed = deepspeed,
- label_smoothing_factor = label_smoothing_factor,
- optim = optim,
- optim_args = optim_args,
- adafactor = adafactor,
- group_by_length = group_by_length,
- length_column_name = length_column_name,
- report_to = report_to,
- project = project,
- trackio_space_id = trackio_space_id,
- ddp_find_unused_parameters = ddp_find_unused_parameters,
- ddp_bucket_cap_mb = ddp_bucket_cap_mb,
- ddp_broadcast_buffers = ddp_broadcast_buffers,
- dataloader_pin_memory = dataloader_pin_memory,
- dataloader_persistent_workers = dataloader_persistent_workers,
- skip_memory_metrics = skip_memory_metrics,
- use_legacy_prediction_loop = use_legacy_prediction_loop,
- push_to_hub = push_to_hub,
- resume_from_checkpoint = resume_from_checkpoint,
- hub_model_id = hub_model_id,
- hub_strategy = hub_strategy,
- hub_token = hub_token,
- hub_private_repo = hub_private_repo,
- hub_always_push = hub_always_push,
- hub_revision = hub_revision,
- gradient_checkpointing = gradient_checkpointing,
- gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
- include_inputs_for_metrics = include_inputs_for_metrics,
- eval_do_concat_batches = eval_do_concat_batches,
- fp16_backend = fp16_backend,
- push_to_hub_model_id = push_to_hub_model_id,
- push_to_hub_organization = push_to_hub_organization,
- push_to_hub_token = push_to_hub_token,
- mp_parameters = mp_parameters,
- auto_find_batch_size = auto_find_batch_size,
- full_determinism = full_determinism,
- torchdynamo = torchdynamo,
- ray_scope = ray_scope,
- ddp_timeout = ddp_timeout,
- torch_compile = torch_compile,
- torch_compile_backend = torch_compile_backend,
- torch_compile_mode = torch_compile_mode,
- include_tokens_per_second = include_tokens_per_second,
- include_num_input_tokens_seen = include_num_input_tokens_seen,
- neftune_noise_alpha = neftune_noise_alpha,
- optim_target_modules = optim_target_modules,
- batch_eval_metrics = batch_eval_metrics,
- eval_on_start = eval_on_start,
- use_liger_kernel = use_liger_kernel,
- liger_kernel_config = liger_kernel_config,
- eval_use_gather_object = eval_use_gather_object,
- average_tokens_across_devices = average_tokens_across_devices,
- max_length = max_length,
- max_prompt_length = max_prompt_length,
- max_completion_length = max_completion_length,
- beta = beta,
- loss_type = loss_type,
- desirable_weight = desirable_weight,
- undesirable_weight = undesirable_weight,
- label_pad_token_id = label_pad_token_id,
- padding_value = padding_value,
- truncation_mode = truncation_mode,
- generate_during_eval = generate_during_eval,
- is_encoder_decoder = is_encoder_decoder,
- disable_dropout = disable_dropout,
- precompute_ref_log_probs = precompute_ref_log_probs,
- model_init_kwargs = model_init_kwargs,
- ref_model_init_kwargs = ref_model_init_kwargs,
- dataset_num_proc = dataset_num_proc,
- use_liger_loss = use_liger_loss,
- base_model_attribute_name = base_model_attribute_name,**kwargs)
- self.vllm_sampling_params = vllm_sampling_params
- self.unsloth_num_chunks = unsloth_num_chunks
- if unsloth_grpo_mini_batch is not None:
- if self.generation_batch_size >= unsloth_grpo_mini_batch:
- self.unsloth_grpo_mini_batch = unsloth_grpo_mini_batch
- else:
- raise ValueError(
- f"Unsloth GRPO mini batch size needs to be less than or equal to the effective generation batch size, "
- f"which is self.per_device_train_batch_size * gradient_accumulation_steps."
- )
- self.unsloth_logit_chunk_multiplier = unsloth_logit_chunk_multiplier
- self.max_seq_length = max_seq_length
-
-pass
-
-class _UnslothKTOTrainer(BaseTrainer):
- r""""""
-
- _tag_names = ["trl", "kto"]
- _name = "KTO"
- _paper = {
- "title": "KTO: Model Alignment as Prospect Theoretic Optimization",
- "id": "2402.01306",
- # docstyle-ignore
- "citation": textwrap.dedent("""\
- @article{ethayarajh2024kto,
- title = {{KTO: Model Alignment as Prospect Theoretic Optimization}},
- author = {Kawin Ethayarajh and Winnie Xu and Niklas Muennighoff and Dan Jurafsky and Douwe Kiela},
- year = 2024,
- eprint = {arXiv:2402.01306},
- }"""),
- }
-
- def __init__(
- self,
- model: Union[PreTrainedModel, nn.Module, str] = None,
- ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
- args: KTOConfig = None,
- train_dataset: Optional[Dataset] = None,
- eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
- processing_class: Optional[
- Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
- ] = None,
- data_collator: Optional[DataCollator] = None,
- model_init: Optional[Callable[[], PreTrainedModel]] = None,
- callbacks: Optional[list[TrainerCallback]] = None,
- optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
- preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
- peft_config: Optional[dict] = None,
- compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
- model_adapter_name: Optional[str] = None,
- ref_adapter_name: Optional[str] = None,
- ):
- if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"):
- warnings.warn(
- "This trainer will soon be moved to trl.experimental and is a candidate for removal. If you rely on "
- "it and want it to remain, please share your comments here: "
- "https://github.com/huggingface/trl/issues/4223. Silence this warning by setting environment variable "
- "TRL_EXPERIMENTAL_SILENCE=1."
- )
- if type(args) is TrainingArguments:
- raise ValueError("Please use `KTOConfig` instead TrainingArguments.")
-
- if not isinstance(model, str) and ref_model is model:
- raise ValueError(
- "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
- "same as `model`, you must mass a copy of it, or `None` if you use peft."
- )
-
- if args.model_init_kwargs is None:
- model_init_kwargs = {}
- elif not isinstance(model, str):
- raise ValueError("You passed model_kwargs to the KTOTrainer. But your model is already instantiated.")
- else:
- model_init_kwargs = args.model_init_kwargs
- dtype = model_init_kwargs.get("dtype")
- if dtype is not None:
- # Convert to `torch.dtype` if an str is passed
- if isinstance(dtype, str) and dtype != "auto":
- dtype = getattr(torch, dtype)
- if dtype != "auto" and not isinstance(dtype, torch.dtype):
- raise ValueError(
- f"Invalid `dtype` passed to the KTOConfig. Expected a string with either `torch.dtype` or 'auto', but got {dtype}."
- )
- model_init_kwargs["dtype"] = dtype
-
- if args.ref_model_init_kwargs is None:
- ref_model_init_kwargs = {}
- elif not isinstance(ref_model, str):
- raise ValueError(
- "You passed ref_model_kwargs to the KTOTrainer. But your ref_model is already instantiated."
- )
- else:
- ref_model_init_kwargs = args.ref_model_init_kwargs
- dtype = ref_model_init_kwargs.get("dtype")
- if dtype is not None:
- # Convert to `torch.dtype` if an str is passed
- if isinstance(dtype, str) and dtype != "auto":
- dtype = getattr(torch, dtype)
- if dtype != "auto" and not isinstance(dtype, torch.dtype):
- raise ValueError(
- f"Invalid `dtype` passed to the KTOConfig. Expected a string with either `torch.dtype` or 'auto', but got {dtype}."
- )
- ref_model_init_kwargs["dtype"] = dtype
-
- if isinstance(model, str):
- model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
-
- if isinstance(ref_model, str):
- ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs)
-
- # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
- # has been called in order to properly call autocast if needed.
- self._peft_has_been_casted_to_bf16 = False
-
- if not is_peft_available() and peft_config is not None:
- raise ValueError(
- "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it with `pip install peft` to use the PEFT models"
- )
- elif is_peft_available() and peft_config is not None:
- # if model is a peft model and we have a peft_config, we merge and unload it first
- if isinstance(model, PeftModel):
- model = model.merge_and_unload()
-
- if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
- _support_gc_kwargs = hasattr(
- args, "gradient_checkpointing_kwargs"
- ) and "gradient_checkpointing_kwargs" in list(
- inspect.signature(prepare_model_for_kbit_training).parameters
- )
-
- prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
-
- if _support_gc_kwargs:
- prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
-
- model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
- elif args.gradient_checkpointing:
- # For backward compatibility with older versions of transformers
- if hasattr(model, "enable_input_require_grads"):
- model.enable_input_require_grads()
- else:
-
- def make_inputs_require_grad(module, input, output):
- output.requires_grad_(True)
-
- model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
-
- # get peft model with the given config
- model = model
- if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
- peft_module_casting_to_bf16(model)
- # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
- self._peft_has_been_casted_to_bf16 = True
-
- # For models that use gradient_checkpointing, we need to attach a hook that enables input
- # to explicitly have `requires_grad=True`, otherwise training will either silently
- # fail or completely fail.
- elif args.gradient_checkpointing:
- # For backward compatibility with older versions of transformers
- if hasattr(model, "enable_input_require_grads"):
- model.enable_input_require_grads()
- else:
-
- def make_inputs_require_grad(module, input, output):
- output.requires_grad_(True)
-
- model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
-
- if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
- raise ValueError(
- "`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
- " Please install `wandb` or `comet-ml` to resolve."
- )
-
- if model is not None:
- self.is_encoder_decoder = model.config.is_encoder_decoder
- elif args.is_encoder_decoder is None:
- raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
- else:
- self.is_encoder_decoder = args.is_encoder_decoder
-
- self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
- self.model_adapter_name = model_adapter_name
- self.ref_adapter_name = ref_adapter_name
-
- if ref_model:
- self.ref_model = ref_model
- elif self.is_peft_model or args.precompute_ref_log_probs:
- # The `model` with adapters turned off will be used as the reference model
- self.ref_model = None
- else:
- self.ref_model = create_reference_model(model)
-
- if processing_class is None:
- raise ValueError(
- "max_length or a processing_class must be specified when using the default DPODataCollatorWithPadding"
- )
- if args.max_length is None:
- logger.warning(
- "When using DPODataCollatorWithPadding, you should set `max_length` in the KTOTrainer's init"
- " it will be set to `512` by default, but you should do it yourself in the future.",
- )
- max_length = 512
- if args.max_length is not None:
- max_length = args.max_length
-
- if args.max_prompt_length is None:
- logger.warning(
- "When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the KTOTrainer's init"
- " it will be set to `128` by default, but you should do it yourself in the future.",
- )
- max_prompt_length = 128
- if args.max_prompt_length is not None:
- max_prompt_length = args.max_prompt_length
-
- max_completion_length = None
- if args.max_completion_length is None and self.is_encoder_decoder:
- logger.warning(
- "When using DPODataCollatorWithPadding with an encoder decoder architecture, you should set `max_completion_length` in the KTOTrainer's init"
- " it will be set to `128` by default, but you should do it yourself in the future.",
- )
- max_completion_length = 128
- if args.max_completion_length is not None and self.is_encoder_decoder:
- max_completion_length = args.max_completion_length
-
- if data_collator is None:
- data_collator = DPODataCollatorWithPadding(
- pad_token_id=processing_class.pad_token_id,
- label_pad_token_id=args.label_pad_token_id,
- is_encoder_decoder=self.is_encoder_decoder,
- )
-
- if args.remove_unused_columns:
- args.remove_unused_columns = False
- # warn users
- logger.warning(
- "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your KTOConfig"
- " we have set it for you, but you should do it yourself in the future.",
- )
-
- self.use_dpo_data_collator = True
- else:
- self.use_dpo_data_collator = False
-
- # Disable dropout in the model and reference model
- if args.disable_dropout:
- disable_dropout_in_model(model)
- if self.ref_model is not None:
- disable_dropout_in_model(self.ref_model)
-
- self.loss_type = args.loss_type
- self.max_length = max_length
- self.generate_during_eval = args.generate_during_eval
- self.label_pad_token_id = args.label_pad_token_id
- self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
- self.max_prompt_length = max_prompt_length
- self.truncation_mode = args.truncation_mode
- self.max_completion_length = max_completion_length
- self.processing_class = processing_class
- self.precompute_ref_log_probs = args.precompute_ref_log_probs
-
- # Not all losses require a KL calculation
- self.calculate_KL = True
- if self.loss_type in ["apo_zero_unpaired"]:
- self.calculate_KL = False
-
- # Since ref_logs are precomputed on the first call to get_train/eval_dataloader
- # keep track of first called to avoid computation of future calls
- self._precomputed_train_ref_log_probs = False
- self._precomputed_eval_ref_log_probs = False
-
- # metric
- self._stored_metrics = defaultdict(lambda: defaultdict(list))
-
- # KTO parameter
- self.beta = args.beta
- self.desirable_weight = args.desirable_weight
- self.undesirable_weight = args.undesirable_weight
- self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
- self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
- if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
- logger.warning(
- "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
- "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
- "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
- "loss.",
- )
-
- # The trainer estimates the number of FLOPs [floating-point operations] using the number of elements in the
- # input tensor associated with the key "input_ids". However, in KTO, the sampled data does not include the
- # "input_ids" key. Instead, the available keys are "prompt_input_ids" and "completion_input_ids". As a result,
- # the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
- # operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
- # "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
- # issued.
- model.warnings_issued["estimate_tokens"] = True
-
- # Compute that only on the main process for faster data processing.
- # see: https://github.com/huggingface/trl/pull/1255
- with PartialState().main_process_first():
- # Extract the prompt if needed
- train_dataset = train_dataset.map(
- maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from train dataset"
- )
- # Unpair the dataset if needed
- train_dataset = maybe_unpair_preference_dataset(
- train_dataset, args.dataset_num_proc, desc="Unpairing train dataset"
- )
- # Apply the chat template if needed
- train_dataset = train_dataset.map(
- maybe_apply_chat_template,
- fn_kwargs={"tokenizer": processing_class},
- num_proc=args.dataset_num_proc,
- desc="Applying chat template to train dataset",
- )
- if eval_dataset is not None:
- eval_dataset = eval_dataset.map(
- maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from eval dataset"
- )
- eval_dataset = maybe_unpair_preference_dataset(
- eval_dataset, args.dataset_num_proc, desc="Unpairing eval dataset"
- )
- eval_dataset = eval_dataset.map(
- maybe_apply_chat_template,
- fn_kwargs={"tokenizer": processing_class},
- num_proc=args.dataset_num_proc,
- desc="Applying chat template to eval dataset",
- )
-
- # Tokenize and prepare the training datasets
- train_dataset = train_dataset.map(
- _tokenize,
- batched=True,
- fn_kwargs={"tokenizer": self.processing_class},
- num_proc=args.dataset_num_proc,
- desc="Tokenizing train dataset",
- )
-
- fn_kwargs = {
- "prefix": "",
- "is_encoder_decoder": self.is_encoder_decoder,
- "tokenizer": self.processing_class,
- "max_length": self.max_length,
- "truncation_mode": self.truncation_mode,
- "label_pad_token_id": self.label_pad_token_id,
- "max_prompt_length": self.max_prompt_length,
- "max_completion_length": self.max_completion_length,
- }
-
- train_dataset = train_dataset.map(
- _process_tokens,
- fn_kwargs=fn_kwargs,
- num_proc=args.dataset_num_proc,
- desc="Processing tokenized train dataset",
- )
-
- # Tokenize and prepare the eval datasets
- if eval_dataset is not None:
- eval_dataset = eval_dataset.map(
- _tokenize,
- fn_kwargs={"tokenizer": self.processing_class},
- batched=True,
- num_proc=args.dataset_num_proc,
- desc="Tokenizing eval dataset",
- )
-
- eval_dataset = eval_dataset.map(
- _process_tokens,
- fn_kwargs=fn_kwargs,
- num_proc=args.dataset_num_proc,
- desc="Processing tokenized eval dataset",
- )
-
- # Get KL datasets if needed
- if self.calculate_KL:
- if args.per_device_train_batch_size <= 1:
- raise ValueError(
- "Actual (not effective) batch size must be > 1. KTO will not work properly because the KL term will be equivalent to the implied reward."
- )
-
- # create pairs for estimating the KL term by flipping the matched pairs in each batch of size total_batch_size
- # i.e., [x_1, y_1], ..., [x_n, y_n] --> [x_1, y_n], ..., [x_n, y_1] = [x'_1, y'_1], ..., [x'_n, y'_n]
- train_kl_dataset = train_dataset.map(
- _get_kl_dataset,
- batched=True,
- batch_size=args.per_device_train_batch_size,
- num_proc=args.dataset_num_proc,
- desc="Extracting KL train dataset",
- )
-
- fn_kwargs["prefix"] = "KL_"
- train_kl_dataset = train_kl_dataset.map(
- _process_tokens,
- fn_kwargs=fn_kwargs,
- num_proc=args.dataset_num_proc,
- remove_columns=[c for c in train_kl_dataset.column_names if c in train_dataset.column_names],
- desc="Processing tokenized train KL dataset",
- )
-
- # merge the datasets
- train_dataset = concatenate_datasets([train_dataset, train_kl_dataset], axis=1)
-
- if eval_dataset is not None:
- # Get KL dataset
- eval_kl_dataset = eval_dataset.map(
- _get_kl_dataset,
- batched=True,
- batch_size=args.per_device_train_batch_size,
- num_proc=args.dataset_num_proc,
- desc="Extracting eval KL dataset",
- )
-
- eval_kl_dataset = eval_kl_dataset.map(
- _process_tokens,
- fn_kwargs=fn_kwargs,
- num_proc=args.dataset_num_proc,
- remove_columns=[c for c in eval_kl_dataset.column_names if c in eval_dataset.column_names],
- desc="Processing tokenized eval KL dataset",
- )
-
- # merge the datasets
- eval_dataset = concatenate_datasets([eval_dataset, eval_kl_dataset], axis=1)
-
- # calculate dataset desirability balance
- num_desirable = max(sum(train_dataset["label"]), 1)
- num_undesirable = max(len(train_dataset["label"]) - num_desirable, 1) # "label" is binary
-
- if num_desirable != num_undesirable:
- # The lower and upper bounds come from Eq. [8] of https://huggingface.co/papers/2402.01306
- des_weight_lower_bound = round((num_undesirable * self.undesirable_weight / num_desirable) * 1, 2)
- des_weight_upper_bound = round((num_undesirable * self.undesirable_weight / num_desirable) * 1.33, 2)
- und_weight_lower_bound = round((num_desirable * self.desirable_weight / num_undesirable) / 1.33, 2)
- und_weight_upper_bound = round((num_desirable * self.desirable_weight / num_undesirable) / 1, 2)
-
- des_weight_in_range = des_weight_lower_bound <= self.desirable_weight <= des_weight_upper_bound
- und_weight_in_range = und_weight_lower_bound <= self.undesirable_weight <= und_weight_upper_bound
-
- if not (des_weight_in_range or und_weight_in_range):
- logger.warning(
- "You have different amounts of desirable/positive and undesirable/negative examples but the "
- "weights on the desirable and undesirable losses don't seem to be in an ideal range. Based "
- f"on your data, we recommend EITHER "
- f"desirable_weight in [{des_weight_lower_bound}, {des_weight_upper_bound}] or "
- f"undesirable_weight in [{und_weight_lower_bound}, {und_weight_upper_bound}] (but NOT BOTH). "
- "See the documentation on how to optimally set these weights.",
- )
-
- super().__init__(
- model=model,
- args=args,
- data_collator=data_collator,
- train_dataset=train_dataset,
- eval_dataset=eval_dataset,
- processing_class=processing_class,
- model_init=model_init,
- compute_metrics=compute_metrics,
- callbacks=callbacks,
- optimizers=optimizers,
- preprocess_logits_for_metrics=preprocess_logits_for_metrics,
- )
-
- # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
- # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
- # self.model_accepts_loss_kwargs to False to enable scaling.
- self.model_accepts_loss_kwargs = False
-
- # Add tags for models that have been loaded with the correct transformers version
- if hasattr(self.model, "add_model_tags"):
- self.model.add_model_tags(self._tag_names)
-
- if not hasattr(self, "accelerator"):
- raise AttributeError(
- "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
- )
-
- # Deepspeed Zero-3 does not support precompute_ref_log_probs
- if self.is_deepspeed_enabled:
- if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs:
- raise ValueError(
- "You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`."
- )
-
- if self.ref_model is None:
- if not (self.is_peft_model or self.precompute_ref_log_probs):
- raise ValueError(
- "No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`"
- )
- else:
- if self.is_deepspeed_enabled:
- self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
- else:
- self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
-
- # Import Liger loss if enabled
- if self.args.use_liger_loss:
- if not is_liger_kernel_available():
- raise ImportError(
- "You set `use_liger_loss=True` but the liger kernel is not available. "
- "Please install liger-kernel first: `pip install liger-kernel`"
- )
- if self.loss_type in ["apo_zero_unpaired"]:
- raise ValueError(
- "You cannot set `loss_type='apo_zero_unpaired'` with liger-kernel."
- "Only KTO loss is supported with liger-kernel."
- )
- if self.precompute_ref_log_probs:
- raise ValueError(
- "You cannot use `precompute_ref_log_probs=True` with liger kernel. Please set "
- "`precompute_ref_log_probs=False`."
- )
- if self.is_peft_model or self.ref_adapter_name is not None:
- raise ValueError(
- "You cannot use `use_liger_loss=True` with Peft models. Please set `use_liger_loss=False`."
- )
- self.kto_loss_fn = LigerFusedLinearKTOLoss(
- ignore_index=self.label_pad_token_id, beta=self.beta, use_ref_model=(self.ref_model is not None)
- )
-
- @contextmanager
- def null_ref_context(self):
- """Context manager for handling null reference model (that is, peft adapter manipulation)."""
- with (
- self.accelerator.unwrap_model(self.model).disable_adapter()
- if self.is_peft_model and not self.ref_adapter_name
- else nullcontext()
- ):
- if self.ref_adapter_name:
- self.model.set_adapter(self.ref_adapter_name)
- yield
- if self.ref_adapter_name:
- self.model.set_adapter(self.model_adapter_name or "default")
-
- def get_train_dataloader(self) -> DataLoader:
- """
- Returns the training [`~torch.utils.data.DataLoader`].
-
- Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`.
- """
-
- if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs:
- dataloader_params = {
- "batch_size": self.args.per_device_train_batch_size,
- "collate_fn": self.data_collator,
- "num_workers": self.args.dataloader_num_workers,
- "pin_memory": self.args.dataloader_pin_memory,
- "shuffle": False,
- }
-
- # prepare dataloader
- data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params))
- reference_completion_logps = []
- reference_KL_logps = []
-
- for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"):
- reference_completion_logp, reference_KL_logp = self.compute_reference_log_probs(padded_batch)
-
- reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
- reference_completion_logps.append(reference_completion_logp.cpu())
-
- if self.calculate_KL:
- reference_KL_logp = self.accelerator.gather_for_metrics(reference_KL_logp)
- reference_KL_logps.append(reference_KL_logp.cpu())
-
- self.train_dataset = self.train_dataset.add_column(
- name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
- )
-
- if self.calculate_KL:
- self.train_dataset = self.train_dataset.add_column(
- name="reference_KL_logps", column=torch.cat(reference_KL_logps).float().numpy()
- )
-
- self._precomputed_train_ref_log_probs = True
-
- return super().get_train_dataloader()
-
- def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
- """
- Returns the evaluation [`~torch.utils.data.DataLoader`].
-
- Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`.
-
- Args:
- eval_dataset (`torch.utils.data.Dataset`, *optional*):
- If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
- by the `model.forward()` method are automatically removed. It must implement `__len__`.
- """
- if eval_dataset is None and self.eval_dataset is None:
- raise ValueError("Trainer: evaluation requires an eval_dataset.")
- eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
-
- if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs:
- dataloader_params = {
- "batch_size": self.args.per_device_eval_batch_size,
- "collate_fn": self.data_collator,
- "num_workers": self.args.dataloader_num_workers,
- "pin_memory": self.args.dataloader_pin_memory,
- "shuffle": False,
- }
-
- # prepare dataloader
- data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))
-
- reference_completion_logps = []
- reference_KL_logps = []
-
- for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"):
- reference_completion_logp, reference_KL_logp = self.compute_reference_log_probs(padded_batch)
-
- reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
- reference_completion_logps.append(reference_completion_logp.cpu())
-
- if self.calculate_KL:
- reference_KL_logp = self.accelerator.gather_for_metrics(reference_KL_logp)
- reference_KL_logps.append(reference_KL_logp.cpu())
-
- eval_dataset = eval_dataset.add_column(
- name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
- )
- if self.calculate_KL:
- eval_dataset = eval_dataset.add_column(
- name="reference_KL_logps", column=torch.cat(reference_KL_logps).float().numpy()
- )
-
- # Save calculated reference_chosen_logps and reference_rejected_logps to the eval_dataset for subsequent runs
- if self.eval_dataset is not None:
- self.eval_dataset = eval_dataset
- self._precomputed_eval_ref_log_probs = True
-
- return super().get_eval_dataloader(eval_dataset=eval_dataset)
-
- def compute_reference_log_probs(self, padded_batch: dict) -> dict:
- """Computes log probabilities of the reference model for a single padded batch of a KTO specific dataset."""
- with torch.no_grad():
- if self.ref_model is None:
- with self.null_ref_context():
- if self.is_encoder_decoder:
- completion_logits = self.model(
- padded_batch["prompt_input_ids"],
- attention_mask=padded_batch["prompt_attention_mask"],
- decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
- labels=padded_batch["completion_labels"],
- ).logits
-
- if self.calculate_KL:
- KL_logits = self.model(
- padded_batch["KL_prompt_input_ids"],
- attention_mask=padded_batch["KL_prompt_attention_mask"],
- decoder_input_ids=padded_batch.get("KL_completion_decoder_input_ids"),
- labels=padded_batch["KL_completion_labels"],
- ).logits
- else:
- completion_logits = self.model(
- padded_batch["completion_input_ids"],
- attention_mask=padded_batch["completion_attention_mask"],
- ).logits
-
- if self.calculate_KL:
- KL_logits = self.model(
- padded_batch["KL_completion_input_ids"],
- attention_mask=padded_batch["KL_completion_attention_mask"],
- ).logits
- else:
- if self.is_encoder_decoder:
- completion_logits = self.ref_model(
- padded_batch["prompt_input_ids"],
- attention_mask=padded_batch["prompt_attention_mask"],
- decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
- labels=padded_batch["completion_labels"],
- ).logits
-
- if self.calculate_KL:
- KL_logits = self.ref_model(
- padded_batch["KL_prompt_input_ids"],
- attention_mask=padded_batch["KL_prompt_attention_mask"],
- decoder_input_ids=padded_batch.get("KL_completion_decoder_input_ids"),
- labels=padded_batch["KL_completion_labels"],
- ).logits
- else:
- completion_logits = self.ref_model(
- padded_batch["completion_input_ids"], attention_mask=padded_batch["completion_attention_mask"]
- ).logits
-
- if self.calculate_KL:
- KL_logits = self.ref_model(
- padded_batch["KL_completion_input_ids"],
- attention_mask=padded_batch["KL_completion_attention_mask"],
- ).logits
-
- completion_logps = self.get_batch_logps(
- completion_logits,
- padded_batch["completion_labels"],
- average_log_prob=False,
- is_encoder_decoder=self.is_encoder_decoder,
- label_pad_token_id=self.label_pad_token_id,
- )
-
- if self.calculate_KL:
- KL_logps = self.get_batch_logps(
- KL_logits,
- padded_batch["KL_completion_labels"],
- average_log_prob=False,
- is_encoder_decoder=self.is_encoder_decoder,
- label_pad_token_id=self.label_pad_token_id,
- )
- else:
- KL_logps = None
-
- return completion_logps, KL_logps
-
- @staticmethod
- def get_batch_logps(
- logits: torch.FloatTensor,
- labels: torch.LongTensor,
- average_log_prob: bool = False,
- label_pad_token_id: int = -100,
- is_encoder_decoder: bool = False,
- ) -> torch.FloatTensor:
- """Compute the log probabilities of the given labels under the given logits.
-
- Args:
- logits:
- Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
- labels:
- Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are
- ignored. Shape: (batch_size, sequence_length)
- average_log_prob:
- If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the
- log probabilities of the (non-masked) tokens.
- label_pad_token_id:
- The label value to ignore when computing log probabilities.
- is_encoder_decoder:
- Whether the model is an encoder-decoder model. If True, the labels are not shifted and the logits are
- assumed to already be aligned with the labels. If False, the labels are shifted to the right by one
- position, and the logits are assumed to be aligned with the shifted labels.
-
- Returns:
- A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the
- given logits.
- """
- if logits.shape[:-1] != labels.shape:
- # Unsloth: auto-truncate to shorter sequence length (model may have truncated input_ids)
- _min_len = min(logits.shape[1], labels.shape[1])
- logits = logits[:, :_min_len, :]
- labels = labels[:, :_min_len]
-
- if not is_encoder_decoder:
- labels = labels[:, 1:].clone()
- logits = logits[:, :-1, :]
- else:
- # Fixes end-dec RuntimeError
- labels = labels.clone()
-
- loss_mask = labels != label_pad_token_id
-
- # dummy token; we'll ignore the losses on these tokens later
- labels[labels == label_pad_token_id] = 0
-
- per_token_logps = selective_log_softmax(logits, labels)
-
- if average_log_prob:
- return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
- else:
- return (per_token_logps * loss_mask).sum(-1)
-
- def forward(
- self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
- ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
- KL_logps = self._compute_kl_logps(model, batch)
-
- model_kwargs = (
- {
- "labels": batch["completion_labels"],
- "decoder_input_ids": batch.get("completion_decoder_input_ids"),
- }
- if self.is_encoder_decoder
- else {}
- )
- if self.aux_loss_enabled:
- model_kwargs["output_router_logits"] = True
-
- outputs = model(
- batch["completion_input_ids"],
- attention_mask=batch["completion_attention_mask"],
- **model_kwargs,
- )
- completion_logits = outputs.logits
-
- completion_logps = self.get_batch_logps(
- completion_logits,
- batch["completion_labels"],
- average_log_prob=False,
- is_encoder_decoder=self.is_encoder_decoder,
- label_pad_token_id=self.label_pad_token_id,
- )
-
- if completion_logps.shape[0] != len(batch["label"]):
- raise ValueError(
- "There is a mismatch between the number of examples in this batch and the number of "
- "examples for which an output sequence was predicted."
- )
-
- chosen_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is True]
- rejected_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is False]
-
- chosen_logps = completion_logps[chosen_idx, ...]
- rejected_logps = completion_logps[rejected_idx, ...]
-
- chosen_logits = completion_logits[chosen_idx, ...]
- rejected_logits = completion_logits[rejected_idx, ...]
-
- if self.aux_loss_enabled:
- return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, KL_logps, outputs.aux_loss)
- else:
- return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, KL_logps)
-
- def kto_loss(
- self,
- policy_chosen_logps: torch.FloatTensor,
- policy_rejected_logps: torch.FloatTensor,
- policy_KL_logps: torch.FloatTensor,
- reference_chosen_logps: torch.FloatTensor,
- reference_rejected_logps: torch.FloatTensor,
- reference_KL_logps: torch.FloatTensor,
- ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
- """Compute the KTO loss for a batch of policy and reference model log probabilities.
-
- Args:
- policy_chosen_logps:
- Log probabilities of the policy model for the chosen responses. Shape: (num(chosen) in batch_size,)
- policy_rejected_logps:
- Log probabilities of the policy model for the rejected responses. Shape: (num(rejected) in batch_size,)
- policy_KL_logps: Log probabilities of the policy model for the KL responses. Shape: (batch_size,)
- reference_chosen_logps:
- Log probabilities of the reference model for the chosen responses. Shape: (num(chosen) in batch_size,)
- reference_rejected_logps:
- Log probabilities of the reference model for the rejected responses. Shape: (num(rejected) in
- batch_size,)
- reference_KL_logps: Log probabilities of the reference model for the KL responses. Shape: (batch_size,)
-
- Returns:
- A tuple of four tensors: (losses, chosen_rewards, rejected_rewards, KL). The losses tensor contains the KTO
- loss for each example in the batch. The chosen_rewards and rejected_rewards tensors contain the rewards for
- the chosen and rejected responses, respectively. The KL tensor contains the detached KL divergence estimate
- between the policy and reference models.
- """
- if self.calculate_KL:
- kl = (policy_KL_logps - reference_KL_logps).mean().detach()
- kl = self.accelerator.gather_for_metrics(kl).mean().clamp(min=0)
- else:
- kl = torch.zeros(1).to(policy_chosen_logps.device)
-
- # Chosen losses
- if policy_chosen_logps.shape[0] != 0 or reference_chosen_logps.shape[0] != 0:
- chosen_logratios = policy_chosen_logps - reference_chosen_logps
-
- if self.loss_type == "kto":
- # Eqn (7) of the KTO paper (https://huggingface.co/papers/2402.01306)
- chosen_losses = 1 - F.sigmoid(self.beta * (chosen_logratios - kl))
- elif self.loss_type == "apo_zero_unpaired":
- # Unpaired variant of Eqn (7) of the APO paper (https://huggingface.co/papers/2408.06266)
- # Use this loss when you believe the chosen outputs are better than your model's default output
- chosen_losses = 1 - F.sigmoid(self.beta * chosen_logratios)
-
- chosen_rewards = self.beta * chosen_logratios.detach()
-
- else:
- # lists can't be empty -- if they are, then accelerate.gather will hang
- chosen_losses = torch.Tensor([]).to(self.accelerator.device)
- chosen_rewards = torch.Tensor([]).to(self.accelerator.device)
-
- # Rejected losses
- if policy_rejected_logps.shape[0] != 0 or reference_rejected_logps.shape[0] != 0:
- rejected_logratios = policy_rejected_logps - reference_rejected_logps
-
- if self.loss_type == "kto":
- rejected_losses = 1 - F.sigmoid(self.beta * (kl - rejected_logratios))
- elif self.loss_type == "apo_zero_unpaired":
- rejected_losses = F.sigmoid(self.beta * rejected_logratios)
-
- rejected_rewards = self.beta * rejected_logratios.detach()
- else:
- # lists can't be empty -- if they are, then accelerate.gather will hang
- rejected_losses = torch.Tensor([]).to(self.accelerator.device)
- rejected_rewards = torch.Tensor([]).to(self.accelerator.device)
-
- losses = torch.cat(
- (self.desirable_weight * chosen_losses, self.undesirable_weight * rejected_losses),
- 0,
- )
-
- return losses, chosen_rewards, rejected_rewards, kl
-
- def _compute_kl_logps(self, model, batch):
- """Compute KL log probabilities for a given batch."""
- KL_logps = None
- if self.calculate_KL:
- if self.is_encoder_decoder:
- KL_model_kwargs = {
- "input_ids": batch["KL_prompt_input_ids"],
- "attention_mask": batch["KL_prompt_attention_mask"],
- "labels": batch["KL_completion_labels"],
- "decoder_input_ids": batch.get("KL_completion_decoder_input_ids"),
- }
- else:
- KL_model_kwargs = {
- "input_ids": batch["KL_completion_input_ids"],
- "attention_mask": batch["KL_completion_attention_mask"],
- }
-
- with torch.no_grad():
- KL_logits = model(**KL_model_kwargs).logits
-
- KL_logps = self.get_batch_logps(
- KL_logits,
- batch["KL_completion_labels"],
- average_log_prob=False,
- is_encoder_decoder=self.is_encoder_decoder,
- label_pad_token_id=self.label_pad_token_id,
- )
- return KL_logps
-
- def _compute_loss_liger(self, model, batch):
- """
- Compute the KTO loss using the Liger-Kernel's LigerFusedLinearKTOLoss.
-
- Args:
- model:
- The policy model used for generating log probabilities and outputs. It could be an encoder-decoder
- model or a regular language model.
- batch: A dictionary containing the input data and labels for the batch.
-
- Returns:
- A dictionary containing the following keys:
- - "loss": The computed KTO loss for the batch.
- - "chosen_logits_sum": Sum of the logits for the chosen responses from the policy model.
- - "rejected_logits_sum": Sum of the logits for the rejected responses from the policy model.
- - "chosen_logps": Log probabilities of the chosen responses from the policy model.
- - "rejected_logps": Log probabilities of the rejected responses from the policy model.
- - "chosen_rewards": Rewards for the chosen responses.
- - "rejected_rewards": Rewards for the rejected responses.
- - "kl": The KL divergence between the policy and reference models (detached).
-
- If auxiliary loss is enabled, the dictionary will also include:
- - "aux_loss": The auxiliary loss from the model outputs.
- """
- policy_KL_logps = self._compute_kl_logps(model, batch)
- reference_KL_logps = self._compute_kl_logps(self.ref_model, batch)
- if self.calculate_KL:
- kl = (policy_KL_logps - reference_KL_logps).mean().detach()
- kl = self.accelerator.gather_for_metrics(kl).mean().clamp(min=0)
- else:
- kl = torch.zeros(1).to(self.accelerator.device)
-
- model_kwargs = (
- {
- "labels": batch["completion_labels"],
- "decoder_input_ids": batch.get("completion_decoder_input_ids"),
- }
- if self.is_encoder_decoder
- else {}
- )
- if self.aux_loss_enabled:
- model_kwargs["output_router_logits"] = True
-
- if self.is_encoder_decoder:
- # 1. Get encoder outputs
- encoder_outputs = model.get_encoder()(
- batch["completion_input_ids"],
- attention_mask=batch["completion_attention_mask"],
- return_dict=True,
- **model_kwargs,
- )
- # 2. Get decoder outputs
- outputs = model.get_decoder()(
- input_ids=model_kwargs["decoder_input_ids"],
- encoder_hidden_states=encoder_outputs.last_hidden_state,
- use_cache=False,
- **model_kwargs,
- )
- # 1. Get reference encoder outputs
- ref_encoder_outputs = self.ref_model.get_encoder()(
- batch["completion_input_ids"],
- attention_mask=batch["completion_attention_mask"],
- return_dict=True,
- **model_kwargs,
- )
- # 2. Get reference decoder outputs
- ref_outputs = self.ref_model.get_decoder()(
- input_ids=model_kwargs["decoder_input_ids"],
- encoder_hidden_states=ref_encoder_outputs.last_hidden_state,
- use_cache=False,
- **model_kwargs,
- )
- else:
- # skip the lm head and get the last hidden state
- if hasattr(model, "get_decoder") and model.get_decoder() is not None:
- base_model = model.get_decoder()
- else:
- base_attr = getattr(model, "base_model_prefix", self.args.base_model_attribute_name)
- base_model = getattr(model, base_attr, model)
- outputs = base_model(
- batch["completion_input_ids"],
- attention_mask=batch["completion_attention_mask"],
- use_cache=False,
- **model_kwargs,
- )
-
- # reference model
- if hasattr(self.ref_model, "get_decoder") and self.ref_model.get_decoder() is not None:
- ref_base_model = self.ref_model.get_decoder()
- else:
- ref_attr = getattr(self.ref_model, "base_model_prefix", self.args.base_model_attribute_name)
- ref_base_model = getattr(self.ref_model, ref_attr, self.ref_model)
- ref_outputs = ref_base_model(
- batch["completion_input_ids"],
- attention_mask=batch["completion_attention_mask"],
- use_cache=False,
- **model_kwargs,
- )
- lm_head = model.get_output_embeddings()
- ref_lm_head = self.ref_model.get_output_embeddings()
-
- (
- loss,
- (
- chosen_logps_sum,
- rejected_logps_sum,
- chosen_logits_sum,
- rejected_logits_sum,
- chosen_rewards_sum,
- rejected_rewards_sum,
- ),
- ) = self.kto_loss_fn(
- _input=outputs.last_hidden_state[:, :-1] if not self.is_encoder_decoder else outputs.last_hidden_state,
- lin_weight=lm_head.weight,
- target=batch["completion_labels"][:, 1:],
- bias=lm_head.bias if hasattr(lm_head, "bias") else None,
- preference_labels=torch.tensor(batch["label"], dtype=torch.bool).to(self.accelerator.device),
- ref_input=ref_outputs.last_hidden_state[:, :-1]
- if not self.is_encoder_decoder
- else outputs.last_hidden_state,
- ref_weight=ref_lm_head.weight,
- ref_bias=ref_lm_head.bias if hasattr(lm_head, "bias") else None,
- kl=kl,
- )
-
- output = {
- "loss": loss,
- "chosen_logits_sum": chosen_logits_sum,
- "rejected_logits_sum": rejected_logits_sum,
- "chosen_logps_sum": chosen_logps_sum,
- "rejected_logps_sum": rejected_logps_sum,
- "chosen_rewards_sum": chosen_rewards_sum,
- "rejected_rewards_sum": rejected_rewards_sum,
- "kl": kl,
- }
- if self.aux_loss_enabled:
- output["aux_loss"] = outputs.aux_loss
-
- return output
-
- def get_batch_loss_metrics(
- self,
- model,
- batch: dict[str, Union[list, torch.LongTensor]],
- ):
- """Compute the KTO loss and other metrics for the given batch of inputs for train or test."""
- metrics = {}
- batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
-
- labels = torch.tensor(batch["label"])
- num_chosen = labels.sum().to(self.accelerator.device)
- num_rejected = (len(labels) - num_chosen).to(self.accelerator.device)
-
- if self.args.use_liger_loss:
- model_output = self._compute_loss_liger(model, batch)
- losses = model_output["loss"]
- policy_chosen_logits = model_output["chosen_logits_sum"]
- policy_rejected_logits = model_output["rejected_logits_sum"]
- policy_chosen_logps = model_output["chosen_logps_sum"]
- policy_rejected_logps = model_output["rejected_logps_sum"]
- chosen_rewards = model_output["chosen_rewards_sum"]
- rejected_rewards = model_output["rejected_rewards_sum"]
- kl = model_output["kl"]
- if self.aux_loss_enabled:
- aux_loss = model_output["aux_loss"]
- else:
- forward_output = self.forward(model, batch)
- (
- policy_chosen_logps,
- policy_rejected_logps,
- policy_chosen_logits,
- policy_rejected_logits,
- policy_KL_logps,
- ) = forward_output[:5]
- if self.aux_loss_enabled:
- aux_loss = forward_output[5]
-
- # if reference_logps in batch use them, otherwise use the reference model
- if "reference_logps" in batch:
- chosen_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is True]
- rejected_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is False]
-
- reference_chosen_logps = batch["reference_logps"][chosen_idx, ...]
- reference_rejected_logps = batch["reference_logps"][rejected_idx, ...]
- if self.calculate_KL:
- reference_KL_logps = batch["reference_KL_logps"]
- else:
- reference_KL_logps = None
- else:
- with torch.no_grad():
- if self.ref_model is None:
- with self.null_ref_context():
- (
- reference_chosen_logps,
- reference_rejected_logps,
- _,
- _,
- reference_KL_logps,
- ) = self.forward(self.model, batch)[:5]
- else:
- (
- reference_chosen_logps,
- reference_rejected_logps,
- _,
- _,
- reference_KL_logps,
- ) = self.forward(self.ref_model, batch)[:5]
-
- losses, chosen_rewards, rejected_rewards, kl = self.kto_loss(
- policy_chosen_logps,
- policy_rejected_logps,
- policy_KL_logps,
- reference_chosen_logps,
- reference_rejected_logps,
- reference_KL_logps,
- )
-
- metrics["kl"] = kl.item()
-
- all_num_chosen = self.accelerator.gather_for_metrics(num_chosen).sum().item()
- all_num_rejected = self.accelerator.gather_for_metrics(num_rejected).sum().item()
-
- if all_num_chosen > 0:
- metrics["rewards/chosen_sum"] = (
- self.accelerator.gather_for_metrics(chosen_rewards.nansum()).nansum().item()
- )
- metrics["logps/chosen_sum"] = (
- self.accelerator.gather_for_metrics(policy_chosen_logps.nansum()).nansum().item()
- )
- metrics["logits/chosen_sum"] = (
- self.accelerator.gather_for_metrics(policy_chosen_logits.nansum()).nansum().item()
- )
- metrics["count/chosen"] = all_num_chosen
-
- if all_num_rejected > 0:
- metrics["rewards/rejected_sum"] = (
- self.accelerator.gather_for_metrics(rejected_rewards.nansum()).nansum().item()
- )
- metrics["logps/rejected_sum"] = (
- self.accelerator.gather_for_metrics(policy_rejected_logps.nansum()).nansum().item()
- )
- metrics["logits/rejected_sum"] = (
- self.accelerator.gather_for_metrics(policy_rejected_logits.nansum()).nansum().item()
- )
- metrics["count/rejected"] = all_num_rejected
-
- loss = losses.nanmean()
- if self.aux_loss_enabled:
- loss += self.aux_loss_coef * aux_loss
-
- return loss, metrics
-
- def compute_loss(
- self,
- model: Union[PreTrainedModel, nn.Module],
- inputs: dict[str, Union[torch.Tensor, Any]],
- return_outputs=False,
- num_items_in_batch=None,
- ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
- compute_loss_context_manager = (
- autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
- )
-
- with compute_loss_context_manager:
- loss, metrics = self.get_batch_loss_metrics(model, inputs)
-
- # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
- loss = loss.to(self.args.device)
- # force log the metrics
- if self.accelerator.is_main_process:
- self.store_metrics(metrics, train_eval="train")
-
- if return_outputs:
- return (loss, metrics)
- return loss
-
- def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
- for key, value in metrics.items():
- self._stored_metrics[train_eval][key].append(value)
-
- def _get_train_sampler(self, dataset: Optional[Dataset] = None) -> Optional[torch.utils.data.Sampler]:
- if dataset is None:
- dataset = self.train_dataset
- if dataset is None or not has_length(dataset):
- return None
- return SequentialSampler(dataset)
-
- def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]:
- """Generate samples from the model and reference model for the given batch of inputs."""
-
- # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
- # the torch amp context manager as some hidden states are silently casted to full precision.
- generate_context_manager = (
- autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
- )
-
- with generate_context_manager:
- policy_output = model.generate(
- input_ids=batch["prompt_input_ids"],
- attention_mask=batch["prompt_attention_mask"],
- max_length=self.max_length,
- do_sample=True,
- pad_token_id=self.processing_class.pad_token_id,
- )
-
- # if reference_output in batch use that otherwise use the reference model
- if "reference_output" in batch:
- reference_output = batch["reference_output"]
- else:
- if self.ref_model is None:
- with self.null_ref_context():
- reference_output = self.model.generate(
- input_ids=batch["prompt_input_ids"],
- attention_mask=batch["prompt_attention_mask"],
- max_length=self.max_length,
- do_sample=True,
- pad_token_id=self.processing_class.pad_token_id,
- )
- else:
- reference_output = self.ref_model.generate(
- input_ids=batch["prompt_input_ids"],
- attention_mask=batch["prompt_attention_mask"],
- max_length=self.max_length,
- do_sample=True,
- pad_token_id=self.processing_class.pad_token_id,
- )
-
- policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
- policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
-
- reference_output = pad_to_length(reference_output, self.max_length, self.processing_class.pad_token_id)
- reference_output_decoded = self.processing_class.batch_decode(reference_output, skip_special_tokens=True)
-
- return policy_output_decoded, reference_output_decoded
-
- def prediction_step(
- self,
- model: Union[PreTrainedModel, nn.Module],
- inputs: dict[str, Union[torch.Tensor, Any]],
- prediction_loss_only: bool,
- ignore_keys: Optional[list[str]] = None,
- ):
- if ignore_keys is None:
- if hasattr(model, "config"):
- ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
- else:
- ignore_keys = []
-
- prediction_context_manager = (
- autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
- )
- with torch.no_grad(), prediction_context_manager:
- loss, metrics = self.get_batch_loss_metrics(model, inputs)
-
- # force log the metrics
- if self.accelerator.is_main_process:
- self.store_metrics(metrics, train_eval="eval")
-
- if prediction_loss_only:
- return (loss.detach(), None, None)
-
- # logits for the chosen and rejected samples from model
- logits_dict = {}
- if "logits/chosen_sum" in metrics:
- logits_dict["eval_logits/chosen"] = metrics["logits/chosen_sum"]
- if "logits/rejected_sum" in metrics:
- logits_dict["eval_logits/rejected"] = metrics["logits/rejected_sum"]
- logits = [v for k, v in logits_dict.items() if k not in ignore_keys]
- logits = torch.tensor(logits, device=self.accelerator.device)
- labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
-
- return (loss.detach(), logits, labels)
-
- def evaluation_loop(
- self,
- dataloader: DataLoader,
- description: str,
- prediction_loss_only: Optional[bool] = None,
- ignore_keys: Optional[list[str]] = None,
- metric_key_prefix: str = "eval",
- ) -> EvalLoopOutput:
- """
- Overriding built-in evaluation loop to store metrics for each batch. Prediction/evaluation loop, shared by
- `Trainer.evaluate()` and `Trainer.predict()`.
-
- Works both with or without labels.
- """
-
- # Sample and save to game log if requested (for one batch to save time)
- if self.generate_during_eval:
- # Generate random indices within the range of the total number of samples
- num_samples = len(dataloader.dataset)
- random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
-
- # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
- random_batch_dataset = dataloader.dataset.select(random_indices)
- random_batch = self.data_collator(random_batch_dataset)
- random_batch = self._prepare_inputs(random_batch)
-
- target_labels = torch.tensor(random_batch["label"], dtype=torch.bool, device=self.accelerator.device)
- target_indices = torch.where(~target_labels)[0]
- target_batch = {
- "prompt_input_ids": random_batch["prompt_input_ids"][target_indices],
- "prompt_attention_mask": random_batch["prompt_attention_mask"][target_indices],
- "prompt": itemgetter(*target_indices)(random_batch["prompt"]),
- }
- policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, target_batch)
-
- table = pd.DataFrame(
- columns=["Prompt", "Policy", "Ref Model"],
- data=[
- [prompt, pol[len(prompt) :], ref[len(prompt) :]]
- for prompt, pol, ref in zip(target_batch["prompt"], policy_output_decoded, ref_output_decoded)
- ],
- )
- if "wandb" in self.args.report_to:
- wandb.log({"game_log": wandb.Table(data=table)})
-
- if "comet_ml" in self.args.report_to:
- log_table_to_comet_experiment(
- name="game_log.csv",
- table=table,
- )
-
- # Base evaluation
- initial_output = super().evaluation_loop(
- dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
- )
-
- return initial_output
-
- def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
- """
- Log `logs` on the various objects watching training, including stored metrics.
-
- Args:
- logs (`dict[str, float]`):
- The values to log.
- start_time (`float`, *optional*):
- Start time of the training.
- """
- # logs either has 'loss' or 'eval_loss'
- train_eval = "train" if "loss" in logs else "eval"
- # train metrics should have no prefix, eval should have 'eval_'
- prefix = "eval_" if train_eval == "eval" else ""
- # accumulate average metrics from sums and lengths
- for split in ["chosen", "rejected"]:
- if f"count/{split}" in self._stored_metrics[train_eval]:
- count_sum = torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"]).sum().item()
- for metric in ["rewards", "logps", "logits"]:
- logs[f"{prefix}{metric}/{split}"] = (
- torch.Tensor(self._stored_metrics[train_eval][f"{metric}/{split}_sum"]).sum().item()
- / count_sum
- )
- # delete obsolete metric
- del self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
- del self._stored_metrics[train_eval][f"count/{split}"]
- # calculate reward margin
- if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs:
- logs[f"{prefix}rewards/margins"] = logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
- # Add averaged stored metrics to logs
- for key, metrics in self._stored_metrics[train_eval].items():
- logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
- del self._stored_metrics[train_eval]
- return super().log(logs, start_time)
-
- # Ensure the model card is saved along with the checkpoint
- def _save_checkpoint(self, model, trial):
- if self.args.hub_model_id is None:
- model_name = Path(self.args.output_dir).name
- else:
- model_name = self.args.hub_model_id.split("/")[-1]
- self.create_model_card(model_name=model_name)
- super()._save_checkpoint(model, trial)
-class UnslothKTOTrainer(_UnslothKTOTrainer):
- """
-
- Initialize KTOTrainer.
-
- Args:
- model ([`~transformers.PreTrainedModel`]):
- The model to train, preferably an [`~transformers.AutoModelForSequenceClassification`].
- ref_model ([`PreTrainedModelWrapper`]):
- Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation
- and loss. If no reference model is provided, the trainer will create a reference model with the same
- architecture as the model to be optimized.
- args ([`KTOConfig`]):
- The arguments to use for training.
- train_dataset ([`~datasets.Dataset`]):
- The dataset to use for training.
- eval_dataset ([`~datasets.Dataset`]):
- The dataset to use for evaluation.
- processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*):
- Processing class used to process the data. If provided, will be used to automatically process the inputs
- for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
- reuse the fine-tuned model.
- data_collator ([`~transformers.DataCollator`], *optional*):
- The data collator to use for training. If None is specified, the default data collator
- ([`DPODataCollatorWithPadding`]) will be used which will pad the sequences to the maximum length of the
- sequences in the batch, given a dataset of paired sequences.
- model_init (`Callable[[], transformers.PreTrainedModel]`):
- The model initializer to use for training. If None is specified, the default model initializer will be
- used.
- callbacks (`list[transformers.TrainerCallback]`):
- The callbacks to use for training.
- optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
- The optimizer and scheduler to use for training.
- preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
- The function to use to preprocess the logits before computing the metrics.
- peft_config (`dict`, defaults to `None`):
- The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in
- a PEFT model.
- compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
- The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to
- metric values.
- model_adapter_name (`str`, defaults to `None`):
- Name of the train target PEFT adapter, when using LoRA with multiple adapters.
- ref_adapter_name (`str`, defaults to `None`):
- Name of the reference PEFT adapter, when using LoRA with multiple adapters.
-
- """
- def __init__(
- self,
- model = None,
- ref_model = None,
- args = None,
- train_dataset = None,
- eval_dataset = None,
- processing_class = None,
- data_collator = None,
- model_init = None,
- callbacks = None,
- preprocess_logits_for_metrics = None,
- peft_config = None,
- compute_metrics = None,
- model_adapter_name = None,
- ref_adapter_name = None,
- **kwargs
- ):
- if args is None: args = UnslothKTOConfig()
- use_bf16 = getattr(args, 'bf16', False)
- if type(use_bf16) is not bool: use_bf16 = False
- use_fp16 = getattr(args, 'fp16', False)
- if type(use_fp16) is not bool: use_fp16 = False
- force_float32 = False
- full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1'
- if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'):
- print('Unsloth: Switching to float32 training since model cannot work with float16')
- force_float32 = True
- mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
- dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None)
- if dtype is None: dtype = model.get_input_embeddings().weight.dtype
- from unsloth_zoo.utils import _get_dtype
- dtype = _get_dtype(dtype)
- float16 = dtype == torch.float16
- if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
- if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
- if force_float32:
- # Forced float32 training
- args.fp16 = False
- args.bf16 = False
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
- if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
- # args.mixed_precision is a new argument which needs to be set now
- elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
- # Mixed precision training
- args.fp16 = float16
- args.bf16 = not float16
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
- if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16'
- # args.mixed_precision is a new argument which needs to be set now
- elif mixed_precision_dtype == 'bfloat16':
- # Both False since bfloat16 full finetuning doesn't do any autocasting.
- args.fp16 = False
- args.bf16 = False
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
- if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
- # args.mixed_precision is a new argument which needs to be set now
-
- if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
- args.eval_strategy = 'steps'
- if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
- ga_steps = getattr(args, 'gradient_accumulation_steps', None)
- if ga_steps is not None and ga_steps > 1:
- from transformers import __version__ as transformers_version
- if Version(transformers_version) <= Version('4.45.2'):
- print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
- '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
- if getattr(args, 'eval_strategy', 'no') != 'no':
- eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
- if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
- if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
- fp16_full_eval = getattr(args, 'fp16_full_eval', False)
- if type(fp16_full_eval) is not bool: fp16_full_eval = False
- bf16_full_eval = getattr(args, 'bf16_full_eval', False)
- if type(bf16_full_eval) is not bool: bf16_full_eval = False
- if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
- if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
- if force_float32:
- args.bf16_full_eval = False
- args.fp16_full_eval = False
- elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
- args.bf16_full_eval = True
- args.fp16_full_eval = False
- elif not bf16_full_eval and not fp16_full_eval:
- args.bf16_full_eval = args.bf16
- args.fp16_full_eval = args.fp16
- _output_logits = False
- if locals().get('compute_metrics', None) is not None: _output_logits = True
- if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
- if _output_logits:
- os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
- if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
- pass
- else:
- model_max_seq_length = getattr(model, 'max_seq_length', None)
- args_max_seq_length = getattr(args, 'max_seq_length', None)
- if args_max_seq_length is None and model_max_seq_length is not None:
- max_seq_length = model.max_seq_length
- if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
- elif args_max_seq_length is not None and model_max_seq_length is not None:
- if args_max_seq_length > model_max_seq_length:
- print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but '
- 'the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.')
- args.max_seq_length = model_max_seq_length
- if model is not None and hasattr(model, 'for_training'):
- model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
- if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
- if 'processing_class' in locals():
- if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
- if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
- __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
- from unsloth_zoo.vision_utils import UnslothVisionDataCollator
- if not isinstance(data_collator, UnslothVisionDataCollator):
- if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
- data_collator = TransformersDataCollatorForLanguageModeling(
- __tokenizer,
- mlm = False,
- mlm_probability = 0.0,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
- data_collator = DataCollatorForSeq2Seq(
- __tokenizer,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- else:
- if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
- if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
- if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
- if not isinstance(data_collator, UnslothVisionDataCollator):
- if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
- if isinstance(data_collator, DataCollatorForSeq2Seq):
- data_collator = DataCollatorForSeq2Seq(
- __tokenizer.tokenizer,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- else:
- data_collator = TransformersDataCollatorForLanguageModeling(
- __tokenizer.tokenizer,
- mlm = False,
- mlm_probability = 0.0,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- other_metrics = []
-
- from unsloth_zoo.logging_utils import PatchRLStatistics
- PatchRLStatistics('kto_trainer', other_metrics)
-
- # [TODO] Fix up DataParallel multiplying batch sizes
- # [TODO] DDP works, but DP seems to not work? [TODO]
- if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1:
- if getattr(args, "_n_gpu", 1) != 1:
- args._n_gpu = 1
- if "model" in locals() and hasattr(model, "for_training"):
- model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
- super().__init__(
- model = model,
- ref_model = ref_model,
- args = args,
- train_dataset = train_dataset,
- eval_dataset = eval_dataset,
- processing_class = processing_class,
- data_collator = data_collator,
- model_init = model_init,
- callbacks = callbacks,
- preprocess_logits_for_metrics = preprocess_logits_for_metrics,
- peft_config = peft_config,
- compute_metrics = compute_metrics,
- model_adapter_name = model_adapter_name,
- ref_adapter_name = ref_adapter_name,**kwargs)
- if "model" in locals() and hasattr(model, "for_inference"):
- model.for_inference()
- if hasattr(self, 'neftune_hook_handle'):
- self.neftune_hook_handle.remove()
- if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
- if getattr(args, 'neftune_noise_alpha', None) is not None:
- model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
- pass
- if hasattr(self, 'accelerator'):
- scaler = self.accelerator.scaler
- current_model = model
- while hasattr(current_model, 'model'):
- current_model.accelerator_scaler = scaler
- current_model = current_model.model
- current_model.accelerator_scaler = scaler
- pass
- if hasattr(self, 'train'):
- self.train = MethodType(prepare_for_training_mode(self.__class__.train), self)
- pass
- if hasattr(self, 'llm') and self.llm is not None and hasattr(self.llm, 'get_tokenizer'):
- _vllm_tok = self.llm.get_tokenizer()
- _pc = getattr(self, 'processing_class', None) or getattr(self, 'tokenizer', None)
- if _vllm_tok is not None and _pc is not None and getattr(_pc, 'chat_template', None) is not None and getattr(_vllm_tok, 'chat_template', None) is None:
- _vllm_tok.chat_template = _pc.chat_template
- pass
-
-pass
-
-
-if hasattr(logger, "addFilter"):
- import logging
- class HideLoggingMessage(logging.Filter):
- def __init__(self, text): self.text = text
- def filter(self, x): return not (self.text in x.getMessage())
- pass
- logger.addFilter(HideLoggingMessage("`use_cache=True`"))
-
diff --git a/unsloth_compiled_cache/UnslothNashMDTrainer.py b/unsloth_compiled_cache/UnslothNashMDTrainer.py
deleted file mode 100644
index b44f278..0000000
--- a/unsloth_compiled_cache/UnslothNashMDTrainer.py
+++ /dev/null
@@ -1,1364 +0,0 @@
-"""
-2026.2.1
-2026.2.1
-4.57.6
-0.24.0
-__UNSLOTH_VERSIONING__
-"""
-
-# Unsloth auto generated code
-# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Lesser General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU Lesser General Public License
-# along with this program. If not, see .
-
-from torch import Tensor
-import torch
-import torch.nn as nn
-from torch.nn import functional as F
-from unsloth_zoo.temporary_patches.common import torch_compile
-from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
-from trl.trainer.nash_md_trainer import (Any, BaseImageProcessor, BasePairwiseJudge, Callable, Dataset, EvalPrediction, F, FeatureExtractionMixin, GeometricMixtureWrapper, IterableDataset, NashMDConfig, NashMDTrainer, OnlineDPOTrainer, OptimizerNames, Optional, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SIMPLE_CHAT_TEMPLATE, TrainerCallback, Union, empty_cache, get_reward, is_conversational, is_peft_available, jinja2, maybe_apply_chat_template, nn, selective_log_softmax, textwrap, torch, truncate_right, unwrap_model_for_generation)
-
-
-import os
-from typing import *
-from dataclasses import dataclass, field
-from packaging.version import Version
-import torch
-import numpy as np
-from contextlib import nullcontext
-from torch.nn import functional as F
-import inspect
-from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
-from transformers.training_args import ParallelMode
-from unsloth_zoo.device_type import DEVICE_TYPE, device_synchronize
-
-# Wrap trainer with padding to right and enable training mode
-# Also patches W&B since multiple runs must use wandb.finish()
-import functools
-from types import MethodType
-try:
- from unsloth_zoo.gradient_checkpointing import reset_unsloth_gradient_checkpointing_buffers
-except:
- def reset_unsloth_gradient_checkpointing_buffers(): pass
-def prepare_for_training_mode(f):
- @functools.wraps(f)
- def wrapper(self, *args, **kwargs):
- # Enable training mode
- _was_training = None
- # Get gradient checkpointing setting from training arguments
- use_gc = getattr(self.args, 'gradient_checkpointing', True)
- if hasattr(self, 'model') and hasattr(self.model, "training"):
- _was_training = self.model.training
- if hasattr(self, 'model') and hasattr(self.model, "for_training"):
- self.model.for_training(use_gradient_checkpointing=use_gc)
- output = f(self, *args, **kwargs)
- # Restore previous mode when possible
- if hasattr(self, 'model') and hasattr(self.model, "for_inference"):
- if _was_training is False:
- self.model.for_inference()
- elif _was_training is True and hasattr(self.model, "for_training"):
- self.model.for_training(use_gradient_checkpointing=use_gc)
- # Reset gradient checkpointing buffers to free memory while staying ready for next run
- try:
- reset_unsloth_gradient_checkpointing_buffers()
- except:
- pass
- # Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run
- try:
- import wandb
- wandb.finish()
- except:
- pass
- return output
- return wrapper
-pass
-
-torch_compile_options = {
- "epilogue_fusion" : True,
- "max_autotune" : False,
- "shape_padding" : True,
- "trace.enabled" : False,
- "triton.cudagraphs" : False,
-}
-
-@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
-def chunked_hidden_states_selective_log_softmax(
- hidden_states: torch.Tensor,
- lm_head: torch.Tensor,
- index: torch.Tensor,
- chunks: int = 4,
- logit_scale_multiply: float = 0.0,
- logit_scale_divide: float = 0.0,
- logit_softcapping: float = 0.0,
- temperature: float = 1.0,
-) -> torch.Tensor:
- # All Unsloth Zoo code licensed under AGPL3
- flat_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1])
- flat_index = index.reshape(-1)
-
- chunked_hidden_states = torch.chunk(flat_hidden_states, chunks=chunks, dim=0)
- chunked_index = torch.chunk(flat_index, chunks=chunks, dim=0)
-
- all_per_token_logps = []
-
- for chunk_hidden_states, chunk_index in zip(chunked_hidden_states, chunked_index):
- chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t()
-
- if logit_scale_multiply != 0.0:
- chunk_logits = chunk_logits * logit_scale_multiply
- if logit_scale_divide != 0.0:
- chunk_logits = chunk_logits / logit_scale_divide
- if logit_softcapping != 0.0:
- chunk_logits = chunk_logits * torch.tanh(chunk_logits / logit_softcapping)
-
- chunk_logits = chunk_logits.to(torch.float32)
-
- if temperature != 1.0:
- chunk_logits = chunk_logits / temperature
-
- selected_logits = torch.gather(chunk_logits, dim=-1, index=chunk_index.unsqueeze(-1)).squeeze(-1)
- logsumexp_values = torch.logsumexp(chunk_logits, dim=-1)
- per_token_logps = selected_logits - logsumexp_values
- all_per_token_logps.append(per_token_logps)
-
- all_per_token_logps = torch.concat(all_per_token_logps)
-
- all_per_token_logps = all_per_token_logps.reshape((hidden_states.shape[0], hidden_states.shape[1]))
- return all_per_token_logps
-
-@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
-def chunked_selective_log_softmax(logits, index):
- # Split into 4 chunks only
- chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
- chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
- all_per_token_logps = []
- # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
- for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
- chunk_logits = chunk_logits.to(torch.float32)
- selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
- logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
- per_token_logps = selected_logits - logsumexp_values
- all_per_token_logps.append(per_token_logps)
- pass
- all_per_token_logps = torch.concat(all_per_token_logps)
- all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
- return all_per_token_logps
-
-def calculate_pad_tokens_in_prompt(
- input_ids: torch.Tensor,
- logits_to_keep: int,
- pad_token_id: int
-) -> torch.Tensor:
- """
- Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens
- """
- if logits_to_keep >= input_ids.shape[1]:
- raise ValueError("logits_to_keep must be smaller than the sequence length.")
-
- prompt_section = input_ids[:, :-logits_to_keep]
-
- padding_mask = (prompt_section == pad_token_id)
-
- pad_token_counts = padding_mask.sum(dim=1)
-
- return pad_token_counts
-
-def create_completion_attention_mask(
- completion_input_ids: torch.Tensor,
- left_pad_tokens_per_prompt: torch.Tensor,
- max_left_pad: int,
- pad_token_id: int
-) -> torch.Tensor:
- """
- Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad]
-
- Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens
- and pad are pad tokens, this function would make a completion mask that would 0 out the pad
- and p tokens. so in this example [0,0,0,1,1,1,0,0,0]
- """
- batch_size, completion_len = completion_input_ids.shape
- device = completion_input_ids.device
-
- num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt
-
- indices = torch.arange(completion_len, device=device).unsqueeze(0)
- shift_mask = indices >= num_tokens_to_mask.unsqueeze(1)
-
- non_padding_mask = (completion_input_ids != pad_token_id)
-
- final_mask = shift_mask & non_padding_mask
-
- return final_mask
-
-def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor:
- """
- Moves all padding tokens in each sequence of a batch to the right.
- """
- mask = (tensor != pad_id)
- # Must do stable=True since binary mark is unordered
- sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True)
- packed_tensor = torch.gather(tensor, 1, sorted_indices)
- return packed_tensor
-
-def align_logprobs_with_mask(
- logprob_tensor: torch.Tensor,
- attention_mask: torch.Tensor,
- pad_value: float = 0.0
-) -> torch.Tensor:
- """
- Aligns a log probability tensor with a given attention mask.
- """
-
- device = logprob_tensor.device
- batch_size, logprob_seq_len = logprob_tensor.shape
- mask_seq_len = attention_mask.shape[1]
-
- padded_logprobs = torch.full(
- attention_mask.shape,
- fill_value=pad_value,
- dtype=logprob_tensor.dtype,
- device=device
- )
-
- left_pad_counts = torch.argmax(attention_mask, dim=1)
-
- cols = torch.arange(logprob_seq_len, device=device)
- dest_indices = left_pad_counts.unsqueeze(1) + cols
-
- # Create destination row indices
- # Shape: [batch_size, logprob_seq_len]
- row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices)
-
- # --- 4. Filter out-of-bounds indices and perform assignment ---
- # Create a mask to identify only the indices that are within the bounds
- # of the target tensor's sequence length.
- valid_mask = dest_indices < mask_seq_len
-
- # Use this mask to select only the valid row indices, column indices,
- # and the corresponding values from the logprob tensor.
- # This flattens the selected elements into 1D tensors.
- valid_rows = row_indices[valid_mask]
- valid_cols = dest_indices[valid_mask]
- valid_vals = logprob_tensor[valid_mask]
-
- # Place the valid values into their correct positions in the padded tensor
- # using a single, efficient advanced indexing operation.
- padded_logprobs[valid_rows, valid_cols] = valid_vals
-
- return padded_logprobs
-
-def autotune_batch_and_chunks(
- total_input_rows,
- seq_len,
- hidden_size,
- vocab_size,
- dtype_bytes=16,
- multiplier=None
-):
- if multiplier is None:
- final_m = max(4, seq_len // 4096)
- else:
- final_m = multiplier
-
- if torch.cuda.is_available():
- free_bytes, _ = torch.cuda.mem_get_info()
- limit_gb = (free_bytes / (1024**3))*.80
- elif hasattr(torch, "xpu") and torch.xpu.is_available():
- # For XPU: estimate free memory from total - reserved
- total_mem = torch.xpu.get_device_properties(0).total_memory
- reserved_mem = torch.xpu.memory_reserved()
- free_bytes = total_mem - reserved_mem
- limit_gb = (free_bytes / (1024**3)) * 0.80
- else:
- # Fallback: assume 8GB available
- limit_gb = 8.0
-
- bytes_to_gb = 1024**3
-
- b_vals = torch.arange(total_input_rows, 0, -1, device='cpu', dtype=torch.float32)
-
- hidden_gb = (b_vals * seq_len * hidden_size * dtype_bytes) / bytes_to_gb
-
- base_logits = ((b_vals/total_input_rows) * b_vals * seq_len * vocab_size * dtype_bytes) / bytes_to_gb
- logits_gb = base_logits / final_m
-
- total_mem_gb = hidden_gb + logits_gb
-
- valid_mask = total_mem_gb <= limit_gb
- valid_indices = torch.nonzero(valid_mask, as_tuple=False)
-
- if valid_indices.shape[0] == 0:
- #This means your GPU will OOM
- return 4, final_m
-
- best_idx = valid_indices[0].item()
- final_b = int(b_vals[best_idx].item())
-
- return final_b, final_m
-@dataclass
-class UnslothNashMDConfig(NashMDConfig):
- """
-
- Configuration class for the [`NashMDTrainer`].
-
- Subclass of [`OnlineDPOConfig`] we can use all its arguments and add the following:
-
- Parameters:
- mixture_coef (`float` or `list[float]`, *optional*, defaults to `0.5`):
- Logit mixture coefficient for the model and reference model. If a list of floats is provided then the
- mixture coefficient is selected for each new epoch and the last coefficient is used for the rest of the
- epochs.
-
- """
- vllm_sampling_params: Optional[Any] = field(
- default = None,
- metadata = {'help': 'vLLM SamplingParams'},
- )
- unsloth_num_chunks : Optional[int] = field(
- default = -1,
- metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
- )
- unsloth_logit_chunk_multiplier : Optional[int] = field(
- default = None,
- metadata = {'help': 'Multiplier for chunked logit computations.'},
- )
- unsloth_grpo_mini_batch : Optional[int] = field(
- default = None,
- metadata = {'help': 'Mini batch size for GRPO hidden state accumulation. Default is None unless user defines it.'},
- )
- max_seq_length : Optional[int] = field(
- default = None,
- metadata = {'help': 'Maximum sequence length to truncate to.'},
- )
- def __init__(
- self,
- output_dir = None,
- overwrite_output_dir = None,
- do_train = False,
- do_eval = False,
- do_predict = False,
- eval_strategy = 'no',
- prediction_loss_only = False,
- per_device_train_batch_size = 4,
- per_device_eval_batch_size = 4,
- per_gpu_train_batch_size = None,
- per_gpu_eval_batch_size = None,
- gradient_accumulation_steps = 2,
- eval_accumulation_steps = 2,
- eval_delay = 0,
- torch_empty_cache_steps = 250,
- learning_rate = 5e-05,
- weight_decay = 0.01,
- adam_beta1 = 0.9,
- adam_beta2 = 0.999,
- adam_epsilon = 1e-08,
- max_grad_norm = 1.0,
- num_train_epochs = 3.0,
- max_steps = -1,
- lr_scheduler_type = 'linear',
- lr_scheduler_kwargs = None,
- warmup_ratio = 0.1,
- warmup_steps = 0,
- log_level = 'passive',
- log_level_replica = 'warning',
- log_on_each_node = True,
- logging_dir = None,
- logging_strategy = 'steps',
- logging_first_step = False,
- logging_steps = 1,
- logging_nan_inf_filter = False,
- save_strategy = 'steps',
- save_steps = 500,
- save_total_limit = None,
- save_safetensors = True,
- save_on_each_node = False,
- save_only_model = False,
- restore_callback_states_from_checkpoint = False,
- no_cuda = False,
- use_cpu = False,
- use_mps_device = False,
- seed = 3407,
- data_seed = 3407,
- jit_mode_eval = False,
- bf16 = False,
- fp16 = False,
- fp16_opt_level = 'O1',
- half_precision_backend = 'auto',
- bf16_full_eval = False,
- fp16_full_eval = False,
- tf32 = None,
- local_rank = -1,
- ddp_backend = None,
- tpu_num_cores = None,
- tpu_metrics_debug = False,
- debug = '',
- dataloader_drop_last = False,
- eval_steps = None,
- dataloader_num_workers = 0,
- dataloader_prefetch_factor = None,
- past_index = -1,
- run_name = None,
- disable_tqdm = None,
- remove_unused_columns = True,
- label_names = None,
- load_best_model_at_end = False,
- metric_for_best_model = None,
- greater_is_better = None,
- ignore_data_skip = False,
- fsdp = None,
- fsdp_min_num_params = 0,
- fsdp_config = None,
- fsdp_transformer_layer_cls_to_wrap = None,
- accelerator_config = None,
- parallelism_config = None,
- deepspeed = None,
- label_smoothing_factor = 0.0,
- optim = 'adamw_8bit',
- optim_args = None,
- adafactor = False,
- group_by_length = False,
- length_column_name = 'length',
- report_to = 'none',
- project = 'huggingface',
- trackio_space_id = 'trackio',
- ddp_find_unused_parameters = None,
- ddp_bucket_cap_mb = None,
- ddp_broadcast_buffers = None,
- dataloader_pin_memory = True,
- dataloader_persistent_workers = False,
- skip_memory_metrics = True,
- use_legacy_prediction_loop = False,
- push_to_hub = False,
- resume_from_checkpoint = None,
- hub_model_id = None,
- hub_strategy = 'every_save',
- hub_token = None,
- hub_private_repo = None,
- hub_always_push = False,
- hub_revision = None,
- gradient_checkpointing = True,
- gradient_checkpointing_kwargs = None,
- include_inputs_for_metrics = False,
- eval_do_concat_batches = True,
- fp16_backend = 'auto',
- push_to_hub_model_id = None,
- push_to_hub_organization = None,
- push_to_hub_token = None,
- mp_parameters = '',
- auto_find_batch_size = False,
- full_determinism = False,
- torchdynamo = None,
- ray_scope = 'last',
- ddp_timeout = 1800,
- torch_compile = False,
- torch_compile_backend = None,
- torch_compile_mode = None,
- include_tokens_per_second = False,
- include_num_input_tokens_seen = False,
- neftune_noise_alpha = None,
- optim_target_modules = None,
- batch_eval_metrics = False,
- eval_on_start = False,
- use_liger_kernel = False,
- liger_kernel_config = None,
- eval_use_gather_object = False,
- average_tokens_across_devices = True,
- reward_model_path = None,
- judge = None,
- max_new_tokens = 64,
- max_length = 512,
- temperature = 0.9,
- top_p = 1.0,
- top_k = None,
- min_p = None,
- repetition_penalty = 1.0,
- generation_kwargs = {},
- use_transformers_paged = False,
- cache_implementation = None,
- missing_eos_penalty = None,
- loss_type = 'sigmoid',
- disable_dropout = True,
- use_vllm = False,
- vllm_model_impl = 'vllm',
- vllm_guided_decoding_regex = None,
- vllm_gpu_memory_utilization = 0.55,
- vllm_mode = 'colocate',
- vllm_server_base_url = None,
- vllm_server_host = '0.0.0.0',
- vllm_server_port = 8000,
- vllm_server_timeout = 240.0,
- vllm_tensor_parallel_size = 1,
- ds3_gather_for_generation = True,
- model_init_kwargs = None,
- reward_weights = None,
- dataset_num_proc = None,
- gpu_memory_utilization = None,
- vllm_sampling_params = None,
- unsloth_num_chunks = -1,
- unsloth_logit_chunk_multiplier = None,
- unsloth_grpo_mini_batch = None,
- max_seq_length = None,
- **kwargs,
- ):
- if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
- if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
- if num_train_epochs is None:
- num_train_epochs = 3.0 # Default to 3 epochs if None, max_steps will override
- if output_dir is None and save_strategy == 'steps' and save_steps == 500:
- output_dir = 'unsloth_training_checkpoints'
- save_strategy = 'no'
- import multiprocessing as _mp
- if _mp.get_start_method() != 'fork':
- dataset_num_proc = None
- elif dataset_num_proc is None:
- import psutil
- dataset_num_proc = min(max((psutil.cpu_count() or 1)+4, 2), 64)
- memory_gb_left = psutil.virtual_memory().available / (1024**3)
- if memory_gb_left <= 2: dataset_num_proc = 1
- else: dataset_num_proc = min(dataset_num_proc, int(memory_gb_left))
- if temperature <= 0:
- raise ValueError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.')
- elif temperature >= 10:
- raise ValueError('Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic.')
-
-
- super().__init__(
- output_dir = output_dir,
- overwrite_output_dir = overwrite_output_dir,
- do_train = do_train,
- do_eval = do_eval,
- do_predict = do_predict,
- eval_strategy = eval_strategy,
- prediction_loss_only = prediction_loss_only,
- per_device_train_batch_size = per_device_train_batch_size,
- per_device_eval_batch_size = per_device_eval_batch_size,
- per_gpu_train_batch_size = per_gpu_train_batch_size,
- per_gpu_eval_batch_size = per_gpu_eval_batch_size,
- gradient_accumulation_steps = gradient_accumulation_steps,
- eval_accumulation_steps = eval_accumulation_steps,
- eval_delay = eval_delay,
- torch_empty_cache_steps = torch_empty_cache_steps,
- learning_rate = learning_rate,
- weight_decay = weight_decay,
- adam_beta1 = adam_beta1,
- adam_beta2 = adam_beta2,
- adam_epsilon = adam_epsilon,
- max_grad_norm = max_grad_norm,
- num_train_epochs = num_train_epochs,
- max_steps = max_steps,
- lr_scheduler_type = lr_scheduler_type,
- lr_scheduler_kwargs = lr_scheduler_kwargs,
- warmup_ratio = warmup_ratio,
- warmup_steps = warmup_steps,
- log_level = log_level,
- log_level_replica = log_level_replica,
- log_on_each_node = log_on_each_node,
- logging_dir = logging_dir,
- logging_strategy = logging_strategy,
- logging_first_step = logging_first_step,
- logging_steps = logging_steps,
- logging_nan_inf_filter = logging_nan_inf_filter,
- save_strategy = save_strategy,
- save_steps = save_steps,
- save_total_limit = save_total_limit,
- save_safetensors = save_safetensors,
- save_on_each_node = save_on_each_node,
- save_only_model = save_only_model,
- restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
- no_cuda = no_cuda,
- use_cpu = use_cpu,
- use_mps_device = use_mps_device,
- seed = seed,
- data_seed = data_seed,
- jit_mode_eval = jit_mode_eval,
- bf16 = bf16,
- fp16 = fp16,
- fp16_opt_level = fp16_opt_level,
- half_precision_backend = half_precision_backend,
- bf16_full_eval = bf16_full_eval,
- fp16_full_eval = fp16_full_eval,
- tf32 = tf32,
- local_rank = local_rank,
- ddp_backend = ddp_backend,
- tpu_num_cores = tpu_num_cores,
- tpu_metrics_debug = tpu_metrics_debug,
- debug = debug,
- dataloader_drop_last = dataloader_drop_last,
- eval_steps = eval_steps,
- dataloader_num_workers = dataloader_num_workers,
- dataloader_prefetch_factor = dataloader_prefetch_factor,
- past_index = past_index,
- run_name = run_name,
- disable_tqdm = disable_tqdm,
- remove_unused_columns = remove_unused_columns,
- label_names = label_names,
- load_best_model_at_end = load_best_model_at_end,
- metric_for_best_model = metric_for_best_model,
- greater_is_better = greater_is_better,
- ignore_data_skip = ignore_data_skip,
- fsdp = fsdp,
- fsdp_min_num_params = fsdp_min_num_params,
- fsdp_config = fsdp_config,
- fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
- accelerator_config = accelerator_config,
- parallelism_config = parallelism_config,
- deepspeed = deepspeed,
- label_smoothing_factor = label_smoothing_factor,
- optim = optim,
- optim_args = optim_args,
- adafactor = adafactor,
- group_by_length = group_by_length,
- length_column_name = length_column_name,
- report_to = report_to,
- project = project,
- trackio_space_id = trackio_space_id,
- ddp_find_unused_parameters = ddp_find_unused_parameters,
- ddp_bucket_cap_mb = ddp_bucket_cap_mb,
- ddp_broadcast_buffers = ddp_broadcast_buffers,
- dataloader_pin_memory = dataloader_pin_memory,
- dataloader_persistent_workers = dataloader_persistent_workers,
- skip_memory_metrics = skip_memory_metrics,
- use_legacy_prediction_loop = use_legacy_prediction_loop,
- push_to_hub = push_to_hub,
- resume_from_checkpoint = resume_from_checkpoint,
- hub_model_id = hub_model_id,
- hub_strategy = hub_strategy,
- hub_token = hub_token,
- hub_private_repo = hub_private_repo,
- hub_always_push = hub_always_push,
- hub_revision = hub_revision,
- gradient_checkpointing = gradient_checkpointing,
- gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
- include_inputs_for_metrics = include_inputs_for_metrics,
- eval_do_concat_batches = eval_do_concat_batches,
- fp16_backend = fp16_backend,
- push_to_hub_model_id = push_to_hub_model_id,
- push_to_hub_organization = push_to_hub_organization,
- push_to_hub_token = push_to_hub_token,
- mp_parameters = mp_parameters,
- auto_find_batch_size = auto_find_batch_size,
- full_determinism = full_determinism,
- torchdynamo = torchdynamo,
- ray_scope = ray_scope,
- ddp_timeout = ddp_timeout,
- torch_compile = torch_compile,
- torch_compile_backend = torch_compile_backend,
- torch_compile_mode = torch_compile_mode,
- include_tokens_per_second = include_tokens_per_second,
- include_num_input_tokens_seen = include_num_input_tokens_seen,
- neftune_noise_alpha = neftune_noise_alpha,
- optim_target_modules = optim_target_modules,
- batch_eval_metrics = batch_eval_metrics,
- eval_on_start = eval_on_start,
- use_liger_kernel = use_liger_kernel,
- liger_kernel_config = liger_kernel_config,
- eval_use_gather_object = eval_use_gather_object,
- average_tokens_across_devices = average_tokens_across_devices,
- reward_model_path = reward_model_path,
- judge = judge,
- max_new_tokens = max_new_tokens,
- max_length = max_length,
- temperature = temperature,
- top_p = top_p,
- top_k = top_k,
- min_p = min_p,
- repetition_penalty = repetition_penalty,
- generation_kwargs = generation_kwargs,
- use_transformers_paged = use_transformers_paged,
- cache_implementation = cache_implementation,
- missing_eos_penalty = missing_eos_penalty,
- loss_type = loss_type,
- disable_dropout = disable_dropout,
- use_vllm = use_vllm,
- vllm_model_impl = vllm_model_impl,
- vllm_guided_decoding_regex = vllm_guided_decoding_regex,
- vllm_gpu_memory_utilization = vllm_gpu_memory_utilization,
- vllm_mode = vllm_mode,
- vllm_server_base_url = vllm_server_base_url,
- vllm_server_host = vllm_server_host,
- vllm_server_port = vllm_server_port,
- vllm_server_timeout = vllm_server_timeout,
- vllm_tensor_parallel_size = vllm_tensor_parallel_size,
- ds3_gather_for_generation = ds3_gather_for_generation,
- model_init_kwargs = model_init_kwargs,
- reward_weights = reward_weights,
- dataset_num_proc = dataset_num_proc,
- gpu_memory_utilization = gpu_memory_utilization,**kwargs)
- self.vllm_sampling_params = vllm_sampling_params
- self.unsloth_num_chunks = unsloth_num_chunks
- if unsloth_grpo_mini_batch is not None:
- if self.generation_batch_size >= unsloth_grpo_mini_batch:
- self.unsloth_grpo_mini_batch = unsloth_grpo_mini_batch
- else:
- raise ValueError(
- f"Unsloth GRPO mini batch size needs to be less than or equal to the effective generation batch size, "
- f"which is self.per_device_train_batch_size * gradient_accumulation_steps."
- )
- self.unsloth_logit_chunk_multiplier = unsloth_logit_chunk_multiplier
- self.max_seq_length = max_seq_length
-
-pass
-
-class _UnslothNashMDTrainer(OnlineDPOTrainer):
- """"""
-
- _tag_names = ["trl", "nash-md"]
- _name = "Nash-MD"
- _paper = {
- "title": "Nash Learning from Human Feedback",
- "id": "2312.00886",
- # docstyle-ignore
- "citation": textwrap.dedent("""\
- @inproceedings{munos2024nash,
- title = {{Nash Learning from Human Feedback}},
- author = {R{\'{e}}mi Munos and Michal Valko and Daniele Calandriello and Mohammad Gheshlaghi Azar and Mark Rowland and Zhaohan Daniel Guo and Yunhao Tang and Matthieu Geist and Thomas Mesnard and C{\\^{o}}me Fiegel and Andrea Michi and Marco Selvi and Sertan Girgin and Nikola Momchev and Olivier Bachem and Daniel J. Mankowitz and Doina Precup and Bilal Piot},
- year = 2024,
- booktitle = {Forty-first International Conference on Machine Learning, {ICML} 2024, Vienna, Austria, July 21-27, 2024},
- publisher = {OpenReview.net},
- url = {https://openreview.net/forum?id=Y5AmNYiyCQ}
- }"""),
- }
-
- def __init__(
- self,
- model: Union[PreTrainedModel, nn.Module] = None,
- ref_model: Union[PreTrainedModel, nn.Module] = None,
- reward_funcs: Union[PreTrainedModel, nn.Module, None] = None,
- judge: Optional[BasePairwiseJudge] = None,
- args: Optional[NashMDConfig] = None,
- data_collator: Optional[Callable] = None,
- train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
- eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
- processing_class: Optional[
- Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
- ] = None,
- peft_config: Optional[dict] = None,
- compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
- callbacks: Optional[list[TrainerCallback]] = None,
- optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
- preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
- # Deprecated parameters
- reward_model: Optional[Union[PreTrainedModel, nn.Module]] = None,
- ) -> None:
- super().__init__(
- model=model,
- ref_model=ref_model,
- reward_funcs=reward_funcs,
- judge=judge,
- args=args,
- data_collator=data_collator,
- train_dataset=train_dataset,
- eval_dataset=eval_dataset,
- processing_class=processing_class,
- reward_processing_classes=processing_class,
- peft_config=peft_config,
- compute_metrics=compute_metrics,
- callbacks=callbacks,
- optimizers=optimizers,
- preprocess_logits_for_metrics=preprocess_logits_for_metrics,
- reward_model=reward_model,
- )
-
- self._mixture_coef = self.args.mixture_coef
-
- # Overwrite the stats dictionary to include NashMD specific statistics
- self.stats = {
- # Remove "non_score_reward", "rlhf_reward", "scores_margin"
- # Add "mixture_coef"
- "loss/kl": [],
- "objective/entropy": [],
- "loss/score": [],
- "rewards/probabilities": [],
- "rewards/accuracies": [],
- "rewards/margins": [],
- "logps/chosen": [],
- "logps/rejected": [],
- "val/model_contain_eos_token": [],
- "val/ref_contain_eos_token": [],
- "beta": [],
- "mixture_coef": [],
- }
- if self.reward_funcs is not None:
- if len(self.reward_funcs) != 1:
- raise ValueError("NashMDTrainer only supports one reward function/model.")
- self.reward_funcs = self.reward_funcs[0]
- self.stats["rewards/chosen"] = []
- self.stats["rewards/rejected"] = []
-
- @property
- def mixture_coef(self):
- if isinstance(self._mixture_coef, list):
- epoch = self.state.epoch
- return self._mixture_coef[epoch] if epoch < len(self._mixture_coef) else self._mixture_coef[-1]
- else:
- return self._mixture_coef
-
- def _generate_completions(self, model, prompts):
- # Generate completions from the policy model.
- with unwrap_model_for_generation(model, self.accelerator) as unwrapped_policy_for_gen_ctx:
- model_output = unwrapped_policy_for_gen_ctx.generate(
- input_ids=prompts["input_ids"],
- attention_mask=prompts["attention_mask"],
- generation_config=self.generation_config,
- )
-
- # Get the DDP/FSDP unwrapped version of the main model.
- # This will be the policy model for GeometricMixtureWrapper (PEFT adapters active if PEFT is used).
- policy_model_for_gmw = self.accelerator.unwrap_model(model)
-
- # Determine the correct reference model for GeometricMixtureWrapper.
- # This also needs to be DDP/FSDP unwrapped.
- ref_model_for_gmw: torch.nn.Module
- if self.ref_model is None:
- # No explicit ref_model is provided.
- # Use the base of the main `model` if it's a PEFT model.
- # policy_model_for_gmw is already DDP-unwrapped.
- if is_peft_available() and isinstance(policy_model_for_gmw, PeftModel):
- ref_model_for_gmw = policy_model_for_gmw.get_base_model()
- else:
- # Not a PEFT model (or PEFT not available), or already a base model.
- # Use the DDP-unwrapped policy model itself as the reference.
- ref_model_for_gmw = policy_model_for_gmw
- else:
- # An explicit ref_model is provided. Unwrap it for DDP/FSDP.
- ref_model_for_gmw = self.accelerator.unwrap_model(self.ref_model)
-
- # Both models given to GeometricMixtureWrapper (policy_model_for_gmw and ref_model_for_gmw) are DDP-unwrapped.
- with torch.no_grad(): # Ensure no_grad context for mixture model generation
- mixture_model = GeometricMixtureWrapper(
- model=policy_model_for_gmw,
- ref_model=ref_model_for_gmw,
- generation_config=self.generation_config,
- mixture_coef=self.mixture_coef,
- device=self.accelerator.device,
- )
-
- mixture_output = mixture_model.generate(
- input_ids=prompts["input_ids"],
- attention_mask=prompts["attention_mask"],
- generation_config=self.generation_config,
- )
-
- return model_output, mixture_output
-
- def _process_completions(self, model_output, mixture_output, prompts):
- context_length = prompts["input_ids"].shape[1]
-
- # Process model completions
- model_completion_ids = model_output[:, context_length:]
- model_completion_ids, model_completion_mask = truncate_right(
- model_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
- )
- model_data = {
- "input_ids": torch.cat((prompts["input_ids"], model_completion_ids), dim=1),
- "attention_mask": torch.cat((prompts["attention_mask"], model_completion_mask), dim=1),
- "raw": prompts["raw"],
- }
-
- # Process reference model completions
- mixture_completion_ids = mixture_output[:, context_length:]
- mixture_completion_ids, mixture_completion_mask = truncate_right(
- mixture_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
- )
- mixture_data = {
- "input_ids": torch.cat((prompts["input_ids"], mixture_completion_ids), dim=1),
- "attention_mask": torch.cat((prompts["attention_mask"], mixture_completion_mask), dim=1),
- "raw": prompts["raw"],
- }
-
- return model_data, mixture_data
-
- def _compute_rewards(self, model_data, mixture_data, context_length):
- with torch.no_grad():
- _, model_scores, _ = get_reward(
- self.reward_funcs, model_data["input_ids"], self.processing_class.pad_token_id, context_length
- )
- _, mixture_scores, _ = get_reward(
- self.reward_funcs, mixture_data["input_ids"], self.processing_class.pad_token_id, context_length
- )
-
- # Apply EOS penalty if needed
- if self.args.missing_eos_penalty is not None:
- model_contain_eos = torch.any(model_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
- mixture_contain_eos = torch.any(mixture_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
- model_scores[~model_contain_eos] -= self.args.missing_eos_penalty
- mixture_scores[~mixture_contain_eos] -= self.args.missing_eos_penalty
-
- return model_scores, mixture_scores
-
- def _compute_judge(self, model_data, mixture_data, context_length):
- prompts = model_data["raw"]
- model_data_completions = self.processing_class.batch_decode(
- model_data["input_ids"][:, context_length:], skip_special_tokens=True
- )
- model_data_completions = [completion.strip() for completion in model_data_completions]
-
- mixture_data_completions = self.processing_class.batch_decode(
- mixture_data["input_ids"][:, context_length:], skip_special_tokens=True
- )
- mixture_data_completions = [completion.strip() for completion in mixture_data_completions]
- if is_conversational({"prompt": prompts[0]}):
- model_data_completions = [
- [{"role": "assistant", "content": completion}] for completion in model_data_completions
- ]
- environment = jinja2.Environment()
- template = environment.from_string(SIMPLE_CHAT_TEMPLATE)
- prompts = [template.render(messages=message) for message in prompts]
- model_data_completions = [template.render(messages=completion) for completion in model_data_completions]
-
- mixture_data_completions = [
- [{"role": "assistant", "content": completion}] for completion in mixture_data_completions
- ]
- mixture_data_completions = [
- template.render(messages=completion) for completion in mixture_data_completions
- ]
-
- probability = self.judge.judge(
- prompts,
- list(zip(model_data_completions, mixture_data_completions)),
- return_scores=True,
- )
- return torch.tensor(probability, device=model_data["input_ids"].device)
-
- def _compute_logprobs(self, model, model_data, context_length):
- def compute_logprobs_for_data(m, data):
- output = m(data["input_ids"], attention_mask=data["attention_mask"])
- logits = output.logits[:, context_length - 1 : -1]
- token_logprobs = selective_log_softmax(logits, data["input_ids"][:, context_length:])
- return token_logprobs
-
- # Compute logprobs for model completions under the model
- model_logprobs_model_data = compute_logprobs_for_data(model, model_data)
-
- # Compute logprobs of model completions under the reference model
- with torch.no_grad():
- if self.ref_model is None:
- with model.disable_adapter():
- ref_logprobs_model_data = compute_logprobs_for_data(model, model_data)
- else:
- ref_logprobs_model_data = compute_logprobs_for_data(self.ref_model, model_data)
-
- # Mask padding tokens
- model_padding_mask = model_data["attention_mask"][:, context_length:] == 0
- model_logprobs_model_data = model_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
- ref_logprobs_model_data = ref_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
-
- return (model_logprobs_model_data, ref_logprobs_model_data)
-
- def _compute_losses(
- self,
- model_logprobs_model_data,
- ref_logprobs_model_data,
- probability,
- ):
- # reinforce score where 0.5 is a control variate
- score = (probability - 0.5) * model_logprobs_model_data.sum(1)
-
- # kl divergence via reinforce
- with torch.no_grad():
- log_ratio = model_logprobs_model_data - ref_logprobs_model_data
- kl_div_log = log_ratio.sum(1)
- kl_div_loss = (log_ratio * model_logprobs_model_data).sum(1)
-
- # final loss
- loss = self.beta * kl_div_loss - score
-
- return loss.mean(), score, kl_div_log
-
- def _log_statistics(
- self,
- model_data,
- mixture_data,
- model_logprobs_model_data,
- ref_logprobs_model_data,
- probability,
- score,
- kl_div,
- context_length,
- model_scores=None,
- mixture_scores=None,
- ):
- # Helper function to gather and compute mean
- def gather_mean(tensor):
- return self.accelerator.gather_for_metrics(tensor).mean().item()
-
- # Log score
- self.stats["loss/score"].append(gather_mean(score))
- # Log KL divergence
- self.stats["loss/kl"].append(gather_mean(kl_div))
-
- # Log logprobs
- model_logprobs_model_data_sum = model_logprobs_model_data.sum(1)
- ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1)
-
- self.stats["logps/chosen"].append(gather_mean(model_logprobs_model_data_sum))
- self.stats["logps/rejected"].append(gather_mean(ref_logprobs_model_data_sum))
-
- # Log rewards
- if self.reward_funcs is not None:
- self.stats["rewards/chosen"].append(gather_mean(model_scores))
- self.stats["rewards/rejected"].append(gather_mean(mixture_scores))
-
- # Log probabilities
- self.stats["rewards/probabilities"].append(gather_mean(probability))
-
- # Calculate entropy for model data
- entropy_model_data = -model_logprobs_model_data.sum(1)
- self.stats["objective/entropy"].append(gather_mean(entropy_model_data))
-
- # Calculate margins
- margin = model_logprobs_model_data_sum - ref_logprobs_model_data_sum
- self.stats["rewards/margins"].append(gather_mean(margin))
-
- # Calculate accuracy
- accuracy = (margin > 0).float()
- self.stats["rewards/accuracies"].append(gather_mean(accuracy))
-
- # Log EOS token statistics
- model_eos = (model_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
- mixture_eos = (mixture_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
- self.stats["val/model_contain_eos_token"].append(gather_mean(model_eos.float()))
- self.stats["val/ref_contain_eos_token"].append(gather_mean(mixture_eos.float()))
-
- # Log beta and mixture coef
- self.stats["beta"].append(self.beta)
- self.stats["mixture_coef"].append(self.mixture_coef)
-
- def training_step(
- self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
- ) -> torch.Tensor:
- model.train()
-
- # Apply chat template and tokenize the input
- batch_size = len(next(iter(inputs.values())))
- prompts = inputs["prompt"]
- inputs = [{k: v[i] for k, v in inputs.items()} for i in range(batch_size)]
- inputs = [maybe_apply_chat_template(x, self.processing_class) for x in inputs]
- inputs = [self.tokenize_row(x, self.model.config.is_encoder_decoder, self.processing_class) for x in inputs]
- inputs = self.data_collator(inputs)
-
- # need the prompt_ only
- inputs = self._prepare_inputs(inputs)
- context_length = inputs["prompt_input_ids"].shape[1]
- prompts = {
- "input_ids": inputs["prompt_input_ids"],
- "attention_mask": inputs["prompt_attention_mask"],
- "raw": prompts,
- }
- del inputs
-
- # Sample completions from both the model and the reference model
- model_output, mixture_output = self._generate_completions(model, prompts)
-
- # Process model completions
- model_data, mixture_data = self._process_completions(model_output, mixture_output, prompts)
-
- # Compute rewards
- if self.reward_funcs is not None:
- model_scores, mixture_scores = self._compute_rewards(model_data, mixture_data, context_length)
- # probability of the model data vs the mixture data
- probability = F.sigmoid(model_scores - mixture_scores)
- else:
- model_scores, mixture_scores = None, None
- probability = self._compute_judge(model_data, mixture_data, context_length)
-
- # Compute logprobs
- model_logprobs_model_data, ref_logprobs_model_data = self._compute_logprobs(model, model_data, context_length)
-
- # Compute loss
- loss, score, kl_div = self._compute_losses(model_logprobs_model_data, ref_logprobs_model_data, probability)
-
- # Log everything
- self._log_statistics(
- model_data,
- mixture_data,
- model_logprobs_model_data.detach(),
- ref_logprobs_model_data,
- probability,
- score.detach(),
- kl_div.detach(),
- context_length,
- model_scores,
- mixture_scores,
- )
-
- if (
- self.args.torch_empty_cache_steps is not None
- and self.state.global_step % self.args.torch_empty_cache_steps == 0
- ):
- empty_cache()
-
- kwargs = {}
- # For LOMO optimizers you need to explicitly use the learning rate
- if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
- kwargs["learning_rate"] = self._get_learning_rate()
-
- if self.args.n_gpu > 1:
- loss = loss.mean() # mean() to average on multi-gpu parallel training
-
- self.accelerator.backward(loss, **kwargs)
-
- return loss.detach() / self.args.gradient_accumulation_steps
-class UnslothNashMDTrainer(_UnslothNashMDTrainer):
- """
-
- Trainer for the Nash-MD method.
-
- It is implemented as a subclass of [`OnlineDPOTrainer`].
-
- Args:
- model ([`~transformers.PreTrainedModel`]):
- The model to train, preferably an `AutoModelForCausalLM`.
- ref_model ([`PreTrainedModelWrapper`]):
- Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation
- and loss. If no reference model is provided, the trainer will create a reference model with the same
- architecture as the model to be optimized.
- reward_funcs ([`~transformers.PreTrainedModel`]):
- The reward model to score completions with, preferably an
- [`~transformers.AutoModelForSequenceClassification`].
- judge ([`BasePairwiseJudge`]):
- The judge to use for pairwise comparison of model completions.
- args ([`NashMDConfig`]):
- The NashMD config arguments to use for training.
- data_collator ([`~transformers.DataCollator`]):
- The data collator to use for training. If None is specified, the default data collator
- ([`DPODataCollatorWithPadding`]) will be used which will pad the sequences to the maximum length of the
- sequences in the batch, given a dataset of paired sequences.
- train_dataset ([`~datasets.Dataset`]):
- The dataset to use for training.
- eval_dataset ([`~datasets.Dataset`]):
- The dataset to use for evaluation.
- processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*):
- Processing class used to process the data. If provided, will be used to automatically process the inputs
- for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
- reuse the fine-tuned model.
- peft_config (`dict`):
- The peft config to use for training.
- compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
- The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to
- metric values.
- callbacks (`list[transformers.TrainerCallback]`):
- The callbacks to use for training.
- optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
- The optimizer and scheduler to use for training.
- preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
- The function to use to preprocess the logits before computing the metrics.
-
- reward_model:
-
-
-
- This parameter is deprecated and will be removed in version 0.25.0. Use `reward_funcs` instead.
-
-
-
- """
- def __init__(
- self,
- model = None,
- ref_model = None,
- reward_funcs = None,
- judge = None,
- args = None,
- data_collator = None,
- train_dataset = None,
- eval_dataset = None,
- processing_class = None,
- peft_config = None,
- compute_metrics = None,
- callbacks = None,
- preprocess_logits_for_metrics = None,
- reward_model = None,
- **kwargs
- ):
- if args is None: args = UnslothNashMDConfig()
- use_bf16 = getattr(args, 'bf16', False)
- if type(use_bf16) is not bool: use_bf16 = False
- use_fp16 = getattr(args, 'fp16', False)
- if type(use_fp16) is not bool: use_fp16 = False
- force_float32 = False
- full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1'
- if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'):
- print('Unsloth: Switching to float32 training since model cannot work with float16')
- force_float32 = True
- mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
- dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None)
- if dtype is None: dtype = model.get_input_embeddings().weight.dtype
- from unsloth_zoo.utils import _get_dtype
- dtype = _get_dtype(dtype)
- float16 = dtype == torch.float16
- if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
- if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
- if force_float32:
- # Forced float32 training
- args.fp16 = False
- args.bf16 = False
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
- if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
- # args.mixed_precision is a new argument which needs to be set now
- elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
- # Mixed precision training
- args.fp16 = float16
- args.bf16 = not float16
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
- if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16'
- # args.mixed_precision is a new argument which needs to be set now
- elif mixed_precision_dtype == 'bfloat16':
- # Both False since bfloat16 full finetuning doesn't do any autocasting.
- args.fp16 = False
- args.bf16 = False
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
- if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
- # args.mixed_precision is a new argument which needs to be set now
-
- if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
- args.eval_strategy = 'steps'
- if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
- ga_steps = getattr(args, 'gradient_accumulation_steps', None)
- if ga_steps is not None and ga_steps > 1:
- from transformers import __version__ as transformers_version
- if Version(transformers_version) <= Version('4.45.2'):
- print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
- '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
- if getattr(args, 'eval_strategy', 'no') != 'no':
- eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
- if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
- if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
- fp16_full_eval = getattr(args, 'fp16_full_eval', False)
- if type(fp16_full_eval) is not bool: fp16_full_eval = False
- bf16_full_eval = getattr(args, 'bf16_full_eval', False)
- if type(bf16_full_eval) is not bool: bf16_full_eval = False
- if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
- if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
- if force_float32:
- args.bf16_full_eval = False
- args.fp16_full_eval = False
- elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
- args.bf16_full_eval = True
- args.fp16_full_eval = False
- elif not bf16_full_eval and not fp16_full_eval:
- args.bf16_full_eval = args.bf16
- args.fp16_full_eval = args.fp16
- _output_logits = False
- if locals().get('compute_metrics', None) is not None: _output_logits = True
- if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
- if _output_logits:
- os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
- if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
- pass
- else:
- model_max_seq_length = getattr(model, 'max_seq_length', None)
- args_max_seq_length = getattr(args, 'max_seq_length', None)
- if args_max_seq_length is None and model_max_seq_length is not None:
- max_seq_length = model.max_seq_length
- if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
- elif args_max_seq_length is not None and model_max_seq_length is not None:
- if args_max_seq_length > model_max_seq_length:
- print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but '
- 'the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.')
- args.max_seq_length = model_max_seq_length
- if model is not None and hasattr(model, 'for_training'):
- model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
- if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
- if 'processing_class' in locals():
- if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
- if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
- __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
- from unsloth_zoo.vision_utils import UnslothVisionDataCollator
- if not isinstance(data_collator, UnslothVisionDataCollator):
- if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
- data_collator = TransformersDataCollatorForLanguageModeling(
- __tokenizer,
- mlm = False,
- mlm_probability = 0.0,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
- data_collator = DataCollatorForSeq2Seq(
- __tokenizer,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- else:
- if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
- if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
- if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
- if not isinstance(data_collator, UnslothVisionDataCollator):
- if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
- if isinstance(data_collator, DataCollatorForSeq2Seq):
- data_collator = DataCollatorForSeq2Seq(
- __tokenizer.tokenizer,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- else:
- data_collator = TransformersDataCollatorForLanguageModeling(
- __tokenizer.tokenizer,
- mlm = False,
- mlm_probability = 0.0,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- other_metrics = []
-
- from unsloth_zoo.logging_utils import PatchRLStatistics
- PatchRLStatistics('nash_md_trainer', other_metrics)
-
- # [TODO] Fix up DataParallel multiplying batch sizes
- # [TODO] DDP works, but DP seems to not work? [TODO]
- if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1:
- if getattr(args, "_n_gpu", 1) != 1:
- args._n_gpu = 1
- if "model" in locals() and hasattr(model, "for_training"):
- model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
- super().__init__(
- model = model,
- ref_model = ref_model,
- reward_funcs = reward_funcs,
- judge = judge,
- args = args,
- data_collator = data_collator,
- train_dataset = train_dataset,
- eval_dataset = eval_dataset,
- processing_class = processing_class,
- peft_config = peft_config,
- compute_metrics = compute_metrics,
- callbacks = callbacks,
- preprocess_logits_for_metrics = preprocess_logits_for_metrics,
- reward_model = reward_model,**kwargs)
- if "model" in locals() and hasattr(model, "for_inference"):
- model.for_inference()
- if hasattr(self, 'neftune_hook_handle'):
- self.neftune_hook_handle.remove()
- if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
- if getattr(args, 'neftune_noise_alpha', None) is not None:
- model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
- pass
- if hasattr(self, 'accelerator'):
- scaler = self.accelerator.scaler
- current_model = model
- while hasattr(current_model, 'model'):
- current_model.accelerator_scaler = scaler
- current_model = current_model.model
- current_model.accelerator_scaler = scaler
- pass
- if hasattr(self, 'train'):
- self.train = MethodType(prepare_for_training_mode(self.__class__.train), self)
- pass
- if hasattr(self, 'llm') and self.llm is not None and hasattr(self.llm, 'get_tokenizer'):
- _vllm_tok = self.llm.get_tokenizer()
- _pc = getattr(self, 'processing_class', None) or getattr(self, 'tokenizer', None)
- if _vllm_tok is not None and _pc is not None and getattr(_pc, 'chat_template', None) is not None and getattr(_vllm_tok, 'chat_template', None) is None:
- _vllm_tok.chat_template = _pc.chat_template
- pass
-
-pass
diff --git a/unsloth_compiled_cache/UnslothORPOTrainer.py b/unsloth_compiled_cache/UnslothORPOTrainer.py
deleted file mode 100644
index 1ed2b5a..0000000
--- a/unsloth_compiled_cache/UnslothORPOTrainer.py
+++ /dev/null
@@ -1,1884 +0,0 @@
-"""
-2026.2.1
-2026.2.1
-4.57.6
-0.24.0
-__UNSLOTH_VERSIONING__
-"""
-
-# Unsloth auto generated code
-# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Lesser General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU Lesser General Public License
-# along with this program. If not, see .
-
-from torch import Tensor
-import torch
-import torch.nn as nn
-from torch.nn import functional as F
-from unsloth_zoo.temporary_patches.common import torch_compile
-from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
-from trl.trainer.orpo_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, BaseTrainer, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, ORPOConfig, ORPOTrainer, Optional, PartialState, Path, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, Union, add_bos_token_if_needed, add_eos_token_if_needed, autocast, defaultdict, disable_dropout_in_model, inspect, is_comet_available, is_peft_available, is_torch_fx_proxy, is_torch_xla_available, is_wandb_available, log_table_to_comet_experiment, logger, logging, maybe_apply_chat_template, maybe_extract_prompt, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, selective_log_softmax, textwrap, torch, warnings, AutoModelForCausalLM, BaseImageProcessor, Callable, DPODataCollatorWithPadding, DataCollator, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, ORPOConfig, ORPOTrainer, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, Union, autocast, defaultdict, disable_dropout_in_model, inspect, is_comet_available, is_peft_available, is_wandb_available, logger, maybe_apply_chat_template, maybe_extract_prompt, nn, np, os, peft_module_casting_to_bf16, prepare_model_for_kbit_training, torch, warnings, F, Optional, PeftModel, PreTrainedModel, is_peft_available, logger, os, torch)
-
-
-import os
-from typing import *
-from dataclasses import dataclass, field
-from packaging.version import Version
-import torch
-import numpy as np
-from contextlib import nullcontext
-from torch.nn import functional as F
-import inspect
-from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
-from transformers.training_args import ParallelMode
-from unsloth_zoo.device_type import DEVICE_TYPE, device_synchronize
-
-# Wrap trainer with padding to right and enable training mode
-# Also patches W&B since multiple runs must use wandb.finish()
-import functools
-from types import MethodType
-try:
- from unsloth_zoo.gradient_checkpointing import reset_unsloth_gradient_checkpointing_buffers
-except:
- def reset_unsloth_gradient_checkpointing_buffers(): pass
-def prepare_for_training_mode(f):
- @functools.wraps(f)
- def wrapper(self, *args, **kwargs):
- # Enable training mode
- _was_training = None
- # Get gradient checkpointing setting from training arguments
- use_gc = getattr(self.args, 'gradient_checkpointing', True)
- if hasattr(self, 'model') and hasattr(self.model, "training"):
- _was_training = self.model.training
- if hasattr(self, 'model') and hasattr(self.model, "for_training"):
- self.model.for_training(use_gradient_checkpointing=use_gc)
- output = f(self, *args, **kwargs)
- # Restore previous mode when possible
- if hasattr(self, 'model') and hasattr(self.model, "for_inference"):
- if _was_training is False:
- self.model.for_inference()
- elif _was_training is True and hasattr(self.model, "for_training"):
- self.model.for_training(use_gradient_checkpointing=use_gc)
- # Reset gradient checkpointing buffers to free memory while staying ready for next run
- try:
- reset_unsloth_gradient_checkpointing_buffers()
- except:
- pass
- # Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run
- try:
- import wandb
- wandb.finish()
- except:
- pass
- return output
- return wrapper
-pass
-
-torch_compile_options = {
- "epilogue_fusion" : True,
- "max_autotune" : False,
- "shape_padding" : True,
- "trace.enabled" : False,
- "triton.cudagraphs" : False,
-}
-
-@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
-def chunked_hidden_states_selective_log_softmax(
- hidden_states: torch.Tensor,
- lm_head: torch.Tensor,
- index: torch.Tensor,
- chunks: int = 4,
- logit_scale_multiply: float = 0.0,
- logit_scale_divide: float = 0.0,
- logit_softcapping: float = 0.0,
- temperature: float = 1.0,
-) -> torch.Tensor:
- # All Unsloth Zoo code licensed under AGPL3
- flat_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1])
- flat_index = index.reshape(-1)
-
- chunked_hidden_states = torch.chunk(flat_hidden_states, chunks=chunks, dim=0)
- chunked_index = torch.chunk(flat_index, chunks=chunks, dim=0)
-
- all_per_token_logps = []
-
- for chunk_hidden_states, chunk_index in zip(chunked_hidden_states, chunked_index):
- chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t()
-
- if logit_scale_multiply != 0.0:
- chunk_logits = chunk_logits * logit_scale_multiply
- if logit_scale_divide != 0.0:
- chunk_logits = chunk_logits / logit_scale_divide
- if logit_softcapping != 0.0:
- chunk_logits = chunk_logits * torch.tanh(chunk_logits / logit_softcapping)
-
- chunk_logits = chunk_logits.to(torch.float32)
-
- if temperature != 1.0:
- chunk_logits = chunk_logits / temperature
-
- selected_logits = torch.gather(chunk_logits, dim=-1, index=chunk_index.unsqueeze(-1)).squeeze(-1)
- logsumexp_values = torch.logsumexp(chunk_logits, dim=-1)
- per_token_logps = selected_logits - logsumexp_values
- all_per_token_logps.append(per_token_logps)
-
- all_per_token_logps = torch.concat(all_per_token_logps)
-
- all_per_token_logps = all_per_token_logps.reshape((hidden_states.shape[0], hidden_states.shape[1]))
- return all_per_token_logps
-
-@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
-def chunked_selective_log_softmax(logits, index):
- # Split into 4 chunks only
- chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
- chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
- all_per_token_logps = []
- # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
- for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
- chunk_logits = chunk_logits.to(torch.float32)
- selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
- logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
- per_token_logps = selected_logits - logsumexp_values
- all_per_token_logps.append(per_token_logps)
- pass
- all_per_token_logps = torch.concat(all_per_token_logps)
- all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
- return all_per_token_logps
-
-def calculate_pad_tokens_in_prompt(
- input_ids: torch.Tensor,
- logits_to_keep: int,
- pad_token_id: int
-) -> torch.Tensor:
- """
- Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens
- """
- if logits_to_keep >= input_ids.shape[1]:
- raise ValueError("logits_to_keep must be smaller than the sequence length.")
-
- prompt_section = input_ids[:, :-logits_to_keep]
-
- padding_mask = (prompt_section == pad_token_id)
-
- pad_token_counts = padding_mask.sum(dim=1)
-
- return pad_token_counts
-
-def create_completion_attention_mask(
- completion_input_ids: torch.Tensor,
- left_pad_tokens_per_prompt: torch.Tensor,
- max_left_pad: int,
- pad_token_id: int
-) -> torch.Tensor:
- """
- Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad]
-
- Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens
- and pad are pad tokens, this function would make a completion mask that would 0 out the pad
- and p tokens. so in this example [0,0,0,1,1,1,0,0,0]
- """
- batch_size, completion_len = completion_input_ids.shape
- device = completion_input_ids.device
-
- num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt
-
- indices = torch.arange(completion_len, device=device).unsqueeze(0)
- shift_mask = indices >= num_tokens_to_mask.unsqueeze(1)
-
- non_padding_mask = (completion_input_ids != pad_token_id)
-
- final_mask = shift_mask & non_padding_mask
-
- return final_mask
-
-def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor:
- """
- Moves all padding tokens in each sequence of a batch to the right.
- """
- mask = (tensor != pad_id)
- # Must do stable=True since binary mark is unordered
- sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True)
- packed_tensor = torch.gather(tensor, 1, sorted_indices)
- return packed_tensor
-
-def align_logprobs_with_mask(
- logprob_tensor: torch.Tensor,
- attention_mask: torch.Tensor,
- pad_value: float = 0.0
-) -> torch.Tensor:
- """
- Aligns a log probability tensor with a given attention mask.
- """
-
- device = logprob_tensor.device
- batch_size, logprob_seq_len = logprob_tensor.shape
- mask_seq_len = attention_mask.shape[1]
-
- padded_logprobs = torch.full(
- attention_mask.shape,
- fill_value=pad_value,
- dtype=logprob_tensor.dtype,
- device=device
- )
-
- left_pad_counts = torch.argmax(attention_mask, dim=1)
-
- cols = torch.arange(logprob_seq_len, device=device)
- dest_indices = left_pad_counts.unsqueeze(1) + cols
-
- # Create destination row indices
- # Shape: [batch_size, logprob_seq_len]
- row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices)
-
- # --- 4. Filter out-of-bounds indices and perform assignment ---
- # Create a mask to identify only the indices that are within the bounds
- # of the target tensor's sequence length.
- valid_mask = dest_indices < mask_seq_len
-
- # Use this mask to select only the valid row indices, column indices,
- # and the corresponding values from the logprob tensor.
- # This flattens the selected elements into 1D tensors.
- valid_rows = row_indices[valid_mask]
- valid_cols = dest_indices[valid_mask]
- valid_vals = logprob_tensor[valid_mask]
-
- # Place the valid values into their correct positions in the padded tensor
- # using a single, efficient advanced indexing operation.
- padded_logprobs[valid_rows, valid_cols] = valid_vals
-
- return padded_logprobs
-
-def autotune_batch_and_chunks(
- total_input_rows,
- seq_len,
- hidden_size,
- vocab_size,
- dtype_bytes=16,
- multiplier=None
-):
- if multiplier is None:
- final_m = max(4, seq_len // 4096)
- else:
- final_m = multiplier
-
- if torch.cuda.is_available():
- free_bytes, _ = torch.cuda.mem_get_info()
- limit_gb = (free_bytes / (1024**3))*.80
- elif hasattr(torch, "xpu") and torch.xpu.is_available():
- # For XPU: estimate free memory from total - reserved
- total_mem = torch.xpu.get_device_properties(0).total_memory
- reserved_mem = torch.xpu.memory_reserved()
- free_bytes = total_mem - reserved_mem
- limit_gb = (free_bytes / (1024**3)) * 0.80
- else:
- # Fallback: assume 8GB available
- limit_gb = 8.0
-
- bytes_to_gb = 1024**3
-
- b_vals = torch.arange(total_input_rows, 0, -1, device='cpu', dtype=torch.float32)
-
- hidden_gb = (b_vals * seq_len * hidden_size * dtype_bytes) / bytes_to_gb
-
- base_logits = ((b_vals/total_input_rows) * b_vals * seq_len * vocab_size * dtype_bytes) / bytes_to_gb
- logits_gb = base_logits / final_m
-
- total_mem_gb = hidden_gb + logits_gb
-
- valid_mask = total_mem_gb <= limit_gb
- valid_indices = torch.nonzero(valid_mask, as_tuple=False)
-
- if valid_indices.shape[0] == 0:
- #This means your GPU will OOM
- return 4, final_m
-
- best_idx = valid_indices[0].item()
- final_b = int(b_vals[best_idx].item())
-
- return final_b, final_m
-@dataclass
-class UnslothORPOConfig(ORPOConfig):
- """
-
- Configuration class for the [`ORPOTrainer`].
-
- This class includes only the parameters that are specific to ORPO training. For a full list of training arguments,
- please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may
- differ from those in [`~transformers.TrainingArguments`].
-
- Using [`~transformers.HfArgumentParser`] we can turn this class into
- [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
- command line.
-
- Parameters:
- max_length (`int` or `None`, *optional*, defaults to `1024`):
- Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
- to use the default data collator.
- max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
- Maximum length of the prompt. This argument is required if you want to use the default data collator.
- max_completion_length (`int`, *optional*):
- Maximum length of the completion. This argument is required if you want to use the default data collator
- and your model is an encoder-decoder.
- beta (`float`, *optional*, defaults to `0.1`):
- Parameter controlling the relative ratio loss weight in the ORPO loss. In the
- [paper](https://huggingface.co/papers/2403.07691), it is denoted by λ. In the
- [code](https://github.com/xfactlab/orpo), it is denoted by `alpha`.
- disable_dropout (`bool`, *optional*, defaults to `True`):
- Whether to disable dropout in the model.
- label_pad_token_id (`int`, *optional*, defaults to `-100`):
- Label pad token id. This argument is required if you want to use the default data collator.
- padding_value (`int`, *optional*):
- Padding value to use. If `None`, the padding value of the tokenizer is used.
- truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
- Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
- This argument is required if you want to use the default data collator.
- generate_during_eval (`bool`, *optional*, defaults to `False`):
- If `True`, generates and logs completions from the model to W&B or Comet during evaluation.
- is_encoder_decoder (`bool`, *optional*):
- When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
- you need to specify if the model returned by the callable is an encoder-decoder model.
- model_init_kwargs (`dict[str, Any]`, *optional*):
- Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
- string.
- dataset_num_proc (`int`, *optional*):
- Number of processes to use for processing the dataset.
-
- """
- vllm_sampling_params: Optional[Any] = field(
- default = None,
- metadata = {'help': 'vLLM SamplingParams'},
- )
- unsloth_num_chunks : Optional[int] = field(
- default = -1,
- metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
- )
- unsloth_logit_chunk_multiplier : Optional[int] = field(
- default = None,
- metadata = {'help': 'Multiplier for chunked logit computations.'},
- )
- unsloth_grpo_mini_batch : Optional[int] = field(
- default = None,
- metadata = {'help': 'Mini batch size for GRPO hidden state accumulation. Default is None unless user defines it.'},
- )
- max_seq_length : Optional[int] = field(
- default = None,
- metadata = {'help': 'Maximum sequence length to truncate to.'},
- )
- def __init__(
- self,
- output_dir = None,
- overwrite_output_dir = None,
- do_train = False,
- do_eval = False,
- do_predict = False,
- eval_strategy = 'no',
- prediction_loss_only = False,
- per_device_train_batch_size = 4,
- per_device_eval_batch_size = 4,
- per_gpu_train_batch_size = None,
- per_gpu_eval_batch_size = None,
- gradient_accumulation_steps = 2,
- eval_accumulation_steps = 2,
- eval_delay = 0,
- torch_empty_cache_steps = 250,
- learning_rate = 5e-05,
- weight_decay = 0.01,
- adam_beta1 = 0.9,
- adam_beta2 = 0.999,
- adam_epsilon = 1e-08,
- max_grad_norm = 1.0,
- num_train_epochs = 3.0,
- max_steps = -1,
- lr_scheduler_type = 'linear',
- lr_scheduler_kwargs = None,
- warmup_ratio = 0.1,
- warmup_steps = 0,
- log_level = 'passive',
- log_level_replica = 'warning',
- log_on_each_node = True,
- logging_dir = None,
- logging_strategy = 'steps',
- logging_first_step = False,
- logging_steps = 1,
- logging_nan_inf_filter = False,
- save_strategy = 'steps',
- save_steps = 500,
- save_total_limit = None,
- save_safetensors = True,
- save_on_each_node = False,
- save_only_model = False,
- restore_callback_states_from_checkpoint = False,
- no_cuda = False,
- use_cpu = False,
- use_mps_device = False,
- seed = 3407,
- data_seed = 3407,
- jit_mode_eval = False,
- bf16 = False,
- fp16 = False,
- fp16_opt_level = 'O1',
- half_precision_backend = 'auto',
- bf16_full_eval = False,
- fp16_full_eval = False,
- tf32 = None,
- local_rank = -1,
- ddp_backend = None,
- tpu_num_cores = None,
- tpu_metrics_debug = False,
- debug = '',
- dataloader_drop_last = False,
- eval_steps = None,
- dataloader_num_workers = 0,
- dataloader_prefetch_factor = None,
- past_index = -1,
- run_name = None,
- disable_tqdm = None,
- remove_unused_columns = True,
- label_names = None,
- load_best_model_at_end = False,
- metric_for_best_model = None,
- greater_is_better = None,
- ignore_data_skip = False,
- fsdp = None,
- fsdp_min_num_params = 0,
- fsdp_config = None,
- fsdp_transformer_layer_cls_to_wrap = None,
- accelerator_config = None,
- parallelism_config = None,
- deepspeed = None,
- label_smoothing_factor = 0.0,
- optim = 'adamw_8bit',
- optim_args = None,
- adafactor = False,
- group_by_length = False,
- length_column_name = 'length',
- report_to = 'none',
- project = 'huggingface',
- trackio_space_id = 'trackio',
- ddp_find_unused_parameters = None,
- ddp_bucket_cap_mb = None,
- ddp_broadcast_buffers = None,
- dataloader_pin_memory = True,
- dataloader_persistent_workers = False,
- skip_memory_metrics = True,
- use_legacy_prediction_loop = False,
- push_to_hub = False,
- resume_from_checkpoint = None,
- hub_model_id = None,
- hub_strategy = 'every_save',
- hub_token = None,
- hub_private_repo = None,
- hub_always_push = False,
- hub_revision = None,
- gradient_checkpointing = True,
- gradient_checkpointing_kwargs = None,
- include_inputs_for_metrics = False,
- eval_do_concat_batches = True,
- fp16_backend = 'auto',
- push_to_hub_model_id = None,
- push_to_hub_organization = None,
- push_to_hub_token = None,
- mp_parameters = '',
- auto_find_batch_size = False,
- full_determinism = False,
- torchdynamo = None,
- ray_scope = 'last',
- ddp_timeout = 1800,
- torch_compile = False,
- torch_compile_backend = None,
- torch_compile_mode = None,
- include_tokens_per_second = False,
- include_num_input_tokens_seen = False,
- neftune_noise_alpha = None,
- optim_target_modules = None,
- batch_eval_metrics = False,
- eval_on_start = False,
- use_liger_kernel = False,
- liger_kernel_config = None,
- eval_use_gather_object = False,
- average_tokens_across_devices = True,
- max_length = 1024,
- max_prompt_length = 512,
- max_completion_length = None,
- beta = 0.1,
- disable_dropout = True,
- label_pad_token_id = -100,
- padding_value = None,
- truncation_mode = 'keep_end',
- generate_during_eval = False,
- is_encoder_decoder = None,
- model_init_kwargs = None,
- dataset_num_proc = None,
- vllm_sampling_params = None,
- unsloth_num_chunks = -1,
- unsloth_logit_chunk_multiplier = None,
- unsloth_grpo_mini_batch = None,
- max_seq_length = None,
- **kwargs,
- ):
- if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
- if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
- if num_train_epochs is None:
- num_train_epochs = 3.0 # Default to 3 epochs if None, max_steps will override
- if output_dir is None and save_strategy == 'steps' and save_steps == 500:
- output_dir = 'unsloth_training_checkpoints'
- save_strategy = 'no'
- import multiprocessing as _mp
- if _mp.get_start_method() != 'fork':
- dataset_num_proc = None
- elif dataset_num_proc is None:
- import psutil
- dataset_num_proc = min(max((psutil.cpu_count() or 1)+4, 2), 64)
- memory_gb_left = psutil.virtual_memory().available / (1024**3)
- if memory_gb_left <= 2: dataset_num_proc = 1
- else: dataset_num_proc = min(dataset_num_proc, int(memory_gb_left))
-
- super().__init__(
- output_dir = output_dir,
- overwrite_output_dir = overwrite_output_dir,
- do_train = do_train,
- do_eval = do_eval,
- do_predict = do_predict,
- eval_strategy = eval_strategy,
- prediction_loss_only = prediction_loss_only,
- per_device_train_batch_size = per_device_train_batch_size,
- per_device_eval_batch_size = per_device_eval_batch_size,
- per_gpu_train_batch_size = per_gpu_train_batch_size,
- per_gpu_eval_batch_size = per_gpu_eval_batch_size,
- gradient_accumulation_steps = gradient_accumulation_steps,
- eval_accumulation_steps = eval_accumulation_steps,
- eval_delay = eval_delay,
- torch_empty_cache_steps = torch_empty_cache_steps,
- learning_rate = learning_rate,
- weight_decay = weight_decay,
- adam_beta1 = adam_beta1,
- adam_beta2 = adam_beta2,
- adam_epsilon = adam_epsilon,
- max_grad_norm = max_grad_norm,
- num_train_epochs = num_train_epochs,
- max_steps = max_steps,
- lr_scheduler_type = lr_scheduler_type,
- lr_scheduler_kwargs = lr_scheduler_kwargs,
- warmup_ratio = warmup_ratio,
- warmup_steps = warmup_steps,
- log_level = log_level,
- log_level_replica = log_level_replica,
- log_on_each_node = log_on_each_node,
- logging_dir = logging_dir,
- logging_strategy = logging_strategy,
- logging_first_step = logging_first_step,
- logging_steps = logging_steps,
- logging_nan_inf_filter = logging_nan_inf_filter,
- save_strategy = save_strategy,
- save_steps = save_steps,
- save_total_limit = save_total_limit,
- save_safetensors = save_safetensors,
- save_on_each_node = save_on_each_node,
- save_only_model = save_only_model,
- restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
- no_cuda = no_cuda,
- use_cpu = use_cpu,
- use_mps_device = use_mps_device,
- seed = seed,
- data_seed = data_seed,
- jit_mode_eval = jit_mode_eval,
- bf16 = bf16,
- fp16 = fp16,
- fp16_opt_level = fp16_opt_level,
- half_precision_backend = half_precision_backend,
- bf16_full_eval = bf16_full_eval,
- fp16_full_eval = fp16_full_eval,
- tf32 = tf32,
- local_rank = local_rank,
- ddp_backend = ddp_backend,
- tpu_num_cores = tpu_num_cores,
- tpu_metrics_debug = tpu_metrics_debug,
- debug = debug,
- dataloader_drop_last = dataloader_drop_last,
- eval_steps = eval_steps,
- dataloader_num_workers = dataloader_num_workers,
- dataloader_prefetch_factor = dataloader_prefetch_factor,
- past_index = past_index,
- run_name = run_name,
- disable_tqdm = disable_tqdm,
- remove_unused_columns = remove_unused_columns,
- label_names = label_names,
- load_best_model_at_end = load_best_model_at_end,
- metric_for_best_model = metric_for_best_model,
- greater_is_better = greater_is_better,
- ignore_data_skip = ignore_data_skip,
- fsdp = fsdp,
- fsdp_min_num_params = fsdp_min_num_params,
- fsdp_config = fsdp_config,
- fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
- accelerator_config = accelerator_config,
- parallelism_config = parallelism_config,
- deepspeed = deepspeed,
- label_smoothing_factor = label_smoothing_factor,
- optim = optim,
- optim_args = optim_args,
- adafactor = adafactor,
- group_by_length = group_by_length,
- length_column_name = length_column_name,
- report_to = report_to,
- project = project,
- trackio_space_id = trackio_space_id,
- ddp_find_unused_parameters = ddp_find_unused_parameters,
- ddp_bucket_cap_mb = ddp_bucket_cap_mb,
- ddp_broadcast_buffers = ddp_broadcast_buffers,
- dataloader_pin_memory = dataloader_pin_memory,
- dataloader_persistent_workers = dataloader_persistent_workers,
- skip_memory_metrics = skip_memory_metrics,
- use_legacy_prediction_loop = use_legacy_prediction_loop,
- push_to_hub = push_to_hub,
- resume_from_checkpoint = resume_from_checkpoint,
- hub_model_id = hub_model_id,
- hub_strategy = hub_strategy,
- hub_token = hub_token,
- hub_private_repo = hub_private_repo,
- hub_always_push = hub_always_push,
- hub_revision = hub_revision,
- gradient_checkpointing = gradient_checkpointing,
- gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
- include_inputs_for_metrics = include_inputs_for_metrics,
- eval_do_concat_batches = eval_do_concat_batches,
- fp16_backend = fp16_backend,
- push_to_hub_model_id = push_to_hub_model_id,
- push_to_hub_organization = push_to_hub_organization,
- push_to_hub_token = push_to_hub_token,
- mp_parameters = mp_parameters,
- auto_find_batch_size = auto_find_batch_size,
- full_determinism = full_determinism,
- torchdynamo = torchdynamo,
- ray_scope = ray_scope,
- ddp_timeout = ddp_timeout,
- torch_compile = torch_compile,
- torch_compile_backend = torch_compile_backend,
- torch_compile_mode = torch_compile_mode,
- include_tokens_per_second = include_tokens_per_second,
- include_num_input_tokens_seen = include_num_input_tokens_seen,
- neftune_noise_alpha = neftune_noise_alpha,
- optim_target_modules = optim_target_modules,
- batch_eval_metrics = batch_eval_metrics,
- eval_on_start = eval_on_start,
- use_liger_kernel = use_liger_kernel,
- liger_kernel_config = liger_kernel_config,
- eval_use_gather_object = eval_use_gather_object,
- average_tokens_across_devices = average_tokens_across_devices,
- max_length = max_length,
- max_prompt_length = max_prompt_length,
- max_completion_length = max_completion_length,
- beta = beta,
- disable_dropout = disable_dropout,
- label_pad_token_id = label_pad_token_id,
- padding_value = padding_value,
- truncation_mode = truncation_mode,
- generate_during_eval = generate_during_eval,
- is_encoder_decoder = is_encoder_decoder,
- model_init_kwargs = model_init_kwargs,
- dataset_num_proc = dataset_num_proc,**kwargs)
- self.vllm_sampling_params = vllm_sampling_params
- self.unsloth_num_chunks = unsloth_num_chunks
- if unsloth_grpo_mini_batch is not None:
- if self.generation_batch_size >= unsloth_grpo_mini_batch:
- self.unsloth_grpo_mini_batch = unsloth_grpo_mini_batch
- else:
- raise ValueError(
- f"Unsloth GRPO mini batch size needs to be less than or equal to the effective generation batch size, "
- f"which is self.per_device_train_batch_size * gradient_accumulation_steps."
- )
- self.unsloth_logit_chunk_multiplier = unsloth_logit_chunk_multiplier
- self.max_seq_length = max_seq_length
-
-pass
-
-class _UnslothORPOTrainer(BaseTrainer):
- r""""""
-
- _tag_names = ["trl", "orpo"]
- _name = "ORPO"
- _paper = {
- "title": "ORPO: Monolithic Preference Optimization without Reference Model",
- "id": "2403.07691",
- # docstyle-ignore
- "citation": textwrap.dedent("""\
- @article{hong2024orpo,
- title = {{ORPO: Monolithic Preference Optimization without Reference Model}},
- author = {Jiwoo Hong and Noah Lee and James Thorne},
- year = 2024,
- eprint = {arXiv:2403.07691}
- }"""),
- }
-
- def __init__(
- self,
- model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
- args: Optional[ORPOConfig] = None,
- data_collator: Optional[DataCollator] = None,
- train_dataset: Optional[Dataset] = None,
- eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
- processing_class: Optional[
- Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
- ] = None,
- model_init: Optional[Callable[[], PreTrainedModel]] = None,
- callbacks: Optional[list[TrainerCallback]] = None,
- optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
- preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
- peft_config: Optional[dict] = None,
- compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
- ):
- if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"):
- warnings.warn(
- "This trainer will soon be moved to trl.experimental and is a candidate for removal. If you rely on "
- "it and want it to remain, please share your comments here: "
- "https://github.com/huggingface/trl/issues/4223. Silence this warning by setting environment variable "
- "TRL_EXPERIMENTAL_SILENCE=1."
- )
- if args.model_init_kwargs is None:
- model_init_kwargs = {}
- elif not isinstance(model, str):
- raise ValueError("You passed model_kwargs to the ORPOTrainer. But your model is already instantiated.")
- else:
- model_init_kwargs = args.model_init_kwargs
- dtype = model_init_kwargs.get("dtype")
- if dtype is not None:
- # Convert to `torch.dtype` if an str is passed
- if isinstance(dtype, str) and dtype != "auto":
- dtype = getattr(torch, dtype)
- if dtype != "auto" and not isinstance(dtype, torch.dtype):
- raise ValueError(
- f"Invalid `dtype` passed to the ORPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {dtype}."
- )
- model_init_kwargs["dtype"] = dtype
-
- if isinstance(model, str):
- model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
-
- # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
- # has been called in order to properly call autocast if needed.
- self._peft_has_been_casted_to_bf16 = False
-
- if not is_peft_available() and peft_config is not None:
- raise ValueError(
- "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
- )
- elif is_peft_available() and peft_config is not None:
- # if model is a peft model and we have a peft_config, we merge and unload it first
- if isinstance(model, PeftModel):
- model = model.merge_and_unload()
-
- if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
- _support_gc_kwargs = hasattr(
- args, "gradient_checkpointing_kwargs"
- ) and "gradient_checkpointing_kwargs" in list(
- inspect.signature(prepare_model_for_kbit_training).parameters
- )
-
- prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
-
- if _support_gc_kwargs:
- prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
-
- model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
- elif args.gradient_checkpointing:
- # For backward compatibility with older versions of transformers
- if hasattr(model, "enable_input_require_grads"):
- model.enable_input_require_grads()
- else:
-
- def make_inputs_require_grad(module, input, output):
- output.requires_grad_(True)
-
- model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
-
- # get peft model with the given config
- model = model
- if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
- peft_module_casting_to_bf16(model)
- # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
- self._peft_has_been_casted_to_bf16 = True
-
- # For models that use gradient_checkpointing, we need to attach a hook that enables input
- # to explicitly have `requires_grad=True`, otherwise training will either silently
- # fail or completely fail.
- elif args.gradient_checkpointing:
- # For backward compatibility with older versions of transformers
- if hasattr(model, "enable_input_require_grads"):
- model.enable_input_require_grads()
- else:
-
- def make_inputs_require_grad(module, input, output):
- output.requires_grad_(True)
-
- model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
-
- if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
- raise ValueError(
- "`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
- " Please install `wandb` or `comet-ml` to resolve."
- )
-
- if model is not None:
- self.is_encoder_decoder = model.config.is_encoder_decoder
- elif args.is_encoder_decoder is None:
- raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
- else:
- self.is_encoder_decoder = args.is_encoder_decoder
-
- if self.is_encoder_decoder:
- self.decoder_start_token_id = model.config.decoder_start_token_id
- self.pad_token_id = model.config.pad_token_id
-
- if processing_class is None:
- raise ValueError("processing_class must be specified to tokenize a ORPO dataset.")
- if args.max_length is None:
- logger.warning(
- "`max_length` is not set in the ORPOConfig's init"
- " it will default to `512` by default, but you should do it yourself in the future.",
- )
- max_length = 512
- else:
- max_length = args.max_length
- if args.max_prompt_length is None:
- logger.warning(
- "`max_prompt_length` is not set in the ORPOConfig's init"
- " it will default to `128` by default, but you should do it yourself in the future.",
- )
- max_prompt_length = 128
- else:
- max_prompt_length = args.max_prompt_length
-
- if args.max_completion_length is None and self.is_encoder_decoder:
- logger.warning(
- "When using an encoder decoder architecture, you should set `max_completion_length` in the ORPOConfig's init"
- " it will default to `128` by default, but you should do it yourself in the future.",
- )
- self.max_completion_length = 128
- else:
- self.max_completion_length = args.max_completion_length
-
- if data_collator is None:
- data_collator = DPODataCollatorWithPadding(
- pad_token_id=processing_class.pad_token_id,
- label_pad_token_id=args.label_pad_token_id,
- is_encoder_decoder=self.is_encoder_decoder,
- )
-
- if args.remove_unused_columns:
- args.remove_unused_columns = False
- # warn users
- logger.warning(
- "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments"
- " we have set it for you, but you should do it yourself in the future.",
- )
-
- self.use_dpo_data_collator = True
- else:
- self.use_dpo_data_collator = False
-
- # Disable dropout in the model and reference model
- if args.disable_dropout:
- disable_dropout_in_model(model)
-
- self.max_length = max_length
- self.generate_during_eval = args.generate_during_eval
- self.label_pad_token_id = args.label_pad_token_id
- self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
- self.max_prompt_length = max_prompt_length
- self.truncation_mode = args.truncation_mode
- self.processing_class = processing_class
-
- self.beta = args.beta
- self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
- self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
- if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
- logger.warning(
- "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
- "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
- "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
- "loss.",
- )
-
- self._stored_metrics = defaultdict(lambda: defaultdict(list))
-
- # The trainer estimates the number of FLOPs [floating-point operations] using the number of elements in the
- # input tensor associated with the key "input_ids". However, in ORPO, the sampled data does not include the
- # "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and
- # "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens
- # of the input, floating-point operations will not be computed." To suppress this warning, we set the
- # "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate
- # that the warning has already been issued.
- model.warnings_issued["estimate_tokens"] = True
-
- # Compute that only on the main process for faster data processing.
- # see: https://github.com/huggingface/trl/pull/1255
- with PartialState().main_process_first():
- # Extract the prompt if needed, and apply the chat template if needed
- train_dataset = train_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
- train_dataset = train_dataset.map(
- maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc
- )
- train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
- if eval_dataset is not None:
- eval_dataset = eval_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
- eval_dataset = eval_dataset.map(
- maybe_apply_chat_template,
- fn_kwargs={"tokenizer": processing_class},
- num_proc=args.dataset_num_proc,
- )
- eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
-
- super().__init__(
- model=model,
- args=args,
- data_collator=data_collator,
- train_dataset=train_dataset,
- eval_dataset=eval_dataset,
- processing_class=processing_class,
- model_init=model_init,
- compute_metrics=compute_metrics,
- callbacks=callbacks,
- optimizers=optimizers,
- preprocess_logits_for_metrics=preprocess_logits_for_metrics,
- )
-
- # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
- # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
- # self.model_accepts_loss_kwargs to False to enable scaling.
- self.model_accepts_loss_kwargs = False
-
- # Add tags for models that have been loaded with the correct transformers version
- if hasattr(self.model, "add_model_tags"):
- self.model.add_model_tags(self._tag_names)
-
- if not hasattr(self, "accelerator"):
- raise AttributeError(
- "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
- )
-
- def build_tokenized_answer(self, prompt, answer):
- """
- Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`. It does ensure `enc(a + b) = enc(a) + enc(a +
- b)[len(enc(a)):]`. Reference:
- https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
- """
-
- full_tokenized = self.processing_class(prompt + answer, add_special_tokens=False)
- prompt_input_ids = self.processing_class(prompt, add_special_tokens=False)["input_ids"]
-
- answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :]
- answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :]
-
- # Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
- full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids])
-
- # Prepare input tokens for token by token comparison
- full_input_ids = np.array(full_tokenized["input_ids"])
-
- if len(full_input_ids) != len(full_concat_input_ids):
- raise ValueError("Prompt input ids and answer input ids should have the same length.")
-
- # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
- # can be merged together when tokenizing prompt+answer. This could result
- # on the last token from the prompt being different when tokenized on its own
- # vs when done as prompt+answer.
- response_token_ids_start_idx = len(prompt_input_ids)
-
- # If tokenized prompt is different than both prompt+answer, then it means the
- # last token has changed due to merging.
- if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]:
- response_token_ids_start_idx -= 1
-
- prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx]
- prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx]
-
- if len(prompt_input_ids) != len(prompt_attention_mask):
- raise ValueError("Prompt input ids and attention mask should have the same length.")
-
- answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
- answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]
-
- return dict(
- prompt_input_ids=prompt_input_ids,
- prompt_attention_mask=prompt_attention_mask,
- input_ids=answer_input_ids,
- attention_mask=answer_attention_mask,
- )
-
- def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> dict:
- """Tokenize a single row from a ORPO specific dataset.
-
- At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation in case the prompt +
- chosen or prompt + rejected responses is/are too long. First we truncate the prompt; if we're still too long,
- we truncate the chosen/rejected.
-
- We also create the labels for the chosen/rejected responses, which are of length equal to the sum of the length
- of the prompt and the chosen/rejected response, with label_pad_token_id for the prompt tokens.
- """
- batch = {}
- prompt = feature["prompt"]
- chosen = feature["chosen"]
- rejected = feature["rejected"]
-
- if not self.is_encoder_decoder:
- # Check issues below for more details
- # 1. https://github.com/huggingface/trl/issues/907
- # 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
- # 3. https://github.com/LianjiaTech/BELLE/issues/337
-
- if not isinstance(prompt, str):
- raise ValueError(f"prompt should be an str but got {type(prompt)}")
- prompt_tokens = self.processing_class(prompt, add_special_tokens=False)
- prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}
-
- if not isinstance(chosen, str):
- raise ValueError(f"chosen should be an str but got {type(chosen)}")
- chosen_tokens = self.build_tokenized_answer(prompt, chosen)
-
- if not isinstance(rejected, str):
- raise ValueError(f"rejected should be an str but got {type(rejected)}")
- rejected_tokens = self.build_tokenized_answer(prompt, rejected)
-
- # Last prompt token might get merged by tokenizer and
- # it should not be included for generation if that happens
- prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"])
-
- chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"])
- rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"])
- prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids)
-
- for k, v in prompt_tokens.items():
- prompt_tokens[k] = v[:prompt_len_input_ids]
-
- # Make sure prompts only have one different token at most an
- # and length only differs by 1 at most
- num_diff_tokens = sum(
- a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])
- )
- num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
- if num_diff_tokens > 1 or num_diff_len > 1:
- raise ValueError(
- "Chosen and rejected prompt_input_ids might only differ on the "
- "last token due to tokenizer merge ops."
- )
-
- # add BOS token to head of prompt. Avoid adding if it's already there
- prompt_tokens, chosen_tokens, rejected_tokens = add_bos_token_if_needed(
- self.processing_class.bos_token_id,
- prompt_len_input_ids,
- prompt_tokens,
- chosen_prompt_len_input_ids,
- chosen_tokens,
- rejected_prompt_len_input_ids,
- rejected_tokens,
- )
-
- # add EOS token to end of answer. Avoid adding if it's already there
- chosen_tokens, rejected_tokens = add_eos_token_if_needed(
- self.processing_class.eos_token_id, chosen_tokens, rejected_tokens
- )
-
- longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))
-
- # if combined sequence is too long, truncate the prompt
- for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
- if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
- if self.truncation_mode == "keep_start":
- for k in ["prompt_input_ids", "prompt_attention_mask"]:
- answer_tokens[k] = answer_tokens[k][: self.max_prompt_length]
- elif self.truncation_mode == "keep_end":
- for k in ["prompt_input_ids", "prompt_attention_mask"]:
- answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :]
- else:
- raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
-
- # if that's still too long, truncate the response
- for answer_tokens in [chosen_tokens, rejected_tokens]:
- if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
- for k in ["input_ids", "attention_mask"]:
- answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length]
-
- # Create labels
- chosen_sequence_tokens = {
- k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"]
- }
- rejected_sequence_tokens = {
- k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"]
- }
- chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
- chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [
- self.label_pad_token_id
- ] * len(chosen_tokens["prompt_input_ids"])
- rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
- rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [
- self.label_pad_token_id
- ] * len(rejected_tokens["prompt_input_ids"])
-
- for k, toks in {
- "chosen_": chosen_sequence_tokens,
- "rejected_": rejected_sequence_tokens,
- "": prompt_tokens,
- }.items():
- for type_key, tokens in toks.items():
- if type_key == "token_type_ids":
- continue
- batch[f"{k}{type_key}"] = tokens
-
- else:
- chosen_tokens = self.processing_class(
- chosen, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
- )
- rejected_tokens = self.processing_class(
- rejected, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
- )
- prompt_tokens = self.processing_class(
- prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True
- )
-
- batch["chosen_labels"] = chosen_tokens["input_ids"]
- batch["rejected_labels"] = rejected_tokens["input_ids"]
- batch["prompt_input_ids"] = prompt_tokens["input_ids"]
- batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]
-
- if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
- batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
- labels=torch.tensor(batch["rejected_labels"])
- )
- batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
- labels=torch.tensor(batch["chosen_labels"])
- )
-
- if is_torch_xla_available():
- # Pad the sequences to global max_length to avoid TorchXLA recompilation
- for k in batch:
- if "labels" in k or self.is_encoder_decoder:
- pad_value = self.label_pad_token_id
- elif k.endswith("_input_ids"):
- pad_value = self.padding_value
- elif k.endswith("_attention_mask"):
- pad_value = 0
- batch[k] = batch[k] + [pad_value] * (self.max_length - len(batch[k]))
- return batch
-
- @staticmethod
- def concatenated_inputs(
- batch: dict[str, Union[list, torch.LongTensor]],
- is_encoder_decoder: bool = False,
- label_pad_token_id: int = -100,
- padding_value: int = 0,
- device: Optional[torch.device] = None,
- ) -> dict[str, torch.LongTensor]:
- """Concatenate the chosen and rejected inputs into a single tensor.
-
- Args:
- batch:
- A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors
- of shape (batch_size, sequence_length).
- is_encoder_decoder:
- Whether the model is an encoder-decoder model.
- label_pad_token_id:
- The label pad token id.
- padding_value:
- The padding value to use for the concatenated inputs_ids.
- device:
- The device for the concatenated inputs.
-
- Returns:
- A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
- """
- concatenated_batch = {}
-
- if is_encoder_decoder:
- max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1])
- else:
- max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
-
- for k in batch:
- if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
- if "labels" in k or is_encoder_decoder:
- pad_value = label_pad_token_id
- elif k.endswith("_input_ids"):
- pad_value = padding_value
- elif k.endswith("_attention_mask"):
- pad_value = 0
- concatenated_key = k.replace("chosen", "concatenated")
- concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
- for k in batch:
- if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
- if "labels" in k or is_encoder_decoder:
- pad_value = label_pad_token_id
- elif k.endswith("_input_ids"):
- pad_value = padding_value
- elif k.endswith("_attention_mask"):
- pad_value = 0
- concatenated_key = k.replace("rejected", "concatenated")
- concatenated_batch[concatenated_key] = torch.cat(
- (
- concatenated_batch[concatenated_key],
- pad_to_length(batch[k], max_length, pad_value=pad_value),
- ),
- dim=0,
- ).to(device=device)
-
- if is_encoder_decoder:
- concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device)
- concatenated_batch["concatenated_attention_mask"] = (
- batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
- )
-
- return concatenated_batch
-
- def odds_ratio_loss(
- self,
- policy_chosen_logps: torch.FloatTensor,
- policy_rejected_logps: torch.FloatTensor,
- ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
- """Compute ORPO's odds ratio (OR) loss for a batch of policy and reference model log probabilities.
-
- Args:
- policy_chosen_logps:
- Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
- policy_rejected_logps:
- Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
-
- Returns:
- A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). The losses tensor contains the ORPO
- loss for each example in the batch. The chosen_rewards and rejected_rewards tensors contain the rewards for
- the chosen and rejected responses, respectively. The log odds ratio of the chosen responses over the
- rejected responses ratio for logging purposes. The `log(sigmoid(log_odds_chosen))` for logging purposes.
- """
-
- # Derived from Eqs. (4) and (7) from https://huggingface.co/papers/2403.07691 by using log identities and exp(log(P(y|x)) = P(y|x)
- log_odds = (policy_chosen_logps - policy_rejected_logps) - (
- torch.log1p(-torch.exp(policy_chosen_logps)) - torch.log1p(-torch.exp(policy_rejected_logps))
- )
- ratio = F.logsigmoid(log_odds)
- losses = self.beta * ratio
-
- chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach()
- rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach()
-
- return losses, chosen_rewards, rejected_rewards, torch.mean(ratio), torch.mean(log_odds)
-
- @staticmethod
- def get_batch_logps(
- logits: torch.FloatTensor,
- labels: torch.LongTensor,
- average_log_prob: bool = False,
- label_pad_token_id: int = -100,
- is_encoder_decoder: bool = False,
- ) -> torch.FloatTensor:
- """Compute the log probabilities of the given labels under the given logits.
-
- Args:
- logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
- labels:
- Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are
- ignored. Shape: (batch_size, sequence_length)
- average_log_prob:
- If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the
- log probabilities of the (non-masked) tokens.
- label_pad_token_id: The label pad token id.
- is_encoder_decoder: Whether the model is an encoder-decoder model.
-
- Returns:
- A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the
- given logits.
- """
- if logits.shape[:-1] != labels.shape:
- raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
-
- if not is_encoder_decoder:
- labels = labels[:, 1:].clone()
- logits = logits[:, :-1, :]
- loss_mask = labels != label_pad_token_id
-
- # dummy token; we'll ignore the losses on these tokens later
- labels = torch.where(labels == label_pad_token_id, 0, labels)
-
- per_token_logps = selective_log_softmax(logits, labels)
-
- if average_log_prob:
- return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
- else:
- return (per_token_logps * loss_mask).sum(-1)
-
- def concatenated_forward(
- self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
- ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
- """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
-
- We do this to avoid doing two forward passes, because it's faster for FSDP.
- """
- concatenated_batch = self.concatenated_inputs(
- batch,
- is_encoder_decoder=self.is_encoder_decoder,
- label_pad_token_id=self.label_pad_token_id,
- padding_value=self.padding_value,
- device=self.accelerator.device,
- )
- len_chosen = batch["chosen_labels"].shape[0]
-
- model_kwargs = (
- {
- "decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]),
- }
- if self.is_encoder_decoder
- else {}
- )
-
- if self.aux_loss_enabled:
- model_kwargs["output_router_logits"] = True
-
- outputs = model(
- concatenated_batch["concatenated_input_ids"],
- attention_mask=concatenated_batch["concatenated_attention_mask"],
- use_cache=False,
- **model_kwargs,
- )
- all_logits = outputs.logits
-
- def cross_entropy_loss(logits, labels):
- if not self.is_encoder_decoder:
- # Shift so that tokens < n predict n
- logits = logits[..., :-1, :].contiguous()
- labels = labels[..., 1:].contiguous()
- # Flatten the tokens
- loss_fct = nn.CrossEntropyLoss()
- logits = logits.view(-1, logits.shape[-1])
- labels = labels.view(-1)
- # Enable model parallelism
- labels = labels.to(logits.device)
- loss = loss_fct(logits, labels)
- return loss
-
- if self.is_encoder_decoder:
- labels = concatenated_batch["concatenated_labels"].clone()
- else:
- labels = concatenated_batch["concatenated_input_ids"].clone()
- attention_mask = concatenated_batch["concatenated_attention_mask"]
- labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id)
- # orpo chosen nll loss is computed over the full prompt and response
- chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
-
- all_logps = self.get_batch_logps(
- all_logits,
- concatenated_batch["concatenated_labels"],
- average_log_prob=True,
- is_encoder_decoder=self.is_encoder_decoder,
- label_pad_token_id=self.label_pad_token_id,
- )
-
- chosen_logps = all_logps[:len_chosen]
- rejected_logps = all_logps[len_chosen:]
-
- if not self.is_encoder_decoder:
- chosen_logits = all_logits[:len_chosen, :-1, :]
- rejected_logits = all_logits[len_chosen:, :-1, :]
- else:
- chosen_logits = all_logits[:len_chosen]
- rejected_logits = all_logits[len_chosen:]
-
- if self.aux_loss_enabled:
- return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss, outputs.aux_loss)
-
- return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss)
-
- def get_batch_loss_metrics(
- self,
- model,
- batch: dict[str, Union[list, torch.LongTensor]],
- train_eval: Literal["train", "eval"] = "train",
- ):
- """Compute the ORPO loss and other metrics for the given batch of inputs for train or test."""
- metrics = {}
-
- forward_output = self.concatenated_forward(model, batch)
- (
- policy_chosen_logps,
- policy_rejected_logps,
- policy_chosen_logits,
- policy_rejected_logits,
- policy_nll_loss,
- ) = forward_output[:5]
- if self.aux_loss_enabled:
- aux_loss = forward_output[5]
-
- losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss(
- policy_chosen_logps, policy_rejected_logps
- )
- # full ORPO loss
- loss = policy_nll_loss - losses.mean()
-
- reward_accuracies = (chosen_rewards > rejected_rewards).float()
-
- prefix = "eval_" if train_eval == "eval" else ""
- metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean()
- metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean()
- metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean()
- metrics[f"{prefix}rewards/margins"] = self.accelerator.gather_for_metrics(
- chosen_rewards - rejected_rewards
- ).mean()
- metrics[f"{prefix}logps/rejected"] = self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean()
- metrics[f"{prefix}logps/chosen"] = self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean()
- metrics[f"{prefix}logits/rejected"] = self.accelerator.gather_for_metrics(
- policy_rejected_logits.detach().mean()
- ).mean()
- metrics[f"{prefix}logits/chosen"] = self.accelerator.gather_for_metrics(
- policy_chosen_logits.detach().mean()
- ).mean()
- metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean()
- metrics[f"{prefix}log_odds_ratio"] = self.accelerator.gather_for_metrics(log_odds_ratio).detach().mean()
- metrics[f"{prefix}log_odds_chosen"] = self.accelerator.gather_for_metrics(log_odds_chosen).detach().mean()
- if is_torch_xla_available():
- xm.mark_step() # needed because .item() calls
- for k, v in metrics.items():
- metrics[k] = v.item()
- if self.aux_loss_enabled:
- loss += self.aux_loss_coef * aux_loss
-
- return loss, metrics
-
- def compute_loss(
- self,
- model: Union[PreTrainedModel, nn.Module],
- inputs: dict[str, Union[torch.Tensor, Any]],
- return_outputs=False,
- num_items_in_batch=None,
- ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
- compute_loss_context_manager = (
- autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
- )
-
- with compute_loss_context_manager:
- loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
-
- # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
- loss = loss.to(self.args.device)
-
- # force log the metrics
- self.store_metrics(metrics, train_eval="train")
-
- if return_outputs:
- return (loss, metrics)
- return loss
-
- def generate_from_model(self, model, batch: dict[str, torch.LongTensor]) -> str:
- """Generate samples from the model and reference model for the given batch of inputs."""
-
- # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
- # the torch amp context manager as some hidden states are silently casted to full precision.
- generate_context_manager = (
- autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
- )
-
- with generate_context_manager:
- policy_output = model.generate(
- input_ids=batch["prompt_input_ids"],
- attention_mask=batch["prompt_attention_mask"],
- max_length=self.max_length,
- do_sample=True,
- pad_token_id=self.processing_class.pad_token_id,
- )
-
- policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
- policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
-
- return policy_output_decoded
-
- def prediction_step(
- self,
- model: Union[PreTrainedModel, nn.Module],
- inputs: dict[str, Union[torch.Tensor, Any]],
- prediction_loss_only: bool,
- ignore_keys: Optional[list[str]] = None,
- ):
- if not self.use_dpo_data_collator:
- logger.warning(
- "prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
- "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
- )
- if ignore_keys is None:
- if hasattr(model, "config"):
- ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
- else:
- ignore_keys = []
-
- prediction_context_manager = (
- autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
- )
-
- with torch.no_grad(), prediction_context_manager:
- loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")
-
- # force log the metrics
- self.store_metrics(metrics, train_eval="eval")
-
- if prediction_loss_only:
- return (loss.detach(), None, None)
-
- # logits for the chosen and rejected samples from model
- logits_dict = {
- "eval_logits/chosen": metrics["eval_logits/chosen"],
- "eval_logits/rejected": metrics["eval_logits/rejected"],
- }
- logits = [v for k, v in logits_dict.items() if k not in ignore_keys]
- logits = torch.tensor(logits, device=self.accelerator.device)
- labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
-
- return (loss.detach(), logits, labels)
-
- def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
- for key, value in metrics.items():
- self._stored_metrics[train_eval][key].append(value)
-
- def evaluation_loop(
- self,
- dataloader: DataLoader,
- description: str,
- prediction_loss_only: Optional[bool] = None,
- ignore_keys: Optional[list[str]] = None,
- metric_key_prefix: str = "eval",
- ) -> EvalLoopOutput:
- """
- Overriding built-in evaluation loop to store metrics for each batch. Prediction/evaluation loop, shared by
- `Trainer.evaluate()` and `Trainer.predict()`.
-
- Works both with or without labels.
- """
-
- # Sample and save to game log if requested (for one batch to save time)
- if self.generate_during_eval:
- # Generate random indices within the range of the total number of samples
- num_samples = len(dataloader.dataset)
- random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
-
- # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
- random_batch_dataset = dataloader.dataset.select(random_indices)
- random_batch = self.data_collator(random_batch_dataset)
- random_batch = self._prepare_inputs(random_batch)
-
- policy_output_decoded = self.generate_from_model(self.model, random_batch)
-
- table = pd.DataFrame(
- columns=["Prompt", "Policy"],
- data=[
- [prompt, pol[len(prompt) :]] for prompt, pol in zip(random_batch["prompt"], policy_output_decoded)
- ],
- )
- if "wandb" in self.args.report_to:
- wandb.log({"game_log": wandb.Table(data=table)})
-
- if "comet_ml" in self.args.report_to:
- log_table_to_comet_experiment(
- name="game_log.csv",
- table=table,
- )
-
- # Base evaluation
- initial_output = super().evaluation_loop(
- dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
- )
-
- return initial_output
-
- def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
- """
- Log `logs` on the various objects watching training, including stored metrics.
-
- Args:
- logs (`dict[str, float]`):
- The values to log.
- start_time (`float`, *optional*):
- Start time of the training.
- """
- # logs either has 'loss' or 'eval_loss'
- train_eval = "train" if "loss" in logs else "eval"
- # Add averaged stored metrics to logs
- for key, metrics in self._stored_metrics[train_eval].items():
- logs[key] = torch.tensor(metrics).mean().item()
- del self._stored_metrics[train_eval]
- return super().log(logs, start_time)
-
- def _shift_right(self, input_ids):
- if self.decoder_start_token_id is None:
- raise ValueError(
- "model.config.decoder_start_token_id has to be defined. It is usually set to the pad_token_id."
- )
-
- # shift inputs to the right
- if is_torch_fx_proxy(input_ids):
- # Item assignment is not supported natively for proxies.
- shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), self.decoder_start_token_id)
- shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
- else:
- shifted_input_ids = input_ids.new_zeros(input_ids.shape)
- shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
- shifted_input_ids[..., 0] = self.decoder_start_token_id
-
- if self.pad_token_id is None:
- raise ValueError("model.config.pad_token_id has to be defined.")
- # replace possible -100 values in labels by `pad_token_id`
- shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id)
-
- return shifted_input_ids
-
- # Ensure the model card is saved along with the checkpoint
- def _save_checkpoint(self, model, trial):
- if self.args.hub_model_id is None:
- model_name = Path(self.args.output_dir).name
- else:
- model_name = self.args.hub_model_id.split("/")[-1]
- self.create_model_card(model_name=model_name)
- super()._save_checkpoint(model, trial)
-class UnslothORPOTrainer(_UnslothORPOTrainer):
- """
-
- Initialize ORPOTrainer.
-
- Args:
- model ([`~transformers.PreTrainedModel`]):
- The model to train, preferably an [`~transformers.AutoModelForSequenceClassification`].
- args ([`ORPOConfig`]):
- The ORPO config arguments to use for training.
- data_collator ([`~transformers.DataCollator`]):
- The data collator to use for training. If None is specified, the default data collator
- ([`DPODataCollatorWithPadding`]) will be used which will pad the sequences to the maximum length of the
- sequences in the batch, given a dataset of paired sequences.
- train_dataset ([`~datasets.Dataset`]):
- The dataset to use for training.
- eval_dataset ([`~datasets.Dataset`]):
- The dataset to use for evaluation.
- processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*):
- Processing class used to process the data. If provided, will be used to automatically process the inputs
- for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
- reuse the fine-tuned model.
- model_init (`Callable[[], transformers.PreTrainedModel]`):
- The model initializer to use for training. If None is specified, the default model initializer will be
- used.
- callbacks (`list[transformers.TrainerCallback]`):
- The callbacks to use for training.
- optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
- The optimizer and scheduler to use for training.
- preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
- The function to use to preprocess the logits before computing the metrics.
- peft_config (`dict`, defaults to `None`):
- The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in
- a PEFT model.
- compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
- The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to
- metric values.
-
- """
- def __init__(
- self,
- model = None,
- args = None,
- data_collator = None,
- train_dataset = None,
- eval_dataset = None,
- processing_class = None,
- model_init = None,
- callbacks = None,
- preprocess_logits_for_metrics = None,
- peft_config = None,
- compute_metrics = None,
- **kwargs
- ):
- if args is None: args = UnslothORPOConfig()
- use_bf16 = getattr(args, 'bf16', False)
- if type(use_bf16) is not bool: use_bf16 = False
- use_fp16 = getattr(args, 'fp16', False)
- if type(use_fp16) is not bool: use_fp16 = False
- force_float32 = False
- full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1'
- if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'):
- print('Unsloth: Switching to float32 training since model cannot work with float16')
- force_float32 = True
- mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
- dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None)
- if dtype is None: dtype = model.get_input_embeddings().weight.dtype
- from unsloth_zoo.utils import _get_dtype
- dtype = _get_dtype(dtype)
- float16 = dtype == torch.float16
- if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
- if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
- if force_float32:
- # Forced float32 training
- args.fp16 = False
- args.bf16 = False
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
- if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
- # args.mixed_precision is a new argument which needs to be set now
- elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
- # Mixed precision training
- args.fp16 = float16
- args.bf16 = not float16
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
- if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16'
- # args.mixed_precision is a new argument which needs to be set now
- elif mixed_precision_dtype == 'bfloat16':
- # Both False since bfloat16 full finetuning doesn't do any autocasting.
- args.fp16 = False
- args.bf16 = False
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
- if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
- # args.mixed_precision is a new argument which needs to be set now
-
- if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
- args.eval_strategy = 'steps'
- if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
- ga_steps = getattr(args, 'gradient_accumulation_steps', None)
- if ga_steps is not None and ga_steps > 1:
- from transformers import __version__ as transformers_version
- if Version(transformers_version) <= Version('4.45.2'):
- print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
- '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
- if getattr(args, 'eval_strategy', 'no') != 'no':
- eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
- if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
- if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
- fp16_full_eval = getattr(args, 'fp16_full_eval', False)
- if type(fp16_full_eval) is not bool: fp16_full_eval = False
- bf16_full_eval = getattr(args, 'bf16_full_eval', False)
- if type(bf16_full_eval) is not bool: bf16_full_eval = False
- if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
- if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
- if force_float32:
- args.bf16_full_eval = False
- args.fp16_full_eval = False
- elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
- args.bf16_full_eval = True
- args.fp16_full_eval = False
- elif not bf16_full_eval and not fp16_full_eval:
- args.bf16_full_eval = args.bf16
- args.fp16_full_eval = args.fp16
- _output_logits = False
- if locals().get('compute_metrics', None) is not None: _output_logits = True
- if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
- if _output_logits:
- os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
- if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
- pass
- else:
- model_max_seq_length = getattr(model, 'max_seq_length', None)
- args_max_seq_length = getattr(args, 'max_seq_length', None)
- if args_max_seq_length is None and model_max_seq_length is not None:
- max_seq_length = model.max_seq_length
- if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
- elif args_max_seq_length is not None and model_max_seq_length is not None:
- if args_max_seq_length > model_max_seq_length:
- print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but '
- 'the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.')
- args.max_seq_length = model_max_seq_length
- if model is not None and hasattr(model, 'for_training'):
- model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
- if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
- if 'processing_class' in locals():
- if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
- if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
- __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
- from unsloth_zoo.vision_utils import UnslothVisionDataCollator
- if not isinstance(data_collator, UnslothVisionDataCollator):
- if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
- data_collator = TransformersDataCollatorForLanguageModeling(
- __tokenizer,
- mlm = False,
- mlm_probability = 0.0,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
- data_collator = DataCollatorForSeq2Seq(
- __tokenizer,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- else:
- if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
- if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
- if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
- if not isinstance(data_collator, UnslothVisionDataCollator):
- if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
- if isinstance(data_collator, DataCollatorForSeq2Seq):
- data_collator = DataCollatorForSeq2Seq(
- __tokenizer.tokenizer,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- else:
- data_collator = TransformersDataCollatorForLanguageModeling(
- __tokenizer.tokenizer,
- mlm = False,
- mlm_probability = 0.0,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- other_metrics = []
-
- from unsloth_zoo.logging_utils import PatchRLStatistics
- PatchRLStatistics('orpo_trainer', other_metrics)
-
- # [TODO] Fix up DataParallel multiplying batch sizes
- # [TODO] DDP works, but DP seems to not work? [TODO]
- if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1:
- if getattr(args, "_n_gpu", 1) != 1:
- args._n_gpu = 1
- if "model" in locals() and hasattr(model, "for_training"):
- model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
- super().__init__(
- model = model,
- args = args,
- data_collator = data_collator,
- train_dataset = train_dataset,
- eval_dataset = eval_dataset,
- processing_class = processing_class,
- model_init = model_init,
- callbacks = callbacks,
- preprocess_logits_for_metrics = preprocess_logits_for_metrics,
- peft_config = peft_config,
- compute_metrics = compute_metrics,**kwargs)
- if "model" in locals() and hasattr(model, "for_inference"):
- model.for_inference()
- if hasattr(self, 'neftune_hook_handle'):
- self.neftune_hook_handle.remove()
- if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
- if getattr(args, 'neftune_noise_alpha', None) is not None:
- model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
- pass
- if hasattr(self, 'accelerator'):
- scaler = self.accelerator.scaler
- current_model = model
- while hasattr(current_model, 'model'):
- current_model.accelerator_scaler = scaler
- current_model = current_model.model
- current_model.accelerator_scaler = scaler
- pass
- if hasattr(self, 'train'):
- self.train = MethodType(prepare_for_training_mode(self.__class__.train), self)
- pass
- if hasattr(self, 'llm') and self.llm is not None and hasattr(self.llm, 'get_tokenizer'):
- _vllm_tok = self.llm.get_tokenizer()
- _pc = getattr(self, 'processing_class', None) or getattr(self, 'tokenizer', None)
- if _vllm_tok is not None and _pc is not None and getattr(_pc, 'chat_template', None) is not None and getattr(_vllm_tok, 'chat_template', None) is None:
- _vllm_tok.chat_template = _pc.chat_template
- pass
-
-pass
-
-
-if hasattr(logger, "addFilter"):
- import logging
- class HideLoggingMessage(logging.Filter):
- def __init__(self, text): self.text = text
- def filter(self, x): return not (self.text in x.getMessage())
- pass
- logger.addFilter(HideLoggingMessage("`use_cache=True`"))
-
diff --git a/unsloth_compiled_cache/UnslothOnlineDPOTrainer.py b/unsloth_compiled_cache/UnslothOnlineDPOTrainer.py
deleted file mode 100644
index 2921887..0000000
--- a/unsloth_compiled_cache/UnslothOnlineDPOTrainer.py
+++ /dev/null
@@ -1,2467 +0,0 @@
-"""
-2026.2.1
-2026.2.1
-4.57.6
-0.24.0
-__UNSLOTH_VERSIONING__
-"""
-
-# Unsloth auto generated code
-# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Lesser General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU Lesser General Public License
-# along with this program. If not, see .
-
-from torch import Tensor
-import torch
-import torch.nn as nn
-from torch.nn import functional as F
-from unsloth_zoo.temporary_patches.common import torch_compile
-from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
-from trl.trainer.online_dpo_trainer import (Any, AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, BasePairwiseJudge, BaseTrainer, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalPrediction, F, FSDP, GenerationConfig, IterableDataset, MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES, OnlineDPOConfig, OnlineDPOTrainer, OptimizerNames, Optional, Path, PeftConfig, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, RewardFunc, SIMPLE_CHAT_TEMPLATE, Trainer, TrainerCallback, Union, VLLMClient, apply_chat_template, broadcast_object_list, create_reference_model, disable_dropout_in_model, empty_cache, ensure_master_addr_port, gather_object, is_conversational, is_flash_attn_2_available, is_peft_model, is_vllm_available, jinja2, logger, logging, maybe_apply_chat_template, nn, nullcontext, os, pad, prepare_deepspeed, prepare_fsdp, profiling_context, re, seed_worker, textwrap, torch, truncate_right, unwrap_model_for_generation, version, warnings, wraps, AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, BasePairwiseJudge, Callable, DPODataCollatorWithPadding, DataCollator, Dataset, EvalPrediction, F, GenerationConfig, IterableDataset, MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES, OnlineDPOConfig, Optional, PeftConfig, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, RewardFunc, Trainer, TrainerCallback, Union, VLLMClient, create_reference_model, disable_dropout_in_model, ensure_master_addr_port, is_vllm_available, logger, nn, os, pad, prepare_deepspeed, prepare_fsdp, re, torch, version, warnings, F, apply_chat_template, is_conversational, re, F, FSDP, is_peft_model, nn, nullcontext, os, re, version, F, Optional, PreTrainedModel, Trainer, logger, os, re, torch, F, FSDP, nn, os, re, F, FSDP, nn, re, torch)
-
-
-import os
-from typing import *
-from dataclasses import dataclass, field
-from packaging.version import Version
-import torch
-import numpy as np
-from contextlib import nullcontext
-from torch.nn import functional as F
-import inspect
-from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
-from transformers.training_args import ParallelMode
-from unsloth_zoo.device_type import DEVICE_TYPE, device_synchronize
-
-# Wrap trainer with padding to right and enable training mode
-# Also patches W&B since multiple runs must use wandb.finish()
-import functools
-from types import MethodType
-try:
- from unsloth_zoo.gradient_checkpointing import reset_unsloth_gradient_checkpointing_buffers
-except:
- def reset_unsloth_gradient_checkpointing_buffers(): pass
-def prepare_for_training_mode(f):
- @functools.wraps(f)
- def wrapper(self, *args, **kwargs):
- # Enable training mode
- _was_training = None
- # Get gradient checkpointing setting from training arguments
- use_gc = getattr(self.args, 'gradient_checkpointing', True)
- if hasattr(self, 'model') and hasattr(self.model, "training"):
- _was_training = self.model.training
- if hasattr(self, 'model') and hasattr(self.model, "for_training"):
- self.model.for_training(use_gradient_checkpointing=use_gc)
- output = f(self, *args, **kwargs)
- # Restore previous mode when possible
- if hasattr(self, 'model') and hasattr(self.model, "for_inference"):
- if _was_training is False:
- self.model.for_inference()
- elif _was_training is True and hasattr(self.model, "for_training"):
- self.model.for_training(use_gradient_checkpointing=use_gc)
- # Reset gradient checkpointing buffers to free memory while staying ready for next run
- try:
- reset_unsloth_gradient_checkpointing_buffers()
- except:
- pass
- # Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run
- try:
- import wandb
- wandb.finish()
- except:
- pass
- return output
- return wrapper
-pass
-
-torch_compile_options = {
- "epilogue_fusion" : True,
- "max_autotune" : False,
- "shape_padding" : True,
- "trace.enabled" : False,
- "triton.cudagraphs" : False,
-}
-
-@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
-def chunked_hidden_states_selective_log_softmax(
- hidden_states: torch.Tensor,
- lm_head: torch.Tensor,
- index: torch.Tensor,
- chunks: int = 4,
- logit_scale_multiply: float = 0.0,
- logit_scale_divide: float = 0.0,
- logit_softcapping: float = 0.0,
- temperature: float = 1.0,
-) -> torch.Tensor:
- # All Unsloth Zoo code licensed under AGPL3
- flat_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1])
- flat_index = index.reshape(-1)
-
- chunked_hidden_states = torch.chunk(flat_hidden_states, chunks=chunks, dim=0)
- chunked_index = torch.chunk(flat_index, chunks=chunks, dim=0)
-
- all_per_token_logps = []
-
- for chunk_hidden_states, chunk_index in zip(chunked_hidden_states, chunked_index):
- chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t()
-
- if logit_scale_multiply != 0.0:
- chunk_logits = chunk_logits * logit_scale_multiply
- if logit_scale_divide != 0.0:
- chunk_logits = chunk_logits / logit_scale_divide
- if logit_softcapping != 0.0:
- chunk_logits = chunk_logits * torch.tanh(chunk_logits / logit_softcapping)
-
- chunk_logits = chunk_logits.to(torch.float32)
-
- if temperature != 1.0:
- chunk_logits = chunk_logits / temperature
-
- selected_logits = torch.gather(chunk_logits, dim=-1, index=chunk_index.unsqueeze(-1)).squeeze(-1)
- logsumexp_values = torch.logsumexp(chunk_logits, dim=-1)
- per_token_logps = selected_logits - logsumexp_values
- all_per_token_logps.append(per_token_logps)
-
- all_per_token_logps = torch.concat(all_per_token_logps)
-
- all_per_token_logps = all_per_token_logps.reshape((hidden_states.shape[0], hidden_states.shape[1]))
- return all_per_token_logps
-
-@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
-def chunked_selective_log_softmax(logits, index):
- # Split into 4 chunks only
- chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
- chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
- all_per_token_logps = []
- # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
- for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
- chunk_logits = chunk_logits.to(torch.float32)
- selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
- logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
- per_token_logps = selected_logits - logsumexp_values
- all_per_token_logps.append(per_token_logps)
- pass
- all_per_token_logps = torch.concat(all_per_token_logps)
- all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
- return all_per_token_logps
-
-def calculate_pad_tokens_in_prompt(
- input_ids: torch.Tensor,
- logits_to_keep: int,
- pad_token_id: int
-) -> torch.Tensor:
- """
- Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens
- """
- if logits_to_keep >= input_ids.shape[1]:
- raise ValueError("logits_to_keep must be smaller than the sequence length.")
-
- prompt_section = input_ids[:, :-logits_to_keep]
-
- padding_mask = (prompt_section == pad_token_id)
-
- pad_token_counts = padding_mask.sum(dim=1)
-
- return pad_token_counts
-
-def create_completion_attention_mask(
- completion_input_ids: torch.Tensor,
- left_pad_tokens_per_prompt: torch.Tensor,
- max_left_pad: int,
- pad_token_id: int
-) -> torch.Tensor:
- """
- Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad]
-
- Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens
- and pad are pad tokens, this function would make a completion mask that would 0 out the pad
- and p tokens. so in this example [0,0,0,1,1,1,0,0,0]
- """
- batch_size, completion_len = completion_input_ids.shape
- device = completion_input_ids.device
-
- num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt
-
- indices = torch.arange(completion_len, device=device).unsqueeze(0)
- shift_mask = indices >= num_tokens_to_mask.unsqueeze(1)
-
- non_padding_mask = (completion_input_ids != pad_token_id)
-
- final_mask = shift_mask & non_padding_mask
-
- return final_mask
-
-def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor:
- """
- Moves all padding tokens in each sequence of a batch to the right.
- """
- mask = (tensor != pad_id)
- # Must do stable=True since binary mark is unordered
- sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True)
- packed_tensor = torch.gather(tensor, 1, sorted_indices)
- return packed_tensor
-
-def align_logprobs_with_mask(
- logprob_tensor: torch.Tensor,
- attention_mask: torch.Tensor,
- pad_value: float = 0.0
-) -> torch.Tensor:
- """
- Aligns a log probability tensor with a given attention mask.
- """
-
- device = logprob_tensor.device
- batch_size, logprob_seq_len = logprob_tensor.shape
- mask_seq_len = attention_mask.shape[1]
-
- padded_logprobs = torch.full(
- attention_mask.shape,
- fill_value=pad_value,
- dtype=logprob_tensor.dtype,
- device=device
- )
-
- left_pad_counts = torch.argmax(attention_mask, dim=1)
-
- cols = torch.arange(logprob_seq_len, device=device)
- dest_indices = left_pad_counts.unsqueeze(1) + cols
-
- # Create destination row indices
- # Shape: [batch_size, logprob_seq_len]
- row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices)
-
- # --- 4. Filter out-of-bounds indices and perform assignment ---
- # Create a mask to identify only the indices that are within the bounds
- # of the target tensor's sequence length.
- valid_mask = dest_indices < mask_seq_len
-
- # Use this mask to select only the valid row indices, column indices,
- # and the corresponding values from the logprob tensor.
- # This flattens the selected elements into 1D tensors.
- valid_rows = row_indices[valid_mask]
- valid_cols = dest_indices[valid_mask]
- valid_vals = logprob_tensor[valid_mask]
-
- # Place the valid values into their correct positions in the padded tensor
- # using a single, efficient advanced indexing operation.
- padded_logprobs[valid_rows, valid_cols] = valid_vals
-
- return padded_logprobs
-
-def autotune_batch_and_chunks(
- total_input_rows,
- seq_len,
- hidden_size,
- vocab_size,
- dtype_bytes=16,
- multiplier=None
-):
- if multiplier is None:
- final_m = max(4, seq_len // 4096)
- else:
- final_m = multiplier
-
- if torch.cuda.is_available():
- free_bytes, _ = torch.cuda.mem_get_info()
- limit_gb = (free_bytes / (1024**3))*.80
- elif hasattr(torch, "xpu") and torch.xpu.is_available():
- # For XPU: estimate free memory from total - reserved
- total_mem = torch.xpu.get_device_properties(0).total_memory
- reserved_mem = torch.xpu.memory_reserved()
- free_bytes = total_mem - reserved_mem
- limit_gb = (free_bytes / (1024**3)) * 0.80
- else:
- # Fallback: assume 8GB available
- limit_gb = 8.0
-
- bytes_to_gb = 1024**3
-
- b_vals = torch.arange(total_input_rows, 0, -1, device='cpu', dtype=torch.float32)
-
- hidden_gb = (b_vals * seq_len * hidden_size * dtype_bytes) / bytes_to_gb
-
- base_logits = ((b_vals/total_input_rows) * b_vals * seq_len * vocab_size * dtype_bytes) / bytes_to_gb
- logits_gb = base_logits / final_m
-
- total_mem_gb = hidden_gb + logits_gb
-
- valid_mask = total_mem_gb <= limit_gb
- valid_indices = torch.nonzero(valid_mask, as_tuple=False)
-
- if valid_indices.shape[0] == 0:
- #This means your GPU will OOM
- return 4, final_m
-
- best_idx = valid_indices[0].item()
- final_b = int(b_vals[best_idx].item())
-
- return final_b, final_m
-def vLLMSamplingParams(**kwargs):
- from vllm import SamplingParams
-
- sampling_params = SamplingParams(**kwargs)
- sampling_params._set_kwargs = kwargs
- return sampling_params
-@dataclass
-class UnslothOnlineDPOConfig(OnlineDPOConfig):
- """
-
- Configuration class for the [`OnlineDPOTrainer`].
-
- This class includes only the parameters that are specific to Online DPO training. For a full list of training
- arguments, please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this
- class may differ from those in [`~transformers.TrainingArguments`].
-
- Using [`~transformers.HfArgumentParser`] we can turn this class into
- [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
- command line.
-
- Parameters:
- reward_model_path (`str`, *optional*):
- Path to the reward model. Either `judge` or `reward_model_path` must be set, but not both.
- judge (`str`, *optional*):
- Name of the judge to use. Either `judge` or `reward_model_path` must be set, but not both.
- max_new_tokens (`int`, *optional*, defaults to `64`):
- Maximum number of tokens to generate per completion.
- max_length (`int`, *optional*, defaults to `256`):
- Maximum total length of the sequence (prompt + completion) used to compute log probabilities. If the
- sequence exceeds this limit, the leftmost tokens will be truncated to preserve as much of the completion as
- possible.
- temperature (`float`, *optional*, defaults to `0.9`):
- Temperature for sampling. The higher the temperature, the more random the completions.
- missing_eos_penalty (`float`, *optional*):
- Penalty applied to the score when the model fails to generate an EOS token. This is useful to encourage to
- generate completions shorter than the maximum length (`max_new_tokens`). The penalty must be a positive
- value. This parameter only works when using `reward_funcs` and not when using `judge`.
- beta (`float` or `list[float]`, *optional*, defaults to `0.1`):
- Parameter controlling the deviation from the reference model. Higher β means less deviation from the
- reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in
- the [paper](https://huggingface.co/papers/2310.12036). If a list of floats is provided then the β is
- selected for each new epoch and the last β is used for the rest of the epochs.
- loss_type (`str`, *optional*, defaults to `"sigmoid"`):
- Type of loss to use. Possible values are:
-
- - `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper.
- - `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper.
-
- dataset_num_proc (`int`, *optional*):
- Number of processes to use for processing the dataset.
-
-
-
- This parameter is deprecated and will be removed in version 0.25.0. Since OnlineDPO does not involve
- dataset preparation, you can safely remove it.
-
-
-
- disable_dropout (`bool`, *optional*, defaults to `True`):
- Whether to disable dropout in the model and reference model.
-
- > Parameters that control generation
-
- top_p (`float`, *optional*, defaults to `1.0`):
- Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. Set to
- `1.0` to consider all tokens.
- top_k (`int`, *optional*):
- Number of highest probability vocabulary tokens to keep for top-k-filtering. If `None`, top-k-filtering is
- disabled and all tokens are considered.
- min_p (`float`, *optional*):
- Minimum token probability, which will be scaled by the probability of the most likely token. It must be a
- value between `0.0` and `1.0`. Typical values are in the `0.01-0.2` range.
- repetition_penalty (`float`, *optional*, defaults to `1.0`):
- Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far.
- Values > `1.0` encourage the model to use new tokens, while values < `1.0` encourage the model to repeat
- tokens.
- use_transformers_paged (`bool`, *optional*, defaults to `False`):
- Whether to use the `transformers` paged implementation for generation. If set to `True`, the `transformers`
- paged implementation will be used for generation instead of the default padded implementation. This
- parameter is only effective when `use_vllm` is set to `False`.
- cache_implementation (`str`, *optional*):
- Implementation of the cache method for faster generation when `use_vllm` is set to `False`.
- generation_kwargs (`dict[str, Any]`, *optional*):
- Additional keyword arguments to pass to [`~transformers.GenerationConfig`] (if using transformers) or
- `SamplingParams` (if using vLLM) when sampling completions. This can be used to further customize the
- generation behavior, such as setting `suppress_tokens`, `num_beams`, etc. If it contains keys that conflict
- with the other generation parameters (like `min_p`, `top_p`, etc.), they will override them.
-
- > Parameters that control generation acceleration powered by vLLM
-
- use_vllm (`bool`, *optional*, defaults to `False`):
- Whether to use vLLM for generating completions. If set to `True`, the trainer will use vLLM for generation
- instead of the default model.generate(). Requires `vllm` to be installed.
- vllm_model_impl (`str`, *optional*, defaults to `"vllm"`):
- Model implementation to use for vLLM. Must be one of `"transformers"` or `"vllm"`. `"transformers"`: Use
- the `transformers` backend for model implementation. `"vllm"`: Use the `vllm` library for model
- implementation.
- vllm_mode (`str`, *optional*, defaults to `"server"`):
- Mode to use for vLLM integration when `use_vllm` is set to `True`. Must be one of `"server"` or
- `"colocate"`.
-
- - `"server"`: The trainer will send generation requests to a separate vLLM server. Make sure a TRL vLLM
- server is running (start with `trl vllm-serve`).
- - `"colocate"`: vLLM will run in the same process and share the training GPUs. This avoids the need for a
- separate server but may cause resource contention with training.
- vllm_guided_decoding_regex (`str`, *optional*):
- Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled.
-
- > Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`)
-
- vllm_server_base_url (`str`, *optional*):
- Base URL for the vLLM server (e.g., `"http://localhost:8000"`). If provided, `vllm_server_host` and
- `vllm_server_port` are ignored.
- vllm_server_host (`str`, *optional*, defaults to `"0.0.0.0"`):
- Host of the vLLM server to connect to. Ignored if `vllm_server_base_url` is provided.
- vllm_server_port (`int`, *optional*, defaults to `8000`):
- Port of the vLLM server to connect to. Ignored if `vllm_server_base_url` is provided.
- vllm_server_timeout (`float`, *optional*, defaults to `240.0`):
- Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up after the
- timeout, a `ConnectionError` is raised.
-
- > Parameters that control colocated vLLM execution (only used when `vllm_mode` is `"colocate"`)
-
- vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.55`):
- Control the GPU memory utilization for vLLM. This setting only applies when `vllm_mode` is set to
- `"colocate"`. If you are using `vllm_mode="server"`, this parameter must be passed separately when
- launching the vLLM server via the `--vllm_gpu_memory_utilization` flag.
- vllm_tensor_parallel_size (`int`, *optional*, defaults to `1`):
- Control the tensor parallel size for vLLM. This setting only applies when `vllm_mode` is set to
- `"colocate"`. If you are using `vllm_mode="server"`, this parameter must be passed separately when
- launching the vLLM server via the `--vllm_tensor_parallel_size` flag.
-
- > Other parameters
-
- ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
- This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
- improving generation speed. However, disabling this option allows training models that exceed the VRAM
- capacity of a single GPU, albeit at the cost of slower generation. Disabling this option is not compatible
- with vLLM generation.
- model_init_kwargs (`dict[str, Any]`, *optional*):
- Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
- string.
-
- """
- vllm_sampling_params: Optional[Any] = field(
- default = None,
- metadata = {'help': 'vLLM SamplingParams'},
- )
- unsloth_num_chunks : Optional[int] = field(
- default = -1,
- metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
- )
- unsloth_logit_chunk_multiplier : Optional[int] = field(
- default = None,
- metadata = {'help': 'Multiplier for chunked logit computations.'},
- )
- unsloth_grpo_mini_batch : Optional[int] = field(
- default = None,
- metadata = {'help': 'Mini batch size for GRPO hidden state accumulation. Default is None unless user defines it.'},
- )
- max_seq_length : Optional[int] = field(
- default = None,
- metadata = {'help': 'Maximum sequence length to truncate to.'},
- )
- def __init__(
- self,
- output_dir = None,
- overwrite_output_dir = None,
- do_train = False,
- do_eval = False,
- do_predict = False,
- eval_strategy = 'no',
- prediction_loss_only = False,
- per_device_train_batch_size = 4,
- per_device_eval_batch_size = 4,
- per_gpu_train_batch_size = None,
- per_gpu_eval_batch_size = None,
- gradient_accumulation_steps = 2,
- eval_accumulation_steps = 2,
- eval_delay = 0,
- torch_empty_cache_steps = 250,
- learning_rate = 5e-05,
- weight_decay = 0.01,
- adam_beta1 = 0.9,
- adam_beta2 = 0.999,
- adam_epsilon = 1e-08,
- max_grad_norm = 1.0,
- num_train_epochs = 3.0,
- max_steps = -1,
- lr_scheduler_type = 'linear',
- lr_scheduler_kwargs = None,
- warmup_ratio = 0.1,
- warmup_steps = 0,
- log_level = 'passive',
- log_level_replica = 'warning',
- log_on_each_node = True,
- logging_dir = None,
- logging_strategy = 'steps',
- logging_first_step = False,
- logging_steps = 1,
- logging_nan_inf_filter = False,
- save_strategy = 'steps',
- save_steps = 500,
- save_total_limit = None,
- save_safetensors = True,
- save_on_each_node = False,
- save_only_model = False,
- restore_callback_states_from_checkpoint = False,
- no_cuda = False,
- use_cpu = False,
- use_mps_device = False,
- seed = 3407,
- data_seed = 3407,
- jit_mode_eval = False,
- bf16 = False,
- fp16 = False,
- fp16_opt_level = 'O1',
- half_precision_backend = 'auto',
- bf16_full_eval = False,
- fp16_full_eval = False,
- tf32 = None,
- local_rank = -1,
- ddp_backend = None,
- tpu_num_cores = None,
- tpu_metrics_debug = False,
- debug = '',
- dataloader_drop_last = False,
- eval_steps = None,
- dataloader_num_workers = 0,
- dataloader_prefetch_factor = None,
- past_index = -1,
- run_name = None,
- disable_tqdm = None,
- remove_unused_columns = True,
- label_names = None,
- load_best_model_at_end = False,
- metric_for_best_model = None,
- greater_is_better = None,
- ignore_data_skip = False,
- fsdp = None,
- fsdp_min_num_params = 0,
- fsdp_config = None,
- fsdp_transformer_layer_cls_to_wrap = None,
- accelerator_config = None,
- parallelism_config = None,
- deepspeed = None,
- label_smoothing_factor = 0.0,
- optim = 'adamw_8bit',
- optim_args = None,
- adafactor = False,
- group_by_length = False,
- length_column_name = 'length',
- report_to = 'none',
- project = 'huggingface',
- trackio_space_id = 'trackio',
- ddp_find_unused_parameters = None,
- ddp_bucket_cap_mb = None,
- ddp_broadcast_buffers = None,
- dataloader_pin_memory = True,
- dataloader_persistent_workers = False,
- skip_memory_metrics = True,
- use_legacy_prediction_loop = False,
- push_to_hub = False,
- resume_from_checkpoint = None,
- hub_model_id = None,
- hub_strategy = 'every_save',
- hub_token = None,
- hub_private_repo = None,
- hub_always_push = False,
- hub_revision = None,
- gradient_checkpointing = True,
- gradient_checkpointing_kwargs = None,
- include_inputs_for_metrics = False,
- eval_do_concat_batches = True,
- fp16_backend = 'auto',
- push_to_hub_model_id = None,
- push_to_hub_organization = None,
- push_to_hub_token = None,
- mp_parameters = '',
- auto_find_batch_size = False,
- full_determinism = False,
- torchdynamo = None,
- ray_scope = 'last',
- ddp_timeout = 1800,
- torch_compile = False,
- torch_compile_backend = None,
- torch_compile_mode = None,
- include_tokens_per_second = False,
- include_num_input_tokens_seen = False,
- neftune_noise_alpha = None,
- optim_target_modules = None,
- batch_eval_metrics = False,
- eval_on_start = False,
- use_liger_kernel = False,
- liger_kernel_config = None,
- eval_use_gather_object = False,
- average_tokens_across_devices = True,
- reward_model_path = None,
- judge = None,
- max_new_tokens = 64,
- max_length = 512,
- temperature = 0.9,
- top_p = 1.0,
- top_k = None,
- min_p = None,
- repetition_penalty = 1.0,
- generation_kwargs = {},
- use_transformers_paged = False,
- cache_implementation = None,
- missing_eos_penalty = None,
- loss_type = 'sigmoid',
- disable_dropout = True,
- use_vllm = False,
- vllm_model_impl = 'vllm',
- vllm_guided_decoding_regex = None,
- vllm_gpu_memory_utilization = 0.55,
- vllm_mode = 'colocate',
- vllm_server_base_url = None,
- vllm_server_host = '0.0.0.0',
- vllm_server_port = 8000,
- vllm_server_timeout = 240.0,
- vllm_tensor_parallel_size = 1,
- ds3_gather_for_generation = True,
- model_init_kwargs = None,
- reward_weights = None,
- dataset_num_proc = None,
- gpu_memory_utilization = None,
- vllm_sampling_params = None,
- unsloth_num_chunks = -1,
- unsloth_logit_chunk_multiplier = None,
- unsloth_grpo_mini_batch = None,
- max_seq_length = None,
- **kwargs,
- ):
- if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
- if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
- if num_train_epochs is None:
- num_train_epochs = 3.0 # Default to 3 epochs if None, max_steps will override
- if output_dir is None and save_strategy == 'steps' and save_steps == 500:
- output_dir = 'unsloth_training_checkpoints'
- save_strategy = 'no'
- import multiprocessing as _mp
- if _mp.get_start_method() != 'fork':
- dataset_num_proc = None
- elif dataset_num_proc is None:
- import psutil
- dataset_num_proc = min(max((psutil.cpu_count() or 1)+4, 2), 64)
- memory_gb_left = psutil.virtual_memory().available / (1024**3)
- if memory_gb_left <= 2: dataset_num_proc = 1
- else: dataset_num_proc = min(dataset_num_proc, int(memory_gb_left))
- if temperature <= 0:
- raise ValueError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.')
- elif temperature >= 10:
- raise ValueError('Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic.')
-
-
- super().__init__(
- output_dir = output_dir,
- overwrite_output_dir = overwrite_output_dir,
- do_train = do_train,
- do_eval = do_eval,
- do_predict = do_predict,
- eval_strategy = eval_strategy,
- prediction_loss_only = prediction_loss_only,
- per_device_train_batch_size = per_device_train_batch_size,
- per_device_eval_batch_size = per_device_eval_batch_size,
- per_gpu_train_batch_size = per_gpu_train_batch_size,
- per_gpu_eval_batch_size = per_gpu_eval_batch_size,
- gradient_accumulation_steps = gradient_accumulation_steps,
- eval_accumulation_steps = eval_accumulation_steps,
- eval_delay = eval_delay,
- torch_empty_cache_steps = torch_empty_cache_steps,
- learning_rate = learning_rate,
- weight_decay = weight_decay,
- adam_beta1 = adam_beta1,
- adam_beta2 = adam_beta2,
- adam_epsilon = adam_epsilon,
- max_grad_norm = max_grad_norm,
- num_train_epochs = num_train_epochs,
- max_steps = max_steps,
- lr_scheduler_type = lr_scheduler_type,
- lr_scheduler_kwargs = lr_scheduler_kwargs,
- warmup_ratio = warmup_ratio,
- warmup_steps = warmup_steps,
- log_level = log_level,
- log_level_replica = log_level_replica,
- log_on_each_node = log_on_each_node,
- logging_dir = logging_dir,
- logging_strategy = logging_strategy,
- logging_first_step = logging_first_step,
- logging_steps = logging_steps,
- logging_nan_inf_filter = logging_nan_inf_filter,
- save_strategy = save_strategy,
- save_steps = save_steps,
- save_total_limit = save_total_limit,
- save_safetensors = save_safetensors,
- save_on_each_node = save_on_each_node,
- save_only_model = save_only_model,
- restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
- no_cuda = no_cuda,
- use_cpu = use_cpu,
- use_mps_device = use_mps_device,
- seed = seed,
- data_seed = data_seed,
- jit_mode_eval = jit_mode_eval,
- bf16 = bf16,
- fp16 = fp16,
- fp16_opt_level = fp16_opt_level,
- half_precision_backend = half_precision_backend,
- bf16_full_eval = bf16_full_eval,
- fp16_full_eval = fp16_full_eval,
- tf32 = tf32,
- local_rank = local_rank,
- ddp_backend = ddp_backend,
- tpu_num_cores = tpu_num_cores,
- tpu_metrics_debug = tpu_metrics_debug,
- debug = debug,
- dataloader_drop_last = dataloader_drop_last,
- eval_steps = eval_steps,
- dataloader_num_workers = dataloader_num_workers,
- dataloader_prefetch_factor = dataloader_prefetch_factor,
- past_index = past_index,
- run_name = run_name,
- disable_tqdm = disable_tqdm,
- remove_unused_columns = remove_unused_columns,
- label_names = label_names,
- load_best_model_at_end = load_best_model_at_end,
- metric_for_best_model = metric_for_best_model,
- greater_is_better = greater_is_better,
- ignore_data_skip = ignore_data_skip,
- fsdp = fsdp,
- fsdp_min_num_params = fsdp_min_num_params,
- fsdp_config = fsdp_config,
- fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
- accelerator_config = accelerator_config,
- parallelism_config = parallelism_config,
- deepspeed = deepspeed,
- label_smoothing_factor = label_smoothing_factor,
- optim = optim,
- optim_args = optim_args,
- adafactor = adafactor,
- group_by_length = group_by_length,
- length_column_name = length_column_name,
- report_to = report_to,
- project = project,
- trackio_space_id = trackio_space_id,
- ddp_find_unused_parameters = ddp_find_unused_parameters,
- ddp_bucket_cap_mb = ddp_bucket_cap_mb,
- ddp_broadcast_buffers = ddp_broadcast_buffers,
- dataloader_pin_memory = dataloader_pin_memory,
- dataloader_persistent_workers = dataloader_persistent_workers,
- skip_memory_metrics = skip_memory_metrics,
- use_legacy_prediction_loop = use_legacy_prediction_loop,
- push_to_hub = push_to_hub,
- resume_from_checkpoint = resume_from_checkpoint,
- hub_model_id = hub_model_id,
- hub_strategy = hub_strategy,
- hub_token = hub_token,
- hub_private_repo = hub_private_repo,
- hub_always_push = hub_always_push,
- hub_revision = hub_revision,
- gradient_checkpointing = gradient_checkpointing,
- gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
- include_inputs_for_metrics = include_inputs_for_metrics,
- eval_do_concat_batches = eval_do_concat_batches,
- fp16_backend = fp16_backend,
- push_to_hub_model_id = push_to_hub_model_id,
- push_to_hub_organization = push_to_hub_organization,
- push_to_hub_token = push_to_hub_token,
- mp_parameters = mp_parameters,
- auto_find_batch_size = auto_find_batch_size,
- full_determinism = full_determinism,
- torchdynamo = torchdynamo,
- ray_scope = ray_scope,
- ddp_timeout = ddp_timeout,
- torch_compile = torch_compile,
- torch_compile_backend = torch_compile_backend,
- torch_compile_mode = torch_compile_mode,
- include_tokens_per_second = include_tokens_per_second,
- include_num_input_tokens_seen = include_num_input_tokens_seen,
- neftune_noise_alpha = neftune_noise_alpha,
- optim_target_modules = optim_target_modules,
- batch_eval_metrics = batch_eval_metrics,
- eval_on_start = eval_on_start,
- use_liger_kernel = use_liger_kernel,
- liger_kernel_config = liger_kernel_config,
- eval_use_gather_object = eval_use_gather_object,
- average_tokens_across_devices = average_tokens_across_devices,
- reward_model_path = reward_model_path,
- judge = judge,
- max_new_tokens = max_new_tokens,
- max_length = max_length,
- temperature = temperature,
- top_p = top_p,
- top_k = top_k,
- min_p = min_p,
- repetition_penalty = repetition_penalty,
- generation_kwargs = generation_kwargs,
- use_transformers_paged = use_transformers_paged,
- cache_implementation = cache_implementation,
- missing_eos_penalty = missing_eos_penalty,
- loss_type = loss_type,
- disable_dropout = disable_dropout,
- use_vllm = use_vllm,
- vllm_model_impl = vllm_model_impl,
- vllm_guided_decoding_regex = vllm_guided_decoding_regex,
- vllm_gpu_memory_utilization = vllm_gpu_memory_utilization,
- vllm_mode = vllm_mode,
- vllm_server_base_url = vllm_server_base_url,
- vllm_server_host = vllm_server_host,
- vllm_server_port = vllm_server_port,
- vllm_server_timeout = vllm_server_timeout,
- vllm_tensor_parallel_size = vllm_tensor_parallel_size,
- ds3_gather_for_generation = ds3_gather_for_generation,
- model_init_kwargs = model_init_kwargs,
- reward_weights = reward_weights,
- dataset_num_proc = dataset_num_proc,
- gpu_memory_utilization = gpu_memory_utilization,**kwargs)
- self.vllm_sampling_params = vllm_sampling_params
- self.unsloth_num_chunks = unsloth_num_chunks
- if unsloth_grpo_mini_batch is not None:
- if self.generation_batch_size >= unsloth_grpo_mini_batch:
- self.unsloth_grpo_mini_batch = unsloth_grpo_mini_batch
- else:
- raise ValueError(
- f"Unsloth GRPO mini batch size needs to be less than or equal to the effective generation batch size, "
- f"which is self.per_device_train_batch_size * gradient_accumulation_steps."
- )
- self.unsloth_logit_chunk_multiplier = unsloth_logit_chunk_multiplier
- self.max_seq_length = max_seq_length
-
-pass
-
-class _UnslothOnlineDPOTrainer(BaseTrainer):
- r""""""
-
- _tag_names = ["trl", "online-dpo"]
- _name = "Online DPO"
- _paper = {
- "title": "Direct Language Model Alignment from Online AI Feedback",
- "id": "2402.04792",
- # docstyle-ignore
- "citation": textwrap.dedent("""\
- @article{guo2024direct,
- title = {{Direct Language Model Alignment from Online AI Feedback}},
- author = {Shangmin Guo and Biao Zhang and Tianlin Liu and Tianqi Liu and Misha Khalman and Felipe Llinares and Alexandre Ram{\'{e}} and Thomas Mesnard and Yao Zhao and Bilal Piot and Johan Ferret and Mathieu Blondel},
- year = 2024,
- eprint = {arXiv:2402.04792}
- }"""),
- }
-
- def __init__(
- self,
- model: Union[PreTrainedModel, nn.Module, str],
- ref_model: Union[PreTrainedModel, nn.Module, None] = None,
- reward_funcs: Optional[Union[RewardFunc, list[RewardFunc]]] = None,
- judge: Optional[BasePairwiseJudge] = None,
- args: Optional[OnlineDPOConfig] = None,
- data_collator: Optional[DataCollator] = None,
- train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
- eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
- processing_class: Optional[Union[PreTrainedTokenizerBase, ProcessorMixin]] = None,
- reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
- peft_config: Optional["PeftConfig"] = None,
- compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
- callbacks: Optional[list[TrainerCallback]] = None,
- optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
- preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
- # Deprecated parameters
- reward_model: Optional[Union[PreTrainedModel, nn.Module]] = None,
- reward_processing_class: Optional[PreTrainedTokenizerBase] = None,
- ) -> None:
-
- if hasattr(model, 'vllm_engine') and hasattr(args, 'use_vllm'):
- if (getattr(args, 'use_vllm', False) == False):
- args.use_vllm = True
- if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"):
- warnings.warn(
- "This trainer will soon be moved to trl.experimental and is a candidate for removal. If you rely on "
- "it and want it to remain, please share your comments here: "
- "https://github.com/huggingface/trl/issues/4223. Silence this warning by setting environment variable "
- "TRL_EXPERIMENTAL_SILENCE=1."
- )
- if ref_model is model:
- raise ValueError(
- "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
- "same as `model`, either omit the `ref_model` argument or pass `None`."
- )
-
- self.ref_model = ref_model
-
- # Handle deprecated parameters for backward compatibility
- if reward_model is not None:
- warnings.warn(
- "The `reward_model` parameter is deprecated and will be removed in version 0.25.0. "
- "Please use `reward_funcs` instead. For example, change `reward_model=model` to `reward_funcs=model`.",
- )
- # Convert old reward_model to new reward_funcs format
- if reward_funcs is None:
- reward_funcs = reward_model
- else:
- warnings.warn(
- "Both `reward_model` and `reward_funcs` are provided. Using `reward_funcs` and ignoring "
- "`reward_model`.",
- )
-
- if reward_processing_class is not None:
- warnings.warn(
- "The `reward_processing_class` parameter is deprecated and will be removed in version 0.25.0. "
- "Please use `reward_processing_classes` instead. For example, change "
- "`reward_processing_class=tokenizer` to `reward_processing_classes=tokenizer`.",
- )
- # Convert old reward_processing_class to new reward_processing_classes format
- if reward_processing_classes is None:
- reward_processing_classes = reward_processing_class
- else:
- warnings.warn(
- "Both `reward_processing_class` and `reward_processing_classes` are provided. Using "
- "`reward_processing_classes` and ignoring `reward_processing_class`.",
- )
-
- # Validate reward configuration - must have exactly one of: judge, or reward_funcs
- reward_configs = sum(x is not None for x in [judge, reward_funcs])
- if reward_configs == 0:
- raise ValueError("One of `judge` or `reward_funcs` must be provided.")
- elif reward_configs > 1:
- if judge is not None:
- logger.warning(
- "Both `judge` and `reward_funcs` are provided. Using `judge` and ignoring `reward_funcs`.",
- UserWarning,
- )
- reward_funcs = None
- self.judge = judge
-
- # Handle reward_funcs
- if reward_funcs is not None:
- if not isinstance(reward_funcs, list):
- reward_funcs = [reward_funcs]
- self.reward_func_names = []
-
- # Process reward functions [convert strings to models, collect names]
- model_init_kwargs = args.model_init_kwargs or {}
- for i, reward_func in enumerate(reward_funcs):
- if isinstance(reward_func, str):
- # Load model from string path
- reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
- reward_func, num_labels=1, **model_init_kwargs
- )
- if isinstance(reward_funcs[i], nn.Module):
- self.reward_func_names.append(reward_funcs[i].config._name_or_path.split("/")[-1])
- else:
- self.reward_func_names.append(reward_funcs[i].__name__)
- self.reward_funcs = reward_funcs
-
- # Handle reward processing classes for reward_funcs
- if reward_processing_classes is None:
- reward_processing_classes = [None] * len(reward_funcs)
- elif not isinstance(reward_processing_classes, list):
- reward_processing_classes = [reward_processing_classes]
- else:
- if len(reward_processing_classes) != len(reward_funcs):
- raise ValueError(
- "The number of reward processing classes must match the number of reward functions."
- )
-
- self.reward_processing_classes = []
- for reward_processing_class_i, reward_func in zip(reward_processing_classes, reward_funcs):
- if isinstance(reward_func, PreTrainedModel):
- if reward_processing_class_i is None:
- reward_processing_class_i = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
- if reward_processing_class_i.pad_token_id is None:
- reward_processing_class_i.pad_token = reward_processing_class_i.eos_token
- # Set pad token ID on reward model config
- reward_func.config.pad_token_id = reward_processing_class_i.pad_token_id
- self.reward_processing_classes.append(reward_processing_class_i)
- else:
- self.reward_funcs = None
- self.reward_func_names = []
- self.reward_processing_classes = []
-
- # Handle reward_weights
- if reward_funcs is not None:
- if args.reward_weights is not None:
- if len(args.reward_weights) != len(self.reward_funcs):
- raise ValueError(
- f"Number of reward weights ({len(args.reward_weights)}) must match number of reward "
- f"functions ({len(self.reward_funcs)})"
- )
- self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32)
- else:
- self.reward_weights = torch.ones(len(self.reward_funcs), dtype=torch.float32)
- else:
- self.reward_weights = None
-
- if args.missing_eos_penalty is not None and reward_funcs is None and judge is None:
- # Check if this is the old reward_model case
- if reward_model is not None:
- logger.warning(
- "The `missing_eos_penalty` parameter is deprecated when used with the deprecated `reward_model` parameter. "
- "Please use `reward_funcs` instead of `reward_model` to continue using this feature.",
- FutureWarning,
- stacklevel=2,
- )
- else:
- raise ValueError("`missing_eos_penalty` is only supported when `reward_funcs` is provided.")
-
- if args is None:
- raise ValueError("`args` must be provided.")
-
- # Check that the processing_class is provided
- if processing_class is None:
- raise ValueError("`processing_class` must be provided.")
-
- model_init_kwargs = args.model_init_kwargs or {}
- if isinstance(model, str):
- model_id = model
-
- # Handle dtype in model_init_kwargs
- dtype = model_init_kwargs.get("dtype")
- if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None:
- pass
- elif isinstance(dtype, str):
- dtype = getattr(torch, dtype)
- model_init_kwargs["dtype"] = dtype
- else:
- raise ValueError(
- "Invalid `dtype` passed to `OnlineDPOConfig`. Expected either 'auto' or a string "
- f"representing a `torch.dtype` (e.g., 'float32'), but got {dtype}."
- )
-
- model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
- else:
- if args.model_init_kwargs is not None:
- raise ValueError(
- "You passed `model_init_kwargs` to the `OnlineDPOConfig`, but your model is already instantiated. "
- "This argument can only be used when the `model` argument is a string."
- )
- self.is_encoder_decoder = model.config.is_encoder_decoder
- self.is_vision_model = model.config.model_type in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.keys()
-
- if False:
- pass
-
- # Enable gradient checkpointing if requested
- if args.gradient_checkpointing:
- model = self._enable_gradient_checkpointing(model, args)
-
- # Disable dropout in the model and reference model
- if args.disable_dropout:
- disable_dropout_in_model(model)
- if self.ref_model is not None:
- disable_dropout_in_model(self.ref_model)
-
- # Handle the ref_model
- # Usually, the user wants the ref model to be the initial version of the model. When using PEFT, it's easy to
- # get the ref model, as it's just the model with a disabled adapter. When not using PEFT, we need to create
- # the ref model from the model by copying it and disable the gradients and set it in evaluation mode.
- if ref_model is None: # No ref model provided, the most common case
- if False:
- self.ref_model = create_reference_model(model) # copy, disable gradients, set eval mode
- else:
- self.ref_model = None # we don't need a ref model here, we can just disable the adapter.
- else: # rare case, the user provided a ref model
- self.ref_model = ref_model
- self.ref_model.eval()
-
- # Disable the gradient and set the reward model in eval mode
- if reward_funcs is not None:
- for reward_func in reward_funcs:
- if isinstance(reward_func, PreTrainedModel):
- reward_func.eval()
-
- self.max_length = args.max_length
-
- self.stats = {
- "objective/kl": [],
- "objective/entropy": [],
- "objective/non_score_reward": [],
- "rewards/chosen": [],
- "rewards/rejected": [],
- "rewards/accuracies": [],
- "rewards/margins": [],
- "logps/chosen": [],
- "logps/rejected": [],
- "val/contain_eos_token": [],
- "beta": [],
- }
- if self.reward_funcs is not None:
- self.stats["objective/rlhf_reward"] = []
- self.stats["objective/scores_margin"] = []
- self.stats["objective/scores"] = []
-
- # Store generation parameters for later use
- self.use_vllm = args.use_vllm
- self.num_generations = 2 # Generate 2 completions per prompt for Online DPO
- self.temperature = args.temperature
- self.top_p = args.top_p
- self.top_k = args.top_k
- self.min_p = args.min_p
- self.repetition_penalty = args.repetition_penalty
- self.use_transformers_paged = args.use_transformers_paged
- self.vllm_mode = args.vllm_mode if args.use_vllm else None
- self.vllm_gpu_memory_utilization = args.vllm_gpu_memory_utilization
- self.vllm_tensor_parallel_size = args.vllm_tensor_parallel_size
- self.vllm_model_impl = args.vllm_model_impl
-
- # Handle pad token for processors or tokenizers
- if isinstance(processing_class, ProcessorMixin):
- tokenizer = processing_class.tokenizer
- elif isinstance(processing_class, PreTrainedTokenizerBase):
- tokenizer = processing_class
- else:
- raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`")
-
- if tokenizer.pad_token is None:
- tokenizer.pad_token = tokenizer.eos_token
-
- self.pad_token = tokenizer.pad_token
- self.pad_token_id = tokenizer.pad_token_id
- self.eos_token_id = tokenizer.eos_token_id
-
- # Vision tokens for VLM support
- self.image_token_id = getattr(processing_class, "image_token_id", None)
- self.vision_start_token_id = getattr(processing_class, "vision_start_token_id", None)
- self.vision_end_token_id = getattr(processing_class, "vision_end_token_id", None)
- # Get the image token string for token collapsing
- self.image_token = None
- if self.image_token_id is not None:
- self.image_token = tokenizer.decode([self.image_token_id])
-
- # Define the collator if not provided
- if data_collator is None:
- data_collator = DPODataCollatorWithPadding(pad_token_id=self.pad_token_id)
-
- # The trainer estimates the number of FLOPs [floating-point operations] using the number of elements in the
- # input tensor associated with the key "input_ids". However, in Online DPO, the sampled data does not include
- # the "input_ids" key. As a result, the trainer issues the warning: "Could not estimate the number of tokens
- # of the input, floating-point operations will not be computed." To suppress this warning, we set the
- # "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate
- # that the warning has already been issued.
- model.warnings_issued["estimate_tokens"] = True
-
- super().__init__(
- model=model,
- args=args,
- data_collator=data_collator,
- train_dataset=train_dataset,
- eval_dataset=eval_dataset,
- processing_class=processing_class,
- compute_metrics=compute_metrics,
- callbacks=callbacks,
- optimizers=optimizers,
- preprocess_logits_for_metrics=preprocess_logits_for_metrics,
- )
-
- # Add tags for models that have been loaded with the correct transformers version
- if hasattr(self.model, "add_model_tags"):
- self.model.add_model_tags(self._tag_names)
-
- self._beta = args.beta
-
- # Set up generation configuration and vLLM after super[].__init__
- if self.use_vllm:
- if not is_vllm_available():
- raise ImportError(
- "vLLM is not available and `use_vllm` is set to True. Please install vLLM with "
- "`pip install trl[vllm]` to use it."
- )
-
- if self.vllm_mode == "server":
- if self.accelerator.is_main_process:
- if args.vllm_server_base_url is not None:
- base_url = args.vllm_server_base_url
- else:
- base_url = f"http://{args.vllm_server_host}:{args.vllm_server_port}"
- self.vllm_client = VLLMClient(base_url=base_url, connection_timeout=args.vllm_server_timeout)
- self.vllm_client.init_communicator(device=torch.cuda.current_device())
- else:
- self.vllm_client = None
- elif self.vllm_mode == "colocate":
- vllm_kwargs = {
- "model": model.name_or_path,
- "tensor_parallel_size": self.vllm_tensor_parallel_size,
- "gpu_memory_utilization": self.vllm_gpu_memory_utilization,
- "model_impl": self.vllm_model_impl,
- "max_num_seqs": self.args.per_device_train_batch_size * self.vllm_tensor_parallel_size,
- "max_model_len": args.max_length + args.max_new_tokens,
- "distributed_executor_backend": "external_launcher",
- "seed": self.accelerator.process_index // self.vllm_tensor_parallel_size,
- "max_num_batched_tokens": 4096,
- }
- os.environ["RANK"] = str(self.accelerator.process_index)
- os.environ["LOCAL_RANK"] = str(self.accelerator.local_process_index)
- os.environ["WORLD_SIZE"] = str(self.accelerator.num_processes)
- ensure_master_addr_port()
-
- self.llm = model.vllm_engine
- else:
- raise ValueError(f"vllm_mode must be either 'server' or 'colocate', got '{self.vllm_mode}'.")
- self.guided_decoding_regex = args.vllm_guided_decoding_regex
- self._last_loaded_step = -1
- generation_params = {
- "n": 2,
- "repetition_penalty": self.repetition_penalty,
- "temperature": self.temperature,
- "top_p": self.top_p,
- "top_k": -1 if self.top_k is None else self.top_k,
- "min_p": 0.0 if self.min_p is None else self.min_p,
- "max_tokens": args.max_new_tokens,
- "detokenize": False,
- }
- if args.generation_kwargs is not None:
- generation_params.update(args.generation_kwargs)
- if self.guided_decoding_regex:
- generation_params["guided_decoding"] = GuidedDecodingParams(regex=self.guided_decoding_regex)
- self.generation_config = SamplingParams(**generation_params)
- self.accelerator.wait_for_everyone()
- else:
- # Set up transformers generation config
- generation_kwargs = {
- "max_new_tokens": args.max_new_tokens,
- "do_sample": True,
- "pad_token_id": self.pad_token_id,
- "bos_token_id": tokenizer.bos_token_id,
- "eos_token_id": self.eos_token_id,
- "temperature": self.temperature,
- "top_k": self.top_k,
- "top_p": self.top_p,
- "repetition_penalty": self.repetition_penalty,
- "use_cache": True if not self.args.gradient_checkpointing else False,
- }
- # Add min_p if supported
- if self.min_p is not None:
- generation_kwargs["min_p"] = self.min_p
- if args.generation_kwargs is not None:
- generation_kwargs.update(args.generation_kwargs)
- # Remove None values
- generation_kwargs = {k: v for k, v in generation_kwargs.items() if v is not None}
- self.generation_config = GenerationConfig(**generation_kwargs)
-
- if self.ref_model is not None:
- if self.is_deepspeed_enabled:
- self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
- elif self.is_fsdp_enabled:
- self.ref_model = prepare_fsdp(self.ref_model, self.accelerator)
- else:
- self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
- if self.reward_funcs is not None:
- for i, reward_func in enumerate(self.reward_funcs):
- if isinstance(reward_func, PreTrainedModel):
- if self.is_deepspeed_enabled:
- self.reward_funcs[i] = prepare_deepspeed(reward_func, self.accelerator)
- else:
- # set device placement to True to make `prepare_model` move `reward_func` to device when using fsdp
- self.reward_funcs[i] = self.accelerator.prepare_model(
- reward_func, evaluation_mode=True, device_placement=True
- )
-
- @property
- def beta(self):
- if isinstance(self._beta, list):
- epoch = self.state.epoch
- return self._beta[epoch] if epoch < len(self._beta) else self._beta[-1]
- else:
- return self._beta
-
- @staticmethod
- def tokenize_row(feature, is_encoder_decoder: bool, tokenizer: PreTrainedTokenizerBase) -> dict[str, Any]:
- """Tokenize a single row from a DPO specific dataset."""
- if not is_encoder_decoder:
- batch = tokenizer(feature["prompt"], add_special_tokens=False)
- # Add BOS token to head of prompt. Avoid adding if it's already there
- if tokenizer.bos_token_id is not None:
- prompt_len_input_ids = len(batch["input_ids"])
- if prompt_len_input_ids == 0 or tokenizer.bos_token_id != batch["input_ids"][0]:
- batch["input_ids"] = [tokenizer.bos_token_id] + batch["input_ids"]
- batch["attention_mask"] = [1] + batch["attention_mask"]
- else:
- batch = tokenizer(feature["prompt"], add_special_tokens=True)
- batch = {f"prompt_{key}": value for key, value in batch.items()}
- return batch
-
- # Same as Trainer.get_train_dataloader but skip the "remove_unused_columns".
- @wraps(Trainer.get_train_dataloader)
- def get_train_dataloader(self) -> DataLoader:
- if self.train_dataset is None:
- raise ValueError("Trainer: training requires a train_dataset.")
-
- train_dataset = self.train_dataset
- data_collator = self.data_collator
- dataloader_params = {
- "batch_size": self._train_batch_size,
- "collate_fn": data_collator,
- "num_workers": self.args.dataloader_num_workers,
- "pin_memory": self.args.dataloader_pin_memory,
- "persistent_workers": self.args.dataloader_persistent_workers,
- }
-
- if not isinstance(train_dataset, torch.utils.data.IterableDataset):
- dataloader_params["sampler"] = self._get_train_sampler()
- dataloader_params["drop_last"] = self.args.dataloader_drop_last
- dataloader_params["worker_init_fn"] = seed_worker
- dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
-
- return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
-
- # Same as Trainer.get_eval_dataloader but skip the "remove_unused_columns".
- @wraps(Trainer.get_eval_dataloader)
- def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None) -> DataLoader:
- if eval_dataset is None and self.eval_dataset is None:
- raise ValueError("Trainer: evaluation requires an eval_dataset.")
-
- # If we have persistent workers, don't do a fork bomb especially as eval datasets
- # don't change during training
- dataloader_key = eval_dataset if isinstance(eval_dataset, str) else "eval"
- if (
- hasattr(self, "_eval_dataloaders")
- and dataloader_key in self._eval_dataloaders
- and self.args.dataloader_persistent_workers
- ):
- return self.accelerator.prepare(self._eval_dataloaders[dataloader_key])
-
- eval_dataset = (
- self.eval_dataset[eval_dataset]
- if isinstance(eval_dataset, str)
- else eval_dataset
- if eval_dataset is not None
- else self.eval_dataset
- )
- data_collator = self.data_collator
-
- dataloader_params = {
- "batch_size": self.args.eval_batch_size,
- "collate_fn": data_collator,
- "num_workers": self.args.dataloader_num_workers,
- "pin_memory": self.args.dataloader_pin_memory,
- "persistent_workers": self.args.dataloader_persistent_workers,
- }
-
- if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
- dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset)
- dataloader_params["drop_last"] = self.args.dataloader_drop_last
- dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
-
- # accelerator.free_memory() will destroy the references, so
- # we need to store the non-prepared version
- eval_dataloader = DataLoader(eval_dataset, **dataloader_params)
- if self.args.dataloader_persistent_workers:
- if hasattr(self, "_eval_dataloaders"):
- self._eval_dataloaders[dataloader_key] = eval_dataloader
- else:
- self._eval_dataloaders = {dataloader_key: eval_dataloader}
-
- return self.accelerator.prepare(eval_dataloader)
-
- def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: OnlineDPOConfig) -> PreTrainedModel:
- """Enables gradient checkpointing for the model."""
- # Ensure use_cache is disabled
- model.config.use_cache = False
-
- # Enable gradient checkpointing on the base model for PEFT
- if is_peft_model(model):
- model.base_model.gradient_checkpointing_enable()
- # Enable gradient checkpointing for non-PEFT models
- else:
- model.gradient_checkpointing_enable()
-
- gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
- use_reentrant = (
- "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]
- )
-
- if use_reentrant:
- model.enable_input_require_grads()
-
- return model
-
- def _generate_vllm(self, prompts, images=None):
- eos_token_id = self.eos_token_id
- pad_token_id = self.pad_token_id
-
- # Generate completion_ids and prompt_ids based on mode
- if self.vllm_mode == "server":
- completion_ids, prompt_ids = self._generate_vllm_server(prompts, images)
- elif self.vllm_mode == "colocate":
- completion_ids, prompt_ids = self._generate_vllm_colocate(prompts, images)
-
- # Shared padding, masking, and tensor conversion logic
- max_prompt_length = max(len(ids) for ids in prompt_ids)
- prompt_mask = [[0] * (max_prompt_length - len(ids)) + [1] * len(ids) for ids in prompt_ids]
- prompt_ids = [[pad_token_id] * (max_prompt_length - len(ids)) + ids for ids in prompt_ids]
- max_tokens = self.generation_config.max_tokens
- completion_mask = [[1] * len(ids) + [0] * (max_tokens - len(ids)) for ids in completion_ids]
- completion_ids = [
- ids + [eos_token_id] if ids[-1] != eos_token_id and len(ids) < max_tokens else ids
- for ids in completion_ids
- ]
- completion_ids = [ids + [pad_token_id] * (max_tokens - len(ids)) for ids in completion_ids]
-
- # Convert to tensors
- prompt_ids = torch.tensor(prompt_ids, device=self.accelerator.device)
- prompt_mask = torch.tensor(prompt_mask, device=self.accelerator.device)
- completion_ids = torch.tensor(completion_ids, device=self.accelerator.device)
- completion_mask = torch.tensor(completion_mask, device=self.accelerator.device)
-
- return prompt_ids, prompt_mask, completion_ids, completion_mask
-
- def _generate_vllm_server(self, prompts, images=None):
- """Generate completions using vLLM server mode"""
- has_images = images is not None
-
- # Update vLLM server weights if needed
- if hasattr(self, "_last_loaded_step") and self.state.global_step != self._last_loaded_step:
- self._move_model_to_vllm()
- self._last_loaded_step = self.state.global_step
- elif not hasattr(self, "_last_loaded_step"):
- self._move_model_to_vllm()
- self._last_loaded_step = self.state.global_step
-
- # Apply chat template if conversational
- if is_conversational({"prompt": prompts[0]}):
- prompts_text = [apply_chat_template({"prompt": p}, self.processing_class)["prompt"] for p in prompts]
- else:
- prompts_text = prompts
- # Gather all prompts to main process
- all_prompts = gather_object(prompts_text)
- if has_images:
- all_images = gather_object(images)
-
- if self.accelerator.is_main_process:
- # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate
- # num_generations outputs for each one. This is faster than generating outputs for each duplicate
- # prompt individually.
- ordered_set_of_prompts = all_prompts[:: self.num_generations]
- if has_images:
- ordered_set_of_images = all_images[:: self.num_generations]
- else:
- ordered_set_of_images = None
- completion_ids = self.vllm_client.generate(
- prompts=ordered_set_of_prompts,
- images=ordered_set_of_images,
- n=self.num_generations,
- repetition_penalty=self.repetition_penalty,
- temperature=self.temperature,
- top_p=self.top_p,
- top_k=-1 if self.top_k is None else self.top_k,
- min_p=0.0 if self.min_p is None else self.min_p,
- max_tokens=self.generation_config.max_tokens,
- guided_decoding_regex=self.guided_decoding_regex if hasattr(self, "guided_decoding_regex") else None,
- generation_kwargs=self.args.generation_kwargs,
- )
- # Flatten: each prompt generates 2 completions
- completion_ids = [[comp_id] for prompt_completions in completion_ids for comp_id in prompt_completions]
- else:
- completion_ids = [None] * (len(all_prompts) * 2)
-
- # Broadcast completions to all processes
- completion_ids = broadcast_object_list(completion_ids, from_process=0)
-
- # Each process takes its slice
- process_slice = slice(
- self.accelerator.process_index * len(prompts) * 2,
- (self.accelerator.process_index + 1) * len(prompts) * 2,
- )
- completion_ids = completion_ids[process_slice]
-
- # Create prompt_ids by tokenizing locally
- prompt_inputs = self.processing_class(
- text=prompts_text,
- return_tensors="pt",
- padding=True,
- padding_side="left",
- add_special_tokens=False,
- )
- prompt_ids = []
- for prompt_tokens in prompt_inputs["input_ids"]:
- prompt_ids.extend([prompt_tokens.tolist(), prompt_tokens.tolist()]) # 2 copies for 2 completions
- return completion_ids, prompt_ids
-
- def _generate_vllm_colocate(self, prompts, images=None):
- """Generate completions using vLLM colocate mode"""
- # Update model weights if needed - only after gradient accumulation completes
- if self.state.global_step != self._last_loaded_step:
- self._move_model_to_vllm()
- self._last_loaded_step = self.state.global_step
-
- # Apply chat template if conversational
- if is_conversational({"prompt": prompts[0]}):
- prompts_text = [apply_chat_template({"prompt": p}, self.processing_class)["prompt"] for p in prompts]
- else:
- prompts_text = prompts
-
- # Prepare vLLM inputs with images if available
- if images is not None:
- vllm_inputs = []
- for prompt, image in zip(prompts_text, images):
- if image is not None:
- vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image}})
- else:
- vllm_inputs.append(prompt)
- else:
- vllm_inputs = prompts_text
-
- outputs = self.llm.generate(vllm_inputs, self.generation_config, use_tqdm=False, lora_request = self.model.load_lora('online_dpo_trainer_lora_model', load_tensors = True))
-
- completion_ids = [list(output.outputs[i].token_ids) for i in range(2) for output in outputs]
- prompt_ids = [list(output.prompt_token_ids) for _ in range(2) for output in outputs]
-
- return completion_ids, prompt_ids
-
- def _move_model_to_vllm(self):
- """Synchronize model weights to vLLM server with support for PEFT, DeepSpeed, and FSDP"""
- # For DeepSpeed ZeRO-3 and FSDP, we need to gather all parameters before operations
- deepspeed_plugin = self.accelerator.state.deepspeed_plugin
- zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3
- if zero_stage_3:
- import deepspeed
-
- gather_if_zero3 = deepspeed.zero.GatheredParameters
- else:
- gather_if_zero3 = nullcontext
-
- if is_peft_model(self.model):
- # With PEFT and FSDP/DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as
- # merging adapters in a sharded manner is not supported.
- # TODO: does this work with FSDP?
- with gather_if_zero3(list(self.model.parameters())):
- self.model.merge_adapter()
-
- # Update vLLM weights while parameters are gathered
- if self.is_fsdp_enabled: # note if using FSDP, gather_if_zero3 is nullcontext
- # Update vLLM weights while parameters are gathered
- # For PEFT with FSDP we need to use the memory efficient post-order traversal
- fsdp_plugin = getattr(self.accelerator.state, "fsdp_plugin", None)
- fsdp_version = getattr(fsdp_plugin, "fsdp_version", 1) if fsdp_plugin else 1
- if fsdp_version == 1:
- # use memory-efficient post-order traversal for FSDP
- self._sync_fsdp1_params_to_vllm(self.model)
- elif fsdp_version == 2:
- self._sync_fsdp2_params_to_vllm(self.model)
- else:
- # DeepSpeed ZeRO-3 with PEFT
- for name, param in self.model.named_parameters():
- # When using PEFT, we need to recover the original parameter name and discard some parameters
- name = name.removeprefix("base_model.model.").replace(".base_layer", "")
- if self.model.prefix in name:
- continue
- # When module to save, remove its prefix and discard the original module
- if "original_module" in name:
- continue
- name = self._fix_param_name_to_vllm(name, extra_prefixes=["modules_to_save.default."])
-
- if self.vllm_mode == "server" and self.accelerator.is_main_process:
- self.vllm_client.update_named_param(name, param.data)
- elif self.vllm_mode == "colocate":
-
- pass
-
- pass
- # Unmerge adapters while parameters are still gathered
- self.model.unmerge_adapter()
- # Parameters will automatically be repartitioned when exiting the context
- else:
- # For non-PEFT models, simply gather (if needed) and update each parameter individually.
- if self.is_fsdp_enabled:
- fsdp_plugin = getattr(self.accelerator.state, "fsdp_plugin", None)
- fsdp_version = getattr(fsdp_plugin, "fsdp_version", 1) if fsdp_plugin else 1
- if fsdp_version == 1:
- self._sync_fsdp1_params_to_vllm(self.model) # use memory-efficient post-order traversal for FSDP
- elif fsdp_version == 2:
- self._sync_fsdp2_params_to_vllm(self.model)
- else:
- for name, param in self.model.named_parameters():
- name = self._fix_param_name_to_vllm(name)
- with gather_if_zero3([param]):
- if self.vllm_mode == "server" and self.accelerator.is_main_process:
- self.vllm_client.update_named_param(name, param.data)
- elif self.vllm_mode == "colocate":
-
- pass
-
- pass
-
- # Reset cache on vLLM
- if self.vllm_mode == "server" and self.accelerator.is_main_process:
- self.vllm_client.reset_prefix_cache()
- elif self.vllm_mode == "colocate":
- self.llm.reset_prefix_cache()
-
- def _sync_fsdp1_params_to_vllm(self, module: nn.Module, prefix: str = "", visited=None):
- """Memory-efficient post-order traversal of FSDP modules to extract full parameters and sync with vLLM."""
- # For FSDP1, we need to recurse into children and also use summon_full_params
- if visited is None:
- visited = set()
- for child_name, child_module in module.named_children():
- child_prefix = f"{prefix}.{child_name}" if prefix else child_name
- self._sync_fsdp1_params_to_vllm(
- child_module, prefix=child_prefix, visited=visited
- ) # recurse into the child
-
- if isinstance(module, FSDP):
- with FSDP.summon_full_params(module, recurse=False, writeback=False):
- for param_name, param in module.named_parameters():
- full_name = f"{prefix}.{param_name}" if prefix else param_name
- full_name = self._fix_param_name_to_vllm(full_name, extra_prefixes=["_fsdp_wrapped_module."])
-
- if full_name in visited:
- continue # skip FSDP subtrees already traversed
- visited.add(full_name)
-
- if self.vllm_mode == "server" and self.accelerator.is_main_process:
- self.vllm_client.update_named_param(full_name, param.data)
- elif self.vllm_mode == "colocate":
-
- pass
-
- pass
-
- def _sync_fsdp2_params_to_vllm(self, module: nn.Module):
- # For FSDP2, module already covers all parameters, so no need for recursion
- for name, param in module.items():
- if param.is_cpu:
- param = param.to(torch.device("cuda"))
- param = param.full_tensor()
-
- if self.vllm_mode == "server" and self.accelerator.is_main_process:
- self.vllm_client.update_named_param(name, param)
- elif self.vllm_mode == "colocate":
-
- pass
-
- pass
-
- def _fix_param_name_to_vllm(self, name, extra_prefixes: Optional[list[str]] = None):
- """Clean parameter names for vLLM compatibility"""
- extra_prefixes = extra_prefixes or []
- prefixes = ["_checkpoint_wrapped_module."] + extra_prefixes
- for prefix in prefixes:
- name = name.replace(prefix, "")
- return name
-
- def process_vision_row(
- self, features: dict[str, Union[list, torch.Tensor]], processing_class=None
- ) -> dict[str, list[int]]:
- """
- Process a vision row for VLM models (adapted from DPO trainer)
- """
- processor = processing_class or self.processing_class
- processed_features = processor(images=[features["image"]], text=features["prompt"], add_special_tokens=False)
-
- prompt_input_ids = processed_features["input_ids"][0]
-
- # Create the output dict with required fields
- output = {
- "prompt_input_ids": prompt_input_ids,
- "prompt_attention_mask": processed_features["attention_mask"][0],
- }
-
- # Add vision-specific fields
- if "pixel_values" in processed_features:
- output["pixel_values"] = processed_features["pixel_values"][0]
- if "pixel_attention_mask" in processed_features:
- output["pixel_attention_mask"] = processed_features["pixel_attention_mask"][0]
- if "image_sizes" in processed_features:
- output["image_sizes"] = processed_features["image_sizes"][0]
-
- return output
-
- def _generate(self, model, prompts, images=None):
- """Generate completions using the model"""
- device = next(model.parameters()).device
- eos_token_id = self.eos_token_id
- pad_token_id = self.pad_token_id
-
- # Apply chat template and tokenize the input
- inputs = [{"prompt": prompt} for prompt in prompts]
-
- # Add images if provided (VLM support)
- if images is not None:
- for i, image in enumerate(images):
- inputs[i]["image"] = image
-
- # Apply chat template to get text prompts
- prompts_text = [maybe_apply_chat_template(x, self.processing_class)["prompt"] for x in inputs]
-
- # Handle image token collapsing/removal
- # The chat template sometimes inserts a single image token into the prompt text. However, when this text is
- # later tokenized, the single image token string is expanded into multiple image token IDs, depending on the
- # image size. We need to handle this properly.
- if self.image_token is not None and images is not None:
- escaped_img_token = re.escape(self.image_token)
- # Search for the image token in the chat template
- if hasattr(self.processing_class, "chat_template") and self.processing_class.chat_template:
- if re.search(escaped_img_token, self.processing_class.chat_template):
- # Collapse repeated image tokens back into a single token
- prompts_text = [
- re.sub(rf"({escaped_img_token})+", self.image_token, text) for text in prompts_text
- ]
- else:
- # If the chat template doesn't use the image token, remove all instances
- if self.vision_end_token_id is not None:
- escaped_eoi_token = re.escape(
- self.processing_class.tokenizer.decode([self.vision_end_token_id])
- )
- prompts_text = [
- re.sub(rf"({escaped_img_token})+{escaped_eoi_token}", "", text) for text in prompts_text
- ]
- else:
- # If vision_end_token_id is None, just remove the image tokens
- prompts_text = [re.sub(rf"({escaped_img_token})+", "", text) for text in prompts_text]
-
- # Prepare kwargs for processing class
- kwargs = {}
- if images is not None:
- kwargs = {"images": [[img] for img in images]}
-
- # Process inputs using the processing class (handles both VLM and LLM)
- prompt_inputs = self.processing_class(
- text=prompts_text,
- return_tensors="pt",
- padding=True,
- padding_side="left",
- add_special_tokens=False,
- **kwargs,
- )
-
- prompt_inputs = {k: v.to(device) for k, v in prompt_inputs.items()}
- # Convert vision inputs to model's dtype for proper computation
- if "pixel_values" in prompt_inputs:
- # Handle DataParallel wrapped models
- model_dtype = getattr(model, "dtype", None)
- if model_dtype is None and hasattr(model, "module"):
- model_dtype = model.module.dtype
- if model_dtype is not None:
- prompt_inputs["pixel_values"] = prompt_inputs["pixel_values"].to(model_dtype)
-
- # Sample 2 completions per prompt of size `max_new_tokens` from the model
- prompt_ids = prompt_inputs["input_ids"].repeat(2, 1)
- prompt_mask = prompt_inputs["attention_mask"].repeat(2, 1)
-
- # Prepare vision inputs if available
- vision_generation_kwargs = {}
- if self.is_vision_model and images is not None:
- if "pixel_values" in prompt_inputs:
- vision_generation_kwargs["pixel_values"] = prompt_inputs["pixel_values"].repeat(2, 1, 1, 1)
- if "pixel_attention_mask" in prompt_inputs:
- vision_generation_kwargs["pixel_attention_mask"] = prompt_inputs["pixel_attention_mask"].repeat(2, 1)
- if "image_sizes" in prompt_inputs:
- vision_generation_kwargs["image_sizes"] = prompt_inputs["image_sizes"].repeat(2, 1)
- if "image_grid_thw" in prompt_inputs:
- vision_generation_kwargs["image_grid_thw"] = prompt_inputs["image_grid_thw"].repeat(2, 1)
-
- if self.use_transformers_paged:
- previous_attn = self.model_wrapped.config._attn_implementation
-
- if is_flash_attn_2_available():
- self.model_wrapped.config._attn_implementation = "paged_attention"
- else:
- self.model_wrapped.config._attn_implementation = "sdpa_paged"
- with (
- profiling_context(self, "transformers.generate_batch"),
- unwrap_model_for_generation(
- model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
- ) as unwrapped_model,
- torch.no_grad(),
- FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(),
- ):
- # Cast to the appropriate dtype based on training configuration
- if self.args.bf16:
- unwrapped_model.to(torch.bfloat16)
- elif self.args.fp16:
- unwrapped_model.to(torch.float16)
- with torch.inference_mode():
- all_outputs = unwrapped_model.generate_batch(
- prompt_ids.tolist(),
- generation_config=self.generation_config,
- progress_bar=False,
- )
- unwrapped_model.train() # restore training mode, as generate_batch forces eval mode
- completion_ids = [output.generated_tokens for output in all_outputs.values()]
- completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
- completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right")
- prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
- # Restore the original attention implementation, training mode
- self.model_wrapped.config._attn_implementation = previous_attn
-
- # Extract completion_ids and create completion_mask
- prompt_length = prompt_ids.size(1)
- completion_ids = prompt_completion_ids[:, prompt_length:]
- completion_ids, completion_mask = truncate_right(completion_ids, eos_token_id, pad_token_id)
-
- return prompt_ids, prompt_mask, completion_ids, completion_mask
- else:
- # Regular generation path
- with (
- profiling_context(self, "transformers.generate"),
- unwrap_model_for_generation(
- model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
- ) as unwrapped_model,
- torch.no_grad(),
- FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(),
- ):
- # Setup cache implementation if specified
- if self.args.cache_implementation is not None:
- unwrapped_model.generation_config.cache_implementation = self.args.cache_implementation
-
- # Standard generation
- output = unwrapped_model.generate(
- input_ids=prompt_ids,
- attention_mask=prompt_mask,
- generation_config=self.generation_config,
- **vision_generation_kwargs,
- )
-
- completion_ids = output[:, prompt_ids.size(1) :]
- completion_ids, completion_mask = truncate_right(completion_ids, eos_token_id, pad_token_id)
-
- return prompt_ids, prompt_mask, completion_ids, completion_mask
-
- def _calculate_rewards_from_functions(self, prompts, completions, completion_ids_list, **reward_kwargs):
- """
- Calculate rewards using reward functions
- """
- device = self.accelerator.device
- rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
-
- # Add trainer state to reward kwargs for dynamic reward shaping
- reward_kwargs["trainer_state"] = self.state
-
- for i, (reward_func, reward_processing_class) in enumerate(
- zip(self.reward_funcs, self.reward_processing_classes)
- ):
- if isinstance(reward_func, nn.Module): # Model-based reward function
- # Handle conversational vs text input
- if is_conversational({"prompt": prompts[0]}):
- messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
- texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
- else:
- texts = [p + c for p, c in zip(prompts, completions)]
-
- # Tokenize and get reward scores
- reward_inputs = reward_processing_class(
- text=texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
- )
- reward_inputs = {k: v.to(device) for k, v in reward_inputs.items()}
-
- with torch.inference_mode():
- rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
- else:
- # Custom reward function
- output_reward_func = reward_func(
- prompts=prompts, completions=completions, completion_ids=completion_ids_list, **reward_kwargs
- )
- # Convert None values to NaN
- output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func]
- rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
-
- # Weight and sum across all reward functions
- if self.reward_weights is not None:
- total_rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1)
- else:
- total_rewards = rewards_per_func.nansum(dim=1)
-
- return total_rewards
-
- def _forward(self, model, prompt_ids, prompt_mask, completion_ids, completion_mask, vision_inputs=None):
- # Get the number of tokens to truncate from prompt
- num_tokens_to_truncate = max(prompt_ids.size(1) + completion_ids.size(1) - self.max_length, 0)
-
- # Truncate left to avoid oom
- prompt_ids = prompt_ids[:, num_tokens_to_truncate:]
- prompt_mask = prompt_mask[:, num_tokens_to_truncate:]
-
- # Concat the prompt and completion
- prompt_completion_ids = torch.cat((prompt_ids, completion_ids), dim=1)
- prompt_completion_mask = torch.cat((prompt_mask, completion_mask), dim=1)
-
- # Prepare model kwargs with vision inputs if available
- model_kwargs = {"attention_mask": prompt_completion_mask}
- if vision_inputs is not None:
- if "pixel_values" in vision_inputs:
- model_kwargs["pixel_values"] = vision_inputs["pixel_values"]
- if "pixel_attention_mask" in vision_inputs:
- model_kwargs["pixel_attention_mask"] = vision_inputs["pixel_attention_mask"]
- if "image_sizes" in vision_inputs:
- model_kwargs["image_sizes"] = vision_inputs["image_sizes"]
- if "image_grid_thw" in vision_inputs:
- model_kwargs["image_grid_thw"] = vision_inputs["image_grid_thw"]
-
- # Get the logprobs of the completions from the model
- output = model(prompt_completion_ids, **model_kwargs)
-
- # There is 1 offset, because the model predicts the next token
- prompt_len = prompt_ids.size(1)
- start_idx = prompt_len - 1 if prompt_len > 0 else 0
- # Only slice off the last logit when we have a prompt, otherwise we need all logits
- end_idx = -1 if prompt_len > 0 else None
- logits = output.logits[:, start_idx:end_idx]
-
- # Take the completion tokens logprob
- logprobs = torch.take_along_dim(logits.log_softmax(dim=-1), completion_ids.unsqueeze(-1), dim=2).squeeze(-1)
- return logprobs
-
- def training_step(
- self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
- ) -> torch.Tensor:
- model.train()
-
- prompts = inputs["prompt"]
- batch_size = len(prompts)
-
- # Handle images for VLM support
- has_images = "image" in inputs
- images = None
- if has_images:
- images = inputs["image"]
- # Convert conversational prompts to include image tokens
- for prompt in prompts:
- if isinstance(prompt, list):
- for message in prompt:
- if not isinstance(message, dict):
- continue
- content = message.get("content")
- role = message.get("role")
- if isinstance(content, str):
- if role == "user":
- message["content"] = [{"type": "image"}, {"type": "text", "text": content}]
- elif role == "system":
- message["content"] = [{"type": "text", "text": content}]
-
- if self.args.use_vllm:
- prompt_ids, prompt_mask, completion_ids, completion_mask = self._generate_vllm(prompts, images)
- else:
- prompt_ids, prompt_mask, completion_ids, completion_mask = self._generate(model, prompts, images)
-
- contain_eos_token = torch.any(completion_ids == self.eos_token_id, dim=-1)
-
- # Extract vision inputs if available for VLM support
- vision_inputs = None
- if has_images and self.is_vision_model and not self.args.use_vllm:
- # For vision models with transformers generation, we need to prepare vision inputs
- # Process the images to get vision inputs that can be passed through the forward pass
- vision_inputs = {}
- kwargs = {"images": [[img] for img in images]}
- processed = self.processing_class(
- text=[""] * len(images), # Dummy text for vision processing
- return_tensors="pt",
- **kwargs,
- )
- # Handle DataParallel wrapped models
- model_device = getattr(model, "device", None)
- model_dtype = getattr(model, "dtype", None)
- if model_device is None and hasattr(model, "module"):
- model_device = model.module.device
- model_dtype = model.module.dtype
- # Move vision tensors to device and convert to model dtype
- # Need to duplicate for 2 completions per prompt
- if "pixel_values" in processed:
- vision_inputs["pixel_values"] = (
- processed["pixel_values"].to(model_device, dtype=model_dtype).repeat(2, 1, 1, 1)
- )
- if "pixel_attention_mask" in processed:
- vision_inputs["pixel_attention_mask"] = processed["pixel_attention_mask"].to(model_device).repeat(2, 1)
- if "image_sizes" in processed:
- vision_inputs["image_sizes"] = processed["image_sizes"].to(model_device).repeat(2, 1)
- if "image_grid_thw" in processed:
- vision_inputs["image_grid_thw"] = processed["image_grid_thw"].to(model_device).repeat(2, 1)
-
- logprobs = self._forward(model, prompt_ids, prompt_mask, completion_ids, completion_mask, vision_inputs)
- with torch.no_grad():
- if self.ref_model is not None:
- ref_logprobs = self._forward(
- self.ref_model, prompt_ids, prompt_mask, completion_ids, completion_mask, vision_inputs
- )
- else: # peft case: we just need to disable the adapter
- with self.model.disable_adapter():
- ref_logprobs = self._forward(
- self.model, prompt_ids, prompt_mask, completion_ids, completion_mask, vision_inputs
- )
-
- # Decode the completions, and format them if the input is conversational
- device = logprobs.device
- completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
- if is_conversational({"prompt": prompts[0]}):
- completions = [[{"role": "assistant", "content": completion}] for completion in completions]
-
- # Get the reward from reward functions, judge, or deprecated reward_model
- if self.reward_funcs is not None:
- # First create completion_ids_list for custom reward functions
- completion_ids_list = [completion_ids[i].tolist() for i in range(completion_ids.shape[0])]
-
- # Extract additional fields from inputs for reward functions
- reward_kwargs = {}
- keys = [key for key in inputs if key not in ["prompt"]]
- for key in keys:
- if isinstance(inputs[key], (list, tuple)):
- # Repeat input fields to match number of completions (2 per prompt)
- reward_kwargs[key] = inputs[key] * 2
- else:
- reward_kwargs[key] = inputs[key]
-
- # Calculate rewards using reward functions
- rewards = self._calculate_rewards_from_functions(
- prompts=2 * prompts, completions=completions, completion_ids_list=completion_ids_list, **reward_kwargs
- )
-
- # Apply missing EOS penalty if configured
- if self.args.missing_eos_penalty is not None:
- rewards[~contain_eos_token] -= self.args.missing_eos_penalty
-
- # Split rewards into chosen/rejected pairs
- first_half, second_half = rewards.split(batch_size)
- mask = first_half >= second_half
- elif self.judge is not None:
- # Once formatted, conversational data may contain special tokens (such as <|im_start|>) that are not
- # directly understandable by the judge and could alter its judgment. To avoid this and make the judge
- # independent of the model's chat template, we use the raw conversation data, and apply our own chat
- # template to it.
- if is_conversational({"prompt": prompts[0]}):
- environment = jinja2.Environment()
- template = environment.from_string(SIMPLE_CHAT_TEMPLATE)
- prompts = [template.render(messages=prompt) for prompt in prompts]
- completions = [template.render(messages=completion) for completion in completions]
-
- ranks_of_first_completion = self.judge.judge(
- prompts, list(zip(completions[:batch_size], completions[batch_size:]))
- )
-
- # convert ranks to a True/False mask:
- # when rank == 0, it means the first completion is the best
- # when rank == 1, it means the second completion is the best
- mask = torch.tensor([rank == 0 for rank in ranks_of_first_completion], device=device)
-
- batch_range = torch.arange(batch_size, device=device)
- chosen_indices = batch_range + (~mask * batch_size)
- rejected_indices = batch_range + (mask * batch_size)
-
- # Build tensor so that the first half is the chosen examples and the second half the rejected examples
- cr_indices = torch.cat((chosen_indices, rejected_indices), dim=0) # cr = chosen and rejected
- cr_logprobs = logprobs[cr_indices]
- cr_ref_logprobs = ref_logprobs[cr_indices]
-
- # mask out the padding tokens
- padding_mask = ~completion_mask.bool()
- cr_padding_mask = padding_mask[cr_indices]
-
- cr_logprobs_sum = (cr_logprobs * ~cr_padding_mask).sum(1)
- cr_ref_logprobs_sum = (cr_ref_logprobs * ~cr_padding_mask).sum(1)
-
- # Split the chosen and rejected examples
- chosen_logprobs_sum, rejected_logprobs_sum = torch.split(cr_logprobs_sum, batch_size)
- chosen_ref_logprobs_sum, rejected_ref_logprobs_sum = torch.split(cr_ref_logprobs_sum, batch_size)
- pi_logratios = chosen_logprobs_sum - rejected_logprobs_sum
- ref_logratios = chosen_ref_logprobs_sum - rejected_ref_logprobs_sum
-
- logits = pi_logratios - ref_logratios
-
- if self.args.loss_type == "sigmoid":
- losses = -F.logsigmoid(self.beta * logits)
- elif self.args.loss_type == "ipo":
- losses = (logits - 1 / (2 * self.beta)) ** 2
- else:
- raise NotImplementedError(f"invalid loss type {self.loss_type}")
-
- loss = losses.mean()
-
- # Log everything
- if self.reward_funcs is not None:
- # When using reward_funcs, we have rewards instead of scores
- scores_margin = rewards[chosen_indices] - rewards[rejected_indices]
- self.stats["objective/scores_margin"].append(
- self.accelerator.gather_for_metrics(scores_margin.mean()).mean().item()
- )
- self.stats["objective/scores"].append(self.accelerator.gather_for_metrics(rewards.mean()).mean().item())
- self.stats["val/contain_eos_token"].append(contain_eos_token.float().mean().item())
- self.stats["logps/chosen"].append(self.accelerator.gather_for_metrics(chosen_logprobs_sum).mean().item())
- self.stats["logps/rejected"].append(self.accelerator.gather_for_metrics(rejected_logprobs_sum).mean().item())
-
- kl = logprobs - ref_logprobs
- mean_kl = kl.sum(1).mean()
- self.stats["objective/kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
- non_score_reward = (-self.beta * kl).sum(1)
- mean_non_score_reward = non_score_reward.mean()
- self.stats["objective/non_score_reward"].append(
- self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item()
- )
- if self.reward_funcs is not None:
- # Calculate RLHF reward by combining rewards with non_score_reward
- rlhf_reward = rewards + non_score_reward
- self.stats["objective/rlhf_reward"].append(self.accelerator.gather_for_metrics(rlhf_reward).mean().item())
-
- mean_entropy = -logprobs.sum(1).mean()
- self.stats["objective/entropy"].append(self.accelerator.gather_for_metrics(mean_entropy).mean().item())
- chosen_rewards = self.beta * (chosen_logprobs_sum - chosen_ref_logprobs_sum)
- gathered_chosen_rewards = self.accelerator.gather_for_metrics(chosen_rewards)
- self.stats["rewards/chosen"].append(gathered_chosen_rewards.mean().item())
- rejected_rewards = self.beta * (rejected_logprobs_sum - rejected_ref_logprobs_sum)
- gathered_rejected_rewards = self.accelerator.gather_for_metrics(rejected_rewards)
- self.stats["rewards/rejected"].append(gathered_rejected_rewards.mean().item())
- margin = gathered_chosen_rewards - gathered_rejected_rewards
- self.stats["rewards/margins"].append(margin.mean().item())
- accuracy = margin > 0
- self.stats["rewards/accuracies"].append(accuracy.float().mean().item())
- self.stats["beta"].append(self.beta)
-
- if (
- self.args.torch_empty_cache_steps is not None
- and self.state.global_step % self.args.torch_empty_cache_steps == 0
- ):
- empty_cache()
-
- kwargs = {}
-
- # For LOMO optimizers you need to explicitly use the learning rate
- if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
- kwargs["learning_rate"] = self._get_learning_rate()
-
- if self.args.n_gpu > 1:
- loss = loss.mean() # mean() to average on multi-gpu parallel training
-
- self.accelerator.backward(loss, **kwargs)
-
- return loss.detach() / self.args.gradient_accumulation_steps
-
- # Same as Trainer._maybe_log_save_evaluate but log our metrics
- def _maybe_log_save_evaluate(
- self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, learning_rate=None
- ):
- if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
- logs: dict[str, float] = {}
-
- # all_gather + mean() to get average loss over all processes
- tr_loss_scalar = self._nested_gather(tr_loss).mean().item()
-
- # reset tr_loss to zero
- tr_loss -= tr_loss
-
- logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
- if grad_norm is not None:
- logs["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm
- if learning_rate is not None:
- logs["learning_rate"] = learning_rate
- else:
- logs["learning_rate"] = self._get_learning_rate()
-
- # Add our metrics
- for key, val in self.stats.items():
- logs[key] = sum(val) / len(val)
- self.stats = {key: [] for key in self.stats} # reset stats
-
- self._total_loss_scalar += tr_loss_scalar
- self._globalstep_last_logged = self.state.global_step
- self.store_flos()
- self.log(logs, start_time)
-
- metrics = None
- if self.control.should_evaluate:
- metrics = self._evaluate(trial, ignore_keys_for_eval)
- is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial)
-
- if self.args.save_strategy == "best":
- self.control.should_save = is_new_best_metric
-
- if self.control.should_save:
- self._save_checkpoint(model, trial)
- self.control = self.callback_handler.on_save(self.args, self.state, self.control)
-
- # Ensure the model card is saved along with the checkpoint
- def _save_checkpoint(self, model, trial):
- if self.args.hub_model_id is None:
- model_name = Path(self.args.output_dir).name
- else:
- model_name = self.args.hub_model_id.split("/")[-1]
- self.create_model_card(model_name=model_name)
- super()._save_checkpoint(model, trial)
-class UnslothOnlineDPOTrainer(_UnslothOnlineDPOTrainer):
- """
-
- Initialize OnlineDPOTrainer.
-
- Args:
- model (`Union[str, nn.Module, PreTrainedModel]`):
- Model to be trained. Can be either:
-
- - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a
- path to a *directory* containing model weights saved using
- [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
- using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keyword arguments in
- `args.model_init_kwargs`.
- - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
- ref_model ([`~transformers.PreTrainedModel`] or `torch.nn.Module` or `None`):
- The reference model to use for training. If None is specified, the reference model will be created from the
- model.
- judge ([`BasePairwiseJudge`]):
- The judge to use for pairwise comparison of model completions.
- reward_funcs (`Union[RewardFunc, list[RewardFunc]]`, *optional*):
- Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward
- functions with the prompts and completions and sum the rewards. Can be either:
-
- - A single reward function: Can be a string (path to model), a [`~transformers.PreTrainedModel`], or a
- custom callable function.
- - A list of reward functions: Must all be of compatible types.
-
- Note: Only one of `judge`, or `reward_funcs` should be provided.
- args ([`OnlineDPOConfig`]):
- The online DPO config arguments to use for training.
- data_collator ([`~transformers.DataCollator`]):
- The data collator to use for training. If None is specified, the default data collator
- ([`DPODataCollatorWithPadding`]) will be used which will pad the sequences to the maximum length of the
- sequences in the batch, given a dataset of paired sequences.
- train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
- The dataset to use for training.
- eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
- The dataset to use for evaluation.
- processing_class ([`~transformers.PreTrainedTokenizerBase`] or [`~transformers.ProcessorMixin`], *optional*):
- Processing class used to process the data. If provided, will be used to automatically process the inputs
- for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
- reuse the fine-tuned model.
- reward_processing_classes ([`~transformers.PreTrainedTokenizerBase`] or `list[PreTrainedTokenizerBase]`, *optional*):
- Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either:
-
- - A single processing class: Used when `reward_funcs` contains only one reward function.
- - A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`.
-
- If set to `None`, the tokenizer for each model-based reward function is automatically loaded using
- [`~transformers.AutoTokenizer.from_pretrained`].
- peft_config ([`~peft.PeftConfig`], *optional*):
- PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
- compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
- The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to
- metric values.
- callbacks (`list[transformers.TrainerCallback]`):
- The callbacks to use for training.
- optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
- The optimizer and scheduler to use for training.
- preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
- The function to use to preprocess the logits before computing the metrics.
-
- reward_model:
-
-
-
- This parameter is deprecated and will be removed in version 0.25.0. Use `reward_funcs` instead.
-
-
-
- """
- def __init__(
- self,
- model,
- ref_model = None,
- reward_funcs = None,
- judge = None,
- args = None,
- data_collator = None,
- train_dataset = None,
- eval_dataset = None,
- processing_class = None,
- reward_processing_classes = None,
- peft_config = None,
- compute_metrics = None,
- callbacks = None,
- preprocess_logits_for_metrics = None,
- reward_model = None,
- reward_processing_class = None,
- **kwargs
- ):
- if args is None: args = UnslothOnlineDPOConfig()
- use_bf16 = getattr(args, 'bf16', False)
- if type(use_bf16) is not bool: use_bf16 = False
- use_fp16 = getattr(args, 'fp16', False)
- if type(use_fp16) is not bool: use_fp16 = False
- force_float32 = False
- full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1'
- if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'):
- print('Unsloth: Switching to float32 training since model cannot work with float16')
- force_float32 = True
- mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
- dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None)
- if dtype is None: dtype = model.get_input_embeddings().weight.dtype
- from unsloth_zoo.utils import _get_dtype
- dtype = _get_dtype(dtype)
- float16 = dtype == torch.float16
- if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
- if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
- if force_float32:
- # Forced float32 training
- args.fp16 = False
- args.bf16 = False
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
- if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
- # args.mixed_precision is a new argument which needs to be set now
- elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
- # Mixed precision training
- args.fp16 = float16
- args.bf16 = not float16
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
- if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16'
- # args.mixed_precision is a new argument which needs to be set now
- elif mixed_precision_dtype == 'bfloat16':
- # Both False since bfloat16 full finetuning doesn't do any autocasting.
- args.fp16 = False
- args.bf16 = False
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
- if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
- # args.mixed_precision is a new argument which needs to be set now
-
- if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
- args.eval_strategy = 'steps'
- if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
- ga_steps = getattr(args, 'gradient_accumulation_steps', None)
- if ga_steps is not None and ga_steps > 1:
- from transformers import __version__ as transformers_version
- if Version(transformers_version) <= Version('4.45.2'):
- print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
- '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
- if getattr(args, 'eval_strategy', 'no') != 'no':
- eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
- if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
- if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
- fp16_full_eval = getattr(args, 'fp16_full_eval', False)
- if type(fp16_full_eval) is not bool: fp16_full_eval = False
- bf16_full_eval = getattr(args, 'bf16_full_eval', False)
- if type(bf16_full_eval) is not bool: bf16_full_eval = False
- if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
- if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
- if force_float32:
- args.bf16_full_eval = False
- args.fp16_full_eval = False
- elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
- args.bf16_full_eval = True
- args.fp16_full_eval = False
- elif not bf16_full_eval and not fp16_full_eval:
- args.bf16_full_eval = args.bf16
- args.fp16_full_eval = args.fp16
- _output_logits = False
- if locals().get('compute_metrics', None) is not None: _output_logits = True
- if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
- if _output_logits:
- os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
- if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
- pass
- else:
- model_max_seq_length = getattr(model, 'max_seq_length', None)
- args_max_seq_length = getattr(args, 'max_seq_length', None)
- if args_max_seq_length is None and model_max_seq_length is not None:
- max_seq_length = model.max_seq_length
- if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
- elif args_max_seq_length is not None and model_max_seq_length is not None:
- if args_max_seq_length > model_max_seq_length:
- print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but '
- 'the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.')
- args.max_seq_length = model_max_seq_length
- if model is not None and hasattr(model, 'for_training'):
- model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
- if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
- if 'processing_class' in locals():
- if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
- if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
- __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
- from unsloth_zoo.vision_utils import UnslothVisionDataCollator
- if not isinstance(data_collator, UnslothVisionDataCollator):
- if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
- data_collator = TransformersDataCollatorForLanguageModeling(
- __tokenizer,
- mlm = False,
- mlm_probability = 0.0,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
- data_collator = DataCollatorForSeq2Seq(
- __tokenizer,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- else:
- if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
- if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
- if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
- if not isinstance(data_collator, UnslothVisionDataCollator):
- if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
- if isinstance(data_collator, DataCollatorForSeq2Seq):
- data_collator = DataCollatorForSeq2Seq(
- __tokenizer.tokenizer,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- else:
- data_collator = TransformersDataCollatorForLanguageModeling(
- __tokenizer.tokenizer,
- mlm = False,
- mlm_probability = 0.0,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- other_metrics = []
-
- from unsloth_zoo.logging_utils import PatchRLStatistics
- PatchRLStatistics('online_dpo_trainer', other_metrics)
-
- # [TODO] Fix up DataParallel multiplying batch sizes
- # [TODO] DDP works, but DP seems to not work? [TODO]
- if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1:
- if getattr(args, "_n_gpu", 1) != 1:
- args._n_gpu = 1
- if "model" in locals() and hasattr(model, "for_training"):
- model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
- super().__init__(
- model = model,
- ref_model = ref_model,
- reward_funcs = reward_funcs,
- judge = judge,
- args = args,
- data_collator = data_collator,
- train_dataset = train_dataset,
- eval_dataset = eval_dataset,
- processing_class = processing_class,
- reward_processing_classes = reward_processing_classes,
- peft_config = peft_config,
- compute_metrics = compute_metrics,
- callbacks = callbacks,
- preprocess_logits_for_metrics = preprocess_logits_for_metrics,
- reward_model = reward_model,
- reward_processing_class = reward_processing_class,**kwargs)
- if "model" in locals() and hasattr(model, "for_inference"):
- model.for_inference()
- if hasattr(self, 'neftune_hook_handle'):
- self.neftune_hook_handle.remove()
- if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
- if getattr(args, 'neftune_noise_alpha', None) is not None:
- model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
- pass
- if hasattr(self, 'accelerator'):
- scaler = self.accelerator.scaler
- current_model = model
- while hasattr(current_model, 'model'):
- current_model.accelerator_scaler = scaler
- current_model = current_model.model
- current_model.accelerator_scaler = scaler
- pass
- if hasattr(self, 'train'):
- self.train = MethodType(prepare_for_training_mode(self.__class__.train), self)
- pass
- if hasattr(self, 'llm') and self.llm is not None and hasattr(self.llm, 'get_tokenizer'):
- _vllm_tok = self.llm.get_tokenizer()
- _pc = getattr(self, 'processing_class', None) or getattr(self, 'tokenizer', None)
- if _vllm_tok is not None and _pc is not None and getattr(_pc, 'chat_template', None) is not None and getattr(_vllm_tok, 'chat_template', None) is None:
- _vllm_tok.chat_template = _pc.chat_template
- pass
-
-pass
-
-
-if hasattr(logger, "addFilter"):
- import logging
- class HideLoggingMessage(logging.Filter):
- def __init__(self, text): self.text = text
- def filter(self, x): return not (self.text in x.getMessage())
- pass
- logger.addFilter(HideLoggingMessage("`use_cache=True`"))
-
diff --git a/unsloth_compiled_cache/UnslothPPOTrainer.py b/unsloth_compiled_cache/UnslothPPOTrainer.py
deleted file mode 100644
index 5226980..0000000
--- a/unsloth_compiled_cache/UnslothPPOTrainer.py
+++ /dev/null
@@ -1,1658 +0,0 @@
-"""
-2026.2.1
-2026.2.1
-4.57.6
-0.24.0
-__UNSLOTH_VERSIONING__
-"""
-
-# Unsloth auto generated code
-# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Lesser General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU Lesser General Public License
-# along with this program. If not, see .
-
-from torch import Tensor
-import torch
-import torch.nn as nn
-from torch.nn import functional as F
-from unsloth_zoo.temporary_patches.common import torch_compile
-from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
-from trl.trainer.ppo_trainer import (Accelerator, BaseImageProcessor, BaseTrainer, CallbackHandler, DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK, DataCollatorWithPadding, DataLoader, Dataset, ExportableState, FeatureExtractionMixin, GenerationConfig, INVALID_LOGPROB, OnlineTrainerState, Optional, PPOConfig, PPOTrainer, Path, PeftConfig, PeftModel, PolicyAndValueWrapper, PreTrainedTokenizerBase, PrinterCallback, ProcessorMixin, TrainerCallback, TrainerControl, Union, batch_generation, broadcast, contextmanager, create_reference_model, defaultdict, disable_dropout_in_model, empty_cache, exact_div, first_true_indices, forward, gather_object, gc, get_peft_model, get_reporting_integration_callbacks, get_reward, is_peft_available, is_rich_available, log_table_to_comet_experiment, masked_mean, masked_whiten, math, nn, np, nullcontext, os, pd, peft_module_casting_to_bf16, prepare_deepspeed, print_rich_table, selective_log_softmax, textwrap, time, torch, truncate_response, unwrap_model_for_generation, warnings, Accelerator, BaseImageProcessor, CallbackHandler, DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK, DataCollatorWithPadding, DataLoader, Dataset, ExportableState, FeatureExtractionMixin, OnlineTrainerState, Optional, PPOConfig, PeftConfig, PeftModel, PolicyAndValueWrapper, PreTrainedTokenizerBase, PrinterCallback, ProcessorMixin, TrainerCallback, TrainerControl, Union, broadcast, create_reference_model, disable_dropout_in_model, exact_div, forward, get_peft_model, get_reporting_integration_callbacks, is_peft_available, math, nn, os, pd, peft_module_casting_to_bf16, prepare_deepspeed, time, torch, warnings, Optional, PeftModel, is_peft_available, os, torch)
-
-
-import os
-from typing import *
-from dataclasses import dataclass, field
-from packaging.version import Version
-import torch
-import numpy as np
-from contextlib import nullcontext
-from torch.nn import functional as F
-import inspect
-from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
-from transformers.training_args import ParallelMode
-from unsloth_zoo.device_type import DEVICE_TYPE, device_synchronize
-
-# Wrap trainer with padding to right and enable training mode
-# Also patches W&B since multiple runs must use wandb.finish()
-import functools
-from types import MethodType
-try:
- from unsloth_zoo.gradient_checkpointing import reset_unsloth_gradient_checkpointing_buffers
-except:
- def reset_unsloth_gradient_checkpointing_buffers(): pass
-def prepare_for_training_mode(f):
- @functools.wraps(f)
- def wrapper(self, *args, **kwargs):
- # Enable training mode
- _was_training = None
- # Get gradient checkpointing setting from training arguments
- use_gc = getattr(self.args, 'gradient_checkpointing', True)
- if hasattr(self, 'model') and hasattr(self.model, "training"):
- _was_training = self.model.training
- if hasattr(self, 'model') and hasattr(self.model, "for_training"):
- self.model.for_training(use_gradient_checkpointing=use_gc)
- output = f(self, *args, **kwargs)
- # Restore previous mode when possible
- if hasattr(self, 'model') and hasattr(self.model, "for_inference"):
- if _was_training is False:
- self.model.for_inference()
- elif _was_training is True and hasattr(self.model, "for_training"):
- self.model.for_training(use_gradient_checkpointing=use_gc)
- # Reset gradient checkpointing buffers to free memory while staying ready for next run
- try:
- reset_unsloth_gradient_checkpointing_buffers()
- except:
- pass
- # Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run
- try:
- import wandb
- wandb.finish()
- except:
- pass
- return output
- return wrapper
-pass
-
-torch_compile_options = {
- "epilogue_fusion" : True,
- "max_autotune" : False,
- "shape_padding" : True,
- "trace.enabled" : False,
- "triton.cudagraphs" : False,
-}
-
-@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
-def chunked_hidden_states_selective_log_softmax(
- hidden_states: torch.Tensor,
- lm_head: torch.Tensor,
- index: torch.Tensor,
- chunks: int = 4,
- logit_scale_multiply: float = 0.0,
- logit_scale_divide: float = 0.0,
- logit_softcapping: float = 0.0,
- temperature: float = 1.0,
-) -> torch.Tensor:
- # All Unsloth Zoo code licensed under AGPL3
- flat_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1])
- flat_index = index.reshape(-1)
-
- chunked_hidden_states = torch.chunk(flat_hidden_states, chunks=chunks, dim=0)
- chunked_index = torch.chunk(flat_index, chunks=chunks, dim=0)
-
- all_per_token_logps = []
-
- for chunk_hidden_states, chunk_index in zip(chunked_hidden_states, chunked_index):
- chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t()
-
- if logit_scale_multiply != 0.0:
- chunk_logits = chunk_logits * logit_scale_multiply
- if logit_scale_divide != 0.0:
- chunk_logits = chunk_logits / logit_scale_divide
- if logit_softcapping != 0.0:
- chunk_logits = chunk_logits * torch.tanh(chunk_logits / logit_softcapping)
-
- chunk_logits = chunk_logits.to(torch.float32)
-
- if temperature != 1.0:
- chunk_logits = chunk_logits / temperature
-
- selected_logits = torch.gather(chunk_logits, dim=-1, index=chunk_index.unsqueeze(-1)).squeeze(-1)
- logsumexp_values = torch.logsumexp(chunk_logits, dim=-1)
- per_token_logps = selected_logits - logsumexp_values
- all_per_token_logps.append(per_token_logps)
-
- all_per_token_logps = torch.concat(all_per_token_logps)
-
- all_per_token_logps = all_per_token_logps.reshape((hidden_states.shape[0], hidden_states.shape[1]))
- return all_per_token_logps
-
-@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
-def chunked_selective_log_softmax(logits, index):
- # Split into 4 chunks only
- chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
- chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
- all_per_token_logps = []
- # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
- for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
- chunk_logits = chunk_logits.to(torch.float32)
- selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
- logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
- per_token_logps = selected_logits - logsumexp_values
- all_per_token_logps.append(per_token_logps)
- pass
- all_per_token_logps = torch.concat(all_per_token_logps)
- all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
- return all_per_token_logps
-
-def calculate_pad_tokens_in_prompt(
- input_ids: torch.Tensor,
- logits_to_keep: int,
- pad_token_id: int
-) -> torch.Tensor:
- """
- Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens
- """
- if logits_to_keep >= input_ids.shape[1]:
- raise ValueError("logits_to_keep must be smaller than the sequence length.")
-
- prompt_section = input_ids[:, :-logits_to_keep]
-
- padding_mask = (prompt_section == pad_token_id)
-
- pad_token_counts = padding_mask.sum(dim=1)
-
- return pad_token_counts
-
-def create_completion_attention_mask(
- completion_input_ids: torch.Tensor,
- left_pad_tokens_per_prompt: torch.Tensor,
- max_left_pad: int,
- pad_token_id: int
-) -> torch.Tensor:
- """
- Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad]
-
- Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens
- and pad are pad tokens, this function would make a completion mask that would 0 out the pad
- and p tokens. so in this example [0,0,0,1,1,1,0,0,0]
- """
- batch_size, completion_len = completion_input_ids.shape
- device = completion_input_ids.device
-
- num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt
-
- indices = torch.arange(completion_len, device=device).unsqueeze(0)
- shift_mask = indices >= num_tokens_to_mask.unsqueeze(1)
-
- non_padding_mask = (completion_input_ids != pad_token_id)
-
- final_mask = shift_mask & non_padding_mask
-
- return final_mask
-
-def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor:
- """
- Moves all padding tokens in each sequence of a batch to the right.
- """
- mask = (tensor != pad_id)
- # Must do stable=True since binary mark is unordered
- sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True)
- packed_tensor = torch.gather(tensor, 1, sorted_indices)
- return packed_tensor
-
-def align_logprobs_with_mask(
- logprob_tensor: torch.Tensor,
- attention_mask: torch.Tensor,
- pad_value: float = 0.0
-) -> torch.Tensor:
- """
- Aligns a log probability tensor with a given attention mask.
- """
-
- device = logprob_tensor.device
- batch_size, logprob_seq_len = logprob_tensor.shape
- mask_seq_len = attention_mask.shape[1]
-
- padded_logprobs = torch.full(
- attention_mask.shape,
- fill_value=pad_value,
- dtype=logprob_tensor.dtype,
- device=device
- )
-
- left_pad_counts = torch.argmax(attention_mask, dim=1)
-
- cols = torch.arange(logprob_seq_len, device=device)
- dest_indices = left_pad_counts.unsqueeze(1) + cols
-
- # Create destination row indices
- # Shape: [batch_size, logprob_seq_len]
- row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices)
-
- # --- 4. Filter out-of-bounds indices and perform assignment ---
- # Create a mask to identify only the indices that are within the bounds
- # of the target tensor's sequence length.
- valid_mask = dest_indices < mask_seq_len
-
- # Use this mask to select only the valid row indices, column indices,
- # and the corresponding values from the logprob tensor.
- # This flattens the selected elements into 1D tensors.
- valid_rows = row_indices[valid_mask]
- valid_cols = dest_indices[valid_mask]
- valid_vals = logprob_tensor[valid_mask]
-
- # Place the valid values into their correct positions in the padded tensor
- # using a single, efficient advanced indexing operation.
- padded_logprobs[valid_rows, valid_cols] = valid_vals
-
- return padded_logprobs
-
-def autotune_batch_and_chunks(
- total_input_rows,
- seq_len,
- hidden_size,
- vocab_size,
- dtype_bytes=16,
- multiplier=None
-):
- if multiplier is None:
- final_m = max(4, seq_len // 4096)
- else:
- final_m = multiplier
-
- if torch.cuda.is_available():
- free_bytes, _ = torch.cuda.mem_get_info()
- limit_gb = (free_bytes / (1024**3))*.80
- elif hasattr(torch, "xpu") and torch.xpu.is_available():
- # For XPU: estimate free memory from total - reserved
- total_mem = torch.xpu.get_device_properties(0).total_memory
- reserved_mem = torch.xpu.memory_reserved()
- free_bytes = total_mem - reserved_mem
- limit_gb = (free_bytes / (1024**3)) * 0.80
- else:
- # Fallback: assume 8GB available
- limit_gb = 8.0
-
- bytes_to_gb = 1024**3
-
- b_vals = torch.arange(total_input_rows, 0, -1, device='cpu', dtype=torch.float32)
-
- hidden_gb = (b_vals * seq_len * hidden_size * dtype_bytes) / bytes_to_gb
-
- base_logits = ((b_vals/total_input_rows) * b_vals * seq_len * vocab_size * dtype_bytes) / bytes_to_gb
- logits_gb = base_logits / final_m
-
- total_mem_gb = hidden_gb + logits_gb
-
- valid_mask = total_mem_gb <= limit_gb
- valid_indices = torch.nonzero(valid_mask, as_tuple=False)
-
- if valid_indices.shape[0] == 0:
- #This means your GPU will OOM
- return 4, final_m
-
- best_idx = valid_indices[0].item()
- final_b = int(b_vals[best_idx].item())
-
- return final_b, final_m
-@dataclass
-class UnslothPPOConfig(PPOConfig):
- """
-
- Configuration class for the [`PPOTrainer`].
-
- This class includes only the parameters that are specific to PPO training. For a full list of training arguments,
- please refer to the [`~transformers.TrainingArguments`] and [`OnPolicyConfig`] documentation. Note that default
- values in this class may differ from those in [`~transformers.TrainingArguments`].
-
- Using [`~transformers.HfArgumentParser`] we can turn this class into
- [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
- command line.
-
- Parameters:
- exp_name (`str`, *optional*, defaults to `os.path.basename(__file__)[:-3]`):
- Name of this experiment.
- reward_model_path (`str`, *optional*, defaults to `"EleutherAI/pythia-160m"`):
- Path to the reward model.
- model_adapter_name (`str`, *optional*):
- Name of the train target PEFT adapter, when using LoRA with multiple adapters.
- ref_adapter_name (`str`, *optional*):
- Name of the reference PEFT adapter, when using LoRA with multiple adapters.
- num_ppo_epochs (`int`, *optional*, defaults to `4`):
- Number of epochs to train.
- whiten_rewards (`bool`, *optional*, defaults to `False`):
- Whether to whiten the rewards.
- kl_coef (`float`, *optional*, defaults to `0.05`):
- KL coefficient.
- kl_estimator (`Literal["k1", "k3"]`, *optional*, defaults to `"k1"`):
- Which estimator for KL-Divergence to use from [Approximating KL
- Divergence](http://joschu.net/blog/kl-approx.html). Defaults to "k1", a straightforward, unbiased
- estimator. Can be set to "k3", an unbiased estimator with lower variance which "appears to be a strictly
- better estimator". Cannot be set to "k2", as it is used for logging purposes.
- cliprange (`float`, *optional*, defaults to `0.2`):
- Clip range.
- vf_coef (`float`, *optional*, defaults to `0.1`):
- Value function coefficient.
- cliprange_value (`float`, *optional*, defaults to `0.2`):
- Clip range for the value function.
- gamma (`float`, *optional*, defaults to `1.0`):
- Discount factor.
- lam (`float`, *optional*, defaults to `0.95`):
- Lambda value for GAE.
- ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
- This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
- improving generation speed. However, disabling this option allows training models that exceed the VRAM
- capacity of a single GPU, albeit at the cost of slower generation.
-
- """
- vllm_sampling_params: Optional[Any] = field(
- default = None,
- metadata = {'help': 'vLLM SamplingParams'},
- )
- unsloth_num_chunks : Optional[int] = field(
- default = -1,
- metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
- )
- unsloth_logit_chunk_multiplier : Optional[int] = field(
- default = None,
- metadata = {'help': 'Multiplier for chunked logit computations.'},
- )
- unsloth_grpo_mini_batch : Optional[int] = field(
- default = None,
- metadata = {'help': 'Mini batch size for GRPO hidden state accumulation. Default is None unless user defines it.'},
- )
-
- def __init__(
- self,
- output_dir = None,
- overwrite_output_dir = None,
- do_train = False,
- do_eval = False,
- do_predict = False,
- eval_strategy = 'no',
- prediction_loss_only = False,
- per_device_train_batch_size = 4,
- per_device_eval_batch_size = 4,
- per_gpu_train_batch_size = None,
- per_gpu_eval_batch_size = None,
- gradient_accumulation_steps = 2,
- eval_accumulation_steps = 2,
- eval_delay = 0,
- torch_empty_cache_steps = 250,
- learning_rate = 5e-05,
- weight_decay = 0.01,
- adam_beta1 = 0.9,
- adam_beta2 = 0.999,
- adam_epsilon = 1e-08,
- max_grad_norm = 1.0,
- num_train_epochs = 3.0,
- max_steps = -1,
- lr_scheduler_type = 'linear',
- lr_scheduler_kwargs = None,
- warmup_ratio = 0.1,
- warmup_steps = 0,
- log_level = 'passive',
- log_level_replica = 'warning',
- log_on_each_node = True,
- logging_dir = None,
- logging_strategy = 'steps',
- logging_first_step = False,
- logging_steps = 1,
- logging_nan_inf_filter = False,
- save_strategy = 'steps',
- save_steps = 500,
- save_total_limit = None,
- save_safetensors = True,
- save_on_each_node = False,
- save_only_model = False,
- restore_callback_states_from_checkpoint = False,
- no_cuda = False,
- use_cpu = False,
- use_mps_device = False,
- seed = 3407,
- data_seed = 3407,
- jit_mode_eval = False,
- bf16 = False,
- fp16 = False,
- fp16_opt_level = 'O1',
- half_precision_backend = 'auto',
- bf16_full_eval = False,
- fp16_full_eval = False,
- tf32 = None,
- local_rank = -1,
- ddp_backend = None,
- tpu_num_cores = None,
- tpu_metrics_debug = False,
- debug = '',
- dataloader_drop_last = False,
- eval_steps = None,
- dataloader_num_workers = 0,
- dataloader_prefetch_factor = None,
- past_index = -1,
- run_name = None,
- disable_tqdm = None,
- remove_unused_columns = True,
- label_names = None,
- load_best_model_at_end = False,
- metric_for_best_model = None,
- greater_is_better = None,
- ignore_data_skip = False,
- fsdp = None,
- fsdp_min_num_params = 0,
- fsdp_config = None,
- fsdp_transformer_layer_cls_to_wrap = None,
- accelerator_config = None,
- parallelism_config = None,
- deepspeed = None,
- label_smoothing_factor = 0.0,
- optim = 'adamw_8bit',
- optim_args = None,
- adafactor = False,
- group_by_length = False,
- length_column_name = 'length',
- report_to = 'none',
- project = 'huggingface',
- trackio_space_id = 'trackio',
- ddp_find_unused_parameters = None,
- ddp_bucket_cap_mb = None,
- ddp_broadcast_buffers = None,
- dataloader_pin_memory = True,
- dataloader_persistent_workers = False,
- skip_memory_metrics = True,
- use_legacy_prediction_loop = False,
- push_to_hub = False,
- resume_from_checkpoint = None,
- hub_model_id = None,
- hub_strategy = 'every_save',
- hub_token = None,
- hub_private_repo = None,
- hub_always_push = False,
- hub_revision = None,
- gradient_checkpointing = True,
- gradient_checkpointing_kwargs = None,
- include_inputs_for_metrics = False,
- eval_do_concat_batches = True,
- fp16_backend = 'auto',
- push_to_hub_model_id = None,
- push_to_hub_organization = None,
- push_to_hub_token = None,
- mp_parameters = '',
- auto_find_batch_size = False,
- full_determinism = False,
- torchdynamo = None,
- ray_scope = 'last',
- ddp_timeout = 1800,
- torch_compile = False,
- torch_compile_backend = None,
- torch_compile_mode = None,
- include_tokens_per_second = False,
- include_num_input_tokens_seen = False,
- neftune_noise_alpha = None,
- optim_target_modules = None,
- batch_eval_metrics = False,
- eval_on_start = False,
- use_liger_kernel = False,
- liger_kernel_config = None,
- eval_use_gather_object = False,
- average_tokens_across_devices = True,
- dataset_num_proc = None,
- num_mini_batches = 1,
- total_episodes = None,
- local_rollout_forward_batch_size = 64,
- num_sample_generations = 10,
- response_length = 53,
- stop_token = None,
- stop_token_id = None,
- temperature = 0.7,
- missing_eos_penalty = None,
- sft_model_path = 'EleutherAI/pythia-160m',
- world_size = None,
- num_total_batches = None,
- micro_batch_size = None,
- local_batch_size = None,
- batch_size = None,
- local_mini_batch_size = None,
- mini_batch_size = None,
- exp_name = 'ppo_config',
- reward_model_path = 'EleutherAI/pythia-160m',
- model_adapter_name = None,
- ref_adapter_name = None,
- num_ppo_epochs = 4,
- whiten_rewards = False,
- kl_coef = 0.05,
- kl_estimator = 'k1',
- cliprange = 0.2,
- vf_coef = 0.1,
- cliprange_value = 0.2,
- gamma = 1.0,
- lam = 0.95,
- ds3_gather_for_generation = True,
- vllm_sampling_params = None,
- unsloth_num_chunks = -1,
- unsloth_logit_chunk_multiplier = None,
- unsloth_grpo_mini_batch = None,
-
- **kwargs,
- ):
- if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
- if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
- if num_train_epochs is None:
- num_train_epochs = 3.0 # Default to 3 epochs if None, max_steps will override
- if output_dir is None and save_strategy == 'steps' and save_steps == 500:
- output_dir = 'unsloth_training_checkpoints'
- save_strategy = 'no'
- import multiprocessing as _mp
- if _mp.get_start_method() != 'fork':
- dataset_num_proc = None
- elif dataset_num_proc is None:
- import psutil
- dataset_num_proc = min(max((psutil.cpu_count() or 1)+4, 2), 64)
- memory_gb_left = psutil.virtual_memory().available / (1024**3)
- if memory_gb_left <= 2: dataset_num_proc = 1
- else: dataset_num_proc = min(dataset_num_proc, int(memory_gb_left))
- if temperature <= 0:
- raise ValueError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.')
- elif temperature >= 10:
- raise ValueError('Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic.')
-
-
- super().__init__(
- output_dir = output_dir,
- overwrite_output_dir = overwrite_output_dir,
- do_train = do_train,
- do_eval = do_eval,
- do_predict = do_predict,
- eval_strategy = eval_strategy,
- prediction_loss_only = prediction_loss_only,
- per_device_train_batch_size = per_device_train_batch_size,
- per_device_eval_batch_size = per_device_eval_batch_size,
- per_gpu_train_batch_size = per_gpu_train_batch_size,
- per_gpu_eval_batch_size = per_gpu_eval_batch_size,
- gradient_accumulation_steps = gradient_accumulation_steps,
- eval_accumulation_steps = eval_accumulation_steps,
- eval_delay = eval_delay,
- torch_empty_cache_steps = torch_empty_cache_steps,
- learning_rate = learning_rate,
- weight_decay = weight_decay,
- adam_beta1 = adam_beta1,
- adam_beta2 = adam_beta2,
- adam_epsilon = adam_epsilon,
- max_grad_norm = max_grad_norm,
- num_train_epochs = num_train_epochs,
- max_steps = max_steps,
- lr_scheduler_type = lr_scheduler_type,
- lr_scheduler_kwargs = lr_scheduler_kwargs,
- warmup_ratio = warmup_ratio,
- warmup_steps = warmup_steps,
- log_level = log_level,
- log_level_replica = log_level_replica,
- log_on_each_node = log_on_each_node,
- logging_dir = logging_dir,
- logging_strategy = logging_strategy,
- logging_first_step = logging_first_step,
- logging_steps = logging_steps,
- logging_nan_inf_filter = logging_nan_inf_filter,
- save_strategy = save_strategy,
- save_steps = save_steps,
- save_total_limit = save_total_limit,
- save_safetensors = save_safetensors,
- save_on_each_node = save_on_each_node,
- save_only_model = save_only_model,
- restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
- no_cuda = no_cuda,
- use_cpu = use_cpu,
- use_mps_device = use_mps_device,
- seed = seed,
- data_seed = data_seed,
- jit_mode_eval = jit_mode_eval,
- bf16 = bf16,
- fp16 = fp16,
- fp16_opt_level = fp16_opt_level,
- half_precision_backend = half_precision_backend,
- bf16_full_eval = bf16_full_eval,
- fp16_full_eval = fp16_full_eval,
- tf32 = tf32,
- local_rank = local_rank,
- ddp_backend = ddp_backend,
- tpu_num_cores = tpu_num_cores,
- tpu_metrics_debug = tpu_metrics_debug,
- debug = debug,
- dataloader_drop_last = dataloader_drop_last,
- eval_steps = eval_steps,
- dataloader_num_workers = dataloader_num_workers,
- dataloader_prefetch_factor = dataloader_prefetch_factor,
- past_index = past_index,
- run_name = run_name,
- disable_tqdm = disable_tqdm,
- remove_unused_columns = remove_unused_columns,
- label_names = label_names,
- load_best_model_at_end = load_best_model_at_end,
- metric_for_best_model = metric_for_best_model,
- greater_is_better = greater_is_better,
- ignore_data_skip = ignore_data_skip,
- fsdp = fsdp,
- fsdp_min_num_params = fsdp_min_num_params,
- fsdp_config = fsdp_config,
- fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
- accelerator_config = accelerator_config,
- parallelism_config = parallelism_config,
- deepspeed = deepspeed,
- label_smoothing_factor = label_smoothing_factor,
- optim = optim,
- optim_args = optim_args,
- adafactor = adafactor,
- group_by_length = group_by_length,
- length_column_name = length_column_name,
- report_to = report_to,
- project = project,
- trackio_space_id = trackio_space_id,
- ddp_find_unused_parameters = ddp_find_unused_parameters,
- ddp_bucket_cap_mb = ddp_bucket_cap_mb,
- ddp_broadcast_buffers = ddp_broadcast_buffers,
- dataloader_pin_memory = dataloader_pin_memory,
- dataloader_persistent_workers = dataloader_persistent_workers,
- skip_memory_metrics = skip_memory_metrics,
- use_legacy_prediction_loop = use_legacy_prediction_loop,
- push_to_hub = push_to_hub,
- resume_from_checkpoint = resume_from_checkpoint,
- hub_model_id = hub_model_id,
- hub_strategy = hub_strategy,
- hub_token = hub_token,
- hub_private_repo = hub_private_repo,
- hub_always_push = hub_always_push,
- hub_revision = hub_revision,
- gradient_checkpointing = gradient_checkpointing,
- gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
- include_inputs_for_metrics = include_inputs_for_metrics,
- eval_do_concat_batches = eval_do_concat_batches,
- fp16_backend = fp16_backend,
- push_to_hub_model_id = push_to_hub_model_id,
- push_to_hub_organization = push_to_hub_organization,
- push_to_hub_token = push_to_hub_token,
- mp_parameters = mp_parameters,
- auto_find_batch_size = auto_find_batch_size,
- full_determinism = full_determinism,
- torchdynamo = torchdynamo,
- ray_scope = ray_scope,
- ddp_timeout = ddp_timeout,
- torch_compile = torch_compile,
- torch_compile_backend = torch_compile_backend,
- torch_compile_mode = torch_compile_mode,
- include_tokens_per_second = include_tokens_per_second,
- include_num_input_tokens_seen = include_num_input_tokens_seen,
- neftune_noise_alpha = neftune_noise_alpha,
- optim_target_modules = optim_target_modules,
- batch_eval_metrics = batch_eval_metrics,
- eval_on_start = eval_on_start,
- use_liger_kernel = use_liger_kernel,
- liger_kernel_config = liger_kernel_config,
- eval_use_gather_object = eval_use_gather_object,
- average_tokens_across_devices = average_tokens_across_devices,
- dataset_num_proc = dataset_num_proc,
- num_mini_batches = num_mini_batches,
- total_episodes = total_episodes,
- local_rollout_forward_batch_size = local_rollout_forward_batch_size,
- num_sample_generations = num_sample_generations,
- response_length = response_length,
- stop_token = stop_token,
- stop_token_id = stop_token_id,
- temperature = temperature,
- missing_eos_penalty = missing_eos_penalty,
- sft_model_path = sft_model_path,
- world_size = world_size,
- num_total_batches = num_total_batches,
- micro_batch_size = micro_batch_size,
- local_batch_size = local_batch_size,
- batch_size = batch_size,
- local_mini_batch_size = local_mini_batch_size,
- mini_batch_size = mini_batch_size,
- exp_name = exp_name,
- reward_model_path = reward_model_path,
- model_adapter_name = model_adapter_name,
- ref_adapter_name = ref_adapter_name,
- num_ppo_epochs = num_ppo_epochs,
- whiten_rewards = whiten_rewards,
- kl_coef = kl_coef,
- kl_estimator = kl_estimator,
- cliprange = cliprange,
- vf_coef = vf_coef,
- cliprange_value = cliprange_value,
- gamma = gamma,
- lam = lam,
- ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
- self.vllm_sampling_params = vllm_sampling_params
- self.unsloth_num_chunks = unsloth_num_chunks
- if unsloth_grpo_mini_batch is not None:
- if self.generation_batch_size >= unsloth_grpo_mini_batch:
- self.unsloth_grpo_mini_batch = unsloth_grpo_mini_batch
- else:
- raise ValueError(
- f"Unsloth GRPO mini batch size needs to be less than or equal to the effective generation batch size, "
- f"which is self.per_device_train_batch_size * gradient_accumulation_steps."
- )
- self.unsloth_logit_chunk_multiplier = unsloth_logit_chunk_multiplier
-
-
-pass
-
-class _UnslothPPOTrainer(BaseTrainer):
- """"""
-
- _tag_names = ["trl", "ppo"]
- _name = "PPO"
- _paper = {
- "title": "Fine-Tuning Language Models from Human Preferences",
- "id": "1909.08593",
- # docstyle-ignore
- "citation": textwrap.dedent("""\
- @article{mziegler2019fine-tuning,
- title = {{Fine-Tuning Language Models from Human Preferences}},
- author = {Daniel M. Ziegler and Nisan Stiennon and Jeffrey Wu and Tom B. Brown and Alec Radford and Dario Amodei and Paul F. Christiano and Geoffrey Irving},
- year = 2019,
- eprint = {arXiv:1909.08593}
- }"""),
- }
-
- def __init__(
- self,
- args: PPOConfig,
- processing_class: Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin],
- model: nn.Module,
- ref_model: Optional[nn.Module],
- reward_model: nn.Module,
- train_dataset: Dataset,
- value_model: nn.Module,
- data_collator: Optional[DataCollatorWithPadding] = None,
- eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
- # less commonly used
- optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
- callbacks: Optional[list[TrainerCallback]] = None,
- peft_config: Optional["PeftConfig"] = None,
- ) -> None:
- if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"):
- warnings.warn(
- "This trainer will soon be moved to trl.experimental and is a candidate for removal. If you rely on "
- "it and want it to remain, please share your comments here: "
- "https://github.com/huggingface/trl/issues/4223. Silence this warning by setting environment variable "
- "TRL_EXPERIMENTAL_SILENCE=1."
- )
- if ref_model is model:
- raise ValueError(
- "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
- "same as `model`, you must make a copy of it, or `None` if you use peft."
- )
-
- self.args = args
- self.processing_class = processing_class
- self.policy_model = model
-
- # Define the collator if not provided
- if data_collator is None:
- data_collator = DataCollatorWithPadding(self.processing_class)
-
- # Handle stop token settings: update policy model's generation_config to use provided stop token
- if args.stop_token and args.stop_token_id:
- raise ValueError("You cannot set both `stop_token` and `stop_token_id`.")
- elif args.stop_token:
- if args.stop_token == "eos":
- self.policy_model.generation_config.eos_token_id = self.stop_token_id = processing_class.eos_token_id
- else:
- raise ValueError(
- f"Unknown `stop_token` {args.stop_token}. Allowed values are: `'eos'` and `None` (no stop token)."
- )
- else:
- self.policy_model.generation_config.eos_token_id = self.stop_token_id = args.stop_token_id # None or int
-
- # Check that the kl estimator is valid
- if self.args.kl_estimator not in {"k1", "k3"}:
- raise ValueError(
- "kl_estimator must be either 'k1' (straightforward, unbiased) or 'k3' (lower variance, unbiased, "
- "appears to be a strictly better estimator). See "
- "[Approximating KL Divergence](http://joschu.net/blog/kl-approx.html) for details."
- )
-
- # peft support
- if not is_peft_available() and peft_config is not None:
- raise ImportError(
- "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
- )
- elif is_peft_available() and peft_config is not None:
- # if model is a peft model and we have a peft_confg, we merge and unload it first
- if isinstance(self.policy_model, PeftModel):
- self.policy_model = self.policy_model.merge_and_unload()
-
- # get peft model with the given config
- self.policy_model = get_peft_model(self.policy_model, peft_config)
- if args.bf16 and getattr(self.policy_model, "is_loaded_in_4bit", False):
- peft_module_casting_to_bf16(self.policy_model)
-
- self.is_peft_model = is_peft_available() and isinstance(self.policy_model, PeftModel)
- self.model_adapter_name = args.model_adapter_name
- self.ref_adapter_name = args.ref_adapter_name
-
- if ref_model:
- self.ref_model = ref_model
- elif self.is_peft_model:
- self.ref_model = None
- else:
- self.ref_model = create_reference_model(self.policy_model)
-
- self.reward_model = reward_model
- self.train_dataset = train_dataset
- self.train_dataset_len = len(train_dataset)
- self.value_model = value_model
- self.data_collator = data_collator
- self.eval_dataset = eval_dataset
- self.optimizer, self.lr_scheduler = optimizers
- self.optimizer_cls_and_kwargs = None # needed for transformers >= 4.47
-
- #########
- # calculate various batch sizes
- #########
- if args.total_episodes is None: # allow the users to define episodes in terms of epochs.
- args.total_episodes = int(args.num_train_epochs * self.train_dataset_len)
- accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
- self.accelerator = accelerator
- args.world_size = accelerator.num_processes
- args.local_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps
- args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size)
- args.batch_size = int(args.local_batch_size * args.world_size)
- args.mini_batch_size = exact_div(
- args.batch_size, args.num_mini_batches, "`batch_size` must be a multiple of `num_mini_batches`"
- )
- args.local_mini_batch_size = exact_div(
- args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`"
- )
- if args.whiten_rewards:
- assert args.local_mini_batch_size >= 8, (
- f"Per-rank minibatch size {args.local_mini_batch_size} is insufficient for whitening"
- )
- # `per_rank_rollout_batch_size` is our `args.local_batch_size`
- # `per_rank_minibatch_size` is our `args.local_mini_batch_size`
- args.num_total_batches = math.ceil(
- args.total_episodes / args.batch_size
- ) # we may train for more than `total_episodes`
- time_tensor = torch.tensor(int(time.time()), device=accelerator.device)
- time_int = broadcast(time_tensor, 0).item() # avoid different timestamps across processes
- args.run_name = f"{args.exp_name}__{args.seed}__{time_int}"
- self.local_seed = args.seed + accelerator.process_index * 100003 # Prime
- if args.num_sample_generations > 0:
- self.sample_generations_freq = max(1, args.num_total_batches // args.num_sample_generations)
- self.local_dataloader_batch_size = args.local_batch_size
-
- #########
- # setup model, optimizer, and others
- #########
- for module in [self.policy_model, self.ref_model, self.value_model, self.reward_model]:
- if module is not None:
- disable_dropout_in_model(module)
- self.model = PolicyAndValueWrapper(self.policy_model, self.value_model)
- self.model.config = self.policy_model.config # needed for pushing to hub
- self.create_optimizer_and_scheduler(
- num_training_steps=args.num_total_batches
- ) # note that we are calling `self.lr_scheduler.step[]` manually only at the batch level
-
- #########
- # trainer specifics
- #########
- default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
- self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
- self.callback_handler = CallbackHandler(
- self.callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler
- )
- self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
- self.control = TrainerControl()
- self.state = OnlineTrainerState(
- is_local_process_zero=self.is_local_process_zero(),
- is_world_process_zero=self.is_world_process_zero(),
- stateful_callbacks=[
- cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
- ],
- )
- self.current_flos = 0
- self.hp_search_backend = None
- self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
- self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
- # Create distant repo and output directory if needed
- self.hub_model_id = None
- if self.args.push_to_hub:
- self.init_hf_repo()
- if self.args.should_save:
- os.makedirs(self.args.output_dir, exist_ok=True)
-
- # Add tags for models that have been loaded with the correct transformers version
- if hasattr(self.model, "add_model_tags"):
- self.model.add_model_tags(self._tag_names)
-
- #########
- # setup dataloader
- #########
- self.dataloader = DataLoader(
- self.train_dataset,
- batch_size=self.local_dataloader_batch_size,
- shuffle=True,
- collate_fn=self.data_collator,
- drop_last=True, # needed; otherwise the last batch will be of ragged shape
- )
- # sync random states for DataLoader[shuffle=True] before `accelerator.prepare`
- # see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c
- torch.manual_seed(args.seed)
- self.model, self.optimizer, self.dataloader = accelerator.prepare(self.model, self.optimizer, self.dataloader)
- torch.manual_seed(self.local_seed) # reset the local seed again
-
- self.eval_dataloader = DataLoader(
- self.eval_dataset,
- batch_size=args.per_device_eval_batch_size,
- collate_fn=self.data_collator,
- drop_last=True,
- ) # no need to shuffle eval dataset
- self.eval_dataloader = accelerator.prepare(self.eval_dataloader)
-
- if self.is_deepspeed_enabled:
- self.reward_model = prepare_deepspeed(
- self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16
- )
-
- if self.ref_model is None:
- if not self.is_peft_model:
- raise ValueError("No reference model and model is not a Peft model.")
- else:
- self.ref_model = prepare_deepspeed(
- self.ref_model, args.per_device_train_batch_size, args.fp16, args.bf16
- )
- else:
- if self.ref_model is None:
- if not self.is_peft_model:
- raise ValueError("No reference model and model is not a Peft model.")
- else:
- self.ref_model = self.ref_model.to(self.accelerator.device)
- self.reward_model = self.reward_model.to(self.accelerator.device)
-
- def get_train_dataloader(self) -> DataLoader:
- return self.dataloader
-
- def get_eval_dataloader(self) -> DataLoader:
- return self.eval_dataloader
-
- @contextmanager
- def null_ref_context(self):
- """Context manager for handling null reference model (that is, peft adapter manipulation)."""
- with (
- self.accelerator.unwrap_model(self.model.policy).disable_adapter()
- if self.is_peft_model and not self.ref_adapter_name
- else nullcontext()
- ):
- if self.ref_adapter_name:
- self.model.policy.set_adapter(self.ref_adapter_name)
- yield
- if self.ref_adapter_name:
- self.model.policy.set_adapter(self.model_adapter_name or "default")
-
- def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
- backup_model = self.model
- self.model = self.model.policy # save only the policy
-
- if self.is_deepspeed_enabled:
- backup_deepspeed = self.deepspeed
- self.deepspeed = self.model
-
- super().save_model(output_dir, _internal_call)
-
- self.model = backup_model
-
- if self.is_deepspeed_enabled:
- self.deepspeed = backup_deepspeed
-
- def train(self):
- args = self.args
- accelerator = self.accelerator
- optimizer = self.optimizer
- model = self.model
- ref_policy = self.ref_model
- reward_model = self.reward_model
- processing_class = self.processing_class
- dataloader = self.dataloader
- device = accelerator.device
-
- def repeat_generator():
- while True:
- yield from dataloader
-
- iter_dataloader = iter(repeat_generator())
- generation_config = GenerationConfig(
- max_new_tokens=args.response_length,
- temperature=(args.temperature + 1e-7),
- top_k=0.0,
- top_p=1.0,
- do_sample=True,
- )
-
- accelerator.print("===training policy===")
- start_time = time.time()
- stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps)
- approxkl_stats = torch.zeros(stats_shape, device=device)
- pg_clipfrac_stats = torch.zeros(stats_shape, device=device)
- pg_loss_stats = torch.zeros(stats_shape, device=device)
- vf_loss_stats = torch.zeros(stats_shape, device=device)
- vf_clipfrac_stats = torch.zeros(stats_shape, device=device)
- entropy_stats = torch.zeros(stats_shape, device=device)
- ratio_stats = torch.zeros(stats_shape, device=device)
- model.train()
-
- # trainer state initialization
- self.state.global_step = 0
- self.state.episode = 0
- self.state.max_steps = args.num_total_batches
- self.state.num_train_epochs = args.total_episodes / self.train_dataset_len
- # Compute absolute values for logging, eval, and save if given as ratio
- if args.logging_steps is not None:
- if args.logging_steps < 1:
- self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps)
- else:
- self.state.logging_steps = args.logging_steps
- if args.eval_steps is not None:
- if args.eval_steps < 1:
- self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps)
- else:
- self.state.eval_steps = args.eval_steps
- if args.save_steps is not None:
- if args.save_steps < 1:
- self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps)
- else:
- self.state.save_steps = args.save_steps
- self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
-
- # backward compatibility
- if self.is_deepspeed_enabled:
- self.deepspeed = self.model
- self.model_wrapped = self.model
-
- for update in range(1, args.num_total_batches + 1):
- self.state.episode += 1 * args.batch_size
- data = next(iter_dataloader)
- with torch.no_grad():
- queries = data["input_ids"].to(device)
- context_length = queries.shape[1]
- responses = []
- postprocessed_responses = []
- logprobs = []
- ref_logprobs = []
- scores = []
- sequence_lengths = []
- values = []
- with unwrap_model_for_generation(
- self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
- ) as unwrapped_model:
- query_responses, logitss = batch_generation(
- unwrapped_model.policy,
- queries,
- args.local_rollout_forward_batch_size,
- processing_class.pad_token_id,
- generation_config,
- )
-
- for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size):
- query = queries[i : i + args.local_rollout_forward_batch_size]
- query_response = query_responses[i : i + args.local_rollout_forward_batch_size]
- response = query_response[:, context_length:]
- logits = logitss[i : i + args.local_rollout_forward_batch_size]
- logprob = selective_log_softmax(logits, response)
- del logits
- empty_cache()
-
- if ref_policy is None:
- with self.null_ref_context():
- ref_output = forward(model.policy, query_response, processing_class.pad_token_id)
- else:
- ref_output = forward(ref_policy, query_response, processing_class.pad_token_id)
- ref_logits = ref_output.logits[:, context_length - 1 : -1]
- ref_logits /= args.temperature + 1e-7
- ref_logprob = selective_log_softmax(ref_logits, response)
- del ref_output, ref_logits
- empty_cache()
-
- # Response Processing 1. truncate response after the first occurrence of `stop_token_id`
- postprocessed_response = response
- if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
- postprocessed_response = truncate_response(
- self.stop_token_id, processing_class.pad_token_id, response
- )
-
- # Response Processing 2. run reward model on the truncated responses
- postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
- sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1
- unwrapped_value_model = accelerator.unwrap_model(model).value_model
- full_value, _, _ = get_reward(
- unwrapped_value_model, query_response, processing_class.pad_token_id, context_length
- )
- value = full_value[:, context_length - 1 : -1].squeeze(-1)
- _, score, _ = get_reward(
- reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
- )
-
- responses.append(response)
- postprocessed_responses.append(postprocessed_response)
- logprobs.append(logprob)
- ref_logprobs.append(ref_logprob)
- sequence_lengths.append(sequence_length)
- scores.append(score)
- values.append(value)
- responses = torch.cat(responses, 0)
- postprocessed_responses = torch.cat(postprocessed_responses, 0)
- logprobs = torch.cat(logprobs, 0)
- ref_logprobs = torch.cat(ref_logprobs, 0)
- sequence_lengths = torch.cat(sequence_lengths, 0)
- scores = torch.cat(scores, 0)
- values = torch.cat(values, 0)
- del (logprob, ref_logprob, full_value, value, score, unwrapped_model)
- empty_cache()
- gc.collect()
-
- # Response Processing 3. Filter completion. Ensure that the sample contains stop_token_id
- # Completions not passing that filter will receive a lower score.
- contain_eos_token = torch.any(postprocessed_responses == self.processing_class.eos_token_id, dim=-1)
- if self.args.missing_eos_penalty is not None:
- scores[~contain_eos_token] -= self.args.missing_eos_penalty
- # accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}")
-
- # be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw
- response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1)
- padding_mask = response_idxs > sequence_lengths.unsqueeze(1)
- logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB)
- ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB)
- sequence_lengths_p1 = sequence_lengths + 1
- padding_mask_p1 = response_idxs > (sequence_lengths_p1.unsqueeze(1))
- values = torch.masked_fill(values, padding_mask_p1, 0)
-
- # 4. compute rewards
- # Formula used by http://joschu.net/blog/kl-approx.html for the k1 and k3 estimators
- logr = ref_logprobs - logprobs
- kl = -logr if args.kl_estimator == "k1" else (logr.exp() - 1) - logr # Else statement is k3
- non_score_reward = -args.kl_coef * kl
- rewards = non_score_reward.clone()
- actual_start = torch.arange(rewards.size(0), device=rewards.device)
- actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths)
- rewards[[actual_start, actual_end]] += scores
-
- # 5. whiten rewards
- if args.whiten_rewards:
- rewards = masked_whiten(rewards, mask=~padding_mask_p1, shift_mean=False)
- rewards = torch.masked_fill(rewards, padding_mask_p1, 0)
-
- # 6. compute advantages and returns
- lastgaelam = 0
- advantages_reversed = []
- gen_length = responses.shape[1]
- for t in reversed(range(gen_length)):
- nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0
- delta = rewards[:, t] + args.gamma * nextvalues - values[:, t]
- lastgaelam = delta + args.gamma * args.lam * lastgaelam
- advantages_reversed.append(lastgaelam)
- advantages = torch.stack(advantages_reversed[::-1], axis=1)
- returns = advantages + values
- advantages = masked_whiten(advantages, ~padding_mask)
- advantages = torch.masked_fill(advantages, padding_mask, 0)
- empty_cache()
-
- # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch
- for ppo_epoch_idx in range(args.num_ppo_epochs):
- b_inds = np.random.permutation(args.local_batch_size)
- minibatch_idx = 0
- for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size):
- mini_batch_end = mini_batch_start + args.local_mini_batch_size
- mini_batch_inds = b_inds[mini_batch_start:mini_batch_end]
- gradient_accumulation_idx = 0
- for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size):
- with accelerator.accumulate(model):
- micro_batch_end = micro_batch_start + args.per_device_train_batch_size
- micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end]
- mb_advantage = advantages[micro_batch_inds]
- mb_responses = responses[micro_batch_inds]
- mb_query_responses = query_responses[micro_batch_inds]
- mb_logprobs = logprobs[micro_batch_inds]
- mb_return = returns[micro_batch_inds]
- mb_values = values[micro_batch_inds]
-
- output, vpred_temp = forward(model, mb_query_responses, processing_class.pad_token_id)
- logits = output.logits[:, context_length - 1 : -1]
- logits /= args.temperature + 1e-7
- new_logprobs = selective_log_softmax(logits, mb_responses)
- new_logprobs = torch.masked_fill(
- new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB
- )
- vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1)
- vpred = torch.masked_fill(vpred, padding_mask_p1[micro_batch_inds], 0)
- vpredclipped = torch.clamp(
- vpred,
- mb_values - args.cliprange_value,
- mb_values + args.cliprange_value,
- )
- vf_losses1 = torch.square(vpred - mb_return)
- vf_losses2 = torch.square(vpredclipped - mb_return)
- vf_loss_max = torch.max(vf_losses1, vf_losses2)
- vf_loss = 0.5 * masked_mean(vf_loss_max, ~padding_mask_p1[micro_batch_inds])
- vf_clipfrac = masked_mean(
- (vf_losses2 > vf_losses1).float(), ~padding_mask_p1[micro_batch_inds]
- )
- logprobs_diff = new_logprobs - mb_logprobs
- ratio = torch.exp(logprobs_diff)
- pg_losses = -mb_advantage * ratio
- pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange)
- pg_loss_max = torch.max(pg_losses, pg_losses2)
- pg_loss = masked_mean(pg_loss_max, ~padding_mask[micro_batch_inds])
- loss = pg_loss + args.vf_coef * vf_loss
- accelerator.backward(loss)
- optimizer.step()
- optimizer.zero_grad()
- with torch.no_grad():
- pg_clipfrac = masked_mean(
- (pg_losses2 > pg_losses).float(), ~padding_mask[micro_batch_inds]
- )
- prob_dist = torch.nn.functional.softmax(logits, dim=-1, dtype = torch.float32).to(logits.dtype)
- entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1)
- approxkl = 0.5 * (logprobs_diff**2).mean()
- approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl
- pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
- pg_clipfrac
- )
- pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss
- vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss
- vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
- vf_clipfrac
- )
- entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean()
- ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean()
- gradient_accumulation_idx += 1
- minibatch_idx += 1
- # del everything and empty cache
- # fmt: off
- del (
- output, vpred_temp, logits, new_logprobs, vpred, vpredclipped,
- vf_losses1, vf_losses2, vf_loss, vf_clipfrac, logprobs_diff, ratio, pg_losses, pg_losses2, pg_loss_max,
- pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, mb_return,
- mb_advantage, mb_values, mb_responses, mb_query_responses, mb_logprobs,
- )
- # fmt: on
- empty_cache()
- with torch.no_grad():
- mean_kl = kl.sum(1).mean()
- mean_entropy = (-logprobs).sum(1).mean()
- mean_non_score_reward = non_score_reward.sum(1).mean()
- rlhf_reward = mean_non_score_reward + scores.mean()
- eps = int(self.state.episode / (time.time() - start_time))
- metrics = {}
- metrics["eps"] = eps
- metrics["objective/kl"] = self.accelerator.gather_for_metrics(mean_kl).mean().item()
- metrics["objective/entropy"] = self.accelerator.gather_for_metrics(mean_entropy).mean().item()
- metrics["objective/non_score_reward"] = (
- self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item()
- )
- metrics["objective/rlhf_reward"] = self.accelerator.gather_for_metrics(rlhf_reward).mean().item()
- metrics["objective/scores"] = self.accelerator.gather_for_metrics(scores.mean()).mean().item()
- metrics["policy/approxkl_avg"] = self.accelerator.gather_for_metrics(approxkl_stats).mean().item()
- metrics["policy/clipfrac_avg"] = self.accelerator.gather_for_metrics(pg_clipfrac_stats).mean().item()
- metrics["loss/policy_avg"] = self.accelerator.gather_for_metrics(pg_loss_stats).mean().item()
- metrics["loss/value_avg"] = self.accelerator.gather_for_metrics(vf_loss_stats).mean().item()
- metrics["val/clipfrac_avg"] = self.accelerator.gather_for_metrics(vf_clipfrac_stats).mean().item()
- metrics["policy/entropy_avg"] = self.accelerator.gather_for_metrics(entropy_stats).mean().item()
- metrics["val/ratio"] = self.accelerator.gather_for_metrics(ratio_stats).mean().item()
- metrics["val/ratio_var"] = self.accelerator.gather_for_metrics(ratio_stats).var().item()
- metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item()
- metrics["lr"] = self.lr_scheduler.get_last_lr()[0]
- metrics["episode"] = self.state.episode
- self.state.epoch = self.state.episode / self.train_dataset_len # used by self.log
- self.state.global_step += 1
- self.log(metrics)
-
- self.lr_scheduler.step()
- self.control = self.callback_handler.on_step_end(args, self.state, self.control)
- if self.control.should_save:
- self._save_checkpoint(model, trial=None)
- self.control = self.callback_handler.on_save(self.args, self.state, self.control)
- del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_reward
- empty_cache()
- gc.collect()
-
- if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0:
- self.generate_completions(sampling=True)
- empty_cache()
- del (
- query_responses,
- responses,
- postprocessed_responses,
- logprobs,
- ref_logprobs,
- values,
- sequence_lengths,
- contain_eos_token,
- sequence_lengths_p1,
- response_idxs,
- padding_mask,
- padding_mask_p1,
- rewards,
- actual_start,
- actual_end,
- advantages,
- returns,
- )
- empty_cache()
-
- # HF trainer specifics
- self.control = self.callback_handler.on_train_end(args, self.state, self.control)
- if self.control.should_save:
- self._save_checkpoint(model, trial=None)
- self.control = self.callback_handler.on_save(self.args, self.state, self.control)
-
- def generate_completions(self, sampling: bool = False):
- args = self.args
- processing_class = self.processing_class
- generation_config = GenerationConfig(
- max_new_tokens=self.args.response_length,
- temperature=(0.01 + 1e-7),
- top_k=0.0,
- top_p=1.0,
- do_sample=True,
- )
-
- table = defaultdict(list)
- with unwrap_model_for_generation(
- self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
- ) as unwrapped_model:
- for batch in self.eval_dataloader:
- query = batch["input_ids"]
- with torch.no_grad():
- context_length = query.shape[1]
- query_response, _ = batch_generation(
- unwrapped_model.policy,
- query,
- query.shape[0],
- processing_class.pad_token_id,
- generation_config,
- )
- response = query_response[:, context_length:]
- postprocessed_response = response
- if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
- postprocessed_response = truncate_response(
- self.stop_token_id, processing_class.pad_token_id, response
- )
- table["query"].extend(
- gather_object(processing_class.batch_decode(query, skip_special_tokens=True))
- )
- table["model response"].extend(
- gather_object(processing_class.batch_decode(postprocessed_response))
- )
-
- postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
- _, score, _ = get_reward(
- self.reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
- )
- table["score"].extend(self.accelerator.gather_for_metrics(score).float().cpu().numpy())
-
- if sampling:
- break
- df = pd.DataFrame(table)
-
- if self.accelerator.is_main_process:
- if is_rich_available():
- print_rich_table(df.iloc[0 : 0 + 5])
- if "wandb" in args.report_to:
- import wandb
-
- if wandb.run is not None:
- wandb.log({"completions": wandb.Table(dataframe=df)})
-
- if "comet_ml" in args.report_to:
- log_table_to_comet_experiment(
- name="completions.csv",
- table=df,
- )
-
- # Ensure the model card is saved along with the checkpoint
- def _save_checkpoint(self, model, trial):
- if self.args.hub_model_id is None:
- model_name = Path(self.args.output_dir).name
- else:
- model_name = self.args.hub_model_id.split("/")[-1]
- self.create_model_card(model_name=model_name)
- super()._save_checkpoint(model, trial)
-class UnslothPPOTrainer(_UnslothPPOTrainer):
- """
- Trainer for Proximal Policy Optimization (PPO).
-
- For details on PPO, see the paper: [Proximal Policy Optimization
- Algorithms](https://huggingface.co/papers/1707.06347).
-
- Args:
- args ([`PPOConfig`]):
- Training arguments.
- processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`]):
- Class to process the data.
- model (`torch.nn.Module`):
- Model to be trained. This is the policy model.
- ref_model (`torch.nn.Module`, *optional*):
- Reference model used to compute the KL divergence. If `None`, a copy of the policy model is created.
- reward_model (`torch.nn.Module`):
- Reward model used to compute the rewards.
- train_dataset ([`~datasets.Dataset`]):
- Dataset for training.
- value_model (`torch.nn.Module`):
- Value model used to predict the value of a state.
- data_collator ([`~transformers.DataCollatorWithPadding`], *optional*):
- Data collator to batch and pad samples from the dataset. If `None`, a default data collator is created
- using the `processing_class`.
- eval_dataset ([`~datasets.Dataset`] or `dict` of [`~datasets.Dataset`], *optional*):
- Dataset for evaluation.
- optimizers (`tuple` of `torch.optim.Optimizer` and `torch.optim.lr_scheduler.LambdaLR`, *optional*, defaults to `(None, None)`):
- Tuple containing the optimizer and the learning rate scheduler to use for training. If `None`, the
- optimizer and the learning rate scheduler are created using the
- [`~transformers.Trainer.create_optimizer_and_scheduler`] method.
- callbacks (`list` of [`~transformers.TrainerCallback`], *optional*):
- Callbacks to use during training.
- peft_config ([`~peft.PeftConfig`], *optional*):
- PEFT configuration to use PEFT for training. If `None`, PEFT is not used. If provided, the policy `model`
- will be wrapped with the specified PEFT adapter.
-
- """
- def __init__(
- self,
- args,
- processing_class,
- model,
- ref_model,
- reward_model,
- train_dataset,
- value_model,
- data_collator = None,
- eval_dataset = None,
- callbacks = None,
- peft_config = None,
- **kwargs
- ):
- if args is None: args = UnslothPPOConfig()
- use_bf16 = getattr(args, 'bf16', False)
- if type(use_bf16) is not bool: use_bf16 = False
- use_fp16 = getattr(args, 'fp16', False)
- if type(use_fp16) is not bool: use_fp16 = False
- force_float32 = False
- full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1'
- if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'):
- print('Unsloth: Switching to float32 training since model cannot work with float16')
- force_float32 = True
- mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
- dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None)
- if dtype is None: dtype = model.get_input_embeddings().weight.dtype
- from unsloth_zoo.utils import _get_dtype
- dtype = _get_dtype(dtype)
- float16 = dtype == torch.float16
- if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
- if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
- if force_float32:
- # Forced float32 training
- args.fp16 = False
- args.bf16 = False
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
- if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
- # args.mixed_precision is a new argument which needs to be set now
- elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
- # Mixed precision training
- args.fp16 = float16
- args.bf16 = not float16
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
- if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16'
- # args.mixed_precision is a new argument which needs to be set now
- elif mixed_precision_dtype == 'bfloat16':
- # Both False since bfloat16 full finetuning doesn't do any autocasting.
- args.fp16 = False
- args.bf16 = False
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
- if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
- # args.mixed_precision is a new argument which needs to be set now
-
- if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
- args.eval_strategy = 'steps'
- if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
- ga_steps = getattr(args, 'gradient_accumulation_steps', None)
- if ga_steps is not None and ga_steps > 1:
- from transformers import __version__ as transformers_version
- if Version(transformers_version) <= Version('4.45.2'):
- print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
- '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
- if getattr(args, 'eval_strategy', 'no') != 'no':
- eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
- if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
- if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
- fp16_full_eval = getattr(args, 'fp16_full_eval', False)
- if type(fp16_full_eval) is not bool: fp16_full_eval = False
- bf16_full_eval = getattr(args, 'bf16_full_eval', False)
- if type(bf16_full_eval) is not bool: bf16_full_eval = False
- if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
- if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
- if force_float32:
- args.bf16_full_eval = False
- args.fp16_full_eval = False
- elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
- args.bf16_full_eval = True
- args.fp16_full_eval = False
- elif not bf16_full_eval and not fp16_full_eval:
- args.bf16_full_eval = args.bf16
- args.fp16_full_eval = args.fp16
- _output_logits = False
- if locals().get('compute_metrics', None) is not None: _output_logits = True
- if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
- if _output_logits:
- os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
- if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
- pass
- else:
- model_max_seq_length = getattr(model, 'max_seq_length', None)
- args_max_seq_length = getattr(args, 'max_seq_length', None)
- if args_max_seq_length is None and model_max_seq_length is not None:
- max_seq_length = model.max_seq_length
- if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
- elif args_max_seq_length is not None and model_max_seq_length is not None:
- if args_max_seq_length > model_max_seq_length:
- print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but '
- 'the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.')
- args.max_seq_length = model_max_seq_length
- if model is not None and hasattr(model, 'for_training'):
- model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
- if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
- if 'processing_class' in locals():
- if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
- if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
- __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
- from unsloth_zoo.vision_utils import UnslothVisionDataCollator
- if not isinstance(data_collator, UnslothVisionDataCollator):
- if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
- data_collator = TransformersDataCollatorForLanguageModeling(
- __tokenizer,
- mlm = False,
- mlm_probability = 0.0,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
- data_collator = DataCollatorForSeq2Seq(
- __tokenizer,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- else:
- if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
- if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
- if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
- if not isinstance(data_collator, UnslothVisionDataCollator):
- if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
- if isinstance(data_collator, DataCollatorForSeq2Seq):
- data_collator = DataCollatorForSeq2Seq(
- __tokenizer.tokenizer,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- else:
- data_collator = TransformersDataCollatorForLanguageModeling(
- __tokenizer.tokenizer,
- mlm = False,
- mlm_probability = 0.0,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- other_metrics = []
-
- from unsloth_zoo.logging_utils import PatchRLStatistics
- PatchRLStatistics('ppo_trainer', other_metrics)
-
- # [TODO] Fix up DataParallel multiplying batch sizes
- # [TODO] DDP works, but DP seems to not work? [TODO]
- if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1:
- if getattr(args, "_n_gpu", 1) != 1:
- args._n_gpu = 1
- if "model" in locals() and hasattr(model, "for_training"):
- model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
- super().__init__(
- args = args,
- processing_class = processing_class,
- model = model,
- ref_model = ref_model,
- reward_model = reward_model,
- train_dataset = train_dataset,
- value_model = value_model,
- data_collator = data_collator,
- eval_dataset = eval_dataset,
- callbacks = callbacks,
- peft_config = peft_config,**kwargs)
- if "model" in locals() and hasattr(model, "for_inference"):
- model.for_inference()
- if hasattr(self, 'neftune_hook_handle'):
- self.neftune_hook_handle.remove()
- if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
- if getattr(args, 'neftune_noise_alpha', None) is not None:
- model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
- pass
- if hasattr(self, 'accelerator'):
- scaler = self.accelerator.scaler
- current_model = model
- while hasattr(current_model, 'model'):
- current_model.accelerator_scaler = scaler
- current_model = current_model.model
- current_model.accelerator_scaler = scaler
- pass
- if hasattr(self, 'train'):
- self.train = MethodType(prepare_for_training_mode(self.__class__.train), self)
- pass
- if hasattr(self, 'llm') and self.llm is not None and hasattr(self.llm, 'get_tokenizer'):
- _vllm_tok = self.llm.get_tokenizer()
- _pc = getattr(self, 'processing_class', None) or getattr(self, 'tokenizer', None)
- if _vllm_tok is not None and _pc is not None and getattr(_pc, 'chat_template', None) is not None and getattr(_vllm_tok, 'chat_template', None) is None:
- _vllm_tok.chat_template = _pc.chat_template
- pass
-
-pass
diff --git a/unsloth_compiled_cache/UnslothPRMTrainer.py b/unsloth_compiled_cache/UnslothPRMTrainer.py
deleted file mode 100644
index 8232c86..0000000
--- a/unsloth_compiled_cache/UnslothPRMTrainer.py
+++ /dev/null
@@ -1,1133 +0,0 @@
-"""
-2026.2.1
-2026.2.1
-4.57.6
-0.24.0
-__UNSLOTH_VERSIONING__
-"""
-
-# Unsloth auto generated code
-# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Lesser General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU Lesser General Public License
-# along with this program. If not, see .
-
-from torch import Tensor
-import torch
-import torch.nn as nn
-from torch.nn import functional as F
-from unsloth_zoo.temporary_patches.common import torch_compile
-from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
-from trl.trainer.prm_trainer import (BaseImageProcessor, BaseTrainer, Callable, DataCollator, DataCollatorForTokenClassification, Dataset, EvalPrediction, FeatureExtractionMixin, Optional, PRMConfig, PRMTrainer, PartialState, Path, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, Union, chain, compute_accuracy, disable_dropout_in_model, features, nn, os, textwrap, torch, warnings, BaseImageProcessor, Callable, DataCollator, DataCollatorForTokenClassification, Dataset, EvalPrediction, FeatureExtractionMixin, Optional, PRMConfig, PartialState, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, Union, compute_accuracy, disable_dropout_in_model, features, nn, os, torch, warnings, Optional, PreTrainedModel, os, torch)
-
-
-import os
-from typing import *
-from dataclasses import dataclass, field
-from packaging.version import Version
-import torch
-import numpy as np
-from contextlib import nullcontext
-from torch.nn import functional as F
-import inspect
-from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
-from transformers.training_args import ParallelMode
-from unsloth_zoo.device_type import DEVICE_TYPE, device_synchronize
-
-# Wrap trainer with padding to right and enable training mode
-# Also patches W&B since multiple runs must use wandb.finish()
-import functools
-from types import MethodType
-try:
- from unsloth_zoo.gradient_checkpointing import reset_unsloth_gradient_checkpointing_buffers
-except:
- def reset_unsloth_gradient_checkpointing_buffers(): pass
-def prepare_for_training_mode(f):
- @functools.wraps(f)
- def wrapper(self, *args, **kwargs):
- # Enable training mode
- _was_training = None
- # Get gradient checkpointing setting from training arguments
- use_gc = getattr(self.args, 'gradient_checkpointing', True)
- if hasattr(self, 'model') and hasattr(self.model, "training"):
- _was_training = self.model.training
- if hasattr(self, 'model') and hasattr(self.model, "for_training"):
- self.model.for_training(use_gradient_checkpointing=use_gc)
- output = f(self, *args, **kwargs)
- # Restore previous mode when possible
- if hasattr(self, 'model') and hasattr(self.model, "for_inference"):
- if _was_training is False:
- self.model.for_inference()
- elif _was_training is True and hasattr(self.model, "for_training"):
- self.model.for_training(use_gradient_checkpointing=use_gc)
- # Reset gradient checkpointing buffers to free memory while staying ready for next run
- try:
- reset_unsloth_gradient_checkpointing_buffers()
- except:
- pass
- # Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run
- try:
- import wandb
- wandb.finish()
- except:
- pass
- return output
- return wrapper
-pass
-
-torch_compile_options = {
- "epilogue_fusion" : True,
- "max_autotune" : False,
- "shape_padding" : True,
- "trace.enabled" : False,
- "triton.cudagraphs" : False,
-}
-
-@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
-def chunked_hidden_states_selective_log_softmax(
- hidden_states: torch.Tensor,
- lm_head: torch.Tensor,
- index: torch.Tensor,
- chunks: int = 4,
- logit_scale_multiply: float = 0.0,
- logit_scale_divide: float = 0.0,
- logit_softcapping: float = 0.0,
- temperature: float = 1.0,
-) -> torch.Tensor:
- # All Unsloth Zoo code licensed under AGPL3
- flat_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1])
- flat_index = index.reshape(-1)
-
- chunked_hidden_states = torch.chunk(flat_hidden_states, chunks=chunks, dim=0)
- chunked_index = torch.chunk(flat_index, chunks=chunks, dim=0)
-
- all_per_token_logps = []
-
- for chunk_hidden_states, chunk_index in zip(chunked_hidden_states, chunked_index):
- chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t()
-
- if logit_scale_multiply != 0.0:
- chunk_logits = chunk_logits * logit_scale_multiply
- if logit_scale_divide != 0.0:
- chunk_logits = chunk_logits / logit_scale_divide
- if logit_softcapping != 0.0:
- chunk_logits = chunk_logits * torch.tanh(chunk_logits / logit_softcapping)
-
- chunk_logits = chunk_logits.to(torch.float32)
-
- if temperature != 1.0:
- chunk_logits = chunk_logits / temperature
-
- selected_logits = torch.gather(chunk_logits, dim=-1, index=chunk_index.unsqueeze(-1)).squeeze(-1)
- logsumexp_values = torch.logsumexp(chunk_logits, dim=-1)
- per_token_logps = selected_logits - logsumexp_values
- all_per_token_logps.append(per_token_logps)
-
- all_per_token_logps = torch.concat(all_per_token_logps)
-
- all_per_token_logps = all_per_token_logps.reshape((hidden_states.shape[0], hidden_states.shape[1]))
- return all_per_token_logps
-
-@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
-def chunked_selective_log_softmax(logits, index):
- # Split into 4 chunks only
- chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
- chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
- all_per_token_logps = []
- # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
- for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
- chunk_logits = chunk_logits.to(torch.float32)
- selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
- logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
- per_token_logps = selected_logits - logsumexp_values
- all_per_token_logps.append(per_token_logps)
- pass
- all_per_token_logps = torch.concat(all_per_token_logps)
- all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
- return all_per_token_logps
-
-def calculate_pad_tokens_in_prompt(
- input_ids: torch.Tensor,
- logits_to_keep: int,
- pad_token_id: int
-) -> torch.Tensor:
- """
- Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens
- """
- if logits_to_keep >= input_ids.shape[1]:
- raise ValueError("logits_to_keep must be smaller than the sequence length.")
-
- prompt_section = input_ids[:, :-logits_to_keep]
-
- padding_mask = (prompt_section == pad_token_id)
-
- pad_token_counts = padding_mask.sum(dim=1)
-
- return pad_token_counts
-
-def create_completion_attention_mask(
- completion_input_ids: torch.Tensor,
- left_pad_tokens_per_prompt: torch.Tensor,
- max_left_pad: int,
- pad_token_id: int
-) -> torch.Tensor:
- """
- Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad]
-
- Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens
- and pad are pad tokens, this function would make a completion mask that would 0 out the pad
- and p tokens. so in this example [0,0,0,1,1,1,0,0,0]
- """
- batch_size, completion_len = completion_input_ids.shape
- device = completion_input_ids.device
-
- num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt
-
- indices = torch.arange(completion_len, device=device).unsqueeze(0)
- shift_mask = indices >= num_tokens_to_mask.unsqueeze(1)
-
- non_padding_mask = (completion_input_ids != pad_token_id)
-
- final_mask = shift_mask & non_padding_mask
-
- return final_mask
-
-def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor:
- """
- Moves all padding tokens in each sequence of a batch to the right.
- """
- mask = (tensor != pad_id)
- # Must do stable=True since binary mark is unordered
- sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True)
- packed_tensor = torch.gather(tensor, 1, sorted_indices)
- return packed_tensor
-
-def align_logprobs_with_mask(
- logprob_tensor: torch.Tensor,
- attention_mask: torch.Tensor,
- pad_value: float = 0.0
-) -> torch.Tensor:
- """
- Aligns a log probability tensor with a given attention mask.
- """
-
- device = logprob_tensor.device
- batch_size, logprob_seq_len = logprob_tensor.shape
- mask_seq_len = attention_mask.shape[1]
-
- padded_logprobs = torch.full(
- attention_mask.shape,
- fill_value=pad_value,
- dtype=logprob_tensor.dtype,
- device=device
- )
-
- left_pad_counts = torch.argmax(attention_mask, dim=1)
-
- cols = torch.arange(logprob_seq_len, device=device)
- dest_indices = left_pad_counts.unsqueeze(1) + cols
-
- # Create destination row indices
- # Shape: [batch_size, logprob_seq_len]
- row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices)
-
- # --- 4. Filter out-of-bounds indices and perform assignment ---
- # Create a mask to identify only the indices that are within the bounds
- # of the target tensor's sequence length.
- valid_mask = dest_indices < mask_seq_len
-
- # Use this mask to select only the valid row indices, column indices,
- # and the corresponding values from the logprob tensor.
- # This flattens the selected elements into 1D tensors.
- valid_rows = row_indices[valid_mask]
- valid_cols = dest_indices[valid_mask]
- valid_vals = logprob_tensor[valid_mask]
-
- # Place the valid values into their correct positions in the padded tensor
- # using a single, efficient advanced indexing operation.
- padded_logprobs[valid_rows, valid_cols] = valid_vals
-
- return padded_logprobs
-
-def autotune_batch_and_chunks(
- total_input_rows,
- seq_len,
- hidden_size,
- vocab_size,
- dtype_bytes=16,
- multiplier=None
-):
- if multiplier is None:
- final_m = max(4, seq_len // 4096)
- else:
- final_m = multiplier
-
- if torch.cuda.is_available():
- free_bytes, _ = torch.cuda.mem_get_info()
- limit_gb = (free_bytes / (1024**3))*.80
- elif hasattr(torch, "xpu") and torch.xpu.is_available():
- # For XPU: estimate free memory from total - reserved
- total_mem = torch.xpu.get_device_properties(0).total_memory
- reserved_mem = torch.xpu.memory_reserved()
- free_bytes = total_mem - reserved_mem
- limit_gb = (free_bytes / (1024**3)) * 0.80
- else:
- # Fallback: assume 8GB available
- limit_gb = 8.0
-
- bytes_to_gb = 1024**3
-
- b_vals = torch.arange(total_input_rows, 0, -1, device='cpu', dtype=torch.float32)
-
- hidden_gb = (b_vals * seq_len * hidden_size * dtype_bytes) / bytes_to_gb
-
- base_logits = ((b_vals/total_input_rows) * b_vals * seq_len * vocab_size * dtype_bytes) / bytes_to_gb
- logits_gb = base_logits / final_m
-
- total_mem_gb = hidden_gb + logits_gb
-
- valid_mask = total_mem_gb <= limit_gb
- valid_indices = torch.nonzero(valid_mask, as_tuple=False)
-
- if valid_indices.shape[0] == 0:
- #This means your GPU will OOM
- return 4, final_m
-
- best_idx = valid_indices[0].item()
- final_b = int(b_vals[best_idx].item())
-
- return final_b, final_m
-@dataclass
-class UnslothPRMConfig(PRMConfig):
- """
-
- Configuration class for the [`PRMTrainer`].
-
- This class includes only the parameters that are specific to PRM training. For a full list of training arguments,
- please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may
- differ from those in [`~transformers.TrainingArguments`].
-
- Using [`~transformers.HfArgumentParser`] we can turn this class into
- [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
- command line.
-
- Parameters:
- max_length (`int` or `None`, *optional*, defaults to `1024`):
- Maximum length of the sequences (prompt + completion) used for truncation.
- max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
- Maximum length of the prompt used for truncation.
- max_completion_length (`int`, *optional*):
- Maximum length of the completion used for truncation. The completion is the concatenation of the steps.
- disable_dropout (`bool`, *optional*, defaults to `True`):
- Whether to disable dropout in the model.
- step_separator (`str`, *optional*, defaults to `"\n"`):
- Separator used to separate each step of the reasoning process.
- train_on_last_step_only (`bool`, *optional*, defaults to `False`):
- Whether to train only on the last step.
- dataset_num_proc (`int`, *optional*):
- Number of processes to use for processing the dataset.
-
- """
- vllm_sampling_params: Optional[Any] = field(
- default = None,
- metadata = {'help': 'vLLM SamplingParams'},
- )
- unsloth_num_chunks : Optional[int] = field(
- default = -1,
- metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
- )
- unsloth_logit_chunk_multiplier : Optional[int] = field(
- default = None,
- metadata = {'help': 'Multiplier for chunked logit computations.'},
- )
- unsloth_grpo_mini_batch : Optional[int] = field(
- default = None,
- metadata = {'help': 'Mini batch size for GRPO hidden state accumulation. Default is None unless user defines it.'},
- )
- max_seq_length : Optional[int] = field(
- default = None,
- metadata = {'help': 'Maximum sequence length to truncate to.'},
- )
- def __init__(
- self,
- output_dir = None,
- overwrite_output_dir = None,
- do_train = False,
- do_eval = False,
- do_predict = False,
- eval_strategy = 'no',
- prediction_loss_only = False,
- per_device_train_batch_size = 4,
- per_device_eval_batch_size = 4,
- per_gpu_train_batch_size = None,
- per_gpu_eval_batch_size = None,
- gradient_accumulation_steps = 2,
- eval_accumulation_steps = 2,
- eval_delay = 0,
- torch_empty_cache_steps = 250,
- learning_rate = 5e-05,
- weight_decay = 0.01,
- adam_beta1 = 0.9,
- adam_beta2 = 0.999,
- adam_epsilon = 1e-08,
- max_grad_norm = 1.0,
- num_train_epochs = 3.0,
- max_steps = -1,
- lr_scheduler_type = 'linear',
- lr_scheduler_kwargs = None,
- warmup_ratio = 0.1,
- warmup_steps = 0,
- log_level = 'passive',
- log_level_replica = 'warning',
- log_on_each_node = True,
- logging_dir = None,
- logging_strategy = 'steps',
- logging_first_step = False,
- logging_steps = 1,
- logging_nan_inf_filter = False,
- save_strategy = 'steps',
- save_steps = 500,
- save_total_limit = None,
- save_safetensors = True,
- save_on_each_node = False,
- save_only_model = False,
- restore_callback_states_from_checkpoint = False,
- no_cuda = False,
- use_cpu = False,
- use_mps_device = False,
- seed = 3407,
- data_seed = 3407,
- jit_mode_eval = False,
- bf16 = False,
- fp16 = False,
- fp16_opt_level = 'O1',
- half_precision_backend = 'auto',
- bf16_full_eval = False,
- fp16_full_eval = False,
- tf32 = None,
- local_rank = -1,
- ddp_backend = None,
- tpu_num_cores = None,
- tpu_metrics_debug = False,
- debug = '',
- dataloader_drop_last = False,
- eval_steps = None,
- dataloader_num_workers = 0,
- dataloader_prefetch_factor = None,
- past_index = -1,
- run_name = None,
- disable_tqdm = None,
- remove_unused_columns = True,
- label_names = None,
- load_best_model_at_end = False,
- metric_for_best_model = None,
- greater_is_better = None,
- ignore_data_skip = False,
- fsdp = None,
- fsdp_min_num_params = 0,
- fsdp_config = None,
- fsdp_transformer_layer_cls_to_wrap = None,
- accelerator_config = None,
- parallelism_config = None,
- deepspeed = None,
- label_smoothing_factor = 0.0,
- optim = 'adamw_8bit',
- optim_args = None,
- adafactor = False,
- group_by_length = False,
- length_column_name = 'length',
- report_to = 'none',
- project = 'huggingface',
- trackio_space_id = 'trackio',
- ddp_find_unused_parameters = None,
- ddp_bucket_cap_mb = None,
- ddp_broadcast_buffers = None,
- dataloader_pin_memory = True,
- dataloader_persistent_workers = False,
- skip_memory_metrics = True,
- use_legacy_prediction_loop = False,
- push_to_hub = False,
- resume_from_checkpoint = None,
- hub_model_id = None,
- hub_strategy = 'every_save',
- hub_token = None,
- hub_private_repo = None,
- hub_always_push = False,
- hub_revision = None,
- gradient_checkpointing = True,
- gradient_checkpointing_kwargs = None,
- include_inputs_for_metrics = False,
- eval_do_concat_batches = True,
- fp16_backend = 'auto',
- push_to_hub_model_id = None,
- push_to_hub_organization = None,
- push_to_hub_token = None,
- mp_parameters = '',
- auto_find_batch_size = False,
- full_determinism = False,
- torchdynamo = None,
- ray_scope = 'last',
- ddp_timeout = 1800,
- torch_compile = False,
- torch_compile_backend = None,
- torch_compile_mode = None,
- include_tokens_per_second = False,
- include_num_input_tokens_seen = False,
- neftune_noise_alpha = None,
- optim_target_modules = None,
- batch_eval_metrics = False,
- eval_on_start = False,
- use_liger_kernel = False,
- liger_kernel_config = None,
- eval_use_gather_object = False,
- average_tokens_across_devices = True,
- max_length = 1024,
- max_prompt_length = 512,
- max_completion_length = None,
- disable_dropout = True,
- step_separator = '\
-',
- train_on_last_step_only = False,
- dataset_num_proc = None,
- vllm_sampling_params = None,
- unsloth_num_chunks = -1,
- unsloth_logit_chunk_multiplier = None,
- unsloth_grpo_mini_batch = None,
- max_seq_length = None,
- **kwargs,
- ):
- if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
- if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
- if num_train_epochs is None:
- num_train_epochs = 3.0 # Default to 3 epochs if None, max_steps will override
- if output_dir is None and save_strategy == 'steps' and save_steps == 500:
- output_dir = 'unsloth_training_checkpoints'
- save_strategy = 'no'
- import multiprocessing as _mp
- if _mp.get_start_method() != 'fork':
- dataset_num_proc = None
- elif dataset_num_proc is None:
- import psutil
- dataset_num_proc = min(max((psutil.cpu_count() or 1)+4, 2), 64)
- memory_gb_left = psutil.virtual_memory().available / (1024**3)
- if memory_gb_left <= 2: dataset_num_proc = 1
- else: dataset_num_proc = min(dataset_num_proc, int(memory_gb_left))
-
- super().__init__(
- output_dir = output_dir,
- overwrite_output_dir = overwrite_output_dir,
- do_train = do_train,
- do_eval = do_eval,
- do_predict = do_predict,
- eval_strategy = eval_strategy,
- prediction_loss_only = prediction_loss_only,
- per_device_train_batch_size = per_device_train_batch_size,
- per_device_eval_batch_size = per_device_eval_batch_size,
- per_gpu_train_batch_size = per_gpu_train_batch_size,
- per_gpu_eval_batch_size = per_gpu_eval_batch_size,
- gradient_accumulation_steps = gradient_accumulation_steps,
- eval_accumulation_steps = eval_accumulation_steps,
- eval_delay = eval_delay,
- torch_empty_cache_steps = torch_empty_cache_steps,
- learning_rate = learning_rate,
- weight_decay = weight_decay,
- adam_beta1 = adam_beta1,
- adam_beta2 = adam_beta2,
- adam_epsilon = adam_epsilon,
- max_grad_norm = max_grad_norm,
- num_train_epochs = num_train_epochs,
- max_steps = max_steps,
- lr_scheduler_type = lr_scheduler_type,
- lr_scheduler_kwargs = lr_scheduler_kwargs,
- warmup_ratio = warmup_ratio,
- warmup_steps = warmup_steps,
- log_level = log_level,
- log_level_replica = log_level_replica,
- log_on_each_node = log_on_each_node,
- logging_dir = logging_dir,
- logging_strategy = logging_strategy,
- logging_first_step = logging_first_step,
- logging_steps = logging_steps,
- logging_nan_inf_filter = logging_nan_inf_filter,
- save_strategy = save_strategy,
- save_steps = save_steps,
- save_total_limit = save_total_limit,
- save_safetensors = save_safetensors,
- save_on_each_node = save_on_each_node,
- save_only_model = save_only_model,
- restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
- no_cuda = no_cuda,
- use_cpu = use_cpu,
- use_mps_device = use_mps_device,
- seed = seed,
- data_seed = data_seed,
- jit_mode_eval = jit_mode_eval,
- bf16 = bf16,
- fp16 = fp16,
- fp16_opt_level = fp16_opt_level,
- half_precision_backend = half_precision_backend,
- bf16_full_eval = bf16_full_eval,
- fp16_full_eval = fp16_full_eval,
- tf32 = tf32,
- local_rank = local_rank,
- ddp_backend = ddp_backend,
- tpu_num_cores = tpu_num_cores,
- tpu_metrics_debug = tpu_metrics_debug,
- debug = debug,
- dataloader_drop_last = dataloader_drop_last,
- eval_steps = eval_steps,
- dataloader_num_workers = dataloader_num_workers,
- dataloader_prefetch_factor = dataloader_prefetch_factor,
- past_index = past_index,
- run_name = run_name,
- disable_tqdm = disable_tqdm,
- remove_unused_columns = remove_unused_columns,
- label_names = label_names,
- load_best_model_at_end = load_best_model_at_end,
- metric_for_best_model = metric_for_best_model,
- greater_is_better = greater_is_better,
- ignore_data_skip = ignore_data_skip,
- fsdp = fsdp,
- fsdp_min_num_params = fsdp_min_num_params,
- fsdp_config = fsdp_config,
- fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
- accelerator_config = accelerator_config,
- parallelism_config = parallelism_config,
- deepspeed = deepspeed,
- label_smoothing_factor = label_smoothing_factor,
- optim = optim,
- optim_args = optim_args,
- adafactor = adafactor,
- group_by_length = group_by_length,
- length_column_name = length_column_name,
- report_to = report_to,
- project = project,
- trackio_space_id = trackio_space_id,
- ddp_find_unused_parameters = ddp_find_unused_parameters,
- ddp_bucket_cap_mb = ddp_bucket_cap_mb,
- ddp_broadcast_buffers = ddp_broadcast_buffers,
- dataloader_pin_memory = dataloader_pin_memory,
- dataloader_persistent_workers = dataloader_persistent_workers,
- skip_memory_metrics = skip_memory_metrics,
- use_legacy_prediction_loop = use_legacy_prediction_loop,
- push_to_hub = push_to_hub,
- resume_from_checkpoint = resume_from_checkpoint,
- hub_model_id = hub_model_id,
- hub_strategy = hub_strategy,
- hub_token = hub_token,
- hub_private_repo = hub_private_repo,
- hub_always_push = hub_always_push,
- hub_revision = hub_revision,
- gradient_checkpointing = gradient_checkpointing,
- gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
- include_inputs_for_metrics = include_inputs_for_metrics,
- eval_do_concat_batches = eval_do_concat_batches,
- fp16_backend = fp16_backend,
- push_to_hub_model_id = push_to_hub_model_id,
- push_to_hub_organization = push_to_hub_organization,
- push_to_hub_token = push_to_hub_token,
- mp_parameters = mp_parameters,
- auto_find_batch_size = auto_find_batch_size,
- full_determinism = full_determinism,
- torchdynamo = torchdynamo,
- ray_scope = ray_scope,
- ddp_timeout = ddp_timeout,
- torch_compile = torch_compile,
- torch_compile_backend = torch_compile_backend,
- torch_compile_mode = torch_compile_mode,
- include_tokens_per_second = include_tokens_per_second,
- include_num_input_tokens_seen = include_num_input_tokens_seen,
- neftune_noise_alpha = neftune_noise_alpha,
- optim_target_modules = optim_target_modules,
- batch_eval_metrics = batch_eval_metrics,
- eval_on_start = eval_on_start,
- use_liger_kernel = use_liger_kernel,
- liger_kernel_config = liger_kernel_config,
- eval_use_gather_object = eval_use_gather_object,
- average_tokens_across_devices = average_tokens_across_devices,
- max_length = max_length,
- max_prompt_length = max_prompt_length,
- max_completion_length = max_completion_length,
- disable_dropout = disable_dropout,
- step_separator = step_separator,
- train_on_last_step_only = train_on_last_step_only,
- dataset_num_proc = dataset_num_proc,**kwargs)
- self.vllm_sampling_params = vllm_sampling_params
- self.unsloth_num_chunks = unsloth_num_chunks
- if unsloth_grpo_mini_batch is not None:
- if self.generation_batch_size >= unsloth_grpo_mini_batch:
- self.unsloth_grpo_mini_batch = unsloth_grpo_mini_batch
- else:
- raise ValueError(
- f"Unsloth GRPO mini batch size needs to be less than or equal to the effective generation batch size, "
- f"which is self.per_device_train_batch_size * gradient_accumulation_steps."
- )
- self.unsloth_logit_chunk_multiplier = unsloth_logit_chunk_multiplier
- self.max_seq_length = max_seq_length
-
-pass
-
-class _UnslothPRMTrainer(BaseTrainer):
- """"""
-
- _tag_names = ["trl", "prm"]
- _name = "PRM"
- _paper = {
- "title": "Solving math word problems with process-and outcome-based feedback",
- "id": "2211.14275",
- # docstyle-ignore
- "citation": textwrap.dedent("""\
- @article{uesato2022solving,
- title = {{Solving Math Word Problems With Process- and Outcome-Based Feedback}},
- author = {Uesato, Jonathan and Kushman, Nate and Kumar, Ramana and Song, Francis and Siegel, Noah and Wang, Lisa and Creswell, Antonia and Irving, Geoffrey and Higgins, Irina},
- year = 2022,
- journal = {arXiv preprint arXiv:2211.14275}
- }"""),
- }
-
- def __init__(
- self,
- model: Optional[Union[PreTrainedModel, nn.Module]] = None,
- args: Optional[PRMConfig] = None,
- data_collator: Optional[DataCollator] = None,
- train_dataset: Optional[Dataset] = None,
- eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
- processing_class: Optional[
- Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
- ] = None,
- model_init: Optional[Callable[[], PreTrainedModel]] = None,
- compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
- callbacks: Optional[list[TrainerCallback]] = None,
- optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
- None,
- None,
- ),
- preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
- peft_config: Optional[dict] = None,
- ):
- if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"):
- warnings.warn(
- "This trainer will soon be moved to trl.experimental and is a candidate for removal. If you rely on "
- "it and want it to remain, please share your comments here: "
- "https://github.com/huggingface/trl/issues/4223. Silence this warning by setting environment variable "
- "TRL_EXPERIMENTAL_SILENCE=1."
- )
- if False:
- pass
-
- # Disable dropout in the model
- if args.disable_dropout:
- disable_dropout_in_model(model)
-
- if compute_metrics is None:
- compute_metrics = compute_accuracy
-
- if data_collator is None:
- if processing_class is None:
- raise ValueError(
- "A processing_class must be specified when using the default DataCollatorForTokenClassification"
- )
- data_collator = DataCollatorForTokenClassification(processing_class, max_length=args.max_length)
-
- if "input_ids" not in train_dataset.column_names:
- with PartialState().main_process_first():
- fn_kwargs = {
- "tokenizer": processing_class,
- "step_separator": args.step_separator,
- "max_length": args.max_length,
- "max_prompt_length": args.max_prompt_length,
- "max_completion_length": args.max_completion_length,
- "train_on_last_step_only": args.train_on_last_step_only,
- }
- train_fn_kwargs = {**fn_kwargs, "is_eval": False}
- train_dataset = train_dataset.map(
- self.tokenize_row,
- fn_kwargs=train_fn_kwargs,
- num_proc=args.dataset_num_proc,
- remove_columns=train_dataset.features,
- desc="Tokenizing train dataset",
- features=features.Features( # needed to avoid map to cast labels to bool
- {
- "labels": features.Sequence(features.Value("int64")),
- "input_ids": features.Sequence(features.Value("int64")),
- }
- ),
- )
-
- eval_fn_kwargs = {**fn_kwargs, "is_eval": True}
- if eval_dataset is not None:
- eval_dataset = eval_dataset.map(
- self.tokenize_row,
- fn_kwargs=eval_fn_kwargs,
- num_proc=args.dataset_num_proc,
- remove_columns=eval_dataset.features,
- desc="Tokenizing eval dataset",
- features=features.Features( # needed to avoid map to cast labels to bool
- {
- "labels": features.Sequence(features.Value("int64")),
- "input_ids": features.Sequence(features.Value("int64")),
- }
- ),
- )
-
- super().__init__(
- model=model,
- args=args,
- data_collator=data_collator,
- train_dataset=train_dataset,
- eval_dataset=eval_dataset,
- processing_class=processing_class,
- model_init=model_init,
- compute_metrics=compute_metrics,
- callbacks=callbacks,
- optimizers=optimizers,
- preprocess_logits_for_metrics=preprocess_logits_for_metrics,
- )
-
- # Add tags for models that have been loaded with the correct transformers version
- if hasattr(self.model, "add_model_tags"):
- self.model.add_model_tags(self._tag_names)
-
- @staticmethod
- def tokenize_row(
- features,
- tokenizer,
- step_separator,
- max_length,
- max_prompt_length,
- max_completion_length,
- train_on_last_step_only,
- is_eval,
- ):
- r"""
- Tokenize a row of the dataset.
-
- Args:
- features (`dict[str, str]`):
- Row of the dataset, should contain the keys `"prompt"`, `"completions"`, and `"labels"`.
- tokenizer ([`~transformers.PreTrainedTokenizerBase`]):
- Tokenizer used to process the data.
- step_separator (`str`):
- Separator between steps in the completion.
- max_length (`int` or `None`):
- Maximum length of the sequences (prompt + completion). If `None`, the sequences are not truncated.
- max_prompt_length (`int` or `None`):
- Maximum length of the prompt. If `None`, the prompt is not truncated.
- max_completion_length (`int` or `None`):
- Maximum length of the completion sequences. If `None`, the completion sequences are not truncated.
- train_on_last_step_only (`bool`):
- Whether to train only on the last step. If `True`, the labels are `-100` for all tokens except the last
- token of the completion.
- is_eval (`bool`):
- Whether the function is used to tokenize samples from a training or an evaluation dataset. Used only if
- `train_on_last_step_only` is set to `True`.
-
- Returns:
- `dict[str, list[int]]`:
- Tokenized sequences with the keys `"input_ids"`, and `"labels".
-
- Example:
- ```python
- >>> from transformers import AutoTokenizer
-
- >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
- >>> features = {
- ... "prompt": "Which number is larger, 9.8 or 9.11?",
- ... "completions": ["11 is greater than 8.", "Hence, 9.11 > 9.8."],
- ... "labels": [True, False],
- ... }
- >>> PRMTrainer.tokenize_row(
- ... features, tokenizer, "\n", max_completion_length=None, train_on_last_step_only=False, is_eval=False
- ... )
- {'input_ids': [23085, 1372, 374, 8131, 11, 220, 24, 13, 23, 476, 220, 24, 13, 16, 16, 30, 16, 16, 374, 7046, 1091, 220, 23, 13, 198, 39, 763, 11, 220, 24, 13, 16, 16, 861, 220, 24, 13, 23, 13, 198],
- 'labels': [-100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 0]}
- ```
- """
- # Tokenize the prompt and completions
- prompt_ids = tokenizer(features["prompt"], add_special_tokens=False)["input_ids"]
- completions_ids = [
- tokenizer(completion, add_special_tokens=False)["input_ids"] for completion in features["completions"]
- ]
- if train_on_last_step_only and not is_eval:
- labels = [-100] * (len(features["labels"]) - 1) + [int(features["labels"][-1])]
- else:
- labels = [int(label) for label in features["labels"]]
-
- # Get the ID of the separator token and add it to the completions
- separator_ids = tokenizer.encode(step_separator, add_special_tokens=False)
- completions_ids = [completion + separator_ids for completion in completions_ids]
-
- # Create the label
- labels = [[-100] * (len(completion) - 1) + [label] for completion, label in zip(completions_ids, labels)]
-
- # Join the completions and labels steps
- completion_ids = list(chain(*completions_ids))
- labels = list(chain(*labels))
-
- if tokenizer.bos_token_id is not None:
- prompt_ids = [tokenizer.bos_token_id] + prompt_ids
-
- # Truncate prompt and completion sequences
- if max_prompt_length is not None:
- prompt_ids = prompt_ids[-max_prompt_length:]
- if max_completion_length is not None:
- completion_ids = completion_ids[:max_completion_length]
- labels = labels[:max_completion_length]
-
- input_ids = prompt_ids + completion_ids
- labels = [-100] * len(prompt_ids) + labels
-
- if max_length is not None:
- input_ids = input_ids[:max_length]
- labels = labels[:max_length]
-
- return {"input_ids": input_ids, "labels": labels}
-
- # Ensure the model card is saved along with the checkpoint
- def _save_checkpoint(self, model, trial):
- if self.args.hub_model_id is None:
- model_name = Path(self.args.output_dir).name
- else:
- model_name = self.args.hub_model_id.split("/")[-1]
- self.create_model_card(model_name=model_name)
- super()._save_checkpoint(model, trial)
-class UnslothPRMTrainer(_UnslothPRMTrainer):
- """
-
- Initialize PRMTrainer.
-
- Args:
- model ([`~transformers.PreTrainedModel`]):
- The model to train, preferably an `AutoModelForTokenClassification`.
- args ([`PRMConfig`]):
- The arguments to use for training.
- data_collator ([`~transformers.DataCollator`]):
- The data collator to use for training. If None is specified, the default data collator
- ([`~transformers.DataCollatorForTokenClassification`]) will be used which will pad the sequences to the
- maximum length of the sequences in the batch, given a dataset of paired sequences.
- train_dataset ([`~datasets.Dataset`]):
- The dataset to use for training.
- eval_dataset ([`~datasets.Dataset`]):
- The dataset to use for evaluation.
- processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*):
- Processing class used to process the data. If provided, will be used to automatically process the inputs
- for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
- reuse the fine-tuned model.
- model_init (`Callable[[], transformers.PreTrainedModel]`):
- The model initializer to use for training. If None is specified, the default model initializer will be
- used.
- compute_metrics (`Callable[[transformers.EvalPrediction], dict]`, *optional* defaults to `compute_accuracy`):
- The metrics to use for evaluation. If no metrics are specified, the default metric (`compute_accuracy`)
- will be used.
- callbacks (`list[transformers.TrainerCallback]`):
- The callbacks to use for training.
- optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
- The optimizer and scheduler to use for training.
- preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
- The function to use to preprocess the logits before computing the metrics.
- peft_config (`dict`, defaults to `None`):
- The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in
- a PEFT model.
-
- """
- def __init__(
- self,
- model = None,
- args = None,
- data_collator = None,
- train_dataset = None,
- eval_dataset = None,
- processing_class = None,
- model_init = None,
- compute_metrics = None,
- callbacks = None,
- preprocess_logits_for_metrics = None,
- peft_config = None,
- **kwargs
- ):
- if args is None: args = UnslothPRMConfig()
- use_bf16 = getattr(args, 'bf16', False)
- if type(use_bf16) is not bool: use_bf16 = False
- use_fp16 = getattr(args, 'fp16', False)
- if type(use_fp16) is not bool: use_fp16 = False
- force_float32 = False
- full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1'
- if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'):
- print('Unsloth: Switching to float32 training since model cannot work with float16')
- force_float32 = True
- mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
- dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None)
- if dtype is None: dtype = model.get_input_embeddings().weight.dtype
- from unsloth_zoo.utils import _get_dtype
- dtype = _get_dtype(dtype)
- float16 = dtype == torch.float16
- if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
- if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
- if force_float32:
- # Forced float32 training
- args.fp16 = False
- args.bf16 = False
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
- if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
- # args.mixed_precision is a new argument which needs to be set now
- elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
- # Mixed precision training
- args.fp16 = float16
- args.bf16 = not float16
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
- if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16'
- # args.mixed_precision is a new argument which needs to be set now
- elif mixed_precision_dtype == 'bfloat16':
- # Both False since bfloat16 full finetuning doesn't do any autocasting.
- args.fp16 = False
- args.bf16 = False
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
- if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
- # args.mixed_precision is a new argument which needs to be set now
-
- if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
- args.eval_strategy = 'steps'
- if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
- ga_steps = getattr(args, 'gradient_accumulation_steps', None)
- if ga_steps is not None and ga_steps > 1:
- from transformers import __version__ as transformers_version
- if Version(transformers_version) <= Version('4.45.2'):
- print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
- '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
- if getattr(args, 'eval_strategy', 'no') != 'no':
- eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
- if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
- if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
- fp16_full_eval = getattr(args, 'fp16_full_eval', False)
- if type(fp16_full_eval) is not bool: fp16_full_eval = False
- bf16_full_eval = getattr(args, 'bf16_full_eval', False)
- if type(bf16_full_eval) is not bool: bf16_full_eval = False
- if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
- if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
- if force_float32:
- args.bf16_full_eval = False
- args.fp16_full_eval = False
- elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
- args.bf16_full_eval = True
- args.fp16_full_eval = False
- elif not bf16_full_eval and not fp16_full_eval:
- args.bf16_full_eval = args.bf16
- args.fp16_full_eval = args.fp16
- _output_logits = False
- if locals().get('compute_metrics', None) is not None: _output_logits = True
- if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
- if _output_logits:
- os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
- if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
- pass
- else:
- model_max_seq_length = getattr(model, 'max_seq_length', None)
- args_max_seq_length = getattr(args, 'max_seq_length', None)
- if args_max_seq_length is None and model_max_seq_length is not None:
- max_seq_length = model.max_seq_length
- if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
- elif args_max_seq_length is not None and model_max_seq_length is not None:
- if args_max_seq_length > model_max_seq_length:
- print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but '
- 'the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.')
- args.max_seq_length = model_max_seq_length
- if model is not None and hasattr(model, 'for_training'):
- model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
- if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
- if 'processing_class' in locals():
- if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
- if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
- __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
- from unsloth_zoo.vision_utils import UnslothVisionDataCollator
- if not isinstance(data_collator, UnslothVisionDataCollator):
- if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
- data_collator = TransformersDataCollatorForLanguageModeling(
- __tokenizer,
- mlm = False,
- mlm_probability = 0.0,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
- data_collator = DataCollatorForSeq2Seq(
- __tokenizer,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- else:
- if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
- if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
- if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
- if not isinstance(data_collator, UnslothVisionDataCollator):
- if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
- if isinstance(data_collator, DataCollatorForSeq2Seq):
- data_collator = DataCollatorForSeq2Seq(
- __tokenizer.tokenizer,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- else:
- data_collator = TransformersDataCollatorForLanguageModeling(
- __tokenizer.tokenizer,
- mlm = False,
- mlm_probability = 0.0,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- other_metrics = []
-
- from unsloth_zoo.logging_utils import PatchRLStatistics
- PatchRLStatistics('prm_trainer', other_metrics)
-
- # [TODO] Fix up DataParallel multiplying batch sizes
- # [TODO] DDP works, but DP seems to not work? [TODO]
- if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1:
- if getattr(args, "_n_gpu", 1) != 1:
- args._n_gpu = 1
- if "model" in locals() and hasattr(model, "for_training"):
- model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
- super().__init__(
- model = model,
- args = args,
- data_collator = data_collator,
- train_dataset = train_dataset,
- eval_dataset = eval_dataset,
- processing_class = processing_class,
- model_init = model_init,
- compute_metrics = compute_metrics,
- callbacks = callbacks,
- preprocess_logits_for_metrics = preprocess_logits_for_metrics,
- peft_config = peft_config,**kwargs)
- if "model" in locals() and hasattr(model, "for_inference"):
- model.for_inference()
- if hasattr(self, 'neftune_hook_handle'):
- self.neftune_hook_handle.remove()
- if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
- if getattr(args, 'neftune_noise_alpha', None) is not None:
- model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
- pass
- if hasattr(self, 'accelerator'):
- scaler = self.accelerator.scaler
- current_model = model
- while hasattr(current_model, 'model'):
- current_model.accelerator_scaler = scaler
- current_model = current_model.model
- current_model.accelerator_scaler = scaler
- pass
- if hasattr(self, 'train'):
- self.train = MethodType(prepare_for_training_mode(self.__class__.train), self)
- pass
- if hasattr(self, 'llm') and self.llm is not None and hasattr(self.llm, 'get_tokenizer'):
- _vllm_tok = self.llm.get_tokenizer()
- _pc = getattr(self, 'processing_class', None) or getattr(self, 'tokenizer', None)
- if _vllm_tok is not None and _pc is not None and getattr(_pc, 'chat_template', None) is not None and getattr(_vllm_tok, 'chat_template', None) is None:
- _vllm_tok.chat_template = _pc.chat_template
- pass
-
-pass
diff --git a/unsloth_compiled_cache/UnslothRLOOTrainer.py b/unsloth_compiled_cache/UnslothRLOOTrainer.py
deleted file mode 100644
index 740659a..0000000
--- a/unsloth_compiled_cache/UnslothRLOOTrainer.py
+++ /dev/null
@@ -1,2828 +0,0 @@
-"""
-2026.2.1
-2026.2.1
-4.57.6
-0.24.0
-__UNSLOTH_VERSIONING__
-"""
-
-# Unsloth auto generated code
-# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Lesser General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU Lesser General Public License
-# along with this program. If not, see .
-
-from torch import Tensor
-import torch
-import torch.nn as nn
-from torch.nn import functional as F
-from unsloth_zoo.temporary_patches.common import torch_compile
-from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
-from trl.trainer.rloo_trainer import (Any, AutoConfig, AutoModelForSequenceClassification, AutoProcessor, AutoTokenizer, BaseTrainer, DataLoader, Dataset, FSDP, GenerationConfig, IterableDataset, Optional, Path, PeftConfig, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, RLOOConfig, RLOOTrainer, RepeatSampler, RewardFunc, Sampler, SyncRefModelCallback, TrainerCallback, Union, VLLMClient, apply_chat_template, broadcast_object_list, datasets, defaultdict, deque, disable_dropout_in_model, ensure_master_addr_port, entropy_from_logits, gather, gather_object, identity, inspect, is_conversational, is_datasets_available, is_flash_attn_2_available, is_peft_model, is_rich_available, is_vllm_available, logger, logging, maybe_apply_chat_template, nanmax, nanmin, nanstd, nn, nullcontext, os, pad, partial, prepare_deepspeed, prepare_fsdp, prepare_multimodal_messages, print_prompt_completions_sample, profiling_context, profiling_decorator, seed_worker, selective_log_softmax, set_seed, shuffle_sequence_dict, split_pixel_values_by_grid, split_tensor_dict, textwrap, torch, transformers, unsplit_pixel_values_by_grid, unwrap_model_for_generation, warnings, AutoConfig, AutoModelForSequenceClassification, AutoProcessor, AutoTokenizer, Dataset, GenerationConfig, IterableDataset, Optional, PeftConfig, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, RLOOConfig, RLOOTrainer, RewardFunc, SyncRefModelCallback, TrainerCallback, Union, VLLMClient, datasets, defaultdict, deque, disable_dropout_in_model, ensure_master_addr_port, identity, inspect, is_peft_model, is_vllm_available, logger, nn, os, pad, prepare_deepspeed, prepare_fsdp, set_seed, torch, transformers, warnings, FSDP, Optional, apply_chat_template, broadcast_object_list, gather, gather_object, is_flash_attn_2_available, maybe_apply_chat_template, nullcontext, os, pad, prepare_multimodal_messages, profiling_context, torch, transformers, unwrap_model_for_generation, FSDP, gather, is_peft_model, nn, nullcontext, os, profiling_decorator, Any, Union, profiling_decorator, shuffle_sequence_dict, split_pixel_values_by_grid, split_tensor_dict, torch, unsplit_pixel_values_by_grid, Optional, PreTrainedModel, logger, os, torch, FSDP, nn, os, FSDP, nn, torch)
-
-
-import os
-from typing import *
-from dataclasses import dataclass, field
-from packaging.version import Version
-import torch
-import numpy as np
-from contextlib import nullcontext
-from torch.nn import functional as F
-import inspect
-from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
-from transformers.training_args import ParallelMode
-from unsloth_zoo.device_type import DEVICE_TYPE, device_synchronize
-
-# Wrap trainer with padding to right and enable training mode
-# Also patches W&B since multiple runs must use wandb.finish()
-import functools
-from types import MethodType
-try:
- from unsloth_zoo.gradient_checkpointing import reset_unsloth_gradient_checkpointing_buffers
-except:
- def reset_unsloth_gradient_checkpointing_buffers(): pass
-def prepare_for_training_mode(f):
- @functools.wraps(f)
- def wrapper(self, *args, **kwargs):
- # Enable training mode
- _was_training = None
- # Get gradient checkpointing setting from training arguments
- use_gc = getattr(self.args, 'gradient_checkpointing', True)
- if hasattr(self, 'model') and hasattr(self.model, "training"):
- _was_training = self.model.training
- if hasattr(self, 'model') and hasattr(self.model, "for_training"):
- self.model.for_training(use_gradient_checkpointing=use_gc)
- output = f(self, *args, **kwargs)
- # Restore previous mode when possible
- if hasattr(self, 'model') and hasattr(self.model, "for_inference"):
- if _was_training is False:
- self.model.for_inference()
- elif _was_training is True and hasattr(self.model, "for_training"):
- self.model.for_training(use_gradient_checkpointing=use_gc)
- # Reset gradient checkpointing buffers to free memory while staying ready for next run
- try:
- reset_unsloth_gradient_checkpointing_buffers()
- except:
- pass
- # Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run
- try:
- import wandb
- wandb.finish()
- except:
- pass
- return output
- return wrapper
-pass
-
-torch_compile_options = {
- "epilogue_fusion" : True,
- "max_autotune" : False,
- "shape_padding" : True,
- "trace.enabled" : False,
- "triton.cudagraphs" : False,
-}
-
-@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
-def chunked_hidden_states_selective_log_softmax(
- hidden_states: torch.Tensor,
- lm_head: torch.Tensor,
- index: torch.Tensor,
- chunks: int = 4,
- logit_scale_multiply: float = 0.0,
- logit_scale_divide: float = 0.0,
- logit_softcapping: float = 0.0,
- temperature: float = 1.0,
-) -> torch.Tensor:
- # All Unsloth Zoo code licensed under AGPL3
- flat_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1])
- flat_index = index.reshape(-1)
-
- chunked_hidden_states = torch.chunk(flat_hidden_states, chunks=chunks, dim=0)
- chunked_index = torch.chunk(flat_index, chunks=chunks, dim=0)
-
- all_per_token_logps = []
-
- for chunk_hidden_states, chunk_index in zip(chunked_hidden_states, chunked_index):
- chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t()
-
- if logit_scale_multiply != 0.0:
- chunk_logits = chunk_logits * logit_scale_multiply
- if logit_scale_divide != 0.0:
- chunk_logits = chunk_logits / logit_scale_divide
- if logit_softcapping != 0.0:
- chunk_logits = chunk_logits * torch.tanh(chunk_logits / logit_softcapping)
-
- chunk_logits = chunk_logits.to(torch.float32)
-
- if temperature != 1.0:
- chunk_logits = chunk_logits / temperature
-
- selected_logits = torch.gather(chunk_logits, dim=-1, index=chunk_index.unsqueeze(-1)).squeeze(-1)
- logsumexp_values = torch.logsumexp(chunk_logits, dim=-1)
- per_token_logps = selected_logits - logsumexp_values
- all_per_token_logps.append(per_token_logps)
-
- all_per_token_logps = torch.concat(all_per_token_logps)
-
- all_per_token_logps = all_per_token_logps.reshape((hidden_states.shape[0], hidden_states.shape[1]))
- return all_per_token_logps
-
-@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
-def chunked_selective_log_softmax(logits, index):
- # Split into 4 chunks only
- chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
- chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
- all_per_token_logps = []
- # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
- for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
- chunk_logits = chunk_logits.to(torch.float32)
- selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
- logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
- per_token_logps = selected_logits - logsumexp_values
- all_per_token_logps.append(per_token_logps)
- pass
- all_per_token_logps = torch.concat(all_per_token_logps)
- all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
- return all_per_token_logps
-
-def calculate_pad_tokens_in_prompt(
- input_ids: torch.Tensor,
- logits_to_keep: int,
- pad_token_id: int
-) -> torch.Tensor:
- """
- Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens
- """
- if logits_to_keep >= input_ids.shape[1]:
- raise ValueError("logits_to_keep must be smaller than the sequence length.")
-
- prompt_section = input_ids[:, :-logits_to_keep]
-
- padding_mask = (prompt_section == pad_token_id)
-
- pad_token_counts = padding_mask.sum(dim=1)
-
- return pad_token_counts
-
-def create_completion_attention_mask(
- completion_input_ids: torch.Tensor,
- left_pad_tokens_per_prompt: torch.Tensor,
- max_left_pad: int,
- pad_token_id: int
-) -> torch.Tensor:
- """
- Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad]
-
- Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens
- and pad are pad tokens, this function would make a completion mask that would 0 out the pad
- and p tokens. so in this example [0,0,0,1,1,1,0,0,0]
- """
- batch_size, completion_len = completion_input_ids.shape
- device = completion_input_ids.device
-
- num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt
-
- indices = torch.arange(completion_len, device=device).unsqueeze(0)
- shift_mask = indices >= num_tokens_to_mask.unsqueeze(1)
-
- non_padding_mask = (completion_input_ids != pad_token_id)
-
- final_mask = shift_mask & non_padding_mask
-
- return final_mask
-
-def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor:
- """
- Moves all padding tokens in each sequence of a batch to the right.
- """
- mask = (tensor != pad_id)
- # Must do stable=True since binary mark is unordered
- sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True)
- packed_tensor = torch.gather(tensor, 1, sorted_indices)
- return packed_tensor
-
-def align_logprobs_with_mask(
- logprob_tensor: torch.Tensor,
- attention_mask: torch.Tensor,
- pad_value: float = 0.0
-) -> torch.Tensor:
- """
- Aligns a log probability tensor with a given attention mask.
- """
-
- device = logprob_tensor.device
- batch_size, logprob_seq_len = logprob_tensor.shape
- mask_seq_len = attention_mask.shape[1]
-
- padded_logprobs = torch.full(
- attention_mask.shape,
- fill_value=pad_value,
- dtype=logprob_tensor.dtype,
- device=device
- )
-
- left_pad_counts = torch.argmax(attention_mask, dim=1)
-
- cols = torch.arange(logprob_seq_len, device=device)
- dest_indices = left_pad_counts.unsqueeze(1) + cols
-
- # Create destination row indices
- # Shape: [batch_size, logprob_seq_len]
- row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices)
-
- # --- 4. Filter out-of-bounds indices and perform assignment ---
- # Create a mask to identify only the indices that are within the bounds
- # of the target tensor's sequence length.
- valid_mask = dest_indices < mask_seq_len
-
- # Use this mask to select only the valid row indices, column indices,
- # and the corresponding values from the logprob tensor.
- # This flattens the selected elements into 1D tensors.
- valid_rows = row_indices[valid_mask]
- valid_cols = dest_indices[valid_mask]
- valid_vals = logprob_tensor[valid_mask]
-
- # Place the valid values into their correct positions in the padded tensor
- # using a single, efficient advanced indexing operation.
- padded_logprobs[valid_rows, valid_cols] = valid_vals
-
- return padded_logprobs
-
-def autotune_batch_and_chunks(
- total_input_rows,
- seq_len,
- hidden_size,
- vocab_size,
- dtype_bytes=16,
- multiplier=None
-):
- if multiplier is None:
- final_m = max(4, seq_len // 4096)
- else:
- final_m = multiplier
-
- if torch.cuda.is_available():
- free_bytes, _ = torch.cuda.mem_get_info()
- limit_gb = (free_bytes / (1024**3))*.80
- elif hasattr(torch, "xpu") and torch.xpu.is_available():
- # For XPU: estimate free memory from total - reserved
- total_mem = torch.xpu.get_device_properties(0).total_memory
- reserved_mem = torch.xpu.memory_reserved()
- free_bytes = total_mem - reserved_mem
- limit_gb = (free_bytes / (1024**3)) * 0.80
- else:
- # Fallback: assume 8GB available
- limit_gb = 8.0
-
- bytes_to_gb = 1024**3
-
- b_vals = torch.arange(total_input_rows, 0, -1, device='cpu', dtype=torch.float32)
-
- hidden_gb = (b_vals * seq_len * hidden_size * dtype_bytes) / bytes_to_gb
-
- base_logits = ((b_vals/total_input_rows) * b_vals * seq_len * vocab_size * dtype_bytes) / bytes_to_gb
- logits_gb = base_logits / final_m
-
- total_mem_gb = hidden_gb + logits_gb
-
- valid_mask = total_mem_gb <= limit_gb
- valid_indices = torch.nonzero(valid_mask, as_tuple=False)
-
- if valid_indices.shape[0] == 0:
- #This means your GPU will OOM
- return 4, final_m
-
- best_idx = valid_indices[0].item()
- final_b = int(b_vals[best_idx].item())
-
- return final_b, final_m
-def vLLMSamplingParams(**kwargs):
- from vllm import SamplingParams
-
- sampling_params = SamplingParams(**kwargs)
- sampling_params._set_kwargs = kwargs
- return sampling_params
-@dataclass
-class UnslothRLOOConfig(RLOOConfig):
- """
-
- Configuration class for the [`RLOOTrainer`].
-
- This class includes only the parameters that are specific to RLOO training. For a full list of training arguments,
- please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may
- differ from those in [`~transformers.TrainingArguments`].
-
- Using [`~transformers.HfArgumentParser`] we can turn this class into
- [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
- command line.
-
- Parameters:
- > Parameters that control the model and reference model
-
- model_init_kwargs (`str`, `dict[str, Any]`, *optional*):
- Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
- argument of the [`RLOOTrainer`] is provided as a string.
- disable_dropout (`bool`, *optional*, defaults to `False`):
- Whether to disable dropout in the model. This is useful for training with a reference model, as it prevents
- the model from generating different logprobs for the same input.
-
- > Parameters that control the data preprocessing
-
- remove_unused_columns (`bool`, *optional*, defaults to `False`):
- Whether to only keep the column `"prompt"` in the dataset. If you use a custom reward function that
- requires any column other than `"prompts"` and `"completions"`, you should keep this to `False`.
- max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
- Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.
- num_generations (`int` or `None`, *optional*, defaults to `2`):
- Number of generations per prompt to sample. The effective batch size (num_processes * per_device_batch_size
- * gradient_accumulation_steps) must be evenly divisible by this value.
- max_completion_length (`int` or `None`, *optional*, defaults to `256`):
- Maximum length of the generated completion.
- ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
- This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
- improving generation speed. However, disabling this option allows training models that exceed the VRAM
- capacity of a single GPU, albeit at the cost of slower generation. Disabling this option is not compatible
- with vLLM generation.
- shuffle_dataset (`bool`, *optional*, defaults to `True`):
- Whether to shuffle the training dataset.
-
- > Parameters that control generation
-
- generation_batch_size: (`int`, *optional*):
- Batch size to use for generation. If `None`, it defaults to the effective training batch size:
- `per_device_train_batch_size * num_processes * steps_per_generation`. In other words, there is one
- generation batch processed per optimization step. Mutually exclusive with `steps_per_generation`.
- steps_per_generation: (`int`, *optional*):
- Number of steps per generation. If `None`, it defaults to `gradient_accumulation_steps`. Mutually exclusive
- with `generation_batch_size`.
- temperature (`float`, defaults to `1.0`):
- Temperature for sampling. The higher the temperature, the more random the completions.
- top_p (`float`, *optional*, defaults to `1.0`):
- Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. Set to
- `1.0` to consider all tokens.
- top_k (`int`, *optional*):
- Number of highest probability vocabulary tokens to keep for top-k-filtering. If `None`, top-k-filtering is
- disabled and all tokens are considered.
- min_p (`float`, *optional*):
- Minimum token probability, which will be scaled by the probability of the most likely token. It must be a
- value between `0.0` and `1.0`. Typical values are in the `0.01-0.2` range.
- repetition_penalty (`float`, *optional*, defaults to `1.0`):
- Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far.
- Values > `1.0` encourage the model to use new tokens, while values < `1.0` encourage the model to repeat
- tokens.
- use_transformers_paged (`bool`, *optional*, defaults to `False`):
- Whether to use the `transformers` paged implementation for generation. If set to `True`, the `transformers`
- paged implementation will be used for generation instead of the default padded implementation. This
- parameter is only effective when `use_vllm` is set to `False`.
- cache_implementation (`str`, *optional*):
- Implementation of the cache method for faster generation when `use_vllm` is set to `False`.
- generation_kwargs (`dict[str, Any]`, *optional*):
- Additional keyword arguments to pass to [`~transformers.GenerationConfig`] (if using transformers) or
- `SamplingParams` (if using vLLM) when sampling completions. This can be used to further customize the
- generation behavior, such as setting `suppress_tokens`, `num_beams`, etc. If it contains keys that conflict
- with the other generation parameters (like `min_p`, `top_p`, etc.), they will override them.
-
- > Parameters that control generation acceleration powered by vLLM
-
- use_vllm (`bool`, *optional*, defaults to `False`):
- Whether to use vLLM for generating completions. If set to `True`, the trainer will use vLLM for generation
- instead of the default model.generate(). Requires `vllm` to be installed.
- vllm_mode (`str`, *optional*, defaults to `"server"`):
- Mode to use for vLLM integration when `use_vllm` is set to `True`. Must be one of `"server"` or
- `"colocate"`.
-
- - `"server"`: The trainer will send generation requests to a separate vLLM server. Make sure a TRL vLLM
- server is running (start with `trl vllm-serve`).
- - `"colocate"`: vLLM will run in the same process and share the training GPUs. This avoids the need for a
- separate server but may cause resource contention with training.
- vllm_model_impl (`str`, *optional*, defaults to `"vllm"`):
- Model implementation to use for vLLM. Must be one of `"transformers"` or `"vllm"`. `"transformers"`: Use
- the `transformers` backend for model implementation. `"vllm"`: Use the `vllm` library for model
- implementation.
- vllm_guided_decoding_regex (`str`, *optional*):
- Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled.
-
- > Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`)
-
- vllm_server_base_url (`str`, *optional*):
- Base URL for the vLLM server (e.g., `"http://localhost:8000"`). If provided, `vllm_server_host` and
- `vllm_server_port` are ignored.
- vllm_server_host (`str`, *optional*, defaults to `"0.0.0.0"`):
- Host of the vLLM server to connect to. Ignored if `vllm_server_base_url` is provided.
- vllm_server_port (`int`, *optional*, defaults to `8000`):
- Port of the vLLM server to connect to. Ignored if `vllm_server_base_url` is provided.
- vllm_server_timeout (`float`, *optional*, defaults to `240.0`):
- Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up after the
- timeout, a `ConnectionError` is raised.
-
- > Parameters that control colocated vLLM execution (only used when `vllm_mode` is `"colocate"`)
-
- vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.3`):
- Control the GPU memory utilization for vLLM. This setting only applies when `vllm_mode` is set to
- `"colocate"`. If you are using `vllm_mode="server"`, this parameter must be passed separately when
- launching the vLLM server via the `--vllm_gpu_memory_utilization` flag.
- vllm_tensor_parallel_size (`int`, *optional*, defaults to `1`):
- Control the tensor parallel size for vLLM. This setting only applies when `vllm_mode` is set to
- `"colocate"`. If you are using `vllm_mode="server"`, this parameter must be passed separately when
- launching the vLLM server via the `--vllm_tensor_parallel_size` flag.
- vllm_enable_sleep_mode (`bool`, *optional*, defaults to `False`):
- Whether to enable sleep mode for vLLM. If `True`, vLLM will sleep during the optimization step and woken
- for weight sync and generation.
-
- > Parameters that control the training
-
- beta (`float`, *optional*, defaults to `0.05`):
- KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving training
- speed.
- num_iterations (`int`, *optional*, defaults to `1`):
- Number of iterations per batch (denoted as μ in the algorithm).
- epsilon (`float`, *optional*, defaults to `0.2`):
- Epsilon value for clipping.
- epsilon_high (`float`, *optional*):
- Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the lower-bound
- specified in argument `epsilon`. Paper [DAPO](https://huggingface.co/papers/2503.14476) recommends `0.28`.
- reward_weights (`list[float]`, *optional*):
- Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are
- weighted equally with weight `1.0`.
- normalize_advantages (`bool`, *optional*, defaults to `False`):
- Whether to normalize advantages. Normalization is done per generation batch to have mean `0.0` and standard
- deviation of `1.0`.
- reward_clip_range (`tuple[float, float]`, *optional*):
- Clip range for rewards as (min, max). If `None`, no clipping is applied.
- mask_truncated_completions (`bool`, *optional*, defaults to `False`):
- When enabled, truncated completions are excluded from the loss calculation, preventing them from being
- incorrectly penalized and introducing noise during training. According to the
- [DAPO](https://huggingface.co/papers/2503.14476) paper, this is a good practice for training stability.
- sync_ref_model (`bool`, *optional*, defaults to `False`):
- Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using
- the `ref_model_mixup_alpha` parameter. This synchronization originates from the
- [TR-DPO](https://huggingface.co/papers/2404.09656) paper.
- ref_model_mixup_alpha (`float`, *optional*, defaults to `0.6`):
- α parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which controls the mix
- between the current policy and the previous reference policy during updates. The reference policy is
- updated according to the equation: `π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you
- must set `sync_ref_model=True`.
- ref_model_sync_steps (`int`, *optional*, defaults to `512`):
- τ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which determines how
- frequently the current policy is synchronized with the reference policy. To use this parameter, you must
- set `sync_ref_model=True`.
-
- > Parameters that control the logging
-
- log_completions (`bool`, *optional*, defaults to `False`):
- Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is installed,
- it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`.
- num_completions_to_print (`int`, *optional*):
- Number of completions to print with `rich`. If `None`, all completions are logged.
- wandb_log_unique_prompts (`bool`, *optional*, defaults to `False`):
- Whether to log unique prompts in wandb. If `True`, only unique prompts are logged. If `False`, all prompts
- are logged.
-
- > Deprecated parameters
-
- rloo_k:
-
-
-
- This parameter is deprecated and will be removed in version 0.25.0. Use `num_generations` instead.
-
-
-
- cliprange:
-
-
-
- This parameter is deprecated and will be removed in version 0.25.0. Use `epsilon` instead.
-
-
-
- kl_coef:
-
-
-
- This parameter is deprecated and will be removed in version 0.25.0. Use `beta` instead.
-
-
-
- exp_name:
-
-
-
- This parameter is deprecated and will be removed in version 0.25.0. Use `run_name` instead.
-
-
-
- normalize_reward:
-
-
-
- This parameter is deprecated and will be removed in version 0.25.0. Use `normalize_advantages` instead.
-
-
-
- num_ppo_epochs:
-
-
-
- This parameter is deprecated and will be removed in version 0.25.0. Use `num_iterations` instead.
-
-
-
- num_mini_batches:
-
-
-
- This parameter is deprecated and will be removed in version 0.25.0. Use `steps_per_generation` instead.
-
-
-
- total_episodes:
-
-
-
- This parameter is deprecated and will be removed in version 0.25.0. Use `max_steps` instead.
-
-
-
- response_length:
-
-
-
- This parameter is deprecated and will be removed in version 0.25.0. Use `max_completion_length` instead.
-
-
-
- token_level_kl:
-
-
-
- This parameter is deprecated and will be removed in version 0.25.0. KL is now computed only at the sequence
- level.
-
-
-
- dataset_num_proc:
-
-
-
- This parameter is deprecated and will be removed in version 0.25.0. This parameter was unused, you can
- safely remove it from your scripts.
-
-
-
- local_rollout_forward_batch_size:
-
-
-
- This parameter is deprecated and will be removed in version 0.25.0. Now it is automatically set to
- `per_device_train_batch_size` (or `per_device_eval_batch_size` during evaluation).
-
-
-
- num_sample_generations:
-
-
-
- This parameter is deprecated and will be removed in version 0.25.0. Use `logging_steps` to control
- generation logging frequency.
-
-
-
- stop_token:
-
-
-
- This parameter is deprecated and will be removed in version 0.25.0.
-
-
-
- stop_token_id:
-
-
-
- This parameter is deprecated and will be removed in version 0.25.0. Use `processing_class.eos_token_id`
- instead.
-
-
-
- missing_eos_penalty:
-
-
-
- This parameter is deprecated and will be removed in version 0.25.0. Replicate with a custom reward function
- checking if `eos_token_id` is in `completion_ids`.
-
-
-
- """
- vllm_sampling_params: Optional[Any] = field(
- default = None,
- metadata = {'help': 'vLLM SamplingParams'},
- )
- unsloth_num_chunks : Optional[int] = field(
- default = -1,
- metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
- )
- unsloth_logit_chunk_multiplier : Optional[int] = field(
- default = None,
- metadata = {'help': 'Multiplier for chunked logit computations.'},
- )
- unsloth_grpo_mini_batch : Optional[int] = field(
- default = None,
- metadata = {'help': 'Mini batch size for GRPO hidden state accumulation. Default is None unless user defines it.'},
- )
-
- def __init__(
- self,
- output_dir = None,
- overwrite_output_dir = None,
- do_train = False,
- do_eval = False,
- do_predict = False,
- eval_strategy = 'no',
- prediction_loss_only = False,
- per_device_train_batch_size = 4,
- per_device_eval_batch_size = 4,
- per_gpu_train_batch_size = None,
- per_gpu_eval_batch_size = None,
- gradient_accumulation_steps = 2,
- eval_accumulation_steps = 2,
- eval_delay = 0,
- torch_empty_cache_steps = 250,
- learning_rate = 5e-05,
- weight_decay = 0.01,
- adam_beta1 = 0.9,
- adam_beta2 = 0.999,
- adam_epsilon = 1e-08,
- max_grad_norm = 1.0,
- num_train_epochs = 3.0,
- max_steps = -1,
- lr_scheduler_type = 'linear',
- lr_scheduler_kwargs = None,
- warmup_ratio = 0.1,
- warmup_steps = 0,
- log_level = 'passive',
- log_level_replica = 'warning',
- log_on_each_node = True,
- logging_dir = None,
- logging_strategy = 'steps',
- logging_first_step = False,
- logging_steps = 1,
- logging_nan_inf_filter = False,
- save_strategy = 'steps',
- save_steps = 500,
- save_total_limit = None,
- save_safetensors = True,
- save_on_each_node = False,
- save_only_model = False,
- restore_callback_states_from_checkpoint = False,
- no_cuda = False,
- use_cpu = False,
- use_mps_device = False,
- seed = 3407,
- data_seed = 3407,
- jit_mode_eval = False,
- bf16 = False,
- fp16 = False,
- fp16_opt_level = 'O1',
- half_precision_backend = 'auto',
- bf16_full_eval = False,
- fp16_full_eval = False,
- tf32 = None,
- local_rank = -1,
- ddp_backend = None,
- tpu_num_cores = None,
- tpu_metrics_debug = False,
- debug = '',
- dataloader_drop_last = False,
- eval_steps = None,
- dataloader_num_workers = 0,
- dataloader_prefetch_factor = None,
- past_index = -1,
- run_name = None,
- disable_tqdm = None,
- remove_unused_columns = False,
- label_names = None,
- load_best_model_at_end = False,
- metric_for_best_model = None,
- greater_is_better = None,
- ignore_data_skip = False,
- fsdp = None,
- fsdp_min_num_params = 0,
- fsdp_config = None,
- fsdp_transformer_layer_cls_to_wrap = None,
- accelerator_config = None,
- parallelism_config = None,
- deepspeed = None,
- label_smoothing_factor = 0.0,
- optim = 'adamw_8bit',
- optim_args = None,
- adafactor = False,
- group_by_length = False,
- length_column_name = 'length',
- report_to = 'none',
- project = 'huggingface',
- trackio_space_id = 'trackio',
- ddp_find_unused_parameters = None,
- ddp_bucket_cap_mb = None,
- ddp_broadcast_buffers = None,
- dataloader_pin_memory = True,
- dataloader_persistent_workers = False,
- skip_memory_metrics = True,
- use_legacy_prediction_loop = False,
- push_to_hub = False,
- resume_from_checkpoint = None,
- hub_model_id = None,
- hub_strategy = 'every_save',
- hub_token = None,
- hub_private_repo = None,
- hub_always_push = False,
- hub_revision = None,
- gradient_checkpointing = True,
- gradient_checkpointing_kwargs = None,
- include_inputs_for_metrics = False,
- eval_do_concat_batches = True,
- fp16_backend = 'auto',
- push_to_hub_model_id = None,
- push_to_hub_organization = None,
- push_to_hub_token = None,
- mp_parameters = '',
- auto_find_batch_size = False,
- full_determinism = False,
- torchdynamo = None,
- ray_scope = 'last',
- ddp_timeout = 1800,
- torch_compile = False,
- torch_compile_backend = None,
- torch_compile_mode = None,
- include_tokens_per_second = False,
- include_num_input_tokens_seen = False,
- neftune_noise_alpha = None,
- optim_target_modules = None,
- batch_eval_metrics = False,
- eval_on_start = False,
- use_liger_kernel = False,
- liger_kernel_config = None,
- eval_use_gather_object = False,
- average_tokens_across_devices = True,
- model_init_kwargs = None,
- disable_dropout = False,
- max_prompt_length = 512,
- num_generations = 8,
- max_completion_length = 256,
- ds3_gather_for_generation = True,
- shuffle_dataset = True,
- generation_batch_size = None,
- steps_per_generation = None,
- temperature = 1.0,
- top_p = 1.0,
- top_k = None,
- min_p = None,
- generation_kwargs = {},
- repetition_penalty = 1.0,
- use_transformers_paged = False,
- cache_implementation = None,
- use_vllm = False,
- vllm_mode = 'colocate',
- vllm_model_impl = 'vllm',
- vllm_enable_sleep_mode = False,
- vllm_guided_decoding_regex = None,
- vllm_server_base_url = None,
- vllm_server_host = '0.0.0.0',
- vllm_server_port = 8000,
- vllm_server_timeout = 240.0,
- vllm_gpu_memory_utilization = 0.3,
- vllm_tensor_parallel_size = 1,
- beta = 0.05,
- num_iterations = 1,
- epsilon = 0.2,
- epsilon_high = None,
- reward_weights = None,
- normalize_advantages = False,
- reward_clip_range = None,
- mask_truncated_completions = False,
- sync_ref_model = False,
- ref_model_mixup_alpha = 0.6,
- ref_model_sync_steps = 512,
- log_completions = False,
- num_completions_to_print = None,
- wandb_log_unique_prompts = False,
- rloo_k = None,
- cliprange = None,
- kl_coef = None,
- exp_name = None,
- normalize_reward = None,
- num_ppo_epochs = None,
- num_mini_batches = None,
- total_episodes = None,
- response_length = None,
- token_level_kl = None,
- dataset_num_proc = None,
- local_rollout_forward_batch_size = None,
- num_sample_generations = None,
- stop_token = None,
- stop_token_id = None,
- missing_eos_penalty = None,
- vllm_sampling_params = None,
- unsloth_num_chunks = -1,
- unsloth_logit_chunk_multiplier = None,
- unsloth_grpo_mini_batch = None,
-
- **kwargs,
- ):
- if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
- if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
- if num_train_epochs is None:
- num_train_epochs = 3.0 # Default to 3 epochs if None, max_steps will override
- if output_dir is None and save_strategy == 'steps' and save_steps == 500:
- output_dir = 'unsloth_training_checkpoints'
- save_strategy = 'no'
- import multiprocessing as _mp
- if _mp.get_start_method() != 'fork':
- dataset_num_proc = None
- elif dataset_num_proc is None:
- import psutil
- dataset_num_proc = min(max((psutil.cpu_count() or 1)+4, 2), 64)
- memory_gb_left = psutil.virtual_memory().available / (1024**3)
- if memory_gb_left <= 2: dataset_num_proc = 1
- else: dataset_num_proc = min(dataset_num_proc, int(memory_gb_left))
- if steps_per_generation is None and generation_batch_size is None:
- ga = gradient_accumulation_steps
- world_size = int(os.environ.get('WORLD_SIZE', '1'))
- if (ga * world_size * per_device_train_batch_size) % num_generations != 0:
- print('Unsloth: We now expect `per_device_train_batch_size` * `gradient_accumulation_steps` * `world_size` to be a multiple of `num_generations`.\nWe will change the batch size of ' + str(per_device_train_batch_size) + ' to the `num_generations` of ' + str(num_generations))
- per_device_train_batch_size = num_generations
-
- if temperature <= 0:
- raise ValueError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.')
- elif temperature >= 10:
- raise ValueError('Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic.')
-
-
- super().__init__(
- output_dir = output_dir,
- overwrite_output_dir = overwrite_output_dir,
- do_train = do_train,
- do_eval = do_eval,
- do_predict = do_predict,
- eval_strategy = eval_strategy,
- prediction_loss_only = prediction_loss_only,
- per_device_train_batch_size = per_device_train_batch_size,
- per_device_eval_batch_size = per_device_eval_batch_size,
- per_gpu_train_batch_size = per_gpu_train_batch_size,
- per_gpu_eval_batch_size = per_gpu_eval_batch_size,
- gradient_accumulation_steps = gradient_accumulation_steps,
- eval_accumulation_steps = eval_accumulation_steps,
- eval_delay = eval_delay,
- torch_empty_cache_steps = torch_empty_cache_steps,
- learning_rate = learning_rate,
- weight_decay = weight_decay,
- adam_beta1 = adam_beta1,
- adam_beta2 = adam_beta2,
- adam_epsilon = adam_epsilon,
- max_grad_norm = max_grad_norm,
- num_train_epochs = num_train_epochs,
- max_steps = max_steps,
- lr_scheduler_type = lr_scheduler_type,
- lr_scheduler_kwargs = lr_scheduler_kwargs,
- warmup_ratio = warmup_ratio,
- warmup_steps = warmup_steps,
- log_level = log_level,
- log_level_replica = log_level_replica,
- log_on_each_node = log_on_each_node,
- logging_dir = logging_dir,
- logging_strategy = logging_strategy,
- logging_first_step = logging_first_step,
- logging_steps = logging_steps,
- logging_nan_inf_filter = logging_nan_inf_filter,
- save_strategy = save_strategy,
- save_steps = save_steps,
- save_total_limit = save_total_limit,
- save_safetensors = save_safetensors,
- save_on_each_node = save_on_each_node,
- save_only_model = save_only_model,
- restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
- no_cuda = no_cuda,
- use_cpu = use_cpu,
- use_mps_device = use_mps_device,
- seed = seed,
- data_seed = data_seed,
- jit_mode_eval = jit_mode_eval,
- bf16 = bf16,
- fp16 = fp16,
- fp16_opt_level = fp16_opt_level,
- half_precision_backend = half_precision_backend,
- bf16_full_eval = bf16_full_eval,
- fp16_full_eval = fp16_full_eval,
- tf32 = tf32,
- local_rank = local_rank,
- ddp_backend = ddp_backend,
- tpu_num_cores = tpu_num_cores,
- tpu_metrics_debug = tpu_metrics_debug,
- debug = debug,
- dataloader_drop_last = dataloader_drop_last,
- eval_steps = eval_steps,
- dataloader_num_workers = dataloader_num_workers,
- dataloader_prefetch_factor = dataloader_prefetch_factor,
- past_index = past_index,
- run_name = run_name,
- disable_tqdm = disable_tqdm,
- remove_unused_columns = remove_unused_columns,
- label_names = label_names,
- load_best_model_at_end = load_best_model_at_end,
- metric_for_best_model = metric_for_best_model,
- greater_is_better = greater_is_better,
- ignore_data_skip = ignore_data_skip,
- fsdp = fsdp,
- fsdp_min_num_params = fsdp_min_num_params,
- fsdp_config = fsdp_config,
- fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
- accelerator_config = accelerator_config,
- parallelism_config = parallelism_config,
- deepspeed = deepspeed,
- label_smoothing_factor = label_smoothing_factor,
- optim = optim,
- optim_args = optim_args,
- adafactor = adafactor,
- group_by_length = group_by_length,
- length_column_name = length_column_name,
- report_to = report_to,
- project = project,
- trackio_space_id = trackio_space_id,
- ddp_find_unused_parameters = ddp_find_unused_parameters,
- ddp_bucket_cap_mb = ddp_bucket_cap_mb,
- ddp_broadcast_buffers = ddp_broadcast_buffers,
- dataloader_pin_memory = dataloader_pin_memory,
- dataloader_persistent_workers = dataloader_persistent_workers,
- skip_memory_metrics = skip_memory_metrics,
- use_legacy_prediction_loop = use_legacy_prediction_loop,
- push_to_hub = push_to_hub,
- resume_from_checkpoint = resume_from_checkpoint,
- hub_model_id = hub_model_id,
- hub_strategy = hub_strategy,
- hub_token = hub_token,
- hub_private_repo = hub_private_repo,
- hub_always_push = hub_always_push,
- hub_revision = hub_revision,
- gradient_checkpointing = gradient_checkpointing,
- gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
- include_inputs_for_metrics = include_inputs_for_metrics,
- eval_do_concat_batches = eval_do_concat_batches,
- fp16_backend = fp16_backend,
- push_to_hub_model_id = push_to_hub_model_id,
- push_to_hub_organization = push_to_hub_organization,
- push_to_hub_token = push_to_hub_token,
- mp_parameters = mp_parameters,
- auto_find_batch_size = auto_find_batch_size,
- full_determinism = full_determinism,
- torchdynamo = torchdynamo,
- ray_scope = ray_scope,
- ddp_timeout = ddp_timeout,
- torch_compile = torch_compile,
- torch_compile_backend = torch_compile_backend,
- torch_compile_mode = torch_compile_mode,
- include_tokens_per_second = include_tokens_per_second,
- include_num_input_tokens_seen = include_num_input_tokens_seen,
- neftune_noise_alpha = neftune_noise_alpha,
- optim_target_modules = optim_target_modules,
- batch_eval_metrics = batch_eval_metrics,
- eval_on_start = eval_on_start,
- use_liger_kernel = use_liger_kernel,
- liger_kernel_config = liger_kernel_config,
- eval_use_gather_object = eval_use_gather_object,
- average_tokens_across_devices = average_tokens_across_devices,
- model_init_kwargs = model_init_kwargs,
- disable_dropout = disable_dropout,
- max_prompt_length = max_prompt_length,
- num_generations = num_generations,
- max_completion_length = max_completion_length,
- ds3_gather_for_generation = ds3_gather_for_generation,
- shuffle_dataset = shuffle_dataset,
- generation_batch_size = generation_batch_size,
- steps_per_generation = steps_per_generation,
- temperature = temperature,
- top_p = top_p,
- top_k = top_k,
- min_p = min_p,
- generation_kwargs = generation_kwargs,
- repetition_penalty = repetition_penalty,
- use_transformers_paged = use_transformers_paged,
- cache_implementation = cache_implementation,
- use_vllm = use_vllm,
- vllm_mode = vllm_mode,
- vllm_model_impl = vllm_model_impl,
- vllm_enable_sleep_mode = vllm_enable_sleep_mode,
- vllm_guided_decoding_regex = vllm_guided_decoding_regex,
- vllm_server_base_url = vllm_server_base_url,
- vllm_server_host = vllm_server_host,
- vllm_server_port = vllm_server_port,
- vllm_server_timeout = vllm_server_timeout,
- vllm_gpu_memory_utilization = vllm_gpu_memory_utilization,
- vllm_tensor_parallel_size = vllm_tensor_parallel_size,
- beta = beta,
- num_iterations = num_iterations,
- epsilon = epsilon,
- epsilon_high = epsilon_high,
- reward_weights = reward_weights,
- normalize_advantages = normalize_advantages,
- reward_clip_range = reward_clip_range,
- mask_truncated_completions = mask_truncated_completions,
- sync_ref_model = sync_ref_model,
- ref_model_mixup_alpha = ref_model_mixup_alpha,
- ref_model_sync_steps = ref_model_sync_steps,
- log_completions = log_completions,
- num_completions_to_print = num_completions_to_print,
- wandb_log_unique_prompts = wandb_log_unique_prompts,
- rloo_k = rloo_k,
- cliprange = cliprange,
- kl_coef = kl_coef,
- exp_name = exp_name,
- normalize_reward = normalize_reward,
- num_ppo_epochs = num_ppo_epochs,
- num_mini_batches = num_mini_batches,
- total_episodes = total_episodes,
- response_length = response_length,
- token_level_kl = token_level_kl,
- dataset_num_proc = dataset_num_proc,
- local_rollout_forward_batch_size = local_rollout_forward_batch_size,
- num_sample_generations = num_sample_generations,
- stop_token = stop_token,
- stop_token_id = stop_token_id,
- missing_eos_penalty = missing_eos_penalty,**kwargs)
- self.vllm_sampling_params = vllm_sampling_params
- self.unsloth_num_chunks = unsloth_num_chunks
- if unsloth_grpo_mini_batch is not None:
- if self.generation_batch_size >= unsloth_grpo_mini_batch:
- self.unsloth_grpo_mini_batch = unsloth_grpo_mini_batch
- else:
- raise ValueError(
- f"Unsloth GRPO mini batch size needs to be less than or equal to the effective generation batch size, "
- f"which is self.per_device_train_batch_size * gradient_accumulation_steps."
- )
- self.unsloth_logit_chunk_multiplier = unsloth_logit_chunk_multiplier
-
-
-pass
-
-class _UnslothRLOOTrainer(BaseTrainer):
- """"""
-
- _tag_names = ["trl", "rloo"]
- _name = "RLOO"
- _paper = {
- "title": "Back to Basics: Revisiting REINFORCE-Style Optimization for Learning from Human Feedback in LLMs",
- "id": "2402.14740",
- # docstyle-ignore
- "citation": textwrap.dedent("""\
- @inproceedings{ahmadian2024back,
- title = {{Back to Basics: Revisiting REINFORCE-Style Optimization for Learning from Human Feedback in LLMs}},
- author = {Arash Ahmadian and Chris Cremer and Matthias Gall{\'{e}} and Marzieh Fadaee and Julia Kreutzer and Olivier Pietquin and Ahmet {\"{U}}st{\"{u}}n and Sara Hooker},
- year = 2024,
- booktitle = {Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), {ACL} 2024, Bangkok, Thailand, August 11-16, 2024},
- pages = {12248--12267},
- publisher = {Association for Computational Linguistics},
- editor = {Lun{-}Wei Ku and Andre Martins and Vivek Srikumar},
- }"""),
- }
-
- def __init__(
- self,
- # Note for dev: we can remove the default None when we remove the deprecated model parameter in version 0.25.0
- model: Union[str, PreTrainedModel] = None,
- reward_funcs: Union[RewardFunc, list[RewardFunc]] = None,
- args: Optional[RLOOConfig] = None,
- train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
- eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
- processing_class: Optional[Union[PreTrainedTokenizerBase, ProcessorMixin]] = None,
- reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
- callbacks: Optional[list[TrainerCallback]] = None,
- optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
- peft_config: Optional["PeftConfig"] = None,
- # Deprecated parameters
- config=None,
- reward_model=None,
- policy=None,
- ref_policy=None,
- data_collator=None,
- ):
-
- if hasattr(model, 'vllm_engine') and hasattr(args, 'use_vllm'):
- if (getattr(args, 'use_vllm', False) == False):
- args.use_vllm = True
- if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"):
- warnings.warn(
- "This trainer will soon be moved to trl.experimental and is a candidate for removal. If you rely on "
- "it and want it to remain, please share your comments here: "
- "https://github.com/huggingface/trl/issues/4223. Silence this warning by setting environment variable "
- "TRL_EXPERIMENTAL_SILENCE=1."
- )
- # Handle deprecated parameters
- if config is not None:
- warnings.warn(
- "Parameter 'config' is deprecated and will be removed in version 0.25.0. Please use 'args' instead. "
- "We are setting args=config"
- )
- if args is None:
- args = config
- else:
- raise ValueError("Cannot specify both 'config' (deprecated) and 'args'. Please use 'args' only.")
-
- if reward_model is not None:
- warnings.warn(
- "Parameter 'reward_model' is deprecated and will be removed in version 0.25.0. Please use "
- "'reward_funcs' instead. We are setting reward_funcs=reward_model"
- )
- if reward_funcs is None:
- reward_funcs = reward_model
- else:
- raise ValueError(
- "Cannot specify both 'reward_model' (deprecated) and 'reward_funcs'. Please use 'reward_funcs' "
- "only."
- )
- if policy is not None:
- warnings.warn(
- "Parameter 'policy' is deprecated and will be removed in version 0.25.0. Please use 'model' instead. "
- "We are setting model=policy"
- )
- if model is None:
- model = policy
- else:
- raise ValueError("Cannot specify both 'policy' (deprecated) and 'model'. Please use 'model' only.")
- if ref_policy is not None:
- warnings.warn(
- "Parameter 'ref_policy' is deprecated and will be removed in version 0.25.0. To use the initial model "
- "as the reference model, simply omit this parameter. The parameter is ignored."
- )
- if data_collator is not None:
- warnings.warn(
- "Parameter 'data_collator' is deprecated and will be removed in version 0.25.0. The RLOOTrainer does "
- "not use a data collator, so this parameter is ignored."
- )
- if "input_ids" in train_dataset.column_names:
- warnings.warn(
- "The training dataset contains a column named 'input_ids', indicating that it is pre-tokenized. "
- "Support for pre-tokenized datasets is deprecated and will be removed in version 0.25. Please provide "
- "the raw dataset (conversational or standard) with a 'prompt' column instead."
- )
-
- def decode(example, tokenizer):
- return {"prompt": tokenizer.decode(example["input_ids"])}
-
- train_dataset = train_dataset.map(decode, fn_kwargs={"tokenizer": processing_class})
- if eval_dataset is not None and "input_ids" in eval_dataset.column_names:
- warnings.warn(
- "The evaluation dataset contains a column named 'input_ids', indicating that it is pre-tokenized. "
- "Support for pre-tokenized datasets is deprecated and will be removed in version 0.25. Please provide "
- "the raw dataset (conversational or standard) with a 'prompt' column instead."
- )
-
- def decode(example, tokenizer):
- return {"prompt": tokenizer.decode(example["input_ids"])}
-
- eval_dataset = eval_dataset.map(decode, fn_kwargs={"tokenizer": processing_class})
-
- # Args
- if args is None:
- model_name = model if isinstance(model, str) else model.config._name_or_path
- model_name = model_name.split("/")[-1]
- args = RLOOConfig(f"{model_name}-RLOO")
-
- # Models
- # Trained model
- model_init_kwargs = args.model_init_kwargs or {}
- if isinstance(model, str):
- model_id = model
- dtype = model_init_kwargs.get("dtype")
- if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None:
- pass # dtype is already a torch.dtype or "auto" or None
- elif isinstance(dtype, str): # it's a str, but not "auto"
- dtype = getattr(torch, dtype)
- model_init_kwargs["dtype"] = dtype
- else:
- raise ValueError(
- "Invalid `dtype` passed to `RLOOConfig`. Expected either 'auto' or a string representing "
- f"a `torch.dtype` (e.g., 'float32'), but got {dtype}."
- )
- # Disable caching if gradient checkpointing is enabled [not supported]
- config = AutoConfig.from_pretrained(model_id)
- architecture = getattr(transformers, config.architectures[0])
- model = architecture.from_pretrained(model_id, **model_init_kwargs)
- else:
- model_id = model.config._name_or_path
- if args.model_init_kwargs is not None:
- logger.warning(
- "You passed `model_init_kwargs` to the `RLOOConfig`, but your model is already instantiated. "
- "The `model_init_kwargs` will be ignored."
- )
-
- # Some models [SmolVLM/Idefics3] don't support `logits_to_keep` argument and error out if we pass it
- # Inspect the forward method before we wrap the model with PEFT
- self.model_kwarg_keys = (
- inspect.signature(model.forward).parameters.keys()
- if not hasattr(model, "get_base_model")
- else inspect.signature(model.get_base_model().forward).parameters.keys()
- )
-
- if False:
- pass
-
- # Processing class
- if processing_class is None:
- processing_class = AutoProcessor.from_pretrained(model.config._name_or_path, truncation_side="left")
-
- # Handle pad token for processors or tokenizers
- if isinstance(processing_class, ProcessorMixin):
- tokenizer = processing_class.tokenizer
- elif isinstance(processing_class, PreTrainedTokenizerBase):
- tokenizer = processing_class
- else:
- raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`")
-
- if tokenizer.pad_token is None:
- tokenizer.pad_token = tokenizer.eos_token
-
- self.pad_token = tokenizer.pad_token
- self.pad_token_id = tokenizer.pad_token_id
- self.eos_token_id = tokenizer.eos_token_id
-
- # Reward functions
- if not isinstance(reward_funcs, list):
- reward_funcs = [reward_funcs]
- self.reward_func_names = []
- for i, reward_func in enumerate(reward_funcs):
- if isinstance(reward_func, str):
- reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
- reward_func, num_labels=1, **model_init_kwargs
- )
- if isinstance(reward_funcs[i], nn.Module): # Use Module over PretrainedModel for compat w/ compiled models
- self.reward_func_names.append(reward_funcs[i].config._name_or_path.split("/")[-1])
- else:
- self.reward_func_names.append(reward_funcs[i].__name__)
- self.reward_funcs = reward_funcs
-
- # Reward weights
- if args.reward_weights is not None:
- if len(args.reward_weights) != len(reward_funcs):
- raise ValueError(
- f"Number of reward weights ({len(args.reward_weights)}) must match number of reward "
- f"functions ({len(reward_funcs)})"
- )
- self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32)
- else:
- self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32)
-
- # Reward processing class
- if reward_processing_classes is None:
- reward_processing_classes = [None] * len(reward_funcs)
- elif not isinstance(reward_processing_classes, list):
- reward_processing_classes = [reward_processing_classes]
- if len(reward_processing_classes) != len(reward_funcs):
- raise ValueError(
- f"The number of reward processing classes ({len(reward_processing_classes)}) must match the number of "
- f"reward functions ({len(reward_funcs)})."
- )
-
- for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):
- if isinstance(reward_func, PreTrainedModel):
- if reward_processing_class is None:
- reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
- if reward_processing_class.pad_token_id is None:
- reward_processing_class.pad_token = reward_processing_class.eos_token
- # The reward model computes the reward for the latest non-padded token in the input sequence.
- # So it's important to set the pad token ID to the padding token ID of the processing class.
- reward_func.config.pad_token_id = reward_processing_class.pad_token_id
- reward_processing_classes[i] = reward_processing_class
-
- self.reward_processing_classes = reward_processing_classes
-
- # Training arguments
- self.max_prompt_length = args.max_prompt_length
- self.max_completion_length = args.max_completion_length
- self.num_generations = args.num_generations
- self.temperature = args.temperature
- self.top_p = args.top_p
- self.top_k = args.top_k
- self.min_p = args.min_p
- self.repetition_penalty = args.repetition_penalty
- self.use_transformers_paged = args.use_transformers_paged
- self.use_vllm = args.use_vllm
- self.vllm_mode = args.vllm_mode
- self.vllm_gpu_memory_utilization = args.vllm_gpu_memory_utilization # only applies to colocation mode
- self.vllm_tensor_parallel_size = args.vllm_tensor_parallel_size # only applies to colocation mode
- self.normalize_advantages = args.normalize_advantages
- self.mask_truncated_completions = args.mask_truncated_completions
- self.reward_clip_range = args.reward_clip_range
-
- # Datasets
- self.shuffle_dataset = args.shuffle_dataset
-
- if (
- isinstance(train_dataset, IterableDataset)
- or isinstance(eval_dataset, IterableDataset)
- or (
- isinstance(eval_dataset, dict) and any(isinstance(ds, IterableDataset) for ds in eval_dataset.values())
- )
- ):
- # See https://github.com/huggingface/trl/issues/3213
- raise NotImplementedError(
- "Iterable datasets are not yet supported in RLOOTrainer. Please use a standard dataset instead."
- )
-
- # Multi-step
- self.num_iterations = args.num_iterations
- self.epsilon_low = args.epsilon
- self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon
- # Tracks the number of iterations [forward + backward passes], including those within a grad accum cycle
- self._step = 0
- # Buffer the batch to reuse generated outputs across multiple updates. For more details, see
- # `_get_train_sampler` and `_prepare_inputs`.
- self._buffered_inputs = None
-
- # The trainer estimates the number of FLOPs [floating-point operations] using the number of elements in the
- # input tensor associated with the key "input_ids". However, in RLOO, the sampled data does not include the
- # "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
- # "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
- # suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
- # This acts as a flag to indicate that the warning has already been issued.
- model.warnings_issued["estimate_tokens"] = True
-
- super().__init__(
- model=model,
- args=args,
- data_collator=identity, # No data collation is needed in RLOO
- train_dataset=train_dataset,
- eval_dataset=eval_dataset,
- processing_class=processing_class,
- callbacks=callbacks,
- optimizers=optimizers,
- )
-
- # Reference model
- self.beta = args.beta
- if self.beta == 0.0:
- # If beta is 0.0, the reference model is not needed
- self.ref_model = None
- elif is_peft_model(model):
- # If PEFT is used, the reference model is not needed since the adapter can be disabled
- # to revert to the initial model.
- self.ref_model = None
- else:
- # For deepspeed, fsdp or non-distributed models, create a reference model from scratch
- config = AutoConfig.from_pretrained(model_id)
- architecture = getattr(transformers, config.architectures[0])
- self.ref_model = architecture.from_pretrained(model_id, **model_init_kwargs)
-
- # Disable dropout in the models
- if args.disable_dropout:
- disable_dropout_in_model(model)
- if self.ref_model is not None:
- disable_dropout_in_model(self.ref_model)
-
- # Initialize the metrics
- self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
- self._total_train_tokens = 0
- self.log_completions = args.log_completions
- self.wandb_log_unique_prompts = args.wandb_log_unique_prompts
- self.num_completions_to_print = args.num_completions_to_print
- # Keep logs sized to the generation batch to record only outputs from the latest model update.
- self._logs = {
- "images": deque(maxlen=args.generation_batch_size),
- "prompt": deque(maxlen=args.generation_batch_size),
- "completion": deque(maxlen=args.generation_batch_size),
- "rewards": defaultdict(lambda: deque(maxlen=args.generation_batch_size)),
- "advantages": deque(maxlen=args.generation_batch_size),
- }
-
- # Ensure each process receives a unique seed to prevent duplicate completions when generating with
- # transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but
- # it's safer to set it in all cases.
- set_seed(args.seed, device_specific=True)
-
- if self.use_vllm:
- if not is_vllm_available():
- raise ImportError(
- "vLLM is not available and `use_vllm` is set to True. Please install vLLM with "
- "`pip install trl[vllm]` to use it."
- )
-
- if self.vllm_mode == "server":
- if self.accelerator.is_main_process:
- if args.vllm_server_base_url is not None:
- base_url = args.vllm_server_base_url
- else:
- base_url = f"http://{args.vllm_server_host}:{args.vllm_server_port}"
- self.vllm_client = VLLMClient(base_url=base_url, connection_timeout=args.vllm_server_timeout)
- self.vllm_client.init_communicator(device=torch.cuda.current_device())
-
- elif self.vllm_mode == "colocate":
- if not self.accelerator.num_processes % self.vllm_tensor_parallel_size == 0:
- raise ValueError(
- f"vllm_tensor_parallel_size ({self.vllm_tensor_parallel_size}) must divide world size "
- f"({self.accelerator.num_processes}) evenly."
- )
-
- if self.vllm_tensor_parallel_size > 1:
- self.tp_group, _ = torch.distributed.new_subgroups_by_enumeration(
- [
- list(range(i * self.vllm_tensor_parallel_size, (i + 1) * self.vllm_tensor_parallel_size))
- for i in range(self.accelerator.num_processes // self.vllm_tensor_parallel_size)
- ]
- )
- os.environ["RANK"] = str(self.accelerator.process_index)
- os.environ["LOCAL_RANK"] = str(self.accelerator.local_process_index)
- os.environ["WORLD_SIZE"] = str(self.accelerator.num_processes)
- ensure_master_addr_port()
-
- if self.max_prompt_length is not None and self.max_completion_length is not None:
- max_model_len = self.max_prompt_length + self.max_completion_length
- else:
- max_model_len = None
- self.llm = model.vllm_engine
- if self.args.vllm_enable_sleep_mode:
- self.llm.sleep(level=1)
- else:
- raise ValueError(f"vllm_mode must be either 'server' or 'colocate', got '{self.vllm_mode}'.")
- self.guided_decoding_regex = args.vllm_guided_decoding_regex
-
- self._last_loaded_step = -1
- self.accelerator.wait_for_everyone()
- else:
- generation_kwargs = {
- "max_new_tokens": self.max_completion_length,
- "do_sample": True,
- "pad_token_id": tokenizer.pad_token_id,
- "bos_token_id": tokenizer.bos_token_id,
- "eos_token_id": tokenizer.eos_token_id,
- "temperature": self.temperature,
- "top_p": self.top_p,
- "top_k": self.top_k,
- "min_p": self.min_p,
- "repetition_penalty": self.repetition_penalty,
- "cache_implementation": args.cache_implementation,
- }
- if args.generation_kwargs is not None:
- generation_kwargs.update(args.generation_kwargs)
- self.generation_config = GenerationConfig(**generation_kwargs)
-
- # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
- # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
- # self.model_accepts_loss_kwargs to False to enable scaling.
- self.model_accepts_loss_kwargs = False
-
- # Add tags to the model
- self.model.add_model_tags(self._tag_names)
-
- if self.ref_model is not None:
- if self.is_deepspeed_enabled:
- self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
- elif self.is_fsdp_enabled:
- self.ref_model = prepare_fsdp(self.ref_model, self.accelerator)
- else:
- self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
-
- if args.sync_ref_model:
- self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator))
-
- for i, reward_func in enumerate(self.reward_funcs):
- if isinstance(reward_func, PreTrainedModel):
- if self.is_deepspeed_enabled:
- self.reward_funcs[i] = prepare_deepspeed(reward_func, self.accelerator)
- else:
- # set device placement to True to make `prepare_model` move `reward_func` to device when using fsdp
- self.reward_funcs[i] = self.accelerator.prepare_model(
- reward_func, evaluation_mode=True, device_placement=True
- )
-
- def _set_signature_columns_if_needed(self):
- # If `self.args.remove_unused_columns` is True, non-signature columns are removed.
- # By default, this method sets `self._signature_columns` to the model's expected inputs.
- # In RLOOTrainer, we preprocess data, so using the model's signature columns doesn't work.
- # Instead, we set them to the columns expected by the `training_step` method, hence the override.
- if self._signature_columns is None:
- self._signature_columns = ["prompt", "image", "images"]
-
- # This method overrides `Trainer.get_train_dataloader` to support our custom batching strategy.
- # Instead of returning a standard per-step batch (i.e., `per_device_batch_size), our dataloader loads an
- # *generation* batch (i.e., `per_device_batch_size × steps_per_generation`). This allows us to generate completions
- # once every steps_per_generation step—rather than once per accumulation step—which is significantly more
- # efficient. The only change from the original implementation is multiplying the batch size by
- # `steps_per_generation`. Thus, `_prepare_inputs` is called with this *generation* batch, and it handles the
- # splitting internally.
- # Maintenance note: This method is a copy-paste of the original `Trainer.get_train_dataloader` with only one line
- # modification. As a result, some parts of the method aren't relevant to RLOO, but we keep them to stay one line
- # apart from the super method, ensuring easier maintenance in the future.
- def get_train_dataloader(self):
- if self.train_dataset is None:
- raise ValueError("Trainer: training requires a train_dataset.")
-
- train_dataset = self.train_dataset
- data_collator = self.data_collator
- if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
- train_dataset = self._remove_unused_columns(train_dataset, description="training")
- else:
- data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
-
- dataloader_params = {
- "batch_size": self._train_batch_size * self.args.steps_per_generation, # < this is the change
- "collate_fn": data_collator,
- "num_workers": self.args.dataloader_num_workers,
- "pin_memory": self.args.dataloader_pin_memory,
- "persistent_workers": self.args.dataloader_persistent_workers,
- }
-
- if not isinstance(train_dataset, torch.utils.data.IterableDataset):
- dataloader_params["sampler"] = self._get_train_sampler()
- dataloader_params["drop_last"] = self.args.dataloader_drop_last
- dataloader_params["worker_init_fn"] = partial(
- seed_worker, num_workers=self.args.dataloader_num_workers, rank=self.args.process_index
- )
-
- dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
-
- return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
-
- def _get_train_sampler(self, dataset: Optional[Dataset] = None) -> Sampler:
- # Returns a sampler that
- # 1. ensures each prompt is repeated across multiple processes. This guarantees that identical prompts are
- # distributed to different GPUs, allowing rewards to be computed and normalized correctly within each prompt
- # group. Using the same seed across processes ensures consistent prompt assignment, preventing discrepancies
- # in group formation.
- # 2. repeats the batch multiple times to allow reusing generations across multiple updates. Refer to
- # _prepare_inputs to see how the generations are stored and reused.
-
- # In the following figure, the values are the prompt indices. The first row shows the first sampled batch, the
- # second row shows the second sampled batch, and so on.
- #
- # | GPU 0 | GPU 1 |
- #
- # global_step step <-───> num_generations=2
- # <-───────> per_device_train_batch_size=3
- # grad_accum ▲ ▲ 0 0 0 0 1 1 2 2 <- Generate for the first `steps_per_generation` (prompts 0 to 11); store the completions; use the first slice to compute the loss
- # =2 ▼ | 0 1 3 3 4 4 5 5 <- Take the stored generations and use the second slice to compute the loss
- # |
- # | 1 2 6 6 7 7 8 8 <- Take the stored generations and use the third slice to compute the loss
- # steps_per_gen=4 ▼ 1 3 9 9 10 10 11 11 <- Take the stored generations and use the fourth slice to compute the loss
- #
- # 2 4 12 12 13 13 14 14 <- Generate for the second `steps_per_generation` (prompts 12 to 23); store the completions; use the first slice to compute the loss
- # 2 5 15 15 16 16 17 17 <- Take the stored generations and use the second slice to compute the loss
- # ...
- if dataset is None:
- dataset = self.train_dataset
- return RepeatSampler(
- data_source=dataset,
- mini_repeat_count=self.num_generations,
- batch_size=self.args.generation_batch_size // self.num_generations,
- repeat_count=self.num_iterations * self.args.steps_per_generation,
- shuffle=self.shuffle_dataset,
- seed=self.args.seed,
- )
-
- def _get_eval_sampler(self, eval_dataset) -> Sampler:
- # See _get_train_sampler for an explanation of the sampler.
- return RepeatSampler(
- data_source=eval_dataset,
- mini_repeat_count=self.num_generations,
- seed=self.args.seed,
- )
-
- @profiling_decorator
- def _get_per_token_logps_and_entropies(
- self,
- model,
- input_ids,
- attention_mask,
- logits_to_keep,
- batch_size=None,
- compute_entropy=False,
- pixel_values=None,
- image_grid_thw=None,
- num_images=None,
- pixel_attention_mask=None,
- image_sizes=None,
- token_type_ids=None,
- ) -> dict[str, Optional[torch.Tensor]]:
- """Compute log-probs and (optionally) entropies for each token."""
- batch_size = batch_size or input_ids.size(0) # Chunk inputs into smaller batches to reduce memory peak
- all_logps = []
- all_entropies = []
- for start in range(0, input_ids.size(0), batch_size):
- input_ids_batch = input_ids[start : start + batch_size]
- attention_mask_batch = attention_mask[start : start + batch_size]
-
- # Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't)
- model_inputs = {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch}
-
- if image_grid_thw is not None and pixel_values is not None:
- rows_per_image = image_grid_thw.prod(dim=-1)
- rows_per_sample = torch.split(rows_per_image, num_images)
- rows_per_sample = torch.stack([s.sum() for s in rows_per_sample])
- cum_rows = torch.cat([torch.tensor([0], device=rows_per_sample.device), rows_per_sample.cumsum(0)])
- row_start, row_end = cum_rows[start].item(), cum_rows[start + batch_size].item()
- model_inputs["pixel_values"] = pixel_values[row_start:row_end]
- cum_imgs = torch.tensor([0] + num_images).cumsum(0)
- img_start, img_end = cum_imgs[start], cum_imgs[start + batch_size]
- model_inputs["image_grid_thw"] = image_grid_thw[img_start:img_end]
- elif pixel_values is not None:
- model_inputs["pixel_values"] = pixel_values[start : start + batch_size]
- if pixel_attention_mask is not None:
- model_inputs["pixel_attention_mask"] = pixel_attention_mask[start : start + batch_size]
- if image_sizes is not None:
- model_inputs["image_sizes"] = image_sizes[start : start + batch_size]
- if token_type_ids is not None:
- model_inputs["token_type_ids"] = token_type_ids[start : start + batch_size]
-
- # Only add logits_to_keep if the model supports it
- if "logits_to_keep" in self.model_kwarg_keys:
- # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
- model_inputs["logits_to_keep"] = logits_to_keep + 1
-
- model_inputs["use_cache"] = False # only used in generation; set False to suppress warnings
-
- logits = model(**model_inputs).logits
- # Exclude the last value: it corresponds to the next token pred
- logits = logits[:, :-1, :] # (B, L-1, H)
- # Only keep the last logits_to_keep. For model that support logits_to_keep, this is a no-op.
- logits = logits[:, -logits_to_keep:, :] # (B, logits_to_keep, H)
- # Divide logits by sampling temperature.
- # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
- logits = logits / self.temperature
-
- completion_ids = input_ids_batch[:, -logits_to_keep:]
- logps = selective_log_softmax(logits, completion_ids) # compute logprobs
- all_logps.append(logps)
-
- if compute_entropy:
- with torch.no_grad():
- entropies = entropy_from_logits(logits)
- all_entropies.append(entropies)
-
- logps = torch.cat(all_logps, dim=0)
- entropies = torch.cat(all_entropies, dim=0) if compute_entropy else None
- return logps, entropies
-
- def _fix_param_name_to_vllm(self, name, extra_prefixes: Optional[list[str]] = None):
- extra_prefixes = extra_prefixes or []
- prefixes = ["_checkpoint_wrapped_module."] + extra_prefixes
- for prefix in prefixes:
- name = name.replace(prefix, "")
- return name
-
- def _sync_fsdp1_params_to_vllm(self, module: nn.Module, prefix: str = "", visited=None):
- """Memory-efficient post-order traversal of FSDP modules to extract full parameters and sync with vLLM."""
- # For FSDP1, we need to recurse into children and also use summon_full_params
- if visited is None:
- visited = set()
- for child_name, child_module in module.named_children():
- child_prefix = f"{prefix}.{child_name}" if prefix else child_name
- self._sync_fsdp1_params_to_vllm(
- child_module, prefix=child_prefix, visited=visited
- ) # recurse into the child
-
- if isinstance(module, FSDP):
- with FSDP.summon_full_params(module, recurse=False, writeback=False):
- for param_name, param in module.named_parameters():
- full_name = f"{prefix}.{param_name}" if prefix else param_name
- full_name = self._fix_param_name_to_vllm(full_name, extra_prefixes=["_fsdp_wrapped_module."])
-
- if full_name in visited:
- continue # skip FSDP subtrees already traversed
- visited.add(full_name)
-
- if self.vllm_mode == "server" and self.accelerator.is_main_process:
- self.vllm_client.update_named_param(full_name, param.data)
- elif self.vllm_mode == "colocate":
-
- pass
-
- pass
-
- def _sync_fsdp2_params_to_vllm(self, module: nn.Module):
- # For FSDP2, module already covers all parameters, so no need for recursion
- for name, param in module.items():
- if param.is_cpu:
- param = param.to(torch.device("cuda"))
- param = param.full_tensor()
-
- if self.vllm_mode == "server" and self.accelerator.is_main_process:
- self.vllm_client.update_named_param(name, param)
- elif self.vllm_mode == "colocate":
-
- pass
-
- pass
-
- @profiling_decorator
- def _move_model_to_vllm(self):
- # For DeepSpeed ZeRO-3 and FSDP, we need to gather all parameters before operations
- deepspeed_plugin = self.accelerator.state.deepspeed_plugin
- zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3
- if zero_stage_3:
- import deepspeed
-
- gather_if_zero3 = deepspeed.zero.GatheredParameters
- else:
- gather_if_zero3 = nullcontext
-
- if is_peft_model(self.model):
- # With PEFT and FSDP/DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as
- # merging adapters in a sharded manner is not supported.
- # TODO: does this work with FSDP?
- with gather_if_zero3(list(self.model.parameters())):
- self.model.merge_adapter()
-
- # Update vLLM weights while parameters are gathered
- if self.is_fsdp_enabled: # note if using FSDP, gather_if_zero3 is nullcontext
- # Update vLLM weights while parameters are gathered
- # For PEFT with FSDP we need to use the memory efficient post-order traversal
- fsdp_plugin = getattr(self.accelerator.state, "fsdp_plugin", None)
- fsdp_version = getattr(fsdp_plugin, "fsdp_version", 1) if fsdp_plugin else 1
- if fsdp_version == 1:
- self._sync_fsdp1_params_to_vllm(
- self.model
- ) # use memory-efficient post-order traversal for FSDP
- elif fsdp_version == 2:
- self._sync_fsdp2_params_to_vllm(self.model)
- else:
- # DeepSpeed ZeRO-3 with PEFT
- for name, param in self.model.named_parameters():
- # When using PEFT, we need to recover the original parameter name and discard some parameters
- name = name.removeprefix("base_model.model.").replace(".base_layer", "")
- if self.model.prefix in name:
- continue
- # When module to save, remove its prefix and discard the original module
- if "original_module" in name:
- continue
- name = self._fix_param_name_to_vllm(name, extra_prefixes=["modules_to_save.default."])
-
- if self.vllm_mode == "server" and self.accelerator.is_main_process:
- self.vllm_client.update_named_param(name, param.data)
- elif self.vllm_mode == "colocate":
-
- pass
-
- pass
- # Unmerge adapters while parameters are still gathered
- self.model.unmerge_adapter()
- # Parameters will automatically be repartitioned when exiting the context
- else:
- # For non-PEFT models, simply gather (if needed) and update each parameter individually.
- if self.is_fsdp_enabled:
- fsdp_plugin = getattr(self.accelerator.state, "fsdp_plugin", None)
- fsdp_version = getattr(fsdp_plugin, "fsdp_version", 1) if fsdp_plugin else 1
- if fsdp_version == 1:
- self._sync_fsdp1_params_to_vllm(self.model) # use memory-efficient post-order traversal for FSDP
- elif fsdp_version == 2:
- self._sync_fsdp2_params_to_vllm(self.model)
- else:
- for name, param in self.model.named_parameters():
- name = self._fix_param_name_to_vllm(name)
- with gather_if_zero3([param]):
- if self.vllm_mode == "server" and self.accelerator.is_main_process:
- self.vllm_client.update_named_param(name, param.data)
- elif self.vllm_mode == "colocate":
-
- pass
-
- pass
-
- # Reset cache on vLLM
- if self.vllm_mode == "server" and self.accelerator.is_main_process:
- self.vllm_client.reset_prefix_cache()
- elif self.vllm_mode == "colocate":
- self.llm.reset_prefix_cache()
-
- @profiling_decorator
- def _prepare_inputs(
- self, generation_batch: dict[str, Union[torch.Tensor, Any]]
- ) -> dict[str, Union[torch.Tensor, Any]]:
- # Prepares inputs for model training/evaluation by managing completion generation and batch handling.
- # During training:
- # - Receives the local generation batch (Per-GPU batch size × steps per generation)
- # from the modified training dataloader instead of the standard local batch
- # - Generates completions once for the entire generation batch and splits it into batches of size
- # `per_device_train_batch_size`
- # - Buffers these completions and returns the appropriate slice for the current accumulation step
- # - Optimizes by regenerating completions only periodically (every steps_per_generation * num_iterations)
- # During evaluation:
- # - The input is treated as a standard local batch (no accumulation, no multiple iterations)
- # - Completions are generated for each batch without buffering or reuse
- # Returns a single local batch in both cases.
-
- mode = "train" if self.model.training else "eval"
- if mode == "train":
- generate_every = self.args.steps_per_generation * self.num_iterations
- if self._step % generate_every == 0 or self._buffered_inputs is None:
- # self._buffered_inputs=None can occur when resuming from a checkpoint
- generation_batch = self._generate_and_score_completions(generation_batch)
- generation_batch = split_pixel_values_by_grid(generation_batch)
-
- try: generation_batch = shuffle_sequence_dict(generation_batch)
-
- except: pass
- generation_batches = split_tensor_dict(generation_batch, self.args.steps_per_generation)
- self._buffered_inputs = [unsplit_pixel_values_by_grid(batch) for batch in generation_batches]
- inputs = self._buffered_inputs[self._step % self.args.steps_per_generation]
- self._step += 1
- else:
- # In evaluation, there is neither batch grouping for generation, nor multiple iterations, hence
- # local generation batch == local eval batch
- inputs = self._generate_and_score_completions(generation_batch)
- return inputs
-
- @profiling_decorator
- def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list):
- device = self.accelerator.device
- rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
-
- # Repeat all input columns (but "prompt", "completion", and "completion_ids") to match the num of generations
- keys = [key for key in inputs[0] if key not in ["prompt", "completion", "completion_ids"]]
- reward_kwargs = {key: [example[key] for example in inputs] for key in keys}
-
- # This allows for dynamic reward shaping based on training progress.
- reward_kwargs["trainer_state"] = self.state
-
- for i, (reward_func, reward_processing_class, reward_func_name) in enumerate(
- zip(self.reward_funcs, self.reward_processing_classes, self.reward_func_names)
- ):
- with profiling_context(self, reward_func_name):
- if isinstance(reward_func, nn.Module): # Module (no PretrainedModel) for compat with compiled models
- if is_conversational(inputs[0]):
- messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
- texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
- else:
- texts = [p + c for p, c in zip(prompts, completions)]
- reward_inputs = reward_processing_class(
- text=texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
- )
- reward_inputs = super()._prepare_inputs(reward_inputs)
- with torch.inference_mode():
- rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
- else:
- output_reward_func = reward_func(
- prompts=prompts, completions=completions, completion_ids=completion_ids_list, **reward_kwargs
- )
- # Convert None values to NaN
- output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func]
-
- rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
-
- # If all reward functions return None for a given row, issue a detailed warning
- if torch.isnan(rewards_per_func).all(dim=1).any():
- nan_row_idx = torch.isnan(rewards_per_func).all(dim=1).nonzero(as_tuple=True)[0][0]
- row_reward_kwargs = {
- key: value[nan_row_idx] for key, value in reward_kwargs.items() if key != "trainer_state"
- }
- row_reward_kwargs["prompt"] = prompts[nan_row_idx]
- row_reward_kwargs["completion"] = completions[nan_row_idx]
- logger.warning(
- f"All reward functions returned None for the following kwargs:\n{row_reward_kwargs}\n"
- "Please ensure that at least one reward function returns a valid reward."
- )
-
- # Gather the reward per function: this part is crucial, because the rewards are normalized per group and the
- # completions may be distributed across processes
- rewards_per_func = gather(rewards_per_func)
- return rewards_per_func
-
- def _generate_single_turn(self, prompts: list[str], images: Optional[list]):
- device = self.accelerator.device
-
- # If the prompts are conversational and the inputs contain images, we need to convert the prompts from
- # [{"role": "user", "content": "What color is the sky?"}] to
- # [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}]
- kwargs = {}
- if images is not None:
- kwargs = {"images": images}
- for prompt, image_list in zip(prompts, images):
- if isinstance(prompt, list): # i.e., when using conversational data
- prepare_multimodal_messages(prompt, num_images=len(image_list))
-
- prompts_text = [
- maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts
- ]
-
- if images is not None:
- prompt_inputs = self.processing_class(text=prompts_text, padding=True, return_tensors="pt", **kwargs)
- prompt_inputs = super()._prepare_inputs(prompt_inputs)
- forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
- else:
- forward_kwargs = {}
-
- # Generate completions using either vLLM or regular generation
- if self.use_vllm:
- if self.vllm_mode == "colocate" and self.args.vllm_enable_sleep_mode:
- # wake up colocated vLLM instances if needed
- torch.cuda.empty_cache() # required to avoid OOM in some cases
- self.llm.wake_up()
-
- # First, update the vLLM weights if needed
- if self.state.global_step != self._last_loaded_step:
- self._move_model_to_vllm()
- self._last_loaded_step = self.state.global_step
-
- # Generate completions using vLLM: gather all prompts and use them in a single call in the main process
- if self.vllm_mode == "server":
- all_prompts_text = gather_object(prompts_text)
- if images is not None:
- all_images = gather_object(images)
-
- if self.accelerator.is_main_process:
- # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate
- # num_generations outputs for each one. This is faster than generating outputs for each duplicate
- # prompt individually.
- ordered_set_of_prompts = all_prompts_text[:: self.num_generations]
-
- if images is not None:
- ordered_set_of_images = all_images[:: self.num_generations]
- else:
- ordered_set_of_images = None
-
- with profiling_context(self, "vLLM.generate"):
- output = self.vllm_client.generate(
- prompts=ordered_set_of_prompts,
- images=ordered_set_of_images,
- n=self.num_generations,
- repetition_penalty=self.repetition_penalty,
- temperature=self.temperature,
- top_p=self.top_p,
- top_k=-1 if self.top_k is None else self.top_k,
- min_p=0.0 if self.min_p is None else self.min_p,
- max_tokens=self.max_completion_length,
- truncate_prompt_tokens=self.max_prompt_length,
- guided_decoding_regex=self.guided_decoding_regex,
- generation_kwargs=self.args.generation_kwargs,
- )
- payload = (output["prompt_ids"], output["completion_ids"], output["logprobs"])
- else:
- payload = None
-
- # Broadcast the completions from the main process to all processes, ensuring each process receives its corresponding slice.
- obj_list = [payload]
- broadcast_object_list(obj_list, from_process=0)
- all_prompt_ids, all_completion_ids, _ = obj_list[0]
-
- # At this point, we only get 1 copy of each prompt, so we need to repeat them num_generations times
- all_prompt_ids = [ids for ids in all_prompt_ids for _ in range(self.num_generations)]
-
- process_slice = slice(
- self.accelerator.process_index * len(prompts),
- (self.accelerator.process_index + 1) * len(prompts),
- )
- prompt_ids = all_prompt_ids[process_slice]
- completion_ids = all_completion_ids[process_slice]
-
- # Generate completions using colocated vLLM instances: each device holds vLLM copy and work on their own batch of prompts
- elif self.vllm_mode == "colocate":
- if self.guided_decoding_regex:
- guided_decoding = GuidedDecodingParams(regex=self.guided_decoding_regex)
- else:
- guided_decoding = None
-
- generation_kwargs = {
- "n": 1, # vLLM on each GPU generates only 1 in colocate mode
- "repetition_penalty": self.repetition_penalty,
- "temperature": self.temperature,
- "top_p": self.top_p,
- "top_k": -1 if self.top_k is None else self.top_k,
- "min_p": 0.0 if self.min_p is None else self.min_p,
- "max_tokens": self.max_completion_length,
- "truncate_prompt_tokens": self.max_prompt_length,
- "guided_decoding": guided_decoding,
- }
- if self.args.generation_kwargs is not None:
- generation_kwargs.update(self.args.generation_kwargs)
- sampling_params = SamplingParams(**grpo_update_SamplingParams(SamplingParams, generation_kwargs, getattr(self.args, 'vllm_sampling_params', None)))
-
- if self.vllm_tensor_parallel_size > 1:
- # Gather prompts from all ranks in the TP group and flatten.
- # Each rank starts with its own prompts; after gathering, all ranks see the full group set.
- orig_size = len(prompts_text)
- gathered_prompts = [None for _ in range(self.vllm_tensor_parallel_size)]
- torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.tp_group)
- all_prompts_text = [p for sublist in gathered_prompts for p in sublist]
-
- if images is not None:
- gathered_images = [None for _ in range(self.vllm_tensor_parallel_size)]
- torch.distributed.all_gather_object(gathered_images, images, group=self.tp_group)
- all_images = [img for sublist in gathered_images for img in sublist]
- else:
- all_images = None
- else:
- all_prompts_text = prompts_text
- all_images = images
-
- if images is not None and all_images:
- vllm_inputs = []
- for prompt, image_list in zip(all_prompts_text, all_images):
- vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image_list}})
-
- else:
- vllm_inputs = all_prompts_text
-
- with profiling_context(self, "vLLM.generate"):
- all_outputs = self.llm.generate(vllm_inputs, sampling_params=sampling_params, use_tqdm=False, lora_request = self.model.load_lora('rloo_trainer_lora_model', load_tensors = True))
-
- all_prompt_ids = [output.prompt_token_ids for output in all_outputs]
- all_completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs]
-
- if self.vllm_tensor_parallel_size > 1:
- # Slice completions for this rank within its TP group.
- # Each rank generates all outputs — we keep only our share.
- local_rank_in_group = torch.distributed.get_rank(group=self.tp_group)
- tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size)
- prompt_ids = all_prompt_ids[tp_slice]
- completion_ids = all_completion_ids[tp_slice]
- else:
- prompt_ids = all_prompt_ids
- completion_ids = all_completion_ids
-
- if self.args.vllm_enable_sleep_mode:
- self.llm.sleep(level=1)
-
- elif self.use_transformers_paged:
- # Re-process inputs for paged generation if needed
- # Note: images are already validated and preprocessed above
- paged_prompt_inputs = self.processing_class(text=prompts_text, **kwargs)
- previous_attn = self.model_wrapped.config._attn_implementation
-
- if is_flash_attn_2_available():
- self.model_wrapped.config._attn_implementation = "paged_attention"
- else:
- self.model_wrapped.config._attn_implementation = "sdpa_paged"
- with (
- profiling_context(self, "transformers.generate_batch"),
- unwrap_model_for_generation(
- self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
- ) as unwrapped_model,
- torch.no_grad(),
- FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(),
- ):
- # Cast to the appropriate dtype based on training configuration
- if self.args.bf16:
- unwrapped_model.to(torch.bfloat16)
- elif self.args.fp16:
- unwrapped_model.to(torch.float16)
- with torch.inference_mode():
- all_outputs = unwrapped_model.generate_batch(
- paged_prompt_inputs.input_ids, generation_config=self.generation_config, progress_bar=False
- )
- unwrapped_model.train() # restore training mode, as generate_batch forces eval mode
- completion_ids = [output.generated_tokens for output in all_outputs.values()]
- prompt_ids = paged_prompt_inputs.input_ids
- # Restore the original attention implementation, training mode
- self.model_wrapped.config._attn_implementation = previous_attn
-
- else:
- # Regular generation path
- generate_inputs = self.processing_class(
- text=prompts_text,
- return_tensors="pt",
- padding=True,
- padding_side="left",
- max_length=self.max_prompt_length,
- truncation=True,
- add_special_tokens=False,
- **kwargs,
- )
- generate_inputs = super()._prepare_inputs(generate_inputs)
-
- with (
- profiling_context(self, "transformers.generate"),
- unwrap_model_for_generation(
- self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
- ) as unwrapped_model,
- torch.no_grad(),
- FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(),
- ):
- prompt_completion_ids = unwrapped_model.generate(
- **generate_inputs, generation_config=self.generation_config, disable_compile=True
- )
- # Compute prompt length and extract completion ids
- prompt_ids, prompt_mask = generate_inputs["input_ids"], generate_inputs["attention_mask"]
- prompt_length = prompt_ids.size(1)
- completion_ids = prompt_completion_ids[:, prompt_length:]
-
- # Mask everything after the first EOS token
- is_eos = completion_ids == self.eos_token_id
- eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
- eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
- sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
- completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
- prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())]
- completion_ids = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool())]
-
- return prompt_ids, completion_ids, forward_kwargs
-
- def _generate(self, prompts: list[str], images: Optional[list]):
- device = self.accelerator.device
- mode = "train" if self.model.training else "eval"
-
- prompt_ids, completion_ids, forward_kwargs = self._generate_single_turn(prompts, images)
-
- # Get completion length per sequence, used for logging
- prompt_lengths = torch.tensor([len(ids) for ids in prompt_ids], device=device)
- completion_lengths = torch.tensor([len(ids) for ids in completion_ids], device=device)
- agg_prompt_lengths = self.accelerator.gather(prompt_lengths)
- agg_completion_lengths = self.accelerator.gather(completion_lengths)
- total_prompt_tokens = agg_prompt_lengths.sum()
- total_completion_tokens = agg_completion_lengths.sum() # = num_items_in_batch, required for the DAPO loss
-
- # Log the metrics
- if mode == "train":
- self.state.num_input_tokens_seen += (total_prompt_tokens + total_completion_tokens).item()
- self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen]
-
- # Log completion lengths, mean, min, max
- agg_completion_lengths = self.accelerator.gather(completion_lengths)
- self._metrics[mode]["completions/mean_length"].append(agg_completion_lengths.float().mean().item())
- self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item())
- self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item())
-
- # Identify sequences that terminated with EOS and log their lengths
- eos_and_pad = [self.eos_token_id, self.pad_token_id]
- is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids], device=device)
- agg_is_truncated = self.accelerator.gather(is_truncated)
- self._metrics[mode]["completions/clipped_ratio"].append(agg_is_truncated.float().mean().item())
- term_completion_lengths = agg_completion_lengths[~agg_is_truncated]
- if len(term_completion_lengths) == 0: # edge case where no terminated sequences are found
- term_completion_lengths = torch.zeros(1, device=device)
- self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item())
- self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item())
- self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item())
-
- return prompt_ids, completion_ids, forward_kwargs
-
- def _generate_and_score_completions(
- self, inputs: list[dict[str, Union[torch.Tensor, Any]]]
- ) -> dict[str, Union[torch.Tensor, Any]]:
- device = self.accelerator.device
- mode = "train" if self.model.training else "eval"
-
- prompts = [x["prompt"] for x in inputs]
-
- if "images" in inputs[0]:
- images = [example.get("images") for example in inputs]
- elif "image" in inputs[0]:
- images = [[example.get("image")] if example.get("image") is not None else None for example in inputs]
- else:
- images = None
- # Transformers requires at least one image in the batch, otherwise it throws an error
- if images is not None and all(img_list == [] for img_list in images):
- images = None
-
- prompt_ids_list, completion_ids_list, forward_kwargs = self._generate(prompts, images)
-
- # Convert lists of token IDs to padded tensors
- prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list]
- prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids]
- prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left")
- prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left")
- completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids_list]
- completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids]
- completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right")
- completion_mask = pad(completion_mask, padding_value=0, padding_side="right")
-
- # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask
- if self.mask_truncated_completions:
- eos_and_pad = [self.eos_token_id, self.pad_token_id]
- is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device)
- completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int()
-
- # Concatenate prompt_mask with completion_mask for logit computation
- prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C)
- attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
- # If token_type_ids are used, extend them with zeros for the completion part
- if "token_type_ids" in forward_kwargs:
- token_type_ids = forward_kwargs["token_type_ids"]
- forward_kwargs["token_type_ids"] = torch.cat(
- [token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1
- )
-
- logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
- batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
-
- num_images = [len(img_list) for img_list in images] if images is not None else None
-
- with torch.no_grad():
- # Compute the per-token log probabilities for the current model
- old_per_token_logps, _ = self._get_per_token_logps_and_entropies(
- self.model,
- prompt_completion_ids,
- attention_mask,
- logits_to_keep,
- batch_size,
- num_images=num_images,
- **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes
- )
- old_logps = (old_per_token_logps * completion_mask).sum(1) # mask out padding and tokens after EOS
-
- # Compute the per-token log probabilities for the reference model
- if self.beta != 0.0:
- if self.ref_model is not None:
- ref_per_token_logps, _ = self._get_per_token_logps_and_entropies(
- self.ref_model,
- prompt_completion_ids,
- attention_mask,
- logits_to_keep,
- batch_size=batch_size,
- num_images=num_images,
- **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes
- )
- else:
- with self.accelerator.unwrap_model(self.model).disable_adapter():
- ref_per_token_logps, _ = self._get_per_token_logps_and_entropies(
- self.model,
- prompt_completion_ids,
- attention_mask,
- logits_to_keep,
- batch_size=batch_size,
- num_images=num_images,
- **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes
- )
- else:
- ref_per_token_logps = None
-
- # Decode
- prompts_text = self.processing_class.batch_decode(prompt_ids, skip_special_tokens=True)
- completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
- if is_conversational(inputs[0]):
- completions = []
- for prompt, completion in zip(prompts, completions_text):
- bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
- completions.append([{"role": "assistant", "content": bootstrap + completion}])
- else:
- completions = completions_text
-
- # Calculate rewards for each reward function. rewards_per_func aggregates rewards across all processes. This is
- # important because rewards will be normalized per group, and completions are distributed. We will later slice
- # rewards_per_func to extract each process's subset.
- rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list)
-
- # Apply weights to each reward function's output and sum
- rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1)
-
- # Apply reward clipping if specified
- if self.reward_clip_range:
- rewards = rewards.clamp(min=self.reward_clip_range[0], max=self.reward_clip_range[1])
-
- # Include the KL penalty in the reward
- if self.beta != 0.0:
- per_token_kl = old_per_token_logps - ref_per_token_logps
- # Apply sequence-level KL penalty to rewards (sum KL across tokens first, then apply to each sequence)
- kl = (per_token_kl * completion_mask).sum(-1)
- kl = gather(kl) # rewards are gathered, so kl must be too
- rewards = rewards - self.beta * kl
-
- grouped_rewards = rewards.view(-1, self.num_generations)
- mean_grouped_rewards = grouped_rewards.mean(dim=1)
- std_rewards = grouped_rewards.std(dim=1)
- is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards))
-
- # RLOO advantages computation
- grouped_sum = grouped_rewards.sum(dim=1, keepdim=True) # (num_prompts, 1)
- baselines = (grouped_sum - grouped_rewards) / (self.num_generations - 1) # (num_prompts, num_generations)
- baselines = baselines.view(-1) # Flatten back to match rewards shape
- advantages = rewards - baselines
-
- # Normalize advantages
- if self.normalize_advantages:
- advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-4)
-
- # Slice to keep only the local part of the data
- process_slice = slice(
- self.accelerator.process_index * len(prompts),
- (self.accelerator.process_index + 1) * len(prompts),
- )
- all_process_advantages = advantages.clone() # keep the aggregated advantages for logging
- advantages = advantages[process_slice]
-
- # Calculate and log the mean KL divergence between current and reference model
- if self.beta != 0.0:
- mean_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum().clamp(min=1.0)
- self._metrics[mode]["kl"].append(self.accelerator.gather(mean_kl).nanmean().item())
-
- # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values)
- for i, reward_func_name in enumerate(self.reward_func_names):
- mean_rewards = torch.nanmean(rewards_per_func[:, i]).item()
- self._metrics[mode][f"rewards/{reward_func_name}/mean"].append(mean_rewards)
- std_func_rewards = nanstd(rewards_per_func[:, i]).item()
- self._metrics[mode][f"rewards/{reward_func_name}/std"].append(std_func_rewards)
- self._metrics[mode]["reward"].append(mean_grouped_rewards.mean().item())
- self._metrics[mode]["reward_std"].append(std_rewards.mean().item())
- self._metrics[mode]["frac_reward_zero_std"].append(is_std_zero.float().mean().item())
-
- # Log prompt and completion texts
- self._logs["prompt"].extend(gather_object(prompts_text))
- self._logs["completion"].extend(gather_object(completions_text))
- for i, name in enumerate(self.reward_func_names):
- self._logs["rewards"][name].extend(rewards_per_func[:, i].tolist())
- self._logs["advantages"].extend(all_process_advantages.tolist())
-
- if images is not None:
- self._logs["images"].extend(gather_object(images))
-
- output = {
- "prompt_ids": prompt_ids,
- "prompt_mask": prompt_mask,
- "completion_ids": completion_ids,
- "completion_mask": completion_mask,
- "old_logps": old_logps,
- "advantages": advantages,
- }
- if "pixel_values" in forward_kwargs:
- output["pixel_values"] = forward_kwargs["pixel_values"]
- if "image_grid_thw" in forward_kwargs:
- output["image_grid_thw"] = forward_kwargs["image_grid_thw"]
- if "pixel_attention_mask" in forward_kwargs:
- output["pixel_attention_mask"] = forward_kwargs["pixel_attention_mask"]
- if "image_sizes" in forward_kwargs:
- output["image_sizes"] = forward_kwargs["image_sizes"]
- if "token_type_ids" in forward_kwargs:
- output["token_type_ids"] = forward_kwargs["token_type_ids"]
- if images is not None:
- output["num_images"] = num_images
- return output
-
- @profiling_decorator
- def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
- if return_outputs:
- raise ValueError("The RLOOTrainer does not support returning outputs")
- return self._compute_loss(model, inputs)
-
- def _compute_loss(self, model, inputs):
- # Compute the per-token log probabilities for the model
- prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
- completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
- input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
- attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
- logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
-
- # Compute the per_token_logps and the entropy at each position in the completion
- per_token_logps, entropies = self._get_per_token_logps_and_entropies(
- model,
- input_ids,
- attention_mask,
- logits_to_keep,
- compute_entropy=True,
- pixel_values=inputs.get("pixel_values"),
- image_grid_thw=inputs.get("image_grid_thw"),
- num_images=inputs.get("num_images"),
- pixel_attention_mask=inputs.get("pixel_attention_mask"),
- image_sizes=inputs.get("image_sizes"),
- token_type_ids=inputs.get("token_type_ids"),
- )
-
- logps = (per_token_logps * completion_mask).sum(1) # mask out padding and tokens after EOS
- old_logps = inputs["old_logps"]
- log_ratio = logps - old_logps
-
- # Compute the loss
- advantages = inputs["advantages"]
- coef_1 = torch.exp(log_ratio)
- coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)
- per_sequence_loss1 = coef_1 * advantages
- per_sequence_loss2 = coef_2 * advantages
- per_sequence_loss = -torch.min(per_sequence_loss1, per_sequence_loss2)
- loss = per_sequence_loss.mean()
-
- # Log the metrics
- mode = "train" if self.model.training else "eval"
-
- # Entropy
- mean_entropy = (entropies * completion_mask).sum() / completion_mask.sum().clamp(min=1.0)
- self._metrics[mode]["entropy"].append(self.accelerator.gather(mean_entropy).nanmean().item())
-
- # Compute the clipped probability ratios
- is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages < 0)
- is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages > 0)
- is_region_clipped = is_low_clipped | is_high_clipped
- gathered_low_clip = self.accelerator.gather(is_low_clipped.float().mean())
- self._metrics[mode]["clip_ratio/low_mean"].append(gathered_low_clip.nanmean().item())
- self._metrics[mode]["clip_ratio/low_min"].append(nanmin(gathered_low_clip).item())
- gathered_high_clip = self.accelerator.gather(is_high_clipped.float().mean())
- self._metrics[mode]["clip_ratio/high_mean"].append(gathered_high_clip.nanmean().item())
- self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item())
- gathered_clip_ratio = self.accelerator.gather(is_region_clipped.float().mean())
- self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item())
- return loss
-
- def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: Optional[list[str]] = None):
- inputs = self._prepare_inputs(inputs)
- with torch.no_grad():
- with self.compute_loss_context_manager():
- loss = self.compute_loss(model, inputs)
- loss = loss.mean().detach()
- return loss, None, None
-
- def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
- mode = "train" if self.model.training else "eval"
- metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics
-
- # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
- # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
- if mode == "eval":
- metrics = {f"eval_{key}": val for key, val in metrics.items()}
-
- logs = {**logs, **metrics}
- super().log(logs, start_time)
- self._metrics[mode].clear()
-
- if self.accelerator.is_main_process and self.log_completions:
- if is_rich_available():
- print_prompt_completions_sample(
- self._logs["prompt"],
- self._logs["completion"],
- self._logs["rewards"],
- self._logs["advantages"],
- self.state.global_step,
- self.num_completions_to_print,
- )
-
- if self.args.report_to and "wandb" in self.args.report_to and wandb.run is not None:
- import pandas as pd
-
- table = {
- "step": [str(self.state.global_step)] * len(self._logs["prompt"]),
- "prompt": self._logs["prompt"],
- "completion": self._logs["completion"],
- **self._logs["rewards"],
- "advantage": self._logs["advantages"],
- }
-
- if self._logs["images"]:
- table["images"] = []
- for image_list in self._logs["images"]:
- # Convert images to wandb Image objects for proper visualization
- table["images"].append([wandb.Image(image) for image in image_list])
-
- df = pd.DataFrame(table)
- if self.wandb_log_unique_prompts:
- df = df.drop_duplicates(subset=["prompt"])
- wandb.log({"completions": wandb.Table(dataframe=df)})
-
- # Ensure the model card is saved along with the checkpoint
- def _save_checkpoint(self, model, trial):
- if self.args.hub_model_id is None:
- model_name = Path(self.args.output_dir).name
- else:
- model_name = self.args.hub_model_id.split("/")[-1]
- self.create_model_card(model_name=model_name)
- super()._save_checkpoint(model, trial)
-class UnslothRLOOTrainer(_UnslothRLOOTrainer):
- """
-
- Trainer for the Reinforce Leave One Out (RLOO) method. This algorithm was initially proposed in the paper [Back to
- Basics: Revisiting REINFORCE Style Optimization for Learning from Human Feedback in
- LLMs](https://huggingface.co/papers/2402.14740).
-
- Example:
-
- ```python
- from datasets import load_dataset
- from trl import RLOOTrainer
-
- dataset = load_dataset("trl-lib/tldr", split="train")
- def reward_func(completions, **kwargs):
- # Dummy reward function that rewards completions with more unique letters.
- return [float(len(set(completion))) for completion in completions]
- trainer = RLOOTrainer(
- model="Qwen/Qwen2-0.5B-Instruct",
- reward_funcs=reward_func,
- train_dataset=dataset,
- )
-
- trainer.train()
- ```
-
- Args:
- model (`Union[str, PreTrainedModel]`):
- Model to be trained. Can be either:
-
- - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a
- path to a *directory* containing model weights saved using
- [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
- using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keyword arguments in
- `args.model_init_kwargs`.
- - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
- reward_funcs (`Union[RewardFunc, list[RewardFunc]]`):
- Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward
- functions with the prompts and completions and sum the rewards. Can be either:
-
- - A single reward function, such as:
- - A string: The *model ID* of a pretrained model hosted inside a model repo on huggingface.co, or a
- path to a *directory* containing model weights saved using
- [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
- using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the
- keyword arguments in `args.model_init_kwargs`.
- - A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported.
- - A custom reward function: The function is provided with the prompts and the generated completions,
- plus any additional columns in the dataset. It should return a list of rewards. Custom reward
- functions can also return `None` when the reward is not applicable to those samples. This is useful
- for multi-task training where different reward functions apply to different types of samples. When a
- reward function returns `None` for a sample, that reward function is excluded from the reward
- calculation for that sample. For more details, see [Using a custom reward
- function](#using-a-custom-reward-function).
-
- The trainer's state is also passed to the reward function. The trainer's state is an instance of
- [`~transformers.TrainerState`] and can be accessed by accessing the `trainer_state` argument to the
- reward function's signature.
- - A list of reward functions, where each item can independently be any of the above types. Mixing different
- types within the list (e.g., a string model ID and a custom reward function) is allowed.
- args ([`RLOOConfig`], *optional*):
- Configuration for this trainer. If `None`, a default configuration is used.
- train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
- Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset is
- ignored. The format of the samples can be either:
-
- - [Standard](dataset_formats#standard): Each sample contains plain text.
- - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
- and content).
- eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
- Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
- processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.ProcessorMixin`], *optional*):
- Processing class used to process the data. The padding side must be set to "left". If `None`, the
- processing class is loaded from the model's name with [`~transformers.AutoProcessor.from_pretrained`]. A
- padding token, `tokenizer.pad_token`, must be set. If the processing class has not set a padding token,
- `tokenizer.eos_token` will be used as the default.
- reward_processing_classes ([`~transformers.PreTrainedTokenizerBase`] or `list[PreTrainedTokenizerBase]`, *optional*):
- Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either:
-
- - A single processing class: Used when `reward_funcs` contains only one reward function.
- - A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`.
- If set to `None`, or if an element of the list corresponding to a [`~transformers.PreTrainedModel`] is
- `None`, the tokenizer for the model is automatically loaded using
- [`~transformers.AutoTokenizer.from_pretrained`]. For elements in `reward_funcs` that are custom reward
- functions (not [`~transformers.PreTrainedModel`]), the corresponding entries in `reward_processing_classes`
- are ignored.
- callbacks (list of [`~transformers.TrainerCallback`], *optional*):
- List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed
- in [here](https://huggingface.co/docs/transformers/main_classes/callback).
-
- If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
- method.
- optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
- A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
- model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
- peft_config ([`~peft.PeftConfig`], *optional*):
- PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
-
- config:
-
-
-
- This parameter is deprecated and will be removed in version 0.25.0. Use `args` instead.
-
-
-
- reward_model:
-
-
- This parameter is deprecated and will be removed in version 0.25.0. Use `reward_funcs` instead.
-
-
-
- policy:
-
-
-
- This parameter is deprecated and will be removed in version 0.25.0. Use `model` instead.
-
-
-
- ref_policy:
-
-
-
- This parameter is deprecated and will be removed in version 0.25.0. To use the initial model as the
- reference model, simply omit this parameter. The parameter is ignored.
-
-
-
- data_collator:
-
-
-
- This parameter is deprecated and will be removed in version 0.25.0. The RLOOTrainer does not use a data
- collator, so this parameter is ignored.
-
-
-
- """
- def __init__(
- self,
- model = None,
- reward_funcs = None,
- args = None,
- train_dataset = None,
- eval_dataset = None,
- processing_class = None,
- reward_processing_classes = None,
- callbacks = None,
- peft_config = None,
- config = None,
- reward_model = None,
- policy = None,
- ref_policy = None,
- data_collator = None,
- **kwargs
- ):
- if args is None: args = UnslothRLOOConfig()
- use_bf16 = getattr(args, 'bf16', False)
- if type(use_bf16) is not bool: use_bf16 = False
- use_fp16 = getattr(args, 'fp16', False)
- if type(use_fp16) is not bool: use_fp16 = False
- force_float32 = False
- full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1'
- if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'):
- print('Unsloth: Switching to float32 training since model cannot work with float16')
- force_float32 = True
- mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
- dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None)
- if dtype is None: dtype = model.get_input_embeddings().weight.dtype
- from unsloth_zoo.utils import _get_dtype
- dtype = _get_dtype(dtype)
- float16 = dtype == torch.float16
- if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
- if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
- if force_float32:
- # Forced float32 training
- args.fp16 = False
- args.bf16 = False
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
- if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
- # args.mixed_precision is a new argument which needs to be set now
- elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
- # Mixed precision training
- args.fp16 = float16
- args.bf16 = not float16
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
- if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16'
- # args.mixed_precision is a new argument which needs to be set now
- elif mixed_precision_dtype == 'bfloat16':
- # Both False since bfloat16 full finetuning doesn't do any autocasting.
- args.fp16 = False
- args.bf16 = False
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
- if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
- # args.mixed_precision is a new argument which needs to be set now
-
- if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
- args.eval_strategy = 'steps'
- if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
- ga_steps = getattr(args, 'gradient_accumulation_steps', None)
- if ga_steps is not None and ga_steps > 1:
- from transformers import __version__ as transformers_version
- if Version(transformers_version) <= Version('4.45.2'):
- print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
- '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
- if getattr(args, 'eval_strategy', 'no') != 'no':
- eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
- if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
- if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
- fp16_full_eval = getattr(args, 'fp16_full_eval', False)
- if type(fp16_full_eval) is not bool: fp16_full_eval = False
- bf16_full_eval = getattr(args, 'bf16_full_eval', False)
- if type(bf16_full_eval) is not bool: bf16_full_eval = False
- if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
- if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
- if force_float32:
- args.bf16_full_eval = False
- args.fp16_full_eval = False
- elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
- args.bf16_full_eval = True
- args.fp16_full_eval = False
- elif not bf16_full_eval and not fp16_full_eval:
- args.bf16_full_eval = args.bf16
- args.fp16_full_eval = args.fp16
- _output_logits = False
- if locals().get('compute_metrics', None) is not None: _output_logits = True
- if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
- if _output_logits:
- os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
- if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
- pass
- else:
- model_max_seq_length = getattr(model, 'max_seq_length', None)
- args_max_seq_length = getattr(args, 'max_seq_length', None)
- if args_max_seq_length is None and model_max_seq_length is not None:
- max_seq_length = model.max_seq_length
- if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
- elif args_max_seq_length is not None and model_max_seq_length is not None:
- if args_max_seq_length > model_max_seq_length:
- print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but '
- 'the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.')
- args.max_seq_length = model_max_seq_length
- if model is not None and hasattr(model, 'for_training'):
- model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
- if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
- if 'processing_class' in locals():
- if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
- if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
- __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
- from unsloth_zoo.vision_utils import UnslothVisionDataCollator
- if not isinstance(data_collator, UnslothVisionDataCollator):
- if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
- data_collator = TransformersDataCollatorForLanguageModeling(
- __tokenizer,
- mlm = False,
- mlm_probability = 0.0,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
- data_collator = DataCollatorForSeq2Seq(
- __tokenizer,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- else:
- if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
- if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
- if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
- if not isinstance(data_collator, UnslothVisionDataCollator):
- if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
- if isinstance(data_collator, DataCollatorForSeq2Seq):
- data_collator = DataCollatorForSeq2Seq(
- __tokenizer.tokenizer,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- else:
- data_collator = TransformersDataCollatorForLanguageModeling(
- __tokenizer.tokenizer,
- mlm = False,
- mlm_probability = 0.0,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- other_metrics = []
-
- from unsloth_zoo.logging_utils import PatchRLStatistics
- PatchRLStatistics('rloo_trainer', other_metrics)
-
- # [TODO] Fix up DataParallel multiplying batch sizes
- # [TODO] DDP works, but DP seems to not work? [TODO]
- if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1:
- if getattr(args, "_n_gpu", 1) != 1:
- args._n_gpu = 1
- if "model" in locals() and hasattr(model, "for_training"):
- model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
- super().__init__(
- model = model,
- reward_funcs = reward_funcs,
- args = args,
- train_dataset = train_dataset,
- eval_dataset = eval_dataset,
- processing_class = processing_class,
- reward_processing_classes = reward_processing_classes,
- callbacks = callbacks,
- peft_config = peft_config,
- config = config,
- reward_model = reward_model,
- policy = policy,
- ref_policy = ref_policy,
- data_collator = data_collator,**kwargs)
- if "model" in locals() and hasattr(model, "for_inference"):
- model.for_inference()
- if hasattr(self, 'neftune_hook_handle'):
- self.neftune_hook_handle.remove()
- if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
- if getattr(args, 'neftune_noise_alpha', None) is not None:
- model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
- pass
- if hasattr(self, 'accelerator'):
- scaler = self.accelerator.scaler
- current_model = model
- while hasattr(current_model, 'model'):
- current_model.accelerator_scaler = scaler
- current_model = current_model.model
- current_model.accelerator_scaler = scaler
- pass
- if hasattr(self, 'train'):
- self.train = MethodType(prepare_for_training_mode(self.__class__.train), self)
- pass
- if hasattr(self, 'llm') and self.llm is not None and hasattr(self.llm, 'get_tokenizer'):
- _vllm_tok = self.llm.get_tokenizer()
- _pc = getattr(self, 'processing_class', None) or getattr(self, 'tokenizer', None)
- if _vllm_tok is not None and _pc is not None and getattr(_pc, 'chat_template', None) is not None and getattr(_vllm_tok, 'chat_template', None) is None:
- _vllm_tok.chat_template = _pc.chat_template
- pass
-
-pass
-
-
-if hasattr(logger, "addFilter"):
- import logging
- class HideLoggingMessage(logging.Filter):
- def __init__(self, text): self.text = text
- def filter(self, x): return not (self.text in x.getMessage())
- pass
- logger.addFilter(HideLoggingMessage("`use_cache=True`"))
-
diff --git a/unsloth_compiled_cache/UnslothRewardTrainer.py b/unsloth_compiled_cache/UnslothRewardTrainer.py
deleted file mode 100644
index 5fd7b70..0000000
--- a/unsloth_compiled_cache/UnslothRewardTrainer.py
+++ /dev/null
@@ -1,1351 +0,0 @@
-"""
-2026.2.1
-2026.2.1
-4.57.6
-0.24.0
-__UNSLOTH_VERSIONING__
-"""
-
-# Unsloth auto generated code
-# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Lesser General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU Lesser General Public License
-# along with this program. If not, see .
-
-from torch import Tensor
-import torch
-import torch.nn as nn
-from torch.nn import functional as F
-from unsloth_zoo.temporary_patches.common import torch_compile
-from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
-from trl.trainer.reward_trainer import (Any, AutoModelForSequenceClassification, AutoTokenizer, BaseTrainer, Callable, DataCollator, DataCollatorForPreference, Dataset, EvalPrediction, IterableDataset, Optional, PartialState, Path, PeftConfig, PreTrainedModel, PreTrainedTokenizerBase, RewardConfig, RewardTrainer, TrainerCallback, Union, clone_chat_template, contextlib, dataclass, defaultdict, disable_dropout_in_model, get_act_offloading_ctx_manager, is_conversational, logger, logging, nn, os, pad, re, remove_none_values, suppress_from_pretrained_warning, torch, transformers, Any, AutoModelForSequenceClassification, AutoTokenizer, Callable, DataCollator, DataCollatorForPreference, Dataset, EvalPrediction, IterableDataset, Optional, PeftConfig, PreTrainedModel, PreTrainedTokenizerBase, RewardConfig, TrainerCallback, Union, clone_chat_template, contextlib, defaultdict, disable_dropout_in_model, get_act_offloading_ctx_manager, logger, os, pad, re, suppress_from_pretrained_warning, torch, transformers, Optional, PreTrainedModel, logger, os, re, torch)
-
-
-import os
-from typing import *
-from dataclasses import dataclass, field
-from packaging.version import Version
-import torch
-import numpy as np
-from contextlib import nullcontext
-from torch.nn import functional as F
-import inspect
-from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
-from transformers.training_args import ParallelMode
-from unsloth_zoo.device_type import DEVICE_TYPE, device_synchronize
-
-# Wrap trainer with padding to right and enable training mode
-# Also patches W&B since multiple runs must use wandb.finish()
-import functools
-from types import MethodType
-try:
- from unsloth_zoo.gradient_checkpointing import reset_unsloth_gradient_checkpointing_buffers
-except:
- def reset_unsloth_gradient_checkpointing_buffers(): pass
-def prepare_for_training_mode(f):
- @functools.wraps(f)
- def wrapper(self, *args, **kwargs):
- # Enable training mode
- _was_training = None
- # Get gradient checkpointing setting from training arguments
- use_gc = getattr(self.args, 'gradient_checkpointing', True)
- if hasattr(self, 'model') and hasattr(self.model, "training"):
- _was_training = self.model.training
- if hasattr(self, 'model') and hasattr(self.model, "for_training"):
- self.model.for_training(use_gradient_checkpointing=use_gc)
- output = f(self, *args, **kwargs)
- # Restore previous mode when possible
- if hasattr(self, 'model') and hasattr(self.model, "for_inference"):
- if _was_training is False:
- self.model.for_inference()
- elif _was_training is True and hasattr(self.model, "for_training"):
- self.model.for_training(use_gradient_checkpointing=use_gc)
- # Reset gradient checkpointing buffers to free memory while staying ready for next run
- try:
- reset_unsloth_gradient_checkpointing_buffers()
- except:
- pass
- # Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run
- try:
- import wandb
- wandb.finish()
- except:
- pass
- return output
- return wrapper
-pass
-
-torch_compile_options = {
- "epilogue_fusion" : True,
- "max_autotune" : False,
- "shape_padding" : True,
- "trace.enabled" : False,
- "triton.cudagraphs" : False,
-}
-
-@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
-def chunked_hidden_states_selective_log_softmax(
- hidden_states: torch.Tensor,
- lm_head: torch.Tensor,
- index: torch.Tensor,
- chunks: int = 4,
- logit_scale_multiply: float = 0.0,
- logit_scale_divide: float = 0.0,
- logit_softcapping: float = 0.0,
- temperature: float = 1.0,
-) -> torch.Tensor:
- # All Unsloth Zoo code licensed under AGPL3
- flat_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1])
- flat_index = index.reshape(-1)
-
- chunked_hidden_states = torch.chunk(flat_hidden_states, chunks=chunks, dim=0)
- chunked_index = torch.chunk(flat_index, chunks=chunks, dim=0)
-
- all_per_token_logps = []
-
- for chunk_hidden_states, chunk_index in zip(chunked_hidden_states, chunked_index):
- chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t()
-
- if logit_scale_multiply != 0.0:
- chunk_logits = chunk_logits * logit_scale_multiply
- if logit_scale_divide != 0.0:
- chunk_logits = chunk_logits / logit_scale_divide
- if logit_softcapping != 0.0:
- chunk_logits = chunk_logits * torch.tanh(chunk_logits / logit_softcapping)
-
- chunk_logits = chunk_logits.to(torch.float32)
-
- if temperature != 1.0:
- chunk_logits = chunk_logits / temperature
-
- selected_logits = torch.gather(chunk_logits, dim=-1, index=chunk_index.unsqueeze(-1)).squeeze(-1)
- logsumexp_values = torch.logsumexp(chunk_logits, dim=-1)
- per_token_logps = selected_logits - logsumexp_values
- all_per_token_logps.append(per_token_logps)
-
- all_per_token_logps = torch.concat(all_per_token_logps)
-
- all_per_token_logps = all_per_token_logps.reshape((hidden_states.shape[0], hidden_states.shape[1]))
- return all_per_token_logps
-
-@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
-def chunked_selective_log_softmax(logits, index):
- # Split into 4 chunks only
- chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
- chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
- all_per_token_logps = []
- # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
- for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
- chunk_logits = chunk_logits.to(torch.float32)
- selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
- logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
- per_token_logps = selected_logits - logsumexp_values
- all_per_token_logps.append(per_token_logps)
- pass
- all_per_token_logps = torch.concat(all_per_token_logps)
- all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
- return all_per_token_logps
-
-def calculate_pad_tokens_in_prompt(
- input_ids: torch.Tensor,
- logits_to_keep: int,
- pad_token_id: int
-) -> torch.Tensor:
- """
- Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens
- """
- if logits_to_keep >= input_ids.shape[1]:
- raise ValueError("logits_to_keep must be smaller than the sequence length.")
-
- prompt_section = input_ids[:, :-logits_to_keep]
-
- padding_mask = (prompt_section == pad_token_id)
-
- pad_token_counts = padding_mask.sum(dim=1)
-
- return pad_token_counts
-
-def create_completion_attention_mask(
- completion_input_ids: torch.Tensor,
- left_pad_tokens_per_prompt: torch.Tensor,
- max_left_pad: int,
- pad_token_id: int
-) -> torch.Tensor:
- """
- Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad]
-
- Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens
- and pad are pad tokens, this function would make a completion mask that would 0 out the pad
- and p tokens. so in this example [0,0,0,1,1,1,0,0,0]
- """
- batch_size, completion_len = completion_input_ids.shape
- device = completion_input_ids.device
-
- num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt
-
- indices = torch.arange(completion_len, device=device).unsqueeze(0)
- shift_mask = indices >= num_tokens_to_mask.unsqueeze(1)
-
- non_padding_mask = (completion_input_ids != pad_token_id)
-
- final_mask = shift_mask & non_padding_mask
-
- return final_mask
-
-def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor:
- """
- Moves all padding tokens in each sequence of a batch to the right.
- """
- mask = (tensor != pad_id)
- # Must do stable=True since binary mark is unordered
- sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True)
- packed_tensor = torch.gather(tensor, 1, sorted_indices)
- return packed_tensor
-
-def align_logprobs_with_mask(
- logprob_tensor: torch.Tensor,
- attention_mask: torch.Tensor,
- pad_value: float = 0.0
-) -> torch.Tensor:
- """
- Aligns a log probability tensor with a given attention mask.
- """
-
- device = logprob_tensor.device
- batch_size, logprob_seq_len = logprob_tensor.shape
- mask_seq_len = attention_mask.shape[1]
-
- padded_logprobs = torch.full(
- attention_mask.shape,
- fill_value=pad_value,
- dtype=logprob_tensor.dtype,
- device=device
- )
-
- left_pad_counts = torch.argmax(attention_mask, dim=1)
-
- cols = torch.arange(logprob_seq_len, device=device)
- dest_indices = left_pad_counts.unsqueeze(1) + cols
-
- # Create destination row indices
- # Shape: [batch_size, logprob_seq_len]
- row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices)
-
- # --- 4. Filter out-of-bounds indices and perform assignment ---
- # Create a mask to identify only the indices that are within the bounds
- # of the target tensor's sequence length.
- valid_mask = dest_indices < mask_seq_len
-
- # Use this mask to select only the valid row indices, column indices,
- # and the corresponding values from the logprob tensor.
- # This flattens the selected elements into 1D tensors.
- valid_rows = row_indices[valid_mask]
- valid_cols = dest_indices[valid_mask]
- valid_vals = logprob_tensor[valid_mask]
-
- # Place the valid values into their correct positions in the padded tensor
- # using a single, efficient advanced indexing operation.
- padded_logprobs[valid_rows, valid_cols] = valid_vals
-
- return padded_logprobs
-
-def autotune_batch_and_chunks(
- total_input_rows,
- seq_len,
- hidden_size,
- vocab_size,
- dtype_bytes=16,
- multiplier=None
-):
- if multiplier is None:
- final_m = max(4, seq_len // 4096)
- else:
- final_m = multiplier
-
- if torch.cuda.is_available():
- free_bytes, _ = torch.cuda.mem_get_info()
- limit_gb = (free_bytes / (1024**3))*.80
- elif hasattr(torch, "xpu") and torch.xpu.is_available():
- # For XPU: estimate free memory from total - reserved
- total_mem = torch.xpu.get_device_properties(0).total_memory
- reserved_mem = torch.xpu.memory_reserved()
- free_bytes = total_mem - reserved_mem
- limit_gb = (free_bytes / (1024**3)) * 0.80
- else:
- # Fallback: assume 8GB available
- limit_gb = 8.0
-
- bytes_to_gb = 1024**3
-
- b_vals = torch.arange(total_input_rows, 0, -1, device='cpu', dtype=torch.float32)
-
- hidden_gb = (b_vals * seq_len * hidden_size * dtype_bytes) / bytes_to_gb
-
- base_logits = ((b_vals/total_input_rows) * b_vals * seq_len * vocab_size * dtype_bytes) / bytes_to_gb
- logits_gb = base_logits / final_m
-
- total_mem_gb = hidden_gb + logits_gb
-
- valid_mask = total_mem_gb <= limit_gb
- valid_indices = torch.nonzero(valid_mask, as_tuple=False)
-
- if valid_indices.shape[0] == 0:
- #This means your GPU will OOM
- return 4, final_m
-
- best_idx = valid_indices[0].item()
- final_b = int(b_vals[best_idx].item())
-
- return final_b, final_m
-@dataclass
-class UnslothRewardConfig(RewardConfig):
- """
-
- Configuration class for the [`RewardTrainer`].
-
- This class includes only the parameters that are specific to Reward training. For a full list of training
- arguments, please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this
- class may differ from those in [`~transformers.TrainingArguments`].
-
- Using [`~transformers.HfArgumentParser`] we can turn this class into
- [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
- command line.
-
- Parameters:
- > Parameters that control the model
-
- model_init_kwargs (`dict[str, Any]`, *optional*):
- Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
- argument of the [`RewardTrainer`] is provided as a string. If you're training a MoE architecture and want
- to include the load balancing/auxilliary loss as a part of the final loss, remember to set
- `output_router_logits=True` in this dictionary.
- chat_template_path (`str`, *optional*):
- If specified, sets the model's chat template. This can either be the path to a tokenizer (local directory
- or Hugging Face Hub model) or a direct path to a Jinja template file. When using a Jinja file, you must
- ensure that any special tokens referenced in the template are added to the tokenizer and that the model's
- embedding layer is resized accordingly.
- disable_dropout (`bool`, *optional*, defaults to `True`):
- Whether to disable dropout in the model.
-
- > Parameters that control the data preprocessing
-
- dataset_num_proc (`int`, *optional*):
- Number of processes to use for processing the dataset.
- eos_token (`str`, *optional*):
- Token used to indicate the end of a turn or sequence. If `None`, it defaults to
- `processing_class.eos_token`.
- pad_token (`str`, *optional*):
- Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that is also `None`,
- it falls back to `processing_class.eos_token`.
- max_length (`int` or `None`, *optional*, defaults to `1024`):
- Maximum length of the tokenized sequence. Samples are filtered out if either chosen or rejected sequence
- exceeds this value. If `None`, no filtering is applied.
- pad_to_multiple_of (`int`, *optional*):
- If set, the sequences will be padded to a multiple of this value.
-
- > Parameters that control the training
-
- center_rewards_coefficient (`float`, *optional*):
- Coefficient to incentivize the reward model to output mean-zero rewards (proposed by
- https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`.
- activation_offloading (`bool`, *optional*, defaults to `False`):
- Whether to offload the activations to the CPU.
-
- """
- vllm_sampling_params: Optional[Any] = field(
- default = None,
- metadata = {'help': 'vLLM SamplingParams'},
- )
- unsloth_num_chunks : Optional[int] = field(
- default = -1,
- metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
- )
- unsloth_logit_chunk_multiplier : Optional[int] = field(
- default = None,
- metadata = {'help': 'Multiplier for chunked logit computations.'},
- )
- unsloth_grpo_mini_batch : Optional[int] = field(
- default = None,
- metadata = {'help': 'Mini batch size for GRPO hidden state accumulation. Default is None unless user defines it.'},
- )
- max_seq_length : Optional[int] = field(
- default = None,
- metadata = {'help': 'Maximum sequence length to truncate to.'},
- )
- def __init__(
- self,
- output_dir = None,
- overwrite_output_dir = None,
- do_train = False,
- do_eval = False,
- do_predict = False,
- eval_strategy = 'no',
- prediction_loss_only = False,
- per_device_train_batch_size = 4,
- per_device_eval_batch_size = 4,
- per_gpu_train_batch_size = None,
- per_gpu_eval_batch_size = None,
- gradient_accumulation_steps = 2,
- eval_accumulation_steps = 2,
- eval_delay = 0,
- torch_empty_cache_steps = 250,
- learning_rate = 5e-05,
- weight_decay = 0.01,
- adam_beta1 = 0.9,
- adam_beta2 = 0.999,
- adam_epsilon = 1e-08,
- max_grad_norm = 1.0,
- num_train_epochs = 3.0,
- max_steps = -1,
- lr_scheduler_type = 'linear',
- lr_scheduler_kwargs = None,
- warmup_ratio = 0.1,
- warmup_steps = 0,
- log_level = 'passive',
- log_level_replica = 'warning',
- log_on_each_node = True,
- logging_dir = None,
- logging_strategy = 'steps',
- logging_first_step = False,
- logging_steps = 1,
- logging_nan_inf_filter = False,
- save_strategy = 'steps',
- save_steps = 500,
- save_total_limit = None,
- save_safetensors = True,
- save_on_each_node = False,
- save_only_model = False,
- restore_callback_states_from_checkpoint = False,
- no_cuda = False,
- use_cpu = False,
- use_mps_device = False,
- seed = 3407,
- data_seed = 3407,
- jit_mode_eval = False,
- bf16 = False,
- fp16 = False,
- fp16_opt_level = 'O1',
- half_precision_backend = 'auto',
- bf16_full_eval = False,
- fp16_full_eval = False,
- tf32 = None,
- local_rank = -1,
- ddp_backend = None,
- tpu_num_cores = None,
- tpu_metrics_debug = False,
- debug = '',
- dataloader_drop_last = False,
- eval_steps = None,
- dataloader_num_workers = 0,
- dataloader_prefetch_factor = None,
- past_index = -1,
- run_name = None,
- disable_tqdm = None,
- remove_unused_columns = True,
- label_names = None,
- load_best_model_at_end = False,
- metric_for_best_model = None,
- greater_is_better = None,
- ignore_data_skip = False,
- fsdp = None,
- fsdp_min_num_params = 0,
- fsdp_config = None,
- fsdp_transformer_layer_cls_to_wrap = None,
- accelerator_config = None,
- parallelism_config = None,
- deepspeed = None,
- label_smoothing_factor = 0.0,
- optim = 'adamw_8bit',
- optim_args = None,
- adafactor = False,
- group_by_length = False,
- length_column_name = 'length',
- report_to = 'none',
- project = 'huggingface',
- trackio_space_id = 'trackio',
- ddp_find_unused_parameters = None,
- ddp_bucket_cap_mb = None,
- ddp_broadcast_buffers = None,
- dataloader_pin_memory = True,
- dataloader_persistent_workers = False,
- skip_memory_metrics = True,
- use_legacy_prediction_loop = False,
- push_to_hub = False,
- resume_from_checkpoint = None,
- hub_model_id = None,
- hub_strategy = 'every_save',
- hub_token = None,
- hub_private_repo = None,
- hub_always_push = False,
- hub_revision = None,
- gradient_checkpointing = True,
- gradient_checkpointing_kwargs = None,
- include_inputs_for_metrics = False,
- eval_do_concat_batches = True,
- fp16_backend = 'auto',
- push_to_hub_model_id = None,
- push_to_hub_organization = None,
- push_to_hub_token = None,
- mp_parameters = '',
- auto_find_batch_size = False,
- full_determinism = False,
- torchdynamo = None,
- ray_scope = 'last',
- ddp_timeout = 1800,
- torch_compile = False,
- torch_compile_backend = None,
- torch_compile_mode = None,
- include_tokens_per_second = False,
- include_num_input_tokens_seen = False,
- neftune_noise_alpha = None,
- optim_target_modules = None,
- batch_eval_metrics = False,
- eval_on_start = False,
- use_liger_kernel = False,
- liger_kernel_config = None,
- eval_use_gather_object = False,
- average_tokens_across_devices = True,
- model_init_kwargs = None,
- chat_template_path = None,
- disable_dropout = True,
- dataset_num_proc = None,
- eos_token = None,
- pad_token = None,
- max_length = 1024,
- pad_to_multiple_of = None,
- center_rewards_coefficient = None,
- activation_offloading = False,
- vllm_sampling_params = None,
- unsloth_num_chunks = -1,
- unsloth_logit_chunk_multiplier = None,
- unsloth_grpo_mini_batch = None,
- max_seq_length = None,
- **kwargs,
- ):
- if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
- if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
- if num_train_epochs is None:
- num_train_epochs = 3.0 # Default to 3 epochs if None, max_steps will override
- if output_dir is None and save_strategy == 'steps' and save_steps == 500:
- output_dir = 'unsloth_training_checkpoints'
- save_strategy = 'no'
- import multiprocessing as _mp
- if _mp.get_start_method() != 'fork':
- dataset_num_proc = None
- elif dataset_num_proc is None:
- import psutil
- dataset_num_proc = min(max((psutil.cpu_count() or 1)+4, 2), 64)
- memory_gb_left = psutil.virtual_memory().available / (1024**3)
- if memory_gb_left <= 2: dataset_num_proc = 1
- else: dataset_num_proc = min(dataset_num_proc, int(memory_gb_left))
- if os.environ.get('UNSLOTH_ENABLE_FLEX_ATTENTION', '0') == '1':
- from unsloth_zoo.flex_attention import HAS_FLEX_ATTENTION
- if HAS_FLEX_ATTENTION and pad_to_multiple_of is None:
- from unsloth_zoo.flex_attention import FLEX_ATTENTION_BLOCK_SIZE
- pad_to_multiple_of = FLEX_ATTENTION_BLOCK_SIZE
-
-
- super().__init__(
- output_dir = output_dir,
- overwrite_output_dir = overwrite_output_dir,
- do_train = do_train,
- do_eval = do_eval,
- do_predict = do_predict,
- eval_strategy = eval_strategy,
- prediction_loss_only = prediction_loss_only,
- per_device_train_batch_size = per_device_train_batch_size,
- per_device_eval_batch_size = per_device_eval_batch_size,
- per_gpu_train_batch_size = per_gpu_train_batch_size,
- per_gpu_eval_batch_size = per_gpu_eval_batch_size,
- gradient_accumulation_steps = gradient_accumulation_steps,
- eval_accumulation_steps = eval_accumulation_steps,
- eval_delay = eval_delay,
- torch_empty_cache_steps = torch_empty_cache_steps,
- learning_rate = learning_rate,
- weight_decay = weight_decay,
- adam_beta1 = adam_beta1,
- adam_beta2 = adam_beta2,
- adam_epsilon = adam_epsilon,
- max_grad_norm = max_grad_norm,
- num_train_epochs = num_train_epochs,
- max_steps = max_steps,
- lr_scheduler_type = lr_scheduler_type,
- lr_scheduler_kwargs = lr_scheduler_kwargs,
- warmup_ratio = warmup_ratio,
- warmup_steps = warmup_steps,
- log_level = log_level,
- log_level_replica = log_level_replica,
- log_on_each_node = log_on_each_node,
- logging_dir = logging_dir,
- logging_strategy = logging_strategy,
- logging_first_step = logging_first_step,
- logging_steps = logging_steps,
- logging_nan_inf_filter = logging_nan_inf_filter,
- save_strategy = save_strategy,
- save_steps = save_steps,
- save_total_limit = save_total_limit,
- save_safetensors = save_safetensors,
- save_on_each_node = save_on_each_node,
- save_only_model = save_only_model,
- restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
- no_cuda = no_cuda,
- use_cpu = use_cpu,
- use_mps_device = use_mps_device,
- seed = seed,
- data_seed = data_seed,
- jit_mode_eval = jit_mode_eval,
- bf16 = bf16,
- fp16 = fp16,
- fp16_opt_level = fp16_opt_level,
- half_precision_backend = half_precision_backend,
- bf16_full_eval = bf16_full_eval,
- fp16_full_eval = fp16_full_eval,
- tf32 = tf32,
- local_rank = local_rank,
- ddp_backend = ddp_backend,
- tpu_num_cores = tpu_num_cores,
- tpu_metrics_debug = tpu_metrics_debug,
- debug = debug,
- dataloader_drop_last = dataloader_drop_last,
- eval_steps = eval_steps,
- dataloader_num_workers = dataloader_num_workers,
- dataloader_prefetch_factor = dataloader_prefetch_factor,
- past_index = past_index,
- run_name = run_name,
- disable_tqdm = disable_tqdm,
- remove_unused_columns = remove_unused_columns,
- label_names = label_names,
- load_best_model_at_end = load_best_model_at_end,
- metric_for_best_model = metric_for_best_model,
- greater_is_better = greater_is_better,
- ignore_data_skip = ignore_data_skip,
- fsdp = fsdp,
- fsdp_min_num_params = fsdp_min_num_params,
- fsdp_config = fsdp_config,
- fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
- accelerator_config = accelerator_config,
- parallelism_config = parallelism_config,
- deepspeed = deepspeed,
- label_smoothing_factor = label_smoothing_factor,
- optim = optim,
- optim_args = optim_args,
- adafactor = adafactor,
- group_by_length = group_by_length,
- length_column_name = length_column_name,
- report_to = report_to,
- project = project,
- trackio_space_id = trackio_space_id,
- ddp_find_unused_parameters = ddp_find_unused_parameters,
- ddp_bucket_cap_mb = ddp_bucket_cap_mb,
- ddp_broadcast_buffers = ddp_broadcast_buffers,
- dataloader_pin_memory = dataloader_pin_memory,
- dataloader_persistent_workers = dataloader_persistent_workers,
- skip_memory_metrics = skip_memory_metrics,
- use_legacy_prediction_loop = use_legacy_prediction_loop,
- push_to_hub = push_to_hub,
- resume_from_checkpoint = resume_from_checkpoint,
- hub_model_id = hub_model_id,
- hub_strategy = hub_strategy,
- hub_token = hub_token,
- hub_private_repo = hub_private_repo,
- hub_always_push = hub_always_push,
- hub_revision = hub_revision,
- gradient_checkpointing = gradient_checkpointing,
- gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
- include_inputs_for_metrics = include_inputs_for_metrics,
- eval_do_concat_batches = eval_do_concat_batches,
- fp16_backend = fp16_backend,
- push_to_hub_model_id = push_to_hub_model_id,
- push_to_hub_organization = push_to_hub_organization,
- push_to_hub_token = push_to_hub_token,
- mp_parameters = mp_parameters,
- auto_find_batch_size = auto_find_batch_size,
- full_determinism = full_determinism,
- torchdynamo = torchdynamo,
- ray_scope = ray_scope,
- ddp_timeout = ddp_timeout,
- torch_compile = torch_compile,
- torch_compile_backend = torch_compile_backend,
- torch_compile_mode = torch_compile_mode,
- include_tokens_per_second = include_tokens_per_second,
- include_num_input_tokens_seen = include_num_input_tokens_seen,
- neftune_noise_alpha = neftune_noise_alpha,
- optim_target_modules = optim_target_modules,
- batch_eval_metrics = batch_eval_metrics,
- eval_on_start = eval_on_start,
- use_liger_kernel = use_liger_kernel,
- liger_kernel_config = liger_kernel_config,
- eval_use_gather_object = eval_use_gather_object,
- average_tokens_across_devices = average_tokens_across_devices,
- model_init_kwargs = model_init_kwargs,
- chat_template_path = chat_template_path,
- disable_dropout = disable_dropout,
- dataset_num_proc = dataset_num_proc,
- eos_token = eos_token,
- pad_token = pad_token,
- max_length = max_length,
- pad_to_multiple_of = pad_to_multiple_of,
- center_rewards_coefficient = center_rewards_coefficient,
- activation_offloading = activation_offloading,**kwargs)
- self.vllm_sampling_params = vllm_sampling_params
- self.unsloth_num_chunks = unsloth_num_chunks
- if unsloth_grpo_mini_batch is not None:
- if self.generation_batch_size >= unsloth_grpo_mini_batch:
- self.unsloth_grpo_mini_batch = unsloth_grpo_mini_batch
- else:
- raise ValueError(
- f"Unsloth GRPO mini batch size needs to be less than or equal to the effective generation batch size, "
- f"which is self.per_device_train_batch_size * gradient_accumulation_steps."
- )
- self.unsloth_logit_chunk_multiplier = unsloth_logit_chunk_multiplier
- self.max_seq_length = max_seq_length
-
-pass
-
-class _UnslothRewardTrainer(BaseTrainer):
- """"""
-
- _tag_names = ["trl", "reward-trainer"]
- _name = "Reward"
- _template_file = "rm_model_card.md"
-
- def __init__(
- self,
- model: Union[str, PreTrainedModel],
- args: Optional[RewardConfig] = None,
- data_collator: Optional[DataCollator] = None,
- train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
- eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
- processing_class: Optional[PreTrainedTokenizerBase] = None,
- compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
- callbacks: Optional[list[TrainerCallback]] = None,
- optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
- optimizer_cls_and_kwargs: Optional[tuple[type[torch.optim.Optimizer], dict[str, Any]]] = None,
- preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
- peft_config: Optional["PeftConfig"] = None,
- ):
- # Args
- if args is None:
- model_name = model if isinstance(model, str) else model.config._name_or_path
- model_name = model_name.split("/")[-1]
- args = RewardConfig(f"{model_name}-Reward")
-
- # Model
- model_init_kwargs = args.model_init_kwargs or {}
- if isinstance(model, str):
- model_id = model
- dtype = model_init_kwargs.get("dtype")
- if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None:
- pass # dtype is already a torch.dtype or "auto" or None
- elif isinstance(dtype, str) and dtype in ["bfloat16", "float16", "float32"]:
- model_init_kwargs["dtype"] = getattr(torch, dtype)
- else:
- raise ValueError(
- "Invalid `dtype` passed to `RewardConfig`. Expected either 'auto' or a string representing "
- f"a valid `torch.dtype` (e.g., 'float32'), but got {dtype}."
- )
- with suppress_from_pretrained_warning(transformers.modeling_utils.logger):
- model = AutoModelForSequenceClassification.from_pretrained(model_id, num_labels=1, **model_init_kwargs)
- else:
- model_id = model.config._name_or_path
- if args.model_init_kwargs is not None:
- logger.warning(
- "You passed `model_init_kwargs` to the `RewardConfig`, but your model is already instantiated. "
- "The `model_init_kwargs` will be ignored."
- )
-
- # Processing class
- if processing_class is None:
- processing_class = AutoTokenizer.from_pretrained(model_id)
-
- # Handle pad token for processors or tokenizers
- if args.eos_token is not None:
- eos_token = args.eos_token
- eos_token_id = processing_class.convert_tokens_to_ids(eos_token)
- if eos_token_id is None:
- raise ValueError(
- f"The specified `eos_token` ('{eos_token}') is not found in the vocabulary of the given "
- f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `eos_token` exists "
- "in the vocabulary before using it as an EOS token."
- )
- processing_class.eos_token_id = eos_token_id
-
- if args.chat_template_path is not None:
- if os.path.isfile(args.chat_template_path) and args.chat_template_path.endswith((".jinja", ".j2")):
- with open(args.chat_template_path, encoding="utf-8") as chat_template_file:
- processing_class.chat_template = chat_template_file.read()
- added_tokens = []
- else:
- model, processing_class, added_tokens = clone_chat_template(
- model, processing_class, args.chat_template_path
- )
- else:
- added_tokens = []
-
- # PEFT configuration and model wrapping
- if False:
- if added_tokens:
- # Ensure that the added tokens are trainable
- if peft_config.trainable_token_indices is None:
- peft_config.trainable_token_indices = {"embed_tokens": added_tokens}
- elif "embed_tokens" not in peft_config.trainable_token_indices:
- peft_config.trainable_token_indices["embed_tokens"] = added_tokens
- else:
- peft_config.trainable_token_indices["embed_tokens"].extend(added_tokens)
-
- # Ensure that the lm_head is trainable
- if peft_config.modules_to_save is None or "lm_head" not in peft_config.modules_to_save:
- logger.warning(
- "Cloning chat template added new tokens to the tokenizer, but 'lm_head' is not in PEFT's "
- "`modules_to_save`. As a result, the model may not learn to generate outputs with these new "
- "tokens, leading to degraded generation quality. To fix this, add "
- "`modules_to_save=['lm_head']` to your PEFT configuration."
- )
-
- if peft_config.modules_to_save is None:
- peft_config.modules_to_save = ["lm_head"]
- else:
- peft_config.modules_to_save.append("lm_head")
-
- if False:
- pass
-
- # Disable dropout in the model
- if args.disable_dropout:
- disable_dropout_in_model(model)
-
- # Pad token [needed for SequenceClassification models]
- # If not provided, use the one from the processing class or the eos token if the processing class does not have
- # a pad token.
- pad_token = args.pad_token or processing_class.pad_token or processing_class.eos_token
- pad_token_id = processing_class.convert_tokens_to_ids(pad_token)
- if pad_token_id is None:
- raise ValueError(
- f"The specified `pad_token` ('{pad_token}') is not found in the vocabulary of the given "
- f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `pad_token` exists "
- "in the vocabulary before using it as a padding token."
- )
- model.config.pad_token_id = pad_token_id
- processing_class.pad_token_id = pad_token_id
-
- # Data collator
- if data_collator is None:
- data_collator = DataCollatorForPreference(
- pad_token_id=pad_token_id,
- pad_to_multiple_of=args.pad_to_multiple_of,
- )
-
- # Dataset
- train_dataset = self._prepare_dataset(train_dataset, processing_class, args, "train")
- if eval_dataset is not None:
- if isinstance(eval_dataset, dict):
- eval_dataset = {
- key: self._prepare_dataset(dataset, processing_class, args, key)
- for key, dataset in eval_dataset.items()
- }
- else:
- eval_dataset = self._prepare_dataset(eval_dataset, processing_class, args, "eval")
-
- # Initialize the metrics
- self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
- self._total_train_tokens = 0
-
- # Initialize the Trainer. Parent class will handle:
- # - DeepSpeed configuration [through create_accelerator_and_postprocess]
- # - FSDP setup
- # - Distributed training setup
- # - Optimizer and scheduler creation
-
- super().__init__(
- model=model,
- args=args,
- data_collator=data_collator,
- train_dataset=train_dataset,
- eval_dataset=eval_dataset,
- processing_class=processing_class,
- compute_metrics=compute_metrics,
- callbacks=callbacks,
- optimizers=optimizers,
- optimizer_cls_and_kwargs=optimizer_cls_and_kwargs,
- preprocess_logits_for_metrics=preprocess_logits_for_metrics,
- )
-
- # During evaluation, Trainer calls compute_loss[] only if can_return_loss is True and label_names is empty.
- self.can_return_loss = True
- self.label_names = []
-
- # Initialize activation offloading context
- if self.args.activation_offloading:
- self.maybe_activation_offload_context = get_act_offloading_ctx_manager(model=self.model)
- else:
- self.maybe_activation_offload_context = contextlib.nullcontext()
-
- # Add tags for models that have been loaded with the correct transformers version
- if hasattr(self.model, "add_model_tags"):
- self.model.add_model_tags(self._tag_names)
-
- self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
-
- def _prepare_dataset(
- self,
- dataset: Union[Dataset, IterableDataset],
- processing_class: PreTrainedTokenizerBase,
- args: RewardConfig,
- dataset_name: str,
- ) -> Union[Dataset, IterableDataset]:
- # Tabular backends like Arrow/Parquet insert `None` for mismatched keys in nested structures. Clean them from
- # sampled data.
- if isinstance(dataset, Dataset): # IterableDataset does not support `with_transform`
- dataset = dataset.with_transform(remove_none_values)
-
- # If the dataset is already preprocessed (tokenized), skip the processing steps.
- column_names = list(next(iter(dataset)).keys())
- is_processed = "chosen_input_ids" in column_names and "rejected_input_ids" in column_names
-
- # Build the kwargs for the `map` function
- map_kwargs = {}
- if isinstance(dataset, Dataset): # IterableDataset does not support num_proc
- map_kwargs["num_proc"] = args.dataset_num_proc
-
- with PartialState().main_process_first():
- if not is_processed:
- # Add EOS token to the end of the sequences if needed
- first_example = next(iter(dataset))
- if not is_conversational(first_example):
- if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
- map_kwargs["desc"] = f"Adding EOS to {dataset_name} dataset"
-
- def add_eos(example, eos_token):
- if not example["chosen"].endswith(eos_token):
- example["chosen"] = example["chosen"] + eos_token
- if "rejected" in example and not example["rejected"].endswith(eos_token):
- example["rejected"] = example["rejected"] + eos_token
- return example
-
- dataset = dataset.map(
- add_eos,
- fn_kwargs={"eos_token": processing_class.eos_token},
- **map_kwargs,
- )
-
- # Tokenize the dataset
- if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
- map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"
-
- def tokenize_fn(example, processing_class):
- if "prompt" in example: # explicit prompt case
- example["chosen"] = example["prompt"] + example["chosen"]
- example["rejected"] = example["prompt"] + example["rejected"]
-
- if is_conversational(example):
- chosen_input_ids = processing_class.apply_chat_template(
- example["chosen"],
- tools=example.get("tools"),
- **example.get("chat_template_kwargs", {}),
- )
- rejected_input_ids = processing_class.apply_chat_template(
- example["rejected"],
- tools=example.get("tools"),
- **example.get("chat_template_kwargs", {}),
- )
- output = {"chosen_input_ids": chosen_input_ids, "rejected_input_ids": rejected_input_ids}
- else:
- output = {
- "chosen_input_ids": processing_class(text=example["chosen"])["input_ids"],
- "rejected_input_ids": processing_class(text=example["rejected"])["input_ids"],
- }
- return output
-
- dataset = dataset.map(tokenize_fn, fn_kwargs={"processing_class": processing_class}, **map_kwargs)
-
- # Filter samples that are longer than `max_length`
- if args.max_length is not None:
- if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
- map_kwargs["desc"] = f"Filtering {dataset_name} >{args.max_length} tokens"
- dataset = dataset.filter(
- lambda example: len(example["chosen_input_ids"]) <= args.max_length
- and len(example["rejected_input_ids"]) <= args.max_length,
- **map_kwargs,
- )
-
- return dataset
-
- def _set_signature_columns_if_needed(self):
- # If `self.args.remove_unused_columns` is True, non-signature columns are removed.
- # By default, this method sets `self._signature_columns` to the model's expected inputs (usually, "input_ids"
- # and "attention_mask").
- if self._signature_columns is None:
- self._signature_columns = ["chosen_input_ids", "rejected_input_ids", "margin"]
-
- def compute_loss(
- self,
- model: nn.Module,
- inputs: dict[str, Union[torch.Tensor, Any]],
- return_outputs: bool = False,
- num_items_in_batch: Optional[torch.Tensor] = None,
- ):
- """
- Compute training loss and additionally compute token accuracies
- """
- mode = "train" if self.model.training else "eval"
-
- # If not set, defaults from model config and may warn since cache isn't compatible with gradient checkpointing
- inputs["use_cache"] = False
- outputs = model(**inputs)
-
- # Split the rewards into chosen and rejected
- rewards_chosen, rewards_rejected = torch.chunk(outputs.logits.squeeze(-1), chunks=2)
-
- # Calculate loss, optionally modulate with margin
- if "margin" in inputs:
- loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs["margin"]).mean()
- else:
- loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean()
-
- if self.args.center_rewards_coefficient is not None:
- loss += self.args.center_rewards_coefficient * torch.mean((rewards_chosen + rewards_rejected) ** 2)
-
- if mode == "train":
- num_tokens_in_batch = self.accelerator.gather_for_metrics(inputs["attention_mask"].sum()).sum().item()
- self._total_train_tokens += num_tokens_in_batch
- self._metrics[mode]["num_tokens"] = [self._total_train_tokens]
-
- # Compute min, mean, max, accuracy and margin
- with torch.no_grad():
- all_rewards = self.accelerator.gather(outputs.logits)
- self._metrics[mode]["min_reward"].append(all_rewards.min().item())
- self._metrics[mode]["mean_reward"].append(all_rewards.mean().item())
- self._metrics[mode]["max_reward"].append(all_rewards.max().item())
-
- mean_accuracy = (rewards_chosen > rewards_rejected).float().mean()
- mean_accuracy = self.accelerator.gather_for_metrics(mean_accuracy).mean().item()
- self._metrics[mode]["accuracy"].append(mean_accuracy)
-
- mean_margin = (rewards_chosen - rewards_rejected).mean()
- mean_margin = self.accelerator.gather_for_metrics(mean_margin).mean()
- self._metrics[mode]["margin"].append(mean_margin.item())
-
- return (loss, outputs) if return_outputs else loss
-
- # Override training step to add activation offloading context.
- def training_step(self, *args, **kwargs):
- with self.maybe_activation_offload_context:
- return super().training_step(*args, **kwargs)
-
- def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
- mode = "train" if self.model.training else "eval"
- metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics
-
- # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
- # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
- if mode == "eval":
- metrics = {f"eval_{key}": val for key, val in metrics.items()}
-
- logs.update(metrics)
- super().log(logs, start_time)
- self._metrics[mode].clear()
-
- # Ensure the model card is saved along with the checkpoint
- def _save_checkpoint(self, model, trial):
- if self.args.hub_model_id is None:
- model_name = Path(self.args.output_dir).name
- else:
- model_name = self.args.hub_model_id.split("/")[-1]
- self.create_model_card(model_name=model_name)
- super()._save_checkpoint(model, trial)
-class UnslothRewardTrainer(_UnslothRewardTrainer):
- """
-
- Trainer for Outcome-supervised Reward Models (ORM).
-
- This class is a wrapper around the [`~transformers.Trainer`] class and inherits all of its attributes and methods.
-
- Example:
-
- ```python
- from trl import RewardTrainer
- from datasets import load_dataset
-
- dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
-
- trainer = RewardTrainer(model="Qwen/Qwen2.5-0.5B-Instruct", train_dataset=dataset)
- trainer.train()
- ```
-
- Args:
- model (`Union[str, PreTrainedModel]`):
- Model to be trained. Can be either:
-
- - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a
- path to a *directory* containing model weights saved using
- [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
- using `AutoModelForSequenceClassification.from_pretrained` with the keyword arguments in
- `args.model_init_kwargs`.
- - A sequence classification [`~transformers.PreTrainedModel`] object.
- args ([`RewardConfig`], *optional*):
- Configuration for this trainer. If `None`, a default configuration is used.
- data_collator ([`~transformers.DataCollator`], *optional*):
- Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`.
- Will default to [`~trainer.reward_trainer.DataCollatorForPreference`].
- train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
- Dataset to use for training. This trainer supports [preference](#preference) type (both implicit and
- explicit prompt). The format of the samples can be either:
-
- - [Standard](dataset_formats#standard): Each sample contains plain text.
- - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
- and content).
-
- The trainer also supports processed datasets (tokenized) as long as they contain an `chosen_input_ids` and
- `rejected_input_ids` fields.
- eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
- Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
- processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*):
- Tokenizer used to process the data. If `None`, the tokenizer is loaded from the model's name with
- [`~transformers.AutoTokenizer.from_pretrained`]. A padding token, `processing_class.pad_token`, must be
- set. If the processing class has not set a padding token, `processing_class.eos_token` will be used as the
- default.
- compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
- The function that will be used to compute metrics at evaluation. Must take a
- [`~transformers.EvalPrediction`] and return a dictionary string to metric values. When passing
- [`RewardConfig`] with `batch_eval_metrics` set to `True`, your `compute_metrics` function must take a
- boolean `compute_result` argument. This will be triggered after the last eval batch to signal that the
- function needs to calculate and return the global summary statistics rather than accumulating the
- batch-level statistics.
- callbacks (list of [`~transformers.TrainerCallback`], *optional*):
- List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed
- in [here](https://huggingface.co/docs/transformers/main_classes/callback).
-
- If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
- method.
- optimizers (`tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]]`, *optional*, defaults to `(None, None)`):
- A tuple containing the optimizer and the scheduler to use. Will default to an instance of `AdamW` on your
- model and a scheduler given by [`~transformers.get_linear_schedule_with_warmup`] controlled by `args`.
- optimizer_cls_and_kwargs (`tuple[Type[torch.optim.Optimizer], Dict[str, Any]]`, *optional*):
- A tuple containing the optimizer class and keyword arguments to use. Overrides `optim` and `optim_args` in
- `args`. Incompatible with the `optimizers` argument.
-
- Unlike `optimizers`, this argument avoids the need to place model parameters on the correct devices before
- initializing the Trainer.
- preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*):
- A function that preprocess the logits right before caching them at each evaluation step. Must take two
- tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
- by this function will be reflected in the predictions received by `compute_metrics`.
-
- Note that the labels (second parameter) will be `None` if the dataset does not have them.
- peft_config ([`~peft.PeftConfig`], *optional*):
- PEFT configuration used to wrap the model. If `None`, the model is not wrapped. Note that if the loaded
- model is a causal LM, it's highly recommended to set `modules_to_save=["score"]` in the PEFT configuration
- to ensure that the reward head is properly trained.
-
- """
- def __init__(
- self,
- model,
- args = None,
- data_collator = None,
- train_dataset = None,
- eval_dataset = None,
- processing_class = None,
- compute_metrics = None,
- callbacks = None,
- optimizer_cls_and_kwargs = None,
- preprocess_logits_for_metrics = None,
- peft_config = None,
- **kwargs
- ):
- if args is None: args = UnslothRewardConfig()
- use_bf16 = getattr(args, 'bf16', False)
- if type(use_bf16) is not bool: use_bf16 = False
- use_fp16 = getattr(args, 'fp16', False)
- if type(use_fp16) is not bool: use_fp16 = False
- force_float32 = False
- full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1'
- if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'):
- print('Unsloth: Switching to float32 training since model cannot work with float16')
- force_float32 = True
- mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
- dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None)
- if dtype is None: dtype = model.get_input_embeddings().weight.dtype
- from unsloth_zoo.utils import _get_dtype
- dtype = _get_dtype(dtype)
- float16 = dtype == torch.float16
- if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
- if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
- if force_float32:
- # Forced float32 training
- args.fp16 = False
- args.bf16 = False
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
- if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
- # args.mixed_precision is a new argument which needs to be set now
- elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
- # Mixed precision training
- args.fp16 = float16
- args.bf16 = not float16
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
- if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16'
- # args.mixed_precision is a new argument which needs to be set now
- elif mixed_precision_dtype == 'bfloat16':
- # Both False since bfloat16 full finetuning doesn't do any autocasting.
- args.fp16 = False
- args.bf16 = False
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
- if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
- # args.mixed_precision is a new argument which needs to be set now
-
- if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
- args.eval_strategy = 'steps'
- if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
- ga_steps = getattr(args, 'gradient_accumulation_steps', None)
- if ga_steps is not None and ga_steps > 1:
- from transformers import __version__ as transformers_version
- if Version(transformers_version) <= Version('4.45.2'):
- print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
- '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
- if getattr(args, 'eval_strategy', 'no') != 'no':
- eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
- if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
- if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
- fp16_full_eval = getattr(args, 'fp16_full_eval', False)
- if type(fp16_full_eval) is not bool: fp16_full_eval = False
- bf16_full_eval = getattr(args, 'bf16_full_eval', False)
- if type(bf16_full_eval) is not bool: bf16_full_eval = False
- if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
- if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
- if force_float32:
- args.bf16_full_eval = False
- args.fp16_full_eval = False
- elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
- args.bf16_full_eval = True
- args.fp16_full_eval = False
- elif not bf16_full_eval and not fp16_full_eval:
- args.bf16_full_eval = args.bf16
- args.fp16_full_eval = args.fp16
- _output_logits = False
- if locals().get('compute_metrics', None) is not None: _output_logits = True
- if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
- if _output_logits:
- os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
- if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
- pass
- else:
- model_max_seq_length = getattr(model, 'max_seq_length', None)
- args_max_seq_length = getattr(args, 'max_seq_length', None)
- if args_max_seq_length is None and model_max_seq_length is not None:
- max_seq_length = model.max_seq_length
- if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
- elif args_max_seq_length is not None and model_max_seq_length is not None:
- if args_max_seq_length > model_max_seq_length:
- print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but '
- 'the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.')
- args.max_seq_length = model_max_seq_length
- if model is not None and hasattr(model, 'for_training'):
- model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
- if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
- if 'processing_class' in locals():
- if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
- if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
- __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
- from unsloth_zoo.vision_utils import UnslothVisionDataCollator
- if not isinstance(data_collator, UnslothVisionDataCollator):
- if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
- data_collator = TransformersDataCollatorForLanguageModeling(
- __tokenizer,
- mlm = False,
- mlm_probability = 0.0,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
- data_collator = DataCollatorForSeq2Seq(
- __tokenizer,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- else:
- if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
- if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
- if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
- if not isinstance(data_collator, UnslothVisionDataCollator):
- if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
- if isinstance(data_collator, DataCollatorForSeq2Seq):
- data_collator = DataCollatorForSeq2Seq(
- __tokenizer.tokenizer,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- else:
- data_collator = TransformersDataCollatorForLanguageModeling(
- __tokenizer.tokenizer,
- mlm = False,
- mlm_probability = 0.0,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- other_metrics = []
-
- from unsloth_zoo.logging_utils import PatchRLStatistics
- PatchRLStatistics('reward_trainer', other_metrics)
-
- # [TODO] Fix up DataParallel multiplying batch sizes
- # [TODO] DDP works, but DP seems to not work? [TODO]
- if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1:
- if getattr(args, "_n_gpu", 1) != 1:
- args._n_gpu = 1
- if "model" in locals() and hasattr(model, "for_training"):
- model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
- super().__init__(
- model = model,
- args = args,
- data_collator = data_collator,
- train_dataset = train_dataset,
- eval_dataset = eval_dataset,
- processing_class = processing_class,
- compute_metrics = compute_metrics,
- callbacks = callbacks,
- optimizer_cls_and_kwargs = optimizer_cls_and_kwargs,
- preprocess_logits_for_metrics = preprocess_logits_for_metrics,
- peft_config = peft_config,**kwargs)
- if "model" in locals() and hasattr(model, "for_inference"):
- model.for_inference()
- if hasattr(self, 'neftune_hook_handle'):
- self.neftune_hook_handle.remove()
- if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
- if getattr(args, 'neftune_noise_alpha', None) is not None:
- model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
- pass
- if hasattr(self, 'accelerator'):
- scaler = self.accelerator.scaler
- current_model = model
- while hasattr(current_model, 'model'):
- current_model.accelerator_scaler = scaler
- current_model = current_model.model
- current_model.accelerator_scaler = scaler
- pass
- if hasattr(self, 'train'):
- self.train = MethodType(prepare_for_training_mode(self.__class__.train), self)
- pass
- if hasattr(self, 'llm') and self.llm is not None and hasattr(self.llm, 'get_tokenizer'):
- _vllm_tok = self.llm.get_tokenizer()
- _pc = getattr(self, 'processing_class', None) or getattr(self, 'tokenizer', None)
- if _vllm_tok is not None and _pc is not None and getattr(_pc, 'chat_template', None) is not None and getattr(_vllm_tok, 'chat_template', None) is None:
- _vllm_tok.chat_template = _pc.chat_template
- pass
-
-pass
-
-
-if hasattr(logger, "addFilter"):
- import logging
- class HideLoggingMessage(logging.Filter):
- def __init__(self, text): self.text = text
- def filter(self, x): return not (self.text in x.getMessage())
- pass
- logger.addFilter(HideLoggingMessage("`use_cache=True`"))
-
diff --git a/unsloth_compiled_cache/UnslothSFTTrainer.py b/unsloth_compiled_cache/UnslothSFTTrainer.py
deleted file mode 100644
index 5719e38..0000000
--- a/unsloth_compiled_cache/UnslothSFTTrainer.py
+++ /dev/null
@@ -1,1612 +0,0 @@
-"""
-2026.2.1
-2026.2.1
-4.57.6
-0.24.0
-__UNSLOTH_VERSIONING__
-"""
-
-# Unsloth auto generated code
-# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Lesser General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU Lesser General Public License
-# along with this program. If not, see .
-
-from torch import Tensor
-import torch
-import torch.nn as nn
-from torch.nn import functional as F
-from unsloth_zoo.temporary_patches.common import torch_compile
-from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
-from trl.trainer.sft_trainer import (Any, AutoProcessor, BaseTrainer, Callable, DataCollator, DataCollatorForLanguageModeling, DataCollatorForVisionLanguageModeling, Dataset, EvalPrediction, FLASH_ATTENTION_VARIANTS, IterableDataset, Optional, Path, PeftConfig, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SFTConfig, SFTTrainer, TrainerCallback, TrainingArguments, Union, clone_chat_template, contextlib, create_model_from_path, dataclass, defaultdict, dft_loss, get_act_offloading_ctx_manager, is_conversational, logger, logging, nn, os, pack_dataset, pad, selective_log_softmax, torch, Any, AutoProcessor, Callable, DataCollator, DataCollatorForLanguageModeling, DataCollatorForVisionLanguageModeling, Dataset, EvalPrediction, FLASH_ATTENTION_VARIANTS, IterableDataset, Optional, PeftConfig, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SFTConfig, SFTTrainer, TrainerCallback, TrainingArguments, Union, clone_chat_template, contextlib, create_model_from_path, defaultdict, dft_loss, get_act_offloading_ctx_manager, is_conversational, logger, os, pad, torch, Callable, DataCollator, DataCollatorForLanguageModeling, Dataset, IterableDataset, Optional, Union, os, pack_dataset, pad, Optional, PreTrainedModel, logger, os, torch, os)
-
-
-import os
-from typing import *
-from dataclasses import dataclass, field
-from packaging.version import Version
-import torch
-import numpy as np
-from contextlib import nullcontext
-from torch.nn import functional as F
-import inspect
-from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
-from transformers.training_args import ParallelMode
-from unsloth_zoo.device_type import DEVICE_TYPE, device_synchronize
-
-# Wrap trainer with padding to right and enable training mode
-# Also patches W&B since multiple runs must use wandb.finish()
-import functools
-from types import MethodType
-try:
- from unsloth_zoo.gradient_checkpointing import reset_unsloth_gradient_checkpointing_buffers
-except:
- def reset_unsloth_gradient_checkpointing_buffers(): pass
-def prepare_for_training_mode(f):
- @functools.wraps(f)
- def wrapper(self, *args, **kwargs):
- # Enable training mode
- _was_training = None
- # Get gradient checkpointing setting from training arguments
- use_gc = getattr(self.args, 'gradient_checkpointing', True)
- if hasattr(self, 'model') and hasattr(self.model, "training"):
- _was_training = self.model.training
- if hasattr(self, 'model') and hasattr(self.model, "for_training"):
- self.model.for_training(use_gradient_checkpointing=use_gc)
- output = f(self, *args, **kwargs)
- # Restore previous mode when possible
- if hasattr(self, 'model') and hasattr(self.model, "for_inference"):
- if _was_training is False:
- self.model.for_inference()
- elif _was_training is True and hasattr(self.model, "for_training"):
- self.model.for_training(use_gradient_checkpointing=use_gc)
- # Reset gradient checkpointing buffers to free memory while staying ready for next run
- try:
- reset_unsloth_gradient_checkpointing_buffers()
- except:
- pass
- # Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run
- try:
- import wandb
- wandb.finish()
- except:
- pass
- return output
- return wrapper
-pass
-
-torch_compile_options = {
- "epilogue_fusion" : True,
- "max_autotune" : False,
- "shape_padding" : True,
- "trace.enabled" : False,
- "triton.cudagraphs" : False,
-}
-
-@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
-def chunked_hidden_states_selective_log_softmax(
- hidden_states: torch.Tensor,
- lm_head: torch.Tensor,
- index: torch.Tensor,
- chunks: int = 4,
- logit_scale_multiply: float = 0.0,
- logit_scale_divide: float = 0.0,
- logit_softcapping: float = 0.0,
- temperature: float = 1.0,
-) -> torch.Tensor:
- # All Unsloth Zoo code licensed under AGPL3
- flat_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1])
- flat_index = index.reshape(-1)
-
- chunked_hidden_states = torch.chunk(flat_hidden_states, chunks=chunks, dim=0)
- chunked_index = torch.chunk(flat_index, chunks=chunks, dim=0)
-
- all_per_token_logps = []
-
- for chunk_hidden_states, chunk_index in zip(chunked_hidden_states, chunked_index):
- chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t()
-
- if logit_scale_multiply != 0.0:
- chunk_logits = chunk_logits * logit_scale_multiply
- if logit_scale_divide != 0.0:
- chunk_logits = chunk_logits / logit_scale_divide
- if logit_softcapping != 0.0:
- chunk_logits = chunk_logits * torch.tanh(chunk_logits / logit_softcapping)
-
- chunk_logits = chunk_logits.to(torch.float32)
-
- if temperature != 1.0:
- chunk_logits = chunk_logits / temperature
-
- selected_logits = torch.gather(chunk_logits, dim=-1, index=chunk_index.unsqueeze(-1)).squeeze(-1)
- logsumexp_values = torch.logsumexp(chunk_logits, dim=-1)
- per_token_logps = selected_logits - logsumexp_values
- all_per_token_logps.append(per_token_logps)
-
- all_per_token_logps = torch.concat(all_per_token_logps)
-
- all_per_token_logps = all_per_token_logps.reshape((hidden_states.shape[0], hidden_states.shape[1]))
- return all_per_token_logps
-
-@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
-def chunked_selective_log_softmax(logits, index):
- # Split into 4 chunks only
- chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
- chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
- all_per_token_logps = []
- # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
- for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
- chunk_logits = chunk_logits.to(torch.float32)
- selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
- logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
- per_token_logps = selected_logits - logsumexp_values
- all_per_token_logps.append(per_token_logps)
- pass
- all_per_token_logps = torch.concat(all_per_token_logps)
- all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
- return all_per_token_logps
-
-def calculate_pad_tokens_in_prompt(
- input_ids: torch.Tensor,
- logits_to_keep: int,
- pad_token_id: int
-) -> torch.Tensor:
- """
- Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens
- """
- if logits_to_keep >= input_ids.shape[1]:
- raise ValueError("logits_to_keep must be smaller than the sequence length.")
-
- prompt_section = input_ids[:, :-logits_to_keep]
-
- padding_mask = (prompt_section == pad_token_id)
-
- pad_token_counts = padding_mask.sum(dim=1)
-
- return pad_token_counts
-
-def create_completion_attention_mask(
- completion_input_ids: torch.Tensor,
- left_pad_tokens_per_prompt: torch.Tensor,
- max_left_pad: int,
- pad_token_id: int
-) -> torch.Tensor:
- """
- Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad]
-
- Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens
- and pad are pad tokens, this function would make a completion mask that would 0 out the pad
- and p tokens. so in this example [0,0,0,1,1,1,0,0,0]
- """
- batch_size, completion_len = completion_input_ids.shape
- device = completion_input_ids.device
-
- num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt
-
- indices = torch.arange(completion_len, device=device).unsqueeze(0)
- shift_mask = indices >= num_tokens_to_mask.unsqueeze(1)
-
- non_padding_mask = (completion_input_ids != pad_token_id)
-
- final_mask = shift_mask & non_padding_mask
-
- return final_mask
-
-def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor:
- """
- Moves all padding tokens in each sequence of a batch to the right.
- """
- mask = (tensor != pad_id)
- # Must do stable=True since binary mark is unordered
- sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True)
- packed_tensor = torch.gather(tensor, 1, sorted_indices)
- return packed_tensor
-
-def align_logprobs_with_mask(
- logprob_tensor: torch.Tensor,
- attention_mask: torch.Tensor,
- pad_value: float = 0.0
-) -> torch.Tensor:
- """
- Aligns a log probability tensor with a given attention mask.
- """
-
- device = logprob_tensor.device
- batch_size, logprob_seq_len = logprob_tensor.shape
- mask_seq_len = attention_mask.shape[1]
-
- padded_logprobs = torch.full(
- attention_mask.shape,
- fill_value=pad_value,
- dtype=logprob_tensor.dtype,
- device=device
- )
-
- left_pad_counts = torch.argmax(attention_mask, dim=1)
-
- cols = torch.arange(logprob_seq_len, device=device)
- dest_indices = left_pad_counts.unsqueeze(1) + cols
-
- # Create destination row indices
- # Shape: [batch_size, logprob_seq_len]
- row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices)
-
- # --- 4. Filter out-of-bounds indices and perform assignment ---
- # Create a mask to identify only the indices that are within the bounds
- # of the target tensor's sequence length.
- valid_mask = dest_indices < mask_seq_len
-
- # Use this mask to select only the valid row indices, column indices,
- # and the corresponding values from the logprob tensor.
- # This flattens the selected elements into 1D tensors.
- valid_rows = row_indices[valid_mask]
- valid_cols = dest_indices[valid_mask]
- valid_vals = logprob_tensor[valid_mask]
-
- # Place the valid values into their correct positions in the padded tensor
- # using a single, efficient advanced indexing operation.
- padded_logprobs[valid_rows, valid_cols] = valid_vals
-
- return padded_logprobs
-
-def autotune_batch_and_chunks(
- total_input_rows,
- seq_len,
- hidden_size,
- vocab_size,
- dtype_bytes=16,
- multiplier=None
-):
- if multiplier is None:
- final_m = max(4, seq_len // 4096)
- else:
- final_m = multiplier
-
- if torch.cuda.is_available():
- free_bytes, _ = torch.cuda.mem_get_info()
- limit_gb = (free_bytes / (1024**3))*.80
- elif hasattr(torch, "xpu") and torch.xpu.is_available():
- # For XPU: estimate free memory from total - reserved
- total_mem = torch.xpu.get_device_properties(0).total_memory
- reserved_mem = torch.xpu.memory_reserved()
- free_bytes = total_mem - reserved_mem
- limit_gb = (free_bytes / (1024**3)) * 0.80
- else:
- # Fallback: assume 8GB available
- limit_gb = 8.0
-
- bytes_to_gb = 1024**3
-
- b_vals = torch.arange(total_input_rows, 0, -1, device='cpu', dtype=torch.float32)
-
- hidden_gb = (b_vals * seq_len * hidden_size * dtype_bytes) / bytes_to_gb
-
- base_logits = ((b_vals/total_input_rows) * b_vals * seq_len * vocab_size * dtype_bytes) / bytes_to_gb
- logits_gb = base_logits / final_m
-
- total_mem_gb = hidden_gb + logits_gb
-
- valid_mask = total_mem_gb <= limit_gb
- valid_indices = torch.nonzero(valid_mask, as_tuple=False)
-
- if valid_indices.shape[0] == 0:
- #This means your GPU will OOM
- return 4, final_m
-
- best_idx = valid_indices[0].item()
- final_b = int(b_vals[best_idx].item())
-
- return final_b, final_m
-@dataclass
-class UnslothSFTConfig(SFTConfig):
- """
-
- Configuration class for the [`SFTTrainer`].
-
- This class includes only the parameters that are specific to SFT training. For a full list of training arguments,
- please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may
- differ from those in [`~transformers.TrainingArguments`].
-
- Using [`~transformers.HfArgumentParser`] we can turn this class into
- [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
- command line.
-
- Parameters:
- > Parameters that control the model
-
- model_init_kwargs (`dict[str, Any]`, *optional*):
- Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
- argument of the [`SFTTrainer`] is provided as a string. If you're training a MoE architecture and want to
- include the load balancing/auxilliary loss as a part of the final loss, remember to set
- `output_router_logits=True` in this dictionary.
- chat_template_path (`str`, *optional*):
- If specified, sets the model's chat template. This can either be the path to a tokenizer (local directory
- or Hugging Face Hub model) or a direct path to a Jinja template file. When using a Jinja file, you must
- ensure that any special tokens referenced in the template are added to the tokenizer and that the model's
- embedding layer is resized accordingly.
-
- > Parameters that control the data preprocessing
-
- dataset_text_field (`str`, *optional*, defaults to `"text"`):
- Name of the column that contains text data in the dataset.
- dataset_kwargs (`dict[str, Any]`, *optional*):
- Dictionary of optional keyword arguments for the dataset preparation. The only supported key is
- `skip_prepare_dataset`. When the model is a VLM, `skip_prepare_dataset` is automatically treated as `True`
- regardless of the provided value, since preprocessing is done on the fly.
- dataset_num_proc (`int`, *optional*):
- Number of processes to use for processing the dataset.
- eos_token (`str`, *optional*):
- Token used to indicate the end of a turn or sequence. If `None`, it defaults to
- `processing_class.eos_token`.
- pad_token (`str`, *optional*):
- Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that is also `None`,
- it falls back to `processing_class.eos_token`.
- max_length (`int` or `None`, *optional*, defaults to `1024`):
- Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from the right.
- If `None`, no truncation is applied. When packing is enabled, this value sets the sequence length.
- packing (`bool`, *optional*, defaults to `False`):
- Whether to group multiple sequences into fixed-length blocks to improve computational efficiency and reduce
- padding. Uses `max_length` to define sequence length.
- packing_strategy (`str`, *optional*, defaults to `"bfd"`):
- Strategy for packing sequences. Can be either `"bfd"` (best-fit decreasing, default), or `"wrapped"`.
- padding_free (`bool`, *optional*, defaults to `False`):
- Whether to perform forward passes without padding by flattening all sequences in the batch into a single
- continuous sequence. This reduces memory usage by eliminating padding overhead. Currently, this is only
- supported with the FlashAttention 2 or 3, which can efficiently handle the flattened batch structure. When
- packing is enabled with strategy `"bfd"`, padding-free is enabled, regardless of the value of this
- parameter.
- pad_to_multiple_of (`int`, *optional*):
- If set, the sequences will be padded to a multiple of this value.
- eval_packing (`bool`, *optional*):
- Whether to pack the eval dataset. If `None`, uses the same value as `packing`.
-
- > Parameters that control the training
-
- completion_only_loss (`bool`, *optional*):
- Whether to compute loss only on the completion part of the sequence. If set to `True`, loss is computed
- only on the completion, which is supported only for [prompt-completion](#prompt-completion) datasets. If
- `False`, loss is computed on the entire sequence. If `None` (default), the behavior depends on the dataset:
- loss is computed on the completion for [prompt-completion](#prompt-completion) datasets, and on the full
- sequence for [language modeling](#language-modeling) datasets.
- assistant_only_loss (`bool`, *optional*, defaults to `False`):
- Whether to compute loss only on the assistant part of the sequence. If set to `True`, loss is computed only
- on the assistant responses, which is supported only for [conversational](#conversational) datasets. If
- `False`, loss is computed on the entire sequence.
- loss_type (`str`, *optional*, defaults to `"nll"`):
- Type of loss to use. Possible values are `"nll"` (negative log-likelihood, default) and `"dft"` (Dynamic
- Fine-Tuning, as described in [this paper](https://huggingface.co/papers/2508.05629)).
- activation_offloading (`bool`, *optional*, defaults to `False`):
- Whether to offload the activations to the CPU.
-
- """
- vllm_sampling_params: Optional[Any] = field(
- default = None,
- metadata = {'help': 'vLLM SamplingParams'},
- )
- unsloth_num_chunks : Optional[int] = field(
- default = -1,
- metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
- )
- unsloth_logit_chunk_multiplier : Optional[int] = field(
- default = None,
- metadata = {'help': 'Multiplier for chunked logit computations.'},
- )
- unsloth_grpo_mini_batch : Optional[int] = field(
- default = None,
- metadata = {'help': 'Mini batch size for GRPO hidden state accumulation. Default is None unless user defines it.'},
- )
- max_seq_length : Optional[int] = field(
- default = None,
- metadata = {'help': 'Maximum sequence length to truncate to.'},
- )
- def __init__(
- self,
- output_dir = None,
- overwrite_output_dir = None,
- do_train = False,
- do_eval = False,
- do_predict = False,
- eval_strategy = 'no',
- prediction_loss_only = False,
- per_device_train_batch_size = 4,
- per_device_eval_batch_size = 4,
- per_gpu_train_batch_size = None,
- per_gpu_eval_batch_size = None,
- gradient_accumulation_steps = 2,
- eval_accumulation_steps = 2,
- eval_delay = 0,
- torch_empty_cache_steps = 250,
- learning_rate = 5e-05,
- weight_decay = 0.01,
- adam_beta1 = 0.9,
- adam_beta2 = 0.999,
- adam_epsilon = 1e-08,
- max_grad_norm = 1.0,
- num_train_epochs = 3.0,
- max_steps = -1,
- lr_scheduler_type = 'linear',
- lr_scheduler_kwargs = None,
- warmup_ratio = 0.1,
- warmup_steps = 0,
- log_level = 'passive',
- log_level_replica = 'warning',
- log_on_each_node = True,
- logging_dir = None,
- logging_strategy = 'steps',
- logging_first_step = False,
- logging_steps = 1,
- logging_nan_inf_filter = False,
- save_strategy = 'steps',
- save_steps = 500,
- save_total_limit = None,
- save_safetensors = True,
- save_on_each_node = False,
- save_only_model = False,
- restore_callback_states_from_checkpoint = False,
- no_cuda = False,
- use_cpu = False,
- use_mps_device = False,
- seed = 3407,
- data_seed = 3407,
- jit_mode_eval = False,
- bf16 = False,
- fp16 = False,
- fp16_opt_level = 'O1',
- half_precision_backend = 'auto',
- bf16_full_eval = False,
- fp16_full_eval = False,
- tf32 = None,
- local_rank = -1,
- ddp_backend = None,
- tpu_num_cores = None,
- tpu_metrics_debug = False,
- debug = '',
- dataloader_drop_last = False,
- eval_steps = None,
- dataloader_num_workers = 0,
- dataloader_prefetch_factor = None,
- past_index = -1,
- run_name = None,
- disable_tqdm = None,
- remove_unused_columns = True,
- label_names = None,
- load_best_model_at_end = False,
- metric_for_best_model = None,
- greater_is_better = None,
- ignore_data_skip = False,
- fsdp = None,
- fsdp_min_num_params = 0,
- fsdp_config = None,
- fsdp_transformer_layer_cls_to_wrap = None,
- accelerator_config = None,
- parallelism_config = None,
- deepspeed = None,
- label_smoothing_factor = 0.0,
- optim = 'adamw_8bit',
- optim_args = None,
- adafactor = False,
- group_by_length = False,
- length_column_name = 'length',
- report_to = 'none',
- project = 'huggingface',
- trackio_space_id = 'trackio',
- ddp_find_unused_parameters = None,
- ddp_bucket_cap_mb = None,
- ddp_broadcast_buffers = None,
- dataloader_pin_memory = True,
- dataloader_persistent_workers = False,
- skip_memory_metrics = True,
- use_legacy_prediction_loop = False,
- push_to_hub = False,
- resume_from_checkpoint = None,
- hub_model_id = None,
- hub_strategy = 'every_save',
- hub_token = None,
- hub_private_repo = None,
- hub_always_push = False,
- hub_revision = None,
- gradient_checkpointing = True,
- gradient_checkpointing_kwargs = None,
- include_inputs_for_metrics = False,
- eval_do_concat_batches = True,
- fp16_backend = 'auto',
- push_to_hub_model_id = None,
- push_to_hub_organization = None,
- push_to_hub_token = None,
- mp_parameters = '',
- auto_find_batch_size = False,
- full_determinism = False,
- torchdynamo = None,
- ray_scope = 'last',
- ddp_timeout = 1800,
- torch_compile = False,
- torch_compile_backend = None,
- torch_compile_mode = None,
- include_tokens_per_second = False,
- include_num_input_tokens_seen = False,
- neftune_noise_alpha = None,
- optim_target_modules = None,
- batch_eval_metrics = False,
- eval_on_start = False,
- use_liger_kernel = False,
- liger_kernel_config = None,
- eval_use_gather_object = False,
- average_tokens_across_devices = True,
- model_init_kwargs = None,
- chat_template_path = None,
- dataset_text_field = 'text',
- dataset_kwargs = None,
- dataset_num_proc = None,
- eos_token = None,
- pad_token = None,
- max_length = 1024,
- packing = False,
- packing_strategy = 'bfd',
- padding_free = False,
- pad_to_multiple_of = None,
- eval_packing = None,
- completion_only_loss = None,
- assistant_only_loss = False,
- loss_type = 'nll',
- activation_offloading = False,
- vllm_sampling_params = None,
- unsloth_num_chunks = -1,
- unsloth_logit_chunk_multiplier = None,
- unsloth_grpo_mini_batch = None,
- max_seq_length = None,
- **kwargs,
- ):
- if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
- if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
- if num_train_epochs is None:
- num_train_epochs = 3.0 # Default to 3 epochs if None, max_steps will override
- if output_dir is None and save_strategy == 'steps' and save_steps == 500:
- output_dir = 'unsloth_training_checkpoints'
- save_strategy = 'no'
- import multiprocessing as _mp
- if _mp.get_start_method() != 'fork':
- dataset_num_proc = None
- elif dataset_num_proc is None:
- import psutil
- dataset_num_proc = min(max((psutil.cpu_count() or 1)+4, 2), 64)
- memory_gb_left = psutil.virtual_memory().available / (1024**3)
- if memory_gb_left <= 2: dataset_num_proc = 1
- else: dataset_num_proc = min(dataset_num_proc, int(memory_gb_left))
- if os.environ.get('UNSLOTH_ENABLE_FLEX_ATTENTION', '0') == '1':
- from unsloth_zoo.flex_attention import HAS_FLEX_ATTENTION
- if HAS_FLEX_ATTENTION and pad_to_multiple_of is None:
- from unsloth_zoo.flex_attention import FLEX_ATTENTION_BLOCK_SIZE
- pad_to_multiple_of = FLEX_ATTENTION_BLOCK_SIZE
-
-
- super().__init__(
- output_dir = output_dir,
- overwrite_output_dir = overwrite_output_dir,
- do_train = do_train,
- do_eval = do_eval,
- do_predict = do_predict,
- eval_strategy = eval_strategy,
- prediction_loss_only = prediction_loss_only,
- per_device_train_batch_size = per_device_train_batch_size,
- per_device_eval_batch_size = per_device_eval_batch_size,
- per_gpu_train_batch_size = per_gpu_train_batch_size,
- per_gpu_eval_batch_size = per_gpu_eval_batch_size,
- gradient_accumulation_steps = gradient_accumulation_steps,
- eval_accumulation_steps = eval_accumulation_steps,
- eval_delay = eval_delay,
- torch_empty_cache_steps = torch_empty_cache_steps,
- learning_rate = learning_rate,
- weight_decay = weight_decay,
- adam_beta1 = adam_beta1,
- adam_beta2 = adam_beta2,
- adam_epsilon = adam_epsilon,
- max_grad_norm = max_grad_norm,
- num_train_epochs = num_train_epochs,
- max_steps = max_steps,
- lr_scheduler_type = lr_scheduler_type,
- lr_scheduler_kwargs = lr_scheduler_kwargs,
- warmup_ratio = warmup_ratio,
- warmup_steps = warmup_steps,
- log_level = log_level,
- log_level_replica = log_level_replica,
- log_on_each_node = log_on_each_node,
- logging_dir = logging_dir,
- logging_strategy = logging_strategy,
- logging_first_step = logging_first_step,
- logging_steps = logging_steps,
- logging_nan_inf_filter = logging_nan_inf_filter,
- save_strategy = save_strategy,
- save_steps = save_steps,
- save_total_limit = save_total_limit,
- save_safetensors = save_safetensors,
- save_on_each_node = save_on_each_node,
- save_only_model = save_only_model,
- restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
- no_cuda = no_cuda,
- use_cpu = use_cpu,
- use_mps_device = use_mps_device,
- seed = seed,
- data_seed = data_seed,
- jit_mode_eval = jit_mode_eval,
- bf16 = bf16,
- fp16 = fp16,
- fp16_opt_level = fp16_opt_level,
- half_precision_backend = half_precision_backend,
- bf16_full_eval = bf16_full_eval,
- fp16_full_eval = fp16_full_eval,
- tf32 = tf32,
- local_rank = local_rank,
- ddp_backend = ddp_backend,
- tpu_num_cores = tpu_num_cores,
- tpu_metrics_debug = tpu_metrics_debug,
- debug = debug,
- dataloader_drop_last = dataloader_drop_last,
- eval_steps = eval_steps,
- dataloader_num_workers = dataloader_num_workers,
- dataloader_prefetch_factor = dataloader_prefetch_factor,
- past_index = past_index,
- run_name = run_name,
- disable_tqdm = disable_tqdm,
- remove_unused_columns = remove_unused_columns,
- label_names = label_names,
- load_best_model_at_end = load_best_model_at_end,
- metric_for_best_model = metric_for_best_model,
- greater_is_better = greater_is_better,
- ignore_data_skip = ignore_data_skip,
- fsdp = fsdp,
- fsdp_min_num_params = fsdp_min_num_params,
- fsdp_config = fsdp_config,
- fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
- accelerator_config = accelerator_config,
- parallelism_config = parallelism_config,
- deepspeed = deepspeed,
- label_smoothing_factor = label_smoothing_factor,
- optim = optim,
- optim_args = optim_args,
- adafactor = adafactor,
- group_by_length = group_by_length,
- length_column_name = length_column_name,
- report_to = report_to,
- project = project,
- trackio_space_id = trackio_space_id,
- ddp_find_unused_parameters = ddp_find_unused_parameters,
- ddp_bucket_cap_mb = ddp_bucket_cap_mb,
- ddp_broadcast_buffers = ddp_broadcast_buffers,
- dataloader_pin_memory = dataloader_pin_memory,
- dataloader_persistent_workers = dataloader_persistent_workers,
- skip_memory_metrics = skip_memory_metrics,
- use_legacy_prediction_loop = use_legacy_prediction_loop,
- push_to_hub = push_to_hub,
- resume_from_checkpoint = resume_from_checkpoint,
- hub_model_id = hub_model_id,
- hub_strategy = hub_strategy,
- hub_token = hub_token,
- hub_private_repo = hub_private_repo,
- hub_always_push = hub_always_push,
- hub_revision = hub_revision,
- gradient_checkpointing = gradient_checkpointing,
- gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
- include_inputs_for_metrics = include_inputs_for_metrics,
- eval_do_concat_batches = eval_do_concat_batches,
- fp16_backend = fp16_backend,
- push_to_hub_model_id = push_to_hub_model_id,
- push_to_hub_organization = push_to_hub_organization,
- push_to_hub_token = push_to_hub_token,
- mp_parameters = mp_parameters,
- auto_find_batch_size = auto_find_batch_size,
- full_determinism = full_determinism,
- torchdynamo = torchdynamo,
- ray_scope = ray_scope,
- ddp_timeout = ddp_timeout,
- torch_compile = torch_compile,
- torch_compile_backend = torch_compile_backend,
- torch_compile_mode = torch_compile_mode,
- include_tokens_per_second = include_tokens_per_second,
- include_num_input_tokens_seen = include_num_input_tokens_seen,
- neftune_noise_alpha = neftune_noise_alpha,
- optim_target_modules = optim_target_modules,
- batch_eval_metrics = batch_eval_metrics,
- eval_on_start = eval_on_start,
- use_liger_kernel = use_liger_kernel,
- liger_kernel_config = liger_kernel_config,
- eval_use_gather_object = eval_use_gather_object,
- average_tokens_across_devices = average_tokens_across_devices,
- model_init_kwargs = model_init_kwargs,
- chat_template_path = chat_template_path,
- dataset_text_field = dataset_text_field,
- dataset_kwargs = dataset_kwargs,
- dataset_num_proc = dataset_num_proc,
- eos_token = eos_token,
- pad_token = pad_token,
- max_length = max_length,
- packing = packing,
- packing_strategy = packing_strategy,
- padding_free = padding_free,
- pad_to_multiple_of = pad_to_multiple_of,
- eval_packing = eval_packing,
- completion_only_loss = completion_only_loss,
- assistant_only_loss = assistant_only_loss,
- loss_type = loss_type,
- activation_offloading = activation_offloading,**kwargs)
- self.vllm_sampling_params = vllm_sampling_params
- self.unsloth_num_chunks = unsloth_num_chunks
- if unsloth_grpo_mini_batch is not None:
- if self.generation_batch_size >= unsloth_grpo_mini_batch:
- self.unsloth_grpo_mini_batch = unsloth_grpo_mini_batch
- else:
- raise ValueError(
- f"Unsloth GRPO mini batch size needs to be less than or equal to the effective generation batch size, "
- f"which is self.per_device_train_batch_size * gradient_accumulation_steps."
- )
- self.unsloth_logit_chunk_multiplier = unsloth_logit_chunk_multiplier
- self.max_seq_length = max_seq_length
-
-pass
-
-class _UnslothSFTTrainer(BaseTrainer):
- """"""
-
- _tag_names = ["trl", "sft"]
- _name = "SFT"
-
- def __init__(
- self,
- model: Union[str, PreTrainedModel],
- args: Optional[Union[SFTConfig, TrainingArguments]] = None,
- data_collator: Optional[DataCollator] = None,
- train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
- eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
- processing_class: Optional[Union[PreTrainedTokenizerBase, ProcessorMixin]] = None,
- compute_loss_func: Optional[Callable] = None,
- compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
- callbacks: Optional[list[TrainerCallback]] = None,
- optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
- optimizer_cls_and_kwargs: Optional[tuple[type[torch.optim.Optimizer], dict[str, Any]]] = None,
- preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
- peft_config: Optional["PeftConfig"] = None,
- formatting_func: Optional[Callable[[dict], str]] = None,
- ):
- # Args
- if args is None:
- model_name = model if isinstance(model, str) else model.config._name_or_path
- model_name = model_name.split("/")[-1]
- args = SFTConfig(f"{model_name}-SFT")
- elif isinstance(args, TrainingArguments) and not isinstance(args, SFTConfig):
- dict_args = args.to_dict()
- dict_args["hub_token"] = args.hub_token # to_dict hides the hub_token
- dict_args.pop("push_to_hub_token", None)
- args = SFTConfig(**dict_args)
-
- # Model
- if isinstance(model, str):
- model = create_model_from_path(model, **args.model_init_kwargs or {})
- else:
- if args.model_init_kwargs is not None:
- logger.warning(
- "You passed `model_init_kwargs` to the `SFTConfig`, but your model is already instantiated. "
- "The `model_init_kwargs` will be ignored."
- )
- model_id = model.config._name_or_path
-
- # Processing class
- if processing_class is None:
- processing_class = AutoProcessor.from_pretrained(model_id)
-
- # Handle pad token for processors or tokenizers
- if isinstance(processing_class, ProcessorMixin):
- tokenizer = processing_class.tokenizer
- self._is_vlm = True
- elif isinstance(processing_class, PreTrainedTokenizerBase):
- tokenizer = processing_class
- self._is_vlm = False
- else:
- raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`")
-
- if args.eos_token is not None:
- eos_token = args.eos_token
- eos_token_id = tokenizer.convert_tokens_to_ids(eos_token)
- if eos_token_id is None:
- raise ValueError(
- f"The specified `eos_token` ('{eos_token}') is not found in the vocabulary of the given "
- f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `eos_token` exists "
- "in the vocabulary before using it as an EOS token."
- )
- tokenizer.eos_token_id = eos_token_id
-
- if args.chat_template_path is not None:
- if os.path.isfile(args.chat_template_path) and args.chat_template_path.endswith((".jinja", ".j2")):
- with open(args.chat_template_path, encoding="utf-8") as chat_template_file:
- processing_class.chat_template = chat_template_file.read()
- added_tokens = []
- else:
- model, processing_class, added_tokens = clone_chat_template(
- model, processing_class, args.chat_template_path
- )
- else:
- added_tokens = []
-
- # Catch some wrong configurations related to VLMs
- if self._is_vlm and args.packing:
- raise ValueError(
- "Packing is not supported for vision-language models. Please set `packing=False` in the SFTConfig."
- )
- if self._is_vlm and args.padding_free:
- raise ValueError(
- "Padding-free training is yet not supported for vision-language models. Please set "
- "`padding_free=False` in the `SFTConfig`."
- )
- if self._is_vlm and args.assistant_only_loss:
- raise ValueError(
- "Assistant-only loss is not yet supported for vision-language models. Please set "
- "`assistant_only_loss=False` in the `SFTConfig`."
- )
-
- # PEFT configuration and model wrapping
- if False:
- if added_tokens:
- # Ensure that the added tokens are trainable
- if peft_config.trainable_token_indices is None:
- peft_config.trainable_token_indices = {"embed_tokens": added_tokens}
- elif "embed_tokens" not in peft_config.trainable_token_indices:
- peft_config.trainable_token_indices["embed_tokens"] = added_tokens
- else:
- peft_config.trainable_token_indices["embed_tokens"].extend(added_tokens)
-
- # Ensure that the lm_head is trainable
- if peft_config.modules_to_save is None or "lm_head" not in peft_config.modules_to_save:
- logger.warning(
- "Cloning chat template added new tokens to the tokenizer, but 'lm_head' is not in PEFT's "
- "`modules_to_save`. As a result, the model may not learn to generate outputs with these new "
- "tokens, leading to degraded generation quality. To fix this, add "
- "`modules_to_save=['lm_head']` to your PEFT configuration."
- )
-
- if peft_config.modules_to_save is None:
- peft_config.modules_to_save = ["lm_head"]
- else:
- peft_config.modules_to_save.append("lm_head")
-
- # In Prompt Tuning a small set of trainable virtual tokens [continuous prompt embeddings] is prepended to the
- # input. We store the number of these tokens so we can account for them correctly when calculating accuracy.
- self.num_virtual_tokens = 0
-
- if False:
- pass
- if model.active_adapter in model.peft_config:
- peft_model_config = model.peft_config[model.active_adapter]
- self.num_virtual_tokens = getattr(peft_model_config, "num_virtual_tokens", 0)
-
- # Data collator
- # BFD packing requires padding-free mode; otherwise, the collator outputs padded attention masks, causing
- # FlashAttention to ignore position_ids and recompute them incorrectly from the padded attention mask.
- self.padding_free = args.padding_free or (args.packing and args.packing_strategy == "bfd")
- use_flash_attention = model.config._attn_implementation in FLASH_ATTENTION_VARIANTS
- if self.padding_free:
- if data_collator is not None:
- raise ValueError("Passing a custom data collator is not supported when using padding-free.")
- if args.packing and args.packing_strategy == "wrapped":
- logger.warning(
- "You are passing `padding_free=True` with the 'wrapped' packing strategy, which is not "
- "recommended. Please refer to the documentation to understand why this is not recommended."
- )
- if not use_flash_attention:
- logger.warning(
- "Padding-free training is enabled, but the attention implementation is not set to a supported "
- "flash attention variant. Padding-free training flattens batches into a single sequence, and only "
- "the following implementations are known to reliably support this: "
- f"{', '.join(sorted(FLASH_ATTENTION_VARIANTS))}. Using other implementations may lead to "
- "unexpected behavior. To ensure compatibility, set `attn_implementation` in the model "
- "configuration to one of these supported options or verify that your attention mechanism can "
- "handle flattened sequences."
- )
- # Decide whether to use completion-only loss: if not specified, then it is set to True if the dataset format
- # is prompt-completion, and False if the dataset format is language modeling.
- dataset_sample = next(iter(train_dataset))
- if args.completion_only_loss is None:
- self.completion_only_loss = "prompt" in dataset_sample and "completion" in dataset_sample
- else:
- self.completion_only_loss = args.completion_only_loss
-
- self._is_vision_dataset = "image" in dataset_sample or "images" in dataset_sample
- # Unsloth: override _is_vlm for VLM models that pass a bare tokenizer
- if not self._is_vlm and self._is_vision_dataset:
- _m = model
- if hasattr(_m, "model"): _m = _m.model
- if hasattr(getattr(_m, "config", None), "vision_config") or\
- _m.__class__.__name__.endswith("ForConditionalGeneration"):
- self._is_vlm = True
- if self._is_vision_dataset and not self._is_vlm:
- raise ValueError(
- "The dataset appears to be vision-related (contains 'image' or 'images' keys), but the provided "
- "model does not seem to be a vision-language model. Please check your model and dataset."
- )
-
- if data_collator is None and not self._is_vision_dataset:
- # Get the pad token: if not provided, use the one from the processing class or the eos token
- # if the processing class does not have a pad token.
- pad_token = args.pad_token or tokenizer.pad_token or tokenizer.eos_token
- pad_token_id = tokenizer.convert_tokens_to_ids(pad_token)
- if pad_token_id is None:
- raise ValueError(
- f"The specified `pad_token` ('{pad_token}') is not found in the vocabulary of the given "
- f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `pad_token` exists "
- "in the vocabulary before using it as a padding token."
- )
- data_collator = DataCollatorForLanguageModeling(
- pad_token_id=pad_token_id,
- completion_only_loss=self.completion_only_loss,
- padding_free=self.padding_free,
- pad_to_multiple_of=args.pad_to_multiple_of,
- )
- elif data_collator is None and self._is_vision_dataset:
- data_collator = DataCollatorForVisionLanguageModeling(
- processor=processing_class,
- max_length=args.max_length,
- completion_only_loss=self.completion_only_loss,
- pad_to_multiple_of=args.pad_to_multiple_of,
- dataset_text_field=args.dataset_text_field,
- )
-
- if args.packing and args.packing_strategy == "bfd" and not use_flash_attention:
- logger.warning(
- "You are using packing, but the attention implementation is not set to a supported flash attention "
- "variant. Packing gathers multiple samples into a single sequence, and only the following "
- f"implementations are known to reliably support this: {', '.join(sorted(FLASH_ATTENTION_VARIANTS))}. "
- "Using other implementations may lead to cross-contamination between samples. To avoid this, either "
- "disable packing by setting `packing=False`, or set `attn_implementation` in the model configuration "
- "to one of these supported options."
- )
- if args.assistant_only_loss and not is_conversational(dataset_sample):
- raise ValueError(
- "You set `assistant_only_loss=True`, but the dataset is not conversational. This option is only "
- "supported for conversational datasets."
- )
-
- # Dataset
- # Skip dataset preparation if `skip_prepare_dataset=True` in `dataset_kwargs`, or if it's a VLM, where
- # preprocessing [e.g., image-to-pixel conversion] is too costly and done on the fly instead.
- skip_prepare_dataset = (
- args.dataset_kwargs is not None
- and args.dataset_kwargs.get("skip_prepare_dataset", False)
- or self._is_vision_dataset
- )
- if not skip_prepare_dataset:
- if self.completion_only_loss and formatting_func:
- raise ValueError(
- "A formatting function was provided while `completion_only_loss=True`, which is incompatible. "
- "Using a formatter converts the dataset to a language modeling type, conflicting with "
- "completion-only loss. To resolve this, apply your formatting function before passing the "
- "dataset, or disable `completion_only_loss` in `SFTConfig`."
- )
- self._unsloth_model_ref = model
- train_dataset = self._prepare_dataset(
- train_dataset, processing_class, args, args.packing, formatting_func, "train"
- )
- if eval_dataset is not None:
- packing = args.packing if args.eval_packing is None else args.eval_packing
- if isinstance(eval_dataset, dict):
- eval_dataset = {
- key: self._prepare_dataset(dataset, processing_class, args, packing, formatting_func, key)
- for key, dataset in eval_dataset.items()
- }
- else:
- eval_dataset = self._prepare_dataset(
- eval_dataset, processing_class, args, packing, formatting_func, "eval"
- )
-
- # Loss function
- if args.loss_type == "nll":
- pass # use the default loss
- elif args.loss_type == "dft":
- if compute_loss_func is not None:
- raise ValueError(
- "You passed a `compute_loss_func` together with `loss_type='dft'` to the `SFTTrainer`. "
- "When using `loss_type='dft'`, the loss function is internally set to the DFT loss, so passing a "
- "`compute_loss_func` is not allowed."
- )
- compute_loss_func = dft_loss
- else:
- raise ValueError(f"Invalid `loss_type` {args.loss_type} passed. Supported values are 'nll' and 'dft'.")
-
- # Initialize the metrics
- self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
- self._total_train_tokens = 0
-
- # Initialize the Trainer. Parent class will handle:
- # - DeepSpeed configuration [through create_accelerator_and_postprocess]
- # - FSDP setup
- # - Distributed training setup
- # - Optimizer and scheduler creation
-
- super().__init__(
- model=model,
- args=args,
- data_collator=data_collator,
- train_dataset=train_dataset,
- eval_dataset=eval_dataset,
- processing_class=processing_class,
- compute_loss_func=compute_loss_func,
- compute_metrics=compute_metrics,
- callbacks=callbacks,
- optimizers=optimizers,
- optimizer_cls_and_kwargs=optimizer_cls_and_kwargs,
- preprocess_logits_for_metrics=preprocess_logits_for_metrics,
- )
-
- # Initialize activation offloading context
- if self.args.activation_offloading:
- self.maybe_activation_offload_context = get_act_offloading_ctx_manager(model=self.model)
- else:
- self.maybe_activation_offload_context = contextlib.nullcontext()
-
- # Add tags for models that have been loaded with the correct transformers version
- if hasattr(self.model, "add_model_tags"):
- self.model.add_model_tags(self._tag_names)
-
- self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
-
- def _prepare_dataset(
- self,
- dataset: Union[Dataset, IterableDataset],
- processing_class,
- args,
- packing: bool,
- formatting_func: Optional[Callable[[dict], str]],
- dataset_name: str,
- ) -> Union[Dataset, IterableDataset]:
- # All Unsloth Zoo code licensed under LGPLv3
- try:
- if isinstance(dataset, ConstantLengthDataset): return dataset
- except:
- pass
-
- map_kwargs = {}
- use_desc = isinstance(dataset, Dataset)
- is_vlm = hasattr(processing_class, "tokenizer")
- tokenizer = processing_class
- if is_vlm: tokenizer = processing_class.tokenizer
-
- # Dynamic detection: check if model's module defines a function
- # that requires token_type_ids when is_training=True
- import sys as _sys
- _needs_token_type_ids = False
- # Split to avoid compiler substring match on masking_utils names
- _ccm = 'create_' + 'causal_mask_mapping'
- _model = getattr(self, '_unsloth_model_ref', None) or getattr(self, 'model', None)
- if _model is not None:
- for _m in (_model, getattr(_model, 'model', None)):
- if _m is None: continue
- _mod = _sys.modules.get(type(_m).__module__)
- if _mod is not None and hasattr(_mod, _ccm):
- _needs_token_type_ids = True
- break
-
- if not _needs_token_type_ids:
- # Fallback: model not yet available, check processor class MRO
- for _base in type(processing_class).__mro__:
- _base_mod = getattr(_base, '__module__', '')
- if 'transformers.models.' in _base_mod:
- _modeling_mod = _base_mod.replace('.processing_', '.modeling_')
- _mod = _sys.modules.get(_modeling_mod)
- if _mod is not None and hasattr(_mod, _ccm):
- _needs_token_type_ids = True
- break
- if _needs_token_type_ids and hasattr(args, 'remove_unused_columns'):
- args.remove_unused_columns = False
-
- # Get max length
- max_seq_length = getattr(args, "max_length", 0)
- if max_seq_length == 0: max_seq_length = getattr(args, "max_seq_length", 0)
- if max_seq_length == 0: max_seq_length = getattr(self, "max_seq_length", 0)
- if max_seq_length == 0: max_seq_length = getattr(self, "max_seq", 0)
- if max_seq_length == 0: raise RuntimeError("Unsloth: max_seq_length is 0! Please specify one!")
- dataset_text_field = getattr(args, "dataset_text_field", "text")
- do_truncation = max_seq_length != 0
- do_formatting_func = False
- do_tokenize = True
-
- # Get correct column names
- column_names = set(next(iter(dataset)).keys())
- used_column_names = ["input_ids"]
- if "attention_mask" in column_names:
- used_column_names.append("attention_mask")
- if _needs_token_type_ids:
- used_column_names.append("token_type_ids")
-
- # Check if already tokenized so skip
- from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
- if "labels" in column_names:
- # Most likely forgot data collator!
- if is_vlm and not hasattr(tokenizer, "pad"):
- # Check if processing_class has a .pad, if not, use tokenizer.tokenizer
- raise RuntimeError(f"Unsloth: {processing_class.__class__} does not have .pad!")
- self.data_collator = DataCollatorForSeq2Seq(tokenizer)
- used_column_names.append("labels")
- do_tokenize = False
- elif "input_ids" in column_names:
- # Skip dataset prep, and set data collator
- if is_vlm and not hasattr(tokenizer, "pad"):
- # Check if processing_class has a .pad, if not, use tokenizer.tokenizer
- raise RuntimeError(f"Unsloth: {processing_class.__class__} does not have .pad!")
- self.data_collator = DataCollatorForLanguageModeling(tokenizer, mlm = False)
- do_tokenize = False
- elif dataset_text_field not in column_names:
- do_formatting_func = True
- if formatting_func is None:
- raise RuntimeError("Unsloth: You must specify a `formatting_func`")
- pass
-
- if do_tokenize:
- # Check double BOS tokens
- if do_formatting_func:
- test_text = formatting_func(next(iter(dataset)))
- if not isinstance(test_text, list):
- raise ValueError(
- "Unsloth: The `formatting_func` should return a list of processed strings."
- )
- test_text = test_text[0]
- else:
- test_text = next(iter(dataset))[dataset_text_field][0]
-
- # Get chat template
- chat_template = getattr(processing_class, 'chat_template', '')
- if chat_template == '' and is_vlm:
- chat_template = getattr(tokenizer, 'chat_template', '')
- if chat_template is None:
- chat_template = ''
-
- # Get bos_token
- add_special_tokens = True
- bos_token_1 = getattr(processing_class, 'bos_token', None)
- bos_token_2 = getattr(tokenizer, 'bos_token', None)
- bos_token = bos_token_1 or bos_token_2
-
- if bos_token is not None:
- if test_text.startswith(bos_token) or bos_token in chat_template:
- add_special_tokens = False
- print("Unsloth: We found double BOS tokens - we shall remove one automatically.")
- pass
-
- # Create tokenize function
- def _tokenize(example):
- return tokenizer(
- example[dataset_text_field] if not do_formatting_func else formatting_func(example),
- truncation = do_truncation,
- max_length = max_seq_length,
- return_token_type_ids = _needs_token_type_ids,
- add_special_tokens = add_special_tokens,
- )
- pass
-
- if not isinstance(dataset, IterableDataset):
- import multiprocessing as _mp
- if _mp.get_start_method() != 'fork':
- dataset_num_proc = None
- else:
- dataset_num_proc = getattr(args, "dataset_num_proc", None)
- if dataset_num_proc is None:
- import psutil
- dataset_num_proc = min(max((psutil.cpu_count() or 1)+4, 2), 64)
- memory_gb_left = psutil.virtual_memory().available / (1024**3)
- if memory_gb_left <= 2:
- dataset_num_proc = 1
- else:
- dataset_num_proc = min(dataset_num_proc, int(memory_gb_left))
- map_kwargs["num_proc"] = dataset_num_proc
- else:
- map_kwargs["batch_size"] = dataset._ex_iterable.batch_size
-
- if use_desc: map_kwargs["desc"] = f'Unsloth: Tokenizing ["{dataset_text_field}"]'
- import warnings as _w
- with _w.catch_warnings():
- _w.filterwarnings("ignore", message=".*couldn't be hashed properly.*")
- dataset = dataset.map(_tokenize, batched = True, remove_columns = list(column_names), **map_kwargs)
-
- # If VLM, switch data collator since .pad is needed!
- if is_vlm and not hasattr(processing_class, "pad"):
- data_collator = DataCollatorForLanguageModeling(tokenizer, mlm = False)
- self.data_collator = data_collator
- pass
- pass
- if packing:
- # Try using new packing which works in TRL
- try:
- pack_dataset
- except:
- print("Unsloth: Hugging Face's packing is currently buggy - we're disabling it for now!")
- return dataset
-
- if max_seq_length == 0:
- raise ValueError("When packing is enabled, `max_seq_length` can't be `None`.")
-
- if use_desc: map_kwargs["desc"] = f"Unsloth: Packing {dataset_name} dataset"
- dataset = pack_dataset(
- dataset.select_columns(used_column_names),
- max_seq_length,
- getattr(args, "packing_strategy", "bfd"),
- map_kwargs,
- )
- pass
- return dataset
-
- def _set_signature_columns_if_needed(self):
- # If `self.args.remove_unused_columns` is True, non-signature columns are removed.
- # By default, this method sets `self._signature_columns` to the model's expected inputs (usually, "input_ids"
- # and "attention_mask"). When using `train_on_completion_only` we add a "completion_mask" column to the
- # dataset. So we need to override the default signature columns to include "completion_mask" as well.
- if self._signature_columns is None:
- if self._is_vision_dataset:
- self._signature_columns = ["messages", "prompt", "completion", "images", "input_ids", "labels", "attention_mask", "seq_lengths", "completion_mask", "assistant_masks"]
- else:
- self._signature_columns = ["input_ids", "labels", "seq_lengths", "completion_mask", "assistant_masks"]
-
- def compute_loss(
- self, model, inputs, return_outputs = False, num_items_in_batch = None
- ):
- outputs = super().compute_loss(
- model,
- inputs,
- return_outputs = return_outputs,
- num_items_in_batch = num_items_in_batch,
- )
- return outputs
-
- # Override training step to add activation offloading context.
- def training_step(self, *args, **kwargs):
- with self.maybe_activation_offload_context:
- return super().training_step(*args, **kwargs)
-
- def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
- mode = "train" if self.model.training else "eval"
- metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics
-
- # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
- # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
- if mode == "eval":
- metrics = {f"eval_{key}": val for key, val in metrics.items()}
-
- logs.update(metrics)
- super().log(logs, start_time)
- self._metrics[mode].clear()
-
- # Ensure the model card is saved along with the checkpoint
- def _save_checkpoint(self, model, trial):
- if self.args.hub_model_id is None:
- model_name = Path(self.args.output_dir).name
- else:
- model_name = self.args.hub_model_id.split("/")[-1]
- self.create_model_card(model_name=model_name)
- super()._save_checkpoint(model, trial)
-class UnslothSFTTrainer(_UnslothSFTTrainer):
- """
-
- Trainer for Supervised Fine-Tuning (SFT) method.
-
- This class is a wrapper around the [`~transformers.Trainer`] class and inherits all of its attributes and methods.
-
- Example:
-
- ```python
- from datasets import load_dataset
- from trl import SFTTrainer
-
- dataset = load_dataset("roneneldan/TinyStories", split="train[:1%]")
-
- trainer = SFTTrainer(model="Qwen/Qwen2-0.5B-Instruct", train_dataset=dataset)
- trainer.train()
- ```
-
- Args:
- model (`Union[str, PreTrainedModel]`):
- Model to be trained. Can be either:
-
- - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a
- path to a *directory* containing model weights saved using
- [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
- using `.from_pretrained` (where `` is derived from the model
- config) with the keyword arguments in `args.model_init_kwargs`.
- - A [`~transformers.PreTrainedModel`] object.
- If you're training a model with an MoE architecture and want to include the load balancing/auxilliary loss
- as a part of the final loss, remember to set the `output_router_logits` config of the model to `True`.
- args ([`SFTConfig`], *optional*):
- Configuration for this trainer. If `None`, a default configuration is used.
- data_collator ([`~transformers.DataCollator`], *optional*):
- Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`.
- Will default to [`~trainer.sft_trainer.DataCollatorForLanguageModeling`] if the model is a language model
- and [`~trainer.sft_trainer.DataCollatorForVisionLanguageModeling`] if the model is a vision-language model.
- train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
- Dataset to use for training. SFT supports both [language modeling](#language-modeling) type and
- [prompt-completion](#prompt-completion) type. The format of the samples can be either:
-
- - [Standard](dataset_formats#standard): Each sample contains plain text.
- - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
- and content).
-
- The trainer also supports processed datasets (tokenized) as long as they contain an `input_ids` field.
- eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
- Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
- processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.ProcessorMixin`], *optional*):
- Processing class used to process the data. If `None`, the processing class is loaded from the model's name
- with [`~transformers.AutoProcessor.from_pretrained`]. A padding token, `tokenizer.pad_token`, must be set.
- If the processing class has not set a padding token, `tokenizer.eos_token` will be used as the default.
- compute_loss_func (`Callable`, *optional*):
- A function that accepts the raw model outputs, labels, and the number of items in the entire accumulated
- batch (batch_size * gradient_accumulation_steps) and returns the loss. For example, see the default [loss
- function](https://github.com/huggingface/transformers/blob/052e652d6d53c2b26ffde87e039b723949a53493/src/transformers/trainer.py#L3618)
- used by [`Trainer`].
- compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
- The function that will be used to compute metrics at evaluation. Must take a
- [`~transformers.EvalPrediction`] and return a dictionary string to metric values. When passing
- [`SFTConfig`] with `batch_eval_metrics` set to `True`, your `compute_metrics` function must take a boolean
- `compute_result` argument. This will be triggered after the last eval batch to signal that the function
- needs to calculate and return the global summary statistics rather than accumulating the batch-level
- statistics.
- callbacks (list of [`~transformers.TrainerCallback`], *optional*):
- List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed
- in [here](https://huggingface.co/docs/transformers/main_classes/callback).
-
- If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
- method.
- optimizers (`tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]]`, *optional*, defaults to `(None, None)`):
- A tuple containing the optimizer and the scheduler to use. Will default to an instance of `AdamW` on your
- model and a scheduler given by [`~transformers.get_linear_schedule_with_warmup`] controlled by `args`.
- optimizer_cls_and_kwargs (`tuple[Type[torch.optim.Optimizer], Dict[str, Any]]`, *optional*):
- A tuple containing the optimizer class and keyword arguments to use. Overrides `optim` and `optim_args` in
- `args`. Incompatible with the `optimizers` argument.
-
- Unlike `optimizers`, this argument avoids the need to place model parameters on the correct devices before
- initializing the Trainer.
- preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*):
- A function that preprocess the logits right before caching them at each evaluation step. Must take two
- tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
- by this function will be reflected in the predictions received by `compute_metrics`.
-
- Note that the labels (second parameter) will be `None` if the dataset does not have them.
- peft_config ([`~peft.PeftConfig`], *optional*):
- PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
- formatting_func (`Callable`, *optional*):
- Formatting function applied to the dataset before tokenization. Applying the formatting function explicitly
- converts the dataset into a [language modeling](#language-modeling) type.
-
- """
- def __init__(
- self,
- model,
- args = None,
- data_collator = None,
- train_dataset = None,
- eval_dataset = None,
- processing_class = None,
- compute_loss_func = None,
- compute_metrics = None,
- callbacks = None,
- optimizer_cls_and_kwargs = None,
- preprocess_logits_for_metrics = None,
- peft_config = None,
- formatting_func = None,
- **kwargs
- ):
- if args is None: args = UnslothSFTConfig()
- use_bf16 = getattr(args, 'bf16', False)
- if type(use_bf16) is not bool: use_bf16 = False
- use_fp16 = getattr(args, 'fp16', False)
- if type(use_fp16) is not bool: use_fp16 = False
- force_float32 = False
- full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1'
- if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'):
- print('Unsloth: Switching to float32 training since model cannot work with float16')
- force_float32 = True
- mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
- dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None)
- if dtype is None: dtype = model.get_input_embeddings().weight.dtype
- from unsloth_zoo.utils import _get_dtype
- dtype = _get_dtype(dtype)
- float16 = dtype == torch.float16
- if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
- if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
- if force_float32:
- # Forced float32 training
- args.fp16 = False
- args.bf16 = False
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
- if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
- # args.mixed_precision is a new argument which needs to be set now
- elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
- # Mixed precision training
- args.fp16 = float16
- args.bf16 = not float16
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
- if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16'
- # args.mixed_precision is a new argument which needs to be set now
- elif mixed_precision_dtype == 'bfloat16':
- # Both False since bfloat16 full finetuning doesn't do any autocasting.
- args.fp16 = False
- args.bf16 = False
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
- if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
- # args.mixed_precision is a new argument which needs to be set now
-
- if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
- args.eval_strategy = 'steps'
- if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
- ga_steps = getattr(args, 'gradient_accumulation_steps', None)
- if ga_steps is not None and ga_steps > 1:
- from transformers import __version__ as transformers_version
- if Version(transformers_version) <= Version('4.45.2'):
- print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
- '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
- if getattr(args, 'eval_strategy', 'no') != 'no':
- eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
- if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
- if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
- fp16_full_eval = getattr(args, 'fp16_full_eval', False)
- if type(fp16_full_eval) is not bool: fp16_full_eval = False
- bf16_full_eval = getattr(args, 'bf16_full_eval', False)
- if type(bf16_full_eval) is not bool: bf16_full_eval = False
- if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
- if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
- if force_float32:
- args.bf16_full_eval = False
- args.fp16_full_eval = False
- elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
- args.bf16_full_eval = True
- args.fp16_full_eval = False
- elif not bf16_full_eval and not fp16_full_eval:
- args.bf16_full_eval = args.bf16
- args.fp16_full_eval = args.fp16
- _output_logits = False
- if locals().get('compute_metrics', None) is not None: _output_logits = True
- if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
- if _output_logits:
- os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
- if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
- pass
- else:
- model_max_seq_length = getattr(model, 'max_seq_length', None)
- args_max_seq_length = getattr(args, 'max_seq_length', None)
- if args_max_seq_length is None and model_max_seq_length is not None:
- max_seq_length = model.max_seq_length
- if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
- elif args_max_seq_length is not None and model_max_seq_length is not None:
- if args_max_seq_length > model_max_seq_length:
- print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but '
- 'the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.')
- args.max_seq_length = model_max_seq_length
- if 'max_length' not in locals() and not hasattr(args, 'max_length'):
- pass
- else:
- if hasattr(args, 'max_seq_length') and args.max_seq_length is not None and args.max_seq_length > 0:
- if hasattr(args, 'max_length'):
- args.max_length = args.max_seq_length
- max_length = args.max_length
- else:
- model_max_length = getattr(model, 'max_seq_length', None)
- if model_max_length is None: model_max_length = getattr(model, 'max_length', None)
- if model_max_length is not None:
- args.max_length = model_max_length
- max_length = args.max_length
- elif hasattr(args, 'max_length') and args.max_length is not None:
- max_length = args.max_length
- # if we are here, then we are in a weird case where max_length is set but max_seq_length is not set
- setattr(model, 'max_seq_length', max_length)
- else:
- print('Unsloth: We did not find `max_seq_length` or `max_length` in the model or args. We will set it to 1024.')
- args.max_length = 1024
- if model is not None and hasattr(model, 'for_training'):
- model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
- if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
- if 'processing_class' in locals():
- if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
- if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
- __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
- from unsloth_zoo.vision_utils import UnslothVisionDataCollator
- if not isinstance(data_collator, UnslothVisionDataCollator):
- if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
- data_collator = TransformersDataCollatorForLanguageModeling(
- __tokenizer,
- mlm = False,
- mlm_probability = 0.0,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
- data_collator = DataCollatorForSeq2Seq(
- __tokenizer,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- else:
- if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
- if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
- if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
- if not isinstance(data_collator, UnslothVisionDataCollator):
- if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
- if isinstance(data_collator, DataCollatorForSeq2Seq):
- data_collator = DataCollatorForSeq2Seq(
- __tokenizer.tokenizer,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- else:
- data_collator = TransformersDataCollatorForLanguageModeling(
- __tokenizer.tokenizer,
- mlm = False,
- mlm_probability = 0.0,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- other_metrics = []
-
- from unsloth_zoo.logging_utils import PatchRLStatistics
- PatchRLStatistics('sft_trainer', other_metrics)
- IGNORED_TOKENIZER_NAMES = os.environ.get('UNSLOTH_IGNORED_TOKENIZER_NAMES', '').split('\n')
- from unsloth_zoo.tokenizer_utils import fix_untrained_tokens
- from unsloth_zoo.training_utils import fix_zero_training_loss
- if 'tokenizer' not in locals(): tokenizer = processing_class
- fix_untrained_tokens(model, tokenizer, train_dataset, IGNORED_TOKENIZER_NAMES, eps = 1e-16)
- fix_zero_training_loss(model, tokenizer, train_dataset)
-
- # [TODO] Fix up DataParallel multiplying batch sizes
- # [TODO] DDP works, but DP seems to not work? [TODO]
- if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1:
- if getattr(args, "_n_gpu", 1) != 1:
- args._n_gpu = 1
- if "model" in locals() and hasattr(model, "for_training"):
- model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
- super().__init__(
- model = model,
- args = args,
- data_collator = data_collator,
- train_dataset = train_dataset,
- eval_dataset = eval_dataset,
- processing_class = processing_class,
- compute_loss_func = compute_loss_func,
- compute_metrics = compute_metrics,
- callbacks = callbacks,
- optimizer_cls_and_kwargs = optimizer_cls_and_kwargs,
- preprocess_logits_for_metrics = preprocess_logits_for_metrics,
- peft_config = peft_config,
- formatting_func = formatting_func,**kwargs)
- if "model" in locals() and hasattr(model, "for_inference"):
- model.for_inference()
- if hasattr(self, 'neftune_hook_handle'):
- self.neftune_hook_handle.remove()
- if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
- if getattr(args, 'neftune_noise_alpha', None) is not None:
- model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
- pass
- if hasattr(self, 'accelerator'):
- scaler = self.accelerator.scaler
- current_model = model
- while hasattr(current_model, 'model'):
- current_model.accelerator_scaler = scaler
- current_model = current_model.model
- current_model.accelerator_scaler = scaler
- pass
- if hasattr(self, 'train'):
- self.train = MethodType(prepare_for_training_mode(self.__class__.train), self)
- pass
- if hasattr(self, 'llm') and self.llm is not None and hasattr(self.llm, 'get_tokenizer'):
- _vllm_tok = self.llm.get_tokenizer()
- _pc = getattr(self, 'processing_class', None) or getattr(self, 'tokenizer', None)
- if _vllm_tok is not None and _pc is not None and getattr(_pc, 'chat_template', None) is not None and getattr(_vllm_tok, 'chat_template', None) is None:
- _vllm_tok.chat_template = _pc.chat_template
- pass
-
-pass
-
-
-if hasattr(logger, "addFilter"):
- import logging
- class HideLoggingMessage(logging.Filter):
- def __init__(self, text): self.text = text
- def filter(self, x): return not (self.text in x.getMessage())
- pass
- logger.addFilter(HideLoggingMessage("`use_cache=True`"))
-
diff --git a/unsloth_compiled_cache/UnslothXPOTrainer.py b/unsloth_compiled_cache/UnslothXPOTrainer.py
deleted file mode 100644
index 4da51fc..0000000
--- a/unsloth_compiled_cache/UnslothXPOTrainer.py
+++ /dev/null
@@ -1,1409 +0,0 @@
-"""
-2026.2.1
-2026.2.1
-4.57.6
-0.24.0
-__UNSLOTH_VERSIONING__
-"""
-
-# Unsloth auto generated code
-# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Lesser General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU Lesser General Public License
-# along with this program. If not, see .
-
-from torch import Tensor
-import torch
-import torch.nn as nn
-from torch.nn import functional as F
-from unsloth_zoo.temporary_patches.common import torch_compile
-from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
-from trl.trainer.xpo_trainer import (Any, BaseImageProcessor, BasePairwiseJudge, Callable, Dataset, EvalPrediction, F, FeatureExtractionMixin, IterableDataset, OnlineDPOTrainer, OptimizerNames, Optional, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SIMPLE_CHAT_TEMPLATE, TrainerCallback, Union, XPOConfig, XPOTrainer, empty_cache, get_reward, is_conversational, is_peft_available, jinja2, maybe_apply_chat_template, nn, selective_log_softmax, textwrap, torch, truncate_right, unwrap_model_for_generation)
-
-
-import os
-from typing import *
-from dataclasses import dataclass, field
-from packaging.version import Version
-import torch
-import numpy as np
-from contextlib import nullcontext
-from torch.nn import functional as F
-import inspect
-from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
-from transformers.training_args import ParallelMode
-from unsloth_zoo.device_type import DEVICE_TYPE, device_synchronize
-
-# Wrap trainer with padding to right and enable training mode
-# Also patches W&B since multiple runs must use wandb.finish()
-import functools
-from types import MethodType
-try:
- from unsloth_zoo.gradient_checkpointing import reset_unsloth_gradient_checkpointing_buffers
-except:
- def reset_unsloth_gradient_checkpointing_buffers(): pass
-def prepare_for_training_mode(f):
- @functools.wraps(f)
- def wrapper(self, *args, **kwargs):
- # Enable training mode
- _was_training = None
- # Get gradient checkpointing setting from training arguments
- use_gc = getattr(self.args, 'gradient_checkpointing', True)
- if hasattr(self, 'model') and hasattr(self.model, "training"):
- _was_training = self.model.training
- if hasattr(self, 'model') and hasattr(self.model, "for_training"):
- self.model.for_training(use_gradient_checkpointing=use_gc)
- output = f(self, *args, **kwargs)
- # Restore previous mode when possible
- if hasattr(self, 'model') and hasattr(self.model, "for_inference"):
- if _was_training is False:
- self.model.for_inference()
- elif _was_training is True and hasattr(self.model, "for_training"):
- self.model.for_training(use_gradient_checkpointing=use_gc)
- # Reset gradient checkpointing buffers to free memory while staying ready for next run
- try:
- reset_unsloth_gradient_checkpointing_buffers()
- except:
- pass
- # Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run
- try:
- import wandb
- wandb.finish()
- except:
- pass
- return output
- return wrapper
-pass
-
-torch_compile_options = {
- "epilogue_fusion" : True,
- "max_autotune" : False,
- "shape_padding" : True,
- "trace.enabled" : False,
- "triton.cudagraphs" : False,
-}
-
-@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
-def chunked_hidden_states_selective_log_softmax(
- hidden_states: torch.Tensor,
- lm_head: torch.Tensor,
- index: torch.Tensor,
- chunks: int = 4,
- logit_scale_multiply: float = 0.0,
- logit_scale_divide: float = 0.0,
- logit_softcapping: float = 0.0,
- temperature: float = 1.0,
-) -> torch.Tensor:
- # All Unsloth Zoo code licensed under AGPL3
- flat_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1])
- flat_index = index.reshape(-1)
-
- chunked_hidden_states = torch.chunk(flat_hidden_states, chunks=chunks, dim=0)
- chunked_index = torch.chunk(flat_index, chunks=chunks, dim=0)
-
- all_per_token_logps = []
-
- for chunk_hidden_states, chunk_index in zip(chunked_hidden_states, chunked_index):
- chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t()
-
- if logit_scale_multiply != 0.0:
- chunk_logits = chunk_logits * logit_scale_multiply
- if logit_scale_divide != 0.0:
- chunk_logits = chunk_logits / logit_scale_divide
- if logit_softcapping != 0.0:
- chunk_logits = chunk_logits * torch.tanh(chunk_logits / logit_softcapping)
-
- chunk_logits = chunk_logits.to(torch.float32)
-
- if temperature != 1.0:
- chunk_logits = chunk_logits / temperature
-
- selected_logits = torch.gather(chunk_logits, dim=-1, index=chunk_index.unsqueeze(-1)).squeeze(-1)
- logsumexp_values = torch.logsumexp(chunk_logits, dim=-1)
- per_token_logps = selected_logits - logsumexp_values
- all_per_token_logps.append(per_token_logps)
-
- all_per_token_logps = torch.concat(all_per_token_logps)
-
- all_per_token_logps = all_per_token_logps.reshape((hidden_states.shape[0], hidden_states.shape[1]))
- return all_per_token_logps
-
-@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
-def chunked_selective_log_softmax(logits, index):
- # Split into 4 chunks only
- chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
- chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
- all_per_token_logps = []
- # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
- for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
- chunk_logits = chunk_logits.to(torch.float32)
- selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
- logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
- per_token_logps = selected_logits - logsumexp_values
- all_per_token_logps.append(per_token_logps)
- pass
- all_per_token_logps = torch.concat(all_per_token_logps)
- all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
- return all_per_token_logps
-
-def calculate_pad_tokens_in_prompt(
- input_ids: torch.Tensor,
- logits_to_keep: int,
- pad_token_id: int
-) -> torch.Tensor:
- """
- Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens
- """
- if logits_to_keep >= input_ids.shape[1]:
- raise ValueError("logits_to_keep must be smaller than the sequence length.")
-
- prompt_section = input_ids[:, :-logits_to_keep]
-
- padding_mask = (prompt_section == pad_token_id)
-
- pad_token_counts = padding_mask.sum(dim=1)
-
- return pad_token_counts
-
-def create_completion_attention_mask(
- completion_input_ids: torch.Tensor,
- left_pad_tokens_per_prompt: torch.Tensor,
- max_left_pad: int,
- pad_token_id: int
-) -> torch.Tensor:
- """
- Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad]
-
- Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens
- and pad are pad tokens, this function would make a completion mask that would 0 out the pad
- and p tokens. so in this example [0,0,0,1,1,1,0,0,0]
- """
- batch_size, completion_len = completion_input_ids.shape
- device = completion_input_ids.device
-
- num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt
-
- indices = torch.arange(completion_len, device=device).unsqueeze(0)
- shift_mask = indices >= num_tokens_to_mask.unsqueeze(1)
-
- non_padding_mask = (completion_input_ids != pad_token_id)
-
- final_mask = shift_mask & non_padding_mask
-
- return final_mask
-
-def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor:
- """
- Moves all padding tokens in each sequence of a batch to the right.
- """
- mask = (tensor != pad_id)
- # Must do stable=True since binary mark is unordered
- sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True)
- packed_tensor = torch.gather(tensor, 1, sorted_indices)
- return packed_tensor
-
-def align_logprobs_with_mask(
- logprob_tensor: torch.Tensor,
- attention_mask: torch.Tensor,
- pad_value: float = 0.0
-) -> torch.Tensor:
- """
- Aligns a log probability tensor with a given attention mask.
- """
-
- device = logprob_tensor.device
- batch_size, logprob_seq_len = logprob_tensor.shape
- mask_seq_len = attention_mask.shape[1]
-
- padded_logprobs = torch.full(
- attention_mask.shape,
- fill_value=pad_value,
- dtype=logprob_tensor.dtype,
- device=device
- )
-
- left_pad_counts = torch.argmax(attention_mask, dim=1)
-
- cols = torch.arange(logprob_seq_len, device=device)
- dest_indices = left_pad_counts.unsqueeze(1) + cols
-
- # Create destination row indices
- # Shape: [batch_size, logprob_seq_len]
- row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices)
-
- # --- 4. Filter out-of-bounds indices and perform assignment ---
- # Create a mask to identify only the indices that are within the bounds
- # of the target tensor's sequence length.
- valid_mask = dest_indices < mask_seq_len
-
- # Use this mask to select only the valid row indices, column indices,
- # and the corresponding values from the logprob tensor.
- # This flattens the selected elements into 1D tensors.
- valid_rows = row_indices[valid_mask]
- valid_cols = dest_indices[valid_mask]
- valid_vals = logprob_tensor[valid_mask]
-
- # Place the valid values into their correct positions in the padded tensor
- # using a single, efficient advanced indexing operation.
- padded_logprobs[valid_rows, valid_cols] = valid_vals
-
- return padded_logprobs
-
-def autotune_batch_and_chunks(
- total_input_rows,
- seq_len,
- hidden_size,
- vocab_size,
- dtype_bytes=16,
- multiplier=None
-):
- if multiplier is None:
- final_m = max(4, seq_len // 4096)
- else:
- final_m = multiplier
-
- if torch.cuda.is_available():
- free_bytes, _ = torch.cuda.mem_get_info()
- limit_gb = (free_bytes / (1024**3))*.80
- elif hasattr(torch, "xpu") and torch.xpu.is_available():
- # For XPU: estimate free memory from total - reserved
- total_mem = torch.xpu.get_device_properties(0).total_memory
- reserved_mem = torch.xpu.memory_reserved()
- free_bytes = total_mem - reserved_mem
- limit_gb = (free_bytes / (1024**3)) * 0.80
- else:
- # Fallback: assume 8GB available
- limit_gb = 8.0
-
- bytes_to_gb = 1024**3
-
- b_vals = torch.arange(total_input_rows, 0, -1, device='cpu', dtype=torch.float32)
-
- hidden_gb = (b_vals * seq_len * hidden_size * dtype_bytes) / bytes_to_gb
-
- base_logits = ((b_vals/total_input_rows) * b_vals * seq_len * vocab_size * dtype_bytes) / bytes_to_gb
- logits_gb = base_logits / final_m
-
- total_mem_gb = hidden_gb + logits_gb
-
- valid_mask = total_mem_gb <= limit_gb
- valid_indices = torch.nonzero(valid_mask, as_tuple=False)
-
- if valid_indices.shape[0] == 0:
- #This means your GPU will OOM
- return 4, final_m
-
- best_idx = valid_indices[0].item()
- final_b = int(b_vals[best_idx].item())
-
- return final_b, final_m
-@dataclass
-class UnslothXPOConfig(XPOConfig):
- """
-
- Configuration class for the [`XPOTrainer`].
-
- Subclass of [`OnlineDPOConfig`] we can use all its arguments and add the following:
-
- Parameters:
- alpha (`float` or `list[float]`, *optional*, defaults to `1e-5`):
- Weight of the XPO loss term. If a list of floats is provided then the alpha is selected for each new epoch
- and the last alpha is used for the rest of the epochs.
-
- """
- vllm_sampling_params: Optional[Any] = field(
- default = None,
- metadata = {'help': 'vLLM SamplingParams'},
- )
- unsloth_num_chunks : Optional[int] = field(
- default = -1,
- metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
- )
- unsloth_logit_chunk_multiplier : Optional[int] = field(
- default = None,
- metadata = {'help': 'Multiplier for chunked logit computations.'},
- )
- unsloth_grpo_mini_batch : Optional[int] = field(
- default = None,
- metadata = {'help': 'Mini batch size for GRPO hidden state accumulation. Default is None unless user defines it.'},
- )
- max_seq_length : Optional[int] = field(
- default = None,
- metadata = {'help': 'Maximum sequence length to truncate to.'},
- )
- def __init__(
- self,
- output_dir = None,
- overwrite_output_dir = None,
- do_train = False,
- do_eval = False,
- do_predict = False,
- eval_strategy = 'no',
- prediction_loss_only = False,
- per_device_train_batch_size = 4,
- per_device_eval_batch_size = 4,
- per_gpu_train_batch_size = None,
- per_gpu_eval_batch_size = None,
- gradient_accumulation_steps = 2,
- eval_accumulation_steps = 2,
- eval_delay = 0,
- torch_empty_cache_steps = 250,
- learning_rate = 5e-05,
- weight_decay = 0.01,
- adam_beta1 = 0.9,
- adam_beta2 = 0.999,
- adam_epsilon = 1e-08,
- max_grad_norm = 1.0,
- num_train_epochs = 3.0,
- max_steps = -1,
- lr_scheduler_type = 'linear',
- lr_scheduler_kwargs = None,
- warmup_ratio = 0.1,
- warmup_steps = 0,
- log_level = 'passive',
- log_level_replica = 'warning',
- log_on_each_node = True,
- logging_dir = None,
- logging_strategy = 'steps',
- logging_first_step = False,
- logging_steps = 1,
- logging_nan_inf_filter = False,
- save_strategy = 'steps',
- save_steps = 500,
- save_total_limit = None,
- save_safetensors = True,
- save_on_each_node = False,
- save_only_model = False,
- restore_callback_states_from_checkpoint = False,
- no_cuda = False,
- use_cpu = False,
- use_mps_device = False,
- seed = 3407,
- data_seed = 3407,
- jit_mode_eval = False,
- bf16 = False,
- fp16 = False,
- fp16_opt_level = 'O1',
- half_precision_backend = 'auto',
- bf16_full_eval = False,
- fp16_full_eval = False,
- tf32 = None,
- local_rank = -1,
- ddp_backend = None,
- tpu_num_cores = None,
- tpu_metrics_debug = False,
- debug = '',
- dataloader_drop_last = False,
- eval_steps = None,
- dataloader_num_workers = 0,
- dataloader_prefetch_factor = None,
- past_index = -1,
- run_name = None,
- disable_tqdm = None,
- remove_unused_columns = True,
- label_names = None,
- load_best_model_at_end = False,
- metric_for_best_model = None,
- greater_is_better = None,
- ignore_data_skip = False,
- fsdp = None,
- fsdp_min_num_params = 0,
- fsdp_config = None,
- fsdp_transformer_layer_cls_to_wrap = None,
- accelerator_config = None,
- parallelism_config = None,
- deepspeed = None,
- label_smoothing_factor = 0.0,
- optim = 'adamw_8bit',
- optim_args = None,
- adafactor = False,
- group_by_length = False,
- length_column_name = 'length',
- report_to = 'none',
- project = 'huggingface',
- trackio_space_id = 'trackio',
- ddp_find_unused_parameters = None,
- ddp_bucket_cap_mb = None,
- ddp_broadcast_buffers = None,
- dataloader_pin_memory = True,
- dataloader_persistent_workers = False,
- skip_memory_metrics = True,
- use_legacy_prediction_loop = False,
- push_to_hub = False,
- resume_from_checkpoint = None,
- hub_model_id = None,
- hub_strategy = 'every_save',
- hub_token = None,
- hub_private_repo = None,
- hub_always_push = False,
- hub_revision = None,
- gradient_checkpointing = True,
- gradient_checkpointing_kwargs = None,
- include_inputs_for_metrics = False,
- eval_do_concat_batches = True,
- fp16_backend = 'auto',
- push_to_hub_model_id = None,
- push_to_hub_organization = None,
- push_to_hub_token = None,
- mp_parameters = '',
- auto_find_batch_size = False,
- full_determinism = False,
- torchdynamo = None,
- ray_scope = 'last',
- ddp_timeout = 1800,
- torch_compile = False,
- torch_compile_backend = None,
- torch_compile_mode = None,
- include_tokens_per_second = False,
- include_num_input_tokens_seen = False,
- neftune_noise_alpha = None,
- optim_target_modules = None,
- batch_eval_metrics = False,
- eval_on_start = False,
- use_liger_kernel = False,
- liger_kernel_config = None,
- eval_use_gather_object = False,
- average_tokens_across_devices = True,
- reward_model_path = None,
- judge = None,
- max_new_tokens = 64,
- max_length = 512,
- temperature = 0.9,
- top_p = 1.0,
- top_k = None,
- min_p = None,
- repetition_penalty = 1.0,
- generation_kwargs = {},
- use_transformers_paged = False,
- cache_implementation = None,
- missing_eos_penalty = None,
- loss_type = 'sigmoid',
- disable_dropout = True,
- use_vllm = False,
- vllm_model_impl = 'vllm',
- vllm_guided_decoding_regex = None,
- vllm_gpu_memory_utilization = 0.55,
- vllm_mode = 'colocate',
- vllm_server_base_url = None,
- vllm_server_host = '0.0.0.0',
- vllm_server_port = 8000,
- vllm_server_timeout = 240.0,
- vllm_tensor_parallel_size = 1,
- ds3_gather_for_generation = True,
- model_init_kwargs = None,
- reward_weights = None,
- dataset_num_proc = None,
- gpu_memory_utilization = None,
- vllm_sampling_params = None,
- unsloth_num_chunks = -1,
- unsloth_logit_chunk_multiplier = None,
- unsloth_grpo_mini_batch = None,
- max_seq_length = None,
- **kwargs,
- ):
- if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
- if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
- if num_train_epochs is None:
- num_train_epochs = 3.0 # Default to 3 epochs if None, max_steps will override
- if output_dir is None and save_strategy == 'steps' and save_steps == 500:
- output_dir = 'unsloth_training_checkpoints'
- save_strategy = 'no'
- import multiprocessing as _mp
- if _mp.get_start_method() != 'fork':
- dataset_num_proc = None
- elif dataset_num_proc is None:
- import psutil
- dataset_num_proc = min(max((psutil.cpu_count() or 1)+4, 2), 64)
- memory_gb_left = psutil.virtual_memory().available / (1024**3)
- if memory_gb_left <= 2: dataset_num_proc = 1
- else: dataset_num_proc = min(dataset_num_proc, int(memory_gb_left))
- if temperature <= 0:
- raise ValueError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.')
- elif temperature >= 10:
- raise ValueError('Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic.')
-
-
- super().__init__(
- output_dir = output_dir,
- overwrite_output_dir = overwrite_output_dir,
- do_train = do_train,
- do_eval = do_eval,
- do_predict = do_predict,
- eval_strategy = eval_strategy,
- prediction_loss_only = prediction_loss_only,
- per_device_train_batch_size = per_device_train_batch_size,
- per_device_eval_batch_size = per_device_eval_batch_size,
- per_gpu_train_batch_size = per_gpu_train_batch_size,
- per_gpu_eval_batch_size = per_gpu_eval_batch_size,
- gradient_accumulation_steps = gradient_accumulation_steps,
- eval_accumulation_steps = eval_accumulation_steps,
- eval_delay = eval_delay,
- torch_empty_cache_steps = torch_empty_cache_steps,
- learning_rate = learning_rate,
- weight_decay = weight_decay,
- adam_beta1 = adam_beta1,
- adam_beta2 = adam_beta2,
- adam_epsilon = adam_epsilon,
- max_grad_norm = max_grad_norm,
- num_train_epochs = num_train_epochs,
- max_steps = max_steps,
- lr_scheduler_type = lr_scheduler_type,
- lr_scheduler_kwargs = lr_scheduler_kwargs,
- warmup_ratio = warmup_ratio,
- warmup_steps = warmup_steps,
- log_level = log_level,
- log_level_replica = log_level_replica,
- log_on_each_node = log_on_each_node,
- logging_dir = logging_dir,
- logging_strategy = logging_strategy,
- logging_first_step = logging_first_step,
- logging_steps = logging_steps,
- logging_nan_inf_filter = logging_nan_inf_filter,
- save_strategy = save_strategy,
- save_steps = save_steps,
- save_total_limit = save_total_limit,
- save_safetensors = save_safetensors,
- save_on_each_node = save_on_each_node,
- save_only_model = save_only_model,
- restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
- no_cuda = no_cuda,
- use_cpu = use_cpu,
- use_mps_device = use_mps_device,
- seed = seed,
- data_seed = data_seed,
- jit_mode_eval = jit_mode_eval,
- bf16 = bf16,
- fp16 = fp16,
- fp16_opt_level = fp16_opt_level,
- half_precision_backend = half_precision_backend,
- bf16_full_eval = bf16_full_eval,
- fp16_full_eval = fp16_full_eval,
- tf32 = tf32,
- local_rank = local_rank,
- ddp_backend = ddp_backend,
- tpu_num_cores = tpu_num_cores,
- tpu_metrics_debug = tpu_metrics_debug,
- debug = debug,
- dataloader_drop_last = dataloader_drop_last,
- eval_steps = eval_steps,
- dataloader_num_workers = dataloader_num_workers,
- dataloader_prefetch_factor = dataloader_prefetch_factor,
- past_index = past_index,
- run_name = run_name,
- disable_tqdm = disable_tqdm,
- remove_unused_columns = remove_unused_columns,
- label_names = label_names,
- load_best_model_at_end = load_best_model_at_end,
- metric_for_best_model = metric_for_best_model,
- greater_is_better = greater_is_better,
- ignore_data_skip = ignore_data_skip,
- fsdp = fsdp,
- fsdp_min_num_params = fsdp_min_num_params,
- fsdp_config = fsdp_config,
- fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
- accelerator_config = accelerator_config,
- parallelism_config = parallelism_config,
- deepspeed = deepspeed,
- label_smoothing_factor = label_smoothing_factor,
- optim = optim,
- optim_args = optim_args,
- adafactor = adafactor,
- group_by_length = group_by_length,
- length_column_name = length_column_name,
- report_to = report_to,
- project = project,
- trackio_space_id = trackio_space_id,
- ddp_find_unused_parameters = ddp_find_unused_parameters,
- ddp_bucket_cap_mb = ddp_bucket_cap_mb,
- ddp_broadcast_buffers = ddp_broadcast_buffers,
- dataloader_pin_memory = dataloader_pin_memory,
- dataloader_persistent_workers = dataloader_persistent_workers,
- skip_memory_metrics = skip_memory_metrics,
- use_legacy_prediction_loop = use_legacy_prediction_loop,
- push_to_hub = push_to_hub,
- resume_from_checkpoint = resume_from_checkpoint,
- hub_model_id = hub_model_id,
- hub_strategy = hub_strategy,
- hub_token = hub_token,
- hub_private_repo = hub_private_repo,
- hub_always_push = hub_always_push,
- hub_revision = hub_revision,
- gradient_checkpointing = gradient_checkpointing,
- gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
- include_inputs_for_metrics = include_inputs_for_metrics,
- eval_do_concat_batches = eval_do_concat_batches,
- fp16_backend = fp16_backend,
- push_to_hub_model_id = push_to_hub_model_id,
- push_to_hub_organization = push_to_hub_organization,
- push_to_hub_token = push_to_hub_token,
- mp_parameters = mp_parameters,
- auto_find_batch_size = auto_find_batch_size,
- full_determinism = full_determinism,
- torchdynamo = torchdynamo,
- ray_scope = ray_scope,
- ddp_timeout = ddp_timeout,
- torch_compile = torch_compile,
- torch_compile_backend = torch_compile_backend,
- torch_compile_mode = torch_compile_mode,
- include_tokens_per_second = include_tokens_per_second,
- include_num_input_tokens_seen = include_num_input_tokens_seen,
- neftune_noise_alpha = neftune_noise_alpha,
- optim_target_modules = optim_target_modules,
- batch_eval_metrics = batch_eval_metrics,
- eval_on_start = eval_on_start,
- use_liger_kernel = use_liger_kernel,
- liger_kernel_config = liger_kernel_config,
- eval_use_gather_object = eval_use_gather_object,
- average_tokens_across_devices = average_tokens_across_devices,
- reward_model_path = reward_model_path,
- judge = judge,
- max_new_tokens = max_new_tokens,
- max_length = max_length,
- temperature = temperature,
- top_p = top_p,
- top_k = top_k,
- min_p = min_p,
- repetition_penalty = repetition_penalty,
- generation_kwargs = generation_kwargs,
- use_transformers_paged = use_transformers_paged,
- cache_implementation = cache_implementation,
- missing_eos_penalty = missing_eos_penalty,
- loss_type = loss_type,
- disable_dropout = disable_dropout,
- use_vllm = use_vllm,
- vllm_model_impl = vllm_model_impl,
- vllm_guided_decoding_regex = vllm_guided_decoding_regex,
- vllm_gpu_memory_utilization = vllm_gpu_memory_utilization,
- vllm_mode = vllm_mode,
- vllm_server_base_url = vllm_server_base_url,
- vllm_server_host = vllm_server_host,
- vllm_server_port = vllm_server_port,
- vllm_server_timeout = vllm_server_timeout,
- vllm_tensor_parallel_size = vllm_tensor_parallel_size,
- ds3_gather_for_generation = ds3_gather_for_generation,
- model_init_kwargs = model_init_kwargs,
- reward_weights = reward_weights,
- dataset_num_proc = dataset_num_proc,
- gpu_memory_utilization = gpu_memory_utilization,**kwargs)
- self.vllm_sampling_params = vllm_sampling_params
- self.unsloth_num_chunks = unsloth_num_chunks
- if unsloth_grpo_mini_batch is not None:
- if self.generation_batch_size >= unsloth_grpo_mini_batch:
- self.unsloth_grpo_mini_batch = unsloth_grpo_mini_batch
- else:
- raise ValueError(
- f"Unsloth GRPO mini batch size needs to be less than or equal to the effective generation batch size, "
- f"which is self.per_device_train_batch_size * gradient_accumulation_steps."
- )
- self.unsloth_logit_chunk_multiplier = unsloth_logit_chunk_multiplier
- self.max_seq_length = max_seq_length
-
-pass
-
-class _UnslothXPOTrainer(OnlineDPOTrainer):
- """"""
-
- _tag_names = ["trl", "xpo"]
- _name = "XPO"
- _paper = {
- "title": "Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF",
- "id": "2405.21046",
- # docstyle-ignore
- "citation": textwrap.dedent("""\
- @article{jung2024binary,
- title = {{Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF}},
- author = {Tengyang Xie and Dylan J. Foster and Akshay Krishnamurthy and Corby Rosset and Ahmed Awadallah and Alexander Rakhlin},
- year = 2024,
- eprint = {arXiv:2405.21046}
- }"""),
- }
-
- def __init__(
- self,
- model: Union[PreTrainedModel, nn.Module] = None,
- ref_model: Union[PreTrainedModel, nn.Module] = None,
- reward_funcs: Optional[nn.Module] = None,
- judge: Optional[BasePairwiseJudge] = None,
- args: Optional[XPOConfig] = None,
- data_collator: Optional[Callable] = None,
- train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
- eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
- processing_class: Optional[
- Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
- ] = None,
- reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
- peft_config: Optional[dict] = None,
- compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
- callbacks: Optional[list[TrainerCallback]] = None,
- optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
- preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
- # Deprecated parameters
- reward_model: Optional[Union[PreTrainedModel, nn.Module]] = None,
- ) -> None:
- super().__init__(
- model=model,
- ref_model=ref_model,
- judge=judge,
- reward_funcs=reward_funcs,
- reward_model=reward_model,
- args=args,
- data_collator=data_collator,
- train_dataset=train_dataset,
- eval_dataset=eval_dataset,
- processing_class=processing_class,
- reward_processing_classes=reward_processing_classes,
- peft_config=peft_config,
- compute_metrics=compute_metrics,
- callbacks=callbacks,
- optimizers=optimizers,
- preprocess_logits_for_metrics=preprocess_logits_for_metrics,
- )
-
- self._alpha = self.args.alpha
-
- # Overwrite the stats dictionary to include XPO specific statistics
- self.stats = {
- # Remove "non_score_reward", "rlhf_reward", "scores"
- # Add "loss/dpo", "loss/xpo"
- "loss/dpo": [],
- "loss/xpo": [],
- "objective/kl": [],
- "objective/entropy": [],
- "rewards/chosen": [],
- "rewards/rejected": [],
- "rewards/accuracies": [],
- "rewards/margins": [],
- "logps/chosen": [],
- "logps/rejected": [],
- # Replace "contain_eos_token" by "model_contain_eos_token" and "ref_contain_eos_token"
- "val/model_contain_eos_token": [],
- "val/ref_contain_eos_token": [],
- "alpha": [],
- "beta": [],
- }
- if self.reward_funcs is not None:
- if len(self.reward_funcs) != 1:
- raise ValueError("XPOTrainer only supports one reward function/model.")
- self.reward_funcs = self.reward_funcs[0]
- self.stats["objective/model_scores"] = []
- self.stats["objective/ref_scores"] = []
- self.stats["objective/scores_margin"] = []
-
- @property
- def alpha(self):
- if isinstance(self._alpha, list):
- epoch = self.state.epoch
- return self._alpha[epoch] if epoch < len(self._alpha) else self._alpha[-1]
- else:
- return self._alpha
-
- def _generate_completions(self, prompts, model):
- with unwrap_model_for_generation(model, self.accelerator) as unwrapped_policy_model_for_gen:
- model_output = unwrapped_policy_model_for_gen.generate(
- input_ids=prompts["input_ids"],
- attention_mask=prompts["attention_mask"],
- generation_config=self.generation_config,
- )
-
- actual_model_for_ref_generation: torch.nn.Module
- if self.ref_model is None:
- unwrapped_main_model_for_ref_logic = self.accelerator.unwrap_model(model)
-
- if is_peft_available() and isinstance(unwrapped_main_model_for_ref_logic, PeftModel):
- actual_model_for_ref_generation = unwrapped_main_model_for_ref_logic.get_base_model()
- else:
- actual_model_for_ref_generation = unwrapped_main_model_for_ref_logic
- else:
- actual_model_for_ref_generation = self.accelerator.unwrap_model(self.ref_model)
-
- with unwrap_model_for_generation(actual_model_for_ref_generation, self.accelerator) as final_ref_model_for_gen:
- ref_output = final_ref_model_for_gen.generate(
- input_ids=prompts["input_ids"],
- attention_mask=prompts["attention_mask"],
- generation_config=self.generation_config,
- )
-
- return model_output, ref_output
-
- def _process_completions(self, model_output, ref_output, prompts):
- context_length = prompts["input_ids"].shape[1]
-
- # Process model completions
- model_completion_ids = model_output[:, context_length:]
- model_completion_ids, model_completion_mask = truncate_right(
- model_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
- )
- model_data = {
- "input_ids": torch.cat((prompts["input_ids"], model_completion_ids), dim=1),
- "attention_mask": torch.cat((prompts["attention_mask"], model_completion_mask), dim=1),
- "raw": prompts["raw"],
- }
-
- # Process reference model completions
- ref_completion_ids = ref_output[:, context_length:]
- ref_completion_ids, ref_completion_mask = truncate_right(
- ref_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
- )
- ref_data = {
- "input_ids": torch.cat((prompts["input_ids"], ref_completion_ids), dim=1),
- "attention_mask": torch.cat((prompts["attention_mask"], ref_completion_mask), dim=1),
- "raw": prompts["raw"],
- }
-
- return model_data, ref_data
-
- def _compute_rewards(self, model_data, ref_data, context_length):
- with torch.no_grad():
- _, model_scores, _ = get_reward(
- self.reward_funcs, model_data["input_ids"], self.processing_class.pad_token_id, context_length
- )
- _, ref_scores, _ = get_reward(
- self.reward_funcs, ref_data["input_ids"], self.processing_class.pad_token_id, context_length
- )
-
- # Apply EOS penalty if needed
- if self.args.missing_eos_penalty is not None:
- model_contain_eos = torch.any(model_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
- ref_contain_eos = torch.any(ref_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
- model_scores[~model_contain_eos] -= self.args.missing_eos_penalty
- ref_scores[~ref_contain_eos] -= self.args.missing_eos_penalty
-
- return model_scores, ref_scores
-
- def _compute_judge(self, model_data, ref_data, context_length):
- prompts = model_data["raw"]
- model_data_completions = self.processing_class.batch_decode(
- model_data["input_ids"][:, context_length:], skip_special_tokens=True
- )
- model_data_completions = [completion.strip() for completion in model_data_completions]
-
- ref_data_completions = self.processing_class.batch_decode(
- ref_data["input_ids"][:, context_length:], skip_special_tokens=True
- )
- ref_data_completions = [completion.strip() for completion in ref_data_completions]
-
- if is_conversational({"prompt": prompts[0]}):
- model_data_completions = [
- [{"role": "assistant", "content": completion}] for completion in model_data_completions
- ]
- environment = jinja2.Environment()
- template = environment.from_string(SIMPLE_CHAT_TEMPLATE)
- prompts = [template.render(messages=message) for message in prompts]
- model_data_completions = [template.render(messages=completion) for completion in model_data_completions]
-
- ref_data_completions = [
- [{"role": "assistant", "content": completion}] for completion in ref_data_completions
- ]
- ref_data_completions = [template.render(messages=completion) for completion in ref_data_completions]
-
- ranks_of_first_completion = self.judge.judge(
- prompts,
- list(zip(model_data_completions, ref_data_completions)),
- )
- # convert ranks to a True/False mask:
- # when rank == 0, it means the first completion is the best
- # when rank == 1, it means the second completion is the best
- return torch.tensor([rank == 0 for rank in ranks_of_first_completion], device=model_data["input_ids"].device)
-
- def _compute_logprobs(self, model, model_data, ref_data, context_length):
- def compute_logprobs_for_data(m, data):
- output = m(data["input_ids"], attention_mask=data["attention_mask"])
- logits = output.logits[:, context_length - 1 : -1]
- token_logprobs = selective_log_softmax(logits, data["input_ids"][:, context_length:])
- return token_logprobs
-
- # Compute logprobs for model completions
- model_logprobs_model_data = compute_logprobs_for_data(model, model_data)
- # Compute logprobs for model on reference completions (for XPO loss)
- model_logprobs_ref_data = compute_logprobs_for_data(model, ref_data)
-
- # Compute logprobs for reference model completions
- with torch.no_grad():
- if self.ref_model is None:
- with model.disable_adapter():
- ref_logprobs_model_data = compute_logprobs_for_data(model, model_data)
- ref_logprobs_ref_data = compute_logprobs_for_data(model, ref_data)
- else:
- ref_logprobs_model_data = compute_logprobs_for_data(self.ref_model, model_data)
- ref_logprobs_ref_data = compute_logprobs_for_data(self.ref_model, ref_data)
-
- # Mask padding tokens
- model_padding_mask = model_data["attention_mask"][:, context_length:] == 0
- ref_padding_mask = ref_data["attention_mask"][:, context_length:] == 0
- model_logprobs_model_data = model_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
- model_logprobs_ref_data = model_logprobs_ref_data.masked_fill(ref_padding_mask, 0.0)
- ref_logprobs_ref_data = ref_logprobs_ref_data.masked_fill(ref_padding_mask, 0.0)
- ref_logprobs_model_data = ref_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
-
- return model_logprobs_model_data, model_logprobs_ref_data, ref_logprobs_ref_data, ref_logprobs_model_data
-
- def _compute_losses(
- self,
- model_logprobs_model_data,
- model_logprobs_ref_data,
- ref_logprobs_ref_data,
- ref_logprobs_model_data,
- chosen_mask,
- ):
- # Compute log probs
- model_logprobs_model_data_sum = model_logprobs_model_data.sum(1)
- model_logprobs_ref_data_sum = model_logprobs_ref_data.sum(1)
- ref_logprobs_ref_data_sum = ref_logprobs_ref_data.sum(1)
- ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1)
-
- chosen_model_logprobs = torch.where(chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
- chosen_ref_logprobs = torch.where(chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
- chosen_log_ratios = chosen_model_logprobs - chosen_ref_logprobs
-
- rejected_model_logprobs = torch.where(~chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
- rejected_ref_logprobs = torch.where(~chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
- rejected_log_ratios = rejected_model_logprobs - rejected_ref_logprobs
-
- # Compute logits as the difference between chosen and rejected log ratios
- logits = chosen_log_ratios - rejected_log_ratios
-
- if self.args.loss_type == "sigmoid":
- dpo_losses = -F.logsigmoid(self.beta * logits)
- elif self.args.loss_type == "ipo":
- dpo_losses = (logits - 1 / (2 * self.beta)) ** 2
- else:
- raise NotImplementedError(f"invalid loss type {self.args.loss_type}")
-
- # Compute XPO specific loss
- xpo_losses = self.alpha * model_logprobs_ref_data_sum
-
- # Total loss
- loss = (dpo_losses + xpo_losses).mean()
-
- return loss, dpo_losses, xpo_losses
-
- def _log_statistics(
- self,
- model_data,
- ref_data,
- model_logprobs_model_data,
- model_logprobs_ref_data,
- ref_logprobs_ref_data,
- ref_logprobs_model_data,
- chosen_mask,
- dpo_losses,
- xpo_losses,
- context_length,
- model_scores=None,
- ref_scores=None,
- ):
- # Helper function to gather and compute mean
- def gather_mean(tensor):
- return self.accelerator.gather_for_metrics(tensor).mean().item()
-
- # Log losses
- self.stats["loss/dpo"].append(gather_mean(dpo_losses))
- self.stats["loss/xpo"].append(gather_mean(xpo_losses))
-
- # Log scores
- if self.reward_funcs is not None:
- self.stats["objective/model_scores"].append(gather_mean(model_scores))
- self.stats["objective/ref_scores"].append(gather_mean(ref_scores))
- self.stats["objective/scores_margin"].append(gather_mean(model_scores - ref_scores))
-
- # Log logprobs
- model_logprobs_model_data_sum = model_logprobs_model_data.sum(1)
- model_logprobs_ref_data_sum = model_logprobs_ref_data.sum(1)
- ref_logprobs_ref_data_sum = ref_logprobs_ref_data.sum(1)
- ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1)
-
- chosen_model_logprobs = torch.where(chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
- chosen_ref_logprobs = torch.where(chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
- chosen_log_ratios = chosen_model_logprobs - chosen_ref_logprobs
-
- rejected_model_logprobs = torch.where(~chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
- rejected_ref_logprobs = torch.where(~chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
- rejected_log_ratios = rejected_model_logprobs - rejected_ref_logprobs
-
- self.stats["logps/chosen"].append(gather_mean(chosen_model_logprobs.mean() + chosen_ref_logprobs.mean()))
- self.stats["logps/rejected"].append(gather_mean(rejected_model_logprobs.mean() + rejected_ref_logprobs.mean()))
-
- # Log rewards
- # Compute various statistics
- chosen_rewards = chosen_log_ratios * self.beta
- rejected_rewards = rejected_log_ratios * self.beta
- self.stats["rewards/chosen"].append(gather_mean(chosen_rewards.mean()))
- self.stats["rewards/rejected"].append(gather_mean(rejected_rewards.mean()))
-
- # Calculate KL divergence for model and ref data
- kl_model_data = model_logprobs_model_data - ref_logprobs_model_data
- kl_ref_data = model_logprobs_ref_data - ref_logprobs_ref_data
- mean_kl = (kl_model_data.sum(1) + kl_ref_data.sum(1)).mean() / 2
- self.stats["objective/kl"].append(gather_mean(mean_kl))
-
- # Calculate entropy for model and ref data
- entropy_model_data = -model_logprobs_model_data.sum(1)
- entropy_ref_data = -model_logprobs_ref_data.sum(1)
- mean_entropy = (entropy_model_data.mean() + entropy_ref_data.mean()) / 2
- self.stats["objective/entropy"].append(gather_mean(mean_entropy))
-
- # Calculate margins
- margin = chosen_rewards - rejected_rewards
- self.stats["rewards/margins"].append(gather_mean(margin.mean()))
-
- # Calculate accuracy
- accuracy = (margin > 0).float()
- self.stats["rewards/accuracies"].append(gather_mean(accuracy.mean()))
-
- # Log EOS token statistics
- model_eos = (model_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
- ref_eos = (ref_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
- self.stats["val/model_contain_eos_token"].append(gather_mean(model_eos.float()))
- self.stats["val/ref_contain_eos_token"].append(gather_mean(ref_eos.float()))
-
- # Log alpha and beta
- self.stats["alpha"].append(self.alpha)
- self.stats["beta"].append(self.beta)
-
- def training_step(
- self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
- ) -> torch.Tensor:
- model.train()
-
- # Apply chat template and tokenize the input
- batch_size = len(next(iter(inputs.values())))
- prompts = inputs["prompt"]
- inputs = [{k: v[i] for k, v in inputs.items()} for i in range(batch_size)]
- inputs = [maybe_apply_chat_template(x, self.processing_class) for x in inputs]
- inputs = [self.tokenize_row(x, self.model.config.is_encoder_decoder, self.processing_class) for x in inputs]
- inputs = self.data_collator(inputs)
-
- # need the prompt_ only
- inputs = self._prepare_inputs(inputs)
- context_length = inputs["prompt_input_ids"].shape[1]
- prompts = {
- "input_ids": inputs["prompt_input_ids"],
- "attention_mask": inputs["prompt_attention_mask"],
- "raw": prompts,
- }
- del inputs
-
- # Sample completions from both the model and the reference model
- model_output, ref_output = self._generate_completions(prompts, model)
-
- # Process model completions
- model_data, ref_data = self._process_completions(model_output, ref_output, prompts)
-
- # Compute rewards
- if self.reward_funcs is not None:
- model_scores, ref_scores = self._compute_rewards(model_data, ref_data, context_length)
- chosen_mask = model_scores >= ref_scores
- else:
- model_scores, ref_scores = None, None
- chosen_mask = self._compute_judge(model_data, ref_data, context_length)
-
- # Compute logprobs
- model_logprobs_model_data, model_logprobs_ref_data, ref_logprobs_ref_data, ref_logprobs_model_data = (
- self._compute_logprobs(model, model_data, ref_data, context_length)
- )
-
- # Compute loss
- loss, dpo_losses, xpo_losses = self._compute_losses(
- model_logprobs_model_data,
- model_logprobs_ref_data,
- ref_logprobs_ref_data,
- ref_logprobs_model_data,
- chosen_mask,
- )
-
- # Log everything
- self._log_statistics(
- model_data,
- ref_data,
- model_logprobs_model_data.detach(),
- model_logprobs_ref_data.detach(),
- ref_logprobs_ref_data,
- ref_logprobs_model_data,
- chosen_mask,
- dpo_losses.detach(),
- xpo_losses.detach(),
- context_length,
- model_scores,
- ref_scores,
- )
-
- if (
- self.args.torch_empty_cache_steps is not None
- and self.state.global_step % self.args.torch_empty_cache_steps == 0
- ):
- empty_cache()
-
- kwargs = {}
- # For LOMO optimizers you need to explicitly use the learning rate
- if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
- kwargs["learning_rate"] = self._get_learning_rate()
-
- if self.args.n_gpu > 1:
- loss = loss.mean() # mean() to average on multi-gpu parallel training
-
- self.accelerator.backward(loss, **kwargs)
-
- return loss.detach() / self.args.gradient_accumulation_steps
-class UnslothXPOTrainer(_UnslothXPOTrainer):
- """
-
- Trainer for Exploratory Preference Optimization (XPO).
-
- It is implemented as a subclass of [`OnlineDPOTrainer`].
-
- Args:
- model ([`~transformers.PreTrainedModel`]):
- The model to train, preferably an `AutoModelForCausalLM`.
- ref_model ([`PreTrainedModelWrapper`]):
- Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation
- and loss. If no reference model is provided, the trainer will create a reference model with the same
- architecture as the model to be optimized.
- reward_funcs ([`~transformers.PreTrainedModel`]):
- The reward model to score completions with, preferably an
- [`~transformers.AutoModelForSequenceClassification`].
- judge ([`BasePairwiseJudge`]):
- The judge to use for pairwise comparison of model completions.
- args ([`XPOConfig`]):
- The XPO config arguments to use for training.
- data_collator ([`~transformers.DataCollator`]):
- The data collator to use for training. If None is specified, the default data collator
- ([`DPODataCollatorWithPadding`]) will be used which will pad the sequences to the maximum length of the
- sequences in the batch, given a dataset of paired sequences.
- train_dataset ([`~datasets.Dataset`]):
- The dataset to use for training.
- eval_dataset ([`~datasets.Dataset`]):
- The dataset to use for evaluation.
- processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*):
- Processing class used to process the data. If provided, will be used to automatically process the inputs
- for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
- reuse the fine-tuned model.
- peft_config (`dict`):
- The peft config to use for training.
- compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
- The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to
- metric values.
- callbacks (`list[transformers.TrainerCallback]`):
- The callbacks to use for training.
- optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
- The optimizer and scheduler to use for training.
- preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
- The function to use to preprocess the logits before computing the metrics.
-
- reward_model:
-
-
-
- This parameter is deprecated and will be removed in version 0.25.0. Use `reward_funcs` instead.
-
-
-
- """
- def __init__(
- self,
- model = None,
- ref_model = None,
- reward_funcs = None,
- judge = None,
- args = None,
- data_collator = None,
- train_dataset = None,
- eval_dataset = None,
- processing_class = None,
- reward_processing_classes = None,
- peft_config = None,
- compute_metrics = None,
- callbacks = None,
- preprocess_logits_for_metrics = None,
- reward_model = None,
- **kwargs
- ):
- if args is None: args = UnslothXPOConfig()
- use_bf16 = getattr(args, 'bf16', False)
- if type(use_bf16) is not bool: use_bf16 = False
- use_fp16 = getattr(args, 'fp16', False)
- if type(use_fp16) is not bool: use_fp16 = False
- force_float32 = False
- full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1'
- if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'):
- print('Unsloth: Switching to float32 training since model cannot work with float16')
- force_float32 = True
- mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
- dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None)
- if dtype is None: dtype = model.get_input_embeddings().weight.dtype
- from unsloth_zoo.utils import _get_dtype
- dtype = _get_dtype(dtype)
- float16 = dtype == torch.float16
- if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
- if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
- if force_float32:
- # Forced float32 training
- args.fp16 = False
- args.bf16 = False
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
- if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
- # args.mixed_precision is a new argument which needs to be set now
- elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
- # Mixed precision training
- args.fp16 = float16
- args.bf16 = not float16
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
- if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16'
- # args.mixed_precision is a new argument which needs to be set now
- elif mixed_precision_dtype == 'bfloat16':
- # Both False since bfloat16 full finetuning doesn't do any autocasting.
- args.fp16 = False
- args.bf16 = False
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
- if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no'
- # args.mixed_precision is a new argument which needs to be set now
-
- if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
- args.eval_strategy = 'steps'
- if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
- ga_steps = getattr(args, 'gradient_accumulation_steps', None)
- if ga_steps is not None and ga_steps > 1:
- from transformers import __version__ as transformers_version
- if Version(transformers_version) <= Version('4.45.2'):
- print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
- '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
- if getattr(args, 'eval_strategy', 'no') != 'no':
- eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
- if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
- if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
- fp16_full_eval = getattr(args, 'fp16_full_eval', False)
- if type(fp16_full_eval) is not bool: fp16_full_eval = False
- bf16_full_eval = getattr(args, 'bf16_full_eval', False)
- if type(bf16_full_eval) is not bool: bf16_full_eval = False
- if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
- if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
- if force_float32:
- args.bf16_full_eval = False
- args.fp16_full_eval = False
- elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
- args.bf16_full_eval = True
- args.fp16_full_eval = False
- elif not bf16_full_eval and not fp16_full_eval:
- args.bf16_full_eval = args.bf16
- args.fp16_full_eval = args.fp16
- _output_logits = False
- if locals().get('compute_metrics', None) is not None: _output_logits = True
- if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
- if _output_logits:
- os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
- if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
- pass
- else:
- model_max_seq_length = getattr(model, 'max_seq_length', None)
- args_max_seq_length = getattr(args, 'max_seq_length', None)
- if args_max_seq_length is None and model_max_seq_length is not None:
- max_seq_length = model.max_seq_length
- if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
- elif args_max_seq_length is not None and model_max_seq_length is not None:
- if args_max_seq_length > model_max_seq_length:
- print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but '
- 'the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.')
- args.max_seq_length = model_max_seq_length
- if model is not None and hasattr(model, 'for_training'):
- model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
- if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
- if 'processing_class' in locals():
- if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
- if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
- __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
- from unsloth_zoo.vision_utils import UnslothVisionDataCollator
- if not isinstance(data_collator, UnslothVisionDataCollator):
- if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
- data_collator = TransformersDataCollatorForLanguageModeling(
- __tokenizer,
- mlm = False,
- mlm_probability = 0.0,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
- data_collator = DataCollatorForSeq2Seq(
- __tokenizer,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- else:
- if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
- if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
- if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
- if not isinstance(data_collator, UnslothVisionDataCollator):
- if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
- if isinstance(data_collator, DataCollatorForSeq2Seq):
- data_collator = DataCollatorForSeq2Seq(
- __tokenizer.tokenizer,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- else:
- data_collator = TransformersDataCollatorForLanguageModeling(
- __tokenizer.tokenizer,
- mlm = False,
- mlm_probability = 0.0,
- pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
- )
- other_metrics = []
-
- from unsloth_zoo.logging_utils import PatchRLStatistics
- PatchRLStatistics('xpo_trainer', other_metrics)
-
- # [TODO] Fix up DataParallel multiplying batch sizes
- # [TODO] DDP works, but DP seems to not work? [TODO]
- if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1:
- if getattr(args, "_n_gpu", 1) != 1:
- args._n_gpu = 1
- if "model" in locals() and hasattr(model, "for_training"):
- model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True))
- super().__init__(
- model = model,
- ref_model = ref_model,
- reward_funcs = reward_funcs,
- judge = judge,
- args = args,
- data_collator = data_collator,
- train_dataset = train_dataset,
- eval_dataset = eval_dataset,
- processing_class = processing_class,
- reward_processing_classes = reward_processing_classes,
- peft_config = peft_config,
- compute_metrics = compute_metrics,
- callbacks = callbacks,
- preprocess_logits_for_metrics = preprocess_logits_for_metrics,
- reward_model = reward_model,**kwargs)
- if "model" in locals() and hasattr(model, "for_inference"):
- model.for_inference()
- if hasattr(self, 'neftune_hook_handle'):
- self.neftune_hook_handle.remove()
- if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
- if getattr(args, 'neftune_noise_alpha', None) is not None:
- model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
- pass
- if hasattr(self, 'accelerator'):
- scaler = self.accelerator.scaler
- current_model = model
- while hasattr(current_model, 'model'):
- current_model.accelerator_scaler = scaler
- current_model = current_model.model
- current_model.accelerator_scaler = scaler
- pass
- if hasattr(self, 'train'):
- self.train = MethodType(prepare_for_training_mode(self.__class__.train), self)
- pass
- if hasattr(self, 'llm') and self.llm is not None and hasattr(self.llm, 'get_tokenizer'):
- _vllm_tok = self.llm.get_tokenizer()
- _pc = getattr(self, 'processing_class', None) or getattr(self, 'tokenizer', None)
- if _vllm_tok is not None and _pc is not None and getattr(_pc, 'chat_template', None) is not None and getattr(_vllm_tok, 'chat_template', None) is None:
- _vllm_tok.chat_template = _pc.chat_template
- pass
-
-pass
diff --git a/unsloth_compiled_cache/moe_utils.py b/unsloth_compiled_cache/moe_utils.py
deleted file mode 100644
index b444c2f..0000000
--- a/unsloth_compiled_cache/moe_utils.py
+++ /dev/null
@@ -1,1251 +0,0 @@
-# Unsloth Zoo - Utilities for Unsloth
-# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as published
-# by the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU Affero General Public License for more details.
-#
-# You should have received a copy of the GNU Affero General Public License
-# along with this program. If not, see .
-import torch
-import torch.nn.functional as F
-import os
-import shutil
-from typing import Optional, Tuple
-from torch.autograd import Function
-from .utils import logger
-
-# Get compile location
-UNSLOTH_COMPILE_LOCATION = os.environ.get(
- "UNSLOTH_COMPILE_LOCATION", "unsloth_compiled_cache"
-)
-
-
-def install_to_cache(source_path, destination_filename=None):
- """
- Copies a file to the unsloth_compiled_cache directory
- to ensure it is available for compiled modules.
- """
- if not os.path.exists(UNSLOTH_COMPILE_LOCATION):
- try:
- os.makedirs(UNSLOTH_COMPILE_LOCATION)
- except:
- pass
-
- current_file = os.path.abspath(source_path)
- if destination_filename is None:
- destination_filename = os.path.basename(current_file)
-
- destination = os.path.abspath(os.path.join(UNSLOTH_COMPILE_LOCATION, destination_filename))
-
- # If source and dest are different, copy.
- if current_file != destination:
- try:
- shutil.copy(current_file, destination)
- except Exception:
- pass
-
-
-install_to_cache(__file__, "moe_utils.py")
-
-# ============================================================================
-# Grouped MM wrapper
-# ============================================================================
-# Simple wrapper around torch._grouped_mm that ensures contiguous inputs.
-# Native backward works correctly - no custom autograd needed.
-# ============================================================================
-
-
-def _grouped_mm_with_backward_fix(
- inputs: torch.Tensor, weight: torch.Tensor, offsets: torch.Tensor
-) -> torch.Tensor:
- """
- Grouped matmul with working backward pass.
-
- Uses native torch._grouped_mm with contiguous inputs for correct gradients.
- """
- return torch._grouped_mm(inputs, weight, offs=offsets)
-
-
-# Global flag to check if grouped GEMM is available
-_GROUPED_GEMM_AVAILABLE = None
-_TORCH_GROUPED_MM_AVAILABLE = hasattr(torch, "_grouped_mm")
-
-# Check if GPU supports torch._grouped_mm (verified via runtime check)
-_TORCH_GROUPED_MM_SUPPORTED = None
-
-
-def _check_torch_grouped_mm_supported():
- """
- Check if torch._grouped_mm is actually supported on the current GPU.
- We check for existence and verify with a dummy call.
- A runtime probe is the only reliable check.
- """
- global _TORCH_GROUPED_MM_SUPPORTED
- if _TORCH_GROUPED_MM_SUPPORTED is not None: return _TORCH_GROUPED_MM_SUPPORTED
-
- if not _TORCH_GROUPED_MM_AVAILABLE:
- _TORCH_GROUPED_MM_SUPPORTED = False
- return False
-
- if not torch.cuda.is_available():
- _TORCH_GROUPED_MM_SUPPORTED = False
- return False
-
- try:
- # Attempt a dummy grouped_mm call to verify support.
- # This handles cases where the symbol exists but hardware is unsupported (e.g. < H100).
- # It also allows support on newer hardware or backports without code changes.
- device = torch.cuda.current_device()
- dtype = torch.float16
-
- # Minimal dummy data: 1 expert, 1 token, dim 8 (safe alignment)
- x = torch.ones((1, 8), device=device, dtype=dtype)
- w = torch.ones((1, 8, 8), device=device, dtype=dtype)
- offs = torch.tensor([1], device=device, dtype=torch.int32)
-
- torch._grouped_mm(x, w, offs=offs)
- del x, w, offs
- _TORCH_GROUPED_MM_SUPPORTED = True
- except Exception:
- _TORCH_GROUPED_MM_SUPPORTED = False
-
- return _TORCH_GROUPED_MM_SUPPORTED
-
-
-_TRITON_ALLOCATOR_INITIALIZED = False
-_PERSISTENT_BUFFER = None
-
-
-def _init_triton_allocator():
- """
- Initialize a persistent Triton allocator to avoid memory allocation overhead per call.
- This significantly reduces GPU utilization fluctuation.
- """
- global _TRITON_ALLOCATOR_INITIALIZED, _PERSISTENT_BUFFER
- if _TRITON_ALLOCATOR_INITIALIZED: return
-
- try:
- import triton
-
- # Create a persistent buffer that grows as needed
- # This avoids allocating new memory on every kernel call
-
- def persistent_alloc_fn(size: int, alignment: int, stream):
- global _PERSISTENT_BUFFER
- # Round up size to avoid frequent reallocations
- # Round to nearest 128 bytes for alignment
- rounded_size = ((size + 128 - 1) // 128) * 128
-
- if (
- _PERSISTENT_BUFFER is None
- or _PERSISTENT_BUFFER.numel() * _PERSISTENT_BUFFER.element_size()
- < rounded_size
- ):
- # Allocate with small headroom (10%) to reduce reallocations
- # Use ByteTensor (uint8) for raw byte storage
- _PERSISTENT_BUFFER = torch.empty(
- int(rounded_size * 1.1), device="cuda", dtype=torch.uint8
- )
- _PERSISTENT_BUFFER.__hibernate__ = {"type": "ignore"}
- return _PERSISTENT_BUFFER
-
- triton.set_allocator(persistent_alloc_fn)
- triton._unsloth_allocator_set = True
- _TRITON_ALLOCATOR_INITIALIZED = True
- except Exception:
- pass
-
-
-def _check_grouped_gemm_available():
- """Check if Unsloth grouped GEMM kernels are available."""
- if os.environ.get("UNSLOTH_DISABLE_MOE_TRITON", "0") == "1": return False
-
- global _GROUPED_GEMM_AVAILABLE
- if _GROUPED_GEMM_AVAILABLE is not None: return _GROUPED_GEMM_AVAILABLE
-
- try:
- from unsloth.kernels.moe.grouped_gemm.interface import grouped_gemm, supports_tma
- _GROUPED_GEMM_AVAILABLE = True
- _init_triton_allocator()
- except (ImportError, ModuleNotFoundError):
- _GROUPED_GEMM_AVAILABLE = False
- return _GROUPED_GEMM_AVAILABLE
-
-
-from functools import lru_cache
-
-
-@lru_cache(maxsize=1)
-def select_moe_backend():
- """
- Selects the MoE backend based on UNSLOTH_MOE_BACKEND environment variable and availability.
- Choices: "grouped_mm", "unsloth_triton", "native_torch".
- Default if unspecified: "grouped_mm".
- """
- # This Unsloth Zoo code section is licensed under AGPL3
-
- requested = os.environ.get("UNSLOTH_MOE_BACKEND")
- if requested:
- if requested == "grouped_mm" and _check_torch_grouped_mm_supported():
- return "grouped_mm"
- if requested == "unsloth_triton" and _check_grouped_gemm_available():
- return "unsloth_triton"
- if requested == "native_torch":
- return "native_torch"
- logger.info(f"Unsloth: '{requested}' backend requested but is not available. Falling back to next available.")
-
- if _check_torch_grouped_mm_supported():
- logger.info("Unsloth: Using MoE backend 'grouped_mm'")
- return "grouped_mm"
- if _check_grouped_gemm_available():
- logger.info("Unsloth: Using MoE backend 'unsloth_triton'")
- return "unsloth_triton"
- return "native_torch"
-
-
-def forward_moe_backend(
- self,
- hidden_states: torch.Tensor,
- top_k_index: torch.Tensor,
- top_k_weights: torch.Tensor,
-) -> torch.Tensor:
- """
- Dispatch MoE forward to the selected backend.
- Centralizes backend selection to keep model-specific patches minimal.
- """
- # This Unsloth Zoo code section is licensed under AGPL3
-
- backend = select_moe_backend()
- if backend == "grouped_mm":
- return forward_native_grouped_mm(self, hidden_states, top_k_index, top_k_weights)
- if backend == "unsloth_triton":
- return forward_triton_grouped_gemm(self, hidden_states, top_k_index, top_k_weights)
- return forward_native_moe_loop(self, hidden_states, top_k_index, top_k_weights)
-
-
-@torch.no_grad()
-def _get_routing_indices(selected_experts, num_experts):
- """
- Compute token→expert mapping for grouped GEMM.
- Uses bincount instead of histc to avoid float conversion overhead.
-
- Returns:
- token_counts_by_expert: (num_experts,) token counts per expert
- gather_indices: (total_tokens,) indices for gathering tokens in expert order
- """
- # This Unsloth Zoo code section is licensed under AGPL3
-
- flat_experts = selected_experts.view(-1)
-
- # bincount is faster than histc since it doesn't require float conversion
- token_counts_by_expert = torch.bincount(flat_experts, minlength=num_experts).to(torch.int32)
-
- # argsort with stable=True preserves order within each expert
- gather_indices = flat_experts.argsort(stable=True)
-
- return token_counts_by_expert, gather_indices
-
-
-def _silu_and_mul(x):
- """Fused SiLU activation and element-wise multiply for gate/up projections."""
- gate, up = x.chunk(2, dim=-1)
- return F.silu(gate) * up
-
-
-# ============================================================================
-# Separated LoRA Helper Functions
-# ============================================================================
-
-
-def _has_lora_adapters(param) -> bool:
- """Check if parameter has active LoRA adapters (PEFT ParamWrapper)."""
- # Check if this is a PEFT LoRA wrapper
- if not hasattr(param, "lora_A") or not hasattr(param, "lora_B"):
- return False
- if hasattr(param, "disable_adapters") and param.disable_adapters:
- return False
- if hasattr(param, "merged") and param.merged:
- return False
- return len(param.lora_A) > 0
-
-
-def _extract_lora_from_wrapper(
- wrapper, adapter_name: str = "default", experts_module=None
-) -> Optional[Tuple[torch.Tensor, torch.Tensor, float, int]]:
- """
- Extract LoRA weights from PEFT ParamWrapper for MoE separated computation.
-
- PEFT ParamWrapper for 3D parameters creates:
- - lora_A: nn.Linear(in_dim, E*R) -> weight: (E*R, in_dim)
- - lora_B: nn.Linear(E*R, out_dim) -> weight: (out_dim, E*R)
-
- For grouped_mm: X @ first_weight @ second_weight
-
- STANDARD FORMAT (Qwen3-MoE): weights stored as (E, out_dim, in_dim) for F.linear
- gate_up_proj: (E, 2*I, H) - input X is (N, H), output is (N, 2*I)
- down_proj: (E, H, I) - input X is (N, I), output is (N, H)
-
- For gate_up with (E, 2*I, H):
- lora_A: (E*R, H), lora_B: (2*I, E*R)
- Input X (N, H) needs: X @ (E, H, R) @ (E, R, 2*I) -> (N, 2*I)
- first_weight from lora_A: (E*R, H) -> (E, H, R) after view/permute
- second_weight from lora_B: (2*I, E*R) -> (E, R, 2*I) after view/permute
-
- TRANSPOSED FORMAT (Qwen3-VL-MoE): weights stored as (E, in_dim, out_dim) for grouped_mm
- gate_up_proj: (E, H, 2*I) - input X is (N, H), output is (N, 2*I)
- down_proj: (E, I, H) - input X is (N, I), output is (N, H)
-
- For gate_up with (E, H, 2*I):
- lora_A: (E*R, H), lora_B: (2*I, E*R)
- Input X (N, H) needs: X @ (E, H, R) @ (E, R, 2*I) -> (N, 2*I)
- first_weight from lora_A: (E*R, H) -> (E, H, R)
- second_weight from lora_B: (2*I, E*R) -> (E, R, 2*I)
-
- Returns:
- (first_weight, second_weight, scaling, num_experts) or None
- """
- # This Unsloth Zoo code section is licensed under AGPL3
-
- try:
- if not hasattr(wrapper, "lora_A") or not hasattr(wrapper, "lora_B"):
- return None
-
- if hasattr(wrapper, "disable_adapters") and wrapper.disable_adapters:
- return None
- if hasattr(wrapper, "merged") and wrapper.merged:
- return None
-
- if not wrapper.lora_A:
- return None
-
- if adapter_name not in wrapper.lora_A:
- adapter_name = list(wrapper.lora_A.keys())[0]
-
- lora_A_module = wrapper.lora_A[adapter_name]
- lora_B_module = wrapper.lora_B[adapter_name]
-
- weight_A = lora_A_module.weight # (E*R, dim1)
- weight_B = lora_B_module.weight # (dim2, E*R)
- scaling = wrapper.scaling[adapter_name]
- num_experts = getattr(wrapper, "num_experts", 1)
-
- # GET EXPERTS MODULE TO CHECK FOR REGISTERED EXTRACTOR
- if experts_module is None:
- experts_module = wrapper.get_base_layer() if hasattr(wrapper, "get_base_layer") else None
-
- # Check for model-specific LoRA extractor attached to the experts module
- extractor_fn = getattr(experts_module, "_unsloth_lora_extractor_fn", None)
-
- if extractor_fn is not None:
- return extractor_fn(wrapper, weight_A, weight_B, scaling, num_experts)
-
- # DEFAULT BEHAVIOR (Standard Format / Non-MoE)
- if num_experts > 1:
- total_rank = weight_A.shape[0]
- rank_per_expert = total_rank // num_experts
- dim1 = weight_A.shape[1]
- dim2 = weight_B.shape[0]
-
- # STANDARD FORMAT (Qwen3-MoE / GLM4):
- # Base weights are (E, out_dim, in_dim) for F.linear.
- # LoRA weights follow PEFT: weight_A is (E*R, in_dim), weight_B is (out_dim, E*R).
- # We need X @ (E, in_dim, R) @ (E, R, out_dim).
-
- # first_weight: (E, in_dim, R) - from lora_A
- # second_weight: (E, R, out_dim) - from lora_B
- first_weight = weight_A.view(num_experts, rank_per_expert, dim1)
- first_weight = first_weight.permute(0, 2, 1).contiguous() # (E, dim1, R)
-
- # second_weight (B): (E, R, out_dim)
- second_weight = weight_B.view(dim2, num_experts, rank_per_expert)
- second_weight = second_weight.permute(1, 2, 0).contiguous() # (E, R, dim2)
- else:
- # Non-MoE case: return weights for X @ A.T @ B.T
- first_weight = weight_A.T # (dim1, R)
- second_weight = weight_B.T # (R, dim2)
-
- return first_weight, second_weight, scaling, num_experts
- except Exception:
- return None
-
-
-def _extract_lora_weights(
- param, adapter_name: str = "default", num_experts: int = None, experts_module=None
-) -> Optional[Tuple[torch.Tensor, torch.Tensor, float]]:
- """
- Extract LoRA A and B weights from PEFT ParamWrapper.
-
- This is a compatibility wrapper around _extract_lora_from_wrapper.
- Use _extract_lora_from_wrapper directly for new code.
-
- Returns:
- (first_weight, second_weight, scaling) for (X @ first) @ second
- """
- # This Unsloth Zoo code section is licensed under AGPL3
-
- # Set num_experts on param if provided, so _extract_lora_from_wrapper can use it
- if num_experts is not None and not hasattr(param, "num_experts"):
- param.num_experts = num_experts
-
- result = _extract_lora_from_wrapper(param, adapter_name, experts_module=experts_module)
- if result is None:
- return None
- # Return first 3 elements (first_weight, second_weight, scaling) without num_experts
- return result[0], result[1], result[2]
-
-
-def _get_base_weight(param):
- """Get base weight from potentially wrapped parameter or module."""
- # This Unsloth Zoo code section is licensed under AGPL3
-
- # Recursively unwrap PEFT layers
- while hasattr(param, "base_layer"):
- param = param.base_layer
-
- if hasattr(param, "get_param"):
- return param.get_param()
-
- # Handle Modules (Linear, etc.)
- if hasattr(param, "weight"):
- return param.weight
-
- return param
-
-
-def _get_lora_wrapper_for_param(experts_module, param_name):
- """
- Get the PEFT ParamWrapper for a specific parameter (gate_up_proj or down_proj).
- Uses the explicit key stored in __dict__ if available.
- Does NOT lazily setup wrappers as that requires traversing logic not present here.
- """
- # This Unsloth Zoo code section is licensed under AGPL3
-
- if hasattr(experts_module, f"{param_name}_lora_wrapper"):
- return getattr(experts_module, f"{param_name}_lora_wrapper")
-
- # Check simple attributes if it's directly wrapped
- if hasattr(experts_module, param_name):
- attr = getattr(experts_module, param_name)
- if hasattr(attr, "lora_A"): # Is a ParamWrapper
- return attr
-
- return None
-
-
-def native_moe_grouped_mm(
- inputs: torch.Tensor, weight: torch.Tensor, offsets: torch.Tensor
-) -> torch.Tensor:
- """
- Native implementation using grouped_mm with backward fix.
-
- Uses custom autograd function to avoid PyTorch's grouped_mm backward stride bug.
- """
- return _grouped_mm_with_backward_fix(inputs, weight, offsets)
-
-
-def _apply_lora_grouped_mm(
- inputs: torch.Tensor,
- lora_B: torch.Tensor,
- lora_A: torch.Tensor,
- offsets: torch.Tensor,
- scaling: float,
- grouped_mm_func=native_moe_grouped_mm,
-) -> torch.Tensor:
- """
- Apply LoRA using grouped GEMM: result = ((X @ B) @ A) * scaling
-
- Args:
- inputs: (total_tokens, in_dim)
- lora_B: (num_experts, in_dim, rank) - First projection
- lora_A: (num_experts, rank, out_dim) - Second projection
- offsets: Grouped GEMM offsets
- scaling: LoRA scaling factor
- grouped_mm_func: Function to use for grouped GEMM (default: native_moe_grouped_mm)
- """
- # This Unsloth Zoo code section is licensed under AGPL3
-
- # 1. First Matmul (X @ B)
- # lora_B is (E, in_dim, R)
- # Native needs (E, in_dim, R) -> No Transpose
- lora_intermediate = grouped_mm_func(inputs, lora_B.contiguous(), offsets)
-
- # 2. Second Matmul (result @ A)
- # lora_A is (E, R, out_dim)
- # Native needs (E, R, out_dim) -> No Transpose
- lora_delta = grouped_mm_func(lora_intermediate, lora_A.contiguous(), offsets)
-
- return lora_delta * scaling
-
-
-def _should_use_separated_lora() -> bool:
- """
- Check if separated LoRA approach should be used (default: True).
- Set UNSLOTH_MOE_LORA_MERGED=1 to use merged approach instead.
- """
- return os.environ.get("UNSLOTH_MOE_LORA_MERGED", "0") != "1"
-
-
-# ============================================================================
-# Model-specific Weight Preprocessing Hooks
-# ============================================================================
-# Each model can register its own preprocessing function for weight transposition.
-# This allows the generic backend to work with different model weight layouts.
-
-_WEIGHT_PREPROCESSORS = {}
-
-
-def register_weight_preprocessor(model_type: str, preprocessor_fn):
- """
- Register a weight preprocessor for a specific model type.
-
- Args:
- model_type: Model identifier (e.g., "qwen3_moe", "qwen3_vl_moe")
- preprocessor_fn: Function(weight, proj_type, hidden_dim) -> processed_weight
- proj_type is "gate_up" or "down"
- """
- _WEIGHT_PREPROCESSORS[model_type] = preprocessor_fn
-
-
-def get_weight_preprocessor(model_type: str):
- """Get registered weight preprocessor for model type."""
- return _WEIGHT_PREPROCESSORS.get(model_type)
-
-
-def preprocess_weight(
- weight: torch.Tensor, proj_type: str, hidden_dim: int, model_type=None
-):
- """
- Preprocess weight tensor for grouped_mm compatibility.
-
- Uses model-specific preprocessor if registered, otherwise uses default logic.
-
- Args:
- weight: Weight tensor (E, dim1, dim2) or similar
- proj_type: "gate_up" or "down"
- hidden_dim: Hidden dimension for shape inference
- model_type: Optional model type to use specific preprocessor
-
- Returns:
- Weight tensor in (E, in_dim, out_dim) format for grouped_mm
- """
- # This Unsloth Zoo code section is licensed under AGPL3
-
- if model_type and model_type in _WEIGHT_PREPROCESSORS:
- return _WEIGHT_PREPROCESSORS[model_type](weight, proj_type, hidden_dim)
-
- # Default preprocessing: check if transposition is needed
- if proj_type == "gate_up":
- # For gate_up, we need (E, hidden_dim, 2*intermediate)
- if weight.shape[1] == hidden_dim:
- return weight
- else:
- return weight.transpose(-2, -1)
- else: # down
- # For down, we need (E, intermediate, hidden_dim)
- if weight.shape[2] == hidden_dim:
- return weight
- else:
- return weight.transpose(-2, -1)
-
-
-# ============================================================================
-# Generic MoE Detection and ParamWrapper Patching
-# ============================================================================
-
-
-def _is_moe_experts_module(module) -> bool:
- """
- Check if module is an MoE experts layer (generic, not model-specific).
-
- Detects modules with stacked expert weights as 3D nn.Parameter:
- - gate_up_proj/down_proj pattern (Qwen3-MoE, Qwen3-VL-MoE, etc.)
- - w1/w2/w3 pattern (older MoE models)
- """
- # This Unsloth Zoo code section is licensed under AGPL3
-
- import torch.nn as nn
-
- # Check for gate_up_proj pattern
- if hasattr(module, "gate_up_proj"):
- param = module.gate_up_proj
- if isinstance(param, nn.Parameter) and param.ndim == 3:
- return True
-
- # Check for w1/w2 pattern (separate gate/up projections)
- if hasattr(module, "w1") and hasattr(module, "w2"):
- w1 = module.w1
- if isinstance(w1, nn.Parameter) and w1.ndim == 3:
- return True
-
- return False
-
-
-# Aliases for compatibility with gpt_oss.py
-_get_moe_lora_weights = _extract_lora_from_wrapper
-
-
-# Store original ParamWrapper.forward for fallback
-_original_param_wrapper_forward = None
-
-
-def _patched_param_wrapper_forward(
- self, x: torch.Tensor, *args, **kwargs
-) -> torch.Tensor:
- """
- Patched ParamWrapper.forward for MoE separated LoRA.
-
- For MoE expert modules:
- - Bypasses PEFTs _activate_lora parametrization context
- - Stores LoRA data by parameter_name for forward_native_grouped_mm to use
-
- For non-MoE modules:
- - Falls back to original PEFT forward
- """
- # This Unsloth Zoo code section is licensed under AGPL3
-
- # CRITICAL: Use self.base_layer for forward call (immediate parent)
- # NOT self.get_base_layer() which recursively traverses to deepest layer!
- # The wrapper chain must be preserved: down_proj -> gate_up_proj -> Qwen3MoeExperts
- immediate_base_layer = self.base_layer
-
- # For storing LoRA data, we DO need the actual experts module
- # Use get_base_layer() to find it (recursive traversal is correct here)
- experts_module = self.get_base_layer()
-
- use_separated = _should_use_separated_lora()
- param_name = getattr(self, "parameter_name", None)
-
- # Check if this is an MoE experts module that should use separated LoRA
- if (
- use_separated
- and param_name in ("gate_up_proj", "down_proj")
- and _is_moe_experts_module(experts_module)
- ):
- # MoE experts: bypass PEFT's _activate_lora, use separated computation
-
- # Check adapter state
- if self.disable_adapters:
- if self.merged:
- self.unmerge()
- return immediate_base_layer(x, *args, **kwargs)
-
- if self.merged:
- return immediate_base_layer(x, *args, **kwargs)
-
- # Ensure wrapper.num_experts is set for LoRA weight reshaping
- if not hasattr(self, "num_experts"):
- if hasattr(experts_module, "num_experts"):
- self.num_experts = experts_module.num_experts
- elif hasattr(experts_module, param_name):
- p = getattr(experts_module, param_name)
- if hasattr(p, "shape") and len(p.shape) >= 1:
- self.num_experts = p.shape[0]
-
- # Extract LoRA for this specific parameter
- lora_data = _extract_lora_from_wrapper(self)
-
- if lora_data is not None and param_name:
- # Store LoRA data on the EXPERTS MODULE (not base_layer)
- # e.g., _unsloth_lora_gate_up_proj or _unsloth_lora_down_proj
- lora_attr = f"_unsloth_lora_{param_name}"
- setattr(experts_module, lora_attr, lora_data)
-
- try:
- # Call IMMEDIATE base_layer to preserve wrapper chain
- # (down_proj wrapper calls gate_up_proj wrapper calls Qwen3MoeExperts)
- result = immediate_base_layer(x, *args, **kwargs)
- finally:
- # Clean up
- if param_name:
- lora_attr = f"_unsloth_lora_{param_name}"
- if hasattr(experts_module, lora_attr):
- delattr(experts_module, lora_attr)
-
- return result
-
- # Non-MoE: use original PEFT forward with _activate_lora
- return _original_param_wrapper_forward(self, x, *args, **kwargs)
-
-
-def patch_param_wrapper_for_moe():
- """
- Patch PEFT's ParamWrapper.forward to use separated LoRA for MoE.
-
- This should be called after PEFT is imported.
- """
- # This Unsloth Zoo code section is licensed under AGPL3
-
- global _original_param_wrapper_forward
-
- try:
- from peft.tuners.lora.layer import ParamWrapper
-
- # Store original forward
- if _original_param_wrapper_forward is None:
- _original_param_wrapper_forward = ParamWrapper.forward
-
- # Patch with our version
- ParamWrapper.forward = _patched_param_wrapper_forward
-
- return True
- except ImportError:
- return False
-
-
-def forward_native_grouped_mm(
- self,
- hidden_states: torch.Tensor,
- top_k_index: torch.Tensor,
- top_k_weights: torch.Tensor,
-) -> torch.Tensor:
- """
- Native Pytorch grouped GEMM MoE forward pass.
- Uses torch._grouped_mm which is significantly faster than loop and works without Triton dependencies.
- Requires torch._grouped_mm support (verified via runtime check).
- """
- # This Unsloth Zoo code section is licensed under AGPL3
-
- # Runtime safety check - defense in depth
- if not _check_torch_grouped_mm_supported():
- major, minor = torch.cuda.get_device_capability(torch.cuda.current_device())
- raise RuntimeError(
- f"torch._grouped_mm is not supported on this device (Compute Capability {major}.{minor}). "
- f"Set UNSLOTH_MOE_BACKEND='unsloth_triton' or 'native_torch' to use a compatible backend."
- )
-
- is_2d_input = hidden_states.dim() == 2
- if is_2d_input:
- sequence_length, hidden_dim = hidden_states.shape
- batch_size = 1
- else:
- batch_size, sequence_length, hidden_dim = hidden_states.shape
-
- hidden_states = hidden_states.view(-1, hidden_dim)
-
- # 1. Calculate routing
- flat_top_k = top_k_index.view(-1)
- num_tokens_per_expert = torch.bincount(flat_top_k, minlength=self.num_experts).int()
-
- # 2. Sort indices to group tokens by expert
- sorted_indices = torch.argsort(flat_top_k, stable=True)
- token_indices = sorted_indices // top_k_index.shape[-1]
-
- # 3. Permute Input
- # We need to gather inputs. Since we may have expanded top_k, we use token_indices to map back to original input
- permuted_input = hidden_states[token_indices]
-
- # 4. Prepare Grouped MM arguments
- offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32)
-
- # ========================================================================
- # Gate + Up projection with optional separated LoRA (DEFAULT)
- # ========================================================================
- use_separated_lora = _should_use_separated_lora()
- gate_up_lora = None
-
- # Check for injected LoRA data from patched ParamWrapper (preferred path)
- if getattr(self, "_unsloth_lora_gate_up_proj", None) is not None:
- gate_up_lora = self._unsloth_lora_gate_up_proj[
- :3
- ] # (first_weight, second_weight, scaling)
- # Fallback: check parameter directly (for older wrapping patterns)
- elif (
- use_separated_lora
- and hasattr(self, "gate_up_proj")
- and _has_lora_adapters(self.gate_up_proj)
- ):
- gate_up_lora = _extract_lora_weights(
- self.gate_up_proj, num_experts=self.num_experts, experts_module=self
- )
-
- if hasattr(self, "gate_up_proj"):
- # Get base weights (raw, without LoRA)
- gate_up_base = _get_base_weight(self.gate_up_proj)
-
- # Get model type for preprocessing (if registered)
- model_type = getattr(self, "_unsloth_model_type", None)
-
- # Handle different weight shapes using preprocessor
- # torch._grouped_mm backward requires weights to be contiguous; preprocessing may return a transposed view.
- w1 = preprocess_weight(gate_up_base, "gate_up", hidden_dim, model_type)
- # Base forward: X @ W
- mm1_out = _grouped_mm_with_backward_fix(permuted_input, w1, offsets)
-
- # Add separated LoRA contribution: + ((X @ first) @ second) * scaling
- # _extract_lora_from_wrapper returns (first_weight, second_weight, scaling)
- if gate_up_lora is not None:
- first_weight, second_weight, scaling = gate_up_lora
-
- # Cast to input dtype (LoRA weights are float32, input may be bfloat16)
- # Ensure contiguous for grouped_mm alignment requirements
- first_weight = first_weight.to(permuted_input.dtype).contiguous()
- second_weight = second_weight.to(permuted_input.dtype).contiguous()
-
- # Step 1: permuted_input @ first_weight
- try:
- lora_out = _grouped_mm_with_backward_fix(permuted_input, first_weight, offsets)
- lora_out = lora_out.contiguous()
- except RuntimeError as e:
- raise e
-
- # Step 2: result @ second_weight
- # Handle unaligned O dimension or other grouped_mm failures
- try:
- if second_weight.shape[-1] % 8 != 0:
- pad_size = 8 - (second_weight.shape[-1] % 8)
- second_weight_padded = F.pad(
- second_weight, (0, pad_size)
- ).contiguous()
- lora_delta = _grouped_mm_with_backward_fix(
- lora_out, second_weight_padded, offsets
- )
- lora_delta = lora_delta[:, :-pad_size]
- else:
- lora_delta = _grouped_mm_with_backward_fix(
- lora_out, second_weight, offsets
- )
- except RuntimeError:
- # Fallback to manual loop if grouped_mm fails (e.g. stride alignment)
- lora_delta = torch.empty(
- (lora_out.shape[0], second_weight.shape[-1]),
- dtype=lora_out.dtype,
- device=lora_out.device,
- )
- cpu_offsets = offsets.cpu().tolist()
- prev_offset = 0
- for i, end in enumerate(cpu_offsets):
- if prev_offset < end:
- lora_delta[prev_offset:end] = torch.matmul(
- lora_out[prev_offset:end], second_weight[i]
- )
- prev_offset = end
-
- # Add scaled LoRA contribution
- mm1_out = mm1_out + lora_delta * scaling
-
- if hasattr(self, "gate_up_proj_bias") and self.gate_up_proj_bias is not None:
- num_repeats = num_tokens_per_expert.to(self.gate_up_proj_bias.device)
- bias_expanded = self.gate_up_proj_bias.repeat_interleave(num_repeats, dim=0)
- mm1_out = mm1_out + bias_expanded.to(mm1_out.dtype)
-
- if "GptOssExperts" in self.__class__.__name__:
- gate = mm1_out[..., ::2]
- up = mm1_out[..., 1::2]
- else:
- gate, up = mm1_out.chunk(2, dim=-1)
-
- elif hasattr(self, "w1") and hasattr(self, "w3"):
- # Separate w1/w3 weights (older models)
- w1_base = _get_base_weight(self.w1)
- w3_base = _get_base_weight(self.w3)
-
- w1 = w1_base.transpose(-2, -1)
- w3 = w3_base.transpose(-2, -1)
-
- gate = _grouped_mm_with_backward_fix(permuted_input, w1, offsets)
- up = _grouped_mm_with_backward_fix(permuted_input, w3, offsets)
-
- # Add LoRA for w1 and w3 separately if present
- if use_separated_lora:
- if _has_lora_adapters(self.w1):
- w1_lora = _extract_lora_weights(self.w1, experts_module=self)
- if w1_lora is not None:
- lora_A, lora_B, scaling = w1_lora
- lora_A_t = lora_A.transpose(-2, -1)
- lora_A_out = _grouped_mm_with_backward_fix(
- permuted_input, lora_A_t, offsets
- )
- lora_B_t = lora_B.transpose(-2, -1)
- lora_B_out = _grouped_mm_with_backward_fix(lora_A_out, lora_B_t, offsets)
- gate = gate + lora_B_out * scaling
-
- if _has_lora_adapters(self.w3):
- w3_lora = _extract_lora_weights(self.w3, experts_module=self)
- if w3_lora is not None:
- lora_A, lora_B, scaling = w3_lora
- lora_A_t = lora_A.transpose(-2, -1)
- lora_A_out = _grouped_mm_with_backward_fix(
- permuted_input, lora_A_t, offsets
- )
- lora_B_t = lora_B.transpose(-2, -1)
- lora_B_out = _grouped_mm_with_backward_fix(lora_A_out, lora_B_t, offsets)
- up = up + lora_B_out * scaling
- else:
- raise AttributeError("MoE layer must have 'gate_up_proj' or 'w1'/'w3'.")
-
- # Activation
- if "GptOssExperts" in self.__class__.__name__:
- # Custom activation from GptOss
- limit = getattr(self, "limit", 7.0)
- alpha = getattr(self, "alpha", 1.702)
-
- gate = gate.clamp(min=None, max=limit)
- up = up.clamp(min=-limit, max=limit)
- glu = gate * torch.sigmoid(gate * alpha)
- inter = (up + 1.0) * glu
- else:
- inter = F.silu(gate) * up
-
- # ========================================================================
- # Down projection with optional separated LoRA (DEFAULT)
- # ========================================================================
- down_lora = None
-
- # Check for injected LoRA data from patched ParamWrapper (preferred path)
- if getattr(self, "_unsloth_lora_down_proj", None) is not None:
- down_lora = self._unsloth_lora_down_proj[
- :3
- ] # (first_weight, second_weight, scaling)
- # Fallback: check parameter directly (for older wrapping patterns)
- elif (
- use_separated_lora
- and hasattr(self, "down_proj")
- and _has_lora_adapters(self.down_proj)
- ):
- down_lora = _extract_lora_weights(self.down_proj, num_experts=self.num_experts, experts_module=self)
-
- if hasattr(self, "down_proj"):
- # Get base weights
- down_base = _get_base_weight(self.down_proj)
-
- # Get model type for preprocessing (if registered)
- model_type = getattr(self, "_unsloth_model_type", None)
-
- # Handle different weight shapes using preprocessor
- w2 = preprocess_weight(down_base, "down", hidden_dim, model_type)
-
- # Base forward
- mm2_out = _grouped_mm_with_backward_fix(inter, w2, offsets)
-
- # Add separated LoRA contribution if present
- # _extract_lora_from_wrapper returns (first_weight, second_weight, scaling)
- if down_lora is not None:
- first_weight, second_weight, scaling = down_lora
-
- # Cast to input dtype (LoRA weights are float32, input may be bfloat16)
- first_weight = first_weight.to(inter.dtype).contiguous()
- second_weight = second_weight.to(inter.dtype).contiguous()
-
- # Step 1: inter @ first_weight
- lora_out = _grouped_mm_with_backward_fix(inter, first_weight, offsets)
- lora_out = lora_out.contiguous()
-
- # Step 2: result @ second_weight
- try:
- lora_delta = _grouped_mm_with_backward_fix(lora_out, second_weight, offsets)
- except RuntimeError:
- # Fallback to manual loop
- lora_delta = torch.empty(
- (lora_out.shape[0], second_weight.shape[-1]),
- dtype=lora_out.dtype,
- device=lora_out.device,
- )
- cpu_offsets = offsets.cpu().tolist()
- prev_offset = 0
- for i, end in enumerate(cpu_offsets):
- if prev_offset < end:
- lora_delta[prev_offset:end] = torch.matmul(
- lora_out[prev_offset:end], second_weight[i]
- )
- prev_offset = end
-
- # Add scaled LoRA contribution
- mm2_out = mm2_out + lora_delta * scaling
-
- if hasattr(self, "down_proj_bias") and self.down_proj_bias is not None:
- bias_expanded = self.down_proj_bias.repeat_interleave(
- num_tokens_per_expert.to(self.down_proj_bias.device), dim=0
- ).to(mm2_out.device)
- mm2_out = mm2_out + bias_expanded.to(mm2_out.dtype)
-
- elif hasattr(self, "w2"):
- w2_base = _get_base_weight(self.w2)
- w2 = w2_base.transpose(-2, -1)
-
- # Base forward
- mm2_out = _grouped_mm_with_backward_fix(inter, w2, offsets)
-
- # Add LoRA if present
- if use_separated_lora and _has_lora_adapters(self.w2):
- w2_lora = _extract_lora_weights(self.w2, experts_module=self)
- if w2_lora is not None:
- lora_A, lora_B, scaling = w2_lora
- lora_A_t = lora_A.transpose(-2, -1).contiguous()
- lora_A_out = _grouped_mm_with_backward_fix(inter, lora_A_t, offsets)
- lora_B_t = lora_B.transpose(-2, -1).contiguous()
- lora_B_out = _grouped_mm_with_backward_fix(lora_A_out, lora_B_t, offsets)
- mm2_out = mm2_out + lora_B_out * scaling
- else:
- raise AttributeError("MoE layer must have 'down_proj' or 'w2'.")
-
- # 5. Apply Routing Weights and Scatter Add (Reduce)
- flat_weights = top_k_weights.view(-1)
- permuted_weights = flat_weights[sorted_indices]
- mm2_out = mm2_out * permuted_weights.unsqueeze(-1)
-
- final_hidden_states = torch.zeros(
- (batch_size * sequence_length, hidden_dim),
- dtype=hidden_states.dtype,
- device=hidden_states.device,
- )
-
- final_hidden_states.index_add_(0, token_indices, mm2_out.to(hidden_states.dtype))
-
- if is_2d_input:
- return final_hidden_states
-
- return final_hidden_states.view(batch_size, sequence_length, hidden_dim)
-
-
-def forward_triton_grouped_gemm(
- self,
- hidden_states: torch.Tensor,
- top_k_index: torch.Tensor,
- top_k_weights: torch.Tensor,
-) -> torch.Tensor:
- """
- Grouped GEMM MoE forward pass using Triton kernels.
- Compatible with torch.compile (recommended mode="max-autotune" with cudagraph_mark_step_begin).
- """
- # This Unsloth Zoo code section is licensed under AGPL3
-
- # Import grouped GEMM interface
- from unsloth.kernels.moe.grouped_gemm.interface import grouped_gemm
-
- # Import autotune cache
- from unsloth.kernels.moe.autotune_cache import get_or_autotune_moe_kernels
-
- # Helper to check TMA support - assumes helper function or just check directly
- # In original: it was a cached closure. Here we can use _supports_tma() directly
-
- # nonlocal _MODEL_DIMS_AND_CONFIGS # We need a way to store this!
- # For now, let's attach it to self if possible, or use a global usage
- # Attaching to self is cleaner: self._unsloth_moe_configs
-
- # Create expert mask and find which experts have tokens
-
- if not hasattr(self, "_unsloth_moe_configs"):
- self._unsloth_moe_configs = None
-
- use_separated_lora = _should_use_separated_lora()
-
-
- # Handle 3D inputs (batch_size, seq_len, hidden_dim)
- is_3d = hidden_states.dim() == 3
- if is_3d:
- batch_size, seq_len, hidden_dim = hidden_states.shape
- hidden_states = hidden_states.view(-1, hidden_dim)
- num_tokens = batch_size * seq_len
- # Also flatten top_k inputs if they are 3D
- if top_k_index.dim() == 3:
- top_k_index = top_k_index.view(-1, top_k_index.shape[-1])
- if top_k_weights.dim() == 3:
- top_k_weights = top_k_weights.view(-1, top_k_weights.shape[-1])
- else:
- num_tokens, hidden_dim = hidden_states.shape
-
- top_k = top_k_index.shape[1]
-
- # Cache model dimensions and kernel configs on first call
- if self._unsloth_moe_configs is None:
- intermediate_dim = self.gate_up_proj.shape[1] // 2
-
- # Autotune first GEMM
- gemm1_configs = get_or_autotune_moe_kernels(
- num_experts=self.num_experts,
- hidden_dim=hidden_dim,
- intermediate_dim=intermediate_dim * 2,
- top_k=top_k,
- dtype=hidden_states.dtype,
- )
-
- # Autotune second GEMM
- gemm2_configs = get_or_autotune_moe_kernels(
- num_experts=self.num_experts,
- hidden_dim=intermediate_dim,
- intermediate_dim=hidden_dim, # Output dim for 2nd GEMM is hidden_dim
- top_k=top_k,
- dtype=hidden_states.dtype,
- )
-
- self._unsloth_moe_configs = (intermediate_dim, gemm1_configs, gemm2_configs)
-
- # Clear autotuning memory overhead
- torch.cuda.empty_cache()
-
- # Unpack cached configs
- intermediate_dim, gemm1_configs, gemm2_configs = self._unsloth_moe_configs
-
- # Unpack specific kernel configs
- fwd_config_1, bwd_dX_config_1, bwd_dW_config_1 = gemm1_configs
- fwd_config_2, bwd_dX_config_2, bwd_dW_config_2 = gemm2_configs
-
- # Compute routing indices for grouped GEMM
- token_counts_by_expert, gather_indices = _get_routing_indices(
- top_k_index, self.num_experts
- )
- offsets = torch.cumsum(token_counts_by_expert, dim=0, dtype=torch.int32)
-
- if self.gate_up_proj.shape[-1] == hidden_dim:
- w1 = self.gate_up_proj
- else:
- w1 = self.gate_up_proj.transpose(-2, -1).contiguous()
-
- # First grouped GEMM: gate_up projection
- first_gemm_output = grouped_gemm(
- X=hidden_states,
- W=w1,
- m_sizes=token_counts_by_expert,
- topk=top_k,
- gather_indices=gather_indices,
- permute_x=True,
- permute_y=False,
- autotune=False, # We use cached configs
- kernel_config_fwd=fwd_config_1,
- kernel_config_bwd_dX=bwd_dX_config_1,
- kernel_config_bwd_dW=bwd_dW_config_1,
- is_first_gemm=True,
- )
-
- # Apply SiLU activation and multiply gate with up
- intermediate = _silu_and_mul(first_gemm_output)
-
- # Grouped GEMM 2: down projection
-
- # Grouped GEMM 2: down projection
- # Prepare LoRA data
- down_lora = None
- if getattr(self, "_unsloth_lora_down_proj", None) is not None:
- down_lora = self._unsloth_lora_down_proj[:3]
- elif (
- use_separated_lora
- and hasattr(self, "down_proj")
- and _has_lora_adapters(self.down_proj)
- ):
- down_lora = _extract_lora_weights(self.down_proj, num_experts=self.num_experts)
-
- if self.down_proj.shape[-1] == intermediate.shape[-1]:
- w2 = self.down_proj
- else:
- w2 = self.down_proj.transpose(-2, -1).contiguous()
-
- second_gemm_output = grouped_gemm(
- X=intermediate,
- W=w2,
- m_sizes=token_counts_by_expert,
- topk=top_k,
- gather_indices=gather_indices,
- permute_x=False,
- permute_y=True,
- autotune=False, # We use cached configs
- kernel_config_fwd=fwd_config_2,
- kernel_config_bwd_dX=bwd_dX_config_2,
- kernel_config_bwd_dW=bwd_dW_config_2,
- is_first_gemm=False,
- )
-
- # Add separated LoRA contribution for Down
- if down_lora is not None:
- first_weight, second_weight, scaling = down_lora
-
- # Intermediate is already permuted from step 1.
- # Offsets are same.
-
- first_weight = first_weight.to(intermediate.dtype)
- second_weight = second_weight.to(intermediate.dtype)
-
- lora_delta = _apply_lora_grouped_mm(
- intermediate,
- first_weight,
- second_weight,
- offsets,
- scaling,
- grouped_mm_func=native_moe_grouped_mm
- )
-
- second_gemm_output = second_gemm_output + lora_delta
-
- # Apply routing weights and sum across top_k experts
- # Output shape: (num_tokens, top_k, hidden_dim) -> (num_tokens, hidden_dim)
- # Ensure top_k_weights matches dtype (can be float32 from softmax)
- top_k_weights_casted = top_k_weights.to(hidden_states.dtype)
- final_hidden_states = (
- second_gemm_output.view(num_tokens, top_k, hidden_dim)
- * top_k_weights_casted[..., None]
- )
- final_hidden_states = final_hidden_states.sum(dim=1)
-
- if is_3d:
- final_hidden_states = final_hidden_states.view(batch_size, seq_len, hidden_dim)
-
- return final_hidden_states
-
-
-@torch.compiler.disable
-def forward_native_moe_loop(
- self,
- hidden_states: torch.Tensor,
- top_k_index: torch.Tensor,
- top_k_weights: torch.Tensor,
-) -> torch.Tensor:
- """
- Loop-based MoE forward pass. Loops over experts that have tokens routed to them.
- Explicitly disabled for torch.compile to prevent graph breaks/recompilation issues with dynamic control flow.
- """
- # This Unsloth Zoo code section is licensed under AGPL3
- final_hidden_states = torch.zeros_like(hidden_states)
-
- # Create expert mask and find which experts have tokens
- with torch.no_grad():
- expert_mask = F.one_hot(top_k_index, num_classes=self.num_experts)
- expert_mask = expert_mask.permute(2, 1, 0) # (num_experts, top_k, n_tokens)
- expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
-
- # Only loop over experts that actually have tokens routed to them
- for expert_idx_t in expert_hit:
- expert_idx = expert_idx_t.item()
-
- # Find which tokens are routed to this expert
- top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
-
- # Gather only the tokens for this expert
- current_state = hidden_states[token_idx]
-
- # Compute gate_up projection for this expert only
- # Handle 'gate_up_proj' or 'w1'/'w3'
- if hasattr(self, "gate_up_proj"):
- gate, up = F.linear(current_state, self.gate_up_proj[expert_idx]).chunk(
- 2, dim=-1
- )
- else:
- gate = F.linear(current_state, self.w1[expert_idx])
- up = F.linear(current_state, self.w3[expert_idx])
-
- current_hidden_states = self.act_fn(gate) * up
-
- # Compute down projection for this expert only
- if hasattr(self, "down_proj"):
- current_hidden_states = F.linear(
- current_hidden_states, self.down_proj[expert_idx]
- )
- else:
- current_hidden_states = F.linear(current_hidden_states, self.w2[expert_idx])
-
- # Apply routing weights
- current_hidden_states = (
- current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
- )
-
- # Scatter back to final output
- final_hidden_states.index_add_(
- 0, token_idx, current_hidden_states.to(final_hidden_states.dtype)
- )
-
- return final_hidden_states