diff --git a/sans titre_diarization.rttm b/sans titre_diarization.rttm deleted file mode 100644 index e69de29..0000000 diff --git a/sans titre_diarized.txt b/sans titre_diarized.txt deleted file mode 100644 index 7ccb9ae..0000000 --- a/sans titre_diarized.txt +++ /dev/null @@ -1,7386 +0,0 @@ -[00:00:00] VOIX: Tu sais, moi, je suis toujours là, toi, tu es toujours là. -[00:00:03] VOIX: Il faut que petit à petit, on connecte les différents fils d'une manière ou d'une autre. -[00:00:09] VOIX: Et j'ai envie, même si je n'ai pas fait tout le boulot que je voulais faire, je l'ai commencé, même après 14 heures. -[00:00:20] VOIX: Et je voulais juste, par rapport à T2A, par rapport à ce que tu demandes, parce que ce que tu demandes, je vais te montrer les référentiels. -[00:00:28] VOIX: Que nous, on utilise. -[00:00:30] VOIX: Je démarre, je partage. -[00:00:32] VOIX: Vas-y, vas-y. -[00:00:33] VOIX: Et je t'enregistre à l'insu de ton plein gré. -[00:00:37] VOIX: Oui, ça marche. -[00:00:39] VOIX: En fait, tu vois, ça, c'est ce que j'ai commencé tout à l'heure. -[00:00:43] VOIX: Et l'idée, c'était de dire, raisonner bien sûr sur le contrôle T2A, parce que, petit à petit, -[00:00:51] VOIX: les différentes fonctionnalités qu'on va utiliser, elles vont être les mêmes pour la suite. -[00:01:00] VOIX: Donc, c'est bien de réfléchir sur tout ça. -[00:01:02] VOIX: Et en fait, j'ai découpé. -[00:01:05] VOIX: En fait, normalement, c'est avant le contrôle, pendant le contrôle et il y a après le contrôle. -[00:01:11] VOIX: Mais tu vois, on n'en est pas là l'objectif. -[00:01:13] VOIX: Ce n'est pas d'être exhaustif. -[00:01:14] VOIX: C'est juste comprendre le principe. -[00:01:17] VOIX: Et avant le contrôle, tu vois, on commence par l'établissement qui reçoit la notification ARS. -[00:01:23] VOIX: Ça, c'est le document, en fait, que j'ai, que UCR, c'est ça ? Ou URC ? -[00:01:29] VOIX: Non, pas du tout. -[00:01:31] VOIX: L'UCR, il est bien après. -[00:01:32] VOIX: C'est pour ça que c'est important que je t'explique cette logique-là. -[00:01:36] VOIX: Et oui, c'est pour ça que je, en fait, je sais que tu, même si tu me dis, je connais la mécanique, il y a... -[00:01:43] VOIX: Non, mais c'est les noms que je ne connais pas, donc... -[00:01:46] VOIX: C'est une étape. -[00:01:47] VOIX: Et l'UCR, il est... -[00:01:48] VOIX: Il est bien après. -[00:01:50] VOIX: Oui, d'accord, OK. -[00:01:51] VOIX: Il est pendant, il est bien après, en fait. -[00:01:54] VOIX: C'est ça qu'il faut... -[00:01:56] VOIX: Oui, mais c'est... -[00:01:58] VOIX: Là, j'ai très bien compris que c'était après, en fait, mais je ne sais pas comment les appeler. -[00:02:02] VOIX: En fait, moi, j'ai un problème, en fait, avec la sémantique, en fait, on va dire. -[00:02:07] VOIX: C'est... -[00:02:07] VOIX: Mais ce n'est pas grave. -[00:02:08] VOIX: C'est pour ça que je suis en train d'écrire un document, de manière... -[00:02:13] VOIX: Je vais le partager avec Guy aussi, de manière à ce que, petit à petit, tout le monde... -[00:02:18] VOIX: Parce que même Khalid, il ne connaît pas ce parcours, au fait. -[00:02:21] VOIX: C'est un parcours du contrôle. -[00:02:24] VOIX: La notification... -[00:02:24] VOIX: Bon, je vais te montrer l'exemple de l'ARS, mais l'objectif, ce n'est pas de rentrer dans le détail. -[00:02:28] VOIX: C'est juste que l'ARS, elle va écrire à l'établissement pour leur dire, bonjour, vous êtes santé, sa séance du novembre, machin, ce programme, votre établissement est inclus dans ce programme. -[00:02:44] VOIX: Ça veut dire, vous êtes contrôlé. -[00:02:46] VOIX: Et à partir de là, elle va leur dire le champ 1, le champ 2, de la période, etc. -[00:02:54] VOIX: D'accord. -[00:02:55] VOIX: Le champ 1... -[00:02:56] VOIX: D'accord. -[00:02:58] VOIX: Oui. -[00:02:58] VOIX: Tu vois, ce que tu retrouves par la suite, quand on dit champ 1, champ 2, etc., c'est à partir de ce courrier. -[00:03:05] VOIX: Et donc, à partir de là, nous, dans EVA, on a un requêteur dans la phase initiée, d'accord, qui est dans EVA. -[00:03:13] VOIX: J'ai mis 80% parce qu'il n'est pas ici à 100%. -[00:03:18] VOIX: On est à 70-80%. -[00:03:19] VOIX: Il y a des cas qu'on n'arrive pas encore à faire. -[00:03:22] VOIX: On a immédiatement l'automatisation d'un tableau de synthèse, des dossiers potentiels. -[00:03:26] VOIX: C'est ce que j'appelle les dossiers potentiels. -[00:03:28] VOIX: Le nombre, les euros, les types. -[00:03:30] VOIX: Bon, ça, je te montrerai tout ça après. -[00:03:32] VOIX: Ce n'est pas l'objet. -[00:03:33] VOIX: La liste des dossiers. -[00:03:35] VOIX: Et l'avenir, c'est ce qu'il a travaillé Khalid. -[00:03:38] VOIX: Lecture, je ne sais pas, est-ce qu'il a fait l'IA ou d'autres méthodes. -[00:03:42] VOIX: Tu vois, moi, j'ai dit lecture par l'IA du document. -[00:03:44] VOIX: Et requêtage, l'IA, synthèse des documents. -[00:03:47] VOIX: Et ça, c'est testé, approuvé par Khalid et par l'équipe. -[00:03:51] VOIX: C'est-à-dire, à partir de ce document, il n'y a même plus besoin de passer par l'étape initiée. -[00:03:57] VOIX: Il sort la liste des dossiers potentiels qu'on peut exploiter sous un format Excel, etc. -[00:04:04] VOIX: Pour l'instant, tout ça, c'est bien sûr en rappelie. -[00:04:07] VOIX: Tout ce que je raconte ici. -[00:04:09] VOIX: Mais en tout cas, Khalid a travaillé sur ça. -[00:04:11] VOIX: Ça, entre cette étape et cette étape, il peut se passer un mois, comme une semaine, comme... -[00:04:17] VOIX: Bon, ça, c'est juste, il n'y a rien ici à faire. -[00:04:21] VOIX: C'est l'assurance maladie. -[00:04:22] VOIX: Le MRC, c'est médecin responsable du contrôle. -[00:04:25] VOIX: Le MDIM, c'est médecin DIM, qui crée des comptes. -[00:04:28] VOIX: Bon, ça, c'est le document de création, etc. -[00:04:31] VOIX: Après, il y a un fichier de dossier candidat. -[00:04:35] VOIX: On ne reploite pas dans EVA. -[00:04:39] VOIX: Point d'interrogation. -[00:04:41] VOIX: On n'a pas intégré, nous, dans le process habituel. -[00:04:44] VOIX: Mais c'est vrai que l'année dernière, certains établissements nous ont dit qu'ils commençaient à préparer dès les dossiers candidats, -[00:04:50] VOIX: voire même potentiels, voire même là. -[00:04:52] VOIX: Ils commencent à regarder les dossiers, à préparer leurs argumentaires, etc. -[00:04:58] VOIX: Préparer, c'est ça. -[00:04:59] VOIX: C'est regarder le codage, préparer les argumentaires, etc. -[00:05:04] VOIX: Mais à ce moment-là, on n'a pas les fiches OGC, parce que, de toute façon, on n'a pas de numéro OGC. -[00:05:13] VOIX: Tout ça, c'est le RSSDEF qui a les NUM OGC, le téléchargement de RSSDEF de Druid, -[00:05:19] VOIX: déclenchement de la phase préparée et de la fiche OGC. -[00:05:25] VOIX: Et ça, franchement, pour l'instant, je ne vois pas en quoi il y a. -[00:05:32] VOIX: J'ai une question à te poser. -[00:05:33] VOIX: Pourquoi tu mets une étape dossier candidat non exploité aujourd'hui ? -[00:05:37] VOIX: Et après, comment les établissements commencent à préparer des dossiers alors qu'ils n'ont pas les fiches ? -[00:05:43] VOIX: Parce qu'à partir de la deuxième étape-là, il y a des échanges de fichiers entre l'assurance maladie et l'établissement. -[00:05:51] VOIX: Et l'établissement va envoyer à l'assurance maladie le fichier PMSI. -[00:05:57] VOIX: L'assurance maladie va dire, vous allez avoir, pour le champ 1, par exemple, vous avez 200 dossiers candidats, d'accord ? -[00:06:08] VOIX: Et en fait, après, si le nombre de dossiers sur un champ pour un établissement est très élevé, ça dépend cas par cas, -[00:06:16] VOIX: ils vont faire un tirage au sort. -[00:06:18] VOIX: C'est ça qu'ils appellent les dossiers qui vont être réellement contrôlés, avec un fichier qui s'appelle RSSCOND. -[00:06:25] VOIX: Nous, pour l'instant, on ne l'utilise pas, parce qu'on attend que le contrôleur décide des dossiers qui seront réellement contrôlés. -[00:06:35] VOIX: Mais certains établissements qui disent, oui, mais moi, j'ai que 200-300 dossiers pour l'établissement, -[00:06:40] VOIX: je sais que le nombre de dossiers candidats va être le même que le dossier définitif du contrôle, -[00:06:48] VOIX: ils commencent déjà à préparer, mais ils n'ont pas de numéro OGC, parce que le numéro OGC, il est dans ce fichier RSSDEF. -[00:06:56] VOIX: Et nous, l'appli, pour la déclencher vraiment avec les fiches OGC, tout le bazar, on part du RSSDEF. -[00:07:05] VOIX: Avant le RSSDEF, on peut faire des simulations avec un requêteur, mais ça ne déclenche pas la fiche OGC. -[00:07:12] VOIX: Quand on dit fiche OGC, est-ce que ça te parle ? -[00:07:15] VOIX: Oui, oui, oui, oui, ça je l'ai compris. -[00:07:18] VOIX: Voilà. -[00:07:19] VOIX: Ok. -[00:07:20] VOIX: Ça, Jordan, il pourra te montrer ce que c'est la fiche OGC. -[00:07:24] VOIX: C'est la fiche, en fait, dans laquelle on va modifier un codage et qui intègre dans le groupeur de la TIH -[00:07:31] VOIX: et qui va dire, vous allez perdre, gagner tel ou tel euro. -[00:07:36] VOIX: Tu vois, c'est elle qui va historiser les simulations. -[00:07:40] VOIX: Comme l'objet, ce n'est pas de te montrer ça dans le détail, je ne vais pas me connecter. -[00:07:44] VOIX: Sinon, je peux très bien me connecter et te montrer un petit peu tout ça. -[00:07:49] VOIX: Non, on le verra après. -[00:07:51] VOIX: Je veux bien, en fait, que tu continues. -[00:07:54] VOIX: Là, après, en fait, tu passes sur... -[00:07:55] VOIX: Oui, vas-y, pardon. -[00:07:58] VOIX: Non, non, vas-y, pose ta question. -[00:07:59] VOIX: Non, c'est simplement, en fait, dans le dossier de contrôle, tu télécharges. -[00:08:04] VOIX: Ok, et de druide, c'est... -[00:08:07] VOIX: Et de druide, qu'est-ce que c'est, druide ? -[00:08:09] VOIX: Druide, c'est fichier pémicyde. -[00:08:12] VOIX: C'est fichier pémicyde de l'établissement. -[00:08:14] VOIX: Mais c'est déjà pas assez compliqué, vous rajoutez, en fait, des mots. -[00:08:19] VOIX: Oui, mais ça, c'est le... -[00:08:20] VOIX: Malheureusement, c'est... -[00:08:22] VOIX: C'est pour ça que c'est important que je t'explique, que je crasse, -[00:08:26] VOIX: que tu puisses un peu, quand on dit UCR, -[00:08:29] VOIX: tu sais dans quelle étape... -[00:08:31] VOIX: Oui, ça, c'est l'affaire. -[00:08:31] VOIX: Qu'est-ce qu'il reste et quelle étape, etc. -[00:08:34] VOIX: Tu vois, RSS-DEF, nous, on en parle beaucoup de druide. -[00:08:39] VOIX: Dans les démos, on leur dit, vous téléchargez votre RSS-DEF, votre druide, -[00:08:43] VOIX: et vous allez avoir immédiatement tout ce qui s'affiche dans l'étape préparée. -[00:08:49] VOIX: Et à partir de cette étape, -[00:08:52] VOIX: l'établissement peut commencer à regarder sans codage. -[00:08:56] VOIX: Mais pour regarder sans codage, -[00:08:58] VOIX: soit il se connecte à son DPI, -[00:09:00] VOIX: soit en parallèle. -[00:09:02] VOIX: Ça, c'est pas une étape après, -[00:09:03] VOIX: parce qu'elle peut se faire en parallèle de tout ça. -[00:09:06] VOIX: Elle peut se faire après ça, -[00:09:08] VOIX: c'est-à-dire extraction et exploitation des documents justificatifs -[00:09:12] VOIX: depuis le DPI, -[00:09:13] VOIX: c'est depuis le dossier patient. -[00:09:14] VOIX: D'accord ? -[00:09:15] VOIX: Aujourd'hui, nada, on n'en fait rien. -[00:09:20] VOIX: Extraction, hier, dans une démo, -[00:09:22] VOIX: il y a une nana du CHU de... -[00:09:26] VOIX: C'est pas le CHU de Nice, -[00:09:27] VOIX: c'était quel CHU ? -[00:09:28] VOIX: Je ne me rappelle plus. -[00:09:29] VOIX: Qui nous a dit qu'ils ont mis à 5, -[00:09:32] VOIX: un mois -[00:09:35] VOIX: pour extraire -[00:09:37] VOIX: et classer, -[00:09:38] VOIX: renommer, -[00:09:39] VOIX: organiser les fichiers -[00:09:41] VOIX: pour les transmettre à l'équipe de contrôle. -[00:09:43] VOIX: Je ne parle pas d'exploiter -[00:09:45] VOIX: qui est lire. -[00:09:46] VOIX: Je parle juste de sortir les documents, -[00:09:49] VOIX: les organiser, -[00:09:51] VOIX: les renommer, etc. -[00:09:52] VOIX: Parce que soit le service informatique, -[00:09:54] VOIX: il est futé, -[00:09:55] VOIX: et le DPI permet d'extraire -[00:09:57] VOIX: de manière automatique, -[00:09:58] VOIX: comme nous, -[00:09:59] VOIX: on a fait à Bayonne, -[00:10:01] VOIX: mais il fallait que je tape du poing -[00:10:03] VOIX: pour qu'il le fasse. -[00:10:04] VOIX: Soit l'établissement, -[00:10:05] VOIX: l'informatique, -[00:10:06] VOIX: ils disent à nous, -[00:10:07] VOIX: nous, on n'en sait rien. -[00:10:08] VOIX: Et donc, typiquement, -[00:10:09] VOIX: cet établissement, -[00:10:10] VOIX: ils ont mis 5 personnes -[00:10:11] VOIX: pendant un mois, -[00:10:12] VOIX: ce qui je trouve extrêmement aberrant. -[00:10:15] VOIX: Et là, -[00:10:16] VOIX: c'est pour ça que l'avenir, -[00:10:18] VOIX: tu vois, -[00:10:18] VOIX: en fait, -[00:10:19] VOIX: ici, c'est Eva, -[00:10:21] VOIX: et demain, -[00:10:22] VOIX: c'est avenir. -[00:10:24] VOIX: C'est pour ça que j'ai séparé. -[00:10:25] VOIX: Et tu vas voir -[00:10:26] VOIX: que toutes ces fonctionnalités, -[00:10:27] VOIX: je te dirai après -[00:10:28] VOIX: comment on peut les travailler après. -[00:10:32] VOIX: Parce que si tu as un agent -[00:10:34] VOIX: d'IAVision -[00:10:35] VOIX: qui connaît les DPI -[00:10:36] VOIX: et que tu lui dis -[00:10:38] VOIX: OK, je branche, -[00:10:40] VOIX: tu sais, -[00:10:40] VOIX: on est toujours dans la logique -[00:10:41] VOIX: est-ce qu'on arrivera un jour -[00:10:42] VOIX: à brancher -[00:10:45] VOIX: l'IAVision -[00:10:45] VOIX: pour extraire -[00:10:47] VOIX: des documents, -[00:10:48] VOIX: des informations, -[00:10:49] VOIX: etc. -[00:10:50] VOIX: Ben ça, -[00:10:51] VOIX: tu vois, -[00:10:51] VOIX: j'ai mis un point -[00:10:52] VOIX: d'interrogation -[00:10:52] VOIX: parce que c'est pas -[00:10:54] VOIX: d'actualité, -[00:10:55] VOIX: mais en tout cas, -[00:10:56] VOIX: il faut savoir -[00:10:57] VOIX: que ça peut être un point -[00:11:00] VOIX: pour les établissements -[00:11:03] VOIX: parce que exploiter les documents -[00:11:06] VOIX: c'est une chose -[00:11:07] VOIX: mais les extraire -[00:11:08] VOIX: c'est important -[00:11:09] VOIX: dans un contrôle -[00:11:10] VOIX: parce qu'il faut -[00:11:12] VOIX: les transmettre -[00:11:13] VOIX: via Blue File -[00:11:14] VOIX: à l'équipe du contrôle. -[00:11:15] VOIX: Tu sais pourquoi -[00:11:18] VOIX: Christophe, -[00:11:19] VOIX: ah non, -[00:11:20] VOIX: t'étais pas là, -[00:11:21] VOIX: quand Jordan lui a créé -[00:11:23] VOIX: des répertoires vides -[00:11:24] VOIX: au GC1, -[00:11:25] VOIX: numéro de dossier, -[00:11:26] VOIX: machin, -[00:11:27] VOIX: ben justement, -[00:11:28] VOIX: il aide -[00:11:28] VOIX: à classer, -[00:11:30] VOIX: à renommer -[00:11:31] VOIX: ces fichiers -[00:11:32] VOIX: pour qu'il les envoie -[00:11:33] VOIX: au médecin contrôleur. -[00:11:35] VOIX: Et l'exploitation, -[00:11:36] VOIX: c'est ce que tu vas me montrer, -[00:11:37] VOIX: c'est l'objet -[00:11:38] VOIX: de ce qu'on va voir. -[00:11:40] VOIX: Projet IA Dominique -[00:11:41] VOIX: en cours, -[00:11:42] VOIX: la synthèse -[00:11:43] VOIX: du dossier patient, -[00:11:44] VOIX: le codage -[00:11:45] VOIX: PMSI groupage, -[00:11:47] VOIX: argumentaire -[00:11:47] VOIX: de chaque codage, -[00:11:48] VOIX: etc. -[00:11:49] VOIX: Et donc, -[00:11:49] VOIX: la synthèse, -[00:11:50] VOIX: c'est ta question. -[00:11:51] VOIX: Comment je différencie -[00:11:53] VOIX: l'histoire ancienne -[00:11:54] VOIX: de la réalité -[00:11:55] VOIX: de ce qui est pris en charge -[00:11:56] VOIX: pendant le séjour -[00:11:57] VOIX: qui m'intéresse ? -[00:11:59] VOIX: Quel est le motif -[00:12:00] VOIX: de l'hospitalisation ? -[00:12:02] VOIX: Quelles sont -[00:12:03] VOIX: les pathologies -[00:12:04] VOIX: prises en charge ? -[00:12:05] VOIX: Et donc, -[00:12:06] VOIX: c'est hyper important -[00:12:08] VOIX: de dire que, -[00:12:09] VOIX: avant de coder -[00:12:10] VOIX: le PMSI, -[00:12:11] VOIX: c'est ça le cerveau -[00:12:13] VOIX: dont tu dis, -[00:12:14] VOIX: ça manque. -[00:12:15] VOIX: C'est-à-dire, -[00:12:15] VOIX: avant de dire, -[00:12:16] VOIX: je code, -[00:12:17] VOIX: tous ceux qui commencent -[00:12:18] VOIX: par dire, -[00:12:18] VOIX: ah, c'est quoi -[00:12:19] VOIX: comme code, -[00:12:19] VOIX: c'est quoi comme code, -[00:12:20] VOIX: c'est des gens -[00:12:21] VOIX: qui ne vont pas forcément -[00:12:22] VOIX: bien raisonner -[00:12:23] VOIX: sur le dossier. -[00:12:24] VOIX: Ceux qui raisonnent, -[00:12:26] VOIX: c'est ceux qui vont -[00:12:27] VOIX: comprendre le dossier -[00:12:28] VOIX: et l'organiser, -[00:12:31] VOIX: de dire, -[00:12:31] VOIX: ok, -[00:12:31] VOIX: ça c'est l'histoire ancienne, -[00:12:33] VOIX: ça a commencé -[00:12:34] VOIX: telle année, -[00:12:35] VOIX: ça arrive aujourd'hui, -[00:12:37] VOIX: voici le contexte, -[00:12:38] VOIX: l'histoire du patient, -[00:12:39] VOIX: mais aujourd'hui, -[00:12:40] VOIX: il vient en hospitalisation, -[00:12:42] VOIX: pourquoi ? -[00:12:43] VOIX: Et qu'est-ce qui a été -[00:12:44] VOIX: pris en charge ? -[00:12:46] VOIX: Ça, -[00:12:47] VOIX: mine de rien, -[00:12:47] VOIX: c'est la clé du codage. -[00:12:49] VOIX: Si tu n'as pas -[00:12:50] VOIX: cette logique de synthèse, -[00:12:52] VOIX: de compréhension, -[00:12:53] VOIX: en fait, -[00:12:54] VOIX: c'est compréhension, -[00:12:55] VOIX: je vais l'appeler comme ça, -[00:12:57] VOIX: compréhension, -[00:13:00] VOIX: compréhension, -[00:13:03] VOIX: organisation, -[00:13:05] VOIX: aïe, -[00:13:06] VOIX: aïe, -[00:13:06] VOIX: aïe, -[00:13:06] VOIX: aïe, -[00:13:07] VOIX: aïe, -[00:13:09] VOIX: organization, -[00:13:14] VOIX: et synthèse du DPI, -[00:13:16] VOIX: pour moi, -[00:13:17] VOIX: franchement, -[00:13:18] VOIX: on va le voir, -[00:13:19] VOIX: parce que c'est l'exemple -[00:13:21] VOIX: que tu demandes, -[00:13:22] VOIX: mais on va travailler -[00:13:23] VOIX: sur ça. -[00:13:24] VOIX: Pour que le codage -[00:13:25] VOIX: PMC, -[00:13:26] VOIX: après, -[00:13:26] VOIX: c'est facile, -[00:13:27] VOIX: le groupage, -[00:13:28] VOIX: c'est facile, -[00:13:29] VOIX: et si on comprend bien, -[00:13:32] VOIX: on organise bien, -[00:13:32] VOIX: on synthétise bien, -[00:13:33] VOIX: l'argumentaire devient logique. -[00:13:36] VOIX: D'accord ? -[00:13:38] VOIX: Donc, -[00:13:39] VOIX: on va dire, -[00:13:40] VOIX: c'est le sujet du moment. -[00:13:44] VOIX: Et une fois que j'ai fait ça, -[00:13:47] VOIX: l'idéal, -[00:13:48] VOIX: évidemment, -[00:13:49] VOIX: c'est qu'on crée -[00:13:50] VOIX: une préparation IA, -[00:13:52] VOIX: c'est-à-dire, -[00:13:53] VOIX: on dit préparation du codage -[00:13:54] VOIX: des dossiers par IA, -[00:13:55] VOIX: aujourd'hui, -[00:13:55] VOIX: on n'a pas la possibilité -[00:13:56] VOIX: de faire ça, -[00:13:58] VOIX: mais on peut ajouter une phase -[00:14:00] VOIX: qui s'appelle -[00:14:00] VOIX: créer une phase préparée IA, -[00:14:02] VOIX: c'est un agent IA -[00:14:03] VOIX: vision ou autre, -[00:14:04] VOIX: je n'en sais rien, -[00:14:05] VOIX: pour saisir directement -[00:14:06] VOIX: les codes IA -[00:14:07] VOIX: dans la phase préparée IA. -[00:14:09] VOIX: Ce codage IA -[00:14:10] VOIX: peut même concerner -[00:14:11] VOIX: les dossiers potentiels, -[00:14:12] VOIX: les dossiers candidats, -[00:14:13] VOIX: puis les dossiers du RSSDF, -[00:14:14] VOIX: mais bon, -[00:14:14] VOIX: ça, -[00:14:14] VOIX: c'est du détail. -[00:14:17] VOIX: Et ensuite, -[00:14:18] VOIX: la phase qui existe aujourd'hui, -[00:14:20] VOIX: c'est préparer des dossiers -[00:14:21] VOIX: par l'établissement. -[00:14:23] VOIX: C'est ça, -[00:14:24] VOIX: comme j'ai dit ici, -[00:14:25] VOIX: le RSSDF, -[00:14:26] VOIX: il déclenche la phase préparée. -[00:14:28] VOIX: Nous, -[00:14:28] VOIX: ces étapes-là, -[00:14:29] VOIX: aujourd'hui, -[00:14:29] VOIX: elles n'existent pas. -[00:14:30] VOIX: On bascule vers ça. -[00:14:32] VOIX: Et la phase préparée, -[00:14:33] VOIX: aujourd'hui, -[00:14:34] VOIX: c'est l'humain. -[00:14:34] VOIX: C'est l'humain -[00:14:35] VOIX: qui va ouvrir -[00:14:36] VOIX: d'une manière ou d'une autre, -[00:14:37] VOIX: soit directement -[00:14:38] VOIX: son logiciel des pays, -[00:14:40] VOIX: soit ses fiches PDF, -[00:14:43] VOIX: voilà, -[00:14:43] VOIX: et qui va dire, -[00:14:45] VOIX: je décide de tel codage -[00:14:46] VOIX: et qui va le saisir -[00:14:48] VOIX: dans la phase préparée -[00:14:49] VOIX: de Eva. -[00:14:49] VOIX: mais on peut imaginer -[00:14:51] VOIX: une étape POC, -[00:14:53] VOIX: parce que là aussi, -[00:14:54] VOIX: c'est un POC, -[00:14:55] VOIX: exploitation POC, -[00:14:57] VOIX: ça aussi, -[00:14:58] VOIX: il y aura qui vont le lire, -[00:15:00] VOIX: mais si on a l'étape POC -[00:15:01] VOIX: pour valider la phase préparée, -[00:15:03] VOIX: il y a, -[00:15:03] VOIX: ça va rejoindre ça. -[00:15:09] VOIX: puisque l'humain, -[00:15:10] VOIX: et d'ailleurs, -[00:15:11] VOIX: il y a, -[00:15:12] VOIX: tu vois, -[00:15:12] VOIX: Frédéric, -[00:15:13] VOIX: etc., -[00:15:14] VOIX: c'est des gens, -[00:15:15] VOIX: soit on l'intègre -[00:15:16] VOIX: d'emblée -[00:15:17] VOIX: dans Eva, -[00:15:19] VOIX: soit on le fait -[00:15:20] VOIX: à part, -[00:15:21] VOIX: et on fait valider -[00:15:23] VOIX: par les médecins, -[00:15:24] VOIX: par les établissements, -[00:15:25] VOIX: ce que l'IA propose ici, -[00:15:27] VOIX: réfléchir, -[00:15:28] VOIX: ça c'est le premier truc, -[00:15:30] VOIX: valider déjà -[00:15:30] VOIX: ou dévalider -[00:15:31] VOIX: les propositions, -[00:15:33] VOIX: les synthèses, -[00:15:33] VOIX: etc., -[00:15:34] VOIX: et puis réfléchir -[00:15:35] VOIX: à comment articuler -[00:15:36] VOIX: toutes ces nouvelles fonctionnalités -[00:15:38] VOIX: et les intégrer -[00:15:39] VOIX: dans le logiciel Eva, -[00:15:41] VOIX: et les exploiter -[00:15:43] VOIX: pour EvaNov -[00:15:44] VOIX: et pour tout ce qui va suivre, -[00:15:46] VOIX: parce que c'est, -[00:15:47] VOIX: il faut les raisonner -[00:15:50] VOIX: comme des petites fonctionnalités -[00:15:51] VOIX: qui puissent être intégrées -[00:15:56] VOIX: à d'autres logiciels -[00:15:58] VOIX: ou à d'autres projets. -[00:16:04] VOIX: Oui, -[00:16:04] VOIX: je vois l'idée en fait, -[00:16:06] VOIX: Amina, -[00:16:07] VOIX: de ce côté-là, -[00:16:07] VOIX: j'ai bien vu l'idée. -[00:16:11] VOIX: OK. -[00:16:11] VOIX: Est-ce que jusque-là, -[00:16:12] VOIX: déjà, -[00:16:13] VOIX: sur tes étapes, -[00:16:14] VOIX: c'est bon. -[00:16:15] VOIX: C'est bon. -[00:16:16] VOIX: Pour moi, -[00:16:17] VOIX: j'ai commencé. -[00:16:17] VOIX: Oui, -[00:16:18] VOIX: vas-y, -[00:16:18] VOIX: vas-y, -[00:16:18] VOIX: continue. -[00:16:19] VOIX: Non, -[00:16:19] VOIX: non, -[00:16:19] VOIX: vas-y. -[00:16:20] VOIX: L'étape pendant, -[00:16:21] VOIX: je l'ai commencé, -[00:16:24] VOIX: en fait, -[00:16:24] VOIX: c'est les documents -[00:16:26] VOIX: justificatifs du DPI, -[00:16:28] VOIX: avant même de dire, -[00:16:29] VOIX: bonjour Guy, -[00:16:31] VOIX: il vient de rentrer, -[00:16:32] VOIX: avant même de dire, -[00:16:34] VOIX: je les lis, -[00:16:35] VOIX: je les comprends, -[00:16:36] VOIX: je les organise, -[00:16:37] VOIX: etc., -[00:16:38] VOIX: il y a quand même -[00:16:38] VOIX: une petite étape -[00:16:40] VOIX: de dire, -[00:16:41] VOIX: tu vois, -[00:16:42] VOIX: directement dans ton tracker -[00:16:43] VOIX: que tu as dit, -[00:16:44] VOIX: tu vas avoir un truc -[00:16:45] VOIX: qui s'appelle observation, -[00:16:47] VOIX: un autre qui s'appelle -[00:16:48] VOIX: transmission, -[00:16:49] VOIX: un autre qui s'appelle -[00:16:50] VOIX: croc en train d'opératoire, -[00:16:52] VOIX: un autre qui s'appelle -[00:16:52] VOIX: CRH ou pas. -[00:16:54] VOIX: Et donc, -[00:16:55] VOIX: en fait, -[00:16:55] VOIX: il faut les classer, -[00:16:57] VOIX: peut-être les renommer, -[00:16:59] VOIX: avoir, -[00:16:59] VOIX: ça c'est un point d'interrogation, -[00:17:01] VOIX: les organiser, -[00:17:02] VOIX: etc., -[00:17:02] VOIX: pour les transmettre -[00:17:04] VOIX: via Blufa -[00:17:05] VOIX: et la liquide du contrôle, -[00:17:07] VOIX: plus, -[00:17:11] VOIX: identifier, -[00:17:11] VOIX: peut-être, -[00:17:12] VOIX: les documents manquants, -[00:17:14] VOIX: parce que souvent, -[00:17:16] VOIX: ça, -[00:17:17] VOIX: on le fait à la main. -[00:17:18] VOIX: aujourd'hui, -[00:17:19] VOIX: Eva Nada, -[00:17:21] VOIX: à venir, -[00:17:22] VOIX: fonctionnalité, -[00:17:23] VOIX: à réfléchir. -[00:17:24] VOIX: Question, -[00:17:24] VOIX: en fait, -[00:17:25] VOIX: parce que ça fait plusieurs fois -[00:17:26] VOIX: que je le vois, -[00:17:27] VOIX: Blufaile, -[00:17:28] VOIX: c'est quoi, -[00:17:29] VOIX: en fait, -[00:17:29] VOIX: exactement ? -[00:17:30] VOIX: Blufaile, -[00:17:31] VOIX: c'est juste un... -[00:17:32] VOIX: C'est un système -[00:17:34] VOIX: de transmission. -[00:17:35] VOIX: C'est le dépôt, -[00:17:35] VOIX: c'est le... -[00:17:36] VOIX: c'est pas... -[00:17:37] VOIX: c'est comme un S -[00:17:37] VOIX: ou comme Dropbox. -[00:17:39] VOIX: D'accord. -[00:17:40] VOIX: Sécurisé de l'assurance maladie. -[00:17:42] VOIX: C'est-à-dire, -[00:17:43] VOIX: c'est eux qui gèrent ça, -[00:17:44] VOIX: ils ouvrent, -[00:17:45] VOIX: ils ouvrent à l'établissement, -[00:17:46] VOIX: ou médecins d'île -[00:17:47] VOIX: de l'établissement, -[00:17:47] VOIX: un compte. -[00:17:48] VOIX: C'est ce truc-là, -[00:17:50] VOIX: je vais l'ouvrir, -[00:17:50] VOIX: ça va mieux te parler. -[00:17:53] VOIX: Reception, -[00:17:55] VOIX: tu vois, -[00:17:55] VOIX: c'est l'assurance maladie, -[00:17:56] VOIX: d'un mail, -[00:17:57] VOIX: Blufaile, -[00:17:57] VOIX: création du compte, -[00:17:58] VOIX: Blufaile, -[00:17:58] VOIX: etc. -[00:17:59] VOIX: Et donc, -[00:18:00] VOIX: c'est eux qui t'ouvrent, -[00:18:01] VOIX: tu valides par ton mail, -[00:18:03] VOIX: etc. -[00:18:03] VOIX: Et derrière, -[00:18:04] VOIX: c'est l'outil de communication -[00:18:06] VOIX: sécurisé, -[00:18:07] VOIX: parce que tu vas déposer -[00:18:08] VOIX: des dossiers patients, -[00:18:10] VOIX: nominatifs, -[00:18:11] VOIX: tu vas échanger -[00:18:12] VOIX: les fichiers du contrôle, -[00:18:13] VOIX: etc. -[00:18:14] VOIX: C'est comme nous, -[00:18:15] VOIX: on fait avec S, -[00:18:16] VOIX: avec notre répertoire, -[00:18:20] VOIX: tu vois, -[00:18:20] VOIX: ShareFile, -[00:18:21] VOIX: c'est un peu le ShareFile -[00:18:23] VOIX: de l'assurance maladie. -[00:18:26] VOIX: Ok. -[00:18:29] VOIX: Après, -[00:18:30] VOIX: bien sûr, -[00:18:30] VOIX: il y a d'autres étapes -[00:18:31] VOIX: que je travaillerai, -[00:18:32] VOIX: mais tu vois, -[00:18:33] VOIX: en fait, -[00:18:37] VOIX: toi, -[00:18:38] VOIX: ce que tu m'as demandé, -[00:18:39] VOIX: c'est ça, -[00:18:42] VOIX: on va se concentrer sur ça, -[00:18:44] VOIX: tu vois. -[00:18:44] VOIX: parce qu'il y a plusieurs choses -[00:18:46] VOIX: et je vais essayer -[00:18:47] VOIX: de les lister petit à petit, -[00:18:49] VOIX: les revoir ensemble, -[00:18:51] VOIX: parce que c'est important -[00:18:52] VOIX: qu'on soit tous -[00:18:52] VOIX: sur la même longueur d'onde, -[00:18:54] VOIX: parce que peut-être -[00:18:55] VOIX: que moi aussi, -[00:18:56] VOIX: j'ai oublié des choses, -[00:18:57] VOIX: etc. -[00:18:57] VOIX: pour que l'on voit -[00:19:00] VOIX: de quoi on parle -[00:19:01] VOIX: et quand je t'ai dit -[00:19:02] VOIX: la lecture -[00:19:02] VOIX: et les argumentaires, -[00:19:04] VOIX: les argumentaires -[00:19:05] VOIX: de chaque codage, -[00:19:06] VOIX: tu vas voir -[00:19:07] VOIX: que ça va dépendre -[00:19:08] VOIX: des étapes -[00:19:09] VOIX: parce que ton argumentaire -[00:19:10] VOIX: ici -[00:19:12] VOIX: de la phase avant, -[00:19:14] VOIX: il va être différent -[00:19:16] VOIX: de ton argumentaire -[00:19:17] VOIX: de la phase après. -[00:19:19] VOIX: Ben oui, -[00:19:20] VOIX: ça, -[00:19:21] VOIX: là, -[00:19:21] VOIX: il y a une logique. -[00:19:23] VOIX: Là, -[00:19:24] VOIX: pour moi, -[00:19:24] VOIX: en fait, -[00:19:24] VOIX: il y a une logique. -[00:19:26] VOIX: OK. -[00:19:27] VOIX: Bien sûr, -[00:19:28] VOIX: parce que ici, -[00:19:29] VOIX: l'établissement, -[00:19:29] VOIX: il est avec lui-même, -[00:19:31] VOIX: si tu veux. -[00:19:32] VOIX: Il prépare, -[00:19:33] VOIX: mais il n'a aucune idée -[00:19:34] VOIX: de comment -[00:19:35] VOIX: le médecin contrôleur -[00:19:36] VOIX: va réagir, -[00:19:37] VOIX: en fait. -[00:19:46] VOIX: quelqu'un qui va lui valider -[00:19:47] VOIX: ou pas son codage. -[00:19:50] VOIX: Mais après, -[00:19:51] VOIX: le retour du médecin contrôleur -[00:19:53] VOIX: ou après le retour -[00:19:54] VOIX: de l'UCR, -[00:19:55] VOIX: etc., -[00:19:55] VOIX: tu argumentes -[00:19:57] VOIX: comme un avocat. -[00:19:58] VOIX: Tu contre-argumentes, -[00:19:59] VOIX: en fait. -[00:20:00] VOIX: C'est-à-dire, -[00:20:01] VOIX: tu utilises -[00:20:02] VOIX: les refus -[00:20:03] VOIX: ou les arguments -[00:20:04] VOIX: du médecin contrôleur -[00:20:05] VOIX: pour les contre-balancer. -[00:20:08] VOIX: Et à l'étape avant, -[00:20:11] VOIX: tu ne les as pas, -[00:20:12] VOIX: les arguments -[00:20:12] VOIX: du médecin contrôleur. -[00:20:13] VOIX: Tu peux les imaginer. -[00:20:15] VOIX: Et c'est pour ça -[00:20:15] VOIX: que le travail -[00:20:17] VOIX: sur l'après, -[00:20:18] VOIX: ça permet aussi -[00:20:18] VOIX: d'imaginer -[00:20:19] VOIX: les contre-argumentaires -[00:20:21] VOIX: que va utiliser -[00:20:22] VOIX: le médecin contrôleur -[00:20:24] VOIX: et l'UCR. -[00:20:30] VOIX: Ça te va un peu -[00:20:32] VOIX: la manière -[00:20:32] VOIX: pour commencer à... -[00:20:34] VOIX: Mais le document -[00:20:35] VOIX: que tu as fait, -[00:20:36] VOIX: en fait, -[00:20:37] VOIX: quand tu penses -[00:20:38] VOIX: l'avoir fini, -[00:20:38] VOIX: c'est bien -[00:20:39] VOIX: si tu me le transmets. -[00:20:41] VOIX: Bien sûr, -[00:20:42] VOIX: c'est fait pour ça. -[00:20:42] VOIX: Voilà, -[00:20:43] VOIX: parce qu'en fait, -[00:20:43] VOIX: moi, -[00:20:43] VOIX: ça me guide -[00:20:45] VOIX: et puis avec Ralid, -[00:20:46] VOIX: comme ça, -[00:20:46] VOIX: on pourra travailler dessus -[00:20:47] VOIX: parce que si -[00:20:47] VOIX: ce n'est pas lui non plus, -[00:20:49] VOIX: voilà, -[00:20:50] VOIX: ça me guide. -[00:20:50] VOIX: Bon, -[00:20:51] VOIX: ce qu'il y a de bien, -[00:20:52] VOIX: en fait, -[00:20:52] VOIX: c'est que la mécanique -[00:20:53] VOIX: en fait que tu m'as décrite, -[00:20:55] VOIX: alors en dehors, -[00:20:55] VOIX: en fait, -[00:20:56] VOIX: de tout acronyme, -[00:20:57] VOIX: de choses comme ça, -[00:20:57] VOIX: en fait, -[00:20:58] VOIX: je l'avais... -[00:20:58] VOIX: Voilà, -[00:21:00] VOIX: je l'avais bien saisi. -[00:21:02] VOIX: Je l'ai bien saisi -[00:21:04] VOIX: et là, -[00:21:04] VOIX: maintenant, -[00:21:05] VOIX: en fait, -[00:21:05] VOIX: est-ce que j'ai le droit -[00:21:05] VOIX: de te poser des questions ? -[00:21:07] VOIX: Bien sûr, -[00:21:08] VOIX: c'est là, -[00:21:09] VOIX: parce que, -[00:21:09] VOIX: évidemment, -[00:21:10] VOIX: mais c'était aussi -[00:21:11] VOIX: pour te montrer -[00:21:12] VOIX: que moi, -[00:21:12] VOIX: j'ai compris aussi -[00:21:13] VOIX: sur quoi tu travailles, -[00:21:15] VOIX: comment tu raisons, -[00:21:16] VOIX: etc. -[00:21:17] VOIX: C'est pour ça -[00:21:18] VOIX: que l'autre jour, -[00:21:18] VOIX: j'ai voulu te montrer -[00:21:19] VOIX: un peu la mécanique. -[00:21:22] VOIX: Je ne cherchais pas, -[00:21:23] VOIX: en fait, -[00:21:23] VOIX: que le code, -[00:21:24] VOIX: en fait, -[00:21:24] VOIX: parce que je n'étais pas là, -[00:21:26] VOIX: c'était la mécanique. -[00:21:27] VOIX: Voilà. -[00:21:28] VOIX: Une fois, -[00:21:29] VOIX: en fait, -[00:21:29] VOIX: que la mécanique, -[00:21:30] VOIX: en fait, -[00:21:30] VOIX: elle est bonne. -[00:21:30] VOIX: Après, -[00:21:31] VOIX: effectivement, -[00:21:31] VOIX: en fait, -[00:21:31] VOIX: il y a tout le travail -[00:21:32] VOIX: que je suis en train de faire. -[00:21:34] VOIX: Je ne l'ai pas encore fini, -[00:21:35] VOIX: donc on verra ça par la suite, -[00:21:37] VOIX: mais j'ai des questions -[00:21:38] VOIX: justement pour le finir -[00:21:39] VOIX: parce qu'en fait, -[00:21:40] VOIX: j'ai besoin, -[00:21:41] VOIX: en fait, -[00:21:41] VOIX: d'un parti métier. -[00:21:42] VOIX: C'est ce que je disais -[00:21:43] VOIX: tout à l'heure. -[00:21:43] VOIX: Alors, -[00:21:44] VOIX: j'ai un raisonnement -[00:21:45] VOIX: qui est basé uniquement, -[00:21:47] VOIX: en fait, -[00:21:47] VOIX: sur la documentation, -[00:21:48] VOIX: en fait, -[00:21:48] VOIX: que j'ai. -[00:22:00] VOIX: Oui. -[00:22:01] VOIX: Et moi, -[00:22:01] VOIX: par exemple, -[00:22:02] VOIX: si on prend un exemple -[00:22:03] VOIX: de dossier, -[00:22:06] VOIX: on pourra le lire ensemble. -[00:22:07] VOIX: Je ne sais pas comment -[00:22:08] VOIX: tu veux procéder -[00:22:09] VOIX: et tu montrerai les règles. -[00:22:11] VOIX: Alors, -[00:22:12] VOIX: voyons, -[00:22:12] VOIX: comment on peut faire ? -[00:22:14] VOIX: Toi, -[00:22:14] VOIX: tu en as un sous la main -[00:22:16] VOIX: ou alors -[00:22:16] VOIX: je me fournis un, -[00:22:17] VOIX: moi ? -[00:22:19] VOIX: Fournis un. -[00:22:19] VOIX: Fournis un -[00:22:21] VOIX: que tu trouves -[00:22:22] VOIX: au hasard, -[00:22:23] VOIX: en fait, -[00:22:24] VOIX: parce qu'on n'est pas obligé -[00:22:24] VOIX: de travailler -[00:22:25] VOIX: sur des dossiers complexes. -[00:22:27] VOIX: Ce qui soit complexe -[00:22:28] VOIX: ou simple, -[00:22:28] VOIX: l'essentiel, -[00:22:29] VOIX: c'est que -[00:22:31] VOIX: je puisse te montrer -[00:22:32] VOIX: avec le guide méthodologique -[00:22:33] VOIX: parce que je ne sais pas -[00:22:34] VOIX: si tu l'as exploité. -[00:22:35] VOIX: Oui, -[00:22:36] VOIX: je l'ai exploité. -[00:22:37] VOIX: En fait, -[00:22:37] VOIX: je l'ai intégré -[00:22:38] VOIX: dans le système -[00:22:39] VOIX: mais ce qui me manque, -[00:22:40] VOIX: en fait, -[00:22:40] VOIX: c'est le guide méthodologique, -[00:22:43] VOIX: il est très bien, -[00:22:44] VOIX: certainement. -[00:22:44] VOIX: Sauf que le problème, -[00:22:45] VOIX: c'est que si je n'ai pas -[00:22:46] VOIX: en fait ta lecture à toi, -[00:22:48] VOIX: la lecture d'un professionnel, -[00:22:50] VOIX: d'un expert, -[00:22:51] VOIX: ça ne marche pas, -[00:22:52] VOIX: en fait, -[00:22:52] VOIX: en réalité. -[00:22:53] VOIX: Je me suis rendu compte -[00:22:53] VOIX: d'une chose, -[00:22:54] VOIX: c'est que si je suis -[00:22:55] VOIX: pile-poil l'air, -[00:22:56] VOIX: ça ne marche pas. -[00:22:57] VOIX: Alors, -[00:22:57] VOIX: attends, -[00:22:57] VOIX: j'essaye de te partager -[00:22:58] VOIX: mon écran. -[00:22:58] VOIX: Attends, -[00:22:59] VOIX: cherche un dossier. -[00:23:01] VOIX: Attends, -[00:23:01] VOIX: attends, -[00:23:02] VOIX: est-ce que je ne peux pas -[00:23:03] VOIX: en fait partager mon écran ? -[00:23:05] VOIX: C'est en haut à droite. -[00:23:07] VOIX: Non, -[00:23:07] VOIX: mais je l'ai trouvé -[00:23:08] VOIX: mais quand je clique dessus, -[00:23:10] VOIX: en fait, -[00:23:10] VOIX: je ne peux pas. -[00:23:11] VOIX: Est-ce que toi, -[00:23:12] VOIX: tu n'as pas un truc -[00:23:14] VOIX: en fait qui se... -[00:23:16] VOIX: Ou alors, -[00:23:16] VOIX: c'est mon ordinateur -[00:23:17] VOIX: en fait qui a un problème. -[00:23:18] VOIX: Ah, -[00:23:19] VOIX: putain, -[00:23:19] VOIX: ce n'est pas vrai. -[00:23:21] VOIX: Et puis, -[00:23:22] VOIX: en plus, -[00:23:22] VOIX: je veux le faire -[00:23:22] VOIX: à partir de mon truc. -[00:23:24] VOIX: Je vais fermer -[00:23:25] VOIX: en fait la fenêtre, -[00:23:26] VOIX: enfin, -[00:23:27] VOIX: pas celle d'en bas, -[00:23:28] VOIX: et je reviens de suite. -[00:23:30] VOIX: Ok, -[00:23:31] VOIX: il n'y a pas de sens. -[00:23:31] VOIX: Je reviens de suite. -[00:23:33] VOIX: Allez, -[00:23:34] VOIX: de suite. -[00:23:35] VOIX: Ah oui, -[00:23:35] VOIX: mais il a planté mon truc. -[00:23:38] VOIX: Ah ben, -[00:23:39] VOIX: c'est ballot -[00:23:39] VOIX: pour un informaticien. -[00:23:45] VOIX: Ah. -[00:24:22] VOIX: me revoilà. -[00:24:23] VOIX: Alors, -[00:24:23] VOIX: alors voyons si ça marche. -[00:24:24] VOIX: Oui, -[00:24:24] VOIX: ça marche. -[00:24:25] VOIX: Allez, -[00:24:26] VOIX: hop, -[00:24:26] VOIX: je te partage tout mon écran. -[00:24:30] VOIX: Voilà. -[00:24:31] VOIX: Voilà, -[00:24:31] VOIX: normalement. -[00:24:32] VOIX: Nickel. -[00:24:33] VOIX: Voilà, -[00:24:33] VOIX: là, -[00:24:33] VOIX: maintenant, -[00:24:33] VOIX: tu me vois. -[00:24:34] VOIX: Et là, -[00:24:34] VOIX: on va prendre un dossier. -[00:24:35] VOIX: Alors, -[00:24:36] VOIX: je prends un dossier au pif. -[00:24:37] VOIX: D'accord ? -[00:24:38] VOIX: D'ailleurs, -[00:24:39] VOIX: en passant, -[00:24:40] VOIX: en fait, -[00:24:40] VOIX: je te montre. -[00:24:41] VOIX: Là, -[00:24:41] VOIX: moi, -[00:24:41] VOIX: ça, -[00:24:41] VOIX: c'est le dossier complet. -[00:24:42] VOIX: Donc, -[00:24:43] VOIX: en fait, -[00:24:43] VOIX: j'ai le CRH -[00:24:45] VOIX: et le tracker. -[00:24:46] VOIX: D'accord ? -[00:24:47] VOIX: Et moi, -[00:24:47] VOIX: ce que je fais... -[00:24:48] VOIX: et parfois, -[00:24:49] VOIX: pardon, -[00:24:50] VOIX: parfois, -[00:24:50] VOIX: dans le tracker, -[00:24:52] VOIX: tu peux avoir le CRH intégré -[00:24:54] VOIX: dans le truc complet. -[00:24:56] VOIX: J'ai remarqué... -[00:24:57] VOIX: Là, -[00:24:58] VOIX: ça s'appelle CRH à part -[00:24:59] VOIX: parce qu'on l'a ajouté -[00:25:00] VOIX: a posteriori. -[00:25:01] VOIX: Il n'a pas été extrait -[00:25:04] VOIX: initialement. -[00:25:05] VOIX: Mais ça, -[00:25:05] VOIX: à la rigueur, -[00:25:06] VOIX: c'est du détail -[00:25:06] VOIX: parce qu'on n'aura pas -[00:25:07] VOIX: cette méthode partout -[00:25:08] VOIX: dans les autres établissements. -[00:25:10] VOIX: Ce n'est pas du tout ça -[00:25:11] VOIX: l'essentiel. -[00:25:12] VOIX: D'accord. -[00:25:13] VOIX: Bon, -[00:25:13] VOIX: je fais, -[00:25:13] VOIX: en fait, -[00:25:14] VOIX: une petite passe -[00:25:14] VOIX: pour que tu visualises. -[00:25:16] VOIX: Ici, -[00:25:17] VOIX: en fait, -[00:25:17] VOIX: tu vois, -[00:25:18] VOIX: le dossier, -[00:25:19] VOIX: c'est ça, -[00:25:20] VOIX: c'est d'autres documents -[00:25:21] VOIX: et l'IA, -[00:25:22] VOIX: en fait, -[00:25:23] VOIX: ne fonctionne que sur ça. -[00:25:25] VOIX: Donc, -[00:25:25] VOIX: qui sont des documents -[00:25:26] VOIX: qui sont anonymisés. -[00:25:28] VOIX: Et je te montre, -[00:25:28] VOIX: tu vois, -[00:25:29] VOIX: le PDF. -[00:25:30] VOIX: Donc, -[00:25:30] VOIX: en fait, -[00:25:31] VOIX: c'est ça. -[00:25:32] VOIX: Tu vois, -[00:25:33] VOIX: en fait, -[00:25:33] VOIX: tous les noms, -[00:25:34] VOIX: les prénoms, -[00:25:35] VOIX: les dates, -[00:25:35] VOIX: certaines dates, -[00:25:36] VOIX: certaines choses, -[00:25:37] VOIX: en fait, -[00:25:37] VOIX: sont anonymisés. -[00:25:38] VOIX: Ça, -[00:25:38] VOIX: il faut que j'aie un retour -[00:25:39] VOIX: de la part de Jordan -[00:25:40] VOIX: parce qu'en fait, -[00:25:40] VOIX: il avait des règles -[00:25:41] VOIX: à me passer. -[00:25:42] VOIX: Et donc, -[00:25:43] VOIX: si tu veux, -[00:25:43] VOIX: tout le système... -[00:25:45] VOIX: Oui. -[00:25:46] VOIX: Ce qu'il faudra faire -[00:25:47] VOIX: si tu fais ça, -[00:25:50] VOIX: c'est insérer -[00:25:53] VOIX: le numéro OGC. -[00:25:56] VOIX: D'accord. -[00:25:59] VOIX: D'accord. -[00:26:00] VOIX: Parce que si jamais, -[00:26:02] VOIX: typiquement, -[00:26:04] VOIX: tu vois, -[00:26:05] VOIX: moi, -[00:26:05] VOIX: si je veux envoyer -[00:26:06] VOIX: les documents -[00:26:07] VOIX: à l'ARS -[00:26:08] VOIX: ou à la saisie -[00:26:09] VOIX: d'un atelier, -[00:26:10] VOIX: etc., -[00:26:11] VOIX: c'est vrai que -[00:26:12] VOIX: stabiloter, -[00:26:13] VOIX: on le faisait manuellement -[00:26:14] VOIX: il y a un moment. -[00:26:15] VOIX: après, -[00:26:16] VOIX: on ajoutait -[00:26:17] VOIX: OGC temps -[00:26:18] VOIX: pour que ça soit -[00:26:19] VOIX: identifié... -[00:26:20] VOIX: Pour moi, -[00:26:20] VOIX: en fait, -[00:26:20] VOIX: ce n'est pas très compliqué -[00:26:21] VOIX: parce que j'ai le numéro, -[00:26:23] VOIX: en fait, -[00:26:23] VOIX: les documents sont numérotés -[00:26:24] VOIX: puisque ça, -[00:26:25] VOIX: c'est le dossier, -[00:26:26] VOIX: en fait, -[00:26:26] VOIX: OGC, -[00:26:27] VOIX: en fait... -[00:26:27] VOIX: Non, -[00:26:27] VOIX: c'est l'OGC 1. -[00:26:28] VOIX: Tu sais, -[00:26:29] VOIX: dans 1, -[00:26:30] VOIX: tu as deux infos. -[00:26:31] VOIX: Tu as l'OGC, -[00:26:33] VOIX: tu as le numéro dossier. -[00:26:35] VOIX: D'accord. -[00:26:36] VOIX: L'OGC, -[00:26:36] VOIX: c'est le chiffre -[00:26:38] VOIX: au départ. -[00:26:39] VOIX: 1, -[00:26:39] VOIX: 2, -[00:26:39] VOIX: 3, -[00:26:39] VOIX: 4, -[00:26:40] VOIX: 5, -[00:26:40] VOIX: 6, -[00:26:40] VOIX: 7, -[00:26:41] VOIX: 7, -[00:26:41] VOIX: 7, -[00:26:42] VOIX: 7, -[00:26:42] VOIX: 7, -[00:27:04] VOIX: D'accord. -[00:27:05] VOIX: Ça, -[00:27:05] VOIX: c'est fait, -[00:27:06] VOIX: en fait. -[00:27:06] VOIX: C'est... -[00:27:09] VOIX: Le classement, -[00:27:10] VOIX: numérotation, -[00:27:11] VOIX: ce genre de choses-là, -[00:27:12] VOIX: ça va. -[00:27:13] VOIX: Donc, -[00:27:14] VOIX: ah oui, -[00:27:15] VOIX: il fallait que j'ouvre, -[00:27:16] VOIX: en fait, -[00:27:16] VOIX: un document. -[00:27:16] VOIX: Alors, -[00:27:16] VOIX: je vais ouvrir, -[00:27:17] VOIX: en fait, -[00:27:17] VOIX: le document propre. -[00:27:19] VOIX: D'accord? -[00:27:20] VOIX: Ouvre le CRH -[00:27:21] VOIX: parce que parfois, -[00:27:21] VOIX: les CRH, -[00:27:22] VOIX: quand ils sont bien faits, -[00:27:24] VOIX: ils peuvent résumer -[00:27:26] VOIX: la situation. -[00:27:27] VOIX: Donc là, -[00:27:28] VOIX: on est sur le CRH. -[00:27:30] VOIX: Attends, -[00:27:31] VOIX: j'essaie, -[00:27:31] VOIX: en fait, -[00:27:32] VOIX: de lire le truc. -[00:27:33] VOIX: C'est le 2304, -[00:27:38] VOIX: 2753.pdf. -[00:27:39] VOIX: Je le dis -[00:27:39] VOIX: parce que -[00:27:39] VOIX: comment j'enregistre, -[00:27:40] VOIX: en fait, -[00:27:40] VOIX: comme ça, -[00:27:41] VOIX: je vais retrouver le dossier. -[00:27:43] VOIX: D'accord. -[00:27:44] VOIX: OK. -[00:27:44] VOIX: Je comprends. -[00:27:45] VOIX: Donc, -[00:27:46] VOIX: voilà. -[00:27:46] VOIX: Donc là, -[00:27:47] VOIX: par quoi, -[00:27:48] VOIX: en fait, -[00:27:48] VOIX: tu commences... -[00:27:50] VOIX: Pardon, -[00:27:51] VOIX: Dominique, -[00:27:51] VOIX: il y a Jordan qui m'appelle -[00:27:52] VOIX: parce que... -[00:27:53] VOIX: Deux minutes, -[00:27:54] VOIX: s'il te plaît. -[00:27:54] VOIX: Vas-y, -[00:27:55] VOIX: vas-y, -[00:27:55] VOIX: vas-y. -[00:27:56] VOIX: Oui, -[00:27:57] VOIX: Jordan. -[00:28:01] VOIX: Je vais dire... -[00:28:02] VOIX: Voilà. -[00:28:03] VOIX: Qu'est-ce que tu regardes -[00:28:05] VOIX: en premier ? -[00:28:06] VOIX: En premier, -[00:28:07] VOIX: je vais regarder -[00:28:08] VOIX: le nom du patient -[00:28:10] VOIX: ou de la patiente -[00:28:11] VOIX: pour vérifier que... -[00:28:14] VOIX: Ça, -[00:28:15] VOIX: je te parle -[00:28:15] VOIX: du primo-codage -[00:28:16] VOIX: habituel. -[00:28:17] VOIX: D'accord ? -[00:28:18] VOIX: Je regarde -[00:28:18] VOIX: le numéro de dossier, -[00:28:19] VOIX: ce qu'on appelle -[00:28:20] VOIX: l'identité de vigilance. -[00:28:21] VOIX: C'est-à-dire, -[00:28:22] VOIX: on s'assure -[00:28:22] VOIX: qu'on est en train -[00:28:24] VOIX: de coder -[00:28:24] VOIX: parce que parfois, -[00:28:26] VOIX: dans des dossiers, -[00:28:27] VOIX: on a des documents -[00:28:28] VOIX: qui sont des patients -[00:28:30] VOIX: erronés. -[00:28:31] VOIX: Donc, -[00:28:32] VOIX: on regarde les noms, -[00:28:34] VOIX: on regarde les dates -[00:28:34] VOIX: de naissance, -[00:28:35] VOIX: on regarde les dates -[00:28:36] VOIX: d'hospitalisation -[00:28:37] VOIX: pour vérifier -[00:28:39] VOIX: qu'on est en cohérence -[00:28:40] VOIX: entre les différents documents, -[00:28:41] VOIX: en fait. -[00:28:41] VOIX: Donc, -[00:28:42] VOIX: il y a un contrôle -[00:28:44] VOIX: un peu -[00:28:45] VOIX: de s'assurer -[00:28:45] VOIX: qu'il n'y a pas -[00:28:47] VOIX: d'erreur -[00:28:47] VOIX: sur le fichier. -[00:28:51] VOIX: De suite, -[00:28:53] VOIX: moi, -[00:28:53] VOIX: en tout cas, -[00:28:53] VOIX: je vais dire -[00:28:54] VOIX: OK, -[00:28:55] VOIX: dans le service -[00:28:56] VOIX: du 25 février -[00:28:57] VOIX: ou 3 mars, -[00:28:59] VOIX: donc, -[00:28:59] VOIX: je sais -[00:28:59] VOIX: qu'il est resté -[00:29:00] VOIX: 7 jours -[00:29:01] VOIX: et donc, -[00:29:02] VOIX: je sais -[00:29:02] VOIX: que mon enjeu, -[00:29:04] VOIX: il va être -[00:29:05] VOIX: autour du diagnostic -[00:29:06] VOIX: principal, -[00:29:07] VOIX: du diagnostic associé -[00:29:08] VOIX: et des axes. -[00:29:09] VOIX: Donc, -[00:29:09] VOIX: il faut que je regarde -[00:29:09] VOIX: le dossier -[00:29:10] VOIX: dans sa globalité. -[00:29:12] VOIX: Tu vois, -[00:29:13] VOIX: la durée de ses jours, -[00:29:14] VOIX: il me donne aussi -[00:29:14] VOIX: une indication -[00:29:17] VOIX: sur les objectifs -[00:29:18] VOIX: parce que si j'avais -[00:29:19] VOIX: un dossier de 2 jours, -[00:29:20] VOIX: je vais aller -[00:29:21] VOIX: beaucoup plus vite -[00:29:22] VOIX: et je vais raisonner -[00:29:23] VOIX: uniquement sur -[00:29:23] VOIX: pourquoi il a été -[00:29:24] VOIX: hospitalisé. -[00:29:25] VOIX: Je ne vais pas -[00:29:26] VOIX: raisonner sur -[00:29:27] VOIX: les niveaux de sévérité, -[00:29:28] VOIX: etc. -[00:29:28] VOIX: Même si ça, -[00:29:29] VOIX: je le mets -[00:29:30] VOIX: en aparté, -[00:29:31] VOIX: tu vois, -[00:29:31] VOIX: ce n'est pas -[00:29:31] VOIX: le raisonnement -[00:29:32] VOIX: qu'il faut apprendre -[00:29:35] VOIX: à l'IA, -[00:29:36] VOIX: il faut lui apprendre -[00:29:36] VOIX: à raisonner sur le dossier -[00:29:38] VOIX: dans sa globalité. -[00:29:40] VOIX: Si j'en fiche un peu -[00:29:40] VOIX: sur certains dossiers, -[00:29:42] VOIX: on sait qu'on peut -[00:29:43] VOIX: les valoriser -[00:29:44] VOIX: uniquement avec le DP. -[00:29:46] VOIX: Donc, -[00:29:47] VOIX: ça, -[00:29:47] VOIX: c'est la première ligne -[00:29:48] VOIX: que je vais regarder, -[00:29:50] VOIX: vérifier, -[00:29:50] VOIX: etc. -[00:29:52] VOIX: Je vais regarder le MH, -[00:29:56] VOIX: c'est motif d'hospitalisation. -[00:29:58] VOIX: En fait, -[00:29:59] VOIX: soit qu'il est clairement, -[00:30:02] VOIX: d'accord, -[00:30:03] VOIX: soit, -[00:30:03] VOIX: il faut le deviner, -[00:30:05] VOIX: c'est toute la difficulté. -[00:30:08] VOIX: Donc là, -[00:30:08] VOIX: j'ai une douleur. -[00:30:10] VOIX: Une douleur, -[00:30:11] VOIX: ça veut dire, -[00:30:12] VOIX: c'est un symptôme. -[00:30:13] VOIX: D'accord ? -[00:30:14] VOIX: Tu te rappelleras -[00:30:14] VOIX: de ce terme -[00:30:15] VOIX: pour que tu vois -[00:30:16] VOIX: le raisonnement après -[00:30:17] VOIX: et pourquoi, -[00:30:18] VOIX: quelles règles -[00:30:19] VOIX: on applique -[00:30:19] VOIX: dans ce cas-là. -[00:30:20] VOIX: Donc, -[00:30:21] VOIX: en fait, -[00:30:22] VOIX: j'ose imaginer -[00:30:24] VOIX: que ce patient-là, -[00:30:26] VOIX: il vient -[00:30:27] VOIX: parce qu'il a -[00:30:28] VOIX: un diagnostic inconnu, -[00:30:30] VOIX: il a mal -[00:30:31] VOIX: et l'établissement, -[00:30:34] VOIX: le médecin, -[00:30:34] VOIX: il va faire -[00:30:35] VOIX: en sorte -[00:30:36] VOIX: d'identifier -[00:30:37] VOIX: pourquoi le patient, -[00:30:38] VOIX: il a mal -[00:30:39] VOIX: et de traiter -[00:30:41] VOIX: en fonction -[00:30:42] VOIX: de la pathologie -[00:30:43] VOIX: ou de traiter -[00:30:44] VOIX: la douleur. -[00:30:45] VOIX: Ça, -[00:30:45] VOIX: c'est la démarche -[00:30:46] VOIX: habituelle. -[00:30:47] VOIX: En même temps, -[00:30:49] VOIX: tu vas avoir -[00:30:51] VOIX: antécédent, -[00:30:53] VOIX: c'est un patient -[00:30:55] VOIX: qui est jeune, -[00:30:55] VOIX: déjà, -[00:30:56] VOIX: tu vois, -[00:30:56] VOIX: ah oui, -[00:30:57] VOIX: je regarde aussi l'âge -[00:30:58] VOIX: parce que si -[00:30:59] VOIX: tu as un patient -[00:30:59] VOIX: âgé, -[00:31:01] VOIX: tu sauras -[00:31:02] VOIX: qu'il va avoir -[00:31:03] VOIX: une liste -[00:31:03] VOIX: d'antécédents, -[00:31:04] VOIX: d'histoires -[00:31:05] VOIX: qui est extrêmement -[00:31:08] VOIX: riche -[00:31:08] VOIX: alors que 80, -[00:31:10] VOIX: tu dis bon, -[00:31:10] VOIX: il est jeune, -[00:31:11] VOIX: il a mal au ventre, -[00:31:13] VOIX: ça doit être -[00:31:13] VOIX: une appendicite -[00:31:15] VOIX: ou un truc -[00:31:15] VOIX: qui va être opéré, -[00:31:17] VOIX: tu vois. -[00:31:18] VOIX: Moi, -[00:31:18] VOIX: ça, -[00:31:19] VOIX: c'est mon raisonnement -[00:31:19] VOIX: un peu médical -[00:31:20] VOIX: en disant -[00:31:21] VOIX: qu'il est jeune -[00:31:22] VOIX: et qu'il a déjà -[00:31:24] VOIX: fait une crise -[00:31:25] VOIX: de colique hépatique -[00:31:26] VOIX: avec vésiculitis, -[00:31:28] VOIX: lithiasique, -[00:31:28] VOIX: donc c'est peut-être -[00:31:29] VOIX: qu'il a une colécystite -[00:31:31] VOIX: aiguë, -[00:31:31] VOIX: colécystite, -[00:31:32] VOIX: c'est la vésicule -[00:31:33] VOIX: biliaire -[00:31:34] VOIX: qui est inflammée -[00:31:35] VOIX: d'autant plus -[00:31:35] VOIX: que là, -[00:31:37] VOIX: on est en février, -[00:31:38] VOIX: en janvier, -[00:31:39] VOIX: je regarde ça aussi, -[00:31:40] VOIX: c'est-à-dire là, -[00:31:41] VOIX: on est en février, -[00:31:42] VOIX: il est rentré -[00:31:42] VOIX: en février 2023 -[00:31:44] VOIX: et un mois avant, -[00:31:45] VOIX: même pas un mois, -[00:31:46] VOIX: il a eu une colique hépatique -[00:31:49] VOIX: avec vésicule lithiasique -[00:31:50] VOIX: sur les côtes. -[00:31:51] VOIX: Donc probablement, -[00:31:52] VOIX: le médecin, -[00:31:53] VOIX: en tout cas, -[00:31:53] VOIX: qui le prend en charge, -[00:31:54] VOIX: il va raisonner -[00:31:54] VOIX: que c'est ce qu'on appelle -[00:31:55] VOIX: une récidive, -[00:31:56] VOIX: c'est-à-dire la lithiase -[00:32:00] VOIX: qui a migré -[00:32:01] VOIX: ou voilà, -[00:32:01] VOIX: un problème en tout cas -[00:32:02] VOIX: digestif. -[00:32:03] VOIX: Il n'a pas de traitement -[00:32:04] VOIX: en cours, -[00:32:05] VOIX: il n'a pas d'allergie, -[00:32:07] VOIX: histoire de la maladie, -[00:32:08] VOIX: donc fin janvier, -[00:32:10] VOIX: premier épisode -[00:32:11] VOIX: de colique hépatique -[00:32:12] VOIX: prise en charge -[00:32:13] VOIX: aux urgences -[00:32:14] VOIX: de Saint-Palais, -[00:32:16] VOIX: lithiase, -[00:32:16] VOIX: tu vois, -[00:32:17] VOIX: ils avaient déjà -[00:32:17] VOIX: diagnostiqué -[00:32:18] VOIX: des lithiases -[00:32:19] VOIX: vésiculaires, -[00:32:19] VOIX: ça veut dire -[00:32:20] VOIX: lithiase de la vésicule -[00:32:21] VOIX: biliaire -[00:32:22] VOIX: sur éco-échographie -[00:32:24] VOIX: en ville, -[00:32:25] VOIX: la patiente -[00:32:26] VOIX: devait revoir -[00:32:27] VOIX: son médecin traitant -[00:32:28] VOIX: début mars -[00:32:28] VOIX: pour évoquer -[00:32:29] VOIX: la suite -[00:32:29] VOIX: de la prise en charge. -[00:32:31] VOIX: Mais le 24 février -[00:32:32] VOIX: vers 20 heures, -[00:32:33] VOIX: récidive, -[00:32:34] VOIX: récidive, -[00:32:35] VOIX: c'est-à-dire de nouveau, -[00:32:37] VOIX: récidive -[00:32:39] VOIX: des douleurs -[00:32:40] VOIX: ou de la maladie, -[00:32:41] VOIX: etc. -[00:32:42] VOIX: En hippocondre droit, -[00:32:44] VOIX: l'hippocondre droit, -[00:32:45] VOIX: c'est la partie -[00:32:45] VOIX: haute droite -[00:32:46] VOIX: du ventre, -[00:32:47] VOIX: ne cédant pas, -[00:32:49] VOIX: arrive hyperalgique, -[00:32:51] VOIX: c'est-à-dire très douloureux -[00:32:53] VOIX: aux urgences, -[00:32:54] VOIX: pas de signe -[00:32:55] VOIX: d'irritation péritoniale. -[00:32:57] VOIX: L'irritation péritoniale, -[00:32:59] VOIX: ça peut évoquer -[00:33:00] VOIX: une péritonite, -[00:33:01] VOIX: c'est-à-dire -[00:33:01] VOIX: quand ils vont palper -[00:33:02] VOIX: le ventre, -[00:33:03] VOIX: ils vont trouver -[00:33:03] VOIX: qu'il est dur. -[00:33:04] VOIX: Donc là, -[00:33:05] VOIX: bon, -[00:33:05] VOIX: il n'y a rien -[00:33:06] VOIX: qui les alerte -[00:33:07] VOIX: d'un point de vue -[00:33:08] VOIX: examen de l'abdomen. -[00:33:10] VOIX: Hémodynamique stable, -[00:33:11] VOIX: c'est-à-dire la tension, -[00:33:12] VOIX: la fréquence cardiaque, -[00:33:13] VOIX: etc., -[00:33:14] VOIX: c'est ça, -[00:33:14] VOIX: l'hémodynamique. -[00:33:15] VOIX: Donc la patiente stable. -[00:33:18] VOIX: Ils ont fait -[00:33:18] VOIX: un bilan biologique, -[00:33:20] VOIX: tout ça, -[00:33:20] VOIX: et regarde, -[00:33:23] VOIX: ça, -[00:33:23] VOIX: il te dit -[00:33:24] VOIX: fin janvier, -[00:33:25] VOIX: donc du coup, -[00:33:26] VOIX: les deux premières phrases, -[00:33:28] VOIX: c'était -[00:33:29] VOIX: aux urgences -[00:33:30] VOIX: de Saint-Palais, -[00:33:31] VOIX: il est rentré -[00:33:32] VOIX: chez lui -[00:33:32] VOIX: et il devait voir -[00:33:34] VOIX: la suite -[00:33:34] VOIX: avec son médecin traitant, -[00:33:36] VOIX: donc c'est -[00:33:36] VOIX: l'ancien, -[00:33:38] VOIX: c'est ancien. -[00:33:39] VOIX: Le 24 février, -[00:33:41] VOIX: il a été hospitalisé -[00:33:42] VOIX: au CHCB, -[00:33:43] VOIX: le 25 février, -[00:33:45] VOIX: il est passé -[00:33:46] VOIX: aux urgences -[00:33:49] VOIX: du CHCB, -[00:33:50] VOIX: je pense, -[00:33:52] VOIX: donc il est rentré -[00:33:53] VOIX: via les urgences, -[00:33:54] VOIX: le 24, -[00:33:56] VOIX: c'est ça ou pas ? -[00:33:57] VOIX: Oui, -[00:33:58] VOIX: oui, -[00:33:58] VOIX: oui. -[00:33:59] VOIX: Point d'interrogation. -[00:34:00] VOIX: Oui. -[00:34:00] VOIX: Parce que là, -[00:34:01] VOIX: la date d'entrée, -[00:34:02] VOIX: c'est le 25, -[00:34:04] VOIX: et là, -[00:34:04] VOIX: il parle du 24 à 20 heures. -[00:34:06] VOIX: Oui, -[00:34:06] VOIX: c'est ça. -[00:34:07] VOIX: En fait, -[00:34:07] VOIX: en fait, -[00:34:07] VOIX: ça, -[00:34:08] VOIX: 24, -[00:34:08] VOIX: 25, -[00:34:09] VOIX: il faut savoir -[00:34:10] VOIX: que nous, -[00:34:10] VOIX: quand on code, -[00:34:11] VOIX: on va vérifier -[00:34:12] VOIX: aussi le dossier -[00:34:13] VOIX: administratif -[00:34:13] VOIX: pour vérifier -[00:34:15] VOIX: est-ce que la date -[00:34:16] VOIX: d'entrée du séjour, -[00:34:17] VOIX: c'est le 24 février -[00:34:17] VOIX: ou c'est le 25 février. -[00:34:19] VOIX: Et est-ce que -[00:34:20] VOIX: quand il part -[00:34:20] VOIX: d'hyper-algie -[00:34:22] VOIX: aux urgences, -[00:34:22] VOIX: est-ce que c'est -[00:34:23] VOIX: les urgences -[00:34:24] VOIX: du centre hospitalier -[00:34:25] VOIX: de la Côte-Basque -[00:34:26] VOIX: ou est-ce que -[00:34:27] VOIX: ce sont les urgences -[00:34:28] VOIX: d'un autre établissement -[00:34:29] VOIX: comme Saint-Palais -[00:34:30] VOIX: qui l'a vu, -[00:34:31] VOIX: il est resté -[00:34:32] VOIX: jusqu'au 25. -[00:34:33] VOIX: Et ça, -[00:34:34] VOIX: c'est peut-être -[00:34:34] VOIX: la suite du dossier -[00:34:35] VOIX: qui va l'expliquer. -[00:34:36] VOIX: Peut-être que -[00:34:37] VOIX: je ne le saurais pas. -[00:34:39] VOIX: Tu vois, -[00:34:40] VOIX: je reviens un peu -[00:34:40] VOIX: pour te dire -[00:34:41] VOIX: comment il a analysé -[00:34:43] VOIX: les choses. -[00:34:44] VOIX: C'est parfait. -[00:34:45] VOIX: Franchement, -[00:34:45] VOIX: en fait, -[00:34:46] VOIX: c'est parfait. -[00:34:46] VOIX: Mais tu m'en avais déjà parlé. -[00:34:47] VOIX: On en a parlé, -[00:34:48] VOIX: je ne sais plus avec qui. -[00:34:50] VOIX: Effectivement, -[00:34:50] VOIX: en fait, -[00:34:50] VOIX: qu'on n'a pas -[00:34:51] VOIX: de suivi nécessairement -[00:34:52] VOIX: entre les urgences -[00:34:55] VOIX: et l'hospitalisation même. -[00:34:57] VOIX: Non. -[00:34:58] VOIX: Parce que le patient, -[00:34:59] VOIX: il peut aller aux urgences, -[00:34:59] VOIX: de Saint-Palais, -[00:35:00] VOIX: Saint-Palais, -[00:35:01] VOIX: il trouve un problème -[00:35:02] VOIX: et il l'envoie au CHTB -[00:35:04] VOIX: pour hospitalisation. -[00:35:05] VOIX: Donc, -[00:35:05] VOIX: quand il dit -[00:35:06] VOIX: arriver aux urgences, -[00:35:08] VOIX: sans préciser, -[00:35:09] VOIX: le plus probable -[00:35:10] VOIX: c'est les urgences -[00:35:11] VOIX: du CHTB, -[00:35:12] VOIX: mais la date m'alerte. -[00:35:14] VOIX: Tu vois, -[00:35:15] VOIX: si c'était le 25, -[00:35:16] VOIX: je l'aurais été -[00:35:17] VOIX: moins alertée. -[00:35:19] VOIX: Donc, -[00:35:20] VOIX: soit qu'entre le 25 et le 25, -[00:35:21] VOIX: il est resté -[00:35:22] VOIX: dans les urgences du CHTB -[00:35:23] VOIX: et il est monté -[00:35:25] VOIX: à l'étage, -[00:35:26] VOIX: là, -[00:35:26] VOIX: je ne sais pas -[00:35:27] VOIX: quel service, -[00:35:28] VOIX: il n'y a pas -[00:35:28] VOIX: de gastro-entéraux -[00:35:29] VOIX: et il est monté. -[00:35:30] VOIX: Donc, -[00:35:30] VOIX: le médecin gastro-entéraux, -[00:35:32] VOIX: il résonne par rapport -[00:35:33] VOIX: à l'heure -[00:35:34] VOIX: et à la date -[00:35:34] VOIX: d'entrée dans le service -[00:35:35] VOIX: et donc, -[00:35:36] VOIX: le patient, -[00:35:36] VOIX: il a passé -[00:35:37] VOIX: une nuit aux urgences. -[00:35:39] VOIX: C'est le plus probable -[00:35:40] VOIX: à la lecture -[00:35:41] VOIX: de... -[00:35:41] VOIX: Oui, parce qu'en fait, -[00:35:42] VOIX: oui, parce que ça dit, -[00:35:44] VOIX: arrive, -[00:35:47] VOIX: arrive hyperalgique -[00:35:48] VOIX: aux urgences, -[00:35:49] VOIX: pas de signe -[00:35:49] VOIX: d'irritation -[00:35:52] VOIX: péritoniale. -[00:35:52] VOIX: Par contre, -[00:35:53] VOIX: en fait, -[00:35:53] VOIX: il y a suffisamment -[00:35:54] VOIX: de précision, -[00:35:57] VOIX: en fait, -[00:35:57] VOIX: le bilan biologique, -[00:35:59] VOIX: en fait, -[00:35:59] VOIX: s'ils l'ont, -[00:36:00] VOIX: en fait, -[00:36:00] VOIX: ça, -[00:36:01] VOIX: c'est parce qu'en fait, -[00:36:01] VOIX: ils ont les informations, -[00:36:02] VOIX: j'imagine. -[00:36:04] VOIX: Oui, -[00:36:05] VOIX: normalement, -[00:36:05] VOIX: il doit y avoir -[00:36:06] VOIX: dans le dossier, -[00:36:07] VOIX: donc, -[00:36:08] VOIX: l'IA, -[00:36:08] VOIX: les documents, -[00:36:09] VOIX: ça, -[00:36:10] VOIX: c'est facile, -[00:36:11] VOIX: si tu lui donnes l'info, -[00:36:13] VOIX: vérifier que la patiente -[00:36:14] VOIX: est passée par les urgences -[00:36:15] VOIX: de l'établissement -[00:36:17] VOIX: et combien de temps -[00:36:18] VOIX: elle est restée, -[00:36:19] VOIX: qu'est-ce qui a été fait, -[00:36:20] VOIX: etc. -[00:36:20] VOIX: Parce que ça, -[00:36:21] VOIX: c'est la synthèse du médecin, -[00:36:22] VOIX: mais normalement, -[00:36:23] VOIX: le passage aux urgences, -[00:36:24] VOIX: il est tracé aussi -[00:36:25] VOIX: dans le dossier patient. -[00:36:26] VOIX: D'accord. -[00:36:27] VOIX: Et peut-être, -[00:36:27] VOIX: quand on va regarder -[00:36:28] VOIX: Tracker, -[00:36:29] VOIX: on aura plus d'infos, -[00:36:31] VOIX: ça nous donnera, -[00:36:32] VOIX: mais là, -[00:36:32] VOIX: je te dis, -[00:36:32] VOIX: c'est-à-dire, -[00:36:34] VOIX: si je n'ai que ce document, -[00:36:36] VOIX: voilà les questions -[00:36:37] VOIX: que je vais me poser -[00:36:37] VOIX: et que je vais chercher -[00:36:39] VOIX: ailleurs -[00:36:40] VOIX: dans le dossier patient. -[00:36:41] VOIX: Parce que rien de ce document -[00:36:42] VOIX: ne permet pas -[00:36:43] VOIX: de répondre -[00:36:43] VOIX: à toutes les questions. -[00:36:44] VOIX: Oui. -[00:36:44] VOIX: Alors là, -[00:36:45] VOIX: par contre, -[00:36:45] VOIX: dans ce document, -[00:36:47] VOIX: en fait, -[00:36:47] VOIX: il y a marqué -[00:36:47] VOIX: donc au total, -[00:36:48] VOIX: premier épisode, -[00:36:49] VOIX: donc en fait, -[00:36:50] VOIX: j'imagine qu'il y a marqué -[00:36:51] VOIX: synthèse, -[00:36:52] VOIX: en fait, -[00:36:52] VOIX: du paragraphe précédent -[00:36:53] VOIX: et c'est premier épisode -[00:36:55] VOIX: de créatite aiguë -[00:36:57] VOIX: d'origine... -[00:36:57] VOIX: Créatite aiguë -[00:36:58] VOIX: d'origine lithéatique. -[00:36:59] VOIX: D'accord, -[00:37:00] VOIX: OK. -[00:37:01] VOIX: OK. -[00:37:01] VOIX: Voilà. -[00:37:01] VOIX: Ils ont fait le bilan -[00:37:03] VOIX: donc ils ont -[00:37:04] VOIX: la lipazémie -[00:37:05] VOIX: qui est élevée -[00:37:05] VOIX: à 6000 -[00:37:06] VOIX: donc supérieure -[00:37:07] VOIX: à 3N -[00:37:08] VOIX: c'est-à-dire -[00:37:08] VOIX: supérieure -[00:37:09] VOIX: à 3 fois -[00:37:10] VOIX: la normale. -[00:37:10] VOIX: Oui. -[00:37:12] VOIX: Tu vois ? -[00:37:13] VOIX: Asat, -[00:37:14] VOIX: 8 fois -[00:37:14] VOIX: la normale. -[00:37:15] VOIX: Alad, -[00:37:16] VOIX: 9 fois -[00:37:16] VOIX: la normale. -[00:37:18] VOIX: Donc le bilan -[00:37:18] VOIX: hépatique, -[00:37:19] VOIX: pancréatique -[00:37:20] VOIX: est complètement perturbé. -[00:37:22] VOIX: Phosphatase -[00:37:23] VOIX: alcaline, -[00:37:23] VOIX: donc le gastroenterot -[00:37:25] VOIX: à partir -[00:37:26] VOIX: de ce bilan, -[00:37:27] VOIX: il dit -[00:37:28] VOIX: premier épisode -[00:37:29] VOIX: parce que la patiente -[00:37:30] VOIX: avait déjà -[00:37:31] VOIX: une lithiase -[00:37:32] VOIX: au niveau -[00:37:32] VOIX: de la vésicule -[00:37:33] VOIX: biliaire -[00:37:34] VOIX: et ces lithiases -[00:37:35] VOIX: peuvent migrer -[00:37:36] VOIX: au niveau -[00:37:36] VOIX: du pancréas -[00:37:37] VOIX: et provoquer -[00:37:38] VOIX: des pancréatites -[00:37:39] VOIX: aiguës. -[00:37:40] VOIX: Et sa pancréatite -[00:37:41] VOIX: aiguë -[00:37:42] VOIX: d'origine lithiasique, -[00:37:43] VOIX: c'est un diagnostic -[00:37:45] VOIX: qui est facile -[00:37:46] VOIX: quelque part. -[00:37:47] VOIX: Mais avant, -[00:37:48] VOIX: tu vas le trouver, -[00:37:49] VOIX: mais est-ce que c'est -[00:37:49] VOIX: des dépées ou pas ? -[00:37:50] VOIX: Je t'expliquerai -[00:37:51] VOIX: pourquoi après, -[00:37:53] VOIX: est-ce qu'on met -[00:37:53] VOIX: la douleur, -[00:37:54] VOIX: est-ce qu'on met -[00:37:54] VOIX: la pancréatite ? -[00:37:56] VOIX: Là, -[00:37:57] VOIX: moi je sais, -[00:37:57] VOIX: mais je vais aller -[00:37:58] VOIX: pousser la réflexion -[00:38:00] VOIX: jusqu'au bout. -[00:38:02] VOIX: Donc suspecté -[00:38:03] VOIX: chez cette patiente -[00:38:04] VOIX: ayant fait un épisode -[00:38:05] VOIX: de colique hépatique -[00:38:06] VOIX: le 23 janvier -[00:38:08] VOIX: avec indication -[00:38:09] VOIX: chirurgicale -[00:38:10] VOIX: retenue. -[00:38:13] VOIX: Donc titration morphine, -[00:38:15] VOIX: comme elle est -[00:38:15] VOIX: hyperalgique -[00:38:16] VOIX: et elle a très mal, -[00:38:17] VOIX: donc ils l'ont mis -[00:38:17] VOIX: sous morphine, -[00:38:19] VOIX: relais en PCA, -[00:38:21] VOIX: c'est-à-dire -[00:38:21] VOIX: ils font, -[00:38:23] VOIX: en fait, -[00:38:24] VOIX: ils font une perf -[00:38:25] VOIX: et c'est le patient -[00:38:26] VOIX: qui gère -[00:38:27] VOIX: la dose. -[00:38:28] VOIX: Oui, -[00:38:28] VOIX: qui a pris -[00:38:29] VOIX: sur le bouton. -[00:38:30] VOIX: Voilà. -[00:38:31] VOIX: Et transfert, -[00:38:32] VOIX: donc tout ça, -[00:38:33] VOIX: c'est la synthèse -[00:38:33] VOIX: des urgences -[00:38:34] VOIX: et ils la transfèrent -[00:38:35] VOIX: dans le service -[00:38:35] VOIX: de gastro-entérologie. -[00:38:37] VOIX: Ils la mettent -[00:38:37] VOIX: sous morphine, -[00:38:38] VOIX: ils la montent -[00:38:39] VOIX: vers un spécialiste. -[00:38:41] VOIX: rassurante -[00:38:41] VOIX: sur le plan -[00:38:42] VOIX: abdominal -[00:38:42] VOIX: parce qu'ici, -[00:38:43] VOIX: il n'a pas de signe -[00:38:45] VOIX: d'irritation péritoniale. -[00:38:47] VOIX: Donc, -[00:38:48] VOIX: s'il y avait -[00:38:48] VOIX: une péritonite, -[00:38:50] VOIX: c'est-à-dire -[00:38:50] VOIX: la patiente, -[00:38:51] VOIX: soit elle a une infection, -[00:38:53] VOIX: soit qu'elle a -[00:38:53] VOIX: une perforation, -[00:38:55] VOIX: soit... -[00:38:55] VOIX: Donc, -[00:38:55] VOIX: ça devient -[00:38:55] VOIX: une urgence. -[00:38:57] VOIX: Là, -[00:38:58] VOIX: ils ont un peu -[00:38:59] VOIX: le temps. -[00:38:59] VOIX: Ils vont calmer -[00:39:00] VOIX: la douleur, -[00:39:01] VOIX: ils vont compléter -[00:39:02] VOIX: le bilan, -[00:39:02] VOIX: etc. -[00:39:03] VOIX: Donc, -[00:39:04] VOIX: ce n'est pas -[00:39:04] VOIX: l'extrême urgence. -[00:39:06] VOIX: Sous PCA de morphine, -[00:39:07] VOIX: elle n'est pas douloureuse -[00:39:08] VOIX: et pas dileus, -[00:39:09] VOIX: c'est-à-dire -[00:39:10] VOIX: pas d'occlusion. -[00:39:12] VOIX: En fait, -[00:39:13] VOIX: ça se code aussi -[00:39:13] VOIX: l'ilus, -[00:39:14] VOIX: mais tu vois, -[00:39:15] VOIX: il n'y a pas -[00:39:15] VOIX: non douloureuse -[00:39:16] VOIX: et pas... -[00:39:18] VOIX: Ça, -[00:39:18] VOIX: c'est tout ça, -[00:39:19] VOIX: c'est important. -[00:39:20] VOIX: L'essai agent initialement -[00:39:22] VOIX: parce qu'il pense -[00:39:23] VOIX: l'opérer. -[00:39:23] VOIX: L'opérer, -[00:39:25] VOIX: évolution rapidement favorable, -[00:39:27] VOIX: permettant une reprise -[00:39:28] VOIX: alimentaire progressive -[00:39:29] VOIX: et arrêt de la PCA de morphine. -[00:39:32] VOIX: Le TDM, -[00:39:33] VOIX: c'est le scanner à J3, -[00:39:34] VOIX: donc il est resté hospitalisé -[00:39:36] VOIX: trois jours. -[00:39:38] VOIX: Absence... -[00:39:38] VOIX: Ah oui, -[00:39:38] VOIX: pardon. -[00:39:40] VOIX: La phrase là, -[00:39:41] VOIX: juste au-dessus, -[00:39:42] VOIX: qu'on a dit -[00:39:42] VOIX: « transfert dans le service -[00:39:44] VOIX: du gastron entero », -[00:39:45] VOIX: c'est plus en faveur -[00:39:47] VOIX: que c'était -[00:39:47] VOIX: les urgences -[00:39:48] VOIX: de l'hôpital. -[00:39:49] VOIX: Tu vois, -[00:39:49] VOIX: ils n'ont pas dit -[00:39:50] VOIX: « transfert à l'hôpital » -[00:39:51] VOIX: ou ils ont dit -[00:39:52] VOIX: « dans le service », -[00:39:53] VOIX: donc à mon avis, -[00:39:53] VOIX: chez eux. -[00:39:55] VOIX: Donc, -[00:39:56] VOIX: il faut un scanner -[00:39:57] VOIX: le troisième jour -[00:39:58] VOIX: d'hospitalisation. -[00:40:00] VOIX: Absence -[00:40:01] VOIX: de signes de gravité, -[00:40:02] VOIX: pas d'anomalie significative -[00:40:03] VOIX: de la glande pancréatique, -[00:40:05] VOIX: ni d'infiltration, -[00:40:06] VOIX: pas de coulée, -[00:40:07] VOIX: pas d'argument -[00:40:08] VOIX: pour une complication -[00:40:10] VOIX: vasculaire, -[00:40:10] VOIX: pas de pseudo-anévrisme -[00:40:12] VOIX: décelé, -[00:40:13] VOIX: minime infiltration -[00:40:14] VOIX: péri-vésiculaire -[00:40:15] VOIX: évocatrice -[00:40:16] VOIX: de pop-at-signe. -[00:40:18] VOIX: Ça, -[00:40:19] VOIX: je ne sais pas -[00:40:19] VOIX: ce que c'est, -[00:40:21] VOIX: mais sans -[00:40:21] VOIX: hydrocholéciste -[00:40:22] VOIX: le jour, -[00:40:23] VOIX: score de Balthazar -[00:40:24] VOIX: à zéro. -[00:40:25] VOIX: Le score de Balthazar, -[00:40:26] VOIX: c'est pour évaluer -[00:40:28] VOIX: la gravité -[00:40:29] VOIX: de la pancréatite. -[00:40:30] VOIX: OK. -[00:40:32] VOIX: Donc, -[00:40:33] VOIX: ils organisent -[00:40:34] VOIX: une cholécistectomie -[00:40:36] VOIX: par cellioscopie -[00:40:37] VOIX: le 1er mars -[00:40:38] VOIX: par docteur, -[00:40:39] VOIX: voilà. -[00:40:41] VOIX: Colangio retrouve -[00:40:42] VOIX: une lithiase -[00:40:43] VOIX: du bacolédoc, -[00:40:44] VOIX: petite taille, -[00:40:46] VOIX: opacification initiale -[00:40:47] VOIX: du diodénum. -[00:40:48] VOIX: après plusieurs injections -[00:40:49] VOIX: de produits -[00:40:50] VOIX: de contraste, -[00:40:50] VOIX: on arrive finalement -[00:40:51] VOIX: à pousser le calcul -[00:40:53] VOIX: qui réussit -[00:40:54] VOIX: à franchir -[00:40:55] VOIX: la papille, -[00:40:55] VOIX: papille du diodénum, -[00:40:57] VOIX: aux pacifications -[00:40:58] VOIX: franches -[00:40:58] VOIX: du diodénum -[00:40:59] VOIX: ou des courses -[00:41:00] VOIX: en lithiase -[00:41:01] VOIX: résiduelle -[00:41:01] VOIX: de la VBP, -[00:41:04] VOIX: c'est la voie -[00:41:04] VOIX: biliaire principale. -[00:41:05] VOIX: Ça, -[00:41:06] VOIX: c'est des petites -[00:41:08] VOIX: abréviations -[00:41:08] VOIX: que petit à petit -[00:41:09] VOIX: on pourra, -[00:41:11] VOIX: tu vois, -[00:41:12] VOIX: à un moment donné, -[00:41:12] VOIX: tu peux sortir -[00:41:13] VOIX: un peu toute la liste -[00:41:14] VOIX: des abréviations -[00:41:15] VOIX: et on essaiera -[00:41:16] VOIX: de leur donner -[00:41:18] VOIX: du sens. -[00:41:19] VOIX: Et pas de gramme -[00:41:20] VOIX: complet, -[00:41:20] VOIX: pas de cathéter -[00:41:21] VOIX: nécessaire donc, -[00:41:23] VOIX: puisque en fait, -[00:41:24] VOIX: avec le produit -[00:41:25] VOIX: de contraste, -[00:41:26] VOIX: ils ont réussi -[00:41:26] VOIX: à pousser -[00:41:28] VOIX: la lithiase -[00:41:29] VOIX: et donc du coup, -[00:41:30] VOIX: ils ont libéré -[00:41:31] VOIX: la voie. -[00:41:31] VOIX: C'est comme si la lithiase -[00:41:33] VOIX: qui a bougé -[00:41:33] VOIX: au fait de sa place -[00:41:34] VOIX: et c'est ça -[00:41:35] VOIX: qui a provoqué -[00:41:36] VOIX: tout ça. -[00:41:37] VOIX: Donc, -[00:41:39] VOIX: fermeture, -[00:41:40] VOIX: colle, -[00:41:41] VOIX: ça c'est, -[00:41:41] VOIX: tu vois, -[00:41:42] VOIX: c'est un peu, -[00:41:42] VOIX: ça c'est un peu -[00:41:44] VOIX: un semblant -[00:41:45] VOIX: de cro, -[00:41:47] VOIX: de courrier opératoire. -[00:41:49] VOIX: Et en fait, -[00:41:50] VOIX: ça, -[00:41:51] VOIX: quand il te dit -[00:41:52] VOIX: cholecystectomie -[00:41:53] VOIX: par cellio, -[00:41:54] VOIX: cholangiographie, -[00:41:55] VOIX: etc., -[00:41:55] VOIX: tu sais, -[00:41:56] VOIX: on a touché -[00:41:57] VOIX: un petit peu -[00:41:57] VOIX: aux actes CCAM -[00:41:58] VOIX: parce qu'on parle -[00:42:00] VOIX: de l'allemand -[00:42:01] VOIX: que la tire C1-10 -[00:42:02] VOIX: des diagnostics -[00:42:02] VOIX: mais là, -[00:42:06] VOIX: on peut raisonner -[00:42:09] VOIX: d'emblée -[00:42:10] VOIX: sur les actes -[00:42:11] VOIX: si tu veux. -[00:42:11] VOIX: C'est-à-dire, -[00:42:12] VOIX: tu vérifies -[00:42:13] VOIX: que le codage -[00:42:13] VOIX: du scanner -[00:42:14] VOIX: il est fait, -[00:42:15] VOIX: que la colle -[00:42:16] VOIX: citectomie -[00:42:16] VOIX: est faite, -[00:42:17] VOIX: la collangiologie -[00:42:18] VOIX: est faite, -[00:42:19] VOIX: etc. -[00:42:19] VOIX: Tu vois, -[00:42:20] VOIX: parce que tu introduis -[00:42:21] VOIX: la logique globale -[00:42:22] VOIX: du dossier -[00:42:23] VOIX: avec les diagnostics -[00:42:24] VOIX: et les actes. -[00:42:26] VOIX: Donc ça, -[00:42:27] VOIX: ça peut être intéressant. -[00:42:29] VOIX: Donc, -[00:42:30] VOIX: il fait la colle, -[00:42:31] VOIX: OK, -[00:42:34] VOIX: alimentaire bien toléré, -[00:42:35] VOIX: bilan hépatique -[00:42:36] VOIX: à la sortie. -[00:42:37] VOIX: Donc, -[00:42:37] VOIX: tu vois le bilan hépatique, -[00:42:39] VOIX: tous les bilans -[00:42:39] VOIX: qu'ils ont fait -[00:42:40] VOIX: à l'entrée, -[00:42:40] VOIX: comment ça s'est amélioré. -[00:42:42] VOIX: Éruption, -[00:42:43] VOIX: elle fait une complication. -[00:42:47] VOIX: Éruption cutanée, -[00:42:48] VOIX: érythémateuse, -[00:42:49] VOIX: prédominant sur le tronc, -[00:42:50] VOIX: apparu le matin -[00:42:51] VOIX: de la sortie. -[00:42:53] VOIX: Nantes origineuses, -[00:42:54] VOIX: c'est-à-dire, -[00:42:54] VOIX: elle ne gratte pas -[00:42:55] VOIX: dans un contexte -[00:42:56] VOIX: de prise de contre-armale -[00:42:58] VOIX: la veille. -[00:42:58] VOIX: Donc, -[00:42:58] VOIX: elle a fait probablement -[00:43:00] VOIX: une réaction médicamenteuse. -[00:43:03] VOIX: Introduction de cétérisine -[00:43:05] VOIX: qui est un antihistaminique. -[00:43:07] VOIX: Au total, -[00:43:09] VOIX: pancréasite aiguille, -[00:43:10] VOIX: thiasique, -[00:43:10] VOIX: sensible de gravité, -[00:43:11] VOIX: polycystectomisé. -[00:43:13] VOIX: Donc, -[00:43:13] VOIX: avec ces éléments-là, -[00:43:16] VOIX: on peut déjà proposer -[00:43:19] VOIX: un codage -[00:43:20] VOIX: et je t'explique -[00:43:20] VOIX: sur quoi -[00:43:21] VOIX: je vais me baser -[00:43:23] VOIX: pour coder ce dossier. -[00:43:25] VOIX: En fait, -[00:43:28] VOIX: on peut partager mon écran. -[00:43:31] VOIX: Attends, -[00:43:31] VOIX: je vais ouvrir le fichier -[00:43:32] VOIX: que je voulais ouvrir. -[00:43:36] VOIX: Le guide méthodologique. -[00:43:40] VOIX: Attends, -[00:43:40] VOIX: je te reprends. -[00:43:42] VOIX: C'est moi qui partage. -[00:43:44] VOIX: Vas-y, vas-y. -[00:43:44] VOIX: Je fais une pause -[00:43:45] VOIX: et on revient -[00:43:46] VOIX: à ton dossier après. -[00:43:48] VOIX: D'accord ? -[00:43:48] VOIX: Vas-y, vas-y. -[00:43:50] VOIX: OK. -[00:43:50] VOIX: Alors, -[00:43:52] VOIX: le guide méthodologique, -[00:43:54] VOIX: ce document, -[00:43:54] VOIX: je pense que je te l'avais envoyé. -[00:43:56] VOIX: Je l'ai. -[00:43:57] VOIX: Il est mis à jour -[00:43:58] VOIX: tous les ans, -[00:43:59] VOIX: tu vois, -[00:44:00] VOIX: décembre 2025, -[00:44:01] VOIX: parce que c'est celui -[00:44:01] VOIX: qui s'applique en 2026. -[00:44:03] VOIX: Ce qui nous intéresse -[00:44:05] VOIX: dans -[00:44:08] VOIX: ce fichier, -[00:44:10] VOIX: c'était -[00:44:12] VOIX: à partir -[00:44:13] VOIX: du chapitre 5. -[00:44:16] VOIX: Non, -[00:44:16] VOIX: hiérarchisation, -[00:44:17] VOIX: à partir de là. -[00:44:19] VOIX: Hiérarchisation, -[00:44:21] VOIX: et codage -[00:44:22] VOIX: des informations médicales -[00:44:23] VOIX: du résumé -[00:44:23] VOIX: d'unité médicale. -[00:44:25] VOIX: On va aller là. -[00:44:28] VOIX: Et en fait, -[00:44:29] VOIX: attends, -[00:44:31] VOIX: attends, -[00:44:33] VOIX: non, -[00:44:33] VOIX: je reviens avant -[00:44:35] VOIX: parce que ça, -[00:44:36] VOIX: ça va être -[00:44:36] VOIX: beaucoup de blabla. -[00:44:38] VOIX: Je vais te montrer. -[00:44:39] VOIX: Le sommaire, -[00:44:39] VOIX: tu vois déjà, -[00:44:40] VOIX: on va réussir -[00:44:41] VOIX: à expliquer un peu -[00:44:44] VOIX: parce que tu as -[00:44:44] VOIX: les actes -[00:44:46] VOIX: quand on dit -[00:44:46] VOIX: que les cystectomies, -[00:44:47] VOIX: c'est les actes. -[00:44:49] VOIX: Et -[00:44:51] VOIX: consigne de codage, -[00:44:52] VOIX: les situations cliniques. -[00:44:54] VOIX: Voilà. -[00:44:56] VOIX: En fait, -[00:44:57] VOIX: comme tout à l'heure, -[00:44:58] VOIX: je t'ai dit, -[00:44:59] VOIX: le patient, -[00:44:59] VOIX: il est rentré -[00:45:00] VOIX: pour un symptôme -[00:45:01] VOIX: qui est la douleur. -[00:45:03] VOIX: Et le séjour, -[00:45:05] VOIX: je vais, -[00:45:05] VOIX: soit son objectif, -[00:45:07] VOIX: ça sera faire un bilan -[00:45:08] VOIX: pour poser le diagnostic, -[00:45:12] VOIX: soit traiter -[00:45:13] VOIX: et ou traiter. -[00:45:16] VOIX: Là, -[00:45:16] VOIX: il y a eu les deux. -[00:45:17] VOIX: D'accord ? -[00:45:18] VOIX: Puisqu'ils ont posé -[00:45:21] VOIX: le diagnostic -[00:45:21] VOIX: de pancréatite -[00:45:22] VOIX: et non seulement -[00:45:23] VOIX: ils ont posé -[00:45:24] VOIX: le diagnostic, -[00:45:25] VOIX: mais ils ont mis -[00:45:26] VOIX: en place -[00:45:26] VOIX: un traitement chirurgical. -[00:45:28] VOIX: Et le traitement -[00:45:30] VOIX: chirurgical, -[00:45:31] VOIX: c'est ce qu'on appelle -[00:45:32] VOIX: le traitement unique. -[00:45:35] VOIX: Et si je clique -[00:45:36] VOIX: ici, -[00:45:38] VOIX: la règle -[00:45:38] VOIX: que je vais utiliser, -[00:45:40] VOIX: c'est celle-ci. -[00:45:41] VOIX: Le traitement unique -[00:45:42] VOIX: chirurgical. -[00:45:43] VOIX: Dans la situation -[00:45:44] VOIX: de traitement unique -[00:45:45] VOIX: chirurgical, -[00:45:46] VOIX: le DP -[00:45:47] VOIX: et en général -[00:45:48] VOIX: la maladie opérée. -[00:45:49] VOIX: Le DP, -[00:45:50] VOIX: c'est le diagnostic -[00:45:50] VOIX: principal. -[00:45:51] VOIX: Donc là, -[00:45:52] VOIX: cette règle, -[00:45:53] VOIX: il se dit, -[00:45:55] VOIX: si le patient -[00:45:57] VOIX: a été opéré, -[00:45:58] VOIX: le diagnostic -[00:45:59] VOIX: principal -[00:46:00] VOIX: et la maladie -[00:46:01] VOIX: opérée. -[00:46:02] VOIX: D'accord. -[00:46:05] VOIX: Donc, -[00:46:06] VOIX: le patient, -[00:46:07] VOIX: d'ailleurs ça, -[00:46:08] VOIX: ce type de dossier, -[00:46:09] VOIX: il nous a posé, -[00:46:10] VOIX: en fait, -[00:46:10] VOIX: il est censé -[00:46:11] VOIX: être simple, -[00:46:12] VOIX: mais il peut -[00:46:13] VOIX: poser certains problèmes. -[00:46:15] VOIX: Je vais -[00:46:16] VOIX: t'expliquer -[00:46:17] VOIX: le problème. -[00:46:19] VOIX: En fait, -[00:46:20] VOIX: l'acte chirurgical, -[00:46:21] VOIX: c'était quoi ? -[00:46:25] VOIX: Ils ont enlevé -[00:46:26] VOIX: la cholecystite. -[00:46:28] VOIX: Alors, -[00:46:28] VOIX: attends, -[00:46:28] VOIX: tu veux que je le reprenne -[00:46:29] VOIX: le truc que je te dise ? -[00:46:32] VOIX: Oui, -[00:46:32] VOIX: moi je sais, -[00:46:33] VOIX: je veux juste -[00:46:35] VOIX: vérifier que tu suis -[00:46:36] VOIX: la logique. -[00:46:37] VOIX: Non, -[00:46:37] VOIX: mais je suis la logique, -[00:46:39] VOIX: mais les termes -[00:46:40] VOIX: et tout ça, -[00:46:40] VOIX: en fait, -[00:46:41] VOIX: oui, -[00:46:41] VOIX: c'est pour enlever -[00:46:41] VOIX: la cholecystécomie. -[00:46:43] VOIX: C'est -[00:46:43] VOIX: la cholecystécomie. -[00:46:46] VOIX: La cholecystécomie. -[00:46:46] VOIX: Voilà. -[00:46:47] VOIX: tout ce qui est cholecyste, -[00:46:48] VOIX: c'est en lien -[00:46:49] VOIX: avec la vésicule biliaire. -[00:46:51] VOIX: Cholecystéctomie, -[00:46:52] VOIX: ça veut dire -[00:46:52] VOIX: qu'ils ont enlevé -[00:46:54] VOIX: la vésicule biliaire. -[00:46:55] VOIX: Et donc là, -[00:46:56] VOIX: ça va être -[00:46:57] VOIX: la cholecystéctomie, -[00:46:58] VOIX: parcellioscopie, -[00:46:59] VOIX: etc. -[00:47:01] VOIX: Et pourquoi -[00:47:01] VOIX: ils ont enlevé -[00:47:03] VOIX: la vésicule biliaire ? -[00:47:05] VOIX: En fait, -[00:47:07] VOIX: soit on va dire -[00:47:08] VOIX: qu'ils l'ont enlevé -[00:47:09] VOIX: parce qu'il y avait -[00:47:10] VOIX: une lithiase -[00:47:11] VOIX: au niveau -[00:47:12] VOIX: de la vésicule biliaire, -[00:47:14] VOIX: soit -[00:47:14] VOIX: tu te fies -[00:47:16] VOIX: à ce qu'il a écrit -[00:47:17] VOIX: le chirurgien -[00:47:18] VOIX: et tu dis -[00:47:18] VOIX: pour pancréatite. -[00:47:20] VOIX: D'accord ? -[00:47:22] VOIX: Je vais chercher -[00:47:24] VOIX: quelque chose -[00:47:25] VOIX: parce que je pense -[00:47:25] VOIX: que ce type -[00:47:26] VOIX: de dossier -[00:47:27] VOIX: nous a posé -[00:47:28] VOIX: d'énormes problèmes. -[00:47:31] VOIX: Initial pendant -[00:47:33] VOIX: le contrôle -[00:47:34] VOIX: et après -[00:47:35] VOIX: le contrôle, -[00:47:38] VOIX: rapport -[00:47:38] VOIX: du contrôle -[00:47:39] VOIX: et document -[00:47:40] VOIX: transmis -[00:47:41] VOIX: à l'UE 1, -[00:47:43] VOIX: fiche de concertation. -[00:47:45] VOIX: Je ne sais plus -[00:47:46] VOIX: comment il était codé -[00:47:47] VOIX: le dossier 1. -[00:47:49] VOIX: Parce que c'est ça -[00:47:49] VOIX: qui est intéressant. -[00:47:51] VOIX: Parce que, -[00:47:51] VOIX: en fait, -[00:47:52] VOIX: si je code -[00:47:53] VOIX: pour créatite, -[00:47:56] VOIX: mon dossier -[00:47:57] VOIX: va être mieux -[00:47:57] VOIX: valorisé -[00:47:58] VOIX: que si je code -[00:48:02] VOIX: lithiase -[00:48:02] VOIX: de la vésicule -[00:48:04] VOIX: biliaire. -[00:48:06] VOIX: On a eu -[00:48:06] VOIX: des dossiers -[00:48:07] VOIX: pour lesquels -[00:48:09] VOIX: à l'UE 1. -[00:48:10] VOIX: pourquoi je... -[00:48:11] VOIX: Pardon, -[00:48:11] VOIX: j'arrête -[00:48:11] VOIX: le partage -[00:48:12] VOIX: juste pour -[00:48:14] VOIX: ouvrir... -[00:48:14] VOIX: J'arrête pas -[00:48:15] VOIX: à ouvrir -[00:48:15] VOIX: mon répertoire, -[00:48:16] VOIX: ça m'énerve. -[00:48:19] VOIX: Ah ! -[00:48:20] VOIX: C'est pas cool. -[00:48:22] VOIX: Document -[00:48:23] VOIX: transmis -[00:48:23] VOIX: à l'avocat. -[00:48:25] VOIX: Le répertoire -[00:48:26] VOIX: évite. -[00:48:29] VOIX: Fiche médicale. -[00:48:31] VOIX: Ah oui, -[00:48:31] VOIX: ça tourne. -[00:48:32] VOIX: Encore, -[00:48:32] VOIX: c'est pour ça. -[00:48:33] VOIX: C'est long. -[00:48:37] VOIX: En fait, -[00:48:38] VOIX: ce que je ne sais pas, -[00:48:39] VOIX: mais je vais essayer -[00:48:40] VOIX: peut-être de me connecter -[00:48:41] VOIX: sur l'appli. -[00:48:42] VOIX: En fait, -[00:48:43] VOIX: avant de dire -[00:48:44] VOIX: le codage, -[00:48:45] VOIX: je t'explique déjà -[00:48:46] VOIX: la règle que j'utilise. -[00:48:47] VOIX: Le patient, -[00:48:48] VOIX: il est opéré. -[00:48:50] VOIX: La pathologie opérée, -[00:48:51] VOIX: c'est elle -[00:48:52] VOIX: qui va être -[00:48:53] VOIX: le diagnostic principal. -[00:48:54] VOIX: Déjà, -[00:48:54] VOIX: ça, -[00:48:55] VOIX: c'est la règle. -[00:48:55] VOIX: Et tu as vu -[00:48:56] VOIX: qu'il y avait -[00:48:56] VOIX: plusieurs règles -[00:48:57] VOIX: en fonction -[00:48:57] VOIX: de pourquoi -[00:48:58] VOIX: le patient -[00:48:58] VOIX: est rentré. -[00:48:59] VOIX: Et en fonction -[00:49:00] VOIX: des dossiers, -[00:49:01] VOIX: je pourrais -[00:49:01] VOIX: t'expliquer -[00:49:02] VOIX: la règle -[00:49:03] VOIX: et te dire, -[00:49:04] VOIX: et c'est ça -[00:49:04] VOIX: qu'il faut mettre, -[00:49:05] VOIX: c'est-à-dire, -[00:49:06] VOIX: dans les argumentaires, -[00:49:07] VOIX: on peut dire -[00:49:08] VOIX: selon la règle T3 -[00:49:09] VOIX: de traitement unique -[00:49:10] VOIX: chirurgical, -[00:49:11] VOIX: le patient opéré, -[00:49:13] VOIX: c'est la phrase -[00:49:13] VOIX: que je t'ai montrée -[00:49:14] VOIX: dans le guide, -[00:49:15] VOIX: et je dis, -[00:49:16] VOIX: c'est comme ça -[00:49:17] VOIX: qu'on écrit -[00:49:17] VOIX: nos argumentaires, -[00:49:19] VOIX: et dire, -[00:49:19] VOIX: dans le CROS -[00:49:20] VOIX: ou dans le CRH, -[00:49:22] VOIX: dans le document -[00:49:23] VOIX: en tout cas écrit -[00:49:23] VOIX: par le médecin -[00:49:24] VOIX: à telle date, -[00:49:25] VOIX: etc., -[00:49:25] VOIX: c'est pour ça -[00:49:26] VOIX: que c'est bien -[00:49:26] VOIX: d'identifier -[00:49:27] VOIX: la nature -[00:49:28] VOIX: des documents. -[00:49:30] VOIX: Le médecin -[00:49:31] VOIX: a écrit -[00:49:32] VOIX: que le chirurgien -[00:49:34] VOIX: a écrit -[00:49:34] VOIX: que le patient -[00:49:35] VOIX: a été opéré -[00:49:36] VOIX: par colibistectomie -[00:49:37] VOIX: pour pancréatite. -[00:49:39] VOIX: Si je veux défendre -[00:49:40] VOIX: la pancréatite, -[00:49:42] VOIX: j'essaie de trouver -[00:49:44] VOIX: un peu -[00:49:45] VOIX: parce que -[00:49:47] VOIX: c'est important -[00:49:49] VOIX: que je puisse -[00:49:51] VOIX: te montrer -[00:49:52] VOIX: et en même temps -[00:49:54] VOIX: je vais partager -[00:49:55] VOIX: quand même mon écran -[00:49:56] VOIX: parce que -[00:49:56] VOIX: tu vas voir -[00:49:59] VOIX: et bien -[00:50:02] VOIX: je vais partager -[00:50:04] VOIX: parce que -[00:50:04] VOIX: je suis en train -[00:50:06] VOIX: de me connecter -[00:50:08] VOIX: à l'appli -[00:50:09] VOIX: et -[00:50:15] VOIX: on va voir -[00:50:18] VOIX: est-ce que ça va marcher ? -[00:50:20] VOIX: mon téléphone -[00:50:22] VOIX: mais sécurisé -[00:50:24] VOIX: ça fait plaisir -[00:50:34] VOIX: 742 -[00:50:36] VOIX: 743 -[00:50:38] VOIX: oui c'est -[00:50:39] VOIX: authenticateur -[00:50:42] VOIX: donc ça c'est -[00:50:44] VOIX: le contrôle -[00:50:44] VOIX: duquel on parle -[00:50:45] VOIX: et on va voir -[00:50:50] VOIX: suivre -[00:50:51] VOIX: ce serait plus simple -[00:50:52] VOIX: et soit -[00:50:54] VOIX: je mets -[00:50:54] VOIX: 1 -[00:50:55] VOIX: c'était -[00:50:56] VOIX: l'OGC1 -[00:50:57] VOIX: oui c'est ça -[00:50:57] VOIX: c'est ça -[00:51:14] VOIX: comment j'ai pile poil -[00:51:15] VOIX: en face -[00:51:16] VOIX: qu'il me faut -[00:51:16] VOIX: attends je prends -[00:51:17] VOIX: une photo -[00:51:17] VOIX: d'accord -[00:51:19] VOIX: oui mais il y a -[00:51:20] VOIX: une erreur -[00:51:21] VOIX: que je vois -[00:51:23] VOIX: et c'est -[00:51:24] VOIX: c'est quoi -[00:51:24] VOIX: l'erreur -[00:51:27] VOIX: en haut -[00:51:28] VOIX: c'est marqué -[00:51:30] VOIX: K802 -[00:51:30] VOIX: ouais -[00:51:32] VOIX: en bas -[00:51:33] VOIX: c'est marqué -[00:51:34] VOIX: K851 -[00:51:36] VOIX: donc du coup -[00:51:37] VOIX: moi je vais -[00:51:37] VOIX: faire une copie -[00:51:38] VOIX: d'écran -[00:51:41] VOIX: et je vais -[00:51:41] VOIX: l'envoyer -[00:51:42] VOIX: à Jordan -[00:51:50] VOIX: est-ce que -[00:51:50] VOIX: je sais -[00:51:51] VOIX: qu'on a eu -[00:51:51] VOIX: un problème -[00:51:52] VOIX: sur -[00:51:52] VOIX: donc du coup -[00:51:53] VOIX: je ne sais plus -[00:51:54] VOIX: quel est le codage -[00:51:55] VOIX: initial -[00:51:55] VOIX: de l'établissement -[00:52:02] VOIX: c'est ballon -[00:52:08] VOIX: tu vois -[00:52:08] VOIX: on a -[00:52:09] VOIX: ce type -[00:52:09] VOIX: de problème -[00:52:10] VOIX: ouais -[00:52:19] VOIX: donc -[00:52:23] VOIX: K802 -[00:52:25] VOIX: en haut -[00:52:26] VOIX: et K -[00:52:29] VOIX: 851 -[00:52:30] VOIX: en bas -[00:52:47] VOIX: et tous -[00:52:48] VOIX: nos argumentaires -[00:52:49] VOIX: en fait -[00:52:50] VOIX: on les fait -[00:52:51] VOIX: avec l'OGC -[00:52:52] VOIX: parce que c'est -[00:52:53] VOIX: ce qui est censé -[00:52:54] VOIX: être anonyme -[00:52:55] VOIX: voilà -[00:52:57] VOIX: donc du coup -[00:52:58] VOIX: si je reviens -[00:52:59] VOIX: à l'appli -[00:53:00] VOIX: la fiche -[00:53:01] VOIX: où j'essaye -[00:53:02] VOIX: c'est ça -[00:53:03] VOIX: et tu vois -[00:53:04] VOIX: quand je t'ai dit -[00:53:04] VOIX: tout à l'heure -[00:53:05] VOIX: la préparation -[00:53:10] VOIX: ça c'est -[00:53:15] VOIX: tu m'entends -[00:53:16] VOIX: oui -[00:53:16] VOIX: oui voilà -[00:53:18] VOIX: en fait -[00:53:19] VOIX: là ça c'est -[00:53:20] VOIX: la préparation -[00:53:21] VOIX: ça c'est le travail -[00:53:22] VOIX: du team -[00:53:23] VOIX: oui -[00:53:24] VOIX: d'accord -[00:53:25] VOIX: team -[00:53:25] VOIX: du team -[00:53:26] VOIX: c'est tout le monde -[00:53:27] VOIX: il passe -[00:53:28] VOIX: d'accord -[00:53:30] VOIX: ici sur ton programme -[00:53:31] VOIX: en fait -[00:53:31] VOIX: le contrôleur -[00:53:32] VOIX: en fait -[00:53:32] VOIX: c'est quand tu as -[00:53:34] VOIX: un contrôle -[00:53:35] VOIX: tes deux ans -[00:53:36] VOIX: ou c'est quelque chose -[00:53:37] VOIX: de plus -[00:53:37] VOIX: ça c'est le retour -[00:53:39] VOIX: ça c'est les fiches -[00:53:40] VOIX: du médecin contrôleur -[00:53:41] VOIX: c'est le retour -[00:53:43] VOIX: du médecin contrôleur -[00:53:44] VOIX: ouais -[00:53:44] VOIX: il te laisse un peu -[00:53:46] VOIX: de temps -[00:53:47] VOIX: pour revoir -[00:53:49] VOIX: c'est là -[00:53:49] VOIX: où l'argumentaire -[00:53:51] VOIX: de -[00:53:51] VOIX: parce que là -[00:53:53] VOIX: tu fais tes argumentaires -[00:53:54] VOIX: tu prépares ton podage -[00:53:57] VOIX: je vais te montrer -[00:53:58] VOIX: un autre dossier -[00:53:59] VOIX: le dossier -[00:54:00] VOIX: qu'on met -[00:54:00] VOIX: dans -[00:54:02] VOIX: dans la démo -[00:54:03] VOIX: c'est le -[00:54:03] VOIX: 399 -[00:54:04] VOIX: il me semble -[00:54:05] VOIX: parce qu'il est plus -[00:54:08] VOIX: par exemple -[00:54:08] VOIX: l'établissement -[00:54:09] VOIX: il a préparé -[00:54:11] VOIX: il a dit -[00:54:12] VOIX: ce code là -[00:54:13] VOIX: je ne vais pas -[00:54:15] VOIX: pouvoir le défendre -[00:54:16] VOIX: d'accord -[00:54:18] VOIX: et -[00:54:20] VOIX: le médecin contrôleur -[00:54:22] VOIX: quand il a fait -[00:54:23] VOIX: sa fiche -[00:54:24] VOIX: il a dit -[00:54:25] VOIX: non seulement -[00:54:26] VOIX: je ne retiens pas -[00:54:27] VOIX: ce code -[00:54:28] VOIX: mais en plus -[00:54:29] VOIX: je vous modifie -[00:54:30] VOIX: le DP -[00:54:31] VOIX: ok -[00:54:32] VOIX: tu vois -[00:54:33] VOIX: et donc là -[00:54:34] VOIX: nous -[00:54:35] VOIX: on avait préparé -[00:54:36] VOIX: à -[00:54:37] VOIX: à perdre -[00:54:38] VOIX: 1500 -[00:54:39] VOIX: lui -[00:54:40] VOIX: il nous a fait -[00:54:41] VOIX: un codage -[00:54:42] VOIX: qui nous fait perdre -[00:54:42] VOIX: presque 4000 euros -[00:54:44] VOIX: quand j'ai vu -[00:54:46] VOIX: son truc -[00:54:47] VOIX: parce que -[00:54:48] VOIX: dans la prépa -[00:54:49] VOIX: on ne s'est pas -[00:54:50] VOIX: attendu -[00:54:50] VOIX: à ce qu'il nous -[00:54:51] VOIX: change ça -[00:54:53] VOIX: donc on n'a pas -[00:54:54] VOIX: argumenté -[00:54:55] VOIX: sur ce changement -[00:54:56] VOIX: de DP -[00:54:58] VOIX: on a argumenté -[00:55:00] VOIX: sur ça -[00:55:00] VOIX: peut-être -[00:55:01] VOIX: on l'avait préparé -[00:55:02] VOIX: mais donc -[00:55:02] VOIX: je suis obligé -[00:55:03] VOIX: de avant -[00:55:04] VOIX: la concertation -[00:55:05] VOIX: la concertation -[00:55:06] VOIX: c'est la clôture -[00:55:08] VOIX: finale -[00:55:09] VOIX: du dossier -[00:55:10] VOIX: avant la clôture -[00:55:11] VOIX: finale -[00:55:12] VOIX: du dossier -[00:55:12] VOIX: il -[00:55:15] VOIX: c'est comme -[00:55:16] VOIX: si là -[00:55:16] VOIX: moi -[00:55:16] VOIX: le dim -[00:55:17] VOIX: je fais -[00:55:17] VOIX: une contre-proposition -[00:55:19] VOIX: je ne me rappelle -[00:55:20] VOIX: plus -[00:55:20] VOIX: de l'argument -[00:55:21] VOIX: mais tu vois -[00:55:23] VOIX: par exemple -[00:55:23] VOIX: au début -[00:55:25] VOIX: ça c'est un bug -[00:55:28] VOIX: qu'on a -[00:55:29] VOIX: tu vois -[00:55:29] VOIX: tous les arguments -[00:55:31] VOIX: qu'on écrit -[00:55:31] VOIX: le médecin -[00:55:32] VOIX: il a fait ça -[00:55:33] VOIX: il a écrit ça -[00:55:34] VOIX: bla bla bla -[00:55:39] VOIX: ah j'ai un bug -[00:55:41] VOIX: j'ai un bug -[00:55:42] VOIX: dans l'appli -[00:55:43] VOIX: depuis hier -[00:55:45] VOIX: ça je l'ai dit -[00:55:46] VOIX: à Jorban -[00:55:46] VOIX: je ne pourrais pas -[00:55:47] VOIX: te montrer correctement -[00:55:48] VOIX: c'est pas grave -[00:55:49] VOIX: ça s'affiche -[00:55:50] VOIX: tu vois -[00:55:51] VOIX: en fait -[00:55:53] VOIX: nous -[00:55:53] VOIX: on avait -[00:55:54] VOIX: on avait -[00:55:55] VOIX: enlevé -[00:55:56] VOIX: le diagnostic associé -[00:55:57] VOIX: mais on n'avait -[00:55:58] VOIX: pas préparé ça -[00:55:59] VOIX: ça -[00:56:00] VOIX: je l'ai fait -[00:56:00] VOIX: a posteriori -[00:56:01] VOIX: pour dire -[00:56:02] VOIX: parce que je ne savais pas -[00:56:04] VOIX: qu'il allait modifier -[00:56:04] VOIX: le DP -[00:56:06] VOIX: mais si mon argumentaire -[00:56:08] VOIX: au début -[00:56:08] VOIX: initial -[00:56:09] VOIX: il est fait par l'IA -[00:56:11] VOIX: il évoque aussi -[00:56:13] VOIX: ce risque -[00:56:13] VOIX: il pourra me préparer -[00:56:15] VOIX: à dire -[00:56:16] VOIX: attention -[00:56:16] VOIX: le DP peut être -[00:56:18] VOIX: modifié -[00:56:19] VOIX: pourquoi -[00:56:19] VOIX: etc -[00:56:20] VOIX: voici votre argumentaire -[00:56:22] VOIX: pour le défendre -[00:56:23] VOIX: par exemple -[00:56:24] VOIX: tu vois -[00:56:25] VOIX: et lui -[00:56:26] VOIX: le contrôleur -[00:56:27] VOIX: il m'a fait sauter -[00:56:30] VOIX: il fait son argumentaire -[00:56:32] VOIX: pour dire -[00:56:32] VOIX: non mais moi -[00:56:33] VOIX: je garde -[00:56:34] VOIX: la règle -[00:56:34] VOIX: SA -[00:56:35] VOIX: le DP -[00:56:36] VOIX: c'est ça -[00:56:37] VOIX: ça c'est un cas -[00:56:38] VOIX: un petit peu plus -[00:56:39] VOIX: complexe -[00:56:39] VOIX: que le cas -[00:56:40] VOIX: de la chirurgie -[00:56:41] VOIX: donc -[00:56:42] VOIX: mais ça -[00:56:42] VOIX: on le travaillera aussi -[00:56:43] VOIX: parce que -[00:56:44] VOIX: c'est des cas -[00:56:45] VOIX: qui existent en vrai -[00:56:47] VOIX: c'est un jeu -[00:56:49] VOIX: avec ça -[00:56:49] VOIX: c'est à dire que -[00:56:51] VOIX: lui il peut dire -[00:56:52] VOIX: j'utilise -[00:56:53] VOIX: la règle -[00:56:53] VOIX: T3 -[00:56:54] VOIX: et moi -[00:56:55] VOIX: je lui dis -[00:56:55] VOIX: ben non -[00:56:56] VOIX: j'utilise la règle -[00:56:57] VOIX: SA -[00:56:58] VOIX: oui oui -[00:56:59] VOIX: j'ai saisi -[00:57:01] VOIX: tu vois -[00:57:01] VOIX: c'est ça -[00:57:02] VOIX: les contre-argumentaires -[00:57:04] VOIX: c'est à dire que -[00:57:04] VOIX: parfois nous -[00:57:05] VOIX: dans les formations -[00:57:06] VOIX: on fait des synthèses -[00:57:07] VOIX: de ces règles -[00:57:08] VOIX: tu vois -[00:57:09] VOIX: t'as des exemples -[00:57:10] VOIX: t'as etc -[00:57:11] VOIX: et c'est vrai -[00:57:12] VOIX: qu'à la lecture -[00:57:13] VOIX: c'est pas l'impi -[00:57:14] VOIX: t'as des dossiers -[00:57:16] VOIX: qui sont clairs -[00:57:17] VOIX: mais typiquement -[00:57:18] VOIX: celui-là -[00:57:20] VOIX: il me dit -[00:57:21] VOIX: le patient -[00:57:21] VOIX: il vient pour -[00:57:22] VOIX: semaine d'éducation -[00:57:25] VOIX: anti-diabétique -[00:57:26] VOIX: au reau -[00:57:27] VOIX: il est traité -[00:57:27] VOIX: machin -[00:57:28] VOIX: bon on va pas -[00:57:28] VOIX: tout lire -[00:57:29] VOIX: mais il me dit -[00:57:30] VOIX: pas de modification -[00:57:31] VOIX: erreur du DP -[00:57:32] VOIX: tu vois -[00:57:33] VOIX: pas de modification -[00:57:34] VOIX: du schéma thérapeutique -[00:57:36] VOIX: non justifié -[00:57:37] VOIX: il s'agit d'un cas -[00:57:38] VOIX: de surveillance négative -[00:57:39] VOIX: le séjour -[00:57:40] VOIX: on n'a pas mis -[00:57:41] VOIX: en évidence -[00:57:42] VOIX: l'affection nouvelle -[00:57:43] VOIX: donc je code -[00:57:44] VOIX: le Z92 -[00:57:45] VOIX: et t'as vu -[00:57:46] VOIX: le Z092 -[00:57:49] VOIX: il me dévalorise -[00:57:50] VOIX: de 4000 euros -[00:57:51] VOIX: et moi -[00:57:52] VOIX: j'ai essayé -[00:57:53] VOIX: à défaut -[00:57:54] VOIX: de pouvoir -[00:57:55] VOIX: garder -[00:57:55] VOIX: mon premier DP -[00:57:57] VOIX: de proposer -[00:57:58] VOIX: un autre DP -[00:57:59] VOIX: qui casse -[00:58:00] VOIX: moins le dossier -[00:58:01] VOIX: tu vois -[00:58:02] VOIX: au moins -[00:58:02] VOIX: ça casse moi -[00:58:03] VOIX: au lieu de 4000 -[00:58:04] VOIX: bah t'es à moins -[00:58:05] VOIX: 1600 -[00:58:05] VOIX: et j'essaie -[00:58:07] VOIX: essentiellement -[00:58:08] VOIX: autour de l'éducation -[00:58:09] VOIX: la réponse -[00:58:10] VOIX: de docteur -[00:58:10] VOIX: machin -[00:58:11] VOIX: sur Agora -[00:58:12] VOIX: c'est le forum -[00:58:12] VOIX: de la TIH -[00:58:13] VOIX: c'était marqué -[00:58:15] VOIX: qu'on vient bien -[00:58:15] VOIX: en effet à l'éducation -[00:58:16] VOIX: thérapeutique -[00:58:17] VOIX: s'il en tient compte -[00:58:18] VOIX: de la notion -[00:58:19] VOIX: de conseil diététique -[00:58:20] VOIX: de son intitulé -[00:58:21] VOIX: et je lui casse -[00:58:22] VOIX: son code -[00:58:22] VOIX: en disant -[00:58:23] VOIX: le Z092 -[00:58:24] VOIX: ou situation -[00:58:25] VOIX: de surveillance négative -[00:58:27] VOIX: en lui incluant -[00:58:28] VOIX: les examens médicaux -[00:58:30] VOIX: de recherche -[00:58:30] VOIX: notamment -[00:58:31] VOIX: les complications -[00:58:31] VOIX: machin -[00:58:32] VOIX: donc -[00:58:33] VOIX: l'éducation -[00:58:34] VOIX: thérapeutique -[00:58:35] VOIX: est bien été -[00:58:36] VOIX: dans le dossier -[00:58:36] VOIX: patient -[00:58:37] VOIX: avec les différents -[00:58:37] VOIX: intervenants -[00:58:39] VOIX: voilà -[00:58:40] VOIX: donc je défends -[00:58:41] VOIX: j'essaie de -[00:58:42] VOIX: de défendre -[00:58:43] VOIX: un nouveau code -[00:58:44] VOIX: pour moi -[00:58:45] VOIX: parce que je n'ai pas -[00:58:47] VOIX: d'argument -[00:58:47] VOIX: pour maintenir -[00:58:48] VOIX: le codage initial -[00:58:50] VOIX: donc ça n'a pas été -[00:58:52] VOIX: en tout cas nous -[00:58:53] VOIX: on n'a pas repéré -[00:58:54] VOIX: lors de la préparation -[00:58:56] VOIX: mais comme le médecin contrôleur -[00:58:58] VOIX: propose quelque chose -[00:58:59] VOIX: qui cache beaucoup -[00:59:00] VOIX: le dossier -[00:59:00] VOIX: nous on va s'adapter -[00:59:02] VOIX: dans le nouveau argumentaire -[00:59:04] VOIX: à la décision du contrôleur -[00:59:06] VOIX: et la concertation -[00:59:08] VOIX: c'est la décision finale -[00:59:09] VOIX: est-ce que j'arrive -[00:59:10] VOIX: à le convaincre -[00:59:11] VOIX: ou est-ce qu'il garde -[00:59:12] VOIX: son avis initial -[00:59:13] VOIX: d'accord -[00:59:14] VOIX: bon j'ai parfaitement -[00:59:15] VOIX: en fait -[00:59:15] VOIX: saisi -[00:59:16] VOIX: ok -[00:59:16] VOIX: voilà -[00:59:18] VOIX: et donc -[00:59:19] VOIX: pour le dossier 1 -[00:59:24] VOIX: je ne sais pas -[00:59:25] VOIX: comment -[00:59:26] VOIX: je vais savoir -[00:59:28] VOIX: comment il était -[00:59:30] VOIX: il était codé -[00:59:32] VOIX: en fait -[00:59:34] VOIX: mais après -[00:59:35] VOIX: peu importe -[00:59:36] VOIX: tu vois -[00:59:37] VOIX: c'est la logique -[00:59:37] VOIX: qui compte -[00:59:38] VOIX: ouais -[00:59:39] VOIX: ouais -[00:59:44] VOIX: donc -[00:59:45] VOIX: justificatif -[00:59:47] VOIX: retour -[00:59:48] VOIX: je t'en as -[00:59:52] VOIX: donc -[00:59:53] VOIX: est-ce que -[00:59:54] VOIX: je vais -[00:59:55] VOIX: donc -[00:59:59] VOIX: tu vois -[00:59:59] VOIX: je n'ai pas de fiche -[01:00:01] VOIX: pour le 8 -[01:00:01] VOIX: donc -[01:00:02] VOIX: le médecin contrôleur -[01:00:03] VOIX: il a validé -[01:00:04] VOIX: le codage initial -[01:00:05] VOIX: de l'établissement -[01:00:06] VOIX: mais par exemple -[01:00:07] VOIX: pour le dossier 8 -[01:00:11] VOIX: nous on avait codé -[01:00:13] VOIX: on pourra regarder -[01:00:14] VOIX: le dossier 8 -[01:00:16] VOIX: on a codé -[01:00:17] VOIX: calcul de la vésicule -[01:00:18] VOIX: biliaire -[01:00:18] VOIX: et lui -[01:00:18] VOIX: il dit -[01:00:19] VOIX: non -[01:00:19] VOIX: c'est pas le K800 -[01:00:20] VOIX: c'est le K -[01:00:23] VOIX: le K805 -[01:00:25] VOIX: et tu vois -[01:00:26] VOIX: il n'a pas changé -[01:00:27] VOIX: les actes -[01:00:28] VOIX: mais le groupage -[01:00:29] VOIX: tu vois -[01:00:29] VOIX: ça passe -[01:00:30] VOIX: de 07713 -[01:00:31] VOIX: A 07714 -[01:00:33] VOIX: et si on regarde -[01:00:35] VOIX: l'appli -[01:00:35] VOIX: le dossier 8 -[01:00:40] VOIX: on va voir -[01:00:42] VOIX: nous -[01:00:44] VOIX: on avait déjà repéré -[01:00:45] VOIX: dans la prépa -[01:00:46] VOIX: que le patient -[01:00:48] VOIX: que le codage -[01:00:49] VOIX: n'était pas bon -[01:00:50] VOIX: parce qu'il n'y avait pas -[01:00:51] VOIX: de cholycystite aiguë -[01:00:53] VOIX: donc on avait proposé -[01:00:54] VOIX: un code -[01:00:55] VOIX: qui est le 801 -[01:00:56] VOIX: différent de 800 -[01:00:57] VOIX: mais qui faisait perdre -[01:00:58] VOIX: déjà 800 euros -[01:01:00] VOIX: le contrôleur -[01:01:01] VOIX: il a dit -[01:01:01] VOIX: le 805 -[01:01:02] VOIX: mais c'est pareil -[01:01:04] VOIX: sa groupe -[01:01:05] VOIX: pareil -[01:01:06] VOIX: il dit -[01:01:07] VOIX: bon moi -[01:01:07] VOIX: je ne sais pas -[01:01:07] VOIX: pourquoi -[01:01:08] VOIX: j'ai essayé de changer -[01:01:09] VOIX: ça ne change rien du tout -[01:01:10] VOIX: et que finalement -[01:01:11] VOIX: on était tous d'accord -[01:01:12] VOIX: à mon avis -[01:01:13] VOIX: lui -[01:01:13] VOIX: il est en accord -[01:01:14] VOIX: parce que -[01:01:15] VOIX: on ne peut rien tirer -[01:01:16] VOIX: nous -[01:01:17] VOIX: d'emblée -[01:01:18] VOIX: dans la prépa -[01:01:19] VOIX: on a dit -[01:01:20] VOIX: il n'y a pas de choix -[01:01:21] VOIX: par contre -[01:01:22] VOIX: on va le lire ensemble -[01:01:23] VOIX: le dossier 8 -[01:01:24] VOIX: tu vois -[01:01:25] VOIX: je vais sortir -[01:01:26] VOIX: pour que tu vois -[01:01:28] VOIX: parce qu'il ressemble -[01:01:28] VOIX: un peu -[01:01:29] VOIX: à ce dossier là -[01:01:31] VOIX: et on va voir -[01:01:35] VOIX: ça va ? -[01:01:36] VOIX: alors ça fait -[01:01:37] VOIX: ça fait beaucoup d'informations -[01:01:40] VOIX: mais bon -[01:01:41] VOIX: tu as répondu -[01:01:42] VOIX: à une partie de mes questions -[01:01:43] VOIX: il va falloir -[01:01:43] VOIX: que je t'en pose -[01:01:46] VOIX: quelques-unes -[01:01:46] VOIX: mais pour moi -[01:01:48] VOIX: en fait -[01:01:48] VOIX: c'est bon -[01:01:49] VOIX: en fait -[01:01:49] VOIX: ce qu'il me faut -[01:01:50] VOIX: ce qu'il me faudrait -[01:01:51] VOIX: en fait -[01:01:51] VOIX: pour faire un truc carré -[01:01:53] VOIX: pour moi -[01:01:54] VOIX: je parle en fait -[01:01:54] VOIX: pour que je puisse avoir -[01:01:55] VOIX: ma vision -[01:01:56] VOIX: il me faudrait en fait -[01:01:57] VOIX: discuter -[01:01:58] VOIX: en fait -[01:01:58] VOIX: on prend -[01:01:59] VOIX: un ou deux -[01:02:00] VOIX: je t'écoute -[01:02:01] VOIX: je prends juste mon ordi -[01:02:02] VOIX: ouais -[01:02:03] VOIX: c'est de prendre -[01:02:03] VOIX: c'est de prendre -[01:02:04] VOIX: un ou deux -[01:02:06] VOIX: compte rendu -[01:02:08] VOIX: et de faire -[01:02:09] VOIX: toute la chaîne -[01:02:10] VOIX: en fait -[01:02:11] VOIX: c'est -[01:02:11] VOIX: voilà -[01:02:12] VOIX: on l'a codé comme ça -[01:02:13] VOIX: et puis après -[01:02:14] VOIX: en fait -[01:02:15] VOIX: on a un contrôle -[01:02:16] VOIX: le contrôleur -[01:02:17] VOIX: il dit ça -[01:02:18] VOIX: ou ça -[01:02:18] VOIX: et tu vois -[01:02:19] VOIX: en fait -[01:02:19] VOIX: faire toute la chaîne -[01:02:20] VOIX: sur un ou deux dossiers -[01:02:22] VOIX: de telle manière -[01:02:22] VOIX: avoir quelque chose -[01:02:23] VOIX: en fait -[01:02:23] VOIX: de probant -[01:02:25] VOIX: voilà -[01:02:26] VOIX: alors regarde -[01:02:27] VOIX: le dossier -[01:02:28] VOIX: pas le 8 -[01:02:29] VOIX: parce qu'on était d'accord -[01:02:30] VOIX: de A à Z -[01:02:33] VOIX: on va -[01:02:33] VOIX: parce que j'ai mis -[01:02:35] VOIX: 399 -[01:02:35] VOIX: mais en fait -[01:02:36] VOIX: c'est le dossier -[01:02:38] VOIX: 339 -[01:02:38] VOIX: alors je vais pas -[01:02:40] VOIX: l'avoir -[01:02:40] VOIX: figure toi -[01:02:40] VOIX: parce qu'en fait -[01:02:42] VOIX: je travaille que sur -[01:02:43] VOIX: 250 dossiers -[01:02:45] VOIX: et à mon avis -[01:02:45] VOIX: le 300 -[01:02:47] VOIX: 39 -[01:02:48] VOIX: OGC -[01:02:49] VOIX: 339 -[01:02:50] VOIX: non -[01:02:50] VOIX: je l'ai pas -[01:02:53] VOIX: je l'ai pas -[01:02:54] VOIX: à la limite -[01:02:55] VOIX: en fait -[01:02:55] VOIX: si tu l'as sous la main -[01:02:56] VOIX: envoie-moi-le -[01:02:56] VOIX: comme ça -[01:02:57] VOIX: en fait -[01:02:57] VOIX: je l'aurai -[01:02:57] VOIX: ou alors -[01:02:58] VOIX: je vais le chercher -[01:02:59] VOIX: parce que -[01:02:59] VOIX: j'ai été obligé -[01:03:00] VOIX: d'arrêter -[01:03:00] VOIX: mon navigateur -[01:03:01] VOIX: et je peux pas -[01:03:02] VOIX: aller le chercher -[01:03:03] VOIX: de suite -[01:03:04] VOIX: quoi que -[01:03:05] VOIX: allez -[01:03:05] VOIX: 50 -[01:03:06] VOIX: en anglais -[01:03:12] VOIX: où est-ce que -[01:03:13] VOIX: je t'avais déposé -[01:03:14] VOIX: tout ça -[01:03:15] VOIX: dans -[01:03:17] VOIX: chère fille -[01:03:17] VOIX: attends -[01:03:18] VOIX: je vais peut-être -[01:03:19] VOIX: je vais le trouver -[01:03:20] VOIX: c'est bon -[01:03:24] VOIX: ok -[01:03:28] VOIX: et moi -[01:03:29] VOIX: j'ai le codage -[01:03:30] VOIX: dans l'appli -[01:03:30] VOIX: les argumentaires -[01:03:31] VOIX: dans l'appli -[01:03:32] VOIX: donc on le lit -[01:03:33] VOIX: ensemble -[01:03:34] VOIX: et on -[01:03:37] VOIX: regarde l'appli -[01:03:38] VOIX: après -[01:03:39] VOIX: ensemble -[01:03:39] VOIX: allez -[01:03:40] VOIX: le partage -[01:03:42] VOIX: document post -[01:03:44] VOIX: contrôle -[01:03:45] VOIX: je pense que -[01:03:46] VOIX: c'était celui -[01:03:46] VOIX: non -[01:03:47] VOIX: attends -[01:03:47] VOIX: ça c'est -[01:03:47] VOIX: guillière -[01:03:49] VOIX: hop là -[01:03:53] VOIX: voilà -[01:03:53] VOIX: je l'ai -[01:03:54] VOIX: voilà -[01:03:55] VOIX: et tu m'as dit -[01:03:56] VOIX: alors attends -[01:03:56] VOIX: parce que -[01:03:57] VOIX: c'est -[01:03:58] VOIX: j'avais -[01:03:58] VOIX: trouvé que -[01:03:59] VOIX: 250 dossiers -[01:04:01] VOIX: et en fait -[01:04:01] VOIX: il faut cliquer -[01:04:02] VOIX: 10 fois -[01:04:03] VOIX: pour voir -[01:04:03] VOIX: en fait -[01:04:03] VOIX: alors tu m'as dit -[01:04:04] VOIX: 339 -[01:04:06] VOIX: oui -[01:04:09] VOIX: 339 -[01:04:12] VOIX: voilà -[01:04:13] VOIX: c'est -[01:04:13] VOIX: 339 -[01:04:14] VOIX: 23 07 -[01:04:16] VOIX: 2740 -[01:04:17] VOIX: c'est ça -[01:04:18] VOIX: je vérifie -[01:04:19] VOIX: oui -[01:04:22] VOIX: 2740 -[01:04:22] VOIX: oui -[01:04:23] VOIX: allez -[01:04:23] VOIX: je le télécharge -[01:04:25] VOIX: on est sur -[01:04:25] VOIX: le bon dossier -[01:04:27] VOIX: hop là -[01:04:28] VOIX: alors -[01:04:28] VOIX: ah oui -[01:04:29] VOIX: mais ça c'est -[01:04:29] VOIX: le tracker -[01:04:31] VOIX: il me faut tout -[01:04:34] VOIX: mais t'auras -[01:04:35] VOIX: que le tracker -[01:04:36] VOIX: mais on va voir -[01:04:36] VOIX: que dans le tracker -[01:04:37] VOIX: il peut y avoir -[01:04:38] VOIX: beaucoup de choses -[01:04:39] VOIX: d'accord -[01:04:39] VOIX: peut-être que -[01:04:40] VOIX: t'as qu'un tracker -[01:04:41] VOIX: mais le tracker -[01:04:42] VOIX: c'est -[01:04:43] VOIX: il concatène -[01:04:45] VOIX: un peu -[01:04:45] VOIX: les documents -[01:04:46] VOIX: qui sort -[01:04:46] VOIX: ensemble -[01:04:48] VOIX: hop là -[01:04:49] VOIX: donc j'ai le document -[01:04:51] VOIX: je te partage -[01:04:52] VOIX: mon écran -[01:04:52] VOIX: oui -[01:04:53] VOIX: en premier -[01:04:54] VOIX: oui -[01:04:56] VOIX: voilà -[01:04:57] VOIX: tu te vois -[01:05:00] VOIX: oui -[01:05:02] VOIX: allez -[01:05:03] VOIX: voilà -[01:05:04] VOIX: donc on va tout en bout -[01:05:08] VOIX: exactement -[01:05:09] VOIX: donc on refait -[01:05:10] VOIX: la même -[01:05:12] VOIX: mais j'ai l'impression -[01:05:13] VOIX: que c'est le même -[01:05:14] VOIX: non c'est le même -[01:05:14] VOIX: celui-là c'est le même -[01:05:15] VOIX: alors attends -[01:05:16] VOIX: ouais non c'est pas le bon -[01:05:18] VOIX: je vais le fermer -[01:05:20] VOIX: je vais fermer ça -[01:05:22] VOIX: et je vais recommencer -[01:05:24] VOIX: à ouvrir en fait -[01:05:25] VOIX: le document -[01:05:26] VOIX: que je viens de télécharger -[01:05:29] VOIX: voilà c'est ça -[01:05:31] VOIX: il était ouvert -[01:05:32] VOIX: mais dans un autre logiciel -[01:05:33] VOIX: oui -[01:05:34] VOIX: augmente -[01:05:34] VOIX: augmente un peu -[01:05:35] VOIX: s'il te plaît -[01:05:39] VOIX: ça te va comme ça -[01:05:40] VOIX: voilà -[01:05:40] VOIX: en fait -[01:05:40] VOIX: tracker -[01:05:42] VOIX: d'abord -[01:05:43] VOIX: c'est -[01:05:45] VOIX: la solution -[01:05:46] VOIX: de tracker -[01:05:47] VOIX: de sortir les documents -[01:05:48] VOIX: il faut vraiment pas -[01:05:49] VOIX: que tu penses -[01:05:50] VOIX: que tous les documents -[01:05:51] VOIX: on va les avoir comme ça -[01:05:52] VOIX: ça sera complètement différent -[01:05:54] VOIX: ça il faut l'avoir en tête -[01:05:56] VOIX: donc déjà -[01:05:58] VOIX: tu regardes toujours -[01:06:01] VOIX: les noms -[01:06:02] VOIX: tu vois le dossier -[01:06:05] VOIX: les dates -[01:06:06] VOIX: là je vois que la date d'admission -[01:06:08] VOIX: c'est le 13 avril -[01:06:10] VOIX: date de sortie 21 avril -[01:06:11] VOIX: donc tout ça -[01:06:13] VOIX: parce que je vais le croiser -[01:06:14] VOIX: aussi bien -[01:06:15] VOIX: avec le dossier administratif -[01:06:18] VOIX: avec le dossier du codage -[01:06:19] VOIX: etc -[01:06:21] VOIX: si tu descends -[01:06:22] VOIX: j'ai pas la main -[01:06:24] VOIX: donc c'est -[01:06:26] VOIX: tu vois -[01:06:27] VOIX: il vient -[01:06:28] VOIX: le 13 -[01:06:31] VOIX: il passe aux urgences -[01:06:32] VOIX: très bien -[01:06:32] VOIX: bien vu -[01:06:33] VOIX: tu vois -[01:06:33] VOIX: là tu vois que -[01:06:35] VOIX: t'as la trace -[01:06:36] VOIX: du passage aux urgences -[01:06:37] VOIX: par véhicule personnel -[01:06:40] VOIX: les horaires -[01:06:42] VOIX: circulant -[01:06:43] VOIX: douleur à dos -[01:06:44] VOIX: tu vois -[01:06:44] VOIX: ça c'est quoi -[01:06:45] VOIX: priorité 4 -[01:06:46] VOIX: c'est une classification -[01:06:47] VOIX: des urgences -[01:06:48] VOIX: nous on l'utilise pas -[01:06:50] VOIX: mais après -[01:06:51] VOIX: tu sais -[01:06:53] VOIX: les différentes échelles -[01:06:56] VOIX: etc -[01:06:57] VOIX: on pourra -[01:06:58] VOIX: peut-être -[01:06:59] VOIX: petit à petit -[01:07:00] VOIX: les intégrer -[01:07:00] VOIX: mais c'est pas -[01:07:01] VOIX: quelque chose -[01:07:01] VOIX: que nous -[01:07:03] VOIX: priorité 4 -[01:07:04] VOIX: peut-être eux -[01:07:04] VOIX: tu sais -[01:07:05] VOIX: ils classent -[01:07:06] VOIX: est-ce que -[01:07:07] VOIX: d'accord -[01:07:07] VOIX: ça c'est dans -[01:07:08] VOIX: est-ce que -[01:07:10] VOIX: je le fais -[01:07:10] VOIX: rentrer en premier -[01:07:11] VOIX: ou est-ce qu'il peut -[01:07:12] VOIX: attendre un peu -[01:07:13] VOIX: ok -[01:07:13] VOIX: je connais pas -[01:07:15] VOIX: est-ce que c'est une -[01:07:17] VOIX: classification nationale -[01:07:18] VOIX: ou juste interne -[01:07:20] VOIX: ça je peux pas te dire -[01:07:20] VOIX: d'accord -[01:07:21] VOIX: ça c'est un truc -[01:07:21] VOIX: oui c'est un truc -[01:07:23] VOIX: interne -[01:07:23] VOIX: bon ok -[01:07:24] VOIX: ok -[01:07:24] VOIX: et après en fait -[01:07:25] VOIX: on va partir du motif -[01:07:27] VOIX: de prise en charge -[01:07:28] VOIX: et là on sait que -[01:07:29] VOIX: c'est le motif -[01:07:30] VOIX: au niveau des urgences -[01:07:31] VOIX: d'accord -[01:07:32] VOIX: on voit douleur abdo -[01:07:34] VOIX: on voit les observations -[01:07:36] VOIX: de l'IDE -[01:07:36] VOIX: des urgences -[01:07:37] VOIX: il dit -[01:07:38] VOIX: abrécé par médecin -[01:07:39] VOIX: pour douleur abdo -[01:07:40] VOIX: avec diarrhée persistante -[01:07:41] VOIX: personne suivie -[01:07:43] VOIX: par docteur machin -[01:07:44] VOIX: pour anémie -[01:07:44] VOIX: fériprive -[01:07:45] VOIX: dans un contexte -[01:07:47] VOIX: de maladie -[01:07:48] VOIX: de bière mère -[01:07:49] VOIX: donc là -[01:07:50] VOIX: c'est un -[01:07:53] VOIX: on va dire -[01:07:54] VOIX: c'est un antécédent -[01:07:55] VOIX: on le note -[01:07:57] VOIX: dans un coin -[01:07:58] VOIX: du cerveau -[01:07:59] VOIX: pour vérifier -[01:08:00] VOIX: si pendant -[01:08:01] VOIX: le séjour -[01:08:03] VOIX: il y a une prise en charge -[01:08:05] VOIX: diagnostique -[01:08:06] VOIX: thérapeutique -[01:08:06] VOIX: ou surveillance -[01:08:08] VOIX: de sa maladie -[01:08:09] VOIX: de bière mère -[01:08:09] VOIX: et de son anémie -[01:08:10] VOIX: fériprive -[01:08:11] VOIX: si elle prend -[01:08:12] VOIX: un traitement -[01:08:13] VOIX: si ça a été surveillé -[01:08:14] VOIX: si ça a été évoqué -[01:08:16] VOIX: on va le coder -[01:08:17] VOIX: en diagnostic associé -[01:08:18] VOIX: mais si personne -[01:08:20] VOIX: n'en parle -[01:08:20] VOIX: ou ne s'en préoccupe -[01:08:22] VOIX: ça va être difficile -[01:08:23] VOIX: de le coder -[01:08:24] VOIX: parce qu'on va considérer -[01:08:25] VOIX: que c'est juste -[01:08:26] VOIX: un antécédent -[01:08:28] VOIX: qui n'a pas été -[01:08:29] VOIX: prise en charge -[01:08:29] VOIX: pendant -[01:08:30] VOIX: l'hospite -[01:08:31] VOIX: il n'y a pas -[01:08:32] VOIX: d'histoire de maladie -[01:08:33] VOIX: là -[01:08:33] VOIX: remonte -[01:08:33] VOIX: remonte -[01:08:34] VOIX: t'es allé un peu vite -[01:08:34] VOIX: non non mais je sais -[01:08:35] VOIX: pourquoi j'ai fait ça -[01:08:36] VOIX: mais vas-y -[01:08:37] VOIX: oui -[01:08:39] VOIX: donc tu vois -[01:08:40] VOIX: là -[01:08:40] VOIX: les urgences -[01:08:41] VOIX: machin -[01:08:42] VOIX: il l'hospitalise -[01:08:43] VOIX: en médecine -[01:08:45] VOIX: donc on voit -[01:08:46] VOIX: qu'elle est remontée -[01:08:47] VOIX: l'avant-dernière ligne -[01:08:48] VOIX: les deux dernières lignes -[01:08:49] VOIX: tu vois -[01:08:50] VOIX: qu'elle est remontée -[01:08:50] VOIX: en gastro-entérologie -[01:08:52] VOIX: là en tout cas -[01:08:53] VOIX: ce document -[01:08:54] VOIX: des urgences -[01:08:55] VOIX: il trace -[01:08:55] VOIX: un petit peu -[01:08:56] VOIX: ce qui s'est passé -[01:08:57] VOIX: donc c'est plutôt -[01:08:58] VOIX: pas mal -[01:08:59] VOIX: c'est pas le cas partout -[01:09:00] VOIX: dans tous les établissements -[01:09:02] VOIX: et avec tous les -[01:09:03] VOIX: des pays -[01:09:03] VOIX: encore une fois -[01:09:05] VOIX: c'est pour ça que -[01:09:06] VOIX: c'est des adaptations -[01:09:08] VOIX: t'as des informations -[01:09:10] VOIX: aussi en dehors du tableau -[01:09:12] VOIX: dans taille -[01:09:13] VOIX: poids -[01:09:13] VOIX: tu vois -[01:09:18] VOIX: il faut les garder -[01:09:21] VOIX: dans un coin -[01:09:22] VOIX: ça fait partie des infos -[01:09:24] VOIX: et -[01:09:25] VOIX: on descend -[01:09:26] VOIX: on voit -[01:09:27] VOIX: ce qui s'est passé -[01:09:28] VOIX: après -[01:09:30] VOIX: qu'est-ce qu'il y a -[01:09:30] VOIX: ah oui -[01:09:31] VOIX: là -[01:09:31] VOIX: tu as -[01:09:32] VOIX: ce qu'on appelle -[01:09:33] VOIX: les constantes -[01:09:35] VOIX: la température -[01:09:36] VOIX: le pouls -[01:09:37] VOIX: la fréquence cardiaque -[01:09:38] VOIX: la pression arternelle -[01:09:39] VOIX: etc -[01:09:40] VOIX: jour par jour -[01:09:41] VOIX: et heure par heure -[01:09:42] VOIX: en tout cas -[01:09:43] VOIX: quand ils l'ont prise -[01:09:44] VOIX: et ça parfois -[01:09:45] VOIX: on l'exploite -[01:09:48] VOIX: pour un certain codage -[01:09:51] VOIX: là typiquement -[01:09:53] VOIX: moi d'emblée -[01:09:54] VOIX: à la -[01:09:54] VOIX: 1, 2, 3, 4, 5, 6 -[01:09:57] VOIX: la 6ème colonne -[01:09:58] VOIX: c'est-à-dire -[01:09:59] VOIX: le 20 avril -[01:10:00] VOIX: à 15h51 -[01:10:01] VOIX: je repère -[01:10:03] VOIX: que la patiente -[01:10:05] VOIX: elle avait -[01:10:05] VOIX: 38 de fièvre -[01:10:07] VOIX: et 95 de -[01:10:09] VOIX: battements -[01:10:10] VOIX: par l'équipe -[01:10:11] VOIX: ce qui fait beaucoup -[01:10:12] VOIX: et le 20 avril -[01:10:14] VOIX: à 8h -[01:10:15] VOIX: c'est-à-dire -[01:10:15] VOIX: 3 colonnes après -[01:10:16] VOIX: c'est pareil -[01:10:18] VOIX: il y a 38 -[01:10:19] VOIX: et 93 -[01:10:20] VOIX: et on a -[01:10:21] VOIX: une règle -[01:10:22] VOIX: du codage -[01:10:22] VOIX: qui dit -[01:10:24] VOIX: qu'on peut -[01:10:24] VOIX: coder -[01:10:25] VOIX: un niveau 2 -[01:10:26] VOIX: un R65 -[01:10:27] VOIX: en chante fou -[01:10:28] VOIX: quand on a -[01:10:29] VOIX: la fièvre -[01:10:30] VOIX: supérieure à 38 -[01:10:31] VOIX: ou la fréquence cardiaque -[01:10:32] VOIX: supérieure à 90 -[01:10:33] VOIX: donc quand on est -[01:10:35] VOIX: à court de diagnostic -[01:10:36] VOIX: ça on l'analyse aussi -[01:10:38] VOIX: tu vois -[01:10:40] VOIX: parfois dans -[01:10:41] VOIX: certains services -[01:10:42] VOIX: pour certains suppléments -[01:10:43] VOIX: etc -[01:10:44] VOIX: on va regarder -[01:10:45] VOIX: la pression artérielle -[01:10:46] VOIX: on va regarder -[01:10:47] VOIX: la diurèse -[01:10:48] VOIX: on va regarder -[01:10:49] VOIX: donc -[01:10:49] VOIX: on l'analyse pas -[01:10:51] VOIX: systématiquement -[01:10:52] VOIX: mais en fonction -[01:10:53] VOIX: des besoins -[01:10:53] VOIX: on peut avoir -[01:10:55] VOIX: on peut avoir besoin -[01:10:57] VOIX: d'analyser -[01:10:58] VOIX: ces éléments là -[01:10:59] VOIX: ok -[01:11:01] VOIX: mais pas -[01:11:02] VOIX: mais pas a priori -[01:11:03] VOIX: tu vois -[01:11:04] VOIX: ça j'y vais -[01:11:05] VOIX: après -[01:11:05] VOIX: là -[01:11:06] VOIX: c'est parce qu'on passe -[01:11:07] VOIX: comme ça -[01:11:07] VOIX: et pour que j'oublie pas -[01:11:09] VOIX: pour te dire -[01:11:10] VOIX: qu'on l'utilise -[01:11:11] VOIX: mais je vais pas -[01:11:12] VOIX: commencer par analyser ça -[01:11:13] VOIX: mais là je vois -[01:11:14] VOIX: que j'ai un R65 -[01:11:15] VOIX: en niveau 2 -[01:11:16] VOIX: que je peux aller chercher -[01:11:18] VOIX: si vraiment -[01:11:18] VOIX: je suis à court -[01:11:19] VOIX: de niveau de sévérité -[01:11:21] VOIX: ok -[01:11:23] VOIX: si tu descends -[01:11:26] VOIX: parce qu'après -[01:11:27] VOIX: on les concentre -[01:11:28] VOIX: là -[01:11:28] VOIX: tu vois le type -[01:11:30] VOIX: d'observation -[01:11:31] VOIX: on va parcourir -[01:11:32] VOIX: là tu es dans -[01:11:34] VOIX: les observations médicales -[01:11:35] VOIX: descends -[01:11:36] VOIX: tu as l'histoire -[01:11:37] VOIX: de la maladie -[01:11:38] VOIX: on va regarder -[01:11:38] VOIX: que les gros titres -[01:11:39] VOIX: tu as la note d'évolution -[01:11:42] VOIX: continue -[01:11:44] VOIX: continue en bas -[01:11:45] VOIX: note d'évolution -[01:11:46] VOIX: donc ça -[01:11:46] VOIX: on est dans -[01:11:47] VOIX: le médical -[01:11:48] VOIX: les observations médicales -[01:11:49] VOIX: tout ça -[01:11:51] VOIX: c'est les observations médicales -[01:11:52] VOIX: on va voir -[01:11:53] VOIX: est-ce qu'on a -[01:11:54] VOIX: d'autres types -[01:11:55] VOIX: de documents -[01:11:56] VOIX: on a -[01:11:57] VOIX: attends stop -[01:11:58] VOIX: on a la sur -[01:11:59] VOIX: remonte un tout petit peu -[01:12:00] VOIX: on a la surveillance -[01:12:03] VOIX: psychiatrie -[01:12:03] VOIX: mais ça c'est bizarre -[01:12:05] VOIX: oui -[01:12:05] VOIX: c'est des constantes -[01:12:06] VOIX: je te répète -[01:12:07] VOIX: tu vois ces pièges -[01:12:09] VOIX: on a les notes paramédicales -[01:12:11] VOIX: ça c'est les infirmières -[01:12:12] VOIX: notes idéales -[01:12:14] VOIX: attend parce qu'en fait -[01:12:14] VOIX: la personne elle est rentrée -[01:12:15] VOIX: attends attends -[01:12:16] VOIX: tu permets une seconde -[01:12:18] VOIX: bien sûr -[01:12:19] VOIX: pourquoi en fait -[01:12:20] VOIX: elle a mal -[01:12:28] VOIX: pourquoi ta question -[01:12:29] VOIX: non non non -[01:12:30] VOIX: c'est parce que -[01:12:31] VOIX: j'ai essayé de comprendre -[01:12:32] VOIX: pourquoi en fait -[01:12:32] VOIX: on part en fait -[01:12:33] VOIX: la prise en charge -[01:12:34] VOIX: en fait -[01:12:34] VOIX: avec une douleur abdominale -[01:12:36] VOIX: et elle finit en psychiatrie -[01:12:38] VOIX: enfin pourquoi en fait -[01:12:38] VOIX: il y a une note -[01:12:39] VOIX: je t'ai dit -[01:12:40] VOIX: c'est un piège -[01:12:41] VOIX: c'est une erreur -[01:12:42] VOIX: parce que tu vois -[01:12:43] VOIX: psychiatrie -[01:12:43] VOIX: et t'as que des constantes -[01:12:45] VOIX: ouais d'accord -[01:12:45] VOIX: donc c'est pas normal -[01:12:47] VOIX: que la surveillance psychiatrique -[01:12:49] VOIX: se fait avec des constantes -[01:12:51] VOIX: ça n'a pas de sens -[01:12:52] VOIX: oui ça n'a pas de sens -[01:12:53] VOIX: d'accord ok -[01:12:54] VOIX: mais tu vois -[01:12:54] VOIX: il faut faire gaffe -[01:12:55] VOIX: aux termes -[01:12:56] VOIX: parce que là -[01:12:57] VOIX: évidemment -[01:12:58] VOIX: je ne sais pas pourquoi -[01:12:59] VOIX: t'as des notes -[01:13:01] VOIX: paramédicales -[01:13:02] VOIX: tu vois -[01:13:02] VOIX: c'est des notes -[01:13:03] VOIX: de l'infirmière -[01:13:03] VOIX: tu vois -[01:13:04] VOIX: que t'as la note -[01:13:05] VOIX: du kinésithérapeute -[01:13:07] VOIX: tu vois tout ça -[01:13:08] VOIX: tu sais qui l'a vu -[01:13:09] VOIX: il y a le médecin -[01:13:11] VOIX: qui l'a vu -[01:13:11] VOIX: il y a l'IDE -[01:13:12] VOIX: il y a le kiné -[01:13:14] VOIX: descend -[01:13:15] VOIX: tu vois -[01:13:16] VOIX: descend parce que -[01:13:16] VOIX: c'est important -[01:13:17] VOIX: qu'on voit -[01:13:20] VOIX: combien un dossier -[01:13:21] VOIX: il peut être lent -[01:13:22] VOIX: et tu peux avoir -[01:13:23] VOIX: beaucoup de notes -[01:13:23] VOIX: et tout ça -[01:13:24] VOIX: on le lit -[01:13:24] VOIX: pour info -[01:13:26] VOIX: t'as la note -[01:13:27] VOIX: de la aide soignante -[01:13:28] VOIX: note IDE -[01:13:28] VOIX: stop -[01:13:29] VOIX: t'as -[01:13:30] VOIX: après t'as -[01:13:31] VOIX: les administrations -[01:13:33] VOIX: médicamenteuses -[01:13:34] VOIX: continue -[01:13:35] VOIX: continue -[01:13:38] VOIX: parfois -[01:13:39] VOIX: on va aller chercher -[01:13:41] VOIX: pour vérifier -[01:13:42] VOIX: est-ce qu'il y a -[01:13:42] VOIX: de la sueline -[01:13:43] VOIX: pour un diabète -[01:13:44] VOIX: est-ce qu'il y a -[01:13:44] VOIX: des antibiotiques -[01:13:46] VOIX: est-ce que -[01:13:46] VOIX: tu vois -[01:13:47] VOIX: ça dépend après -[01:13:48] VOIX: du niveau de -[01:13:50] VOIX: de vigilance -[01:13:51] VOIX: ou de lecture -[01:13:52] VOIX: de la personne -[01:13:52] VOIX: qui code -[01:13:53] VOIX: mais c'est des choses -[01:13:54] VOIX: qu'on peut aller chercher -[01:13:55] VOIX: après t'as la biologie -[01:13:57] VOIX: t'as la radiologie -[01:13:59] VOIX: tu vois -[01:13:59] VOIX: donc au moins -[01:14:00] VOIX: les prescriptions -[01:14:01] VOIX: de radio -[01:14:02] VOIX: tu vas vérifier -[01:14:03] VOIX: si elle a le scan -[01:14:04] VOIX: l'écho -[01:14:05] VOIX: la radio du thorax -[01:14:06] VOIX: est-ce que t'as -[01:14:07] VOIX: les comptes rendus -[01:14:07] VOIX: en fait -[01:14:08] VOIX: quand je te dis -[01:14:09] VOIX: parfois -[01:14:10] VOIX: dans un dossier -[01:14:11] VOIX: et ça c'est un élément -[01:14:12] VOIX: vraiment important -[01:14:13] VOIX: quand on va faire -[01:14:14] VOIX: l'IA vision -[01:14:14] VOIX: sur le DPI -[01:14:16] VOIX: c'est identifier -[01:14:17] VOIX: ce qui a été fait -[01:14:18] VOIX: comme document -[01:14:19] VOIX: et s'assurer -[01:14:20] VOIX: qu'on a tous -[01:14:21] VOIX: les comptes rendus -[01:14:22] VOIX: ça c'est quelque chose -[01:14:22] VOIX: qui est capital -[01:14:23] VOIX: là typiquement -[01:14:24] VOIX: on voit -[01:14:25] VOIX: qu'il y a eu -[01:14:26] VOIX: une prescription -[01:14:26] VOIX: de radio -[01:14:27] VOIX: de scan -[01:14:27] VOIX: etc -[01:14:28] VOIX: on doit vérifier -[01:14:30] VOIX: est-ce qu'elle a eu -[01:14:32] VOIX: est-ce qu'elle a eu -[01:14:33] VOIX: vraiment son scan -[01:14:34] VOIX: c'est marqué -[01:14:34] VOIX: réalisé -[01:14:35] VOIX: réalisé -[01:14:35] VOIX: il faut normalement -[01:14:37] VOIX: que je puisse récupérer -[01:14:38] VOIX: les comptes rendus -[01:14:40] VOIX: sinon je dois indiquer -[01:14:42] VOIX: que la patiente -[01:14:42] VOIX: a un scan -[01:14:43] VOIX: une écho -[01:14:43] VOIX: mais finalement -[01:14:44] VOIX: je n'ai pas -[01:14:44] VOIX: les comptes rendus -[01:14:45] VOIX: tu vois -[01:14:46] VOIX: et Amina -[01:14:47] VOIX: en fait -[01:14:47] VOIX: dans ce cas là -[01:14:48] VOIX: alors -[01:14:50] VOIX: normalement -[01:14:50] VOIX: s'il a les comptes rendus -[01:14:52] VOIX: ils sont en fait -[01:14:53] VOIX: dans le même dossier -[01:14:54] VOIX: ou alors -[01:14:55] VOIX: normalement -[01:14:55] VOIX: un dossier bien fait -[01:14:56] VOIX: il y a tout -[01:14:57] VOIX: tu vas tout retrouver -[01:14:59] VOIX: tu vas retrouver -[01:15:00] VOIX: tous les comptes rendus -[01:15:01] VOIX: d'accord -[01:15:02] VOIX: ou bien -[01:15:03] VOIX: ils te disent -[01:15:03] VOIX: les comptes rendus -[01:15:04] VOIX: d'imagerie -[01:15:05] VOIX: sont dans tel logiciel -[01:15:06] VOIX: donc on peut aller -[01:15:07] VOIX: les chercher -[01:15:07] VOIX: dans un autre logiciel -[01:15:09] VOIX: ok -[01:15:09] VOIX: tu vois -[01:15:10] VOIX: mais -[01:15:11] VOIX: le plus important -[01:15:13] VOIX: c'est de -[01:15:13] VOIX: c'est de dire -[01:15:15] VOIX: là on a vu -[01:15:16] VOIX: les observations médicales -[01:15:18] VOIX: on a vu -[01:15:18] VOIX: les transmissions -[01:15:19] VOIX: des infirmières -[01:15:20] VOIX: on a vu -[01:15:21] VOIX: une note du kiné -[01:15:22] VOIX: mais je ne sais pas -[01:15:23] VOIX: est-ce que j'ai -[01:15:23] VOIX: un compte rendu opératoire -[01:15:24] VOIX: est-ce que j'ai -[01:15:25] VOIX: un compte rendu -[01:15:26] VOIX: est-ce que j'ai -[01:15:28] VOIX: les comptes rendus -[01:15:28] VOIX: de l'imagerie -[01:15:30] VOIX: là la biologie -[01:15:31] VOIX: il faut que je m'assure -[01:15:32] VOIX: que j'ai les résultats -[01:15:33] VOIX: de mes prescriptions -[01:15:35] VOIX: biologiques -[01:15:36] VOIX: bactériologiques -[01:15:36] VOIX: parce que je vois -[01:15:37] VOIX: qu'il y a eu -[01:15:37] VOIX: de l'anapate -[01:15:38] VOIX: il y a eu -[01:15:39] VOIX: des hémocultures -[01:15:41] VOIX: typiquement -[01:15:41] VOIX: dans le premier dossier -[01:15:42] VOIX: que tu m'as montré -[01:15:43] VOIX: il fallait qu'on aille -[01:15:44] VOIX: chercher le résultat -[01:15:45] VOIX: de l'anapate -[01:15:45] VOIX: je ne sais pas -[01:15:46] VOIX: s'il y était -[01:15:47] VOIX: dans le traquer -[01:15:47] VOIX: ou pas -[01:15:48] VOIX: mais le résultat -[01:15:49] VOIX: de l'anapate -[01:15:49] VOIX: il est clé -[01:15:50] VOIX: dans le codage -[01:15:52] VOIX: il est important -[01:15:54] VOIX: il y a des dossiers -[01:15:54] VOIX: qu'on ne code -[01:15:55] VOIX: que sur l'anapate -[01:15:58] VOIX: donc en fait -[01:16:00] VOIX: c'est ça -[01:16:00] VOIX: en fait -[01:16:02] VOIX: quand je te dis -[01:16:03] VOIX: l'exhaustivité -[01:16:04] VOIX: des documents -[01:16:05] VOIX: indépendamment -[01:16:06] VOIX: de qu'est-ce -[01:16:07] VOIX: qu'il y a dedans -[01:16:07] VOIX: c'est déjà -[01:16:08] VOIX: est-ce que j'ai -[01:16:09] VOIX: mes documents -[01:16:10] VOIX: et ça remonte -[01:16:11] VOIX: un tout petit peu -[01:16:11] VOIX: remonte un tout petit peu -[01:16:12] VOIX: tu vois là -[01:16:14] VOIX: régime diététique -[01:16:15] VOIX: est-ce qu'elle a vu -[01:16:17] VOIX: une diète -[01:16:18] VOIX: tu vois -[01:16:19] VOIX: qu'ils ont fait -[01:16:19] VOIX: le CG -[01:16:20] VOIX: ils ont -[01:16:21] VOIX: tu vois -[01:16:22] VOIX: à peu près -[01:16:23] VOIX: ça te donne -[01:16:23] VOIX: une idée -[01:16:24] VOIX: des soins -[01:16:25] VOIX: de ce qui a été fait -[01:16:26] VOIX: pour cette patiente -[01:16:27] VOIX: qui est resté -[01:16:27] VOIX: quand même -[01:16:27] VOIX: quelques jours -[01:16:29] VOIX: et si -[01:16:30] VOIX: tu descends -[01:16:36] VOIX: ça c'est du -[01:16:36] VOIX: blablabla -[01:16:37] VOIX: c'est toutes les -[01:16:38] VOIX: administrations -[01:16:39] VOIX: médicamenteuses -[01:16:40] VOIX: qui vont se répéter -[01:16:41] VOIX: parce que -[01:16:42] VOIX: ils ont fait -[01:16:43] VOIX: comme ils pouvaient -[01:16:44] VOIX: pour sortir -[01:16:45] VOIX: des éléments -[01:16:45] VOIX: pour le contrôle -[01:16:46] VOIX: nous si on fait -[01:16:48] VOIX: une extraction -[01:16:49] VOIX: IA -[01:16:50] VOIX: on va travailler -[01:16:51] VOIX: tout ça -[01:16:52] VOIX: de manière intelligente -[01:16:53] VOIX: si tu veux -[01:16:55] VOIX: ok -[01:16:58] VOIX: et alors lui -[01:16:59] VOIX: sur ce dossier là -[01:17:00] VOIX: en fait -[01:17:00] VOIX: toi tu peux me montrer -[01:17:01] VOIX: en fait -[01:17:01] VOIX: ce que vous aviez codé -[01:17:03] VOIX: comment vous l'avez traité -[01:17:05] VOIX: en amont -[01:17:06] VOIX: avant -[01:17:06] VOIX: regarde -[01:17:07] VOIX: résultat de radiologie -[01:17:09] VOIX: tu vois ce que je t'ai dit -[01:17:10] VOIX: tout à l'heure -[01:17:10] VOIX: il était caché -[01:17:12] VOIX: le compte rendu -[01:17:12] VOIX: du scanner -[01:17:14] VOIX: il descend encore -[01:17:15] VOIX: on va voir -[01:17:16] VOIX: parce que tout ça -[01:17:17] VOIX: on va le relire -[01:17:17] VOIX: là je ne fais que -[01:17:19] VOIX: la partie -[01:17:20] VOIX: quel document -[01:17:21] VOIX: qu'est-ce que je dois -[01:17:22] VOIX: avoir comme résultat -[01:17:23] VOIX: etc -[01:17:24] VOIX: mais on va -[01:17:24] VOIX: on va regarder -[01:17:25] VOIX: le contenu -[01:17:27] VOIX: descend encore -[01:17:28] VOIX: j'ai le compte rendu -[01:17:31] VOIX: de quoi -[01:17:32] VOIX: de l'échographie -[01:17:33] VOIX: tout à l'heure -[01:17:33] VOIX: on a vu qu'elle a eu -[01:17:34] VOIX: un scan -[01:17:34] VOIX: une écho -[01:17:35] VOIX: donc j'ai les comptes rendus -[01:17:36] VOIX: j'ai le compte rendu -[01:17:37] VOIX: de la radio -[01:17:39] VOIX: descend encore -[01:17:41] VOIX: j'ai le compte rendu -[01:17:43] VOIX: de la radio -[01:17:44] VOIX: écho -[01:17:45] VOIX: parce qu'elle en a eu -[01:17:46] VOIX: plusieurs je pense -[01:17:46] VOIX: à des dates différentes -[01:17:49] VOIX: parce qu'une fois -[01:17:49] VOIX: elle dit -[01:17:50] VOIX: voilà -[01:17:51] VOIX: j'ai les résultats -[01:17:52] VOIX: du labo -[01:17:54] VOIX: quelques résultats -[01:17:55] VOIX: est-ce que j'ai tout -[01:17:56] VOIX: est-ce que j'ai une partie -[01:17:57] VOIX: je n'en sais rien -[01:17:59] VOIX: descend on va voir -[01:18:02] VOIX: bon -[01:18:03] VOIX: j'ai quelques résultats -[01:18:04] VOIX: et tout ça -[01:18:06] VOIX: on le lit -[01:18:07] VOIX: en fait -[01:18:08] VOIX: on le regarde -[01:18:09] VOIX: mais avec un ordre -[01:18:10] VOIX: de priorité -[01:18:12] VOIX: et -[01:18:12] VOIX: je vais t'expliquer -[01:18:13] VOIX: l'ordre de priorité -[01:18:15] VOIX: et qu'est-ce que tu recherches -[01:18:17] VOIX: c'est important -[01:18:17] VOIX: que tu me dises -[01:18:18] VOIX: en fait ce que tu recherches -[01:18:19] VOIX: tu le regardes -[01:18:20] VOIX: mais qu'est-ce que tu -[01:18:21] VOIX: c'est quoi en fait -[01:18:22] VOIX: que tu regardes en premier -[01:18:23] VOIX: là -[01:18:23] VOIX: donc je vais remonter -[01:18:24] VOIX: aux observations médicales -[01:18:27] VOIX: tu sais -[01:18:27] VOIX: on a vu les urgences -[01:18:28] VOIX: on a compris -[01:18:29] VOIX: que la patiente -[01:18:30] VOIX: était venue -[01:18:30] VOIX: pour douleur -[01:18:31] VOIX: et pour diarrhée -[01:18:34] VOIX: tu te rappelles -[01:18:34] VOIX: oui oui -[01:18:36] VOIX: ou pas -[01:18:36] VOIX: donc je vais remonter -[01:18:38] VOIX: au tout début -[01:18:40] VOIX: quand elle a commencé -[01:18:42] VOIX: quand on a commencé -[01:18:43] VOIX: les observations médicales -[01:18:45] VOIX: là -[01:18:45] VOIX: je note -[01:18:46] VOIX: dans ma tête -[01:18:47] VOIX: elle vient aux urgences -[01:18:49] VOIX: pour douleur -[01:18:49] VOIX: diarrhée -[01:18:50] VOIX: et elle est suivie -[01:18:51] VOIX: pour une anémiférie -[01:18:53] VOIX: prive -[01:18:53] VOIX: et maladie du bien-mère -[01:18:54] VOIX: c'est tout ce que je note -[01:18:55] VOIX: soit dans un papier -[01:18:56] VOIX: les types -[01:18:57] VOIX: parfois -[01:18:58] VOIX: un papier -[01:18:58] VOIX: un crayon -[01:18:59] VOIX: et elle se note -[01:19:00] VOIX: deux trois trucs -[01:19:01] VOIX: pour dire -[01:19:02] VOIX: est-ce que je le code -[01:19:02] VOIX: après ou pas -[01:19:03] VOIX: pour ne pas oublier -[01:19:04] VOIX: tu descends -[01:19:05] VOIX: tu descends -[01:19:06] VOIX: on va avoir -[01:19:07] VOIX: les observations médicales -[01:19:10] VOIX: après ce truc là -[01:19:11] VOIX: il y avait -[01:19:11] VOIX: les observations -[01:19:12] VOIX: là -[01:19:12] VOIX: donc là -[01:19:14] VOIX: c'est hyper important -[01:19:16] VOIX: de voir -[01:19:17] VOIX: d'abord -[01:19:18] VOIX: les observations médicales -[01:19:20] VOIX: parfois -[01:19:21] VOIX: tu as -[01:19:21] VOIX: ceux qui commencent -[01:19:22] VOIX: du haut en bas -[01:19:23] VOIX: et parfois -[01:19:24] VOIX: tu as ceux -[01:19:25] VOIX: qui commencent -[01:19:26] VOIX: de bas en haut -[01:19:26] VOIX: d'accord -[01:19:28] VOIX: et la date -[01:19:28] VOIX: elle est -[01:19:29] VOIX: hyper importante -[01:19:30] VOIX: le 21 -[01:19:31] VOIX: c'est la fin -[01:19:32] VOIX: du séjour -[01:19:36] VOIX: ok -[01:19:36] VOIX: ok -[01:19:38] VOIX: il faut que je commence -[01:20:06] VOIX: par -[01:20:07] VOIX: par ça -[01:20:08] VOIX: 9h21 -[01:20:09] VOIX: qu'est-ce qu'il dit -[01:20:10] VOIX: au début -[01:20:10] VOIX: il vient pour douleur -[01:20:12] VOIX: et là -[01:20:13] VOIX: des antécédents -[01:20:14] VOIX: DT c'est diabète -[01:20:16] VOIX: de type 2 -[01:20:19] VOIX: ok -[01:20:22] VOIX: dyslipidémie -[01:20:22] VOIX: HTA -[01:20:23] VOIX: anémie -[01:20:24] VOIX: phériprive chronique -[01:20:25] VOIX: suivie -[01:20:26] VOIX: maladie de bière -[01:20:27] VOIX: mère -[01:20:27] VOIX: il te dit -[01:20:28] VOIX: qu'elle a déjà -[01:20:29] VOIX: fait des biopsies -[01:20:30] VOIX: elle a de la vitamine -[01:20:31] VOIX: P12 -[01:20:31] VOIX: elle est transfusée -[01:20:33] VOIX: elle a du fer -[01:20:34] VOIX: injecte -[01:20:34] VOIX: dernière transfusion -[01:20:35] VOIX: etc -[01:20:36] VOIX: mais on ne sait pas -[01:20:37] VOIX: si elle a un traitement -[01:20:38] VOIX: poursuivi ou pas -[01:20:39] VOIX: parce que jusque là -[01:20:40] VOIX: tout ce qu'il raconte -[01:20:41] VOIX: sur l'anémie -[01:20:42] VOIX: phériprive chronique -[01:20:43] VOIX: c'est un antécédent -[01:20:45] VOIX: je ne sais pas -[01:20:46] VOIX: est-ce qu'ils ont -[01:20:46] VOIX: continué le traitement -[01:20:49] VOIX: pendant le séjour -[01:20:50] VOIX: par exemple -[01:20:51] VOIX: c'est là où parfois -[01:20:53] VOIX: on va aller -[01:20:53] VOIX: voir les médicaments -[01:20:55] VOIX: et on va -[01:20:56] VOIX: on va dire -[01:20:57] VOIX: si la patiente -[01:20:58] VOIX: a de la vitamine -[01:20:58] VOIX: P12 -[01:20:59] VOIX: comme traitement -[01:21:00] VOIX: peut-être que je vais -[01:21:01] VOIX: coder la maladie -[01:21:02] VOIX: de bière mère -[01:21:03] VOIX: en diagnostic associé -[01:21:05] VOIX: ok -[01:21:06] VOIX: c'est pas en diagnostic -[01:21:07] VOIX: principal -[01:21:08] VOIX: parce qu'elle n'est pas -[01:21:08] VOIX: venue pour ça -[01:21:09] VOIX: le diagnostic principal -[01:21:10] VOIX: c'est toujours -[01:21:11] VOIX: pourquoi la patiente -[01:21:13] VOIX: est venue -[01:21:13] VOIX: en hospitalisation -[01:21:14] VOIX: elle n'est pas venue -[01:21:15] VOIX: pour son anémie -[01:21:15] VOIX: ni pour sa maladie -[01:21:16] VOIX: de bière mère -[01:21:18] VOIX: il dit qu'elle a -[01:21:19] VOIX: une cirrhose hépatique -[01:21:21] VOIX: ok -[01:21:21] VOIX: les traitements -[01:21:23] VOIX: elle a -[01:21:25] VOIX: vas-y -[01:21:25] VOIX: vas-y -[01:21:25] VOIX: vas-y -[01:21:26] VOIX: elle a un certain -[01:21:28] VOIX: type de traitement -[01:21:33] VOIX: bisoprolol -[01:21:33] VOIX: cardégique -[01:21:34] VOIX: bétformine -[01:21:35] VOIX: moi dans le diabète -[01:21:36] VOIX: de type 2 -[01:21:37] VOIX: je regarde -[01:21:37] VOIX: si elle a de l'insuline -[01:21:38] VOIX: ou pas -[01:21:39] VOIX: mais là pour l'instant -[01:21:40] VOIX: je ne vois pas -[01:21:42] VOIX: allergie -[01:21:43] VOIX: il n'y en a pas -[01:21:45] VOIX: MDV -[01:21:45] VOIX: mode de vie -[01:21:46] VOIX: MDV c'est mode de vie -[01:21:48] VOIX: il vit seul -[01:21:49] VOIX: autonome -[01:21:50] VOIX: pas d'aide -[01:21:50] VOIX: utilise un fauteuil roulant -[01:21:52] VOIX: depuis peu -[01:21:52] VOIX: pour les grandes distances -[01:21:54] VOIX: car -[01:21:55] VOIX: asthénie -[01:21:56] VOIX: tabagisme -[01:21:56] VOIX: pas de OH -[01:21:57] VOIX: chronique -[01:21:58] VOIX: OH c'est alcool -[01:21:59] VOIX: pas d'alcoolisme -[01:22:00] VOIX: chronique -[01:22:01] VOIX: histoire de la maladie -[01:22:02] VOIX: HDM -[01:22:03] VOIX: patiente de 68 ans -[01:22:05] VOIX: suivi par docteur -[01:22:06] VOIX: pour anémiférie -[01:22:07] VOIX: près de chronique -[01:22:08] VOIX: sans cause digestif -[01:22:08] VOIX: franche -[01:22:09] VOIX: imputable -[01:22:11] VOIX: elle doit voir -[01:22:12] VOIX: l'hémato -[01:22:13] VOIX: tu vois -[01:22:14] VOIX: elle est en cours -[01:22:14] VOIX: de bilan -[01:22:15] VOIX: quand même -[01:22:15] VOIX: de son anémie -[01:22:18] VOIX: elle a -[01:22:19] VOIX: été traité -[01:22:20] VOIX: pour infection -[01:22:21] VOIX: urinaire -[01:22:24] VOIX: en fait -[01:22:25] VOIX: il ne faut jamais -[01:22:26] VOIX: perdre de vue -[01:22:27] VOIX: la date -[01:22:28] VOIX: de l'entrée -[01:22:29] VOIX: pour savoir -[01:22:29] VOIX: est-ce que c'est des choses -[01:22:30] VOIX: qui sont encore en cours -[01:22:32] VOIX: ou est-ce que -[01:22:33] VOIX: c'est des choses -[01:22:33] VOIX: qui sont passées -[01:22:34] VOIX: parce que là -[01:22:35] VOIX: le 7 avril -[01:22:37] VOIX: l'antibiogramme -[01:22:38] VOIX: disparaissant de la douleur -[01:22:39] VOIX: abdo -[01:22:42] VOIX: donc c'est passé -[01:22:45] VOIX: elle a des douleurs -[01:22:47] VOIX: diarrhées depuis 15 jours -[01:22:48] VOIX: asthénie -[01:22:49] VOIX: pas de vomissement -[01:22:50] VOIX: etc -[01:22:50] VOIX: elle a la biologie -[01:22:52] VOIX: du 11 -[01:22:52] VOIX: donc avant de rentrer -[01:22:53] VOIX: elle avait une IRA -[01:22:55] VOIX: une insuffisance rénale -[01:22:56] VOIX: aiguë -[01:22:57] VOIX: parce que IRA -[01:22:58] VOIX: tu peux la voir -[01:23:00] VOIX: en insuffisance rénale -[01:23:01] VOIX: aiguë -[01:23:01] VOIX: ou en insuffisance respiratoire -[01:23:02] VOIX: aiguë -[01:23:05] VOIX: immédiatement -[01:23:07] VOIX: regarde en haut -[01:23:07] VOIX: le gras -[01:23:08] VOIX: en haut de la page -[01:23:09] VOIX: il y a un truc gras -[01:23:11] VOIX: créate -[01:23:11] VOIX: ça veut dire -[01:23:12] VOIX: rénale -[01:23:13] VOIX: c'est en lien -[01:23:14] VOIX: avec le rein -[01:23:15] VOIX: donc du coup -[01:23:15] VOIX: je sais que c'est -[01:23:16] VOIX: insuffisance rénale -[01:23:17] VOIX: aiguë -[01:23:18] VOIX: il n'y a pas -[01:23:19] VOIX: d'hyperlococytose -[01:23:20] VOIX: pas d'anémie -[01:23:20] VOIX: etc -[01:23:22] VOIX: en fait -[01:23:22] VOIX: ce que je suis en train -[01:23:23] VOIX: de chercher -[01:23:24] VOIX: c'est quelle est leur logique -[01:23:26] VOIX: qu'est-ce qu'ils vont faire -[01:23:28] VOIX: parce qu'elle vient -[01:23:29] VOIX: pour douleur diarrhée -[01:23:30] VOIX: est-ce que -[01:23:31] VOIX: ça sera un bilan -[01:23:32] VOIX: diagnostique -[01:23:33] VOIX: est-ce que ça sera -[01:23:33] VOIX: un bilan thérapeutique -[01:23:34] VOIX: qu'est-ce qu'ils vont trouver -[01:23:35] VOIX: donc elle n'a pas -[01:23:37] VOIX: de fièvre -[01:23:38] VOIX: langue grottie -[01:23:39] VOIX: consciente orientée -[01:23:41] VOIX: telle amène -[01:23:43] VOIX: jolie -[01:23:44] VOIX: sans -[01:23:45] VOIX: pas de haute voix -[01:23:46] VOIX: tachycardie -[01:23:47] VOIX: tension artérielle limite -[01:23:48] VOIX: pas de marbrures -[01:23:50] VOIX: pas de souffle -[01:23:51] VOIX: pas d'hérisipel -[01:23:53] VOIX: pas de toux -[01:23:54] VOIX: tu vois -[01:23:54] VOIX: il n'y a pas -[01:23:55] VOIX: papa -[01:23:55] VOIX: et là -[01:23:56] VOIX: l'abdomen -[01:23:56] VOIX: pléthorique -[01:23:57] VOIX: gonflé -[01:23:58] VOIX: mais sans -[01:23:58] VOIX: tintanisme -[01:23:59] VOIX: auréficiaire -[01:24:00] VOIX: mignonne -[01:24:00] VOIX: sensibilité -[01:24:01] VOIX: au hippocondre gauche -[01:24:02] VOIX: donc il a le ventre sensible -[01:24:04] VOIX: mais voilà -[01:24:05] VOIX: bruit -[01:24:06] VOIX: hydrosairie -[01:24:06] VOIX: discrète -[01:24:07] VOIX: pas de douleur -[01:24:08] VOIX: pas d'hypthère -[01:24:09] VOIX: jusque là -[01:24:09] VOIX: tout ça -[01:24:10] VOIX: moi je ne cherche pas -[01:24:11] VOIX: à coder -[01:24:12] VOIX: tout ça -[01:24:13] VOIX: parce que -[01:24:13] VOIX: l'objectif -[01:24:14] VOIX: c'est pas de -[01:24:15] VOIX: trouver des codes -[01:24:16] VOIX: pour -[01:24:17] VOIX: les symptômes -[01:24:18] VOIX: l'objectif -[01:24:18] VOIX: c'est de trouver -[01:24:19] VOIX: une pathologie -[01:24:20] VOIX: en fait -[01:24:21] VOIX: ok -[01:24:22] VOIX: donc au total -[01:24:23] VOIX: ils disent -[01:24:25] VOIX: insuffisance rénale -[01:24:26] VOIX: aiguë -[01:24:26] VOIX: probablement -[01:24:27] VOIX: fonctionnelle -[01:24:28] VOIX: sur déshydratation -[01:24:30] VOIX: majeure -[01:24:30] VOIX: en contexte -[01:24:31] VOIX: de diarrhée -[01:24:32] VOIX: et des faux -[01:24:32] VOIX: d'hydratation -[01:24:35] VOIX: peut-être -[01:24:36] VOIX: que mon -[01:24:37] VOIX: DP -[01:24:38] VOIX: je peux -[01:24:39] VOIX: être tenté -[01:24:40] VOIX: par mettre -[01:24:40] VOIX: la déshydratation -[01:24:42] VOIX: en diagnostic -[01:24:42] VOIX: principal -[01:24:44] VOIX: pourquoi ? -[01:24:45] VOIX: parce que -[01:24:47] VOIX: elle a -[01:24:47] VOIX: une insuffisance -[01:24:48] VOIX: rénale -[01:24:48] VOIX: aiguë -[01:24:49] VOIX: mais il me dit -[01:24:50] VOIX: qu'elle est -[01:24:51] VOIX: plutôt fonctionnelle -[01:24:52] VOIX: et la cause -[01:24:53] VOIX: c'est la déshydratation -[01:24:55] VOIX: et c'est moi -[01:24:56] VOIX: qui partage -[01:24:57] VOIX: mon écran -[01:24:58] VOIX: pour te montrer -[01:24:58] VOIX: une règle -[01:24:59] VOIX: d'accord -[01:25:01] VOIX: tu peux ? -[01:25:02] VOIX: vas-y -[01:25:02] VOIX: vas-y -[01:25:04] VOIX: tu peux -[01:25:07] VOIX: allez -[01:25:08] VOIX: dans le guide -[01:25:09] VOIX: méthodo -[01:25:10] VOIX: que je t'ai -[01:25:10] VOIX: montré -[01:25:11] VOIX: tout à l'heure -[01:25:11] VOIX: si je reviens -[01:25:12] VOIX: à mon sommaire -[01:25:15] VOIX: on avait parlé -[01:25:17] VOIX: des situations -[01:25:17] VOIX: cliniques -[01:25:18] VOIX: on a dit -[01:25:19] VOIX: hospitalisation -[01:25:20] VOIX: pour diagnostic -[01:25:20] VOIX: pour traitement -[01:25:21] VOIX: pour surveillance -[01:25:22] VOIX: là -[01:25:25] VOIX: je suis -[01:25:26] VOIX: sur -[01:25:27] VOIX: à l'heure -[01:25:28] VOIX: où on parle -[01:25:29] VOIX: sur une situation -[01:25:29] VOIX: de diagnostic -[01:25:30] VOIX: elle vient -[01:25:31] VOIX: pour douleur -[01:25:32] VOIX: diarrhée -[01:25:32] VOIX: avec insuffisance -[01:25:33] VOIX: rénale aiguë -[01:25:35] VOIX: ils disent -[01:25:35] VOIX: que la patiente -[01:25:36] VOIX: elle avait -[01:25:37] VOIX: une diarrhée -[01:25:37] VOIX: elle a fait -[01:25:38] VOIX: une déshydratation -[01:25:39] VOIX: qui a provoqué -[01:25:40] VOIX: l'insuffisance -[01:25:41] VOIX: rénale aiguë -[01:25:41] VOIX: fonctionnelle -[01:25:43] VOIX: donc je peux -[01:25:45] VOIX: être tentée -[01:25:45] VOIX: en DP -[01:25:46] VOIX: soit par la diarrhée -[01:25:47] VOIX: soit par la déshydratation -[01:25:48] VOIX: tu vois -[01:25:49] VOIX: c'est pas -[01:25:49] VOIX: aussi formel -[01:25:50] VOIX: et donc -[01:25:52] VOIX: j'ai une règle -[01:25:53] VOIX: ici -[01:25:54] VOIX: qui me dit -[01:25:55] VOIX: que le diagnostic -[01:25:56] VOIX: s'accompagne -[01:25:57] VOIX: au nom -[01:25:57] VOIX: d'un traitement -[01:25:58] VOIX: oui pardon -[01:26:00] VOIX: hospitalisation -[01:26:01] VOIX: pour diagnostic -[01:26:02] VOIX: la situation -[01:26:03] VOIX: est celle -[01:26:03] VOIX: d'un patient -[01:26:04] VOIX: hospitalisé -[01:26:05] VOIX: en raison -[01:26:06] VOIX: d'une symptomatologie -[01:26:07] VOIX: pour un diagnostic -[01:26:08] VOIX: éthiologique -[01:26:09] VOIX: donc le motif -[01:26:10] VOIX: d'hospitalisation -[01:26:11] VOIX: il a un symptôme -[01:26:12] VOIX: une douleur -[01:26:12] VOIX: une diarrhée -[01:26:13] VOIX: etc -[01:26:13] VOIX: et je vais essayer -[01:26:15] VOIX: de comprendre -[01:26:16] VOIX: pourquoi -[01:26:20] VOIX: le diagnostic -[01:26:21] VOIX: qu'il soit traité -[01:26:23] VOIX: ou pas -[01:26:23] VOIX: à partir du moment -[01:26:24] VOIX: où je fais le diagnostic -[01:26:27] VOIX: sauf en cas -[01:26:28] VOIX: de chirurgie -[01:26:28] VOIX: parce que là -[01:26:29] VOIX: il y a le cas -[01:26:29] VOIX: de traitement chirurgical -[01:26:30] VOIX: qui change -[01:26:32] VOIX: cette règle -[01:26:35] VOIX: j'ai deux cas -[01:26:37] VOIX: soit -[01:26:38] VOIX: le séjour -[01:26:39] VOIX: a permis -[01:26:39] VOIX: le diagnostic -[01:26:40] VOIX: de l'infection causale -[01:26:41] VOIX: soit -[01:26:41] VOIX: ils vont faire -[01:26:42] VOIX: une biologie -[01:26:43] VOIX: une bactériale -[01:26:44] VOIX: de la déshydratation -[01:26:45] VOIX: il va me dire -[01:26:47] VOIX: diarrhée -[01:26:48] VOIX: liée à telle bactérie -[01:26:49] VOIX: ou tel virus -[01:26:50] VOIX: peut-être que -[01:26:51] VOIX: dans ce cas-là -[01:26:51] VOIX: je poterai la diarrhée -[01:26:53] VOIX: soit -[01:26:53] VOIX: je me contenterai -[01:26:55] VOIX: de la déshydratation -[01:26:56] VOIX: mais déjà -[01:26:56] VOIX: je commence à réfléchir -[01:26:58] VOIX: tu vois -[01:26:59] VOIX: et ça -[01:27:00] VOIX: c'est les règles -[01:27:01] VOIX: que j'utilise -[01:27:01] VOIX: c'est-à-dire -[01:27:02] VOIX: si le séjour -[01:27:03] VOIX: a permis -[01:27:03] VOIX: le diagnostic -[01:27:04] VOIX: de l'infection causale -[01:27:05] VOIX: elle va être le DP -[01:27:06] VOIX: par exemple -[01:27:08] VOIX: hospice -[01:27:08] VOIX: pour une confusion -[01:27:09] VOIX: mais on découvre -[01:27:10] VOIX: une tumeur cérébrale -[01:27:11] VOIX: le DP -[01:27:12] VOIX: c'est la tumeur cérébrale -[01:27:13] VOIX: hospice -[01:27:14] VOIX: pour douleurs thoraciques -[01:27:15] VOIX: le diagnostic -[01:27:16] VOIX: de l'angine de poitrine -[01:27:17] VOIX: le DP -[01:27:17] VOIX: c'est l'angine de poitrine -[01:27:19] VOIX: hospice -[01:27:20] VOIX: pour anémie -[01:27:21] VOIX: occlusion -[01:27:22] VOIX: découverte -[01:27:23] VOIX: d'un cancer colique -[01:27:24] VOIX: c'est le cancer colique -[01:27:25] VOIX: là on a l'exemple -[01:27:26] VOIX: de hospice -[01:27:27] VOIX: pour douleurs -[01:27:29] VOIX: diarrhées -[01:27:29] VOIX: insuffisance rénale -[01:27:31] VOIX: donc soit -[01:27:32] VOIX: c'est la diarrhée -[01:27:32] VOIX: qui cause tout ça -[01:27:33] VOIX: soit c'est -[01:27:34] VOIX: il y a une infection -[01:27:36] VOIX: etc -[01:27:37] VOIX: soit c'est la déshydratation -[01:27:39] VOIX: qui cause -[01:27:39] VOIX: l'insuffisance rénale -[01:27:40] VOIX: je peux avoir -[01:27:41] VOIX: plusieurs niveaux -[01:27:42] VOIX: de cause à effet -[01:27:43] VOIX: en tout cas -[01:27:44] VOIX: ça c'est une règle -[01:27:44] VOIX: par contre -[01:27:46] VOIX: si je ne trouve pas -[01:27:47] VOIX: de cause -[01:27:48] VOIX: dans ce cas là -[01:27:50] VOIX: hospice pour céphalée -[01:27:51] VOIX: je garde des céphalées -[01:27:53] VOIX: hospice pour douleurs -[01:27:54] VOIX: abdos -[01:27:54] VOIX: mais je ne trouve pas -[01:27:55] VOIX: pourquoi -[01:27:56] VOIX: ça sera douleurs abdos -[01:27:58] VOIX: j'ai l'état de choc -[01:27:59] VOIX: mais je ne trouve pas -[01:27:59] VOIX: pourquoi -[01:28:00] VOIX: le DP reste -[01:28:01] VOIX: l'état de choc -[01:28:03] VOIX: syndrome inflammatoire -[01:28:04] VOIX: je ne sais pas pourquoi -[01:28:05] VOIX: ça reste -[01:28:05] VOIX: le syndrome inflammatoire -[01:28:07] VOIX: donc c'est un peu ça -[01:28:08] VOIX: les règles -[01:28:11] VOIX: tu vois -[01:28:11] VOIX: tu n'as pas la cause -[01:28:12] VOIX: et bien là -[01:28:13] VOIX: la douleur abdos -[01:28:14] VOIX: il a une polyarthrite -[01:28:15] VOIX: de l'homatoïde -[01:28:16] VOIX: en raison de douleurs -[01:28:17] VOIX: abdos -[01:28:18] VOIX: disparition des douleurs -[01:28:19] VOIX: en 48 heures -[01:28:20] VOIX: mais pas de cause -[01:28:20] VOIX: trouvée -[01:28:22] VOIX: c'est les douleurs -[01:28:23] VOIX: abdos -[01:28:23] VOIX: parce qu'on n'a pas -[01:28:24] VOIX: compris pourquoi -[01:28:26] VOIX: ok -[01:28:29] VOIX: ça c'est les règles -[01:28:30] VOIX: du guide -[01:28:30] VOIX: tu as vu -[01:28:31] VOIX: que je reviens -[01:28:31] VOIX: aux règles -[01:28:32] VOIX: parce que moi -[01:28:33] VOIX: perso -[01:28:34] VOIX: et j'ai formé -[01:28:35] VOIX: Pauline à ça -[01:28:36] VOIX: les gens -[01:28:37] VOIX: qui réfléchissent -[01:28:39] VOIX: et qui sont capables -[01:28:40] VOIX: d'argumenter le dossier -[01:28:41] VOIX: c'est les gens -[01:28:43] VOIX: qui rapportent -[01:28:43] VOIX: la logique -[01:28:44] VOIX: au guide méthodologique -[01:28:50] VOIX: est-ce que jusque là -[01:28:51] VOIX: ça va -[01:28:51] VOIX: oui ça va -[01:28:52] VOIX: mais ça fait -[01:28:52] VOIX: beaucoup -[01:28:53] VOIX: beaucoup -[01:28:53] VOIX: beaucoup -[01:28:53] VOIX: d'informations -[01:28:54] VOIX: en fait -[01:28:54] VOIX: que je dois -[01:28:57] VOIX: retenir -[01:28:57] VOIX: tu vas enregistrer -[01:28:59] VOIX: enregistre -[01:29:00] VOIX: oui mais je t'enregistre -[01:29:01] VOIX: mais après -[01:29:01] VOIX: en fait -[01:29:01] VOIX: il faut que je traduise -[01:29:02] VOIX: tout ce que tu viens -[01:29:03] VOIX: de me dire -[01:29:03] VOIX: il faut que je la traduise -[01:29:04] VOIX: dans un autre langage -[01:29:05] VOIX: que je ne pourrais pas -[01:29:07] VOIX: en fait apprendre -[01:29:07] VOIX: surtout aujourd'hui -[01:29:08] VOIX: donc on peut s'arrêter là -[01:29:09] VOIX: déjà -[01:29:10] VOIX: oui je pense en fait -[01:29:10] VOIX: que oui -[01:29:11] VOIX: j'ai peut-être -[01:29:11] VOIX: 2-3 questions -[01:29:12] VOIX: à te poser -[01:29:13] VOIX: si tu permets -[01:29:13] VOIX: parce que c'est important -[01:29:14] VOIX: pour moi -[01:29:15] VOIX: non -[01:29:16] VOIX: en fait -[01:29:17] VOIX: attends -[01:29:18] VOIX: ne bouge pas -[01:29:20] VOIX: je regarde mon écran -[01:29:21] VOIX: donc je ne te vois plus -[01:29:22] VOIX: c'est dommage -[01:29:23] VOIX: mais bon -[01:29:24] VOIX: c'est pas grave -[01:29:25] VOIX: donc tu m'as répondu -[01:29:26] VOIX: je regarde en fait -[01:29:27] VOIX: parce que je me suis noté -[01:29:27] VOIX: plein plein de questions -[01:29:28] VOIX: et tu m'as répondu -[01:29:30] VOIX: en fait -[01:29:30] VOIX: à une grosse partie -[01:29:31] VOIX: quel est le critère -[01:29:32] VOIX: pour toi -[01:29:33] VOIX: oui -[01:29:33] VOIX: si tu as en fait -[01:29:35] VOIX: 2 d'épées plausibles -[01:29:37] VOIX: qu'est-ce que c'est ton critère -[01:29:42] VOIX: mon critère -[01:29:43] VOIX: je repartage mon écran -[01:29:46] VOIX: non mais si tu peux me le dire -[01:29:47] VOIX: comme ça -[01:29:48] VOIX: c'est celui qui a mobilisé -[01:29:50] VOIX: le plus d'efforts -[01:29:54] VOIX: je ne sais pas comment -[01:29:55] VOIX: en quelle règle -[01:29:57] VOIX: en fait petit à petit -[01:29:59] VOIX: il faudra qu'on intègre -[01:30:02] VOIX: dans la notion de règle -[01:30:03] VOIX: dans la logique -[01:30:05] VOIX: c'est hyper important -[01:30:12] VOIX: ok -[01:30:12] VOIX: ok ça je l'ai intégré -[01:30:14] VOIX: mais en fait -[01:30:15] VOIX: on reverra -[01:30:16] VOIX: de toute façon -[01:30:16] VOIX: il va falloir -[01:30:17] VOIX: que je pense -[01:30:18] VOIX: c'est pas fini -[01:30:22] VOIX: c'est celui -[01:30:23] VOIX: en fait -[01:30:25] VOIX: soit celui -[01:30:25] VOIX: qui a mobilisé -[01:30:26] VOIX: le plus d'efforts -[01:30:27] VOIX: et il te dit -[01:30:28] VOIX: si les deux -[01:30:28] VOIX: ont mobilisé -[01:30:30] VOIX: le même niveau d'efforts -[01:30:31] VOIX: ça reste -[01:30:32] VOIX: au choix -[01:30:33] VOIX: de l'établissement -[01:30:34] VOIX: d'accord -[01:30:34] VOIX: parfois on simule -[01:30:36] VOIX: et on choisit -[01:30:36] VOIX: celui qui valorise -[01:30:37] VOIX: le mieux -[01:30:38] VOIX: et on peut le défendre -[01:30:39] VOIX: alors attendez -[01:30:41] VOIX: mobiliser -[01:30:45] VOIX: je vais regarder -[01:30:46] VOIX: moi si je te trouve -[01:30:47] VOIX: la règle -[01:30:47] VOIX: ça sera plus simple -[01:31:08] VOIX: en même temps -[01:31:09] VOIX: en fait -[01:31:09] VOIX: je peux te poser -[01:31:10] VOIX: des questions -[01:31:11] VOIX: ben bien sûr -[01:31:17] VOIX: une question -[01:31:18] VOIX: en fait -[01:31:18] VOIX: c'est -[01:31:19] VOIX: c'est des symptômes -[01:31:20] VOIX: versus -[01:31:21] VOIX: causes -[01:31:21] VOIX: pardon pardon -[01:31:22] VOIX: je peux t'embêter -[01:31:23] VOIX: je vais juste partager -[01:31:25] VOIX: un truc -[01:31:25] VOIX: pour te montrer -[01:31:27] VOIX: toujours un peu -[01:31:27] VOIX: la logique -[01:31:28] VOIX: tu vois -[01:31:29] VOIX: tu m'as posé -[01:31:30] VOIX: la question -[01:31:30] VOIX: que faire -[01:31:32] VOIX: si l'analyse -[01:31:32] VOIX: en termes -[01:31:33] VOIX: de situation clinique -[01:31:33] VOIX: propose plus -[01:31:34] VOIX: d'un diagnostic -[01:31:35] VOIX: principal -[01:31:36] VOIX: tu vois -[01:31:37] VOIX: le DP -[01:31:38] VOIX: étant le problème -[01:31:39] VOIX: de santé -[01:31:39] VOIX: qui a motivé -[01:31:40] VOIX: l'admission -[01:31:41] VOIX: une telle circonstance -[01:31:42] VOIX: ne peut être que rare -[01:31:43] VOIX: le DP déterminé -[01:31:44] VOIX: à la sortie -[01:31:45] VOIX: est alors celui -[01:31:45] VOIX: du problème -[01:31:46] VOIX: qui a mobilisé -[01:31:47] VOIX: l'essentiel -[01:31:47] VOIX: des efforts -[01:31:48] VOIX: de soins -[01:31:48] VOIX: si tu en as plus -[01:31:50] VOIX: dans les cas -[01:31:51] VOIX: où les deux problèmes -[01:31:52] VOIX: auraient mobilisé -[01:31:52] VOIX: des efforts -[01:31:53] VOIX: d'importance -[01:31:53] VOIX: comparables -[01:31:54] VOIX: c'est à dire -[01:31:55] VOIX: dans le cas -[01:31:55] VOIX: de prise en charge -[01:31:57] VOIX: équivalente -[01:31:57] VOIX: et dans ce cas -[01:31:58] VOIX: seulement -[01:31:58] VOIX: le choix du DP -[01:31:59] VOIX: parmi les ex-ecos -[01:32:00] VOIX: est laissé à l'établissement -[01:32:01] VOIX: de santé -[01:32:02] VOIX: d'accord -[01:32:02] VOIX: ok -[01:32:03] VOIX: tu vois -[01:32:04] VOIX: et c'est la règle -[01:32:05] VOIX: M2 -[01:32:05] VOIX: ok -[01:32:09] VOIX: tiens -[01:32:09] VOIX: que si tu vois -[01:32:10] VOIX: je pourrais te dire -[01:32:11] VOIX: non mais c'est -[01:32:12] VOIX: c'est -[01:32:13] VOIX: c'est la règle -[01:32:13] VOIX: M2 -[01:32:14] VOIX: et c'est telle réponse -[01:32:16] VOIX: ouais -[01:32:16] VOIX: c'est impeccable -[01:32:17] VOIX: pour moi -[01:32:17] VOIX: alors donc -[01:32:18] VOIX: ça c'est fait -[01:32:19] VOIX: le symptôme -[01:32:20] VOIX: versus cause -[01:32:22] VOIX: c'est -[01:32:23] VOIX: dans quel cas -[01:32:24] VOIX: en fait -[01:32:24] VOIX: le symptôme -[01:32:25] VOIX: reste un DP -[01:32:27] VOIX: si le bilan -[01:32:29] VOIX: ne trouve -[01:32:30] VOIX: aucun -[01:32:30] VOIX: aucune cause -[01:32:31] VOIX: d'accord -[01:32:32] VOIX: le symptôme -[01:32:33] VOIX: il reste le DP -[01:32:35] VOIX: si par exemple -[01:32:36] VOIX: la douleur -[01:32:37] VOIX: abdo -[01:32:37] VOIX: mais finalement -[01:32:38] VOIX: ils font le scan -[01:32:39] VOIX: ils font l'écho -[01:32:40] VOIX: etc -[01:32:40] VOIX: il n'y a aucune cause -[01:32:41] VOIX: à la douleur -[01:32:42] VOIX: ils vont traiter -[01:32:43] VOIX: la douleur -[01:32:44] VOIX: et ils disent -[01:32:45] VOIX: douleur -[01:32:46] VOIX: sans cause précisée -[01:32:52] VOIX: ok -[01:32:55] VOIX: ça c'était -[01:32:56] VOIX: la règle -[01:32:58] VOIX: D1 -[01:32:59] VOIX: je pense -[01:33:01] VOIX: on va regarder -[01:33:05] VOIX: situation -[01:33:06] VOIX: attends -[01:33:06] VOIX: je vais revenir -[01:33:08] VOIX: un peu avant -[01:33:09] VOIX: parce que c'est important -[01:33:10] VOIX: que je te montre -[01:33:11] VOIX: en fonction -[01:33:12] VOIX: de tes questions -[01:33:14] VOIX: attends -[01:33:14] VOIX: je partage -[01:33:17] VOIX: par exemple -[01:33:18] VOIX: là c'était -[01:33:19] VOIX: la situation -[01:33:20] VOIX: le séjour -[01:33:21] VOIX: a permis -[01:33:21] VOIX: le diagnostic -[01:33:22] VOIX: de l'infection -[01:33:23] VOIX: causale -[01:33:24] VOIX: il n'a pas été -[01:33:26] VOIX: découvert -[01:33:26] VOIX: de cause -[01:33:27] VOIX: à la symptomatologie -[01:33:28] VOIX: c'est le cas -[01:33:28] VOIX: dont tu parles -[01:33:29] VOIX: d'accord -[01:33:31] VOIX: il te dit -[01:33:32] VOIX: il n'a pas été -[01:33:32] VOIX: découvert -[01:33:33] VOIX: de cause -[01:33:33] VOIX: à la symptomatologie -[01:33:34] VOIX: lorsqu'il n'a pas -[01:33:35] VOIX: été découvert -[01:33:36] VOIX: de cause -[01:33:36] VOIX: à la symptomatologie -[01:33:38] VOIX: et les DP -[01:33:38] VOIX: c'est le symptôme -[01:33:39] VOIX: c'est à dire -[01:33:41] VOIX: au speech -[01:33:41] VOIX: le patient -[01:33:42] VOIX: il vient pour -[01:33:42] VOIX: céphalée -[01:33:43] VOIX: mais la conclusion -[01:33:45] VOIX: c'est céphalée -[01:33:46] VOIX: sans cause -[01:33:46] VOIX: trouvée -[01:33:47] VOIX: le DP -[01:33:48] VOIX: reste -[01:33:48] VOIX: les céphalées -[01:33:49] VOIX: tu vois -[01:33:51] VOIX: c'est clair -[01:33:52] VOIX: c'est clair -[01:33:53] VOIX: ou tu veux plus -[01:33:54] VOIX: d'exemple -[01:33:54] VOIX: non mais -[01:33:55] VOIX: là pour l'instant -[01:33:55] VOIX: la douleur abdo -[01:33:56] VOIX: c'est ce que je t'ai montré -[01:33:57] VOIX: ici -[01:33:58] VOIX: il vient pour douleur abdo -[01:34:00] VOIX: mais disparaissant des douleurs -[01:34:01] VOIX: à 48 heures -[01:34:02] VOIX: pas de cause -[01:34:03] VOIX: trouvée -[01:34:04] VOIX: le DP -[01:34:05] VOIX: c'est les douleurs abdo -[01:34:06] VOIX: tout à l'heure -[01:34:08] VOIX: dans ton dossier -[01:34:09] VOIX: on avait -[01:34:10] VOIX: la pancréatite -[01:34:11] VOIX: on avait -[01:34:12] VOIX: la lithiase -[01:34:12] VOIX: de la vésicule biliaire -[01:34:14] VOIX: etc -[01:34:14] VOIX: c'est ce qui expliquait -[01:34:16] VOIX: la douleur -[01:34:16] VOIX: là -[01:34:17] VOIX: ta douleur -[01:34:18] VOIX: mais il fait le bilan -[01:34:19] VOIX: il trouve rien -[01:34:20] VOIX: ça reste la douleur -[01:34:22] VOIX: ça c'est fréquent aussi -[01:34:23] VOIX: surtout dans les services -[01:34:24] VOIX: des urgences -[01:34:25] VOIX: etc -[01:34:27] VOIX: ok -[01:34:30] VOIX: ok -[01:34:30] VOIX: ok -[01:34:31] VOIX: c'est la D2 -[01:34:32] VOIX: la D2 -[01:34:33] VOIX: pour -[01:34:34] VOIX: cette deuxième réponse -[01:34:36] VOIX: c'est noté -[01:34:38] VOIX: c'est noté -[01:34:38] VOIX: quel rôle -[01:34:39] VOIX: jouent les actes -[01:34:40] VOIX: traitement -[01:34:41] VOIX: en fait -[01:34:41] VOIX: dans la décision -[01:34:42] VOIX: en fait -[01:34:42] VOIX: du DP -[01:34:45] VOIX: en fait -[01:34:45] VOIX: ça rejouit -[01:34:46] VOIX: à peu près -[01:34:46] VOIX: ce que tu viens -[01:34:46] VOIX: de me dire -[01:34:47] VOIX: aussi -[01:34:47] VOIX: tu passes en fait -[01:34:48] VOIX: sur le guide -[01:34:49] VOIX: méthodologique -[01:34:50] VOIX: ok -[01:34:50] VOIX: ta question -[01:34:51] VOIX: j'ai pas forcément -[01:34:52] VOIX: compris -[01:34:52] VOIX: le rôle -[01:34:53] VOIX: des actes -[01:34:54] VOIX: voilà -[01:34:57] VOIX: actes et traitement -[01:34:58] VOIX: dans la décision -[01:34:59] VOIX: du DP -[01:34:59] VOIX: ah -[01:35:00] VOIX: c'était la situation -[01:35:01] VOIX: de traitement -[01:35:02] VOIX: en fait -[01:35:03] VOIX: dans les actes -[01:35:03] VOIX: alors -[01:35:04] VOIX: je vais faire -[01:35:04] VOIX: une petite pause -[01:35:05] VOIX: parce que ça -[01:35:06] VOIX: tu ne le trouveras pas -[01:35:07] VOIX: dans le guide méthodo -[01:35:08] VOIX: parce que je ne t'ai pas -[01:35:08] VOIX: donné -[01:35:09] VOIX: toute la logique -[01:35:11] VOIX: sur les actes -[01:35:12] VOIX: tu n'as pas encore -[01:35:13] VOIX: la réglementation -[01:35:14] VOIX: là-dessus -[01:35:14] VOIX: mais dans les actes -[01:35:16] VOIX: CCAM -[01:35:17] VOIX: il y en a -[01:35:18] VOIX: qu'on appelle -[01:35:19] VOIX: les actes -[01:35:19] VOIX: classants -[01:35:20] VOIX: et les actes -[01:35:21] VOIX: non classants -[01:35:22] VOIX: les actes -[01:35:23] VOIX: classants -[01:35:24] VOIX: par exemple -[01:35:25] VOIX: la cholycystectomie -[01:35:26] VOIX: tu vas lui donner -[01:35:28] VOIX: la cholycystite -[01:35:29] VOIX: ou la pancréatite -[01:35:30] VOIX: etc -[01:35:31] VOIX: elle va t'orienter -[01:35:32] VOIX: vers un GHM -[01:35:33] VOIX: chirurgical -[01:35:34] VOIX: puisque -[01:35:35] VOIX: l'acte -[01:35:36] VOIX: il joue -[01:35:37] VOIX: un rôle -[01:35:37] VOIX: dans le groupage -[01:35:39] VOIX: dans le GHM -[01:35:40] VOIX: et dans le GHS -[01:35:41] VOIX: et dans le tarif -[01:35:42] VOIX: les actes -[01:35:43] VOIX: non classants -[01:35:44] VOIX: c'est toutes -[01:35:45] VOIX: les imageries -[01:35:46] VOIX: l'acte -[01:35:50] VOIX: d'anapath -[01:35:51] VOIX: le scanner -[01:35:52] VOIX: l'IRM -[01:35:53] VOIX: les échos -[01:35:53] VOIX: etc -[01:35:54] VOIX: c'est des actes -[01:35:55] VOIX: d'imagerie -[01:35:56] VOIX: mais qui ne vont pas -[01:35:56] VOIX: influencer -[01:35:57] VOIX: le groupage -[01:35:59] VOIX: mais par contre -[01:36:00] VOIX: ils peuvent avoir -[01:36:01] VOIX: un rôle -[01:36:01] VOIX: parce que -[01:36:02] VOIX: peut-être que c'est -[01:36:03] VOIX: dans le scanner -[01:36:04] VOIX: ou dans l'échographie -[01:36:05] VOIX: etc -[01:36:05] VOIX: le compte-rendu -[01:36:06] VOIX: qui va te dire -[01:36:08] VOIX: au fait -[01:36:09] VOIX: le scan -[01:36:10] VOIX: il retrouve -[01:36:11] VOIX: une tumeur -[01:36:12] VOIX: du côlon -[01:36:13] VOIX: et elle peut -[01:36:14] VOIX: expliquer très bien -[01:36:15] VOIX: la douleur abdominale -[01:36:16] VOIX: et dans ce cas-là -[01:36:17] VOIX: c'est la tumeur -[01:36:18] VOIX: du côlon -[01:36:19] VOIX: qu'il faut suivre -[01:36:20] VOIX: pour donner -[01:36:23] VOIX: de la précision -[01:36:24] VOIX: au codage -[01:36:24] VOIX: du diagnostic -[01:36:25] VOIX: donc -[01:36:26] VOIX: c'est soit -[01:36:27] VOIX: à travers -[01:36:27] VOIX: l'acte classant -[01:36:30] VOIX: tu sais que -[01:36:31] VOIX: si ton acte -[01:36:34] VOIX: c'est une cholycystectomie -[01:36:35] VOIX: tu vas regarder -[01:36:37] VOIX: le résultat -[01:36:38] VOIX: le compte-rendu -[01:36:39] VOIX: opératoire -[01:36:40] VOIX: tu vas regarder -[01:36:41] VOIX: pourquoi il a enlevé -[01:36:42] VOIX: il a fait une cholycystectomie -[01:36:45] VOIX: et ça t'oriente -[01:36:47] VOIX: vers le DP -[01:36:48] VOIX: avec la règle T -[01:36:49] VOIX: alors je ne sais plus -[01:36:50] VOIX: comment il s'appelle -[01:36:51] VOIX: traitement unique -[01:36:51] VOIX: chirurgical -[01:36:52] VOIX: et la deuxième situation -[01:36:55] VOIX: c'est tous les résultats -[01:36:56] VOIX: des examens -[01:36:58] VOIX: des actes -[01:36:59] VOIX: des examens -[01:37:00] VOIX: surtout les examens -[01:37:01] VOIX: d'imagerie -[01:37:01] VOIX: qui peuvent être codés -[01:37:04] VOIX: parce qu'ils montrent -[01:37:05] VOIX: des diagnostics -[01:37:06] VOIX: qu'ils posent des diagnostics -[01:37:07] VOIX: qu'ils aident -[01:37:07] VOIX: à poser des diagnostics -[01:37:10] VOIX: et bien ok -[01:37:12] VOIX: moi qui croyais en fait -[01:37:13] VOIX: que mon métier -[01:37:13] VOIX: était super compliqué -[01:37:18] VOIX: en fait c'est simple -[01:37:19] VOIX: et je comprends -[01:37:20] VOIX: pourquoi tu as envie -[01:37:20] VOIX: en fait d'apprendre -[01:37:21] VOIX: en fait à lire -[01:37:23] VOIX: j'ai compris -[01:37:24] VOIX: j'ai compris le bordel -[01:37:25] VOIX: pardon -[01:37:27] VOIX: non non mais -[01:37:28] VOIX: tu as raison -[01:37:29] VOIX: c'est un vrai bordel -[01:37:29] VOIX: et c'est pour ça -[01:37:31] VOIX: que quand je tempère -[01:37:32] VOIX: parfois -[01:37:33] VOIX: certaines situations -[01:37:34] VOIX: en disant -[01:37:34] VOIX: c'est pas -[01:37:35] VOIX: A plus B -[01:37:36] VOIX: égal à C -[01:37:37] VOIX: il faut quand même -[01:37:38] VOIX: qu'on réfléchisse -[01:37:39] VOIX: proprement -[01:37:41] VOIX: pour ne pas -[01:37:42] VOIX: se planter -[01:37:43] VOIX: quitte -[01:37:44] VOIX: par exemple -[01:37:45] VOIX: dans -[01:37:45] VOIX: on peut procéder -[01:37:48] VOIX: aussi -[01:37:49] VOIX: par méthode -[01:37:50] VOIX: c'est à dire -[01:37:51] VOIX: on prend les séjours -[01:37:52] VOIX: chirurgicaux -[01:37:53] VOIX: et -[01:37:54] VOIX: tu travailles -[01:37:55] VOIX: sur une règle -[01:37:57] VOIX: sur une règle -[01:37:58] VOIX: traitement unique -[01:37:59] VOIX: chirurgical -[01:38:00] VOIX: pour dire -[01:38:03] VOIX: il n'y en a pas -[01:38:03] VOIX: beaucoup dans ce contrôle -[01:38:05] VOIX: par contre -[01:38:05] VOIX: parce qu'il y a -[01:38:06] VOIX: quelques-uns -[01:38:07] VOIX: mais c'est pas très fréquent -[01:38:08] VOIX: mais si on avait -[01:38:09] VOIX: un établissement MCO -[01:38:10] VOIX: ça serait très intéressant -[01:38:12] VOIX: d'exploiter -[01:38:13] VOIX: cette notion -[01:38:14] VOIX: de dire -[01:38:14] VOIX: je prends les séjours -[01:38:15] VOIX: de chirurgie -[01:38:16] VOIX: parce que tu vas voir -[01:38:17] VOIX: le patient qui a une -[01:38:18] VOIX: prothèse de hanche -[01:38:20] VOIX: il a deux -[01:38:21] VOIX: trois diagnostics -[01:38:21] VOIX: c'est soit il a fait -[01:38:22] VOIX: une fracture -[01:38:22] VOIX: soit qu'il a une arthrose -[01:38:24] VOIX: de la hanche -[01:38:25] VOIX: voilà -[01:38:26] VOIX: ça va être vite vu -[01:38:28] VOIX: si tu veux -[01:38:28] VOIX: et ça permet -[01:38:29] VOIX: d'entraîner -[01:38:30] VOIX: aussi à dire -[01:38:31] VOIX: avec tel acte -[01:38:32] VOIX: je vais avoir -[01:38:33] VOIX: souvent -[01:38:34] VOIX: tel diagnostic -[01:38:35] VOIX: avec la règle -[01:38:36] VOIX: traitement unique -[01:38:37] VOIX: chirurgical -[01:38:37] VOIX: tu vois -[01:38:38] VOIX: parce que ça -[01:38:39] VOIX: c'est les séjours -[01:38:41] VOIX: les plus faciles -[01:38:42] VOIX: à coder -[01:38:42] VOIX: c'est les traitements -[01:38:43] VOIX: uniques chirurgical -[01:38:44] VOIX: quand on forme -[01:38:45] VOIX: les teams -[01:38:45] VOIX: on commence par -[01:38:46] VOIX: les séjours -[01:38:47] VOIX: de zéro jour -[01:38:47] VOIX: et puis par -[01:38:49] VOIX: la chirurgie -[01:38:49] VOIX: parce que la chirurgie -[01:38:50] VOIX: la logique -[01:38:51] VOIX: elle est plus facile -[01:38:52] VOIX: que dire -[01:38:52] VOIX: le symptôme -[01:38:53] VOIX: la cause -[01:38:54] VOIX: machin -[01:38:54] VOIX: en plus -[01:38:55] VOIX: elles n'ont pas -[01:38:55] VOIX: de notion médical -[01:38:56] VOIX: donc c'est pas évident -[01:38:59] VOIX: il faut avoir -[01:39:00] VOIX: quand même -[01:39:01] VOIX: médical -[01:39:02] VOIX: oui j'ai l'impression -[01:39:04] VOIX: j'ai un peu l'impression -[01:39:08] VOIX: je continue mes questions -[01:39:11] VOIX: bien sûr -[01:39:12] VOIX: alors -[01:39:13] VOIX: parce que je suis organisé -[01:39:14] VOIX: ça se voit pas -[01:39:15] VOIX: mais je suis un garçon -[01:39:16] VOIX: organisé -[01:39:17] VOIX: c'est parfait ça -[01:39:19] VOIX: ça on le verra -[01:39:19] VOIX: après en fait -[01:39:20] VOIX: c'est des problèmes -[01:39:21] VOIX: que je suis en train -[01:39:22] VOIX: non mais ça va aller -[01:39:24] VOIX: peut-être un peu plus vite -[01:39:25] VOIX: tu vois -[01:39:29] VOIX: alors ça -[01:39:30] VOIX: est-ce que je te le pose -[01:39:31] VOIX: là aujourd'hui -[01:39:33] VOIX: non -[01:39:34] VOIX: vas-y vas-y -[01:39:35] VOIX: non non -[01:39:36] VOIX: c'est pas une question -[01:39:36] VOIX: de peur -[01:39:37] VOIX: en fait -[01:39:37] VOIX: si tu veux -[01:39:38] VOIX: tu m'as donné -[01:39:38] VOIX: tellement d'informations -[01:39:39] VOIX: sur des sujets -[01:39:40] VOIX: que j'avais prévu -[01:39:41] VOIX: un peu plus tard -[01:39:42] VOIX: c'est pour ça en fait -[01:39:43] VOIX: que j'aime bien -[01:39:44] VOIX: cadrer un tout petit peu -[01:39:45] VOIX: c'est parce qu'il faut -[01:39:46] VOIX: que j'avance -[01:39:47] VOIX: en fait sur le truc -[01:39:47] VOIX: mais -[01:39:49] VOIX: tu m'as répondu -[01:39:50] VOIX: en fait -[01:39:50] VOIX: sur plein de choses -[01:39:51] VOIX: donc il y a plein -[01:39:52] VOIX: de réponses -[01:39:53] VOIX: en fait que j'ai déjà -[01:39:55] VOIX: moi j'ai des problèmes -[01:39:56] VOIX: en fait -[01:40:06] VOIX: ah si -[01:40:07] VOIX: pour morbidité -[01:40:09] VOIX: je peux te donner -[01:40:10] VOIX: une réponse -[01:40:10] VOIX: vas-y vas-y -[01:40:11] VOIX: tout à l'heure -[01:40:11] VOIX: on a vu le patient -[01:40:13] VOIX: qui est suivi -[01:40:13] VOIX: pour anémie -[01:40:14] VOIX: pour maladie de bière-mère -[01:40:17] VOIX: ma logique -[01:40:17] VOIX: je vais dire -[01:40:19] VOIX: soit -[01:40:19] VOIX: il a un traitement -[01:40:21] VOIX: pour son anémie -[01:40:22] VOIX: et pour sa maladie -[01:40:23] VOIX: de bière-mère -[01:40:23] VOIX: je vais regarder -[01:40:25] VOIX: les médicaments -[01:40:25] VOIX: je vais voir -[01:40:26] VOIX: est-ce qu'il a dû faire -[01:40:28] VOIX: est-ce qu'il a de la vitamine -[01:40:29] VOIX: B12 -[01:40:29] VOIX: etc -[01:40:30] VOIX: soit -[01:40:30] VOIX: je vais avoir -[01:40:31] VOIX: un examen complémentaire -[01:40:33] VOIX: où ils disent -[01:40:33] VOIX: je fais le bilan -[01:40:34] VOIX: de la maladie de bière-mère -[01:40:36] VOIX: pour voir son hémoglobine -[01:40:37] VOIX: etc -[01:40:37] VOIX: donc en fait -[01:40:39] VOIX: je vais chercher -[01:40:39] VOIX: dans le dossier -[01:40:40] VOIX: est-ce que cet antécédent -[01:40:42] VOIX: ou cette maladie chronique -[01:40:44] VOIX: qui est -[01:40:44] VOIX: qui est avec le patient -[01:40:47] VOIX: a été prise en charge -[01:40:49] VOIX: d'un point de vue -[01:40:51] VOIX: diagnostic thérapeutique -[01:40:52] VOIX: aux surveillances -[01:40:53] VOIX: pendant le séjour -[01:40:55] VOIX: si elle a été prise en charge -[01:40:57] VOIX: je peux la coder -[01:40:59] VOIX: en diagnostic associé -[01:41:00] VOIX: qui représente -[01:41:01] VOIX: les comorbidités -[01:41:02] VOIX: s'il n'y a eu -[01:41:03] VOIX: aucune prise en charge -[01:41:05] VOIX: ou aucune mention -[01:41:06] VOIX: et qu'il est écrite -[01:41:07] VOIX: exclusivement -[01:41:08] VOIX: comme antécédent -[01:41:09] VOIX: je ne vais pas la coder -[01:41:12] VOIX: ok -[01:41:14] VOIX: et je -[01:41:15] VOIX: je -[01:41:16] VOIX: ou les dates -[01:41:16] VOIX: dont on a parlé -[01:41:17] VOIX: au tout début -[01:41:18] VOIX: il te dit -[01:41:18] VOIX: en 1928 -[01:41:20] VOIX: il a eu -[01:41:21] VOIX: une fracture -[01:41:23] VOIX: ben là -[01:41:23] VOIX: tu sais que tu ne vas pas -[01:41:25] VOIX: coder sa fracture -[01:41:27] VOIX: tu peux coder -[01:41:28] VOIX: peut-être des conséquences -[01:41:29] VOIX: des choses qui sont -[01:41:30] VOIX: encore actives -[01:41:32] VOIX: aujourd'hui -[01:41:32] VOIX: en fait il faut regarder -[01:41:34] VOIX: la notion de comorbidité -[01:41:35] VOIX: elle est liée au séjour -[01:41:37] VOIX: c'est qu'est-ce qui a été fait -[01:41:39] VOIX: pendant le séjour -[01:41:40] VOIX: qu'on est en train de coder -[01:41:41] VOIX: c'est pas -[01:41:42] VOIX: le patient -[01:41:43] VOIX: son histoire -[01:41:44] VOIX: c'est pas -[01:41:47] VOIX: c'est un codage -[01:41:47] VOIX: le codage -[01:41:48] VOIX: c'est vraiment -[01:41:48] VOIX: lié au séjour -[01:41:51] VOIX: ok -[01:41:52] VOIX: est-ce que -[01:41:53] VOIX: alors attends -[01:41:54] VOIX: je continue -[01:41:55] VOIX: mes questions -[01:41:56] VOIX: et on a bien avancé -[01:41:58] VOIX: c'est une bonne question -[01:41:59] VOIX: en fait -[01:41:59] VOIX: ah mais je te remercie -[01:42:02] VOIX: ah oui -[01:42:02] VOIX: non mais c'est vrai -[01:42:03] VOIX: ça montre que -[01:42:04] VOIX: voilà -[01:42:05] VOIX: c'est -[01:42:06] VOIX: la logique -[01:42:07] VOIX: elle est en train de -[01:42:08] VOIX: alors -[01:42:11] VOIX: donc -[01:42:11] VOIX: j'aime beaucoup -[01:42:15] VOIX: cette manière -[01:42:16] VOIX: de -[01:42:17] VOIX: de se projeter -[01:42:18] VOIX: dans le codage -[01:42:19] VOIX: et qu'on ne soit pas -[01:42:20] VOIX: que sur des mots clés -[01:42:21] VOIX: au fait -[01:42:22] VOIX: ah mais disons que -[01:42:23] VOIX: si tu veux -[01:42:24] VOIX: moi la partie -[01:42:24] VOIX: mots clés -[01:42:25] VOIX: pour tout te dire -[01:42:26] VOIX: en fait moi je l'ai terminé -[01:42:27] VOIX: les mots clés -[01:42:28] VOIX: tu sais -[01:42:28] VOIX: en fait avec le -[01:42:30] VOIX: SL -[01:42:31] VOIX: NMP -[01:42:31] VOIX: en fait j'arrive -[01:42:32] VOIX: à les choper -[01:42:33] VOIX: tout ce qui est négation -[01:42:34] VOIX: machin et tout ça -[01:42:35] VOIX: tout ça en fait j'y arrive -[01:42:36] VOIX: et c'est justement -[01:42:37] VOIX: ce que l'on est en train de faire -[01:42:38] VOIX: qui va amener -[01:42:40] VOIX: en fait la plus value -[01:42:40] VOIX: en fait au système -[01:42:41] VOIX: parce qu'en réalité -[01:42:42] VOIX: tout le reste -[01:42:43] VOIX: en fait ce sont des règles -[01:42:44] VOIX: mais il y a en fait -[01:42:45] VOIX: des petites subtilités -[01:42:46] VOIX: que tu m'as -[01:42:47] VOIX: que tu m'as donné -[01:42:48] VOIX: et ce sont en fait -[01:42:49] VOIX: non mais -[01:42:50] VOIX: quand je dis petite -[01:42:51] VOIX: c'est pas nécessairement -[01:42:52] VOIX: c'est justement -[01:42:53] VOIX: c'est le grain de sable -[01:42:54] VOIX: en fait qui fait bloquer -[01:42:55] VOIX: en fait l'ensemble -[01:42:55] VOIX: d'une machine -[01:42:56] VOIX: qui est cadrée -[01:42:57] VOIX: le fait que les contrôles -[01:42:58] VOIX: c'est plus -[01:42:59] VOIX: c'est le bazar absolu -[01:43:00] VOIX: c'est les petits grains de sable -[01:43:02] VOIX: et l'interprétation -[01:43:03] VOIX: et comment t'orientes -[01:43:05] VOIX: ton dossier -[01:43:05] VOIX: c'est toute la nuance -[01:43:07] VOIX: avec l'avocat hier -[01:43:08] VOIX: on s'est dit -[01:43:09] VOIX: bon on va faire ça comme ça -[01:43:10] VOIX: on va jouer -[01:43:11] VOIX: parce que c'est vrai -[01:43:14] VOIX: ben oui -[01:43:14] VOIX: mais bien sûr -[01:43:15] VOIX: mais là aussi -[01:43:16] VOIX: tu vois en fait -[01:43:17] VOIX: dans ce que tu m'as -[01:43:18] VOIX: ce que tu m'as dit -[01:43:19] VOIX: tout à l'heure -[01:43:20] VOIX: en fait -[01:43:20] VOIX: le premier tableau -[01:43:23] VOIX: en fait que tu m'as montré -[01:43:24] VOIX: sur ta vision -[01:43:24] VOIX: en fait des choses -[01:43:25] VOIX: avec en fait -[01:43:25] VOIX: ce que on fait aujourd'hui -[01:43:27] VOIX: ce qu'on veut faire -[01:43:28] VOIX: en fait demain -[01:43:29] VOIX: il y a déjà -[01:43:30] VOIX: tout ce que -[01:43:31] VOIX: c'est ce que j'ai montré -[01:43:32] VOIX: l'autre jour -[01:43:32] VOIX: j'ai déjà pris en compte -[01:43:34] VOIX: une partie -[01:43:34] VOIX: une grosse partie -[01:43:35] VOIX: de ce que tu m'as montré -[01:43:36] VOIX: et notamment -[01:43:37] VOIX: en fait le fait -[01:43:38] VOIX: que quand on fait -[01:43:39] VOIX: le codage -[01:43:39] VOIX: d'accord -[01:43:40] VOIX: quand on fait le codage -[01:43:41] VOIX: on fait pas en fait -[01:43:42] VOIX: simplement en fait -[01:43:43] VOIX: le codage primaire -[01:43:45] VOIX: c'est en fait -[01:43:46] VOIX: on va justifier -[01:43:47] VOIX: parce que la machine -[01:43:48] VOIX: te permet de faire ça -[01:43:49] VOIX: en même temps -[01:43:50] VOIX: pourquoi tu as mis ça -[01:43:51] VOIX: c'est pour ça -[01:43:52] VOIX: que je pose des questions -[01:43:53] VOIX: peut-être -[01:43:53] VOIX: peut-être très bêtes -[01:43:54] VOIX: c'est parfait -[01:43:55] VOIX: voilà -[01:43:55] VOIX: parce que -[01:43:56] VOIX: c'est pas bête -[01:43:57] VOIX: parce que dans la logique -[01:43:59] VOIX: je vais te donner -[01:43:59] VOIX: une astuce aussi -[01:44:00] VOIX: de raisonnement -[01:44:02] VOIX: qui est utile -[01:44:03] VOIX: dans un des mails -[01:44:04] VOIX: je sais pas -[01:44:05] VOIX: si tu t'en rappelles -[01:44:06] VOIX: je t'ai dit -[01:44:07] VOIX: il y a la règle -[01:44:08] VOIX: de la TIH -[01:44:10] VOIX: il y a la règle -[01:44:11] VOIX: médicale -[01:44:11] VOIX: c'est-à-dire -[01:44:12] VOIX: si tu donnes -[01:44:14] VOIX: par exemple -[01:44:15] VOIX: le site de l'HAS -[01:44:16] VOIX: ou des sites -[01:44:17] VOIX: de recommandations -[01:44:17] VOIX: médicales -[01:44:18] VOIX: lui va te dire -[01:44:20] VOIX: par exemple -[01:44:21] VOIX: j'ai lu -[01:44:21] VOIX: dans un de tes documents -[01:44:22] VOIX: l'insuffisance rénale -[01:44:24] VOIX: aiguë -[01:44:25] VOIX: obstruxive -[01:44:25] VOIX: elle est liée -[01:44:26] VOIX: souvent -[01:44:27] VOIX: à l'obstruction -[01:44:29] VOIX: urétérale -[01:44:30] VOIX: ou je sais pas quoi -[01:44:32] VOIX: médicalement parlant -[01:44:32] VOIX: ça tient la route -[01:44:33] VOIX: mais c'est quelque chose -[01:44:35] VOIX: qui est peut-être -[01:44:36] VOIX: non écrit -[01:44:37] VOIX: par le médecin -[01:44:37] VOIX: par exemple -[01:44:38] VOIX: le patient -[01:44:39] VOIX: il va rentrer -[01:44:39] VOIX: pour douleur -[01:44:41] VOIX: je vais essayer -[01:44:42] VOIX: de trouver des dossiers -[01:44:43] VOIX: dans cette logique là -[01:44:43] VOIX: le médecin lui -[01:44:44] VOIX: il va dire -[01:44:45] VOIX: est-ce que c'est mon cause -[01:44:46] VOIX: ma cause -[01:44:46] VOIX: c'est une tumeur -[01:44:47] VOIX: est-ce que c'est une infection -[01:44:48] VOIX: est-ce que -[01:44:49] VOIX: c'est une inflammation -[01:44:50] VOIX: est-ce que -[01:44:51] VOIX: donc dans sa tête lui -[01:44:52] VOIX: il a fait -[01:44:53] VOIX: mais il va pas -[01:44:54] VOIX: les écrire forcément -[01:44:55] VOIX: il y a des médecins -[01:44:56] VOIX: qui vont les écrire -[01:44:57] VOIX: et il y a des médecins -[01:44:57] VOIX: qui vont pas les écrire -[01:44:58] VOIX: et leur réflexion -[01:45:01] VOIX: va rejoindre -[01:45:02] VOIX: la logique médicale -[01:45:03] VOIX: les recommandations médicales -[01:45:04] VOIX: devant tel symptôme -[01:45:05] VOIX: voici la démarche -[01:45:07] VOIX: pour éliminer -[01:45:08] VOIX: la première hypothèse -[01:45:09] VOIX: il faut faire -[01:45:10] VOIX: tel bilan -[01:45:10] VOIX: tel bilan -[01:45:11] VOIX: éliminer -[01:45:12] VOIX: ou confirmer -[01:45:13] VOIX: confirmer ou infirmer -[01:45:14] VOIX: la deuxième hypothèse -[01:45:15] VOIX: tel bilan -[01:45:15] VOIX: tel bilan -[01:45:16] VOIX: c'est comme ça -[01:45:16] VOIX: qu'ils arrivent -[01:45:17] VOIX: à la démarche diagnostique -[01:45:19] VOIX: et en fait -[01:45:20] VOIX: parfois -[01:45:22] VOIX: ils vont pas -[01:45:23] VOIX: forcément -[01:45:24] VOIX: nommer -[01:45:25] VOIX: la conclusion -[01:45:28] VOIX: parce qu'elle est -[01:45:29] VOIX: tellement évidente -[01:45:30] VOIX: pour eux -[01:45:30] VOIX: et nous -[01:45:32] VOIX: on se retrouve -[01:45:32] VOIX: coincés -[01:45:33] VOIX: même si elle est -[01:45:34] VOIX: évidente -[01:45:35] VOIX: à la lecture -[01:45:35] VOIX: mais si elle n'est -[01:45:37] VOIX: pas écrite -[01:45:38] VOIX: tu peux pas -[01:45:39] VOIX: ça compte ou pas -[01:45:41] VOIX: sauf si -[01:45:42] VOIX: t'as un argument -[01:45:44] VOIX: médical -[01:45:45] VOIX: indiscutable -[01:45:46] VOIX: c'est à dire -[01:45:46] VOIX: tu dis -[01:45:47] VOIX: c'est pas le plus probable -[01:45:48] VOIX: tu dis -[01:45:49] VOIX: j'ai tel argument -[01:45:50] VOIX: tel symptôme -[01:45:51] VOIX: tel bilan -[01:45:52] VOIX: tel ça -[01:45:53] VOIX: ça ça veut dire -[01:45:54] VOIX: que c'est telle maladie -[01:45:55] VOIX: et t'arrives à te défendre -[01:45:57] VOIX: tu peux aller au tribunal -[01:45:59] VOIX: avec -[01:45:59] VOIX: il y a personne -[01:45:59] VOIX: qui te dit -[01:46:00] VOIX: non c'est pas telle pathologie -[01:46:02] VOIX: donc ça c'est des arguments -[01:46:04] VOIX: qu'on peut avancer -[01:46:05] VOIX: ça veut pas dire -[01:46:05] VOIX: qu'ils sont acceptés -[01:46:07] VOIX: lors des contrôles -[01:46:08] VOIX: mais ça non plus -[01:46:09] VOIX: que c'est des arguments -[01:46:10] VOIX: qu'on utilise -[01:46:11] VOIX: les médecins -[01:46:11] VOIX: ils les utilisent -[01:46:12] VOIX: beaucoup d'ailleurs -[01:46:14] VOIX: ok -[01:46:14] VOIX: alors je reprends -[01:46:17] VOIX: mes petites notes -[01:46:18] VOIX: oui -[01:46:18] VOIX: les questions -[01:46:19] VOIX: alors les diagnostics -[01:46:20] VOIX: alors attends -[01:46:22] VOIX: je t'ai parlé -[01:46:22] VOIX: les comorbidités -[01:46:25] VOIX: et ensuite -[01:46:25] VOIX: des diagnostics -[01:46:27] VOIX: fréquents -[01:46:27] VOIX: dans les CHR -[01:46:29] VOIX: mais rarement -[01:46:30] VOIX: en DP -[01:46:34] VOIX: est-ce que tu en as -[01:46:35] VOIX: en tête -[01:46:36] VOIX: quelques-uns -[01:46:37] VOIX: oui -[01:46:40] VOIX: tu peux -[01:46:41] VOIX: par exemple -[01:46:43] VOIX: le sepsis -[01:46:45] VOIX: la dénutrition -[01:46:46] VOIX: le patient -[01:46:48] VOIX: ne vient pas -[01:46:48] VOIX: pour ça -[01:46:50] VOIX: ça revient souvent -[01:46:52] VOIX: dénutrition -[01:46:54] VOIX: oui -[01:46:55] VOIX: parce que le patient -[01:46:56] VOIX: il peut être suivi -[01:46:57] VOIX: pour un problème -[01:46:58] VOIX: oncologique -[01:46:59] VOIX: il vient -[01:47:01] VOIX: parce que -[01:47:01] VOIX: il fait une complication -[01:47:03] VOIX: de son traitement -[01:47:04] VOIX: de sa chimiothérapie -[01:47:05] VOIX: par exemple -[01:47:06] VOIX: il fait une aplasie -[01:47:09] VOIX: médulaire -[01:47:09] VOIX: je dis n'importe quoi -[01:47:10] VOIX: ou une anémie -[01:47:11] VOIX: etc -[01:47:12] VOIX: il vient pour ça -[01:47:14] VOIX: en même temps -[01:47:15] VOIX: on va se rendre compte -[01:47:17] VOIX: qu'il a une dénutrition -[01:47:18] VOIX: on peut la prendre -[01:47:19] VOIX: en charge -[01:47:20] VOIX: elle va être plus ou moins -[01:47:22] VOIX: traitée -[01:47:23] VOIX: plus ou moins prise en charge -[01:47:25] VOIX: plus ou moins tracée -[01:47:25] VOIX: c'est pas traité -[01:47:26] VOIX: il peut être traité -[01:47:27] VOIX: mais pas forcément -[01:47:28] VOIX: tracé dans le dossier -[01:47:29] VOIX: donc dans ce cas là -[01:47:31] VOIX: la dénutrition -[01:47:32] VOIX: il faut qu'elle soit codée -[01:47:33] VOIX: pour la coder -[01:47:35] VOIX: il y a des règles -[01:47:36] VOIX: dans le guide méthodologique -[01:47:38] VOIX: qui définissent -[01:47:39] VOIX: la dénutrition -[01:47:40] VOIX: donc il faut que le dossier -[01:47:41] VOIX: puisse retrouver -[01:47:42] VOIX: si je t'ai envoyé -[01:47:44] VOIX: un linkedin -[01:47:45] VOIX: sur parallèle -[01:47:47] VOIX: qui parlait -[01:47:47] VOIX: de la dénutrition -[01:47:48] VOIX: parce que voilà -[01:47:49] VOIX: ils ont fait un travail -[01:47:59] VOIX: avec risque -[01:48:01] VOIX: ou avec une potentielle -[01:48:02] VOIX: dénutrition -[01:48:03] VOIX: mais peut-être qu'avec -[01:48:05] VOIX: une meilleure traçabilité -[01:48:06] VOIX: on pourra mieux -[01:48:07] VOIX: les valoriser -[01:48:07] VOIX: tu vois ça -[01:48:08] VOIX: c'est les contrôles -[01:48:09] VOIX: après qu'on pourra -[01:48:10] VOIX: intégrer -[01:48:11] VOIX: ou à défaut -[01:48:12] VOIX: de pouvoir les coder -[01:48:14] VOIX: tu peux me donner -[01:48:15] VOIX: là on a parlé -[01:48:16] VOIX: de la dénutrition -[01:48:17] VOIX: tu peux m'en donner -[01:48:18] VOIX: en fait un ou deux -[01:48:19] VOIX: de plus -[01:48:20] VOIX: comme ça -[01:48:21] VOIX: le sepsis -[01:48:22] VOIX: le sepsis -[01:48:23] VOIX: s-e-p-s-i-s -[01:48:33] VOIX: ok -[01:48:33] VOIX: le diabète -[01:48:37] VOIX: hypertension -[01:48:37] VOIX: tu vois -[01:48:38] VOIX: ah d'accord -[01:48:39] VOIX: mais c'est des trucs -[01:48:39] VOIX: fréquents en fait -[01:48:40] VOIX: ça -[01:48:40] VOIX: par exemple -[01:48:44] VOIX: l'hypertension -[01:48:45] VOIX: c'est hyper fréquemment -[01:48:46] VOIX: codé en diagnostics -[01:48:47] VOIX: associés -[01:48:47] VOIX: d'accord -[01:48:48] VOIX: c'est pas forcément -[01:48:49] VOIX: c'est rare -[01:48:50] VOIX: les patients -[01:48:51] VOIX: qui viennent -[01:48:51] VOIX: pour une poussée -[01:48:52] VOIX: d'hypertension -[01:48:54] VOIX: bon moi -[01:48:54] VOIX: je l'ai fait -[01:48:56] VOIX: oui -[01:48:57] VOIX: mais je suis -[01:48:58] VOIX: un oiseau rare -[01:49:01] VOIX: si le patient -[01:49:02] VOIX: il vient -[01:49:02] VOIX: pour la poussée -[01:49:03] VOIX: d'hypertension -[01:49:03] VOIX: dans ce cas là -[01:49:04] VOIX: ça sera le DP -[01:49:05] VOIX: et si on regarde -[01:49:07] VOIX: les diagnostics -[01:49:07] VOIX: associés -[01:49:08] VOIX: les plus fréquemment -[01:49:08] VOIX: codés -[01:49:09] VOIX: ça peut être -[01:49:10] VOIX: l'hypertension -[01:49:10] VOIX: par exemple -[01:49:14] VOIX: bon -[01:49:14] VOIX: t'es prête -[01:49:14] VOIX: pour l'autre question -[01:49:17] VOIX: alors -[01:49:18] VOIX: quels diagnostics -[01:49:19] VOIX: sont quasi -[01:49:20] VOIX: toujours -[01:49:21] VOIX: des DAS -[01:49:22] VOIX: sauf exception -[01:49:27] VOIX: tu as une règle -[01:49:30] VOIX: déjà -[01:49:31] VOIX: liée -[01:49:32] VOIX: au code -[01:49:32] VOIX: qui te dit -[01:49:33] VOIX: que ces codes là -[01:49:34] VOIX: ne pourraient pas -[01:49:34] VOIX: être acceptés -[01:49:35] VOIX: en DP -[01:49:36] VOIX: donc ça -[01:49:37] VOIX: c'est lié -[01:49:37] VOIX: à la CIN10 -[01:49:38] VOIX: mais on a un fichier -[01:49:39] VOIX: Excel référentiel -[01:49:40] VOIX: qui dit -[01:49:42] VOIX: diagnostic -[01:49:43] VOIX: accepté -[01:49:44] VOIX: ou pas -[01:49:44] VOIX: en DP -[01:49:45] VOIX: donc ça -[01:49:46] VOIX: c'est facile -[01:49:46] VOIX: c'est un référentiel -[01:49:55] VOIX: franchement -[01:49:56] VOIX: là -[01:49:56] VOIX: comme ça -[01:49:57] VOIX: c'est pas grave -[01:50:01] VOIX: là -[01:50:02] VOIX: comme ça -[01:50:02] VOIX: puisque -[01:50:03] VOIX: par définition -[01:50:06] VOIX: tu sais -[01:50:07] VOIX: le patient -[01:50:08] VOIX: il peut venir -[01:50:09] VOIX: pour un diabète -[01:50:10] VOIX: il peut venir -[01:50:10] VOIX: pour un -[01:50:11] VOIX: hypertension -[01:50:12] VOIX: il peut venir -[01:50:13] VOIX: peut-être -[01:50:14] VOIX: un peu -[01:50:14] VOIX: plus rarement -[01:50:15] VOIX: tout ce qui est -[01:50:16] VOIX: hypercholestérolémie -[01:50:17] VOIX: mais bon -[01:50:18] VOIX: c'est pas -[01:50:19] VOIX: interdit -[01:50:19] VOIX: c'est pas -[01:50:20] VOIX: impossible -[01:50:21] VOIX: il y a -[01:50:22] VOIX: les diagnostics -[01:50:23] VOIX: interdits -[01:50:23] VOIX: en DP -[01:50:24] VOIX: mais ça -[01:50:24] VOIX: c'est un référentiel -[01:50:25] VOIX: tu ne les trouveras -[01:50:26] VOIX: jamais en DP -[01:50:27] VOIX: donc ça -[01:50:27] VOIX: c'est facile -[01:50:29] VOIX: donc ça -[01:50:30] VOIX: en fait -[01:50:30] VOIX: c'est mes exceptions -[01:50:31] VOIX: c'est ça -[01:50:32] VOIX: ok -[01:50:33] VOIX: c'est peut-être -[01:50:34] VOIX: si tu cherches -[01:50:35] VOIX: l'exception -[01:50:36] VOIX: ça peut être -[01:50:36] VOIX: les diagnostics -[01:50:37] VOIX: interdits -[01:50:37] VOIX: en DP -[01:50:38] VOIX: mais ça -[01:50:38] VOIX: le référentiel -[01:50:39] VOIX: on peut te le donner -[01:50:40] VOIX: il n'y a pas de soucis -[01:50:47] VOIX: j'arrive -[01:50:49] VOIX: parce que -[01:50:49] VOIX: j'ai une autre question -[01:50:50] VOIX: et là -[01:50:51] VOIX: celle-là -[01:50:51] VOIX: elle m'intéresse -[01:50:52] VOIX: beaucoup -[01:50:53] VOIX: c'est -[01:50:54] VOIX: ce que j'appelle -[01:50:55] VOIX: en fait -[01:50:55] VOIX: des codes -[01:50:56] VOIX: toxiques -[01:50:57] VOIX: en contrôle -[01:50:58] VOIX: c'est en fait -[01:50:59] VOIX: qu'ils sont souvent -[01:51:00] VOIX: en fait -[01:51:00] VOIX: contestés -[01:51:02] VOIX: ah ben -[01:51:03] VOIX: ça c'est facile -[01:51:04] VOIX: attends -[01:51:05] VOIX: je vais les écrire -[01:51:07] VOIX: mais je peux te faire -[01:51:08] VOIX: avec les écrits -[01:51:08] VOIX: déjà -[01:51:09] VOIX: la dénutrition -[01:51:10] VOIX: les sepsis -[01:51:12] VOIX: on a eu -[01:51:13] VOIX: ce qu'on appelle -[01:51:14] VOIX: les T80 -[01:51:14] VOIX: les complications -[01:51:16] VOIX: les complications -[01:51:17] VOIX: post-opératoires -[01:51:19] VOIX: mais pas tous -[01:51:20] VOIX: c'est certain -[01:51:21] VOIX: mais je peux te donner -[01:51:22] VOIX: ces quelques codes -[01:51:23] VOIX: ouais -[01:51:23] VOIX: je veux -[01:51:24] VOIX: 74 2 -[01:51:26] VOIX: pas maintenant -[01:51:27] VOIX: en fait -[01:51:28] VOIX: mais -[01:51:28] VOIX: tu vois en fait -[01:51:29] VOIX: les questions -[01:51:30] VOIX: bêtes que je me pose -[01:51:31] VOIX: pour gérer le truc -[01:51:32] VOIX: parce qu'en fait -[01:51:33] VOIX: non non -[01:51:33] VOIX: elles ne sont pas bêtes -[01:51:34] VOIX: les questions -[01:51:35] VOIX: ne sont pas bêtes -[01:51:35] VOIX: Dominique -[01:51:36] VOIX: ce sont en fait -[01:51:37] VOIX: des questions -[01:51:37] VOIX: qui me permettent -[01:51:39] VOIX: je t'expliquerai quand même -[01:51:40] VOIX: comment ça marche -[01:51:41] VOIX: parce que -[01:51:42] VOIX: ouais -[01:51:43] VOIX: pour que tu comprennes -[01:51:43] VOIX: en fait -[01:51:44] VOIX: j'établis en fait -[01:51:45] VOIX: justement -[01:51:46] VOIX: il y a une partie -[01:51:47] VOIX: en fait -[01:51:47] VOIX: c'est l'IA qui le gère -[01:51:49] VOIX: mais une grosse partie -[01:51:50] VOIX: en fait -[01:51:50] VOIX: ce sont en fait -[01:51:51] VOIX: des règles -[01:51:52] VOIX: qui sont fixées -[01:51:54] VOIX: et on ne peut pas -[01:51:55] VOIX: y déroger -[01:51:56] VOIX: parce que de toute façon -[01:51:57] VOIX: en fait -[01:51:57] VOIX: ça rejouit -[01:51:58] VOIX: en fait -[01:51:58] VOIX: ce que -[01:51:59] VOIX: voilà -[01:51:59] VOIX: tu me renvoies -[01:52:00] VOIX: en fait -[01:52:00] VOIX: au document -[01:52:01] VOIX: en fait -[01:52:01] VOIX: de la TIH -[01:52:03] VOIX: à chaque fois -[01:52:04] VOIX: donc -[01:52:04] VOIX: mais par contre -[01:52:06] VOIX: j'ai chopé -[01:52:07] VOIX: en fait -[01:52:07] VOIX: alors j'ai chopé -[01:52:08] VOIX: on va le voir -[01:52:08] VOIX: on va le vérifier -[01:52:10] VOIX: ta logique -[01:52:11] VOIX: en fait -[01:52:11] VOIX: ta façon -[01:52:12] VOIX: en fait -[01:52:12] VOIX: parce que -[01:52:12] VOIX: tout ce que -[01:52:13] VOIX: tu as dit -[01:52:15] VOIX: je l'avais -[01:52:16] VOIX: entreaperçu -[01:52:16] VOIX: avec toutes les discussions -[01:52:17] VOIX: qu'on a eu -[01:52:18] VOIX: mais je les ai validées -[01:52:21] VOIX: donc bon -[01:52:21] VOIX: après -[01:52:23] VOIX: c'est pour ça -[01:52:24] VOIX: que je t'ai dit -[01:52:24] VOIX: on fera des sessions -[01:52:25] VOIX: autant de fois -[01:52:26] VOIX: que nécessaire -[01:52:27] VOIX: ah oui -[01:52:27] VOIX: non mais là -[01:52:28] VOIX: tu ne vas pas y couper -[01:52:32] VOIX: non parce que -[01:52:33] VOIX: alors attends -[01:52:34] VOIX: attends -[01:52:34] VOIX: parce qu'il me reste -[01:52:35] VOIX: une ou deux -[01:52:36] VOIX: et après -[01:52:36] VOIX: je suis tranquille -[01:52:38] VOIX: donc ok -[01:52:40] VOIX: ah oui -[01:52:41] VOIX: si -[01:52:41] VOIX: tu sais -[01:52:42] VOIX: oui pardon -[01:52:42] VOIX: juste -[01:52:43] VOIX: tes questions là -[01:52:44] VOIX: tu les as notées -[01:52:45] VOIX: ah bah oui -[01:52:46] VOIX: je suis en fait -[01:52:47] VOIX: mon protocole -[01:52:47] VOIX: parce que tu sais -[01:52:48] VOIX: moi tes questions -[01:52:50] VOIX: mais c'est une formation -[01:52:52] VOIX: que -[01:52:52] VOIX: c'est génial -[01:52:54] VOIX: tes questions -[01:52:54] VOIX: tu vois -[01:52:55] VOIX: quand tu -[01:52:56] VOIX: non mais c'est vrai -[01:52:57] VOIX: parce que -[01:52:58] VOIX: nous en formation -[01:52:59] VOIX: on travaille beaucoup -[01:53:00] VOIX: avec -[01:53:02] VOIX: est-ce que tu peux mettre ça -[01:53:03] VOIX: c'est quoi la règle -[01:53:05] VOIX: pour coder ça -[01:53:06] VOIX: etc -[01:53:07] VOIX: donc c'est une manière -[01:53:08] VOIX: aussi -[01:53:08] VOIX: tu vois -[01:53:08] VOIX: moi aussi -[01:53:10] VOIX: ta manière -[01:53:11] VOIX: de te poser les questions -[01:53:11] VOIX: m'intéresse énormément -[01:53:12] VOIX: parce que -[01:53:14] VOIX: je vais te l'envoyer -[01:53:15] VOIX: en fait -[01:53:15] VOIX: je vais t'envoyer -[01:53:16] VOIX: en fait -[01:53:16] VOIX: le truc -[01:53:17] VOIX: oui -[01:53:18] VOIX: je te l'envoie en fait -[01:53:19] VOIX: ben là quand on a fini -[01:53:20] VOIX: je t'envoie en fait -[01:53:20] VOIX: le document de vierge -[01:53:22] VOIX: c'est très intéressant -[01:53:24] VOIX: et en fait -[01:53:25] VOIX: si tu -[01:53:25] VOIX: et je t'expliquerai -[01:53:26] VOIX: en fait -[01:53:27] VOIX: comment je fonctionne -[01:53:27] VOIX: d'ailleurs en fait -[01:53:28] VOIX: une grosse discussion -[01:53:29] VOIX: en fait -[01:53:29] VOIX: de même -[01:53:29] VOIX: du mode de fonctionnement -[01:53:31] VOIX: de tout ce bazar -[01:53:32] VOIX: ok -[01:53:33] VOIX: j'aurais besoin -[01:53:34] VOIX: que tu m'envoies -[01:53:35] VOIX: en fait -[01:53:35] VOIX: si tu as la possibilité -[01:53:37] VOIX: attends -[01:53:38] VOIX: parce que -[01:53:38] VOIX: ça c'est fait -[01:53:43] VOIX: oui -[01:53:44] VOIX: il me faudra en fait avoir -[01:53:46] VOIX: mais ça je vais peut-être -[01:53:47] VOIX: demander à Jordan -[01:53:48] VOIX: ou à -[01:53:49] VOIX: ou à Pauline -[01:53:50] VOIX: si je pourrais avoir -[01:53:52] VOIX: en fait une fiche -[01:53:54] VOIX: attends -[01:53:55] VOIX: comment ça s'appelle -[01:53:56] VOIX: j'arrête pas en fait -[01:53:58] VOIX: de rouler -[01:53:59] VOIX: et de rouler -[01:53:59] VOIX: moi aussi je cherche -[01:54:00] VOIX: un document -[01:54:01] VOIX: je vais te le montrer aussi -[01:54:02] VOIX: oui -[01:54:03] VOIX: en fait -[01:54:04] VOIX: c'est un format -[01:54:05] VOIX: en fait -[01:54:06] VOIX: de synthèse -[01:54:06] VOIX: PMSI -[01:54:08] VOIX: idéal -[01:54:08] VOIX: ce que tu aimerais avoir -[01:54:17] VOIX: mais tu sais quoi -[01:54:19] VOIX: mais tu sais quoi -[01:54:22] VOIX: le format en fait -[01:54:23] VOIX: de synthèse -[01:54:24] VOIX: PMSI -[01:54:24] VOIX: alors -[01:54:26] VOIX: j'ai besoin -[01:54:26] VOIX: en fait de savoir -[01:54:27] VOIX: en fait -[01:54:27] VOIX: comment en fait -[01:54:28] VOIX: tu ferais -[01:54:29] VOIX: en fait -[01:54:30] VOIX: une synthèse -[01:54:31] VOIX: en fait -[01:54:31] VOIX: de PMSI -[01:54:35] VOIX: tu vois pas du tout -[01:54:37] VOIX: en fait -[01:54:38] VOIX: est-ce que -[01:54:39] VOIX: c'est une synthèse -[01:54:40] VOIX: parce que j'ai des diapos -[01:54:41] VOIX: de formation -[01:54:42] VOIX: PMSI -[01:54:44] VOIX: où je vais -[01:54:45] VOIX: expliquer -[01:54:46] VOIX: c'est quoi le DP -[01:54:46] VOIX: c'est quoi le DAS -[01:54:47] VOIX: etc -[01:54:49] VOIX: ou est-ce que -[01:54:49] VOIX: c'est un format -[01:54:52] VOIX: des codes -[01:54:54] VOIX: un peu comme -[01:54:54] VOIX: la fiche -[01:54:55] VOIX: OGC -[01:54:56] VOIX: qu'est-ce qu'il y a -[01:54:57] VOIX: dans -[01:54:57] VOIX: oui -[01:54:58] VOIX: voilà -[01:54:58] VOIX: en fait -[01:54:59] VOIX: c'est plutôt ça -[01:55:00] VOIX: plutôt la fiche -[01:55:02] VOIX: d'un séjour -[01:55:03] VOIX: oui -[01:55:03] VOIX: la fiche -[01:55:03] VOIX: d'un séjour -[01:55:05] VOIX: oui -[01:55:06] VOIX: mais -[01:55:07] VOIX: attends -[01:55:08] VOIX: attends -[01:55:08] VOIX: parce que -[01:55:09] VOIX: Guy -[01:55:09] VOIX: parle en même temps -[01:55:11] VOIX: la fiche -[01:55:12] VOIX: d'un séjour -[01:55:13] VOIX: t'as deux points -[01:55:14] VOIX: de vue -[01:55:14] VOIX: t'as le point -[01:55:15] VOIX: de vue PMSI -[01:55:16] VOIX: les données -[01:55:17] VOIX: la variable -[01:55:18] VOIX: un peu comme -[01:55:20] VOIX: on va le remontrer -[01:55:21] VOIX: là -[01:55:23] VOIX: pardon -[01:55:24] VOIX: vas-y -[01:55:24] VOIX: vas-y -[01:55:25] VOIX: la fiche -[01:55:26] VOIX: là qu'on a -[01:55:27] VOIX: vue tout à l'heure -[01:55:30] VOIX: j'espère que -[01:55:31] VOIX: je vais me reconnecter -[01:55:32] VOIX: la fiche -[01:55:35] VOIX: OGC -[01:55:36] VOIX: et là -[01:55:36] VOIX: tu vas voir -[01:55:38] VOIX: elle est intéressante -[01:55:39] VOIX: parce que -[01:55:39] VOIX: pratiquement -[01:55:41] VOIX: si on revient -[01:55:42] VOIX: à notre -[01:55:43] VOIX: 339 -[01:55:49] VOIX: la fiche -[01:55:50] VOIX: en fait -[01:55:50] VOIX: t'as -[01:55:52] VOIX: t'as -[01:55:52] VOIX: ici -[01:55:53] VOIX: le nombre -[01:55:54] VOIX: d'unités médicales -[01:55:56] VOIX: qu'il y a -[01:55:57] VOIX: là il y a une seule -[01:55:58] VOIX: mais on peut avoir -[01:55:59] VOIX: plusieurs unités médicales -[01:56:00] VOIX: avec leur propre date d'entrée -[01:56:02] VOIX: etc -[01:56:03] VOIX: t'as les données administratives -[01:56:05] VOIX: du patient -[01:56:05] VOIX: tu vois quand je te dis -[01:56:07] VOIX: l'âge -[01:56:09] VOIX: mode d'entrée -[01:56:10] VOIX: la durée de séjour -[01:56:11] VOIX: les dates -[01:56:14] VOIX: les modes de remboursement -[01:56:15] VOIX: bon ça c'est des données -[01:56:16] VOIX: de facturation -[01:56:18] VOIX: ça on le regarde après -[01:56:20] VOIX: attends -[01:56:20] VOIX: je lève le doigt -[01:56:21] VOIX: t'as vu -[01:56:21] VOIX: pardon -[01:56:22] VOIX: je vois pas -[01:56:23] VOIX: parce que -[01:56:24] VOIX: je -[01:56:25] VOIX: ouais -[01:56:25] VOIX: il est entré -[01:56:27] VOIX: est-ce que c'est une donnée -[01:56:28] VOIX: en fait -[01:56:28] VOIX: très importante -[01:56:30] VOIX: très importante -[01:56:31] VOIX: ou moyennement importante -[01:56:32] VOIX: il est entré -[01:56:33] VOIX: via domicile -[01:56:34] VOIX: et sorti -[01:56:34] VOIX: par domicile -[01:56:37] VOIX: oui -[01:56:37] VOIX: est importante -[01:56:39] VOIX: parce que -[01:56:42] VOIX: il est importante -[01:56:43] VOIX: parce qu'on peut faire -[01:56:44] VOIX: des contrôles là-dessus -[01:56:45] VOIX: les modes d'entrée -[01:56:47] VOIX: on a des contrôles aussi -[01:56:48] VOIX: sur -[01:56:49] VOIX: est-ce qu'il est passé -[01:56:50] VOIX: par les urgences -[01:56:51] VOIX: ou pas -[01:56:51] VOIX: est-ce que -[01:56:52] VOIX: il a été transféré -[01:56:54] VOIX: est-ce que -[01:56:55] VOIX: c'est pas important -[01:56:57] VOIX: au même niveau -[01:56:58] VOIX: que le codage -[01:56:59] VOIX: mais c'est des choses -[01:56:59] VOIX: qu'on contrôle -[01:57:00] VOIX: qu'on regarde -[01:57:01] VOIX: tout ça -[01:57:02] VOIX: le team -[01:57:02] VOIX: va le vérifier -[01:57:03] VOIX: par exemple -[01:57:04] VOIX: parfois on corrige ça -[01:57:05] VOIX: au moment -[01:57:06] VOIX: à la lecture du dossier -[01:57:07] VOIX: on se rend compte -[01:57:08] VOIX: ben finalement -[01:57:09] VOIX: c'était -[01:57:09] VOIX: il a été transféré -[01:57:10] VOIX: depuis un établissement -[01:57:11] VOIX: on va corriger -[01:57:13] VOIX: ok -[01:57:13] VOIX: d'accord -[01:57:16] VOIX: ou les dates -[01:57:16] VOIX: tu vois tout à l'heure -[01:57:17] VOIX: quand je t'ai dit -[01:57:18] VOIX: le 24 ou le 25 -[01:57:21] VOIX: si on se rend compte -[01:57:23] VOIX: qu'il y a une erreur -[01:57:23] VOIX: on peut la corriger -[01:57:26] VOIX: donc il y a cette vue là -[01:57:27] VOIX: de dire -[01:57:28] VOIX: la synthèse du séjour -[01:57:30] VOIX: d'un point de vue -[01:57:31] VOIX: codage PMSI -[01:57:32] VOIX: où je vais avoir -[01:57:34] VOIX: mes diagnostics -[01:57:36] VOIX: là typiquement -[01:57:37] VOIX: mes diagnostics -[01:57:38] VOIX: si ils disent -[01:57:39] VOIX: j'ai mes modes d'entrée -[01:57:41] VOIX: mes dates -[01:57:41] VOIX: parce qu'on peut -[01:57:42] VOIX: les modifier aussi -[01:57:43] VOIX: j'ai mes actes -[01:57:45] VOIX: et en haut -[01:57:46] VOIX: tu vois les actes -[01:57:47] VOIX: est-ce qu'il est classant -[01:57:48] VOIX: est-ce qu'il y a eu -[01:57:49] VOIX: de l'anesthésie -[01:57:50] VOIX: etc -[01:57:51] VOIX: et en haut -[01:57:52] VOIX: je vais avoir -[01:57:54] VOIX: le groupage -[01:57:55] VOIX: c'est-à-dire -[01:57:55] VOIX: comment il a été groupé -[01:57:57] VOIX: le GHF -[01:57:58] VOIX: le GHS -[01:57:59] VOIX: la valo -[01:58:00] VOIX: est-ce qu'il y a des suppléments -[01:58:01] VOIX: est-ce qu'il y a des bornes -[01:58:03] VOIX: parce que ça -[01:58:03] VOIX: là on rentre -[01:58:04] VOIX: dans la logique -[01:58:05] VOIX: de valorisation -[01:58:06] VOIX: du séjour -[01:58:07] VOIX: qui est encore -[01:58:07] VOIX: une autre histoire -[01:58:10] VOIX: est-ce que c'est -[01:58:11] VOIX: cette logique là -[01:58:12] VOIX: ou est-ce que -[01:58:13] VOIX: quand tu dis -[01:58:14] VOIX: la synthèse du séjour -[01:58:15] VOIX: c'est -[01:58:15] VOIX: ce qu'on a vu -[01:58:17] VOIX: tout à l'heure -[01:58:17] VOIX: dans Tracker -[01:58:19] VOIX: j'ai les observations -[01:58:21] VOIX: médicales -[01:58:21] VOIX: j'ai le courrier -[01:58:22] VOIX: j'ai les comptes rendus -[01:58:23] VOIX: opératoires -[01:58:24] VOIX: j'ai les comptes rendus -[01:58:24] VOIX: de biologie -[01:58:25] VOIX: non c'est plutôt ça -[01:58:26] VOIX: qu'est-ce que tu attends -[01:58:27] VOIX: c'était plutôt ça -[01:58:29] VOIX: en fait -[01:58:29] VOIX: et avec -[01:58:30] VOIX: mais -[01:58:32] VOIX: non non -[01:58:32] VOIX: c'est plutôt ce que tu viens -[01:58:33] VOIX: de me montrer -[01:58:35] VOIX: donc ce que je viens -[01:58:36] VOIX: de te montrer -[01:58:37] VOIX: demande à Jordan -[01:58:38] VOIX: de te montrer -[01:58:39] VOIX: ça c'est la fiche -[01:58:40] VOIX: où j'essaie -[01:58:41] VOIX: mais dans les autres -[01:58:42] VOIX: logiciels -[01:58:42] VOIX: Eva -[01:58:43] VOIX: on a un truc -[01:58:44] VOIX: qui s'appelle -[01:58:45] VOIX: fiche -[01:58:45] VOIX: fiche dossier -[01:58:47] VOIX: dans Eva -[01:58:49] VOIX: améliorer le codage -[01:58:50] VOIX: par exemple -[01:58:51] VOIX: et c'est ni plus -[01:58:52] VOIX: ni moins que ça -[01:58:53] VOIX: on s'est inspiré d'ailleurs -[01:58:54] VOIX: de cette fiche -[01:58:55] VOIX: pour créer les étapes -[01:58:56] VOIX: du contrôle T2A -[01:58:57] VOIX: et elle -[01:58:58] VOIX: elle n'a pas -[01:58:59] VOIX: les étapes -[01:59:00] VOIX: elle est de l'établissement -[01:59:02] VOIX: toute seule -[01:59:02] VOIX: elle est déjà faite -[01:59:04] VOIX: et elle est -[01:59:04] VOIX: voilà -[01:59:05] VOIX: c'est un module à part -[01:59:06] VOIX: et il peut te montrer -[01:59:07] VOIX: par exemple -[01:59:08] VOIX: quand il y a plusieurs unités -[01:59:09] VOIX: quand on a les suppléments -[01:59:10] VOIX: quand on a les variables -[01:59:11] VOIX: gradations -[01:59:12] VOIX: quand on a le bébé -[01:59:14] VOIX: la maman -[01:59:14] VOIX: pour les accouchements -[01:59:15] VOIX: il y a plusieurs cas -[01:59:16] VOIX: spécifiques -[01:59:17] VOIX: si tu veux -[01:59:17] VOIX: c'est pas juste -[01:59:18] VOIX: un modèle unique -[01:59:20] VOIX: parce que là -[01:59:21] VOIX: on est en train -[01:59:22] VOIX: de voir -[01:59:22] VOIX: les cas généraux -[01:59:23] VOIX: mais si tu rentres -[01:59:24] VOIX: dans l'obstétrique -[01:59:26] VOIX: si tu rentres -[01:59:26] VOIX: dans l'area -[01:59:28] VOIX: la surveillance continue -[01:59:29] VOIX: etc -[01:59:29] VOIX: on va avoir -[01:59:30] VOIX: des spécificités -[01:59:31] VOIX: de codage -[01:59:33] VOIX: de valorisation -[01:59:33] VOIX: etc -[01:59:34] VOIX: là on est dans -[01:59:35] VOIX: le cas de base -[01:59:36] VOIX: un team -[01:59:37] VOIX: qui est en train -[01:59:38] VOIX: d'apprendre -[01:59:38] VOIX: à coder -[01:59:40] VOIX: ok -[01:59:41] VOIX: bon -[01:59:42] VOIX: la 11 -[01:59:42] VOIX: ça tu m'as -[01:59:44] VOIX: déjà répondu -[01:59:44] VOIX: c'est les champs -[01:59:46] VOIX: en fait -[01:59:46] VOIX: c'est les motifs -[01:59:47] VOIX: d'admission -[01:59:47] VOIX: diagnostic -[01:59:48] VOIX: en fait -[01:59:48] VOIX: il y a la date -[01:59:49] VOIX: en fait -[01:59:49] VOIX: que j'ai remarqué -[01:59:51] VOIX: excuse-moi -[01:59:52] VOIX: je me parle -[01:59:52] VOIX: en même temps -[01:59:53] VOIX: et je dis -[01:59:53] VOIX: en fait -[01:59:53] VOIX: en même temps -[01:59:54] VOIX: oui -[01:59:54] VOIX: est-ce qu'on n'a pas -[01:59:55] VOIX: ici dans ces fiches -[01:59:57] VOIX: mais on l'a -[01:59:58] VOIX: dans les bases -[01:59:59] VOIX: c'est tout ce qui est lié -[02:00:01] VOIX: aux règles -[02:00:03] VOIX: de remboursement -[02:00:04] VOIX: et d'affiliation -[02:00:05] VOIX: du patient -[02:00:05] VOIX: à la caisse -[02:00:06] VOIX: est-ce qu'il est -[02:00:07] VOIX: assurance maladie -[02:00:08] VOIX: est-ce qu'il est -[02:00:09] VOIX: remboursable -[02:00:09] VOIX: est-ce qu'il est à 80% -[02:00:11] VOIX: est-ce qu'il est à 100% -[02:00:12] VOIX: c'est quelle caisse -[02:00:15] VOIX: voilà -[02:00:15] VOIX: c'est toutes les règles -[02:00:16] VOIX: liées à la facturation -[02:00:19] VOIX: bon ok -[02:00:19] VOIX: ok -[02:00:20] VOIX: c'est noté -[02:00:22] VOIX: moi il me faut -[02:00:24] VOIX: en fait -[02:00:24] VOIX: je rajoute -[02:00:25] VOIX: parce qu'en fait -[02:00:26] VOIX: c'est plus -[02:00:26] VOIX: la fiche -[02:00:29] VOIX: pour ma synthèse -[02:00:30] VOIX: ben moi je note -[02:00:31] VOIX: fiche EVA -[02:00:32] VOIX: au fait -[02:00:32] VOIX: oui -[02:00:33] VOIX: oui ça rejoint -[02:00:34] VOIX: c'est toujours -[02:00:35] VOIX: ouais -[02:00:35] VOIX: et ça -[02:00:36] VOIX: là je t'ai montré -[02:00:37] VOIX: la fiche OGC -[02:00:38] VOIX: mais -[02:00:39] VOIX: Jordan pourra te montrer -[02:00:41] VOIX: ou voilà -[02:00:42] VOIX: ou moi-même -[02:00:44] VOIX: si je me connecte -[02:00:47] VOIX: non mais on le -[02:00:48] VOIX: tu veux que je te la montre -[02:00:52] VOIX: ou pas -[02:00:52] VOIX: bon -[02:00:53] VOIX: ben vas-y -[02:00:54] VOIX: vas-y -[02:00:54] VOIX: puisque tu y es -[02:00:55] VOIX: vas-y -[02:00:57] VOIX: non j'y suis pas -[02:00:58] VOIX: tu peux encore connecter -[02:01:06] VOIX: je sais même pas si -[02:01:09] VOIX: ah c'est -[02:01:10] VOIX: c'est -[02:01:10] VOIX: c'est -[02:01:10] VOIX: c'est -[02:01:11] VOIX: le -[02:01:11] VOIX: le fameux code -[02:01:12] VOIX: qui est écrit -[02:01:13] VOIX: en fait -[02:01:13] VOIX: dans une feuille Excel -[02:01:14] VOIX: est-ce que -[02:01:17] VOIX: ouais -[02:01:17] VOIX: est-ce qu'il va marcher -[02:01:18] VOIX: ou pas -[02:01:19] VOIX: ça a marché de suite -[02:01:20] VOIX: ça a marché -[02:01:21] VOIX: tu as vu -[02:01:22] VOIX: la fiche Excel -[02:01:23] VOIX: de mon cerveau -[02:01:24] VOIX: pour celui-là -[02:01:26] VOIX: je me rappelle -[02:01:27] VOIX: je me rappelle -[02:01:33] VOIX: je me rappelle -[02:01:35] VOIX: je m'en fous -[02:01:35] VOIX: et -[02:01:36] VOIX: c'est là -[02:01:37] VOIX: où on va s'amuser -[02:01:38] VOIX: parce que -[02:01:39] VOIX: je me rappelle -[02:01:40] VOIX: plus -[02:01:40] VOIX: si -[02:01:41] VOIX: je me rappelle -[02:01:54] VOIX: le codage -[02:01:55] VOIX: le codage -[02:01:56] VOIX: ça c'est -[02:01:56] VOIX: du MCO -[02:01:57] VOIX: parce que -[02:01:57] VOIX: là on est en train -[02:01:58] VOIX: de travailler -[02:01:59] VOIX: le champ MCO -[02:02:00] VOIX: on n'a pas encore -[02:02:01] VOIX: touché le champ -[02:02:01] VOIX: et c'est là -[02:02:04] VOIX: je suis pas sûr -[02:02:09] VOIX: parce que -[02:02:10] VOIX: il y a trop -[02:02:10] VOIX: d'abréviations -[02:02:11] VOIX: il y a trop -[02:02:12] VOIX: de machin -[02:02:12] VOIX: là -[02:02:13] VOIX: bon -[02:02:13] VOIX: mais au fur et à mesure -[02:02:14] VOIX: ce que tu peux faire aussi -[02:02:16] VOIX: c'est sortir une liste -[02:02:18] VOIX: d'abréviations -[02:02:19] VOIX: qu'on peut nous -[02:02:20] VOIX: te compléter -[02:02:21] VOIX: oui -[02:02:22] VOIX: ben en fait -[02:02:22] VOIX: je les ai -[02:02:23] VOIX: en fait -[02:02:23] VOIX: à ces listes-là -[02:02:24] VOIX: mais c'est que -[02:02:25] VOIX: je suis -[02:02:25] VOIX: je n'arrive pas -[02:02:26] VOIX: à les retenir -[02:02:27] VOIX: c'est avec -[02:02:28] VOIX: l'expérience -[02:02:29] VOIX: avec le -[02:02:30] VOIX: au fur et à mesure -[02:02:31] VOIX: il y a des choses -[02:02:32] VOIX: que j'arrive à comprendre -[02:02:33] VOIX: et à retenir -[02:02:33] VOIX: bien sûr -[02:02:35] VOIX: tu vois ça -[02:02:36] VOIX: c'est un type -[02:02:36] VOIX: c'est les contrôles -[02:02:37] VOIX: qu'on fait -[02:02:38] VOIX: c'est les contrôles -[02:02:38] VOIX: qualité interne -[02:02:41] VOIX: par exemple -[02:02:41] VOIX: je prends celui-là -[02:02:43] VOIX: et je vais voir -[02:02:45] VOIX: un dossier -[02:02:46] VOIX: tu vois la fiche là -[02:02:49] VOIX: non ça c'est -[02:02:51] VOIX: maintenant -[02:02:51] VOIX: je vais prendre -[02:02:52] VOIX: un qui est -[02:02:55] VOIX: c'est un cauchement -[02:02:57] VOIX: tu vois là -[02:02:58] VOIX: par exemple -[02:02:59] VOIX: dans ce dossier -[02:03:01] VOIX: j'ai -[02:03:01] VOIX: le groupage -[02:03:03] VOIX: du séjour -[02:03:04] VOIX: complet -[02:03:04] VOIX: avec sa durée globale -[02:03:06] VOIX: et je vois -[02:03:07] VOIX: qu'il est passé -[02:03:08] VOIX: par deux unités -[02:03:09] VOIX: médicales -[02:03:10] VOIX: de chirurgie -[02:03:11] VOIX: l'unité 92 -[02:03:14] VOIX: et l'unité 91 -[02:03:16] VOIX: donc -[02:03:17] VOIX: il manque codage -[02:03:18] VOIX: la valeur globale -[02:03:20] VOIX: elle est liée -[02:03:20] VOIX: au séjour -[02:03:21] VOIX: mais le codage -[02:03:23] VOIX: par exemple -[02:03:24] VOIX: ce codage là -[02:03:25] VOIX: le diagnostic -[02:03:26] VOIX: des actes -[02:03:27] VOIX: il peut être différent -[02:03:28] VOIX: tu vois -[02:03:29] VOIX: il a changé -[02:03:30] VOIX: 11 diagnostic -[02:03:31] VOIX: machin -[02:03:31] VOIX: là si je fais ça -[02:03:33] VOIX: j'ai 26 diagnostics -[02:03:35] VOIX: et au fait -[02:03:36] VOIX: le groupage -[02:03:37] VOIX: c'est ce qui va -[02:03:38] VOIX: regrouper -[02:03:38] VOIX: l'ensemble -[02:03:40] VOIX: des codes -[02:03:41] VOIX: qui sont faits -[02:03:41] VOIX: dans chaque -[02:03:42] VOIX: unité médicale -[02:03:43] VOIX: donc ça la logique -[02:03:44] VOIX: de l'unité médicale -[02:03:45] VOIX: elle est aussi importante -[02:03:46] VOIX: ça on n'en a pas -[02:03:47] VOIX: beaucoup parlé -[02:03:48] VOIX: mais -[02:03:49] VOIX: ça -[02:03:51] VOIX: en psychiatrie -[02:03:52] VOIX: on va avoir -[02:03:53] VOIX: beaucoup de découpage -[02:03:54] VOIX: comme ça -[02:03:55] VOIX: le codage -[02:03:55] VOIX: il est lié -[02:03:56] VOIX: à l'unité médicale -[02:03:58] VOIX: et après -[02:03:58] VOIX: c'est le groupeur -[02:04:00] VOIX: qui va jouer -[02:04:01] VOIX: d'accord -[02:04:01] VOIX: mais là tu parles -[02:04:02] VOIX: d'accord -[02:04:03] VOIX: ok -[02:04:04] VOIX: tu parles en fait -[02:04:06] VOIX: le groupeur -[02:04:07] VOIX: en fait -[02:04:08] VOIX: le groupeur -[02:04:08] VOIX: on est d'accord -[02:04:09] VOIX: en fait -[02:04:09] VOIX: c'est essentiellement -[02:04:10] VOIX: en fait -[02:04:10] VOIX: c'est ce qui va donner -[02:04:11] VOIX: une tarification -[02:04:13] VOIX: oui -[02:04:13] VOIX: d'accord -[02:04:14] VOIX: c'est lui qui va te dire -[02:04:15] VOIX: ce DP -[02:04:17] VOIX: avec -[02:04:19] VOIX: ces actes là -[02:04:20] VOIX: avec les actes classants -[02:04:21] VOIX: je vais te montrer -[02:04:22] VOIX: la notion d'actes classants -[02:04:25] VOIX: je ne sais pas -[02:04:25] VOIX: si je peux -[02:04:26] VOIX: je ne peux pas grandir ça -[02:04:27] VOIX: non mais c'est -[02:04:28] VOIX: l'actes classants -[02:04:28] VOIX: tu vois -[02:04:29] VOIX: on a une étoile -[02:04:30] VOIX: pour dire -[02:04:30] VOIX: l'ostéosynthèse -[02:04:32] VOIX: ça c'est un acte classant -[02:04:34] VOIX: tu vois -[02:04:34] VOIX: qu'il a eu de l'anesthésie -[02:04:37] VOIX: classant classant -[02:04:38] VOIX: et un acte -[02:04:40] VOIX: qui n'est pas classant -[02:04:41] VOIX: ben le scan -[02:04:42] VOIX: le scan n'est pas classant -[02:04:44] VOIX: ok -[02:04:45] VOIX: il ne va pas jouer -[02:04:45] VOIX: dans le groupeur -[02:04:47] VOIX: et en fait -[02:04:48] VOIX: comme j'ai des actes -[02:04:49] VOIX: chirurgicaux ici -[02:04:51] VOIX: mon séjour -[02:04:52] VOIX: il est chirurgical -[02:04:53] VOIX: c'est pour ça -[02:04:54] VOIX: qu'il a un C -[02:04:56] VOIX: 08C -[02:04:58] VOIX: 08C -[02:04:58] VOIX: c'est l'ortho -[02:04:59] VOIX: le C -[02:05:00] VOIX: c'est la Chir -[02:05:02] VOIX: et ce 2 là -[02:05:03] VOIX: le dernier -[02:05:04] VOIX: il est lié -[02:05:05] VOIX: au niveau de sévérité -[02:05:07] VOIX: des diagnostics -[02:05:08] VOIX: qui sont codés -[02:05:09] VOIX: tu vois -[02:05:10] VOIX: les listes -[02:05:11] VOIX: les diagnostics associés -[02:05:12] VOIX: il y en a -[02:05:13] VOIX: qui n'ont pas -[02:05:14] VOIX: de niveau de sévérité -[02:05:15] VOIX: ils vont moins -[02:05:16] VOIX: nous intéresser -[02:05:17] VOIX: par exemple -[02:05:17] VOIX: dans le contrôle sécu -[02:05:18] VOIX: ils ne vont même pas -[02:05:19] VOIX: les regarder -[02:05:20] VOIX: parce qu'ils n'ont -[02:05:21] VOIX: pas d'impact -[02:05:22] VOIX: ils vont regarder -[02:05:24] VOIX: beaucoup -[02:05:24] VOIX: les diagnostics -[02:05:25] VOIX: ça aussi -[02:05:26] VOIX: c'est des référentiels -[02:05:27] VOIX: les diagnostics -[02:05:28] VOIX: qui ont des -[02:05:29] VOIX: niveau 2 -[02:05:30] VOIX: quand tu parles -[02:05:30] VOIX: de comorbidité -[02:05:31] VOIX: c'est ça -[02:05:31] VOIX: les CMA -[02:05:32] VOIX: c'est les comorbidités -[02:05:33] VOIX: c'est les diagnostics -[02:05:34] VOIX: associés avec CMA -[02:05:37] VOIX: 1, 2, 3 ou 4 -[02:05:38] VOIX: sur cette partie là -[02:05:40] VOIX: je le verrai -[02:05:40] VOIX: un peu plus tard -[02:05:41] VOIX: parce que -[02:05:43] VOIX: avant d'arriver là -[02:05:44] VOIX: il faut que j'intègre -[02:05:46] VOIX: tout ce qu'on a vu là -[02:05:47] VOIX: aujourd'hui -[02:05:48] VOIX: et puis surtout -[02:05:49] VOIX: il faut que je vois -[02:05:50] VOIX: avec Jordan -[02:05:50] VOIX: parce qu'en fait -[02:05:51] VOIX: sur cette partie là -[02:05:52] VOIX: j'ai pris -[02:05:54] VOIX: le groupeur officiel -[02:05:59] VOIX: qui est vendu -[02:06:01] VOIX: et comme je ne l'ai pas -[02:06:02] VOIX: j'ai utilisé -[02:06:03] VOIX: de le faire -[02:06:05] VOIX: et c'était juste -[02:06:06] VOIX: pour te montrer -[02:06:07] VOIX: aussi la fiche -[02:06:08] VOIX: si tu veux -[02:06:08] VOIX: tu m'as parlé -[02:06:09] VOIX: de la fiche -[02:06:10] VOIX: on s'est inspiré -[02:06:11] VOIX: de celle là -[02:06:12] VOIX: pour travailler -[02:06:12] VOIX: celle de l'OGC -[02:06:14] VOIX: on retrouve -[02:06:15] VOIX: tu vois la même logique -[02:06:17] VOIX: avec les diagnostics -[02:06:19] VOIX: leur niveau de sévérité -[02:06:21] VOIX: etc -[02:06:22] VOIX: et les actes -[02:06:23] VOIX: avec les dates -[02:06:24] VOIX: avec qu'est-ce qui sont classants -[02:06:26] VOIX: est-ce qu'on a -[02:06:27] VOIX: de l'anesthésie -[02:06:27] VOIX: etc -[02:06:28] VOIX: et on peut répartir -[02:06:29] VOIX: par -[02:06:30] VOIX: par -[02:06:31] VOIX: unité médicale -[02:06:32] VOIX: et juste un petit mot -[02:06:33] VOIX: pour que tu vois -[02:06:34] VOIX: que ça -[02:06:35] VOIX: c'est du fait -[02:06:37] VOIX: de l'identifiant patient -[02:06:38] VOIX: qu'on peut dire -[02:06:40] VOIX: que c'est le même patient -[02:06:41] VOIX: qui a eu -[02:06:42] VOIX: 3-6 jours -[02:06:44] VOIX: tu sais -[02:06:45] VOIX: quand on va travailler -[02:06:46] VOIX: le parcours -[02:06:49] VOIX: c'est ça en fait -[02:06:50] VOIX: qu'il faut arriver -[02:06:51] VOIX: à reconstituer -[02:06:52] VOIX: pour dire -[02:06:53] VOIX: parcours -[02:06:54] VOIX: de douleur chronique -[02:06:56] VOIX: là typiquement -[02:06:58] VOIX: j'ai un dossier -[02:06:59] VOIX: avec fracture -[02:07:01] VOIX: du tibia -[02:07:02] VOIX: donc une fracture -[02:07:03] VOIX: il est revenu -[02:07:05] VOIX: 4 mois après -[02:07:06] VOIX: il a un problème -[02:07:08] VOIX: du rein -[02:07:08] VOIX: d'accord -[02:07:09] VOIX: et il est revenu -[02:07:11] VOIX: 2 mois après -[02:07:12] VOIX: il a enlevé -[02:07:14] VOIX: la plaque -[02:07:15] VOIX: donc sur les 3 séjours -[02:07:17] VOIX: il y a 2 qui sont -[02:07:18] VOIX: en lien -[02:07:19] VOIX: avec l'ortho -[02:07:20] VOIX: parce que là -[02:07:21] VOIX: il est revenu -[02:07:22] VOIX: pour enlever la plaque -[02:07:23] VOIX: alors qu'il faut -[02:07:24] VOIX: que j'arrive à dire -[02:07:25] VOIX: le rein -[02:07:26] VOIX: quand il est venu -[02:07:27] VOIX: pour le rein -[02:07:29] VOIX: c'était pas -[02:07:31] VOIX: tu vois -[02:07:32] VOIX: il a fait -[02:07:32] VOIX: un acte rénal -[02:07:33] VOIX: ça ne fait pas partir -[02:07:34] VOIX: du parcours ortho -[02:07:36] VOIX: même si -[02:07:37] VOIX: c'est un séjour -[02:07:38] VOIX: du patient -[02:07:38] VOIX: tu vois -[02:07:39] VOIX: c'est ça -[02:07:40] VOIX: que quand on va -[02:07:41] VOIX: travailler les parcours -[02:07:42] VOIX: il faudra aussi -[02:07:43] VOIX: qu'on arrive à raisonner -[02:07:44] VOIX: à relier -[02:07:45] VOIX: etc -[02:07:46] VOIX: ça c'est encore -[02:07:47] VOIX: un autre niveau -[02:07:49] VOIX: ok -[02:07:51] VOIX: tu penses que -[02:07:52] VOIX: l'IA -[02:07:52] VOIX: va nous aider -[02:07:53] VOIX: pour tout ça -[02:07:56] VOIX: disons -[02:07:57] VOIX: oui -[02:07:57] VOIX: mais par contre -[02:07:58] VOIX: en fait -[02:07:59] VOIX: si tu veux -[02:07:59] VOIX: il faut que -[02:08:01] VOIX: qu'on arrive -[02:08:02] VOIX: en fait -[02:08:02] VOIX: que j'arrive -[02:08:03] VOIX: moi -[02:08:04] VOIX: une grosse partie -[02:08:05] VOIX: de tout ce que tu as dit -[02:08:05] VOIX: je l'ai déjà fait -[02:08:07] VOIX: simplement -[02:08:07] VOIX: en fait -[02:08:07] VOIX: moi -[02:08:08] VOIX: il faut que j'arrive -[02:08:08] VOIX: à attraper -[02:08:09] VOIX: en fait -[02:08:09] VOIX: les trucs métiers -[02:08:10] VOIX: et il faut que je les fasse -[02:08:11] VOIX: en fait -[02:08:12] VOIX: de façon séquentielle -[02:08:13] VOIX: c'est en fait -[02:08:13] VOIX: on va travailler -[02:08:14] VOIX: en fait -[02:08:15] VOIX: sur la psychiatrie -[02:08:16] VOIX: par exemple -[02:08:17] VOIX: ou ce genre de choses là -[02:08:18] VOIX: on va -[02:08:19] VOIX: on va étudier également -[02:08:20] VOIX: en fait -[02:08:20] VOIX: les relations -[02:08:21] VOIX: qu'il y a entre les règles -[02:08:22] VOIX: mais c'est -[02:08:23] VOIX: c'est ce que tu m'as expliqué -[02:08:24] VOIX: en fait tout à l'heure -[02:08:25] VOIX: mais se concentrer -[02:08:26] VOIX: en fait -[02:08:27] VOIX: un par un -[02:08:27] VOIX: parce qu'il y en a -[02:08:28] VOIX: beaucoup -[02:08:28] VOIX: en fait -[02:08:29] VOIX: c'est pour ça que je te dis -[02:08:30] VOIX: il y a beaucoup -[02:08:30] VOIX: parce qu'en fonction -[02:08:31] VOIX: de ton mode de réflexion -[02:08:34] VOIX: change tout le temps -[02:08:35] VOIX: en fonction en fait -[02:08:36] VOIX: de certains nombres -[02:08:37] VOIX: de critères -[02:08:38] VOIX: qu'il faut que j'arrive -[02:08:39] VOIX: à choper -[02:08:39] VOIX: il y en a plein -[02:08:40] VOIX: en fait -[02:08:41] VOIX: que j'ai attrapé -[02:08:41] VOIX: mais il y en a -[02:08:42] VOIX: plein en fait -[02:08:43] VOIX: qui m'échappent -[02:08:44] VOIX: en fait -[02:08:44] VOIX: pour le moment -[02:08:44] VOIX: pourquoi -[02:08:45] VOIX: parce que ça me fait -[02:08:46] VOIX: beaucoup d'informations -[02:08:47] VOIX: en fait -[02:08:47] VOIX: en même temps -[02:08:48] VOIX: donc là -[02:08:48] VOIX: ce que je vais faire -[02:08:50] VOIX: en fait -[02:08:50] VOIX: quand on va arrêter -[02:08:51] VOIX: c'est que je vais prendre -[02:08:52] VOIX: en fait -[02:08:53] VOIX: tout ce que tu m'as raconté -[02:08:54] VOIX: je vais le mettre -[02:08:55] VOIX: dans la machine -[02:08:55] VOIX: je vais lui demander -[02:08:56] VOIX: de me faire une synthèse -[02:08:58] VOIX: de tout ce que l'on a vu -[02:08:59] VOIX: et à chaque fois -[02:09:01] VOIX: en fait -[02:09:01] VOIX: que tu m'as montré -[02:09:02] VOIX: en fait des choses -[02:09:03] VOIX: j'ai pris des photos -[02:09:04] VOIX: et donc en fait -[02:09:05] VOIX: je vais demander -[02:09:05] VOIX: en fait -[02:09:06] VOIX: de faire un rapprochement -[02:09:07] VOIX: entre les photos -[02:09:08] VOIX: entre -[02:09:09] VOIX: voilà -[02:09:09] VOIX: mais pas toi -[02:09:10] VOIX: en fait -[02:09:10] VOIX: en photo -[02:09:10] VOIX: la capture d'écran -[02:09:14] VOIX: tu peux -[02:09:14] VOIX: tu peux -[02:09:15] VOIX: attends -[02:09:16] VOIX: j'ai encore -[02:09:17] VOIX: une ou deux questions -[02:09:18] VOIX: mais pas plus -[02:09:19] VOIX: c'est fini -[02:09:19] VOIX: d'accord -[02:09:22] VOIX: oui alors -[02:09:25] VOIX: j'avais noté -[02:09:27] VOIX: ça -[02:09:27] VOIX: on va pas -[02:09:28] VOIX: quelqu'un -[02:09:29] VOIX: oui -[02:09:30] VOIX: quels sont -[02:09:31] VOIX: ouais -[02:09:34] VOIX: il y a des cas -[02:09:35] VOIX: alors -[02:09:36] VOIX: on va pas -[02:09:36] VOIX: le développer -[02:09:37] VOIX: maintenant -[02:09:38] VOIX: mais c'est -[02:09:38] VOIX: quels sont -[02:09:39] VOIX: en fait -[02:09:39] VOIX: les moyens -[02:09:40] VOIX: puisque tu parlais -[02:09:40] VOIX: en fait -[02:09:41] VOIX: d'IA -[02:09:41] VOIX: quels sont -[02:09:42] VOIX: en fait -[02:09:42] VOIX: les cas -[02:09:43] VOIX: où -[02:09:44] VOIX: obligatoirement -[02:09:44] VOIX: en fait -[02:09:45] VOIX: tu veux garder -[02:09:45] VOIX: la main -[02:09:45] VOIX: en fait -[02:09:46] VOIX: sur l'étude -[02:09:46] VOIX: en fait -[02:09:47] VOIX: du dossier -[02:09:48] VOIX: ou toi -[02:09:49] VOIX: en tant que team -[02:09:51] VOIX: aucun -[02:09:53] VOIX: bon -[02:09:53] VOIX: j'en étais -[02:09:53] VOIX: à peu près sûr -[02:09:54] VOIX: mais -[02:09:56] VOIX: le boulot -[02:09:56] VOIX: le boulot -[02:09:57] VOIX: peut être -[02:09:57] VOIX: refait -[02:09:57] VOIX: sans que -[02:09:58] VOIX: je garde -[02:09:59] VOIX: la main -[02:09:59] VOIX: alors là -[02:10:01] VOIX: c'est le top -[02:10:02] VOIX: du top -[02:10:02] VOIX: il y en a -[02:10:03] VOIX: non mais -[02:10:03] VOIX: alors -[02:10:04] VOIX: pourquoi en fait -[02:10:05] VOIX: je te dis ça -[02:10:06] VOIX: c'est parce que -[02:10:06] VOIX: c'est parce que -[02:10:08] VOIX: si tu as en fait -[02:10:09] VOIX: alors attends -[02:10:09] VOIX: parce que ça fait partie -[02:10:10] VOIX: de mes questions -[02:10:11] VOIX: en fait -[02:10:11] VOIX: dernière question -[02:10:15] VOIX: ouais -[02:10:15] VOIX: c'est l'indicateur -[02:10:16] VOIX: parce qu'en fait -[02:10:17] VOIX: ça va me donner -[02:10:17] VOIX: en fait -[02:10:18] VOIX: parce que moi -[02:10:19] VOIX: quand je fais -[02:10:19] VOIX: en fait -[02:10:20] VOIX: le traitement -[02:10:20] VOIX: en fait -[02:10:21] VOIX: des documents -[02:10:22] VOIX: je sors en fait -[02:10:23] VOIX: un indicateur -[02:10:25] VOIX: ou un scoring -[02:10:26] VOIX: tu sais -[02:10:26] VOIX: on avait parlé de ça -[02:10:27] VOIX: c'est à dire que -[02:10:28] VOIX: si je suis pas sûr -[02:10:29] VOIX: en fait -[02:10:29] VOIX: à tant de pourcents -[02:10:30] VOIX: je mets de côté -[02:10:32] VOIX: maintenant -[02:10:32] VOIX: au fur et à mesure -[02:10:33] VOIX: où j'avance moi -[02:10:34] VOIX: j'arrive à avoir -[02:10:35] VOIX: en fait -[02:10:35] VOIX: des comparaisons -[02:10:37] VOIX: elles sont pas -[02:10:38] VOIX: nécessairement -[02:10:38] VOIX: super bonnes -[02:10:39] VOIX: pour le moment -[02:10:39] VOIX: j'attends d'avoir -[02:10:40] VOIX: en fait -[02:10:40] VOIX: la suite logique -[02:10:41] VOIX: dont je t'avais parlé -[02:10:42] VOIX: donc le -[02:10:45] VOIX: le compte rendu -[02:10:46] VOIX: en fait -[02:10:46] VOIX: le compte rendu -[02:10:47] VOIX: est arrivé -[02:10:48] VOIX: en fait -[02:10:48] VOIX: jusqu'au codage -[02:10:49] VOIX: qui a été vérifié -[02:10:51] VOIX: contrôlé par la suite -[02:10:52] VOIX: si j'ai ça -[02:10:53] VOIX: en fait -[02:10:53] VOIX: ça va être top -[02:10:54] VOIX: mais je fais du scoring -[02:10:56] VOIX: et si tu veux -[02:10:57] VOIX: il y en a beaucoup -[02:10:58] VOIX: en fait -[02:10:58] VOIX: beaucoup -[02:10:59] VOIX: ou pas beaucoup -[02:10:59] VOIX: j'en sais -[02:11:00] VOIX: pas grand chose -[02:11:01] VOIX: là pour le moment -[02:11:01] VOIX: mais que je vais écarter -[02:11:03] VOIX: mais au fur et à mesure -[02:11:04] VOIX: en fait -[02:11:04] VOIX: mon scoring -[02:11:04] VOIX: va diminuer -[02:11:05] VOIX: c'est en fait -[02:11:06] VOIX: au lieu d'avoir -[02:11:07] VOIX: en fait -[02:11:07] VOIX: que 80% -[02:11:08] VOIX: de résultats bons -[02:11:09] VOIX: j'en suis à 75% -[02:11:11] VOIX: pour le moment -[02:11:12] VOIX: ben je vais arriver -[02:11:13] VOIX: près de 100% -[02:11:15] VOIX: mais il y a des dossiers -[02:11:16] VOIX: en fait -[02:11:17] VOIX: typiquement -[02:11:18] VOIX: sur lequel -[02:11:18] VOIX: en fait -[02:11:19] VOIX: il va vouloir -[02:11:19] VOIX: prendre la main -[02:11:20] VOIX: pour des tonnes de raisons -[02:11:21] VOIX: pour l'instant -[02:11:22] VOIX: par exemple -[02:11:22] VOIX: quand il va y avoir -[02:11:24] VOIX: le groupeur -[02:11:25] VOIX: et qu'il va y avoir -[02:11:26] VOIX: un dossier -[02:11:27] VOIX: qui va être évalué -[02:11:28] VOIX: alors s'il est évalué -[02:11:29] VOIX: à 50 euros -[02:11:30] VOIX: ça peut aller -[02:11:30] VOIX: mais s'il est évalué -[02:11:32] VOIX: en fait -[02:11:32] VOIX: à 150 000 euros -[02:11:33] VOIX: ou 200 000 euros -[02:11:34] VOIX: je sais pas en fait -[02:11:34] VOIX: je dis des bêtises -[02:11:36] VOIX: peut-être que ce dossier -[02:11:37] VOIX: là en particulier -[02:11:38] VOIX: tu vas préférer -[02:11:39] VOIX: en fait -[02:11:39] VOIX: le travailler -[02:11:40] VOIX: tout seul -[02:11:40] VOIX: pour ne pas avoir -[02:11:41] VOIX: d'embêtement -[02:11:42] VOIX: par la suite -[02:11:43] VOIX: ou le vérifier -[02:11:45] VOIX: ou le contrôler -[02:11:47] VOIX: ou le valider -[02:11:48] VOIX: parce que -[02:11:49] VOIX: nous le progrès -[02:11:51] VOIX: si tout va bien -[02:11:52] VOIX: c'est au delà -[02:11:54] VOIX: de proposer des codes -[02:11:55] VOIX: c'est les intégrer -[02:11:56] VOIX: dans le dossier -[02:11:58] VOIX: et peut-être -[02:11:59] VOIX: valider -[02:12:00] VOIX: ou pas le dossier -[02:12:00] VOIX: voilà -[02:12:01] VOIX: ça c'est -[02:12:03] VOIX: une priorisation -[02:12:05] VOIX: pour dire -[02:12:06] VOIX: certains dossiers -[02:12:07] VOIX: seront obligatoirement -[02:12:08] VOIX: à valider -[02:12:09] VOIX: par l'humain -[02:12:09] VOIX: tu fais toujours -[02:12:10] VOIX: ta proposition initiale -[02:12:11] VOIX: oui -[02:12:12] VOIX: ben c'est -[02:12:12] VOIX: voilà -[02:12:13] VOIX: en fait -[02:12:13] VOIX: l'étape que tu viens -[02:12:14] VOIX: de donner -[02:12:14] VOIX: en fait -[02:12:15] VOIX: j'ai fait un raccourci -[02:12:16] VOIX: en fait -[02:12:17] VOIX: en 100 ans -[02:12:18] VOIX: mais -[02:12:18] VOIX: voilà -[02:12:19] VOIX: il y en a peut-être -[02:12:19] VOIX: certaines -[02:12:20] VOIX: donc ok -[02:12:21] VOIX: merci -[02:12:23] VOIX: attends -[02:12:27] VOIX: ok -[02:12:28] VOIX: ben écoute -[02:12:28] VOIX: ben là -[02:12:29] VOIX: on va arrêter -[02:12:30] VOIX: pour aujourd'hui -[02:12:31] VOIX: en fait -[02:12:32] VOIX: sur ce sujet -[02:12:33] VOIX: parce que -[02:12:33] VOIX: c'est bon -[02:12:37] VOIX: non c'est -[02:12:38] VOIX: c'est dense -[02:12:39] VOIX: et puis en plus -[02:12:39] VOIX: comme je suis -[02:12:40] VOIX: en plein dedans -[02:12:41] VOIX: il y a plein de choses -[02:12:42] VOIX: en fait -[02:12:42] VOIX: que tu m'as dit -[02:12:43] VOIX: où je suis assez content -[02:12:44] VOIX: en fait -[02:12:44] VOIX: parce que je me dis -[02:12:45] VOIX: là j'ai bien chopé le truc -[02:12:46] VOIX: et j'ai bien -[02:12:47] VOIX: en fait -[02:12:48] VOIX: il y a d'autres choses -[02:12:49] VOIX: où je me dis -[02:12:50] VOIX: ah merde -[02:12:52] VOIX: j'oublie -[02:12:53] VOIX: mais bon -[02:12:54] VOIX: après en fait -[02:12:54] VOIX: c'est justement -[02:12:55] VOIX: là -[02:12:55] VOIX: on est -[02:12:56] VOIX: on va être -[02:12:58] VOIX: sur des itérations -[02:12:58] VOIX: c'est à dire -[02:12:59] VOIX: aujourd'hui -[02:12:59] VOIX: on a pris un peu de temps -[02:13:00] VOIX: après en fait -[02:13:01] VOIX: on va s'échanger -[02:13:02] VOIX: en fait des mails -[02:13:03] VOIX: mais tu vas voir -[02:13:04] VOIX: je vais essayer -[02:13:04] VOIX: de faire des trucs -[02:13:05] VOIX: ultra courts -[02:13:08] VOIX: je finis -[02:13:08] VOIX: je vais essayer -[02:13:09] VOIX: en fait de faire -[02:13:10] VOIX: en fait -[02:13:10] VOIX: ce qu'on appelle -[02:13:10] VOIX: des itérations -[02:13:11] VOIX: mais courtes -[02:13:12] VOIX: c'est jette moi -[02:13:13] VOIX: un coup d'oeil -[02:13:13] VOIX: en fait sur ça -[02:13:14] VOIX: et toi en fait -[02:13:15] VOIX: en deux secondes -[02:13:15] VOIX: tu vas le faire -[02:13:16] VOIX: le truc que j'aurais -[02:13:17] VOIX: mis deux heures -[02:13:17] VOIX: à faire -[02:13:19] VOIX: et toi tu vas me dire -[02:13:20] VOIX: c'est bon -[02:13:20] VOIX: ou c'est pas bon -[02:13:21] VOIX: parce que là -[02:13:21] VOIX: je sais que -[02:13:22] VOIX: voilà -[02:13:23] VOIX: est-ce que tu penses -[02:13:27] VOIX: que si je te donne -[02:13:28] VOIX: par exemple -[02:13:29] VOIX: le chapitre -[02:13:30] VOIX: un ou deux chapitres -[02:13:32] VOIX: du guide méthodologique -[02:13:33] VOIX: le 5 -[02:13:35] VOIX: le 6 -[02:13:36] VOIX: tu sais -[02:13:37] VOIX: je vais repartager -[02:13:38] VOIX: juste pour que tu vois -[02:13:39] VOIX: de quoi je parle -[02:13:41] VOIX: et indépendamment -[02:13:42] VOIX: d'où c'est passé -[02:13:44] VOIX: c'est une question -[02:13:45] VOIX: je n'en sais rien -[02:13:46] VOIX: si c'est utile ou pas -[02:13:48] VOIX: parce que -[02:13:50] VOIX: non ça c'est 3 -[02:13:51] VOIX: 4 -[02:13:52] VOIX: à partir de là -[02:13:53] VOIX: le chapitre 5 -[02:13:55] VOIX: parce qu'il définit -[02:13:57] VOIX: non non -[02:13:58] VOIX: même ça -[02:13:58] VOIX: c'est quoi un DP -[02:13:59] VOIX: c'est quoi un DAS -[02:14:02] VOIX: les règles -[02:14:03] VOIX: de certains codages -[02:14:04] VOIX: le chapitre 5 -[02:14:06] VOIX: et le chapitre 6 -[02:14:08] VOIX: c'est celui des guides -[02:14:10] VOIX: des situations cliniques -[02:14:11] VOIX: au moins -[02:14:13] VOIX: le 4 -[02:14:14] VOIX: et le 6 -[02:14:15] VOIX: est-ce que tu penses -[02:14:17] VOIX: que s'il y a réfléchi -[02:14:19] VOIX: sur un raisonnement -[02:14:20] VOIX: sur une synthèse -[02:14:23] VOIX: sur des -[02:14:24] VOIX: une autre manière -[02:14:25] VOIX: de -[02:14:27] VOIX: de -[02:14:29] VOIX: de -[02:14:29] VOIX: de -[02:14:29] VOIX: nommer -[02:14:30] VOIX: ou de -[02:14:31] VOIX: de traiter les règles -[02:14:33] VOIX: c'est quelque chose -[02:14:34] VOIX: qui peut être utile -[02:14:35] VOIX: c'est une question -[02:14:36] VOIX: je n'en sais rien moi -[02:14:37] VOIX: mais -[02:14:37] VOIX: c'est une question -[02:14:39] VOIX: bah écoute -[02:14:39] VOIX: comme moi -[02:14:40] VOIX: mes questions -[02:14:40] VOIX: elles n'étaient pas si bêtes -[02:14:42] VOIX: la tienne -[02:14:42] VOIX: elle n'est pas bête non plus -[02:14:44] VOIX: en fait -[02:14:44] VOIX: oui c'est utile -[02:14:45] VOIX: c'est utile -[02:14:46] VOIX: mais -[02:14:47] VOIX: ce document là -[02:14:48] VOIX: moi je l'ai déjà -[02:14:48] VOIX: enfin -[02:14:49] VOIX: les règles -[02:14:50] VOIX: que tu m'as données -[02:14:50] VOIX: sont déjà intégrées -[02:14:52] VOIX: en fait dans le système -[02:14:53] VOIX: mais -[02:14:53] VOIX: moi ce qui me manque -[02:14:54] VOIX: c'est -[02:14:55] VOIX: je reviens toujours là dessus -[02:14:56] VOIX: mais je vais les avoir -[02:14:57] VOIX: en fait -[02:14:57] VOIX: demain -[02:14:57] VOIX: je ne vais pas embêter -[02:14:58] VOIX: Jordan et Pauline -[02:15:00] VOIX: en fait -[02:15:00] VOIX: parce que je sais -[02:15:00] VOIX: qu'ils ont beaucoup de boulot -[02:15:01] VOIX: mais moi en fait -[02:15:02] VOIX: à partir du moment -[02:15:03] VOIX: où je vais avoir -[02:15:03] VOIX: en fait -[02:15:04] VOIX: je ne sais pas -[02:15:05] VOIX: 300 -[02:15:05] VOIX: 400 en fait -[02:15:06] VOIX: documents -[02:15:06] VOIX: avec la chaîne complète -[02:15:08] VOIX: le compte rendu -[02:15:09] VOIX: le codage -[02:15:10] VOIX: etc -[02:15:10] VOIX: moi dès que j'ai ça -[02:15:12] VOIX: après je ne casse plus -[02:15:13] VOIX: les pieds à personne -[02:15:14] VOIX: pendant un petit moment -[02:15:15] VOIX: et après en fait -[02:15:16] VOIX: on fait des tests -[02:15:19] VOIX: oui mais -[02:15:20] VOIX: Dominique -[02:15:21] VOIX: tu l'as ça -[02:15:22] VOIX: tu l'as -[02:15:24] VOIX: ah bon -[02:15:24] VOIX: et bien oui -[02:15:26] VOIX: on a 700 dossiers -[02:15:28] VOIX: du CHCB -[02:15:29] VOIX: non -[02:15:29] VOIX: j'ai le compte rendu -[02:15:30] VOIX: j'ai le tracker -[02:15:32] VOIX: non -[02:15:33] VOIX: t'as vu ce qu'on a -[02:15:34] VOIX: ouvert tout à l'heure -[02:15:35] VOIX: le 339 -[02:15:36] VOIX: oui -[02:15:37] VOIX: le tracker -[02:15:38] VOIX: t'avais -[02:15:39] VOIX: les résultats de biologie -[02:15:41] VOIX: t'avais les comptes rendus -[02:15:42] VOIX: des examens complémentaires -[02:15:44] VOIX: d'imagerie -[02:15:45] VOIX: t'avais les observations -[02:15:46] VOIX: médicales -[02:15:47] VOIX: t'avais les notes -[02:15:48] VOIX: des IDE -[02:15:49] VOIX: donc t'avais tout le dossier -[02:15:50] VOIX: les trackers -[02:15:51] VOIX: ils sont pas que -[02:15:52] VOIX: les comptes rendus -[02:15:53] VOIX: les trackers -[02:15:54] VOIX: c'est -[02:15:54] VOIX: la majorité -[02:15:55] VOIX: je le sais -[02:15:56] VOIX: je le sais -[02:15:58] VOIX: mais j'ai pas -[02:15:59] VOIX: toute la chaîne -[02:16:00] VOIX: il me manque en fait -[02:16:01] VOIX: ce que t'appelles -[02:16:03] VOIX: le primo codage -[02:16:03] VOIX: ou le codage -[02:16:04] VOIX: en fait -[02:16:04] VOIX: qui a été fait -[02:16:05] VOIX: ça c'est facile -[02:16:06] VOIX: on a le fichier -[02:16:07] VOIX: PM ici -[02:16:07] VOIX: on a tout -[02:16:09] VOIX: on a tout -[02:16:10] VOIX: on a les documents -[02:16:11] VOIX: parce que -[02:16:12] VOIX: si tu le veux -[02:16:13] VOIX: je sais pas -[02:16:14] VOIX: comment tu le veux -[02:16:15] VOIX: parce que -[02:16:16] VOIX: il y a trois manières -[02:16:18] VOIX: de l'avoir -[02:16:19] VOIX: il y a -[02:16:20] VOIX: Jordan par exemple -[02:16:22] VOIX: il a travaillé -[02:16:23] VOIX: sur les fichiers -[02:16:24] VOIX: bruts -[02:16:24] VOIX: il sort -[02:16:26] VOIX: pour Guy -[02:16:27] VOIX: des fichiers -[02:16:29] VOIX: plats -[02:16:29] VOIX: en Excel -[02:16:30] VOIX: que Guy travaille -[02:16:31] VOIX: pour me sortir -[02:16:32] VOIX: les rapports -[02:16:33] VOIX: d'activité -[02:16:33] VOIX: donc on peut -[02:16:35] VOIX: très bien -[02:16:35] VOIX: demander à Jordan -[02:16:36] VOIX: je vais regarder -[02:16:37] VOIX: d'ailleurs -[02:16:38] VOIX: le fichier -[02:16:38] VOIX: avec Guy -[02:16:38] VOIX: si tu as -[02:16:40] VOIX: numéro de dossier -[02:16:41] VOIX: par numéro de dossier -[02:16:42] VOIX: pour les OGC -[02:16:43] VOIX: parce que c'est ça -[02:16:44] VOIX: qu'on va demander -[02:16:45] VOIX: à Jordan -[02:16:45] VOIX: prendre le druide -[02:16:47] VOIX: de CHCB -[02:16:48] VOIX: coller le numéro -[02:16:50] VOIX: OGC -[02:16:50] VOIX: le numéro de dossier -[02:16:51] VOIX: me dire -[02:16:52] VOIX: c'est quoi mon DP -[02:16:53] VOIX: c'est quoi mon DAS -[02:16:54] VOIX: voilà -[02:16:55] VOIX: en fait -[02:16:56] VOIX: j'ai besoin -[02:16:56] VOIX: de ça -[02:16:57] VOIX: peu importe -[02:16:59] VOIX: le format -[02:16:59] VOIX: en fait -[02:16:59] VOIX: j'ai besoin -[02:17:00] VOIX: de ça -[02:17:00] VOIX: comme ça -[02:17:01] VOIX: en fait -[02:17:01] VOIX: j'ai toute la chaîne -[02:17:03] VOIX: parce qu'en fait -[02:17:03] VOIX: si tu veux -[02:17:04] VOIX: pour entraîner -[02:17:04] VOIX: une IA -[02:17:05] VOIX: en fait -[02:17:05] VOIX: ce dont je t'avais -[02:17:06] VOIX: parlé -[02:17:06] VOIX: en fait -[02:17:07] VOIX: on a besoin -[02:17:08] VOIX: d'avoir -[02:17:08] VOIX: toute la chaîne -[02:17:10] VOIX: et en plus -[02:17:11] VOIX: là -[02:17:12] VOIX: parce que -[02:17:12] VOIX: pour moi -[02:17:13] VOIX: en fait -[02:17:13] VOIX: je dis -[02:17:13] VOIX: on a de la chance -[02:17:14] VOIX: parce qu'en fait -[02:17:15] VOIX: vous avez tout -[02:17:17] VOIX: ah oui -[02:17:17] VOIX: donc -[02:17:20] VOIX: c'est bien -[02:17:22] VOIX: attends -[02:17:23] VOIX: je vais te montrer -[02:17:23] VOIX: juste un fichier -[02:17:26] VOIX: si j'arrive -[02:17:27] VOIX: à l'ouvrir -[02:17:30] VOIX: euh -[02:17:33] VOIX: recettes -[02:17:33] VOIX: c'est -[02:17:34] VOIX: jour -[02:17:39] VOIX: non -[02:17:40] VOIX: j'arrive pas -[02:17:40] VOIX: à l'ouvrir -[02:17:41] VOIX: non -[02:17:42] VOIX: j'arrive pas -[02:17:43] VOIX: à l'ouvrir -[02:17:43] VOIX: donc -[02:17:45] VOIX: voilà -[02:17:45] VOIX: j'ai besoin -[02:17:46] VOIX: ça c'est -[02:17:47] VOIX: demain -[02:17:48] VOIX: je laisserai -[02:17:49] VOIX: le message -[02:17:49] VOIX: à Guy -[02:17:50] VOIX: je sais si tu seras -[02:17:51] VOIX: avec eux -[02:17:52] VOIX: pour la réunion -[02:17:53] VOIX: qu'ils ont prévu -[02:17:54] VOIX: Guy et Jordan -[02:17:54] VOIX: oui -[02:17:57] VOIX: donc en tout cas -[02:17:58] VOIX: c'est facile -[02:17:58] VOIX: ça -[02:17:59] VOIX: si tu veux -[02:18:00] VOIX: oui c'est facile -[02:18:01] VOIX: mais je l'ai pas -[02:18:02] VOIX: donc -[02:18:02] VOIX: mais après -[02:18:02] VOIX: en fait -[02:18:03] VOIX: voilà -[02:18:03] VOIX: demain -[02:18:04] VOIX: je vais demander -[02:18:04] VOIX: à Jordan -[02:18:05] VOIX: non -[02:18:05] VOIX: parce que -[02:18:07] VOIX: il n'oublie pas -[02:18:08] VOIX: que dans -[02:18:08] VOIX: les 100 -[02:18:10] VOIX: les 700 dossiers -[02:18:11] VOIX: du CHCB -[02:18:12] VOIX: il y a -[02:18:13] VOIX: au moins -[02:18:14] VOIX: je pense -[02:18:14] VOIX: 150 -[02:18:16] VOIX: où le médecin contrôleur -[02:18:17] VOIX: il est ok -[02:18:17] VOIX: avec l'établissement -[02:18:18] VOIX: donc il n'y a pas -[02:18:19] VOIX: de changement -[02:18:20] VOIX: et c'est -[02:18:21] VOIX: ceux-là -[02:18:21] VOIX: que tu peux dire -[02:18:23] VOIX: j'ai pas -[02:18:24] VOIX: les allers-retours -[02:18:24] VOIX: etc -[02:18:25] VOIX: je peux déjà -[02:18:26] VOIX: commencer par -[02:18:27] VOIX: ceux-là -[02:18:27] VOIX: pour appliquer -[02:18:28] VOIX: les règles -[02:18:29] VOIX: ceux qui n'ont -[02:18:29] VOIX: pas été modifiés -[02:18:31] VOIX: mais je vais réfléchir -[02:18:32] VOIX: à un format -[02:18:34] VOIX: nous aussi -[02:18:35] VOIX: on travaille -[02:18:35] VOIX: sur ça -[02:18:36] VOIX: moi le format -[02:18:37] VOIX: en fait -[02:18:37] VOIX: si c'est du CSV -[02:18:38] VOIX: ou de l'Excel -[02:18:39] VOIX: ou même un PDF -[02:18:40] VOIX: en fait -[02:18:40] VOIX: moi ça me va -[02:18:41] VOIX: je n'ai pas -[02:18:42] VOIX: de soucis -[02:18:43] VOIX: de formatage -[02:18:43] VOIX: d'accord -[02:18:44] VOIX: il n'y a pas -[02:18:44] VOIX: il n'y a pas -[02:18:45] VOIX: de soucis -[02:18:47] VOIX: non je n'ai pas -[02:18:47] VOIX: de soucis -[02:18:48] VOIX: de formatage -[02:18:48] VOIX: d'ailleurs -[02:18:49] VOIX: en fait -[02:18:49] VOIX: en ce qui concerne -[02:18:51] VOIX: les -[02:19:04] VOIX: parce qu'en fait -[02:19:05] VOIX: le document -[02:19:07] VOIX: en fait -[02:19:07] VOIX: que Harid -[02:19:08] VOIX: a fait -[02:19:08] VOIX: avait quelques -[02:19:09] VOIX: petits trous -[02:19:09] VOIX: dans la raquette -[02:19:10] VOIX: c'est les -[02:19:11] VOIX: i -[02:19:11] VOIX: qui confond -[02:19:12] VOIX: avec les -[02:19:13] VOIX: 1 -[02:19:13] VOIX: ou ce genre -[02:19:14] VOIX: de choses -[02:19:14] VOIX: là -[02:19:14] VOIX: donc j'ai -[02:19:15] VOIX: corrigé -[02:19:15] VOIX: le programme -[02:19:16] VOIX: en entier -[02:19:17] VOIX: et le programme -[02:19:17] VOIX: que j'ai fait -[02:19:20] VOIX: prend en fait -[02:19:20] VOIX: n'importe quel -[02:19:21] VOIX: type de document -[02:19:22] VOIX: et le réorganise -[02:19:23] VOIX: complètement -[02:19:23] VOIX: et voilà -[02:19:25] VOIX: Harid -[02:19:25] VOIX: tu sais -[02:19:26] VOIX: qu'il est absent -[02:19:26] VOIX: pendant un mois -[02:19:27] VOIX: et bien oui -[02:19:28] VOIX: je sais -[02:19:28] VOIX: il revient -[02:19:29] VOIX: un mois -[02:19:30] VOIX: et il repart -[02:19:30] VOIX: 3-4 mois -[02:19:32] VOIX: donc c'est vrai -[02:19:32] VOIX: qu'il faut que -[02:19:34] VOIX: bon ça c'est un sujet -[02:19:35] VOIX: à part entière -[02:19:36] VOIX: mais on en reparlera -[02:19:37] VOIX: tranquillement -[02:19:39] VOIX: parce que je pense -[02:19:40] VOIX: que pendant -[02:19:40] VOIX: 5 à 6 mois -[02:19:41] VOIX: je ne sais pas -[02:19:42] VOIX: s'il va continuer -[02:19:43] VOIX: à faire des choses -[02:19:44] VOIX: pour nous -[02:19:45] VOIX: alors je te le dis -[02:19:46] VOIX: pour avoir discuté -[02:19:47] VOIX: avec lui -[02:19:48] VOIX: longuement -[02:19:48] VOIX: en fait il a dit -[02:19:51] VOIX: oui mais bon -[02:19:52] VOIX: là où je vais -[02:19:53] VOIX: je vais travailler -[02:19:54] VOIX: en fait je vais arrêter -[02:19:55] VOIX: de travailler -[02:19:55] VOIX: à 15h ou 16h -[02:19:57] VOIX: et après en fait -[02:19:58] VOIX: parce qu'il aime bien -[02:19:59] VOIX: travailler -[02:19:59] VOIX: parce que moi -[02:20:00] VOIX: je l'ai en fait -[02:20:01] VOIX: on échange beaucoup -[02:20:02] VOIX: je lui donne -[02:20:03] VOIX: en fait des trucs -[02:20:04] VOIX: et astuces -[02:20:04] VOIX: en fait des choses -[02:20:05] VOIX: comme ça -[02:20:06] VOIX: donc il m'a dit -[02:20:08] VOIX: en fait -[02:20:08] VOIX: je veux continuer -[02:20:09] VOIX: je le mets en copie -[02:20:10] VOIX: en fait de tout -[02:20:11] VOIX: ce que je fais -[02:20:11] VOIX: de la manière -[02:20:12] VOIX: à ce qu'il ait -[02:20:13] VOIX: les infos -[02:20:14] VOIX: d'accord -[02:20:15] VOIX: donc il continue -[02:20:16] VOIX: quand même -[02:20:16] VOIX: à avancer -[02:20:17] VOIX: sur 2-3 petits sujets -[02:20:18] VOIX: oui -[02:20:19] VOIX: mais je vais essayer -[02:20:20] VOIX: de l'appeler -[02:20:21] VOIX: peut-être pas là -[02:20:21] VOIX: mais en tout cas -[02:20:23] VOIX: pour caler un peu -[02:20:25] VOIX: là où on peut -[02:20:26] VOIX: encore compter -[02:20:27] VOIX: sur lui -[02:20:27] VOIX: ou pas -[02:20:28] VOIX: tu vois -[02:20:29] VOIX: parce que c'est -[02:20:29] VOIX: important quand même -[02:20:31] VOIX: il a dit -[02:20:31] VOIX: c'est qu'il ne voulait -[02:20:32] VOIX: pas -[02:20:32] VOIX: surtout pas perdre -[02:20:33] VOIX: notre fil -[02:20:33] VOIX: de continuer -[02:20:35] VOIX: dans mon sens -[02:20:36] VOIX: la logique -[02:20:37] VOIX: quand même -[02:20:38] VOIX: il sera ailleurs -[02:20:39] VOIX: ok -[02:20:40] VOIX: bon on peut clôturer -[02:20:42] VOIX: nous clôturons -[02:20:43] VOIX: nous clôturons -[02:20:44] VOIX: bon en tout cas -[02:20:45] VOIX: merci beaucoup -[02:20:46] VOIX: pour toutes ces explications -[02:20:47] VOIX: dans quelques minutes -[02:20:48] VOIX: je envoie en fait -[02:20:49] VOIX: mes questions -[02:20:50] VOIX: d'accord -[02:20:51] VOIX: si tu as envie -[02:20:52] VOIX: en fait de répondre -[02:20:54] VOIX: aux questions -[02:20:55] VOIX: de t'amuser -[02:20:56] VOIX: mais en fait -[02:20:57] VOIX: je ne veux pas -[02:20:57] VOIX: te prendre le temps -[02:20:58] VOIX: je sais qu'en ce moment -[02:20:59] VOIX: en fait c'est la course -[02:20:59] VOIX: et machin -[02:21:00] VOIX: donc on le verra après -[02:21:02] VOIX: moi j'ai noté en fait -[02:21:03] VOIX: tout ce que j'avais à noter -[02:21:05] VOIX: et je reviendrai -[02:21:06] VOIX: de toute façon -[02:21:06] VOIX: en fait obligatoirement -[02:21:07] VOIX: vers toi -[02:21:08] VOIX: parce que -[02:21:08] VOIX: au fur et à mesure -[02:21:09] VOIX: je vais -[02:21:09] VOIX: là -[02:21:10] VOIX: au fur et à mesure -[02:21:11] VOIX: j'avance -[02:21:11] VOIX: et après -[02:21:12] VOIX: on va passer en fait -[02:21:13] VOIX: sur des tests -[02:21:14] VOIX: je vais faire des corrections -[02:21:15] VOIX: tu vas me critiquer -[02:21:16] VOIX: je vais faire des corrections -[02:21:18] VOIX: tu vas me re-re-critiquer -[02:21:21] VOIX: mais c'est normal en fait -[02:21:23] VOIX: c'est voilà -[02:21:23] VOIX: si il n'y a pas de critique -[02:21:25] VOIX: là je vais m'inquiéter -[02:21:28] VOIX: n'inquiète pas -[02:21:29] VOIX: on va -[02:21:31] VOIX: on va faire les choses -[02:21:32] VOIX: les unes après les autres -[02:21:34] VOIX: allez -[02:21:34] VOIX: mais bonne soirée -[02:21:35] VOIX: bonsoir Guy -[02:21:38] VOIX: à demain -[02:21:39] VOIX: à demain -[02:21:42] VOIX: et bonne soirée -[02:21:43] VOIX: à toi aussi -[02:21:44] VOIX: allez ciao ciao -[02:21:44] VOIX: bonne soirée -[02:21:45] VOIX: salut -[02:22:13] VOIX: mais non -[02:22:15] VOIX: bah écoute -[02:22:29] VOIX: ah ma doublure -[02:22:47] VOIX: coucou -[02:22:49] VOIX: ça va ? -[02:22:55] VOIX: tous les jours -[02:22:56] VOIX: dans les plans -[02:22:58] VOIX: ils changent -[02:22:59] VOIX: mon cas aussi -[02:23:00] VOIX: notre emplacement -[02:23:01] VOIX: c'est plus vers moi -[02:23:04] VOIX: c'est pénible -[02:23:09] VOIX: attendez moi -[02:23:09] VOIX: j'ai vu quelque chose -[02:23:11] VOIX: comme vous parlez -[02:23:13] VOIX: il y a quoi -[02:23:13] VOIX: ça va ? -[02:23:13] VOIX: ça va ? -[02:23:13] VOIX: non -[02:23:13] VOIX: pas -[02:23:14] VOIX: pas -[02:23:15] VOIX: comprends -[02:23:19] VOIX: Merci. -[02:23:51] VOIX: Merci. -[02:24:23] VOIX: Merci. -[02:24:48] VOIX: Merci. -[02:25:23] VOIX: Merci. -[02:25:53] VOIX: Merci. -[02:26:01] VOIX: Merci. -[02:26:25] VOIX: Merci. -[02:26:55] VOIX: Merci. -[02:27:18] VOIX: Merci. -[02:27:54] VOIX: Merci. -[02:28:00] VOIX: Merci. -[02:28:19] VOIX: Merci. -[02:28:23] VOIX: Merci. -[02:28:50] VOIX: Merci. -[02:28:53] VOIX: Merci. -[02:28:54] VOIX: Merci. -[02:29:01] VOIX: Merci. -[02:29:05] VOIX: Merci. -[02:29:08] VOIX: Merci. -[02:29:09] VOIX: Merci. -[02:29:09] VOIX: Moi, si je me souviens, si je me dansais dans un truc, ça m'amènerait. -[02:29:12] VOIX: Moi, le truc, je me disais, là, passé, au final, si, je fais quelque chose, mais je ne fais pas travailler. -[02:29:23] VOIX: Et là, si je me fais la même chose, je sentais que ça a su le pire. -[02:29:32] VOIX: Je ne sais pas si je me disais, mais je ne sentais pas le pire. -[02:29:35] VOIX: Je ne sais pas si je me disais, mais je ne sais pas. -[02:29:47] VOIX: ... -[02:29:48] VOIX: ... -[02:29:54] VOIX: ... -[02:29:54] VOIX: ... -[02:29:54] VOIX: ... -[02:29:56] VOIX: ... -[02:29:56] VOIX: ... -[02:29:58] VOIX: ... -[02:30:00] VOIX: ... -[02:30:01] VOIX: ... -[02:30:02] VOIX: Et moi, c'est ce que j'ai été dit. -[02:30:04] VOIX: Et s'ils me prennent, tu crois que tu vas travailler chez eux ? -[02:30:08] VOIX: Bah, s'ils te prennent, ça va être difficile. -[02:30:11] VOIX: Je crois que tu vas travailler chez eux. -[02:30:17] VOIX: Donc... -[02:30:17] VOIX: Oui, c'est pas comme moi que je vais travailler chez eux. -[02:30:23] VOIX: Je comprends, mais quoi ? -[02:30:24] VOIX: Moi, j'ai l'impression que c'est la science. -[02:30:29] VOIX: Oui, c'est longtemps que tu vas travailler chez eux. -[02:30:33] VOIX: Encore là, les mythes, elle est allée samedi, on est appelé, -[02:30:37] VOIX: donc ils font des photos. -[02:30:38] VOIX: Et rien de voir les photos, ça a été tourné. -[02:30:42] VOIX: Parce que... -[02:30:43] VOIX: Non, mais... -[02:30:43] VOIX: C'est bon, en fait, c'est bien aussi. -[02:30:46] VOIX: Donc, si tu veux, c'est déjà possible. -[02:30:49] VOIX: Après, en fait, tu vois, en fait, si c'était prise, c'est pas prise. -[02:30:52] VOIX: Si c'était prise, c'est tant mieux, parce que c'est pas le besoin. -[02:30:55] VOIX: C'est bon. -[02:30:55] VOIX: Il est utile, peut-être que... -[02:30:57] VOIX: Il y a un, peut-être qu'il y a un, peut-être qu'il y a un, -[02:30:59] VOIX: mais il empêche un peu de continuer avec la... -[02:31:01] VOIX: Sur cette partie-là, en même temps. -[02:31:07] VOIX: C'est justement... -[02:31:08] VOIX: C'est justement... -[02:31:09] VOIX: C'est justement... -[02:31:10] VOIX: C'est tout cas, peut-être qu'il faut travailler dans ta vie, -[02:31:13] VOIX: mais c'est pas le plus... -[02:31:15] VOIX: C'est sûrement, ça va. -[02:31:20] VOIX: C'est un peu plus simple. -[02:31:24] VOIX: C'est un peu plus simple. -[02:31:25] VOIX: C'est un peu plus simple. -[02:31:26] VOIX: C'est un peu plus simple. -[02:31:38] VOIX: C'est-à-dire, en fait, -[02:31:48] VOIX: il y a un peu plus simple. -[02:31:57] VOIX: C'est un peu plus simple. -[02:32:09] VOIX: C'est un peu plus simple. -[02:32:11] VOIX: C'est un peu plus simple. -[02:32:13] VOIX: Bon. -[02:32:14] VOIX: Je vais m'amener la cabane, -[02:32:15] VOIX: et... -[02:32:16] VOIX: ... \ No newline at end of file diff --git a/sans titre_summary_v2.md b/sans titre_summary_v2.md deleted file mode 100644 index 18f8990..0000000 --- a/sans titre_summary_v2.md +++ /dev/null @@ -1,73 +0,0 @@ -### 1. Cadre général du « codage CIM‑10 » dans le cadre d’un **contrôle T2A** -*(vision d’un médecin DIM – expert en codage et en gestion des contrôles)* - -| Phase | Objectif | Principales actions | Outils / livrables | -|-------|----------|----------------------|--------------------| -| **Avant le contrôle** | Préparer le dossier afin que toutes les données nécessaires soient disponibles | 1. Réception de la **notification ARS** (lettre de contrôle).
2. Identifier le **patient / séjour** (nom, prénom, NIR, unité médicale, date d’entrée et de sortie, mode d’entrée).
3. Récupérer :
  • le **CRH** (Compte‑Rendu Hospitalier),
  • les **actes** classants (chirurgicaux, anesthésie, etc.) et non‑classants (imagerie, biologie, etc.),
  • les **examens** (scanner, IRM, échographies),
  • les **bilans biologiques** (hémocultures, Bilan rénal, Bilan hépatique, etc.),
  • les **notes infirmières / IDE / kiné**,
  • les **documents d‑administration** (admission, transfert, sortie).
4. Vérifier la présence du **numéro OGC** (fichier RSS‑DEF/DRUIDE) : il permet d’associer chaque unité médicale au dossier OGC.
5. Créer le **fichier de suivi** (tracker / Excel ou CSV) contenant tous les éléments ci‑dessus, classés par unité médicale. | - Fichier RSS‑DEF (ou export DRUIDE) – numéros OGC
- EVA (ou système interne) – tracker
- Guide méthodologique CIM‑10/ T2A (chapitres 5‑6, règles M2, D1‑D3, etc.) | -| **Pendant le contrôle** | Analyser le dossier et déterminer le **diagnostic principal (DP)** et les **diagnostics associés (DAS)** conformément aux règles de la **CIM‑10** et du **tarif T2A** | 1. Lire **chronologiquement** le parcours patient : motif d’entrée → urgences → prise en charge → interventions → évolution.
2. Identifier le **DP** : le problème de santé qui a justifié l’hospitalisation.
- Si le DP est **une pathologie clairement identifiée** (ex. pancréatite, colique hépatique), il devient le DP.
- Si seul un **symptôme** persiste et aucune cause n’a été trouvée (ex. douleur abdominale sans cause retrouvée), le symptôme reste le DP (règle D1).
3. Lister les **DAS** : comorbidités (CMA), maladies chroniques, antécédents pertinents qui ont été **traités ou suivis** pendant le séjour.
- Ne coder comme DAS que les diagnostics **effectivement pris en charge** (ex. diabète de type 2 traité, anémie traitée).
4. Appliquer les **règles de hiérarchisation** (ex. R65 = sepsis lorsqu’une fièvre > 38 °C ou fréquence cardiaque > 90 bpm).
5. Vérifier les **actes classants** : ils orientent le **GHM** (Groupe Homogène de Malades) et le **groupage tarifaire**.
- Ex. une cholécystectomie → GHM 08C (chirurgical).
- Les actes non‑classants (scan, échographie) n’influencent pas le groupage, mais justifient le DP/DAS.
6. Comparer le **DP/DAS** proposé avec le **fichier d’interdiction/acceptation** (Excel référentiel CIM‑10 / T2A).
- Détecter les codes interdits (ex. K802 vs K805) ; choisir le code autorisé.
7. Émettre un **argumentaire** : expliquer le choix du DP/DAS, citer les preuves (CRH, compte‑rendu opératoire, résultats d’imagerie, prescriptions).
8. Soumettre le **fichier d’UCR** (Unité de Contrôle et de Retour) au contrôleur. | - Guide méthodologique (chapitres 5‑6)
- Référentiel CIM‑10 / T2A (Excel)
- Outil IA (lecture automatisée) – facultatif, à valider par le médecin | -| **Après le contrôle** | Répondre aux **observations du médecin contrôleur**, réviser le codage si nécessaire, finaliser la **fiche OGC** (numéro OGC + DP/DAS) | 1. Réception du **rapport du contrôleur** (commentaires, refus de certains codes).
2. Re‑examiner les pièces du dossier pointées (ex. rapport de scanner, prescription).
3. Corriger le DP/DAS ou les actes, en justifiant les changements dans l’**argumentaire**.
4. Mettre à jour la **fiche OGC** (nouveau DP/DAS, nouveau groupe‑tarif).
5. Valider la version finale avec le **MRC** (médecin responsable du contrôle) et le **DIM**. | - Fiche OGC mise à jour
- Rapport final du contrôle (archivé) | - ---- - -## 2. Points d’attention obligatoires (check‑list) - -| # | Point d’attention | Pourquoi c’est crucial | -|---|-------------------|------------------------| -| **1. Identification du patient** | Nom, prénom, NIR, date de naissance, unité(s) médicale(s) et numéro OGC. | Garantit que le codage porte sur le bon séjour. | -| **2. Dates d’entrée et de sortie** | Vérifier la **date d’admission** (souvent le jour où le patient passe aux urgences) et la **date de sortie** exacte. | Influence le calcul de la durée, le groupe‑tarif et la pertinence du DP. | -| **3. Mode d’entrée / sortie** | Urgences, réanimation, domicile, transfert. | Certains contrôles (ex. transfert inter‑établissements) sont soumis à des règles spécifiques. | -| **4. Exhaustivité du dossier** | Tous les documents : CRH, compte‑rendu opératoire, rapports d’imagerie, bilans biologiques, notes IDE/kyné, prescriptions. | Sans la totalité, le DP peut être mal identifié (symptôme vs pathologie). | -| **5. Présence du numéro OGC** | Extraction du fichier RSS‑DEF/DRUIDE → OGC 1, OGC 2, … | Nécessaire pour créer/mettre à jour la **fiche OGC** et lier le séjour au groupe‑tarif. | -| **6. Distinction DP / DAS** | DP = motif d’hospitalisation ; DAS = diagnostics réellement traités. | Le DP entre dans le groupe‑tarif, les DAS sont facturés en supplément seulement s’ils sont pris en charge. | -| **7. Règle “symptôme vs cause”** | Si aucune cause n’est identifiée, le symptôme reste DP (ex. douleur abdominale non expliquée). | Évite de coder un diagnostic absent du dossier (refus du contrôleur). | -| **8. Sévérité (niveau 2, 3, 4)** | Appliquer les règles de gravité (ex. R65 = sepsis, R70 = déshydratation). | Influence le **groupe de gravité** et le montant de la prise en charge. | -| **9. Actes classants** | Identifier les actes qui déterminent le **GHM** (ex. cholécystectomie → 08C). | Si l’acte classant manque, le groupe‑tarif sera erroné. | -| **10. Actes non‑classants** | Conservés pour justifier le DP/DAS mais ne participent pas au groupage. | Useful for argumentation ; must still être mentionnés dans le dossier. | -| **11. Respect du référentiel de codes** | Utiliser le fichier Excel **CIM‑10/T2A** fourni (liste des codes acceptés, interdits). | Empêche le choix de codes “interdits” (ex. K802 vs K805). | -| **12. Comorbidités (CMA)** | Coder les pathologies chroniques **si elles sont effectivement prises en charge** pendant le séjour. | Elles sont prises en compte dans le calcul du **GHS** et du **coefficient de sévérité**. | -| **13. Vérification des transferts** | Si le patient a été transféré d’un autre établissement, en note la prise en charge et l’inscrire dans le OGC. | Le contrôle peut remettre en cause le DP si le transfert n’est pas justifié. | -| **14. Argumentaire écrit** | Chaque code doit être justifié (CRH, rapport d’anatomopathologie, résultats d’imagerie, prescription). | Le contrôleur se base sur cet argumentaire pour valider ou refuser. | -| **15. Validation par le DIM** | Avant l’envoi, le **DIM** vérifie la conformité du DP, des DAS, des actes et du fichier OGC. | Dernière couche de contrôle interne avant la soumission. | -| **16. Gestion des corrections** | Si le contrôleur propose un changement, modifier le DP/DAS/OGC **et** mettre à jour l’argumentaire, puis re‑soumettre. | Permet de clôturer le contrôle sans pénalité financière. | -| **17. Suivi des dossiers “hors‑norme”** | Identifier les dossiers où le DP reste un symptôme non résolu ou où le montant à payer est très élevé ; les faire valider manuellement. | Ces dossiers requièrent souvent l’intervention du médecin contrôleur pour éviter un « refus total ». | -| **18. Historisation** | Conserver chaque version de la fiche OGC, du tracker et du rapport d’argumentation. | Traçabilité exigée lors de contrôles subséquents ou d’audits. | -| **19. Format de remise** | CSV / Excel : numéro dossier, OGC, DP, DAS, actes classants/non‑classants, dates, mode d’entrée, groupe‑tarif, justification. | Le format attendu par le service de contrôle (ex. Blue‑File/ShareFile). | -| **20. Respect des délais** | Soumission du dossier complet **avant la date limite** indiquée dans la notification T2A. | Tout retard entraîne des pénalités automatiques. | - ---- - -## 3. Synthèse pratique – Modèle de fiche de contrôle T2A - -| Colonnes | Contenu attendu | -|----------|----------------| -| **N° dossier** | Identifiant interne de l’établissement. | -| **Numéro OGC** | OGC‑1, OGC‑2 … (extrait de RSS‑DEF/DRUIDE). | -| **Date d’entrée** | JJ/MM/AAAA + mode (urgences, SD, etc.). | -| **Date de sortie** | JJ/MM/AAAA + mode (domicile, transfert). | -| **DP** | Code CIM‑10 **principal** (ex. K85.0 pancréatite aiguë). | -| **Motif d’hospitalisation** | Texte libre (ex. douleur abdominale, colique hépatique). | -| **DAS (CMA)** | Codes CIM‑10 associés, **seuls** s’ils ont été traités (ex. E11.9 diabète, Z92 antécédent). | -| **Actes classants** | Code CCAM + date (ex. HAJFA cholécystectomie). | -| **Actes non‑classants** | Code CCAM + date (ex. ZEFA scanner). | -| **Groupe‑tarif (GHM)** | Code GHM résultant du regroupement (ex. 08C). | -| **Sevérité** | Niveau 1/2/3/4 ou R65 etc. | -| **Justification** | Référence précise (CRH n°, compte‑rendu opératoire, rapport d’imagerie, ordonnance). | -| **Commentaire contrôleur** | (à remplir après réception). | -| **Statut** | Soumis / accepté / corrigé / clôturé. | - -*Cette fiche peut être exportée en CSV/Excel et importée dans le système de **Blue‑File/ShareFile** ou directement dans **EVA**.* - ---- - -## 4. Conclusion - -Le **codage CIM‑10 dans le cadre d’un contrôle T2A** repose sur une méthode rigoureuse : - -1. **Collecter** l’ensemble des pièces du dossier (CRH, OGC, actes, bilans). -2. **Analyser** le parcours clinique pour identifier le **diagnostic principal** (DP) et les **diagnostics associés** (DAS). -3. **Appliquer** les **règles de hiérarchisation**, de **sévérité** et de **compatibilité** du référentiel (Excel). -4. **Justifier** chaque code par des documents probants et rédiger un **argumentaire** clair. -5. **Vérifier** le tout avec le **DIM**, puis le soumettre via la **fiche OGC**. -6. **Traiter** les remarques du contrôleur en révisant le DP/DAS, en documentant les changements et en re‑soumettant. - -Respecter scrupuleusement les **points d’attention obligatoires** listés ci‑dessus garantit une conformité maximale, minimise les rejets et optimise la valorisation financière du séjour. \ No newline at end of file diff --git a/unsloth_compiled_cache/.locks/.lock.UnslothBCOTrainer.py b/unsloth_compiled_cache/.locks/.lock.UnslothBCOTrainer.py deleted file mode 100644 index e69de29..0000000 diff --git a/unsloth_compiled_cache/.locks/.lock.UnslothCPOTrainer.py b/unsloth_compiled_cache/.locks/.lock.UnslothCPOTrainer.py deleted file mode 100644 index e69de29..0000000 diff --git a/unsloth_compiled_cache/.locks/.lock.UnslothDPOTrainer.py b/unsloth_compiled_cache/.locks/.lock.UnslothDPOTrainer.py deleted file mode 100644 index e69de29..0000000 diff --git a/unsloth_compiled_cache/.locks/.lock.UnslothGKDTrainer.py b/unsloth_compiled_cache/.locks/.lock.UnslothGKDTrainer.py deleted file mode 100644 index e69de29..0000000 diff --git a/unsloth_compiled_cache/.locks/.lock.UnslothGRPOTrainer.py b/unsloth_compiled_cache/.locks/.lock.UnslothGRPOTrainer.py deleted file mode 100644 index e69de29..0000000 diff --git a/unsloth_compiled_cache/.locks/.lock.UnslothKTOTrainer.py b/unsloth_compiled_cache/.locks/.lock.UnslothKTOTrainer.py deleted file mode 100644 index e69de29..0000000 diff --git a/unsloth_compiled_cache/.locks/.lock.UnslothNashMDTrainer.py b/unsloth_compiled_cache/.locks/.lock.UnslothNashMDTrainer.py deleted file mode 100644 index e69de29..0000000 diff --git a/unsloth_compiled_cache/.locks/.lock.UnslothORPOTrainer.py b/unsloth_compiled_cache/.locks/.lock.UnslothORPOTrainer.py deleted file mode 100644 index e69de29..0000000 diff --git a/unsloth_compiled_cache/.locks/.lock.UnslothOnlineDPOTrainer.py b/unsloth_compiled_cache/.locks/.lock.UnslothOnlineDPOTrainer.py deleted file mode 100644 index e69de29..0000000 diff --git a/unsloth_compiled_cache/.locks/.lock.UnslothPPOTrainer.py b/unsloth_compiled_cache/.locks/.lock.UnslothPPOTrainer.py deleted file mode 100644 index e69de29..0000000 diff --git a/unsloth_compiled_cache/.locks/.lock.UnslothPRMTrainer.py b/unsloth_compiled_cache/.locks/.lock.UnslothPRMTrainer.py deleted file mode 100644 index e69de29..0000000 diff --git a/unsloth_compiled_cache/.locks/.lock.UnslothRLOOTrainer.py b/unsloth_compiled_cache/.locks/.lock.UnslothRLOOTrainer.py deleted file mode 100644 index e69de29..0000000 diff --git a/unsloth_compiled_cache/.locks/.lock.UnslothRewardTrainer.py b/unsloth_compiled_cache/.locks/.lock.UnslothRewardTrainer.py deleted file mode 100644 index e69de29..0000000 diff --git a/unsloth_compiled_cache/.locks/.lock.UnslothSFTTrainer.py b/unsloth_compiled_cache/.locks/.lock.UnslothSFTTrainer.py deleted file mode 100644 index e69de29..0000000 diff --git a/unsloth_compiled_cache/.locks/.lock.UnslothXPOTrainer.py b/unsloth_compiled_cache/.locks/.lock.UnslothXPOTrainer.py deleted file mode 100644 index e69de29..0000000 diff --git a/unsloth_compiled_cache/UnslothBCOTrainer.py b/unsloth_compiled_cache/UnslothBCOTrainer.py deleted file mode 100644 index d390393..0000000 --- a/unsloth_compiled_cache/UnslothBCOTrainer.py +++ /dev/null @@ -1,2180 +0,0 @@ -""" -2026.2.1 -2026.2.1 -4.57.6 -0.24.0 -__UNSLOTH_VERSIONING__ -""" - -# Unsloth auto generated code -# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see . - -from torch import Tensor -import torch -import torch.nn as nn -from torch.nn import functional as F -from unsloth_zoo.temporary_patches.common import torch_compile -from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable -from trl.trainer.bco_trainer import (Any, AutoModelForCausalLM, BCOConfig, BCOTrainer, BaseImageProcessor, BaseTrainer, CLF_NAME, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, LogisticRegression, Optional, PartialState, Path, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, RUNNING_NAME, RunningMoments, SequentialSampler, TrainerCallback, TrainingArguments, Union, _process_tokens, _tokenize, autocast, contextmanager, create_reference_model, defaultdict, disable_dropout_in_model, has_length, inspect, is_comet_available, is_joblib_available, is_peft_available, is_sklearn_available, is_wandb_available, itemgetter, joblib, log_table_to_comet_experiment, logger, logging, maybe_apply_chat_template, maybe_extract_prompt, maybe_unpair_preference_dataset, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_deepspeed, prepare_model_for_kbit_training, random, selective_log_softmax, textwrap, torch, tqdm, warnings, AutoModelForCausalLM, BCOConfig, BCOTrainer, BaseImageProcessor, Callable, DPODataCollatorWithPadding, DataCollator, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, LogisticRegression, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, RunningMoments, TrainerCallback, TrainingArguments, Union, autocast, create_reference_model, defaultdict, disable_dropout_in_model, inspect, is_comet_available, is_joblib_available, is_peft_available, is_sklearn_available, is_wandb_available, joblib, logger, maybe_apply_chat_template, maybe_extract_prompt, maybe_unpair_preference_dataset, nn, np, os, peft_module_casting_to_bf16, prepare_deepspeed, prepare_model_for_kbit_training, torch, warnings, F, Optional, PeftModel, PreTrainedModel, is_peft_available, logger, os, torch) - - -import os -from typing import * -from dataclasses import dataclass, field -from packaging.version import Version -import torch -import numpy as np -from contextlib import nullcontext -from torch.nn import functional as F -import inspect -from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling -from transformers.training_args import ParallelMode -from unsloth_zoo.device_type import DEVICE_TYPE, device_synchronize - -# Wrap trainer with padding to right and enable training mode -# Also patches W&B since multiple runs must use wandb.finish() -import functools -from types import MethodType -try: - from unsloth_zoo.gradient_checkpointing import reset_unsloth_gradient_checkpointing_buffers -except: - def reset_unsloth_gradient_checkpointing_buffers(): pass -def prepare_for_training_mode(f): - @functools.wraps(f) - def wrapper(self, *args, **kwargs): - # Enable training mode - _was_training = None - # Get gradient checkpointing setting from training arguments - use_gc = getattr(self.args, 'gradient_checkpointing', True) - if hasattr(self, 'model') and hasattr(self.model, "training"): - _was_training = self.model.training - if hasattr(self, 'model') and hasattr(self.model, "for_training"): - self.model.for_training(use_gradient_checkpointing=use_gc) - output = f(self, *args, **kwargs) - # Restore previous mode when possible - if hasattr(self, 'model') and hasattr(self.model, "for_inference"): - if _was_training is False: - self.model.for_inference() - elif _was_training is True and hasattr(self.model, "for_training"): - self.model.for_training(use_gradient_checkpointing=use_gc) - # Reset gradient checkpointing buffers to free memory while staying ready for next run - try: - reset_unsloth_gradient_checkpointing_buffers() - except: - pass - # Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run - try: - import wandb - wandb.finish() - except: - pass - return output - return wrapper -pass - -torch_compile_options = { - "epilogue_fusion" : True, - "max_autotune" : False, - "shape_padding" : True, - "trace.enabled" : False, - "triton.cudagraphs" : False, -} - -@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) -def chunked_hidden_states_selective_log_softmax( - hidden_states: torch.Tensor, - lm_head: torch.Tensor, - index: torch.Tensor, - chunks: int = 4, - logit_scale_multiply: float = 0.0, - logit_scale_divide: float = 0.0, - logit_softcapping: float = 0.0, - temperature: float = 1.0, -) -> torch.Tensor: - # All Unsloth Zoo code licensed under AGPL3 - flat_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1]) - flat_index = index.reshape(-1) - - chunked_hidden_states = torch.chunk(flat_hidden_states, chunks=chunks, dim=0) - chunked_index = torch.chunk(flat_index, chunks=chunks, dim=0) - - all_per_token_logps = [] - - for chunk_hidden_states, chunk_index in zip(chunked_hidden_states, chunked_index): - chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t() - - if logit_scale_multiply != 0.0: - chunk_logits = chunk_logits * logit_scale_multiply - if logit_scale_divide != 0.0: - chunk_logits = chunk_logits / logit_scale_divide - if logit_softcapping != 0.0: - chunk_logits = chunk_logits * torch.tanh(chunk_logits / logit_softcapping) - - chunk_logits = chunk_logits.to(torch.float32) - - if temperature != 1.0: - chunk_logits = chunk_logits / temperature - - selected_logits = torch.gather(chunk_logits, dim=-1, index=chunk_index.unsqueeze(-1)).squeeze(-1) - logsumexp_values = torch.logsumexp(chunk_logits, dim=-1) - per_token_logps = selected_logits - logsumexp_values - all_per_token_logps.append(per_token_logps) - - all_per_token_logps = torch.concat(all_per_token_logps) - - all_per_token_logps = all_per_token_logps.reshape((hidden_states.shape[0], hidden_states.shape[1])) - return all_per_token_logps - -@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) -def chunked_selective_log_softmax(logits, index): - # Split into 4 chunks only - chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0) - chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0) - all_per_token_logps = [] - # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index) - for chunk_logits, chunk_index in zip(chunked_logits, chunked_index): - chunk_logits = chunk_logits.to(torch.float32) - selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1) - logsumexp_values = torch.logsumexp(chunk_logits, dim = -1) - per_token_logps = selected_logits - logsumexp_values - all_per_token_logps.append(per_token_logps) - pass - all_per_token_logps = torch.concat(all_per_token_logps) - all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1])) - return all_per_token_logps - -def calculate_pad_tokens_in_prompt( - input_ids: torch.Tensor, - logits_to_keep: int, - pad_token_id: int -) -> torch.Tensor: - """ - Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens - """ - if logits_to_keep >= input_ids.shape[1]: - raise ValueError("logits_to_keep must be smaller than the sequence length.") - - prompt_section = input_ids[:, :-logits_to_keep] - - padding_mask = (prompt_section == pad_token_id) - - pad_token_counts = padding_mask.sum(dim=1) - - return pad_token_counts - -def create_completion_attention_mask( - completion_input_ids: torch.Tensor, - left_pad_tokens_per_prompt: torch.Tensor, - max_left_pad: int, - pad_token_id: int -) -> torch.Tensor: - """ - Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad] - - Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens - and pad are pad tokens, this function would make a completion mask that would 0 out the pad - and p tokens. so in this example [0,0,0,1,1,1,0,0,0] - """ - batch_size, completion_len = completion_input_ids.shape - device = completion_input_ids.device - - num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt - - indices = torch.arange(completion_len, device=device).unsqueeze(0) - shift_mask = indices >= num_tokens_to_mask.unsqueeze(1) - - non_padding_mask = (completion_input_ids != pad_token_id) - - final_mask = shift_mask & non_padding_mask - - return final_mask - -def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor: - """ - Moves all padding tokens in each sequence of a batch to the right. - """ - mask = (tensor != pad_id) - # Must do stable=True since binary mark is unordered - sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True) - packed_tensor = torch.gather(tensor, 1, sorted_indices) - return packed_tensor - -def align_logprobs_with_mask( - logprob_tensor: torch.Tensor, - attention_mask: torch.Tensor, - pad_value: float = 0.0 -) -> torch.Tensor: - """ - Aligns a log probability tensor with a given attention mask. - """ - - device = logprob_tensor.device - batch_size, logprob_seq_len = logprob_tensor.shape - mask_seq_len = attention_mask.shape[1] - - padded_logprobs = torch.full( - attention_mask.shape, - fill_value=pad_value, - dtype=logprob_tensor.dtype, - device=device - ) - - left_pad_counts = torch.argmax(attention_mask, dim=1) - - cols = torch.arange(logprob_seq_len, device=device) - dest_indices = left_pad_counts.unsqueeze(1) + cols - - # Create destination row indices - # Shape: [batch_size, logprob_seq_len] - row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices) - - # --- 4. Filter out-of-bounds indices and perform assignment --- - # Create a mask to identify only the indices that are within the bounds - # of the target tensor's sequence length. - valid_mask = dest_indices < mask_seq_len - - # Use this mask to select only the valid row indices, column indices, - # and the corresponding values from the logprob tensor. - # This flattens the selected elements into 1D tensors. - valid_rows = row_indices[valid_mask] - valid_cols = dest_indices[valid_mask] - valid_vals = logprob_tensor[valid_mask] - - # Place the valid values into their correct positions in the padded tensor - # using a single, efficient advanced indexing operation. - padded_logprobs[valid_rows, valid_cols] = valid_vals - - return padded_logprobs - -def autotune_batch_and_chunks( - total_input_rows, - seq_len, - hidden_size, - vocab_size, - dtype_bytes=16, - multiplier=None -): - if multiplier is None: - final_m = max(4, seq_len // 4096) - else: - final_m = multiplier - - if torch.cuda.is_available(): - free_bytes, _ = torch.cuda.mem_get_info() - limit_gb = (free_bytes / (1024**3))*.80 - elif hasattr(torch, "xpu") and torch.xpu.is_available(): - # For XPU: estimate free memory from total - reserved - total_mem = torch.xpu.get_device_properties(0).total_memory - reserved_mem = torch.xpu.memory_reserved() - free_bytes = total_mem - reserved_mem - limit_gb = (free_bytes / (1024**3)) * 0.80 - else: - # Fallback: assume 8GB available - limit_gb = 8.0 - - bytes_to_gb = 1024**3 - - b_vals = torch.arange(total_input_rows, 0, -1, device='cpu', dtype=torch.float32) - - hidden_gb = (b_vals * seq_len * hidden_size * dtype_bytes) / bytes_to_gb - - base_logits = ((b_vals/total_input_rows) * b_vals * seq_len * vocab_size * dtype_bytes) / bytes_to_gb - logits_gb = base_logits / final_m - - total_mem_gb = hidden_gb + logits_gb - - valid_mask = total_mem_gb <= limit_gb - valid_indices = torch.nonzero(valid_mask, as_tuple=False) - - if valid_indices.shape[0] == 0: - #This means your GPU will OOM - return 4, final_m - - best_idx = valid_indices[0].item() - final_b = int(b_vals[best_idx].item()) - - return final_b, final_m -@dataclass -class UnslothBCOConfig(BCOConfig): - """ - - Configuration class for the [`BCOTrainer`]. - - This class includes only the parameters that are specific to BCO training. For a full list of training arguments, - please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may - differ from those in [`~transformers.TrainingArguments`]. - - Using [`~transformers.HfArgumentParser`] we can turn this class into - [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the - command line. - - Parameters: - max_length (`int` or `None`, *optional*, defaults to `1024`): - Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want - to use the default data collator. - max_prompt_length (`int` or `None`, *optional*, defaults to `512`): - Maximum length of the prompt. This argument is required if you want to use the default data collator. - max_completion_length (`int`, *optional*): - Maximum length of the completion. This argument is required if you want to use the default data collator - and your model is an encoder-decoder. - beta (`float`, *optional*, defaults to `0.1`): - Parameter controlling the deviation from the reference model. Higher β means less deviation from the - reference model. - label_pad_token_id (`int`, *optional*, defaults to `-100`): - Label pad token id. This argument is required if you want to use the default data collator. - padding_value (`int`, *optional*): - Padding value to use. If `None`, the padding value of the tokenizer is used. - truncation_mode (`str`, *optional*, defaults to `"keep_end"`): - Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`. - This argument is required if you want to use the default data collator. - disable_dropout (`bool`, *optional*, defaults to `True`): - Whether to disable dropout in the model and reference model. - generate_during_eval (`bool`, *optional*, defaults to `False`): - If `True`, generates and logs completions from both the model and the reference model to W&B or Comet - during evaluation. - is_encoder_decoder (`bool`, *optional*): - When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument, - you need to specify if the model returned by the callable is an encoder-decoder model. - precompute_ref_log_probs (`bool`, *optional*, defaults to `False`): - Whether to precompute reference model log probabilities for training and evaluation datasets. This is - useful when training without the reference model to reduce the total GPU memory needed. - model_init_kwargs (`dict[str, Any]`, *optional*): - Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a - string. - ref_model_init_kwargs (`dict[str, Any]`, *optional*): - Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the reference model - from a string. - dataset_num_proc (`int`, *optional*): - Number of processes to use for processing the dataset. - prompt_sample_size (`int`, *optional*, defaults to `1024`): - Number of prompts that are fed to density ratio classifier. - min_density_ratio (`float`, *optional*, defaults to `0.5`): - Minimum value of the density ratio. The estimated density ratio is clamped to this value. - max_density_ratio (`float`, *optional*, defaults to `10.0`): - Maximum value of the density ratio. The estimated density ratio is clamped to this value. - - """ - vllm_sampling_params: Optional[Any] = field( - default = None, - metadata = {'help': 'vLLM SamplingParams'}, - ) - unsloth_num_chunks : Optional[int] = field( - default = -1, - metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'}, - ) - unsloth_logit_chunk_multiplier : Optional[int] = field( - default = None, - metadata = {'help': 'Multiplier for chunked logit computations.'}, - ) - unsloth_grpo_mini_batch : Optional[int] = field( - default = None, - metadata = {'help': 'Mini batch size for GRPO hidden state accumulation. Default is None unless user defines it.'}, - ) - max_seq_length : Optional[int] = field( - default = None, - metadata = {'help': 'Maximum sequence length to truncate to.'}, - ) - def __init__( - self, - output_dir = None, - overwrite_output_dir = None, - do_train = False, - do_eval = False, - do_predict = False, - eval_strategy = 'no', - prediction_loss_only = False, - per_device_train_batch_size = 4, - per_device_eval_batch_size = 4, - per_gpu_train_batch_size = None, - per_gpu_eval_batch_size = None, - gradient_accumulation_steps = 2, - eval_accumulation_steps = 2, - eval_delay = 0, - torch_empty_cache_steps = 250, - learning_rate = 5e-05, - weight_decay = 0.01, - adam_beta1 = 0.9, - adam_beta2 = 0.999, - adam_epsilon = 1e-08, - max_grad_norm = 1.0, - num_train_epochs = 3.0, - max_steps = -1, - lr_scheduler_type = 'linear', - lr_scheduler_kwargs = None, - warmup_ratio = 0.1, - warmup_steps = 0, - log_level = 'passive', - log_level_replica = 'warning', - log_on_each_node = True, - logging_dir = None, - logging_strategy = 'steps', - logging_first_step = False, - logging_steps = 1, - logging_nan_inf_filter = False, - save_strategy = 'steps', - save_steps = 500, - save_total_limit = None, - save_safetensors = True, - save_on_each_node = False, - save_only_model = False, - restore_callback_states_from_checkpoint = False, - no_cuda = False, - use_cpu = False, - use_mps_device = False, - seed = 3407, - data_seed = 3407, - jit_mode_eval = False, - bf16 = False, - fp16 = False, - fp16_opt_level = 'O1', - half_precision_backend = 'auto', - bf16_full_eval = False, - fp16_full_eval = False, - tf32 = None, - local_rank = -1, - ddp_backend = None, - tpu_num_cores = None, - tpu_metrics_debug = False, - debug = '', - dataloader_drop_last = False, - eval_steps = None, - dataloader_num_workers = 0, - dataloader_prefetch_factor = None, - past_index = -1, - run_name = None, - disable_tqdm = None, - remove_unused_columns = True, - label_names = None, - load_best_model_at_end = False, - metric_for_best_model = None, - greater_is_better = None, - ignore_data_skip = False, - fsdp = None, - fsdp_min_num_params = 0, - fsdp_config = None, - fsdp_transformer_layer_cls_to_wrap = None, - accelerator_config = None, - parallelism_config = None, - deepspeed = None, - label_smoothing_factor = 0.0, - optim = 'adamw_8bit', - optim_args = None, - adafactor = False, - group_by_length = False, - length_column_name = 'length', - report_to = 'none', - project = 'huggingface', - trackio_space_id = 'trackio', - ddp_find_unused_parameters = None, - ddp_bucket_cap_mb = None, - ddp_broadcast_buffers = None, - dataloader_pin_memory = True, - dataloader_persistent_workers = False, - skip_memory_metrics = True, - use_legacy_prediction_loop = False, - push_to_hub = False, - resume_from_checkpoint = None, - hub_model_id = None, - hub_strategy = 'every_save', - hub_token = None, - hub_private_repo = None, - hub_always_push = False, - hub_revision = None, - gradient_checkpointing = True, - gradient_checkpointing_kwargs = None, - include_inputs_for_metrics = False, - eval_do_concat_batches = True, - fp16_backend = 'auto', - push_to_hub_model_id = None, - push_to_hub_organization = None, - push_to_hub_token = None, - mp_parameters = '', - auto_find_batch_size = False, - full_determinism = False, - torchdynamo = None, - ray_scope = 'last', - ddp_timeout = 1800, - torch_compile = False, - torch_compile_backend = None, - torch_compile_mode = None, - include_tokens_per_second = False, - include_num_input_tokens_seen = False, - neftune_noise_alpha = None, - optim_target_modules = None, - batch_eval_metrics = False, - eval_on_start = False, - use_liger_kernel = False, - liger_kernel_config = None, - eval_use_gather_object = False, - average_tokens_across_devices = True, - max_length = 1024, - max_prompt_length = 512, - max_completion_length = None, - beta = 0.1, - label_pad_token_id = -100, - padding_value = None, - truncation_mode = 'keep_end', - disable_dropout = True, - generate_during_eval = False, - is_encoder_decoder = None, - precompute_ref_log_probs = False, - model_init_kwargs = None, - ref_model_init_kwargs = None, - dataset_num_proc = None, - prompt_sample_size = 1024, - min_density_ratio = 0.5, - max_density_ratio = 10.0, - vllm_sampling_params = None, - unsloth_num_chunks = -1, - unsloth_logit_chunk_multiplier = None, - unsloth_grpo_mini_batch = None, - max_seq_length = None, - **kwargs, - ): - if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!') - if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!') - if num_train_epochs is None: - num_train_epochs = 3.0 # Default to 3 epochs if None, max_steps will override - if output_dir is None and save_strategy == 'steps' and save_steps == 500: - output_dir = 'unsloth_training_checkpoints' - save_strategy = 'no' - import multiprocessing as _mp - if _mp.get_start_method() != 'fork': - dataset_num_proc = None - elif dataset_num_proc is None: - import psutil - dataset_num_proc = min(max((psutil.cpu_count() or 1)+4, 2), 64) - memory_gb_left = psutil.virtual_memory().available / (1024**3) - if memory_gb_left <= 2: dataset_num_proc = 1 - else: dataset_num_proc = min(dataset_num_proc, int(memory_gb_left)) - - super().__init__( - output_dir = output_dir, - overwrite_output_dir = overwrite_output_dir, - do_train = do_train, - do_eval = do_eval, - do_predict = do_predict, - eval_strategy = eval_strategy, - prediction_loss_only = prediction_loss_only, - per_device_train_batch_size = per_device_train_batch_size, - per_device_eval_batch_size = per_device_eval_batch_size, - per_gpu_train_batch_size = per_gpu_train_batch_size, - per_gpu_eval_batch_size = per_gpu_eval_batch_size, - gradient_accumulation_steps = gradient_accumulation_steps, - eval_accumulation_steps = eval_accumulation_steps, - eval_delay = eval_delay, - torch_empty_cache_steps = torch_empty_cache_steps, - learning_rate = learning_rate, - weight_decay = weight_decay, - adam_beta1 = adam_beta1, - adam_beta2 = adam_beta2, - adam_epsilon = adam_epsilon, - max_grad_norm = max_grad_norm, - num_train_epochs = num_train_epochs, - max_steps = max_steps, - lr_scheduler_type = lr_scheduler_type, - lr_scheduler_kwargs = lr_scheduler_kwargs, - warmup_ratio = warmup_ratio, - warmup_steps = warmup_steps, - log_level = log_level, - log_level_replica = log_level_replica, - log_on_each_node = log_on_each_node, - logging_dir = logging_dir, - logging_strategy = logging_strategy, - logging_first_step = logging_first_step, - logging_steps = logging_steps, - logging_nan_inf_filter = logging_nan_inf_filter, - save_strategy = save_strategy, - save_steps = save_steps, - save_total_limit = save_total_limit, - save_safetensors = save_safetensors, - save_on_each_node = save_on_each_node, - save_only_model = save_only_model, - restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint, - no_cuda = no_cuda, - use_cpu = use_cpu, - use_mps_device = use_mps_device, - seed = seed, - data_seed = data_seed, - jit_mode_eval = jit_mode_eval, - bf16 = bf16, - fp16 = fp16, - fp16_opt_level = fp16_opt_level, - half_precision_backend = half_precision_backend, - bf16_full_eval = bf16_full_eval, - fp16_full_eval = fp16_full_eval, - tf32 = tf32, - local_rank = local_rank, - ddp_backend = ddp_backend, - tpu_num_cores = tpu_num_cores, - tpu_metrics_debug = tpu_metrics_debug, - debug = debug, - dataloader_drop_last = dataloader_drop_last, - eval_steps = eval_steps, - dataloader_num_workers = dataloader_num_workers, - dataloader_prefetch_factor = dataloader_prefetch_factor, - past_index = past_index, - run_name = run_name, - disable_tqdm = disable_tqdm, - remove_unused_columns = remove_unused_columns, - label_names = label_names, - load_best_model_at_end = load_best_model_at_end, - metric_for_best_model = metric_for_best_model, - greater_is_better = greater_is_better, - ignore_data_skip = ignore_data_skip, - fsdp = fsdp, - fsdp_min_num_params = fsdp_min_num_params, - fsdp_config = fsdp_config, - fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap, - accelerator_config = accelerator_config, - parallelism_config = parallelism_config, - deepspeed = deepspeed, - label_smoothing_factor = label_smoothing_factor, - optim = optim, - optim_args = optim_args, - adafactor = adafactor, - group_by_length = group_by_length, - length_column_name = length_column_name, - report_to = report_to, - project = project, - trackio_space_id = trackio_space_id, - ddp_find_unused_parameters = ddp_find_unused_parameters, - ddp_bucket_cap_mb = ddp_bucket_cap_mb, - ddp_broadcast_buffers = ddp_broadcast_buffers, - dataloader_pin_memory = dataloader_pin_memory, - dataloader_persistent_workers = dataloader_persistent_workers, - skip_memory_metrics = skip_memory_metrics, - use_legacy_prediction_loop = use_legacy_prediction_loop, - push_to_hub = push_to_hub, - resume_from_checkpoint = resume_from_checkpoint, - hub_model_id = hub_model_id, - hub_strategy = hub_strategy, - hub_token = hub_token, - hub_private_repo = hub_private_repo, - hub_always_push = hub_always_push, - hub_revision = hub_revision, - gradient_checkpointing = gradient_checkpointing, - gradient_checkpointing_kwargs = gradient_checkpointing_kwargs, - include_inputs_for_metrics = include_inputs_for_metrics, - eval_do_concat_batches = eval_do_concat_batches, - fp16_backend = fp16_backend, - push_to_hub_model_id = push_to_hub_model_id, - push_to_hub_organization = push_to_hub_organization, - push_to_hub_token = push_to_hub_token, - mp_parameters = mp_parameters, - auto_find_batch_size = auto_find_batch_size, - full_determinism = full_determinism, - torchdynamo = torchdynamo, - ray_scope = ray_scope, - ddp_timeout = ddp_timeout, - torch_compile = torch_compile, - torch_compile_backend = torch_compile_backend, - torch_compile_mode = torch_compile_mode, - include_tokens_per_second = include_tokens_per_second, - include_num_input_tokens_seen = include_num_input_tokens_seen, - neftune_noise_alpha = neftune_noise_alpha, - optim_target_modules = optim_target_modules, - batch_eval_metrics = batch_eval_metrics, - eval_on_start = eval_on_start, - use_liger_kernel = use_liger_kernel, - liger_kernel_config = liger_kernel_config, - eval_use_gather_object = eval_use_gather_object, - average_tokens_across_devices = average_tokens_across_devices, - max_length = max_length, - max_prompt_length = max_prompt_length, - max_completion_length = max_completion_length, - beta = beta, - label_pad_token_id = label_pad_token_id, - padding_value = padding_value, - truncation_mode = truncation_mode, - disable_dropout = disable_dropout, - generate_during_eval = generate_during_eval, - is_encoder_decoder = is_encoder_decoder, - precompute_ref_log_probs = precompute_ref_log_probs, - model_init_kwargs = model_init_kwargs, - ref_model_init_kwargs = ref_model_init_kwargs, - dataset_num_proc = dataset_num_proc, - prompt_sample_size = prompt_sample_size, - min_density_ratio = min_density_ratio, - max_density_ratio = max_density_ratio,**kwargs) - self.vllm_sampling_params = vllm_sampling_params - self.unsloth_num_chunks = unsloth_num_chunks - if unsloth_grpo_mini_batch is not None: - if self.generation_batch_size >= unsloth_grpo_mini_batch: - self.unsloth_grpo_mini_batch = unsloth_grpo_mini_batch - else: - raise ValueError( - f"Unsloth GRPO mini batch size needs to be less than or equal to the effective generation batch size, " - f"which is self.per_device_train_batch_size * gradient_accumulation_steps." - ) - self.unsloth_logit_chunk_multiplier = unsloth_logit_chunk_multiplier - self.max_seq_length = max_seq_length - -pass - -class _UnslothBCOTrainer(BaseTrainer): - r"""""" - - _tag_names = ["trl", "bco"] - _name = "BCO" - _paper = { - "title": "Binary Classifier Optimization for Large Language Model Alignment", - "id": "2404.04656", - # docstyle-ignore - "citation": textwrap.dedent("""\ - @article{jung2024binary, - title = {{Binary Classifier Optimization for Large Language Model Alignment}}, - author = {Seungjae Jung and Gunsoo Han and Daniel Wontae Nam and Kyoung{-}Woon On}, - year = 2024, - eprint = {arXiv:2404.04656} - }"""), - } - - def __init__( - self, - model: Union[PreTrainedModel, nn.Module, str] = None, - ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, - args: BCOConfig = None, - train_dataset: Optional[Dataset] = None, - eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None, - processing_class: Optional[ - Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] - ] = None, - data_collator: Optional[DataCollator] = None, - model_init: Optional[Callable[[], PreTrainedModel]] = None, - callbacks: Optional[list[TrainerCallback]] = None, - optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), - preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, - peft_config: Optional[dict] = None, - compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None, - model_adapter_name: Optional[str] = None, - ref_adapter_name: Optional[str] = None, - embedding_func: Optional[Callable] = None, - embedding_tokenizer: Optional[PreTrainedTokenizerBase] = None, - ): - if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"): - warnings.warn( - "This trainer will soon be moved to trl.experimental and is a candidate for removal. If you rely on " - "it and want it to remain, please share your comments here: " - "https://github.com/huggingface/trl/issues/4223. Silence this warning by setting environment variable " - "TRL_EXPERIMENTAL_SILENCE=1." - ) - if embedding_func is not None and not (is_sklearn_available() and is_joblib_available()): - raise ImportError( - "BCOTrainer with UDM requires the scikit-learn and joblib libraries. Please install it with `pip install scikit-learn joblib`." - ) - - if type(args) is TrainingArguments: - raise ValueError("Please use `BCOConfig` instead `TrainingArguments`.") - - if not isinstance(model, str) and model is not None and ref_model is model: - raise ValueError( - "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the " - "same as `model`, you must mass a copy of it, or `None` if you use peft." - ) - - if args.model_init_kwargs is None: - model_init_kwargs = {} - elif not isinstance(model, str): - raise ValueError("You passed model_kwargs to the BCOTrainer. But your model is already instantiated.") - else: - model_init_kwargs = args.model_init_kwargs - dtype = model_init_kwargs.get("dtype") - if dtype is not None: - # Convert to `torch.dtype` if an str is passed - if isinstance(dtype, str) and dtype != "auto": - dtype = getattr(torch, dtype) - if dtype != "auto" and not isinstance(dtype, torch.dtype): - raise ValueError( - f"Invalid `dtype` passed to the BCOConfig. Expected a string with either `torch.dtype` or 'auto', but got {dtype}." - ) - model_init_kwargs["dtype"] = dtype - - if args.ref_model_init_kwargs is None: - ref_model_init_kwargs = {} - elif not isinstance(ref_model, str): - raise ValueError( - "You passed ref_model_kwargs to the BCOTrainer. But your ref_model is already instantiated." - ) - else: - ref_model_init_kwargs = args.ref_model_init_kwargs - dtype = ref_model_init_kwargs.get("dtype") - if dtype is not None: - # Convert to `torch.dtype` if an str is passed - if isinstance(dtype, str) and dtype != "auto": - dtype = getattr(torch, dtype) - if dtype != "auto" and not isinstance(dtype, torch.dtype): - raise ValueError( - f"Invalid `dtype` passed to the BCOConfig. Expected a string with either `torch.dtype` or 'auto', but got {dtype}." - ) - ref_model_init_kwargs["dtype"] = dtype - - if isinstance(model, str): - model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) - - if isinstance(ref_model, str): - ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs) - - # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16` - # has been called in order to properly call autocast if needed. - self._peft_has_been_casted_to_bf16 = False - - if not is_peft_available() and peft_config is not None: - raise ValueError( - "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it with `pip install peft` to use the PEFT models" - ) - elif is_peft_available() and peft_config is not None: - # if model is a peft model and we have a peft_config, we merge and unload it first - if isinstance(model, PeftModel): - model = model.merge_and_unload() - - if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): - _support_gc_kwargs = hasattr( - args, "gradient_checkpointing_kwargs" - ) and "gradient_checkpointing_kwargs" in list( - inspect.signature(prepare_model_for_kbit_training).parameters - ) - - prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing} - - if _support_gc_kwargs: - prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs - - model = prepare_model_for_kbit_training(model, **prepare_model_kwargs) - elif args.gradient_checkpointing: - # For backward compatibility with older versions of transformers - if hasattr(model, "enable_input_require_grads"): - model.enable_input_require_grads() - else: - - def make_inputs_require_grad(module, input, output): - output.requires_grad_(True) - - model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) - - # get peft model with the given config - model = model - if args.bf16 and getattr(model, "is_loaded_in_4bit", False): - peft_module_casting_to_bf16(model) - # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager - self._peft_has_been_casted_to_bf16 = True - - # For models that use gradient_checkpointing, we need to attach a hook that enables input - # to explicitly have `requires_grad=True`, otherwise training will either silently - # fail or completely fail. - elif args.gradient_checkpointing: - # For backward compatibility with older versions of transformers - if hasattr(model, "enable_input_require_grads"): - model.enable_input_require_grads() - else: - - def make_inputs_require_grad(module, input, output): - output.requires_grad_(True) - - model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) - - if args.generate_during_eval and not (is_wandb_available() or is_comet_available()): - raise ValueError( - "`generate_during_eval=True` requires Weights and Biases or Comet to be installed." - " Please install `wandb` or `comet-ml` to resolve." - ) - - if model is not None: - self.is_encoder_decoder = model.config.is_encoder_decoder - elif args.is_encoder_decoder is None: - raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.") - else: - self.is_encoder_decoder = args.is_encoder_decoder - - self.is_peft_model = is_peft_available() and isinstance(model, PeftModel) - self.model_adapter_name = model_adapter_name - self.ref_adapter_name = ref_adapter_name - - if ref_model: - self.ref_model = ref_model - elif self.is_peft_model or args.precompute_ref_log_probs: - # The `model` with adapters turned off will be used as the reference model - self.ref_model = None - else: - self.ref_model = create_reference_model(model) - - if processing_class is None: - raise ValueError( - "max_length or a processing_class must be specified when using the default DPODataCollatorWithPadding" - ) - if args.max_length is None: - logger.warning( - "When using DPODataCollatorWithPadding, you should set `max_length` in the `BCOConfig`. " - "It will be set to `512` by default, but you should do it yourself in the future.", - ) - max_length = 512 - if args.max_length is not None: - max_length = args.max_length - - if args.max_prompt_length is None: - logger.warning( - "When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the `BCOConfig`. " - "It will be set to `128` by default, but you should do it yourself in the future.", - ) - max_prompt_length = 128 - if args.max_prompt_length is not None: - max_prompt_length = args.max_prompt_length - - max_completion_length = None - if args.max_completion_length is None and self.is_encoder_decoder: - logger.warning( - "When using DPODataCollatorWithPadding with an encoder decoder architecture, you should set `max_completion_length` in the BCOTrainer's init" - " it will be set to `128` by default, but you should do it yourself in the future.", - ) - max_completion_length = 128 - if args.max_completion_length is not None and self.is_encoder_decoder: - max_completion_length = args.max_completion_length - - if data_collator is None: - data_collator = DPODataCollatorWithPadding( - pad_token_id=processing_class.pad_token_id, - label_pad_token_id=args.label_pad_token_id, - is_encoder_decoder=self.is_encoder_decoder, - ) - - if args.remove_unused_columns: - args.remove_unused_columns = False - # warn users - logger.warning( - "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your BCOConfig" - " we have set it for you, but you should do it yourself in the future.", - ) - - self.use_dpo_data_collator = True - else: - self.use_dpo_data_collator = False - - # Disable dropout in the model and reference model - if args.disable_dropout: - disable_dropout_in_model(model) - if self.ref_model is not None: - disable_dropout_in_model(self.ref_model) - - self.max_length = max_length - self.generate_during_eval = args.generate_during_eval - self.label_pad_token_id = args.label_pad_token_id - self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id - self.max_prompt_length = max_prompt_length - self.truncation_mode = args.truncation_mode - self.max_completion_length = max_completion_length - self.precompute_ref_log_probs = args.precompute_ref_log_probs - - # Since ref_logs are precomputed on the first call to get_train/eval_dataloader - # keep track of first called to avoid computation of future calls - self._precomputed_train_ref_log_probs = False - self._precomputed_eval_ref_log_probs = False - - # metric - self._stored_metrics = defaultdict(lambda: defaultdict(list)) - - # BCO parameter - self.beta = args.beta - self.aux_loss_enabled = getattr(model.config, "output_router_logits", False) - self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0) - if self.aux_loss_enabled and self.aux_loss_coef == 0.0: - logger.warning( - "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to " - "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value " - "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary " - "loss.", - ) - - # Underlying Distribution Matching argument - self.embedding_func = embedding_func - self.embedding_tokenizer = embedding_tokenizer - - # The trainer estimates the number of FLOPs [floating-point operations] using the number of elements in the - # input tensor associated with the key "input_ids". However, in BCO, the sampled data does not include the - # "input_ids" key. Instead, the available keys are "prompt_input_ids" and "completion_input_ids". As a result, - # the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point - # operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's - # "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been - # issued. - model.warnings_issued["estimate_tokens"] = True - - with PartialState().main_process_first(): - # Extract the prompt if needed - train_dataset = train_dataset.map( - maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from train dataset" - ) - # Unpair the dataset if needed - train_dataset = maybe_unpair_preference_dataset( - train_dataset, args.dataset_num_proc, desc="Unpairing train dataset" - ) - # Apply the chat template if needed - train_dataset = train_dataset.map( - maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc - ) - if eval_dataset is not None: - # Extract the prompt if needed - eval_dataset = eval_dataset.map( - maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from eval dataset" - ) - # Unpair the dataset if needed - eval_dataset = maybe_unpair_preference_dataset( - eval_dataset, args.dataset_num_proc, desc="Unpairing eval dataset" - ) - eval_dataset = eval_dataset.map( - maybe_apply_chat_template, - fn_kwargs={"tokenizer": processing_class}, - num_proc=args.dataset_num_proc, - ) - - # Tokenize and prepare the training datasets - train_dataset = train_dataset.map( - _tokenize, - batched=True, - fn_kwargs={"tokenizer": processing_class, "embedding_tokenizer": self.embedding_tokenizer}, - num_proc=args.dataset_num_proc, - desc="Tokenizing train dataset", - ) - - # Prepare the datasets - fn_kwargs = { - "prefix": "", - "is_encoder_decoder": self.is_encoder_decoder, - "tokenizer": processing_class, - "max_length": self.max_length, - "truncation_mode": self.truncation_mode, - "label_pad_token_id": self.label_pad_token_id, - "max_prompt_length": self.max_prompt_length, - "max_completion_length": self.max_completion_length, - } - train_dataset = train_dataset.map( - _process_tokens, - fn_kwargs=fn_kwargs, - num_proc=args.dataset_num_proc, - desc="Processing tokenized train dataset", - ) - - if eval_dataset is not None: - # Tokenize - eval_dataset = eval_dataset.map( - _tokenize, - fn_kwargs={"tokenizer": processing_class, "embedding_tokenizer": self.embedding_tokenizer}, - batched=True, - num_proc=args.dataset_num_proc, - desc="Tokenizing eval dataset", - ) - - # Process - fn_kwargs = { - "prefix": "", - "is_encoder_decoder": self.is_encoder_decoder, - "tokenizer": processing_class, - "max_length": self.max_length, - "truncation_mode": self.truncation_mode, - "label_pad_token_id": self.label_pad_token_id, - "max_prompt_length": self.max_prompt_length, - "max_completion_length": self.max_completion_length, - } - eval_dataset = eval_dataset.map( - _process_tokens, - fn_kwargs=fn_kwargs, - num_proc=args.dataset_num_proc, - desc="Processing tokenized eval dataset", - ) - - desirable = train_dataset.filter( - lambda x: x["label"], num_proc=args.dataset_num_proc, desc="Filtering desirable examples" - ) - undesirable = train_dataset.filter( - lambda x: not x["label"], num_proc=args.dataset_num_proc, desc="Filtering undesirable examples" - ) - - super().__init__( - model=model, - args=args, - data_collator=data_collator, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - processing_class=processing_class, - model_init=model_init, - compute_metrics=compute_metrics, - callbacks=callbacks, - optimizers=optimizers, - preprocess_logits_for_metrics=preprocess_logits_for_metrics, - ) - - # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the - # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set - # self.model_accepts_loss_kwargs to False to enable scaling. - self.model_accepts_loss_kwargs = False - - # Add tags for models that have been loaded with the correct transformers version - if hasattr(self.model, "add_model_tags"): - self.model.add_model_tags(self._tag_names) - - if not hasattr(self, "accelerator"): - raise AttributeError( - "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`." - ) - - # Deepspeed Zero-3 does not support precompute_ref_log_probs - if self.is_deepspeed_enabled: - if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs: - raise ValueError( - "You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`." - ) - - if self.ref_model is None: - if not (self.is_peft_model or self.precompute_ref_log_probs): - raise ValueError( - "No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`" - ) - else: - if self.is_deepspeed_enabled: - self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) - else: - self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) - - self.running = RunningMoments(accelerator=self.accelerator) - - if self.embedding_func is None or args.resume_from_checkpoint: - return - - chosen_embeddings = self._get_sample_prompt_embeddings(desirable, sample_size=self.args.prompt_sample_size) - rejected_embeddings = self._get_sample_prompt_embeddings(undesirable, sample_size=self.args.prompt_sample_size) - - embeddings = torch.cat((chosen_embeddings, rejected_embeddings), dim=0) - labels = torch.cat( - (torch.ones_like(chosen_embeddings[:, 0]), torch.zeros_like(rejected_embeddings[:, 0])), dim=0 - ) - - self.clf = LogisticRegression(class_weight="balanced").fit( - embeddings.cpu().float().numpy(), labels.cpu().numpy() - ) - chosen_mean = self.clf.score( - chosen_embeddings.cpu().float().numpy(), torch.ones_like(chosen_embeddings[:, 0]).cpu().numpy() - ) - rejected_mean = self.clf.score( - rejected_embeddings.cpu().float().numpy(), torch.zeros_like(rejected_embeddings[:, 0]).cpu().numpy() - ) - logger.info(f"UDM classifier training scores: chosen: {chosen_mean}, rejected: {rejected_mean}") - - @property - def match_underlying_distribution(self): - return self.embedding_func is not None and self.embedding_tokenizer is not None - - def _get_chosen_prob(self, prompt_embeddings: torch.FloatTensor) -> torch.FloatTensor: - """ - Calculates the probability if the given prompt embedding is from desirable dataset. This function calculates - the probability in the process and ensemble across processes. - """ - dtype = prompt_embeddings.dtype - device = prompt_embeddings.device - rank = self.accelerator.process_index - - padded_prompt_embeddings = self.accelerator.pad_across_processes( - prompt_embeddings, pad_index=self.embedding_tokenizer.pad_token_id - ) - sample_size = padded_prompt_embeddings.shape[0] - nonzero = padded_prompt_embeddings.mean(dim=1) != self.embedding_tokenizer.pad_token_id - prompt_embeddings = self.accelerator.gather(padded_prompt_embeddings) - - # cannot predict for all empty values - if prompt_embeddings.shape[0] == 0: - return torch.tensor([], device=device, dtype=dtype) - - prob = self.clf.predict_proba(prompt_embeddings.cpu().float().numpy())[:, 1] - prob = torch.as_tensor(prob, dtype=dtype, device=device) - prob = self.accelerator.reduce(prob, reduction="mean") - - prob = prob[sample_size * rank : sample_size * (rank + 1)] - prob = prob[nonzero] - - return prob - - def _vectorize_prompt(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor) -> torch.FloatTensor: - """ - Replaces processing_class.pad_token_id to embedding_tokenizer.pad_token_id and applies self.embedding_func - """ - input_ids = torch.where( - input_ids == self.processing_class.pad_token_id, - self.embedding_tokenizer.pad_token_id, - input_ids, - ) - - with torch.no_grad(): - embeddings = self.embedding_func( - input_ids=input_ids, - attention_mask=attention_mask, - ) - - return embeddings - - def _get_prompt_embeddings( - self, batch: dict[str, Union[list, torch.LongTensor]] - ) -> tuple[torch.FloatTensor, torch.FloatTensor]: - """Extract embeddings from frozen embedding model""" - - if not self.match_underlying_distribution: - return None, None - - embeddings = self._vectorize_prompt( - input_ids=batch["embedding_input_ids"], - attention_mask=batch["embedding_attention_mask"], - ) - - labels = torch.tensor(batch["label"], dtype=torch.bool, device=embeddings.device) - chosen_idx = torch.where(labels)[0] - rejected_idx = torch.where(~labels)[0] - - chosen_embeddings = embeddings[chosen_idx, ...] - rejected_embeddings = embeddings[rejected_idx, ...] - - return (chosen_embeddings, rejected_embeddings) - - def _get_sample_prompt_embeddings(self, dataset: Dataset, sample_size: int = 512) -> torch.FloatTensor: - """ - Sample instances from dataset and get prompt embeddings. Used for density ratio classifier training. - """ - n_samples = min(len(dataset), sample_size) - rand_indices = np.random.choice(len(dataset), size=(n_samples,)) - - embedding_dataset = dataset.select(rand_indices) - - dataloader_params = { - "batch_size": self.args.per_device_train_batch_size, - "collate_fn": self.data_collator, - "num_workers": self.args.dataloader_num_workers, - "pin_memory": self.args.dataloader_pin_memory, - "shuffle": False, - } - - # prepare dataloader - data_loader = self.accelerator.prepare(DataLoader(embedding_dataset, **dataloader_params)) - - with torch.no_grad(): - all_embeddings = torch.empty(0) - for padded_batch in tqdm(iterable=data_loader, desc="Building sample prompt embeddings"): - embeddings = self._vectorize_prompt( - input_ids=padded_batch["embedding_input_ids"], - attention_mask=padded_batch["embedding_attention_mask"], - ) - embeddings = self.accelerator.gather_for_metrics(embeddings) - all_embeddings = torch.cat((all_embeddings, embeddings.cpu())) - - return all_embeddings - - def _save_optimizer_and_scheduler(self, output_dir): - output_dir = output_dir if output_dir is not None else self.args.output_dir - super()._save_optimizer_and_scheduler(output_dir) - - if self.accelerator.is_main_process: - # When saving optimizer and scheduler to checkpoint, save also the running delta object. - self.running.save_to_json(os.path.join(output_dir, RUNNING_NAME)) - - if self.match_underlying_distribution: - joblib.dump(self.clf, os.path.join(output_dir, CLF_NAME), compress=True) - - def _load_optimizer_and_scheduler(self, checkpoint): - if checkpoint is None: - logger.warning_once(f"Missing Checkpoint {checkpoint}") - return - - super()._load_optimizer_and_scheduler(checkpoint) - - # when loading optimizer and scheduler from checkpoint, also load the running delta object. - running_file = os.path.join(checkpoint, RUNNING_NAME) - if os.path.isfile(running_file): - self.running = RunningMoments.load_from_json(self.accelerator, running_file) - - if self.match_underlying_distribution: - clf_file = os.path.join(checkpoint, CLF_NAME) - if os.path.isfile(clf_file): - self.clf = joblib.load(clf_file) - - @contextmanager - def null_ref_context(self): - """Context manager for handling null reference model (that is, peft adapter manipulation).""" - with ( - self.accelerator.unwrap_model(self.model).disable_adapter() - if self.is_peft_model and not self.ref_adapter_name - else nullcontext() - ): - if self.ref_adapter_name: - self.model.set_adapter(self.ref_adapter_name) - yield - if self.ref_adapter_name: - self.model.set_adapter(self.model_adapter_name or "default") - - def get_train_dataloader(self) -> DataLoader: - """ - Returns the training [`~torch.utils.data.DataLoader`]. - - Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`. - """ - - if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs: - dataloader_params = { - "batch_size": self.args.per_device_train_batch_size, - "collate_fn": self.data_collator, - "num_workers": self.args.dataloader_num_workers, - "pin_memory": self.args.dataloader_pin_memory, - "shuffle": False, - } - - # prepare dataloader - data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params)) - reference_completion_logps = [] - - for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"): - reference_completion_logp = self.compute_reference_log_probs(padded_batch) - - reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp) - reference_completion_logps.append(reference_completion_logp.cpu()) - - self.train_dataset = self.train_dataset.add_column( - name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy() - ) - - self._precomputed_train_ref_log_probs = True - - return super().get_train_dataloader() - - def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: - """ - Returns the evaluation [`~torch.utils.data.DataLoader`]. - - Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`. - - Args: - eval_dataset (`torch.utils.data.Dataset`, *optional*): - If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted - by the `model.forward()` method are automatically removed. It must implement `__len__`. - """ - if eval_dataset is None and self.eval_dataset is None: - raise ValueError("Trainer: evaluation requires an eval_dataset.") - eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset - - if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs: - dataloader_params = { - "batch_size": self.args.per_device_eval_batch_size, - "collate_fn": self.data_collator, - "num_workers": self.args.dataloader_num_workers, - "pin_memory": self.args.dataloader_pin_memory, - "shuffle": False, - } - - # prepare dataloader - data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params)) - - reference_completion_logps = [] - - for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"): - reference_completion_logp = self.compute_reference_log_probs(padded_batch) - - reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp) - reference_completion_logps.append(reference_completion_logp.cpu()) - - eval_dataset = eval_dataset.add_column( - name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy() - ) - - # Save calculated reference_chosen_logps and reference_rejected_logps to the eval_dataset for subsequent runs - if self.eval_dataset is not None: - self.eval_dataset = eval_dataset - self._precomputed_eval_ref_log_probs = True - - return super().get_eval_dataloader(eval_dataset=eval_dataset) - - def compute_reference_log_probs(self, padded_batch: dict) -> dict: - """Computes log probabilities of the reference model for a single padded batch of a BCO specific dataset.""" - with torch.no_grad(): - if self.ref_model is None: - with self.null_ref_context(): - if self.is_encoder_decoder: - completion_logits = self.model( - padded_batch["prompt_input_ids"], - attention_mask=padded_batch["prompt_attention_mask"], - decoder_input_ids=padded_batch.get("completion_decoder_input_ids"), - labels=padded_batch["completion_labels"], - ).logits - - else: - completion_logits = self.model( - padded_batch["completion_input_ids"], - attention_mask=padded_batch["completion_attention_mask"], - ).logits - - else: - if self.is_encoder_decoder: - completion_logits = self.ref_model( - padded_batch["prompt_input_ids"], - attention_mask=padded_batch["prompt_attention_mask"], - decoder_input_ids=padded_batch.get("completion_decoder_input_ids"), - labels=padded_batch["completion_labels"], - ).logits - - else: - completion_logits = self.ref_model( - padded_batch["completion_input_ids"], attention_mask=padded_batch["completion_attention_mask"] - ).logits - - completion_logps = self.get_batch_logps( - completion_logits, - padded_batch["completion_labels"], - average_log_prob=False, - is_encoder_decoder=self.is_encoder_decoder, - label_pad_token_id=self.label_pad_token_id, - ) - - return completion_logps - - @staticmethod - def get_batch_logps( - logits: torch.FloatTensor, - labels: torch.LongTensor, - average_log_prob: bool = False, - label_pad_token_id: int = -100, - is_encoder_decoder: bool = False, - ) -> torch.FloatTensor: - """Compute the log probabilities of the given labels under the given logits. - - Args: - logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) - labels: - Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are - ignored. Shape: (batch_size, sequence_length) - average_log_prob: - If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the - log probabilities of the (non-masked) tokens. - label_pad_token_id: - The label value to ignore when computing log probabilities. - is_encoder_decoder: - Whether the model is an encoder-decoder model. If True, the labels are not shifted, and the logits are - assumed to already be aligned with the labels. If False, the labels are shifted to the right by one - position, and the logits are assumed to be aligned with the shifted labels. - - Returns: - A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the - given logits. - """ - if logits.shape[:-1] != labels.shape: - raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.") - - if not is_encoder_decoder: - labels = labels[:, 1:].clone() - logits = logits[:, :-1, :] - else: - # Fixes end-dec RuntimeError - labels = labels.clone() - - loss_mask = labels != label_pad_token_id - - # dummy token; we'll ignore the losses on these tokens later - labels[labels == label_pad_token_id] = 0 - - per_token_logps = selective_log_softmax(logits, labels) - - if average_log_prob: - return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) - else: - return (per_token_logps * loss_mask).sum(-1) - - def forward( - self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]] - ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: - model_kwargs = ( - { - "labels": batch["completion_labels"], - "decoder_input_ids": batch.get("completion_decoder_input_ids"), - } - if self.is_encoder_decoder - else {} - ) - if self.aux_loss_enabled: - model_kwargs["output_router_logits"] = True - - outputs = model( - batch["completion_input_ids"], - attention_mask=batch["completion_attention_mask"], - **model_kwargs, - ) - completion_logits = outputs.logits - - completion_logps = self.get_batch_logps( - completion_logits, - batch["completion_labels"], - average_log_prob=False, - is_encoder_decoder=self.is_encoder_decoder, - label_pad_token_id=self.label_pad_token_id, - ) - - if completion_logps.shape[0] != len(batch["label"]): - raise ValueError( - "There is a mismatch between the number of examples in this batch and the number of " - "examples for which an output sequence was predicted." - ) - - chosen_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is True] - rejected_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is False] - - chosen_logps = completion_logps[chosen_idx, ...] - rejected_logps = completion_logps[rejected_idx, ...] - - chosen_logits = completion_logits[chosen_idx, ...] - rejected_logits = completion_logits[rejected_idx, ...] - - if self.aux_loss_enabled: - return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, outputs.aux_loss) - else: - return (chosen_logps, rejected_logps, chosen_logits, rejected_logits) - - def _get_udm_weight(self, rejected_embeddings: torch.FloatTensor) -> torch.FloatTensor: - prob_desirable = self._get_chosen_prob(rejected_embeddings) - min_ratio = self.args.min_density_ratio - max_ratio = self.args.max_density_ratio - - weight = (prob_desirable / (1 - prob_desirable + 1e-8)).clamp(min=min_ratio, max=max_ratio) - - return weight - - def bco_loss( - self, - policy_chosen_logps: torch.FloatTensor, - policy_rejected_logps: torch.FloatTensor, - reference_chosen_logps: torch.FloatTensor, - reference_rejected_logps: torch.FloatTensor, - chosen_embeddings: Optional[torch.FloatTensor], - rejected_embeddings: Optional[torch.FloatTensor], - do_train: bool = True, - ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: - """Compute the BCO loss for a batch of policy and reference model log probabilities. - - Args: - policy_chosen_logps: - Log probabilities of the policy model for the chosen responses. Shape: (num(chosen) in batch_size,) - policy_rejected_logps: - Log probabilities of the policy model for the rejected responses. Shape: (num(rejected) in batch_size,) - reference_chosen_logps: - Log probabilities of the reference model for the chosen responses. Shape: (num(chosen) in batch_size,) - reference_rejected_logps: - Log probabilities of the reference model for the rejected responses. Shape: (num(rejected) in - batch_size,) - chosen_embeddings: embeddings of desirable prompts - rejected_embeddings: embeddings of undesirable prompts - do_train: whether to update the running delta value. Default is True. - - Returns: - A tuple of four tensors: (losses, chosen_rewards, rejected_rewards, delta). The losses tensor contains the - BCO loss for each example in the batch. The chosen_rewards and rejected_rewards tensors contain the rewards - for the chosen and rejected responses, respectively. The delta value contains the moving average of all - implicit rewards. - """ - - chosen_logratios = policy_chosen_logps - reference_chosen_logps - chosen_rewards = self.beta * chosen_logratios - - rejected_logratios = policy_rejected_logps - reference_rejected_logps - rejected_rewards = self.beta * rejected_logratios - - if do_train: - self.running.update(torch.cat((chosen_rewards, rejected_rewards), 0).detach()) - delta = torch.as_tensor(self.running.mean, device=chosen_rewards.device) - - chosen_losses = -F.logsigmoid(chosen_rewards - delta) - rejected_losses = -F.logsigmoid(-(rejected_rewards - delta)) - - if self.match_underlying_distribution: - chosen_weight = torch.ones_like(chosen_losses) - rejected_weight = self._get_udm_weight(rejected_embeddings) - - losses = torch.cat((chosen_weight * chosen_losses, rejected_weight * rejected_losses), dim=0) - else: - losses = torch.cat((chosen_losses, rejected_losses), dim=0) - - return losses, chosen_rewards, rejected_rewards, delta - - def get_batch_loss_metrics( - self, - model, - batch: dict[str, Union[list, torch.LongTensor]], - do_train: bool = True, - ): - """Compute the BCO loss and other metrics for the given batch of inputs for train or test.""" - metrics = {} - batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()} - - forward_output = self.forward(model, batch) - ( - policy_chosen_logps, - policy_rejected_logps, - policy_chosen_logits, - policy_rejected_logits, - ) = forward_output[:4] - if self.aux_loss_enabled: - aux_loss = forward_output[4] - - # if reference_logps in batch use them, otherwise use the reference model - if "reference_logps" in batch: - chosen_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is True] - rejected_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is False] - - reference_chosen_logps = batch["reference_logps"][chosen_idx, ...] - reference_rejected_logps = batch["reference_logps"][rejected_idx, ...] - else: - with torch.no_grad(): - if self.ref_model is None: - with self.null_ref_context(): - ( - reference_chosen_logps, - reference_rejected_logps, - _, - _, - ) = self.forward(self.model, batch)[:4] - else: - ( - reference_chosen_logps, - reference_rejected_logps, - _, - _, - ) = self.forward(self.ref_model, batch)[:4] - - chosen_embeddings, rejected_embeddings = self._get_prompt_embeddings(batch) - - losses, chosen_rewards, rejected_rewards, delta = self.bco_loss( - policy_chosen_logps, - policy_rejected_logps, - reference_chosen_logps, - reference_rejected_logps, - chosen_embeddings, - rejected_embeddings, - do_train=do_train, - ) - metrics["delta"] = self.accelerator.gather_for_metrics(delta).mean().item() - - num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device) - num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device) - - all_num_chosen = self.accelerator.gather_for_metrics(num_chosen).sum().item() - all_num_rejected = self.accelerator.gather_for_metrics(num_rejected).sum().item() - - if all_num_chosen > 0: - metrics["rewards/chosen_sum"] = ( - self.accelerator.gather_for_metrics(chosen_rewards.nansum()).nansum().item() - ) - metrics["logps/chosen_sum"] = ( - self.accelerator.gather_for_metrics(policy_chosen_logps.nansum()).nansum().item() - ) - metrics["logits/chosen_sum"] = ( - self.accelerator.gather_for_metrics(policy_chosen_logits.nansum()).nansum().item() - ) - metrics["count/chosen"] = all_num_chosen - - if all_num_rejected > 0: - metrics["rewards/rejected_sum"] = ( - self.accelerator.gather_for_metrics(rejected_rewards.nansum()).nansum().item() - ) - metrics["logps/rejected_sum"] = ( - self.accelerator.gather_for_metrics(policy_rejected_logps.nansum()).nansum().item() - ) - metrics["logits/rejected_sum"] = ( - self.accelerator.gather_for_metrics(policy_rejected_logits.nansum()).nansum().item() - ) - metrics["count/rejected"] = all_num_rejected - - loss = losses.nanmean() - if self.aux_loss_enabled: - loss += self.aux_loss_coef * aux_loss - - return loss, metrics - - def compute_loss( - self, - model: Union[PreTrainedModel, nn.Module], - inputs: dict[str, Union[torch.Tensor, Any]], - return_outputs=False, - num_items_in_batch=None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]: - compute_loss_context_manager = ( - autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() - ) - - with compute_loss_context_manager: - loss, metrics = self.get_batch_loss_metrics(model, inputs) - - # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class: - loss = loss.to(self.args.device) - # force log the metrics - if self.accelerator.is_main_process: - self.store_metrics(metrics, train_eval="train") - - if return_outputs: - return (loss, metrics) - return loss - - def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None: - for key, value in metrics.items(): - self._stored_metrics[train_eval][key].append(value) - - def _get_train_sampler(self, dataset: Optional[Dataset] = None) -> Optional[torch.utils.data.Sampler]: - if dataset is None: - dataset = self.train_dataset - if dataset is None or not has_length(dataset): - return None - return SequentialSampler(dataset) - - def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]: - """Generate samples from the model and reference model for the given batch of inputs.""" - - # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with - # the torch amp context manager as some hidden states are silently casted to full precision. - generate_context_manager = ( - autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() - ) - with generate_context_manager: - policy_output = model.generate( - input_ids=batch["prompt_input_ids"], - attention_mask=batch["prompt_attention_mask"], - max_length=self.max_length, - do_sample=True, - pad_token_id=self.processing_class.pad_token_id, - ) - - # if reference_output in batch use that otherwise use the reference model - if "reference_output" in batch: - reference_output = batch["reference_output"] - else: - if self.ref_model is None: - with self.null_ref_context(): - reference_output = self.model.generate( - input_ids=batch["prompt_input_ids"], - attention_mask=batch["prompt_attention_mask"], - max_length=self.max_length, - do_sample=True, - pad_token_id=self.processing_class.pad_token_id, - ) - else: - reference_output = self.ref_model.generate( - input_ids=batch["prompt_input_ids"], - attention_mask=batch["prompt_attention_mask"], - max_length=self.max_length, - do_sample=True, - pad_token_id=self.processing_class.pad_token_id, - ) - - policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id) - policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True) - - reference_output = pad_to_length(reference_output, self.max_length, self.processing_class.pad_token_id) - reference_output_decoded = self.processing_class.batch_decode(reference_output, skip_special_tokens=True) - - return policy_output_decoded, reference_output_decoded - - def prediction_step( - self, - model: Union[PreTrainedModel, nn.Module], - inputs: dict[str, Union[torch.Tensor, Any]], - prediction_loss_only: bool, - ignore_keys: Optional[list[str]] = None, - ): - if ignore_keys is None: - if hasattr(model, "config"): - ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", []) - else: - ignore_keys = [] - - prediction_context_manager = ( - autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() - ) - with torch.no_grad(), prediction_context_manager: - loss, metrics = self.get_batch_loss_metrics(model, inputs, do_train=False) - - # force log the metrics - if self.accelerator.is_main_process: - self.store_metrics(metrics, train_eval="eval") - - if prediction_loss_only: - return (loss.detach(), None, None) - - # logits for the chosen and rejected samples from model - logits_dict = {} - if "logits/chosen_sum" in metrics: - logits_dict["eval_logits/chosen"] = metrics["logits/chosen_sum"] - if "logits/rejected_sum" in metrics: - logits_dict["eval_logits/rejected"] = metrics["logits/rejected_sum"] - logits = [v for k, v in logits_dict.items() if k not in ignore_keys] - logits = torch.tensor(logits, device=self.accelerator.device) - labels = torch.zeros(logits.shape[0], device=self.accelerator.device) - - return (loss.detach(), logits, labels) - - def evaluation_loop( - self, - dataloader: DataLoader, - description: str, - prediction_loss_only: Optional[bool] = None, - ignore_keys: Optional[list[str]] = None, - metric_key_prefix: str = "eval", - ) -> EvalLoopOutput: - """ - Overriding built-in evaluation loop to store metrics for each batch. Prediction/evaluation loop, shared by - `Trainer.evaluate()` and `Trainer.predict()`. - - Works both with or without labels. - """ - - # Sample and save to game log if requested (for one batch to save time) - if self.generate_during_eval: - # Generate random indices within the range of the total number of samples - num_samples = len(dataloader.dataset) - random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size) - - # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader - random_batch_dataset = dataloader.dataset.select(random_indices) - random_batch = self.data_collator(random_batch_dataset) - random_batch = self._prepare_inputs(random_batch) - - target_labels = torch.tensor(random_batch["label"], dtype=torch.bool, device=self.accelerator.device) - target_indices = torch.where(~target_labels)[0] - target_batch = { - "prompt_input_ids": random_batch["prompt_input_ids"][target_indices], - "prompt_attention_mask": random_batch["prompt_attention_mask"][target_indices], - "prompt": itemgetter(*target_indices)(random_batch["prompt"]), - } - policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, target_batch) - - table = pd.DataFrame( - columns=["Prompt", "Policy", "Ref Model"], - data=[ - [prompt, pol[len(prompt) :], ref[len(prompt) :]] - for prompt, pol, ref in zip(target_batch["prompt"], policy_output_decoded, ref_output_decoded) - ], - ) - if "wandb" in self.args.report_to: - wandb.log({"game_log": wandb.Table(data=table)}) - - if "comet_ml" in self.args.report_to: - log_table_to_comet_experiment( - name="game_log.csv", - table=table, - ) - - # Base evaluation - initial_output = super().evaluation_loop( - dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix - ) - - return initial_output - - def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None: - """ - Log `logs` on the various objects watching training, including stored metrics. - - Args: - logs (`dict[str, float]`): - The values to log. - start_time (`float`, *optional*): - Start time of the training. - """ - # logs either has 'loss' or 'eval_loss' - train_eval = "train" if "loss" in logs else "eval" - # train metrics should have no prefix, eval should have 'eval_' - prefix = "eval_" if train_eval == "eval" else "" - # accumulate average metrics from sums and lengths - for split in ["chosen", "rejected"]: - if f"count/{split}" in self._stored_metrics[train_eval]: - count_sum = torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"]).sum().item() - for metric in ["rewards", "logps", "logits"]: - logs[f"{prefix}{metric}/{split}"] = ( - torch.Tensor(self._stored_metrics[train_eval][f"{metric}/{split}_sum"]).sum().item() - / count_sum - ) - # delete obsolete metric - del self._stored_metrics[train_eval][f"{metric}/{split}_sum"] - del self._stored_metrics[train_eval][f"count/{split}"] - # calculate reward margin - if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs: - logs[f"{prefix}rewards/margins"] = logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"] - # Add averaged stored metrics to logs - for key, metrics in self._stored_metrics[train_eval].items(): - logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item() - del self._stored_metrics[train_eval] - return super().log(logs, start_time) - - # Ensure the model card is saved along with the checkpoint - def _save_checkpoint(self, model, trial): - if self.args.hub_model_id is None: - model_name = Path(self.args.output_dir).name - else: - model_name = self.args.hub_model_id.split("/")[-1] - self.create_model_card(model_name=model_name) - super()._save_checkpoint(model, trial) -class UnslothBCOTrainer(_UnslothBCOTrainer): - """ - - Initialize BCOTrainer from [BCO](https://huggingface.co/papers/2404.04656) paper. - - Args: - model ([`~transformers.PreTrainedModel`]): - The model to train, preferably an [`~transformers.AutoModelForSequenceClassification`]. - ref_model ([`PreTrainedModelWrapper`]): - Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation - and loss. If no reference model is provided, the trainer will create a reference model with the same - architecture as the model to be optimized. - args ([`BCOConfig`]): - The arguments to use for training. - train_dataset ([`~datasets.Dataset`]): - The dataset to use for training. - eval_dataset ([`~datasets.Dataset`]): - The dataset to use for evaluation. - processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*): - Processing class used to process the data. If provided, will be used to automatically process the inputs - for the model, and it will be saved along the model to make it easier to rerun an interrupted training or - reuse the fine-tuned model. - data_collator ([`~transformers.DataCollator`], *optional*): - The data collator to use for training. If None is specified, the default data collator - ([`DPODataCollatorWithPadding`]) will be used which will pad the sequences to the maximum length of the - sequences in the batch, given a dataset of paired sequences. - model_init (`Callable[[], transformers.PreTrainedModel]`): - The model initializer to use for training. If None is specified, the default model initializer will be - used. - callbacks (`list[transformers.TrainerCallback]`): - The callbacks to use for training. - optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): - The optimizer and scheduler to use for training. - preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): - The function to use to preprocess the logits before computing the metrics. - peft_config (`dict`, defaults to `None`): - The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in - a PEFT model. - compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): - The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to - metric values. - model_adapter_name (`str`, defaults to `None`): - Name of the train target PEFT adapter, when using LoRA with multiple adapters. - ref_adapter_name (`str`, defaults to `None`): - Name of the reference PEFT adapter, when using LoRA with multiple adapters. - - """ - def __init__( - self, - model = None, - ref_model = None, - args = None, - train_dataset = None, - eval_dataset = None, - processing_class = None, - data_collator = None, - model_init = None, - callbacks = None, - preprocess_logits_for_metrics = None, - peft_config = None, - compute_metrics = None, - model_adapter_name = None, - ref_adapter_name = None, - embedding_func = None, - embedding_tokenizer = None, - **kwargs - ): - if args is None: args = UnslothBCOConfig() - use_bf16 = getattr(args, 'bf16', False) - if type(use_bf16) is not bool: use_bf16 = False - use_fp16 = getattr(args, 'fp16', False) - if type(use_fp16) is not bool: use_fp16 = False - force_float32 = False - full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1' - if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'): - print('Unsloth: Switching to float32 training since model cannot work with float16') - force_float32 = True - mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') - dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None) - if dtype is None: dtype = model.get_input_embeddings().weight.dtype - from unsloth_zoo.utils import _get_dtype - dtype = _get_dtype(dtype) - float16 = dtype == torch.float16 - if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`') - if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`') - if force_float32: - # Forced float32 training - args.fp16 = False - args.bf16 = False - os.environ['ACCELERATE_MIXED_PRECISION'] = 'no' - if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no' - # args.mixed_precision is a new argument which needs to be set now - elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32': - # Mixed precision training - args.fp16 = float16 - args.bf16 = not float16 - os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16' - if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16' - # args.mixed_precision is a new argument which needs to be set now - elif mixed_precision_dtype == 'bfloat16': - # Both False since bfloat16 full finetuning doesn't do any autocasting. - args.fp16 = False - args.bf16 = False - os.environ['ACCELERATE_MIXED_PRECISION'] = 'no' - if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no' - # args.mixed_precision is a new argument which needs to be set now - - if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no': - args.eval_strategy = 'steps' - if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1 - ga_steps = getattr(args, 'gradient_accumulation_steps', None) - if ga_steps is not None and ga_steps > 1: - from transformers import __version__ as transformers_version - if Version(transformers_version) <= Version('4.45.2'): - print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n' - '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`') - if getattr(args, 'eval_strategy', 'no') != 'no': - eval_bsz = getattr(args, 'per_device_eval_batch_size', 8) - if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size - if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps - fp16_full_eval = getattr(args, 'fp16_full_eval', False) - if type(fp16_full_eval) is not bool: fp16_full_eval = False - bf16_full_eval = getattr(args, 'bf16_full_eval', False) - if type(bf16_full_eval) is not bool: bf16_full_eval = False - if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True - if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False - if force_float32: - args.bf16_full_eval = False - args.fp16_full_eval = False - elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16': - args.bf16_full_eval = True - args.fp16_full_eval = False - elif not bf16_full_eval and not fp16_full_eval: - args.bf16_full_eval = args.bf16 - args.fp16_full_eval = args.fp16 - _output_logits = False - if locals().get('compute_metrics', None) is not None: _output_logits = True - if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True - if _output_logits: - os.environ['UNSLOTH_RETURN_LOGITS'] = '1' - if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'): - pass - else: - model_max_seq_length = getattr(model, 'max_seq_length', None) - args_max_seq_length = getattr(args, 'max_seq_length', None) - if args_max_seq_length is None and model_max_seq_length is not None: - max_seq_length = model.max_seq_length - if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length - elif args_max_seq_length is not None and model_max_seq_length is not None: - if args_max_seq_length > model_max_seq_length: - print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but ' - 'the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.') - args.max_seq_length = model_max_seq_length - if model is not None and hasattr(model, 'for_training'): - model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True)) - if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right' - if 'processing_class' in locals(): - if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right' - if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right' - __tokenizer = processing_class if 'processing_class' in locals() else tokenizer - from unsloth_zoo.vision_utils import UnslothVisionDataCollator - if not isinstance(data_collator, UnslothVisionDataCollator): - if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names: - data_collator = TransformersDataCollatorForLanguageModeling( - __tokenizer, - mlm = False, - mlm_probability = 0.0, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names: - data_collator = DataCollatorForSeq2Seq( - __tokenizer, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - else: - if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False - if hasattr(args, 'dataset_text_field'): args.dataset_text_field = '' - if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True} - if not isinstance(data_collator, UnslothVisionDataCollator): - if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'): - if isinstance(data_collator, DataCollatorForSeq2Seq): - data_collator = DataCollatorForSeq2Seq( - __tokenizer.tokenizer, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - else: - data_collator = TransformersDataCollatorForLanguageModeling( - __tokenizer.tokenizer, - mlm = False, - mlm_probability = 0.0, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - other_metrics = [] - - from unsloth_zoo.logging_utils import PatchRLStatistics - PatchRLStatistics('bco_trainer', other_metrics) - - # [TODO] Fix up DataParallel multiplying batch sizes - # [TODO] DDP works, but DP seems to not work? [TODO] - if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1: - if getattr(args, "_n_gpu", 1) != 1: - args._n_gpu = 1 - if "model" in locals() and hasattr(model, "for_training"): - model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True)) - super().__init__( - model = model, - ref_model = ref_model, - args = args, - train_dataset = train_dataset, - eval_dataset = eval_dataset, - processing_class = processing_class, - data_collator = data_collator, - model_init = model_init, - callbacks = callbacks, - preprocess_logits_for_metrics = preprocess_logits_for_metrics, - peft_config = peft_config, - compute_metrics = compute_metrics, - model_adapter_name = model_adapter_name, - ref_adapter_name = ref_adapter_name, - embedding_func = embedding_func, - embedding_tokenizer = embedding_tokenizer,**kwargs) - if "model" in locals() and hasattr(model, "for_inference"): - model.for_inference() - if hasattr(self, 'neftune_hook_handle'): - self.neftune_hook_handle.remove() - if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle - if getattr(args, 'neftune_noise_alpha', None) is not None: - model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha - pass - if hasattr(self, 'accelerator'): - scaler = self.accelerator.scaler - current_model = model - while hasattr(current_model, 'model'): - current_model.accelerator_scaler = scaler - current_model = current_model.model - current_model.accelerator_scaler = scaler - pass - if hasattr(self, 'train'): - self.train = MethodType(prepare_for_training_mode(self.__class__.train), self) - pass - if hasattr(self, 'llm') and self.llm is not None and hasattr(self.llm, 'get_tokenizer'): - _vllm_tok = self.llm.get_tokenizer() - _pc = getattr(self, 'processing_class', None) or getattr(self, 'tokenizer', None) - if _vllm_tok is not None and _pc is not None and getattr(_pc, 'chat_template', None) is not None and getattr(_vllm_tok, 'chat_template', None) is None: - _vllm_tok.chat_template = _pc.chat_template - pass - -pass - - -if hasattr(logger, "addFilter"): - import logging - class HideLoggingMessage(logging.Filter): - def __init__(self, text): self.text = text - def filter(self, x): return not (self.text in x.getMessage()) - pass - logger.addFilter(HideLoggingMessage("`use_cache=True`")) - diff --git a/unsloth_compiled_cache/UnslothCPOTrainer.py b/unsloth_compiled_cache/UnslothCPOTrainer.py deleted file mode 100644 index eee10fe..0000000 --- a/unsloth_compiled_cache/UnslothCPOTrainer.py +++ /dev/null @@ -1,1960 +0,0 @@ -""" -2026.2.1 -2026.2.1 -4.57.6 -0.24.0 -__UNSLOTH_VERSIONING__ -""" - -# Unsloth auto generated code -# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see . - -from torch import Tensor -import torch -import torch.nn as nn -from torch.nn import functional as F -from unsloth_zoo.temporary_patches.common import torch_compile -from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable -from trl.trainer.cpo_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, BaseTrainer, CPOConfig, CPOTrainer, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, Optional, PartialState, Path, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, Union, add_bos_token_if_needed, add_eos_token_if_needed, autocast, defaultdict, disable_dropout_in_model, inspect, is_comet_available, is_peft_available, is_torch_fx_proxy, is_wandb_available, log_table_to_comet_experiment, logger, logging, maybe_apply_chat_template, maybe_extract_prompt, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, selective_log_softmax, textwrap, torch, warnings, AutoModelForCausalLM, BaseImageProcessor, CPOConfig, CPOTrainer, Callable, DPODataCollatorWithPadding, DataCollator, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, Union, autocast, defaultdict, disable_dropout_in_model, inspect, is_comet_available, is_peft_available, is_wandb_available, logger, maybe_apply_chat_template, maybe_extract_prompt, nn, np, os, peft_module_casting_to_bf16, prepare_model_for_kbit_training, torch, warnings, F, Optional, PeftModel, PreTrainedModel, is_peft_available, logger, os, torch) - - -import os -from typing import * -from dataclasses import dataclass, field -from packaging.version import Version -import torch -import numpy as np -from contextlib import nullcontext -from torch.nn import functional as F -import inspect -from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling -from transformers.training_args import ParallelMode -from unsloth_zoo.device_type import DEVICE_TYPE, device_synchronize - -# Wrap trainer with padding to right and enable training mode -# Also patches W&B since multiple runs must use wandb.finish() -import functools -from types import MethodType -try: - from unsloth_zoo.gradient_checkpointing import reset_unsloth_gradient_checkpointing_buffers -except: - def reset_unsloth_gradient_checkpointing_buffers(): pass -def prepare_for_training_mode(f): - @functools.wraps(f) - def wrapper(self, *args, **kwargs): - # Enable training mode - _was_training = None - # Get gradient checkpointing setting from training arguments - use_gc = getattr(self.args, 'gradient_checkpointing', True) - if hasattr(self, 'model') and hasattr(self.model, "training"): - _was_training = self.model.training - if hasattr(self, 'model') and hasattr(self.model, "for_training"): - self.model.for_training(use_gradient_checkpointing=use_gc) - output = f(self, *args, **kwargs) - # Restore previous mode when possible - if hasattr(self, 'model') and hasattr(self.model, "for_inference"): - if _was_training is False: - self.model.for_inference() - elif _was_training is True and hasattr(self.model, "for_training"): - self.model.for_training(use_gradient_checkpointing=use_gc) - # Reset gradient checkpointing buffers to free memory while staying ready for next run - try: - reset_unsloth_gradient_checkpointing_buffers() - except: - pass - # Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run - try: - import wandb - wandb.finish() - except: - pass - return output - return wrapper -pass - -torch_compile_options = { - "epilogue_fusion" : True, - "max_autotune" : False, - "shape_padding" : True, - "trace.enabled" : False, - "triton.cudagraphs" : False, -} - -@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) -def chunked_hidden_states_selective_log_softmax( - hidden_states: torch.Tensor, - lm_head: torch.Tensor, - index: torch.Tensor, - chunks: int = 4, - logit_scale_multiply: float = 0.0, - logit_scale_divide: float = 0.0, - logit_softcapping: float = 0.0, - temperature: float = 1.0, -) -> torch.Tensor: - # All Unsloth Zoo code licensed under AGPL3 - flat_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1]) - flat_index = index.reshape(-1) - - chunked_hidden_states = torch.chunk(flat_hidden_states, chunks=chunks, dim=0) - chunked_index = torch.chunk(flat_index, chunks=chunks, dim=0) - - all_per_token_logps = [] - - for chunk_hidden_states, chunk_index in zip(chunked_hidden_states, chunked_index): - chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t() - - if logit_scale_multiply != 0.0: - chunk_logits = chunk_logits * logit_scale_multiply - if logit_scale_divide != 0.0: - chunk_logits = chunk_logits / logit_scale_divide - if logit_softcapping != 0.0: - chunk_logits = chunk_logits * torch.tanh(chunk_logits / logit_softcapping) - - chunk_logits = chunk_logits.to(torch.float32) - - if temperature != 1.0: - chunk_logits = chunk_logits / temperature - - selected_logits = torch.gather(chunk_logits, dim=-1, index=chunk_index.unsqueeze(-1)).squeeze(-1) - logsumexp_values = torch.logsumexp(chunk_logits, dim=-1) - per_token_logps = selected_logits - logsumexp_values - all_per_token_logps.append(per_token_logps) - - all_per_token_logps = torch.concat(all_per_token_logps) - - all_per_token_logps = all_per_token_logps.reshape((hidden_states.shape[0], hidden_states.shape[1])) - return all_per_token_logps - -@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) -def chunked_selective_log_softmax(logits, index): - # Split into 4 chunks only - chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0) - chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0) - all_per_token_logps = [] - # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index) - for chunk_logits, chunk_index in zip(chunked_logits, chunked_index): - chunk_logits = chunk_logits.to(torch.float32) - selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1) - logsumexp_values = torch.logsumexp(chunk_logits, dim = -1) - per_token_logps = selected_logits - logsumexp_values - all_per_token_logps.append(per_token_logps) - pass - all_per_token_logps = torch.concat(all_per_token_logps) - all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1])) - return all_per_token_logps - -def calculate_pad_tokens_in_prompt( - input_ids: torch.Tensor, - logits_to_keep: int, - pad_token_id: int -) -> torch.Tensor: - """ - Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens - """ - if logits_to_keep >= input_ids.shape[1]: - raise ValueError("logits_to_keep must be smaller than the sequence length.") - - prompt_section = input_ids[:, :-logits_to_keep] - - padding_mask = (prompt_section == pad_token_id) - - pad_token_counts = padding_mask.sum(dim=1) - - return pad_token_counts - -def create_completion_attention_mask( - completion_input_ids: torch.Tensor, - left_pad_tokens_per_prompt: torch.Tensor, - max_left_pad: int, - pad_token_id: int -) -> torch.Tensor: - """ - Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad] - - Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens - and pad are pad tokens, this function would make a completion mask that would 0 out the pad - and p tokens. so in this example [0,0,0,1,1,1,0,0,0] - """ - batch_size, completion_len = completion_input_ids.shape - device = completion_input_ids.device - - num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt - - indices = torch.arange(completion_len, device=device).unsqueeze(0) - shift_mask = indices >= num_tokens_to_mask.unsqueeze(1) - - non_padding_mask = (completion_input_ids != pad_token_id) - - final_mask = shift_mask & non_padding_mask - - return final_mask - -def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor: - """ - Moves all padding tokens in each sequence of a batch to the right. - """ - mask = (tensor != pad_id) - # Must do stable=True since binary mark is unordered - sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True) - packed_tensor = torch.gather(tensor, 1, sorted_indices) - return packed_tensor - -def align_logprobs_with_mask( - logprob_tensor: torch.Tensor, - attention_mask: torch.Tensor, - pad_value: float = 0.0 -) -> torch.Tensor: - """ - Aligns a log probability tensor with a given attention mask. - """ - - device = logprob_tensor.device - batch_size, logprob_seq_len = logprob_tensor.shape - mask_seq_len = attention_mask.shape[1] - - padded_logprobs = torch.full( - attention_mask.shape, - fill_value=pad_value, - dtype=logprob_tensor.dtype, - device=device - ) - - left_pad_counts = torch.argmax(attention_mask, dim=1) - - cols = torch.arange(logprob_seq_len, device=device) - dest_indices = left_pad_counts.unsqueeze(1) + cols - - # Create destination row indices - # Shape: [batch_size, logprob_seq_len] - row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices) - - # --- 4. Filter out-of-bounds indices and perform assignment --- - # Create a mask to identify only the indices that are within the bounds - # of the target tensor's sequence length. - valid_mask = dest_indices < mask_seq_len - - # Use this mask to select only the valid row indices, column indices, - # and the corresponding values from the logprob tensor. - # This flattens the selected elements into 1D tensors. - valid_rows = row_indices[valid_mask] - valid_cols = dest_indices[valid_mask] - valid_vals = logprob_tensor[valid_mask] - - # Place the valid values into their correct positions in the padded tensor - # using a single, efficient advanced indexing operation. - padded_logprobs[valid_rows, valid_cols] = valid_vals - - return padded_logprobs - -def autotune_batch_and_chunks( - total_input_rows, - seq_len, - hidden_size, - vocab_size, - dtype_bytes=16, - multiplier=None -): - if multiplier is None: - final_m = max(4, seq_len // 4096) - else: - final_m = multiplier - - if torch.cuda.is_available(): - free_bytes, _ = torch.cuda.mem_get_info() - limit_gb = (free_bytes / (1024**3))*.80 - elif hasattr(torch, "xpu") and torch.xpu.is_available(): - # For XPU: estimate free memory from total - reserved - total_mem = torch.xpu.get_device_properties(0).total_memory - reserved_mem = torch.xpu.memory_reserved() - free_bytes = total_mem - reserved_mem - limit_gb = (free_bytes / (1024**3)) * 0.80 - else: - # Fallback: assume 8GB available - limit_gb = 8.0 - - bytes_to_gb = 1024**3 - - b_vals = torch.arange(total_input_rows, 0, -1, device='cpu', dtype=torch.float32) - - hidden_gb = (b_vals * seq_len * hidden_size * dtype_bytes) / bytes_to_gb - - base_logits = ((b_vals/total_input_rows) * b_vals * seq_len * vocab_size * dtype_bytes) / bytes_to_gb - logits_gb = base_logits / final_m - - total_mem_gb = hidden_gb + logits_gb - - valid_mask = total_mem_gb <= limit_gb - valid_indices = torch.nonzero(valid_mask, as_tuple=False) - - if valid_indices.shape[0] == 0: - #This means your GPU will OOM - return 4, final_m - - best_idx = valid_indices[0].item() - final_b = int(b_vals[best_idx].item()) - - return final_b, final_m -@dataclass -class UnslothCPOConfig(CPOConfig): - """ - - Configuration class for the [`CPOTrainer`]. - - This class includes only the parameters that are specific to CPO training. For a full list of training arguments, - please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may - differ from those in [`~transformers.TrainingArguments`]. - - Using [`~transformers.HfArgumentParser`] we can turn this class into - [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the - command line. - - Parameters: - max_length (`int` or `None`, *optional*, defaults to `1024`): - Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want - to use the default data collator. - max_prompt_length (`int` or `None`, *optional*, defaults to `512`): - Maximum length of the prompt. This argument is required if you want to use the default data collator. - max_completion_length (`int`, *optional*): - Maximum length of the completion. This argument is required if you want to use the default data collator - and your model is an encoder-decoder. - beta (`float`, *optional*, defaults to `0.1`): - Parameter controlling the deviation from the reference model. Higher β means less deviation from the - reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in - the [paper](https://huggingface.co/papers/2310.12036). - label_smoothing (`float`, *optional*, defaults to `0.0`): - Label smoothing factor. This argument is required if you want to use the default data collator. - loss_type (`str`, *optional*, defaults to `"sigmoid"`): - Type of loss to use. Possible values are: - - - `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper. - - `"hinge"`: hinge loss on the normalized likelihood from the - [SLiC](https://huggingface.co/papers/2305.10425) paper. - - `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper. - - `"simpo"`: SimPO loss from the [SimPO](https://huggingface.co/papers/2405.14734) paper. - - `"alphapo"`: AlphaPO loss from the [AlphaPO](https://huggingface.co/papers/2501.03884) paper. This - automatically sets `loss_type="simpo"` and `cpo_alpha=0.0`. - - disable_dropout (`bool`, *optional*, defaults to `True`): - Whether to disable dropout in the model. - cpo_alpha (`float`, *optional*, defaults to `1.0`): - Weight of the BC regularizer in CPO training. - simpo_gamma (`float`, *optional*, defaults to `0.5`): - Target reward margin for the SimPO loss, used only when the `loss_type="simpo"`. - alpha (`float`, *optional*, defaults to `0.0`): - Alpha parameter that controls reward function shape across all loss types. When alpha=0 (default), uses - standard log probability rewards. When `alpha != 0`, applies AlphaPO transformation: `r = (1 - p^(-alpha)) - / alpha` from the [AlphaPO paper](https://huggingface.co/papers/2501.03884). This parameter works with all - loss types. - label_pad_token_id (`int`, *optional*, defaults to `-100`): - Label pad token id. This argument is required if you want to use the default data collator. - padding_value (`int`, *optional*): - Padding value to use. If `None`, the padding value of the tokenizer is used. - truncation_mode (`str`,*optional*, defaults to `"keep_end"`): - Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`. - This argument is required if you want to use the default data collator. - generate_during_eval (`bool`, *optional*, defaults to `False`): - If `True`, generates and logs completions from the model to W&B or Comet during evaluation. - is_encoder_decoder (`bool`, *optional*): - When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument, - you need to specify if the model returned by the callable is an encoder-decoder model. - model_init_kwargs (`dict[str, Any]`, *optional*): - Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a - string. - dataset_num_proc (`int`, *optional*): - Number of processes to use for processing the dataset. - - """ - vllm_sampling_params: Optional[Any] = field( - default = None, - metadata = {'help': 'vLLM SamplingParams'}, - ) - unsloth_num_chunks : Optional[int] = field( - default = -1, - metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'}, - ) - unsloth_logit_chunk_multiplier : Optional[int] = field( - default = None, - metadata = {'help': 'Multiplier for chunked logit computations.'}, - ) - unsloth_grpo_mini_batch : Optional[int] = field( - default = None, - metadata = {'help': 'Mini batch size for GRPO hidden state accumulation. Default is None unless user defines it.'}, - ) - max_seq_length : Optional[int] = field( - default = None, - metadata = {'help': 'Maximum sequence length to truncate to.'}, - ) - def __init__( - self, - output_dir = None, - overwrite_output_dir = None, - do_train = False, - do_eval = False, - do_predict = False, - eval_strategy = 'no', - prediction_loss_only = False, - per_device_train_batch_size = 4, - per_device_eval_batch_size = 4, - per_gpu_train_batch_size = None, - per_gpu_eval_batch_size = None, - gradient_accumulation_steps = 2, - eval_accumulation_steps = 2, - eval_delay = 0, - torch_empty_cache_steps = 250, - learning_rate = 5e-05, - weight_decay = 0.01, - adam_beta1 = 0.9, - adam_beta2 = 0.999, - adam_epsilon = 1e-08, - max_grad_norm = 1.0, - num_train_epochs = 3.0, - max_steps = -1, - lr_scheduler_type = 'linear', - lr_scheduler_kwargs = None, - warmup_ratio = 0.1, - warmup_steps = 0, - log_level = 'passive', - log_level_replica = 'warning', - log_on_each_node = True, - logging_dir = None, - logging_strategy = 'steps', - logging_first_step = False, - logging_steps = 1, - logging_nan_inf_filter = False, - save_strategy = 'steps', - save_steps = 500, - save_total_limit = None, - save_safetensors = True, - save_on_each_node = False, - save_only_model = False, - restore_callback_states_from_checkpoint = False, - no_cuda = False, - use_cpu = False, - use_mps_device = False, - seed = 3407, - data_seed = 3407, - jit_mode_eval = False, - bf16 = False, - fp16 = False, - fp16_opt_level = 'O1', - half_precision_backend = 'auto', - bf16_full_eval = False, - fp16_full_eval = False, - tf32 = None, - local_rank = -1, - ddp_backend = None, - tpu_num_cores = None, - tpu_metrics_debug = False, - debug = '', - dataloader_drop_last = False, - eval_steps = None, - dataloader_num_workers = 0, - dataloader_prefetch_factor = None, - past_index = -1, - run_name = None, - disable_tqdm = None, - remove_unused_columns = True, - label_names = None, - load_best_model_at_end = False, - metric_for_best_model = None, - greater_is_better = None, - ignore_data_skip = False, - fsdp = None, - fsdp_min_num_params = 0, - fsdp_config = None, - fsdp_transformer_layer_cls_to_wrap = None, - accelerator_config = None, - parallelism_config = None, - deepspeed = None, - label_smoothing_factor = 0.0, - optim = 'adamw_8bit', - optim_args = None, - adafactor = False, - group_by_length = False, - length_column_name = 'length', - report_to = 'none', - project = 'huggingface', - trackio_space_id = 'trackio', - ddp_find_unused_parameters = None, - ddp_bucket_cap_mb = None, - ddp_broadcast_buffers = None, - dataloader_pin_memory = True, - dataloader_persistent_workers = False, - skip_memory_metrics = True, - use_legacy_prediction_loop = False, - push_to_hub = False, - resume_from_checkpoint = None, - hub_model_id = None, - hub_strategy = 'every_save', - hub_token = None, - hub_private_repo = None, - hub_always_push = False, - hub_revision = None, - gradient_checkpointing = True, - gradient_checkpointing_kwargs = None, - include_inputs_for_metrics = False, - eval_do_concat_batches = True, - fp16_backend = 'auto', - push_to_hub_model_id = None, - push_to_hub_organization = None, - push_to_hub_token = None, - mp_parameters = '', - auto_find_batch_size = False, - full_determinism = False, - torchdynamo = None, - ray_scope = 'last', - ddp_timeout = 1800, - torch_compile = False, - torch_compile_backend = None, - torch_compile_mode = None, - include_tokens_per_second = False, - include_num_input_tokens_seen = False, - neftune_noise_alpha = None, - optim_target_modules = None, - batch_eval_metrics = False, - eval_on_start = False, - use_liger_kernel = False, - liger_kernel_config = None, - eval_use_gather_object = False, - average_tokens_across_devices = True, - max_length = 1024, - max_prompt_length = 512, - max_completion_length = None, - beta = 0.1, - label_smoothing = 0.0, - loss_type = 'sigmoid', - disable_dropout = True, - cpo_alpha = 1.0, - simpo_gamma = 0.5, - alpha = 0.0, - label_pad_token_id = -100, - padding_value = None, - truncation_mode = 'keep_end', - generate_during_eval = False, - is_encoder_decoder = None, - model_init_kwargs = None, - dataset_num_proc = None, - vllm_sampling_params = None, - unsloth_num_chunks = -1, - unsloth_logit_chunk_multiplier = None, - unsloth_grpo_mini_batch = None, - max_seq_length = None, - **kwargs, - ): - if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!') - if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!') - if num_train_epochs is None: - num_train_epochs = 3.0 # Default to 3 epochs if None, max_steps will override - if output_dir is None and save_strategy == 'steps' and save_steps == 500: - output_dir = 'unsloth_training_checkpoints' - save_strategy = 'no' - import multiprocessing as _mp - if _mp.get_start_method() != 'fork': - dataset_num_proc = None - elif dataset_num_proc is None: - import psutil - dataset_num_proc = min(max((psutil.cpu_count() or 1)+4, 2), 64) - memory_gb_left = psutil.virtual_memory().available / (1024**3) - if memory_gb_left <= 2: dataset_num_proc = 1 - else: dataset_num_proc = min(dataset_num_proc, int(memory_gb_left)) - - super().__init__( - output_dir = output_dir, - overwrite_output_dir = overwrite_output_dir, - do_train = do_train, - do_eval = do_eval, - do_predict = do_predict, - eval_strategy = eval_strategy, - prediction_loss_only = prediction_loss_only, - per_device_train_batch_size = per_device_train_batch_size, - per_device_eval_batch_size = per_device_eval_batch_size, - per_gpu_train_batch_size = per_gpu_train_batch_size, - per_gpu_eval_batch_size = per_gpu_eval_batch_size, - gradient_accumulation_steps = gradient_accumulation_steps, - eval_accumulation_steps = eval_accumulation_steps, - eval_delay = eval_delay, - torch_empty_cache_steps = torch_empty_cache_steps, - learning_rate = learning_rate, - weight_decay = weight_decay, - adam_beta1 = adam_beta1, - adam_beta2 = adam_beta2, - adam_epsilon = adam_epsilon, - max_grad_norm = max_grad_norm, - num_train_epochs = num_train_epochs, - max_steps = max_steps, - lr_scheduler_type = lr_scheduler_type, - lr_scheduler_kwargs = lr_scheduler_kwargs, - warmup_ratio = warmup_ratio, - warmup_steps = warmup_steps, - log_level = log_level, - log_level_replica = log_level_replica, - log_on_each_node = log_on_each_node, - logging_dir = logging_dir, - logging_strategy = logging_strategy, - logging_first_step = logging_first_step, - logging_steps = logging_steps, - logging_nan_inf_filter = logging_nan_inf_filter, - save_strategy = save_strategy, - save_steps = save_steps, - save_total_limit = save_total_limit, - save_safetensors = save_safetensors, - save_on_each_node = save_on_each_node, - save_only_model = save_only_model, - restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint, - no_cuda = no_cuda, - use_cpu = use_cpu, - use_mps_device = use_mps_device, - seed = seed, - data_seed = data_seed, - jit_mode_eval = jit_mode_eval, - bf16 = bf16, - fp16 = fp16, - fp16_opt_level = fp16_opt_level, - half_precision_backend = half_precision_backend, - bf16_full_eval = bf16_full_eval, - fp16_full_eval = fp16_full_eval, - tf32 = tf32, - local_rank = local_rank, - ddp_backend = ddp_backend, - tpu_num_cores = tpu_num_cores, - tpu_metrics_debug = tpu_metrics_debug, - debug = debug, - dataloader_drop_last = dataloader_drop_last, - eval_steps = eval_steps, - dataloader_num_workers = dataloader_num_workers, - dataloader_prefetch_factor = dataloader_prefetch_factor, - past_index = past_index, - run_name = run_name, - disable_tqdm = disable_tqdm, - remove_unused_columns = remove_unused_columns, - label_names = label_names, - load_best_model_at_end = load_best_model_at_end, - metric_for_best_model = metric_for_best_model, - greater_is_better = greater_is_better, - ignore_data_skip = ignore_data_skip, - fsdp = fsdp, - fsdp_min_num_params = fsdp_min_num_params, - fsdp_config = fsdp_config, - fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap, - accelerator_config = accelerator_config, - parallelism_config = parallelism_config, - deepspeed = deepspeed, - label_smoothing_factor = label_smoothing_factor, - optim = optim, - optim_args = optim_args, - adafactor = adafactor, - group_by_length = group_by_length, - length_column_name = length_column_name, - report_to = report_to, - project = project, - trackio_space_id = trackio_space_id, - ddp_find_unused_parameters = ddp_find_unused_parameters, - ddp_bucket_cap_mb = ddp_bucket_cap_mb, - ddp_broadcast_buffers = ddp_broadcast_buffers, - dataloader_pin_memory = dataloader_pin_memory, - dataloader_persistent_workers = dataloader_persistent_workers, - skip_memory_metrics = skip_memory_metrics, - use_legacy_prediction_loop = use_legacy_prediction_loop, - push_to_hub = push_to_hub, - resume_from_checkpoint = resume_from_checkpoint, - hub_model_id = hub_model_id, - hub_strategy = hub_strategy, - hub_token = hub_token, - hub_private_repo = hub_private_repo, - hub_always_push = hub_always_push, - hub_revision = hub_revision, - gradient_checkpointing = gradient_checkpointing, - gradient_checkpointing_kwargs = gradient_checkpointing_kwargs, - include_inputs_for_metrics = include_inputs_for_metrics, - eval_do_concat_batches = eval_do_concat_batches, - fp16_backend = fp16_backend, - push_to_hub_model_id = push_to_hub_model_id, - push_to_hub_organization = push_to_hub_organization, - push_to_hub_token = push_to_hub_token, - mp_parameters = mp_parameters, - auto_find_batch_size = auto_find_batch_size, - full_determinism = full_determinism, - torchdynamo = torchdynamo, - ray_scope = ray_scope, - ddp_timeout = ddp_timeout, - torch_compile = torch_compile, - torch_compile_backend = torch_compile_backend, - torch_compile_mode = torch_compile_mode, - include_tokens_per_second = include_tokens_per_second, - include_num_input_tokens_seen = include_num_input_tokens_seen, - neftune_noise_alpha = neftune_noise_alpha, - optim_target_modules = optim_target_modules, - batch_eval_metrics = batch_eval_metrics, - eval_on_start = eval_on_start, - use_liger_kernel = use_liger_kernel, - liger_kernel_config = liger_kernel_config, - eval_use_gather_object = eval_use_gather_object, - average_tokens_across_devices = average_tokens_across_devices, - max_length = max_length, - max_prompt_length = max_prompt_length, - max_completion_length = max_completion_length, - beta = beta, - label_smoothing = label_smoothing, - loss_type = loss_type, - disable_dropout = disable_dropout, - cpo_alpha = cpo_alpha, - simpo_gamma = simpo_gamma, - alpha = alpha, - label_pad_token_id = label_pad_token_id, - padding_value = padding_value, - truncation_mode = truncation_mode, - generate_during_eval = generate_during_eval, - is_encoder_decoder = is_encoder_decoder, - model_init_kwargs = model_init_kwargs, - dataset_num_proc = dataset_num_proc,**kwargs) - self.vllm_sampling_params = vllm_sampling_params - self.unsloth_num_chunks = unsloth_num_chunks - if unsloth_grpo_mini_batch is not None: - if self.generation_batch_size >= unsloth_grpo_mini_batch: - self.unsloth_grpo_mini_batch = unsloth_grpo_mini_batch - else: - raise ValueError( - f"Unsloth GRPO mini batch size needs to be less than or equal to the effective generation batch size, " - f"which is self.per_device_train_batch_size * gradient_accumulation_steps." - ) - self.unsloth_logit_chunk_multiplier = unsloth_logit_chunk_multiplier - self.max_seq_length = max_seq_length - -pass - -class _UnslothCPOTrainer(BaseTrainer): - r"""""" - - _tag_names = ["trl", "cpo"] - _name = "CPO" - _paper = { - "title": "Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation", - "id": "2401.08417", - # docstyle-ignore - "citation": textwrap.dedent("""\ - @inproceedings{xu2024contrastive, - title = {{Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation}}, - author = {Haoran Xu and Amr Sharaf and Yunmo Chen and Weiting Tan and Lingfeng Shen and Benjamin Van Durme and Kenton Murray and Young Jin Kim}, - year = 2024, - booktitle = {Forty-first International Conference on Machine Learning, {ICML} 2024, Vienna, Austria, July 21-27, 2024}, - publisher = {OpenReview.net}, - url = {https://openreview.net/forum?id=51iwkioZpn} - }"""), - } - - def __init__( - self, - model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, - args: Optional[CPOConfig] = None, - data_collator: Optional[DataCollator] = None, - train_dataset: Optional[Dataset] = None, - eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None, - processing_class: Optional[ - Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] - ] = None, - model_init: Optional[Callable[[], PreTrainedModel]] = None, - callbacks: Optional[list[TrainerCallback]] = None, - optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), - preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, - peft_config: Optional[dict] = None, - compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None, - ): - if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"): - warnings.warn( - "This trainer will soon be moved to trl.experimental and is a candidate for removal. If you rely on " - "it and want it to remain, please share your comments here: " - "https://github.com/huggingface/trl/issues/4223. Silence this warning by setting environment variable " - "TRL_EXPERIMENTAL_SILENCE=1." - ) - if args.model_init_kwargs is None: - model_init_kwargs = {} - elif not isinstance(model, str): - raise ValueError("You passed model_kwargs to the CPOTrainer. But your model is already instantiated.") - else: - model_init_kwargs = args.model_init_kwargs - dtype = model_init_kwargs.get("dtype") - if dtype is not None: - # Convert to `torch.dtype` if an str is passed - if isinstance(dtype, str) and dtype != "auto": - dtype = getattr(torch, dtype) - if dtype != "auto" and not isinstance(dtype, torch.dtype): - raise ValueError( - f"Invalid `dtype` passed to the CPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {dtype}." - ) - model_init_kwargs["dtype"] = dtype - - if isinstance(model, str): - model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) - - # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16` - # has been called in order to properly call autocast if needed. - self._peft_has_been_casted_to_bf16 = False - - if not is_peft_available() and peft_config is not None: - raise ValueError( - "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models" - ) - elif is_peft_available() and peft_config is not None: - # if model is a peft model and we have a peft_config, we merge and unload it first - if isinstance(model, PeftModel): - model = model.merge_and_unload() - - if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): - _support_gc_kwargs = hasattr( - args, "gradient_checkpointing_kwargs" - ) and "gradient_checkpointing_kwargs" in list( - inspect.signature(prepare_model_for_kbit_training).parameters - ) - - prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing} - - if _support_gc_kwargs: - prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs - - model = prepare_model_for_kbit_training(model, **prepare_model_kwargs) - elif args.gradient_checkpointing: - # For backward compatibility with older versions of transformers - if hasattr(model, "enable_input_require_grads"): - model.enable_input_require_grads() - else: - - def make_inputs_require_grad(module, input, output): - output.requires_grad_(True) - - model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) - - # get peft model with the given config - model = model - if args.bf16 and getattr(model, "is_loaded_in_4bit", False): - peft_module_casting_to_bf16(model) - # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager - self._peft_has_been_casted_to_bf16 = True - - # For models that use gradient_checkpointing, we need to attach a hook that enables input - # to explicitly have `requires_grad=True`, otherwise training will either silently - # fail or completely fail. - elif args.gradient_checkpointing: - # For backward compatibility with older versions of transformers - if hasattr(model, "enable_input_require_grads"): - model.enable_input_require_grads() - else: - - def make_inputs_require_grad(module, input, output): - output.requires_grad_(True) - - model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) - - if args.generate_during_eval and not (is_wandb_available() or is_comet_available()): - raise ValueError( - "`generate_during_eval=True` requires Weights and Biases or Comet to be installed." - " Please install `wandb` or `comet-ml` to resolve." - ) - - if model is not None: - self.is_encoder_decoder = model.config.is_encoder_decoder - elif args.is_encoder_decoder is None: - raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.") - else: - self.is_encoder_decoder = args.is_encoder_decoder - - if self.is_encoder_decoder: - self.decoder_start_token_id = model.config.decoder_start_token_id - self.pad_token_id = model.config.pad_token_id - - if processing_class is None: - raise ValueError("processing_class must be specified to tokenize a CPO dataset.") - if args.max_length is None: - logger.warning( - "`max_length` is not set in the CPOConfig's init" - " it will default to `512` by default, but you should do it yourself in the future.", - ) - max_length = 512 - else: - max_length = args.max_length - if args.max_prompt_length is None: - logger.warning( - "`max_prompt_length` is not set in the CPOConfig's init" - " it will default to `128` by default, but you should do it yourself in the future.", - ) - max_prompt_length = 128 - else: - max_prompt_length = args.max_prompt_length - - if not max_prompt_length < max_length: - raise ValueError( - f"max_prompt_length ({max_prompt_length}) should be strictly less than max_length ({max_length})." - ) - - if args.max_completion_length is None and self.is_encoder_decoder: - logger.warning( - "When using an encoder decoder architecture, you should set `max_completion_length` in the CPOConfig's init" - " it will default to `128` by default, but you should do it yourself in the future.", - ) - max_completion_length = 128 - else: - max_completion_length = args.max_completion_length - - if data_collator is None: - data_collator = DPODataCollatorWithPadding( - pad_token_id=processing_class.pad_token_id, - label_pad_token_id=args.label_pad_token_id, - is_encoder_decoder=self.is_encoder_decoder, - ) - - if args.remove_unused_columns: - args.remove_unused_columns = False - # warn users - logger.warning( - "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments" - " we have set it for you, but you should do it yourself in the future.", - ) - - self.use_dpo_data_collator = True - else: - self.use_dpo_data_collator = False - - # Disable dropout in the model - if args.disable_dropout: - disable_dropout_in_model(model) - - self.max_length = max_length - self.generate_during_eval = args.generate_during_eval - self.label_pad_token_id = args.label_pad_token_id - self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id - self.max_prompt_length = max_prompt_length - self.truncation_mode = args.truncation_mode - self.max_completion_length = max_completion_length - self.processing_class = processing_class - - if args.loss_type in ["hinge", "ipo"] and args.label_smoothing > 0: - logger.warning( - f"You are using the {args.loss_type} loss type that does not support label smoothing. The " - "`label_smoothing` parameter will be ignored. Set `label_smoothing` to `0.0` to remove this warning.", - ) - if args.loss_type == "kto_pair": - raise ValueError("Support for kto_pair has been removed in CPOTrainer. Please use KTOTrainer.") - - self.beta = args.beta - self.label_smoothing = args.label_smoothing - self.loss_type = args.loss_type - self.cpo_alpha = args.cpo_alpha - self.aux_loss_enabled = getattr(model.config, "output_router_logits", False) - self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0) - if self.aux_loss_enabled and self.aux_loss_coef == 0.0: - logger.warning( - "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to " - "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value " - "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary " - "loss.", - ) - - if args.loss_type == "simpo": - self.simpo_gamma = args.simpo_gamma - - # AlphaPO parameter for reward shaping - self.alpha = args.alpha - - self._stored_metrics = defaultdict(lambda: defaultdict(list)) - - # The trainer estimates the number of FLOPs [floating-point operations] using the number of elements in the - # input tensor associated with the key "input_ids". However, in CPO, the sampled data does not include the - # "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and - # "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens - # of the input, floating-point operations will not be computed." To suppress this warning, we set the - # "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate - # that the warning has already been issued. - model.warnings_issued["estimate_tokens"] = True - - # Compute that only on the main process for faster data processing. - # see: https://github.com/huggingface/trl/pull/1255 - with PartialState().main_process_first(): - # Extract the prompt if needed, and apply the chat template if needed - train_dataset = train_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc) - train_dataset = train_dataset.map( - maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc - ) - if eval_dataset is not None: - eval_dataset = eval_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc) - eval_dataset = eval_dataset.map( - maybe_apply_chat_template, - fn_kwargs={"tokenizer": processing_class}, - num_proc=args.dataset_num_proc, - ) - - # tokenize the dataset - train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc) - if eval_dataset is not None: - eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc) - - super().__init__( - model=model, - args=args, - data_collator=data_collator, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - processing_class=processing_class, - model_init=model_init, - compute_metrics=compute_metrics, - callbacks=callbacks, - optimizers=optimizers, - preprocess_logits_for_metrics=preprocess_logits_for_metrics, - ) - - # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the - # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set - # self.model_accepts_loss_kwargs to False to enable scaling. - self.model_accepts_loss_kwargs = False - - # Add tags for models that have been loaded with the correct transformers version - if hasattr(self.model, "add_model_tags"): - self.model.add_model_tags(self._tag_names) - - if not hasattr(self, "accelerator"): - raise AttributeError( - "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`." - ) - - def build_tokenized_answer(self, prompt, answer): - """ - Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`. It does ensure `enc(a + b) = enc(a) + enc(a + - b)[len(enc(a)):]`. Reference: - https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257 - """ - - full_tokenized = self.processing_class(prompt + answer, add_special_tokens=False) - prompt_input_ids = self.processing_class(prompt, add_special_tokens=False)["input_ids"] - - answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :] - answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :] - - # Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]` - full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids]) - - # Prepare input tokens for token by token comparison - full_input_ids = np.array(full_tokenized["input_ids"]) - - if len(full_input_ids) != len(full_concat_input_ids): - raise ValueError("Prompt input ids and answer input ids should have the same length.") - - # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens - # can be merged together when tokenizing prompt+answer. This could result - # on the last token from the prompt being different when tokenized on its own - # vs when done as prompt+answer. - response_token_ids_start_idx = len(prompt_input_ids) - - # If tokenized prompt is different than both prompt+answer, then it means the - # last token has changed due to merging. - if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]: - response_token_ids_start_idx -= 1 - - prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx] - prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx] - - if len(prompt_input_ids) != len(prompt_attention_mask): - raise ValueError("Prompt input ids and attention mask should have the same length.") - - answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:] - answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:] - - return dict( - prompt_input_ids=prompt_input_ids, - prompt_attention_mask=prompt_attention_mask, - input_ids=answer_input_ids, - attention_mask=answer_attention_mask, - ) - - def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> dict: - """Tokenize a single row from a CPO specific dataset. - - At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation in case the prompt + - chosen or prompt + rejected responses is/are too long. First we truncate the prompt; if we're still too long, - we truncate the chosen/rejected. - - We also create the labels for the chosen/rejected responses, which are of length equal to the sum of the length - of the prompt and the chosen/rejected response, with label_pad_token_id for the prompt tokens. - """ - batch = {} - prompt = feature["prompt"] - chosen = feature["chosen"] - rejected = feature["rejected"] - - if not self.is_encoder_decoder: - # Check issues below for more details - # 1. https://github.com/huggingface/trl/issues/907 - # 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257 - # 3. https://github.com/LianjiaTech/BELLE/issues/337 - - if not isinstance(prompt, str): - raise ValueError(f"prompt should be an str but got {type(prompt)}") - prompt_tokens = self.processing_class(prompt, add_special_tokens=False) - prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()} - - if not isinstance(chosen, str): - raise ValueError(f"chosen should be an str but got {type(chosen)}") - chosen_tokens = self.build_tokenized_answer(prompt, chosen) - - if not isinstance(rejected, str): - raise ValueError(f"rejected should be an str but got {type(rejected)}") - rejected_tokens = self.build_tokenized_answer(prompt, rejected) - - # Last prompt token might get merged by tokenizer and - # it should not be included for generation if that happens - prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"]) - - chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"]) - rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"]) - prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids) - - for k, v in prompt_tokens.items(): - prompt_tokens[k] = v[:prompt_len_input_ids] - - # Make sure prompts only have one different token at most an - # and length only differs by 1 at most - num_diff_tokens = sum( - a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"]) - ) - num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids) - if num_diff_tokens > 1 or num_diff_len > 1: - raise ValueError( - "Chosen and rejected prompt_input_ids might only differ on the " - "last token due to tokenizer merge ops." - ) - - # add BOS token to head of prompt. Avoid adding if it's already there - prompt_tokens, chosen_tokens, rejected_tokens = add_bos_token_if_needed( - self.processing_class.bos_token_id, - prompt_len_input_ids, - prompt_tokens, - chosen_prompt_len_input_ids, - chosen_tokens, - rejected_prompt_len_input_ids, - rejected_tokens, - ) - - # add EOS token to end of answer. Avoid adding if it's already there - chosen_tokens, rejected_tokens = add_eos_token_if_needed( - self.processing_class.eos_token_id, chosen_tokens, rejected_tokens - ) - - longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"])) - - # if combined sequence is too long, truncate the prompt - for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]: - if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length: - if self.truncation_mode == "keep_start": - for k in ["prompt_input_ids", "prompt_attention_mask"]: - answer_tokens[k] = answer_tokens[k][: self.max_prompt_length] - elif self.truncation_mode == "keep_end": - for k in ["prompt_input_ids", "prompt_attention_mask"]: - answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :] - else: - raise ValueError(f"Unknown truncation mode: {self.truncation_mode}") - - # if that's still too long, truncate the response - for answer_tokens in [chosen_tokens, rejected_tokens]: - if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length: - for k in ["input_ids", "attention_mask"]: - answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length] - - # Create labels - chosen_sequence_tokens = { - k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"] - } - rejected_sequence_tokens = { - k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"] - } - chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:] - chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [ - self.label_pad_token_id - ] * len(chosen_tokens["prompt_input_ids"]) - rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:] - rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [ - self.label_pad_token_id - ] * len(rejected_tokens["prompt_input_ids"]) - - for k, toks in { - "chosen_": chosen_sequence_tokens, - "rejected_": rejected_sequence_tokens, - "": prompt_tokens, - }.items(): - for type_key, tokens in toks.items(): - if type_key == "token_type_ids": - continue - batch[f"{k}{type_key}"] = tokens - - else: - chosen_tokens = self.processing_class( - chosen, truncation=True, max_length=self.max_completion_length, add_special_tokens=True - ) - rejected_tokens = self.processing_class( - rejected, truncation=True, max_length=self.max_completion_length, add_special_tokens=True - ) - prompt_tokens = self.processing_class( - prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True - ) - - batch["chosen_labels"] = chosen_tokens["input_ids"] - batch["rejected_labels"] = rejected_tokens["input_ids"] - batch["prompt_input_ids"] = prompt_tokens["input_ids"] - batch["prompt_attention_mask"] = prompt_tokens["attention_mask"] - - if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"): - batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels( - labels=torch.tensor(batch["rejected_labels"]) - ) - batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels( - labels=torch.tensor(batch["chosen_labels"]) - ) - - return batch - - @staticmethod - def concatenated_inputs( - batch: dict[str, Union[list, torch.LongTensor]], - is_encoder_decoder: bool = False, - label_pad_token_id: int = -100, - padding_value: int = 0, - device: Optional[torch.device] = None, - ) -> dict[str, torch.LongTensor]: - """Concatenate the chosen and rejected inputs into a single tensor. - - Args: - batch: - A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors - of shape (batch_size, sequence_length). - is_encoder_decoder: - Whether the model is an encoder-decoder model. - label_pad_token_id: - The label pad token id. - padding_value: - The padding value to use for the concatenated inputs_ids. - device: - The device for the concatenated inputs. - - Returns: - A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'. - """ - concatenated_batch = {} - - if is_encoder_decoder: - max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1]) - else: - max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1]) - - for k in batch: - if k.startswith("chosen") and isinstance(batch[k], torch.Tensor): - if "labels" in k or is_encoder_decoder: - pad_value = label_pad_token_id - elif k.endswith("_input_ids"): - pad_value = padding_value - elif k.endswith("_attention_mask"): - pad_value = 0 - concatenated_key = k.replace("chosen", "concatenated") - concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value) - for k in batch: - if k.startswith("rejected") and isinstance(batch[k], torch.Tensor): - if "labels" in k or is_encoder_decoder: - pad_value = label_pad_token_id - elif k.endswith("_input_ids"): - pad_value = padding_value - elif k.endswith("_attention_mask"): - pad_value = 0 - concatenated_key = k.replace("rejected", "concatenated") - concatenated_batch[concatenated_key] = torch.cat( - ( - concatenated_batch[concatenated_key], - pad_to_length(batch[k], max_length, pad_value=pad_value), - ), - dim=0, - ).to(device=device) - - if is_encoder_decoder: - concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device) - concatenated_batch["concatenated_attention_mask"] = ( - batch["prompt_attention_mask"].repeat(2, 1).to(device=device) - ) - - return concatenated_batch - - def cpo_loss( - self, - policy_chosen_logps: torch.FloatTensor, - policy_rejected_logps: torch.FloatTensor, - ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: - """Compute the CPO loss for a batch of policy and reference model log probabilities. - - Args: - policy_chosen_logps: - Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) - policy_rejected_logps: - Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) - - Returns: - A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). The losses tensor contains the CPO - loss for each example in the batch. The chosen_rewards and rejected_rewards tensors contain the rewards for - the chosen and rejected responses, respectively. - """ - # Apply AlphaPO reward transformation if alpha != 0 - if self.alpha != 0.0: - # Compute probabilities - chosen_probs = torch.exp(policy_chosen_logps) - rejected_probs = torch.exp(policy_rejected_logps) - - # Apply AlphaPO transformation: r = (1 - p^(-alpha)) / alpha - policy_chosen_rewards = (1 - chosen_probs.pow(-self.alpha)) / self.alpha - policy_rejected_rewards = (1 - rejected_probs.pow(-self.alpha)) / self.alpha - - logits = (policy_chosen_rewards - policy_rejected_rewards).to(self.accelerator.device) - else: - # Standard log probability rewards when alpha = 0 - logits = (policy_chosen_logps - policy_rejected_logps).to(self.accelerator.device) - - # The beta is a temperature parameter for the CPO loss, typically something in the range of 0.1 to 0.5. - # We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and - # calculates a conservative CPO loss. - - if self.loss_type == "simpo": - gamma_logratios = self.simpo_gamma / self.beta - logits = logits - gamma_logratios - # This reduces to Equation 3 from the CPO paper when label_smoothing -> 0. - losses = ( - -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) - - F.logsigmoid(-self.beta * logits) * self.label_smoothing - ) - elif self.loss_type == "sigmoid": - # This reduces to Equation 3 from the CPO paper when label_smoothing -> 0. - losses = ( - -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) - - F.logsigmoid(-self.beta * logits) * self.label_smoothing - ) - elif self.loss_type == "hinge": - losses = torch.relu(1 - self.beta * logits) - elif self.loss_type == "ipo": - # eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper. - losses = (logits - 1 / (2 * self.beta)) ** 2 - else: - raise ValueError( - f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'simpo']" - ) - - # Calculate rewards for logging - if self.alpha != 0.0: - # When using AlphaPO transformation, use the transformed rewards - chosen_rewards = self.beta * policy_chosen_rewards.to(self.accelerator.device).detach() - rejected_rewards = self.beta * policy_rejected_rewards.to(self.accelerator.device).detach() - else: - # Standard log probability rewards - chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach() - rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach() - - return losses, chosen_rewards, rejected_rewards - - @staticmethod - def get_batch_logps( - logits: torch.FloatTensor, - labels: torch.LongTensor, - average_log_prob: bool = False, - label_pad_token_id: int = -100, - is_encoder_decoder: bool = False, - ) -> torch.FloatTensor: - """Compute the log probabilities of the given labels under the given logits. - - Args: - logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) - labels: - Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are - ignored. Shape: (batch_size, sequence_length) - average_log_prob: - If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the - log probabilities of the (non-masked) tokens. - label_pad_token_id: The label pad token id. - is_encoder_decoder: Whether the model is an encoder-decoder model. - - Returns: - A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the - given logits. - """ - if logits.shape[:-1] != labels.shape: - raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.") - - if not is_encoder_decoder: - labels = labels[:, 1:].clone() - logits = logits[:, :-1, :] - loss_mask = labels != label_pad_token_id - - # dummy token; we'll ignore the losses on these tokens later - labels[labels == label_pad_token_id] = 0 - - per_token_logps = selective_log_softmax(logits, labels) - - if average_log_prob: - return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) - else: - return (per_token_logps * loss_mask).sum(-1) - - def concatenated_forward( - self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]] - ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: - """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. - - We do this to avoid doing two forward passes, because it's faster for FSDP. - """ - concatenated_batch = self.concatenated_inputs( - batch, - is_encoder_decoder=self.is_encoder_decoder, - label_pad_token_id=self.label_pad_token_id, - padding_value=self.padding_value, - device=self.accelerator.device, - ) - len_chosen = batch["chosen_labels"].shape[0] - - model_kwargs = ( - { - "decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]), - } - if self.is_encoder_decoder - else {} - ) - - if self.aux_loss_enabled: - model_kwargs["output_router_logits"] = True - - outputs = model( - concatenated_batch["concatenated_input_ids"], - attention_mask=concatenated_batch["concatenated_attention_mask"], - use_cache=False, - **model_kwargs, - ) - all_logits = outputs.logits - - def cross_entropy_loss(logits, labels): - if not self.is_encoder_decoder: - # Shift so that tokens < n predict n - logits = logits[..., :-1, :].contiguous() - labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = nn.CrossEntropyLoss() - logits = logits.view(-1, logits.shape[-1]) - labels = labels.view(-1) - # Enable model parallelism - labels = labels.to(logits.device) - loss = loss_fct(logits, labels) - return loss - - labels = concatenated_batch["concatenated_labels"].clone() - - if self.cpo_alpha == 0: - nll_loss = torch.tensor(0.0).to(self.accelerator.device) - else: - nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen]) - - all_logps = self.get_batch_logps( - all_logits, - concatenated_batch["concatenated_labels"], - average_log_prob=self.loss_type in ["ipo", "simpo"], - is_encoder_decoder=self.is_encoder_decoder, - label_pad_token_id=self.label_pad_token_id, - ) - - chosen_logps = all_logps[:len_chosen] - rejected_logps = all_logps[len_chosen:] - - chosen_logits = all_logits[:len_chosen] - rejected_logits = all_logits[len_chosen:] - - if self.aux_loss_enabled: - return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss, outputs.aux_loss) - - return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss) - - def get_batch_loss_metrics( - self, - model, - batch: dict[str, Union[list, torch.LongTensor]], - train_eval: Literal["train", "eval"] = "train", - ): - """Compute the CPO loss and other metrics for the given batch of inputs for train or test.""" - metrics = {} - - forward_output = self.concatenated_forward(model, batch) - ( - policy_chosen_logps, - policy_rejected_logps, - policy_chosen_logits, - policy_rejected_logits, - policy_nll_loss, - ) = forward_output[:5] - if self.aux_loss_enabled: - aux_loss = forward_output[5] - - losses, chosen_rewards, rejected_rewards = self.cpo_loss( - policy_chosen_logps, - policy_rejected_logps, - ) - - loss = losses.mean() + self.cpo_alpha * policy_nll_loss - reward_accuracies = (chosen_rewards > rejected_rewards).float() - - prefix = "eval_" if train_eval == "eval" else "" - metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean().item() - metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean().item() - metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean().item() - metrics[f"{prefix}rewards/margins"] = ( - self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards).mean().item() - ) - metrics[f"{prefix}logps/rejected"] = ( - self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean().item() - ) - metrics[f"{prefix}logps/chosen"] = ( - self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean().item() - ) - metrics[f"{prefix}logits/rejected"] = ( - self.accelerator.gather_for_metrics(policy_rejected_logits.detach().mean()).mean().item() - ) - metrics[f"{prefix}logits/chosen"] = ( - self.accelerator.gather_for_metrics(policy_chosen_logits.detach().mean()).mean().item() - ) - metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean().item() - - if self.aux_loss_enabled: - loss += self.aux_loss_coef * aux_loss - - return loss, metrics - - def compute_loss( - self, - model: Union[PreTrainedModel, nn.Module], - inputs: dict[str, Union[torch.Tensor, Any]], - return_outputs=False, - num_items_in_batch=None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]: - compute_loss_context_manager = ( - autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() - ) - - with compute_loss_context_manager: - loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train") - - # force log the metrics - self.store_metrics(metrics, train_eval="train") - - if return_outputs: - return (loss, metrics) - return loss - - def generate_from_model(self, model, batch: dict[str, torch.LongTensor]) -> str: - """Generate samples from the model and reference model for the given batch of inputs.""" - - # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with - # the torch amp context manager as some hidden states are silently casted to full precision. - generate_context_manager = ( - autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() - ) - - with generate_context_manager: - policy_output = model.generate( - input_ids=batch["prompt_input_ids"], - attention_mask=batch["prompt_attention_mask"], - max_length=self.max_length, - do_sample=True, - pad_token_id=self.processing_class.pad_token_id, - ) - - policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id) - policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True) - - return policy_output_decoded - - def prediction_step( - self, - model: Union[PreTrainedModel, nn.Module], - inputs: dict[str, Union[torch.Tensor, Any]], - prediction_loss_only: bool, - ignore_keys: Optional[list[str]] = None, - ): - if ignore_keys is None: - if hasattr(model, "config"): - ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", []) - else: - ignore_keys = [] - - prediction_context_manager = ( - autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() - ) - - with torch.no_grad(), prediction_context_manager: - loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval") - - # force log the metrics - self.store_metrics(metrics, train_eval="eval") - - if prediction_loss_only: - return (loss.detach(), None, None) - - # logits for the chosen and rejected samples from model - logits_dict = { - "eval_logits/chosen": metrics["eval_logits/chosen"], - "eval_logits/rejected": metrics["eval_logits/rejected"], - } - logits = [v for k, v in logits_dict.items() if k not in ignore_keys] - logits = torch.tensor(logits, device=self.accelerator.device) - labels = torch.zeros(logits.shape[0], device=self.accelerator.device) - - return (loss.detach(), logits, labels) - - def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None: - for key, value in metrics.items(): - self._stored_metrics[train_eval][key].append(value) - - def evaluation_loop( - self, - dataloader: DataLoader, - description: str, - prediction_loss_only: Optional[bool] = None, - ignore_keys: Optional[list[str]] = None, - metric_key_prefix: str = "eval", - ) -> EvalLoopOutput: - """ - Overriding built-in evaluation loop to store metrics for each batch. Prediction/evaluation loop, shared by - `Trainer.evaluate()` and `Trainer.predict()`. - - Works both with or without labels. - """ - - # Sample and save to game log if requested (for one batch to save time) - if self.generate_during_eval: - # Generate random indices within the range of the total number of samples - num_samples = len(dataloader.dataset) - random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size) - - # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader - random_batch_dataset = dataloader.dataset.select(random_indices) - random_batch = self.data_collator(random_batch_dataset) - random_batch = self._prepare_inputs(random_batch) - - policy_output_decoded = self.generate_from_model(self.model, random_batch) - - table = pd.DataFrame( - columns=["Prompt", "Policy"], - data=[ - [prompt, pol[len(prompt) :]] for prompt, pol in zip(random_batch["prompt"], policy_output_decoded) - ], - ) - if "wandb" in self.args.report_to: - wandb.log({"game_log": wandb.Table(data=table)}) - - if "comet_ml" in self.args.report_to: - log_table_to_comet_experiment( - name="game_log.csv", - table=table, - ) - - # Base evaluation - initial_output = super().evaluation_loop( - dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix - ) - - return initial_output - - def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None: - """ - Log `logs` on the various objects watching training, including stored metrics. - - Args: - logs (`dict[str, float]`): - The values to log. - start_time (`float`, *optional*): - Start time of the training. - """ - # logs either has 'loss' or 'eval_loss' - train_eval = "train" if "loss" in logs else "eval" - # Add averaged stored metrics to logs - for key, metrics in self._stored_metrics[train_eval].items(): - logs[key] = torch.tensor(metrics).mean().item() - del self._stored_metrics[train_eval] - return super().log(logs, start_time) - - def _shift_right(self, input_ids): - if self.decoder_start_token_id is None: - raise ValueError( - "model.config.decoder_start_token_id has to be defined. It is usually set to the pad_token_id." - ) - - # shift inputs to the right - if is_torch_fx_proxy(input_ids): - # Item assignment is not supported natively for proxies. - shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), self.decoder_start_token_id) - shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1) - else: - shifted_input_ids = input_ids.new_zeros(input_ids.shape) - shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() - shifted_input_ids[..., 0] = self.decoder_start_token_id - - if self.pad_token_id is None: - raise ValueError("model.config.pad_token_id has to be defined.") - # replace possible -100 values in labels by `pad_token_id` - shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id) - - return shifted_input_ids - - # Ensure the model card is saved along with the checkpoint - def _save_checkpoint(self, model, trial): - if self.args.hub_model_id is None: - model_name = Path(self.args.output_dir).name - else: - model_name = self.args.hub_model_id.split("/")[-1] - self.create_model_card(model_name=model_name) - super()._save_checkpoint(model, trial) -class UnslothCPOTrainer(_UnslothCPOTrainer): - """ - - Initialize CPOTrainer. - - Args: - model ([`~transformers.PreTrainedModel`]): - The model to train, preferably an [`~transformers.AutoModelForSequenceClassification`]. - args ([`CPOConfig`]): - The CPO config arguments to use for training. - data_collator ([`~transformers.DataCollator`]): - The data collator to use for training. If None is specified, the default data collator - ([`DPODataCollatorWithPadding`]) will be used which will pad the sequences to the maximum length of the - sequences in the batch, given a dataset of paired sequences. - train_dataset ([`~datasets.Dataset`]): - The dataset to use for training. - eval_dataset ([`~datasets.Dataset`]): - The dataset to use for evaluation. - processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*): - Processing class used to process the data. If provided, will be used to automatically process the inputs - for the model, and it will be saved along the model to make it easier to rerun an interrupted training or - reuse the fine-tuned model. - model_init (`Callable[[], transformers.PreTrainedModel]`): - The model initializer to use for training. If None is specified, the default model initializer will be - used. - callbacks (`list[transformers.TrainerCallback]`): - The callbacks to use for training. - optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): - The optimizer and scheduler to use for training. - preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): - The function to use to preprocess the logits before computing the metrics. - peft_config (`dict`, defaults to `None`): - The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in - a PEFT model. - compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): - The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to - metric values. - - """ - def __init__( - self, - model = None, - args = None, - data_collator = None, - train_dataset = None, - eval_dataset = None, - processing_class = None, - model_init = None, - callbacks = None, - preprocess_logits_for_metrics = None, - peft_config = None, - compute_metrics = None, - **kwargs - ): - if args is None: args = UnslothCPOConfig() - use_bf16 = getattr(args, 'bf16', False) - if type(use_bf16) is not bool: use_bf16 = False - use_fp16 = getattr(args, 'fp16', False) - if type(use_fp16) is not bool: use_fp16 = False - force_float32 = False - full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1' - if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'): - print('Unsloth: Switching to float32 training since model cannot work with float16') - force_float32 = True - mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') - dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None) - if dtype is None: dtype = model.get_input_embeddings().weight.dtype - from unsloth_zoo.utils import _get_dtype - dtype = _get_dtype(dtype) - float16 = dtype == torch.float16 - if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`') - if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`') - if force_float32: - # Forced float32 training - args.fp16 = False - args.bf16 = False - os.environ['ACCELERATE_MIXED_PRECISION'] = 'no' - if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no' - # args.mixed_precision is a new argument which needs to be set now - elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32': - # Mixed precision training - args.fp16 = float16 - args.bf16 = not float16 - os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16' - if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16' - # args.mixed_precision is a new argument which needs to be set now - elif mixed_precision_dtype == 'bfloat16': - # Both False since bfloat16 full finetuning doesn't do any autocasting. - args.fp16 = False - args.bf16 = False - os.environ['ACCELERATE_MIXED_PRECISION'] = 'no' - if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no' - # args.mixed_precision is a new argument which needs to be set now - - if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no': - args.eval_strategy = 'steps' - if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1 - ga_steps = getattr(args, 'gradient_accumulation_steps', None) - if ga_steps is not None and ga_steps > 1: - from transformers import __version__ as transformers_version - if Version(transformers_version) <= Version('4.45.2'): - print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n' - '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`') - if getattr(args, 'eval_strategy', 'no') != 'no': - eval_bsz = getattr(args, 'per_device_eval_batch_size', 8) - if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size - if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps - fp16_full_eval = getattr(args, 'fp16_full_eval', False) - if type(fp16_full_eval) is not bool: fp16_full_eval = False - bf16_full_eval = getattr(args, 'bf16_full_eval', False) - if type(bf16_full_eval) is not bool: bf16_full_eval = False - if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True - if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False - if force_float32: - args.bf16_full_eval = False - args.fp16_full_eval = False - elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16': - args.bf16_full_eval = True - args.fp16_full_eval = False - elif not bf16_full_eval and not fp16_full_eval: - args.bf16_full_eval = args.bf16 - args.fp16_full_eval = args.fp16 - _output_logits = False - if locals().get('compute_metrics', None) is not None: _output_logits = True - if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True - if _output_logits: - os.environ['UNSLOTH_RETURN_LOGITS'] = '1' - if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'): - pass - else: - model_max_seq_length = getattr(model, 'max_seq_length', None) - args_max_seq_length = getattr(args, 'max_seq_length', None) - if args_max_seq_length is None and model_max_seq_length is not None: - max_seq_length = model.max_seq_length - if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length - elif args_max_seq_length is not None and model_max_seq_length is not None: - if args_max_seq_length > model_max_seq_length: - print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but ' - 'the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.') - args.max_seq_length = model_max_seq_length - if model is not None and hasattr(model, 'for_training'): - model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True)) - if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right' - if 'processing_class' in locals(): - if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right' - if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right' - __tokenizer = processing_class if 'processing_class' in locals() else tokenizer - from unsloth_zoo.vision_utils import UnslothVisionDataCollator - if not isinstance(data_collator, UnslothVisionDataCollator): - if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names: - data_collator = TransformersDataCollatorForLanguageModeling( - __tokenizer, - mlm = False, - mlm_probability = 0.0, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names: - data_collator = DataCollatorForSeq2Seq( - __tokenizer, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - else: - if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False - if hasattr(args, 'dataset_text_field'): args.dataset_text_field = '' - if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True} - if not isinstance(data_collator, UnslothVisionDataCollator): - if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'): - if isinstance(data_collator, DataCollatorForSeq2Seq): - data_collator = DataCollatorForSeq2Seq( - __tokenizer.tokenizer, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - else: - data_collator = TransformersDataCollatorForLanguageModeling( - __tokenizer.tokenizer, - mlm = False, - mlm_probability = 0.0, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - other_metrics = [] - - from unsloth_zoo.logging_utils import PatchRLStatistics - PatchRLStatistics('cpo_trainer', other_metrics) - - # [TODO] Fix up DataParallel multiplying batch sizes - # [TODO] DDP works, but DP seems to not work? [TODO] - if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1: - if getattr(args, "_n_gpu", 1) != 1: - args._n_gpu = 1 - if "model" in locals() and hasattr(model, "for_training"): - model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True)) - super().__init__( - model = model, - args = args, - data_collator = data_collator, - train_dataset = train_dataset, - eval_dataset = eval_dataset, - processing_class = processing_class, - model_init = model_init, - callbacks = callbacks, - preprocess_logits_for_metrics = preprocess_logits_for_metrics, - peft_config = peft_config, - compute_metrics = compute_metrics,**kwargs) - if "model" in locals() and hasattr(model, "for_inference"): - model.for_inference() - if hasattr(self, 'neftune_hook_handle'): - self.neftune_hook_handle.remove() - if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle - if getattr(args, 'neftune_noise_alpha', None) is not None: - model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha - pass - if hasattr(self, 'accelerator'): - scaler = self.accelerator.scaler - current_model = model - while hasattr(current_model, 'model'): - current_model.accelerator_scaler = scaler - current_model = current_model.model - current_model.accelerator_scaler = scaler - pass - if hasattr(self, 'train'): - self.train = MethodType(prepare_for_training_mode(self.__class__.train), self) - pass - if hasattr(self, 'llm') and self.llm is not None and hasattr(self.llm, 'get_tokenizer'): - _vllm_tok = self.llm.get_tokenizer() - _pc = getattr(self, 'processing_class', None) or getattr(self, 'tokenizer', None) - if _vllm_tok is not None and _pc is not None and getattr(_pc, 'chat_template', None) is not None and getattr(_vllm_tok, 'chat_template', None) is None: - _vllm_tok.chat_template = _pc.chat_template - pass - -pass - - -if hasattr(logger, "addFilter"): - import logging - class HideLoggingMessage(logging.Filter): - def __init__(self, text): self.text = text - def filter(self, x): return not (self.text in x.getMessage()) - pass - logger.addFilter(HideLoggingMessage("`use_cache=True`")) - diff --git a/unsloth_compiled_cache/UnslothDPOTrainer.py b/unsloth_compiled_cache/UnslothDPOTrainer.py deleted file mode 100644 index fc4dd5d..0000000 --- a/unsloth_compiled_cache/UnslothDPOTrainer.py +++ /dev/null @@ -1,2898 +0,0 @@ -""" -2026.2.1 -2026.2.1 -4.57.6 -0.24.0 -__UNSLOTH_VERSIONING__ -""" - -# Unsloth auto generated code -# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see . - -from torch import Tensor -import torch -import torch.nn as nn -from torch.nn import functional as F -from unsloth_zoo.temporary_patches.common import torch_compile -from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable -from trl.trainer.dpo_trainer import (Any, AutoProcessor, BaseImageProcessor, BaseTrainer, Callable, DPOConfig, DPOTrainer, DataCollator, DataCollatorForPreference, DataLoader, Dataset, EvalLoopOutput, F, FDivergenceConstants, FDivergenceType, FeatureExtractionMixin, IterableDataset, Literal, MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES, Optional, PartialState, Path, PeftConfig, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, RunningMoments, SyncRefModelCallback, TrainerCallback, Union, autocast, cap_exp, contextmanager, create_model_from_path, create_reference_model, dataclass, defaultdict, disable_dropout_in_model, empty_cache, flush_left, flush_right, get_peft_model, inspect, is_comet_available, is_liger_kernel_available, is_mlflow_available, is_peft_available, is_wandb_available, log_table_to_comet_experiment, logger, logging, maybe_apply_chat_template, maybe_extract_prompt, nn, nullcontext, pad, pad_to_length, pd, peft_module_casting_to_bf16, prepare_deepspeed, prepare_fsdp, prepare_model_for_kbit_training, random, selective_log_softmax, shift_tokens_right, textwrap, torch, tqdm, warnings, Any, AutoProcessor, BaseImageProcessor, Callable, DPOConfig, DPOTrainer, DataCollator, DataCollatorForPreference, Dataset, EvalLoopOutput, F, FDivergenceConstants, FeatureExtractionMixin, IterableDataset, MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES, Optional, PeftConfig, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, RunningMoments, SyncRefModelCallback, TrainerCallback, Union, create_model_from_path, create_reference_model, defaultdict, disable_dropout_in_model, is_comet_available, is_liger_kernel_available, is_mlflow_available, is_peft_available, is_wandb_available, logger, nn, pad, prepare_deepspeed, prepare_fsdp, torch, warnings, F, Optional, PeftModel, PreTrainedModel, is_peft_available, logger, torch) - - -import os -from typing import * -from dataclasses import dataclass, field -from packaging.version import Version -import torch -import numpy as np -from contextlib import nullcontext -from torch.nn import functional as F -import inspect -from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling -from transformers.training_args import ParallelMode -from unsloth_zoo.device_type import DEVICE_TYPE, device_synchronize - -# Wrap trainer with padding to right and enable training mode -# Also patches W&B since multiple runs must use wandb.finish() -import functools -from types import MethodType -try: - from unsloth_zoo.gradient_checkpointing import reset_unsloth_gradient_checkpointing_buffers -except: - def reset_unsloth_gradient_checkpointing_buffers(): pass -def prepare_for_training_mode(f): - @functools.wraps(f) - def wrapper(self, *args, **kwargs): - # Enable training mode - _was_training = None - # Get gradient checkpointing setting from training arguments - use_gc = getattr(self.args, 'gradient_checkpointing', True) - if hasattr(self, 'model') and hasattr(self.model, "training"): - _was_training = self.model.training - if hasattr(self, 'model') and hasattr(self.model, "for_training"): - self.model.for_training(use_gradient_checkpointing=use_gc) - output = f(self, *args, **kwargs) - # Restore previous mode when possible - if hasattr(self, 'model') and hasattr(self.model, "for_inference"): - if _was_training is False: - self.model.for_inference() - elif _was_training is True and hasattr(self.model, "for_training"): - self.model.for_training(use_gradient_checkpointing=use_gc) - # Reset gradient checkpointing buffers to free memory while staying ready for next run - try: - reset_unsloth_gradient_checkpointing_buffers() - except: - pass - # Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run - try: - import wandb - wandb.finish() - except: - pass - return output - return wrapper -pass - -torch_compile_options = { - "epilogue_fusion" : True, - "max_autotune" : False, - "shape_padding" : True, - "trace.enabled" : False, - "triton.cudagraphs" : False, -} - -@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) -def chunked_hidden_states_selective_log_softmax( - hidden_states: torch.Tensor, - lm_head: torch.Tensor, - index: torch.Tensor, - chunks: int = 4, - logit_scale_multiply: float = 0.0, - logit_scale_divide: float = 0.0, - logit_softcapping: float = 0.0, - temperature: float = 1.0, -) -> torch.Tensor: - # All Unsloth Zoo code licensed under AGPL3 - flat_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1]) - flat_index = index.reshape(-1) - - chunked_hidden_states = torch.chunk(flat_hidden_states, chunks=chunks, dim=0) - chunked_index = torch.chunk(flat_index, chunks=chunks, dim=0) - - all_per_token_logps = [] - - for chunk_hidden_states, chunk_index in zip(chunked_hidden_states, chunked_index): - chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t() - - if logit_scale_multiply != 0.0: - chunk_logits = chunk_logits * logit_scale_multiply - if logit_scale_divide != 0.0: - chunk_logits = chunk_logits / logit_scale_divide - if logit_softcapping != 0.0: - chunk_logits = chunk_logits * torch.tanh(chunk_logits / logit_softcapping) - - chunk_logits = chunk_logits.to(torch.float32) - - if temperature != 1.0: - chunk_logits = chunk_logits / temperature - - selected_logits = torch.gather(chunk_logits, dim=-1, index=chunk_index.unsqueeze(-1)).squeeze(-1) - logsumexp_values = torch.logsumexp(chunk_logits, dim=-1) - per_token_logps = selected_logits - logsumexp_values - all_per_token_logps.append(per_token_logps) - - all_per_token_logps = torch.concat(all_per_token_logps) - - all_per_token_logps = all_per_token_logps.reshape((hidden_states.shape[0], hidden_states.shape[1])) - return all_per_token_logps - -@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) -def chunked_selective_log_softmax(logits, index): - # Split into 4 chunks only - chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0) - chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0) - all_per_token_logps = [] - # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index) - for chunk_logits, chunk_index in zip(chunked_logits, chunked_index): - chunk_logits = chunk_logits.to(torch.float32) - selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1) - logsumexp_values = torch.logsumexp(chunk_logits, dim = -1) - per_token_logps = selected_logits - logsumexp_values - all_per_token_logps.append(per_token_logps) - pass - all_per_token_logps = torch.concat(all_per_token_logps) - all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1])) - return all_per_token_logps - -def calculate_pad_tokens_in_prompt( - input_ids: torch.Tensor, - logits_to_keep: int, - pad_token_id: int -) -> torch.Tensor: - """ - Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens - """ - if logits_to_keep >= input_ids.shape[1]: - raise ValueError("logits_to_keep must be smaller than the sequence length.") - - prompt_section = input_ids[:, :-logits_to_keep] - - padding_mask = (prompt_section == pad_token_id) - - pad_token_counts = padding_mask.sum(dim=1) - - return pad_token_counts - -def create_completion_attention_mask( - completion_input_ids: torch.Tensor, - left_pad_tokens_per_prompt: torch.Tensor, - max_left_pad: int, - pad_token_id: int -) -> torch.Tensor: - """ - Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad] - - Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens - and pad are pad tokens, this function would make a completion mask that would 0 out the pad - and p tokens. so in this example [0,0,0,1,1,1,0,0,0] - """ - batch_size, completion_len = completion_input_ids.shape - device = completion_input_ids.device - - num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt - - indices = torch.arange(completion_len, device=device).unsqueeze(0) - shift_mask = indices >= num_tokens_to_mask.unsqueeze(1) - - non_padding_mask = (completion_input_ids != pad_token_id) - - final_mask = shift_mask & non_padding_mask - - return final_mask - -def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor: - """ - Moves all padding tokens in each sequence of a batch to the right. - """ - mask = (tensor != pad_id) - # Must do stable=True since binary mark is unordered - sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True) - packed_tensor = torch.gather(tensor, 1, sorted_indices) - return packed_tensor - -def align_logprobs_with_mask( - logprob_tensor: torch.Tensor, - attention_mask: torch.Tensor, - pad_value: float = 0.0 -) -> torch.Tensor: - """ - Aligns a log probability tensor with a given attention mask. - """ - - device = logprob_tensor.device - batch_size, logprob_seq_len = logprob_tensor.shape - mask_seq_len = attention_mask.shape[1] - - padded_logprobs = torch.full( - attention_mask.shape, - fill_value=pad_value, - dtype=logprob_tensor.dtype, - device=device - ) - - left_pad_counts = torch.argmax(attention_mask, dim=1) - - cols = torch.arange(logprob_seq_len, device=device) - dest_indices = left_pad_counts.unsqueeze(1) + cols - - # Create destination row indices - # Shape: [batch_size, logprob_seq_len] - row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices) - - # --- 4. Filter out-of-bounds indices and perform assignment --- - # Create a mask to identify only the indices that are within the bounds - # of the target tensor's sequence length. - valid_mask = dest_indices < mask_seq_len - - # Use this mask to select only the valid row indices, column indices, - # and the corresponding values from the logprob tensor. - # This flattens the selected elements into 1D tensors. - valid_rows = row_indices[valid_mask] - valid_cols = dest_indices[valid_mask] - valid_vals = logprob_tensor[valid_mask] - - # Place the valid values into their correct positions in the padded tensor - # using a single, efficient advanced indexing operation. - padded_logprobs[valid_rows, valid_cols] = valid_vals - - return padded_logprobs - -def autotune_batch_and_chunks( - total_input_rows, - seq_len, - hidden_size, - vocab_size, - dtype_bytes=16, - multiplier=None -): - if multiplier is None: - final_m = max(4, seq_len // 4096) - else: - final_m = multiplier - - if torch.cuda.is_available(): - free_bytes, _ = torch.cuda.mem_get_info() - limit_gb = (free_bytes / (1024**3))*.80 - elif hasattr(torch, "xpu") and torch.xpu.is_available(): - # For XPU: estimate free memory from total - reserved - total_mem = torch.xpu.get_device_properties(0).total_memory - reserved_mem = torch.xpu.memory_reserved() - free_bytes = total_mem - reserved_mem - limit_gb = (free_bytes / (1024**3)) * 0.80 - else: - # Fallback: assume 8GB available - limit_gb = 8.0 - - bytes_to_gb = 1024**3 - - b_vals = torch.arange(total_input_rows, 0, -1, device='cpu', dtype=torch.float32) - - hidden_gb = (b_vals * seq_len * hidden_size * dtype_bytes) / bytes_to_gb - - base_logits = ((b_vals/total_input_rows) * b_vals * seq_len * vocab_size * dtype_bytes) / bytes_to_gb - logits_gb = base_logits / final_m - - total_mem_gb = hidden_gb + logits_gb - - valid_mask = total_mem_gb <= limit_gb - valid_indices = torch.nonzero(valid_mask, as_tuple=False) - - if valid_indices.shape[0] == 0: - #This means your GPU will OOM - return 4, final_m - - best_idx = valid_indices[0].item() - final_b = int(b_vals[best_idx].item()) - - return final_b, final_m -@dataclass -class UnslothDPOConfig(DPOConfig): - """ - - Configuration class for the [`DPOTrainer`]. - - This class includes only the parameters that are specific to DPO training. For a full list of training arguments, - please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may - differ from those in [`~transformers.TrainingArguments`]. - - Using [`~transformers.HfArgumentParser`] we can turn this class into - [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the - command line. - - Parameters: - > Parameters that control the model and reference model - - model_init_kwargs (`dict[str, Any]`, *optional*): - Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `model` argument of the - [`DPOTrainer`] is provided as a string. - ref_model_init_kwargs (`dict[str, Any]`, *optional*): - Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `ref_model` argument of the - [`DPOTrainer`] is provided as a string. - model_adapter_name (`str`, *optional*): - Name of the train target PEFT adapter, when using LoRA with multiple adapters. - ref_adapter_name (`str`, *optional*): - Name of the reference PEFT adapter, when using LoRA with multiple adapters. - force_use_ref_model (`bool`, *optional*, defaults to `False`): - If you provide a PEFT model as the active model and wish to use a different model for the `ref_model`, set - this flag to `True`. - disable_dropout (`bool`, *optional*, defaults to `True`): - Whether to disable dropout in the model and reference model. - use_logits_to_keep (`bool`, *optional*, defaults to `False`): - If `True`, only a specified number of logits are computed in the forward pass. This can be useful for - saving memory and speeding up training by not computing the logits for all tokens, especially in scenarios - when working with very long prompts where labels are ignored (-100). - - > Parameters that control the data preprocessing - - dataset_num_proc (`int`, *optional*): - Number of processes to use for processing the dataset. - pad_token (`str`, *optional*): - Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that is also `None`, - it falls back to `processing_class.eos_token`. - label_pad_token_id (`int`, *optional*, defaults to `-100`): - Padding value to use for labels. - max_prompt_length (`int` or `None`, *optional*, defaults to `512`): - Maximum length of the prompt. - max_completion_length (`int`, *optional*): - Maximum length of the completion. - max_length (`int` or `None`, *optional*, defaults to `1024`): - Maximum length of the full sequence (prompt + completion). - truncation_mode (`str`, *optional*, defaults to `"keep_end"`): - Truncation mode to use when the sequence exceeds `max_length`. Possible values are `"keep_end"` and - `"keep_start"`. - padding_free (`bool`, *optional*, defaults to `False`): - Whether to perform forward passes without padding by flattening all sequences in the batch into a single - continuous sequence. This reduces memory usage by eliminating padding overhead. Currently, this is only - supported with the `flash_attention_2` attention implementation, which can efficiently handle the flattened - batch structure. - precompute_ref_log_probs (`bool`, *optional*, defaults to `False`): - Whether to precompute the log probabilities from the reference model. Setting this to `True` allows - training without needing the reference model during training, which can help reduce GPU memory usage. If - set to `False` (default), the reference model will be used during training to compute log probabilities - on-the-fly. - precompute_ref_batch_size (`int`, *optional*): - Batch size to use when precomputing reference model log probabilities. This can be set higher than the - training batch size to speed up preprocessing. If `None`, defaults to `per_device_train_batch_size` for - training and `per_device_eval_batch_size` for evaluation. - tools (`Optional[list[Union[dict, Callable]]]`, *optional*): - List of tools (callable functions) that will be accessible to the model. If the template does not support - function calling, this argument will have no effect. - - > Parameters that control the training - - loss_type (`str` or `list[str]`, *optional*, defaults to `"sigmoid"`): - Type of loss to use. Possible values are: - - - `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper. - - `"hinge"`: hinge loss on the normalized likelihood from the - [SLiC](https://huggingface.co/papers/2305.10425) paper. - - `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper. - - `"exo_pair"`: pairwise EXO loss from the [EXO](https://huggingface.co/papers/2402.00856) paper. - - `"nca_pair"`: pairwise NCA loss from the [NCA](https://huggingface.co/papers/2402.05369) paper. - - `"robust"`: unbiased estimate of the DPO loss that is robust to preference noise from the [Robust - DPO](https://huggingface.co/papers/2403.00409) paper. - - `"bco_pair"`: pairwise BCO loss from the [BCO](https://huggingface.co/papers/2404.04656) paper. - - `"sppo_hard"`: SPPO loss with hard label from the [SPPO](https://huggingface.co/papers/2405.00675) - paper. - - `"aot"`: AOT loss for paired datasets from the [AOT](https://huggingface.co/papers/2406.05882) paper. - - `"aot_pair"`: AOT loss for unpaired datasets from the [AOT](https://huggingface.co/papers/2406.05882) - paper. - - `"discopop"`: DiscoPOP (a.k.a Log-Ratio Modulated Loss, LRML) loss from the - [DiscoPOP](https://huggingface.co/papers/2406.08414) paper. - - `"apo_zero"`: APO-zero loss from the [APO](https://huggingface.co/papers/2408.06266) paper. - - `"apo_down"`: APO-down loss from the [APO](https://huggingface.co/papers/2408.06266) paper. - - `"sft"`: Negative log-likelihood loss (standard supervised fine-tuning loss). - - Multiple loss types can be combined using comma separation (e.g., `["sigmoid", "bco_pair", "sft"]` for - [MPO](https://huggingface.co/papers/2411.10442)). The `loss_weights` parameter can be used to specify - corresponding weights for each loss type. - - use_liger_loss (`bool`, *optional*, defaults to `False`): - Whether to use Liger loss. - base_model_attribute_name (`str`, *optional*, defaults to `"model"`): - Name of the attribute in the model that contains the base model. This is used to get the base model from - the model when the model does not have a `get_decoder` method in the case when `use_liger_loss` is `True`. - beta (`float`, *optional*, defaults to `0.1`): - Parameter controlling the deviation from the reference model. Higher β means less deviation from the - reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in - the [paper](https://huggingface.co/papers/2310.12036). - f_divergence_type ([`FDivergenceType`] or `str`, *optional*, defaults to `FDivergenceType.REVERSE_KL`): - Type of f-divergence regularization function to compute divergence between policy and reference model. - f_alpha_divergence_coef (`float`, *optional*, defaults to `1.0`): - α coefficient in the α-divergence u^-α regularization function for DPO loss. - reference_free (`bool`, *optional*, defaults to `False`): - Whether to ignore the provided reference model and implicitly use a reference model that assigns equal - probability to all responses. - label_smoothing (`float`, *optional*, defaults to `0.0`): - Robust DPO label smoothing parameter from the [cDPO report](https://ericmitchell.ai/cdpo.pdf) and [Robust - DPO](https://huggingface.co/papers/2403.00409) paper that should be between `0.0` and `0.5`. - use_weighting (`bool`, *optional*, defaults to `False`): - Whether to weight the loss as done in the [WPO paper](https://huggingface.co/papers/2406.11827). - rpo_alpha (`float`, *optional*): - α parameter from the [RPO paper](https://huggingface.co/papers/2404.19733) (v3), which controls the - weighting of the NLL term in the loss. If `None`, no weighting is applied and the loss is the same as the - DPO loss. The paper recommends `rpo_alpha=1.0`. - ld_alpha (`float`, *optional*): - α parameter from the [LD-DPO paper](https://huggingface.co/papers/2409.06411), which controls the weighting - of the verbose token log-probabilities in responses. If `None`, no weighting is applied to the verbose - part, and the loss is equivalent to the standard DPO loss. The paper recommends setting `ld_alpha` between - `0.0` and `1.0`. - discopop_tau (`float`, *optional*, defaults to `0.05`): - τ/temperature parameter from the [DiscoPOP](https://huggingface.co/papers/2406.08414) paper, which controls - the shape of log ratio modulated loss. The paper recommends the default value `discopop_tau=0.05`. - loss_weights (`list[float]`, *optional*): - List of loss weights for multi-loss combinations. Used when combining multiple loss types. Example: `[0.8, - 0.2, 1.0]` for [MPO](https://huggingface.co/papers/2411.10442). If not provided, defaults to equal weights - (`1.0`) for all loss types. - sync_ref_model (`bool`, *optional*, defaults to `False`): - Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using - the `ref_model_mixup_alpha` parameter. This synchronization originates from the - [TR-DPO](https://huggingface.co/papers/2404.09656) paper. - ref_model_mixup_alpha (`float`, *optional*, defaults to `0.6`): - α parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which controls the mix - between the current policy and the previous reference policy during updates. The reference policy is - updated according to the equation: `π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you - must set `sync_ref_model=True`. - ref_model_sync_steps (`int`, *optional*, defaults to `512`): - τ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which determines how - frequently the current policy is synchronized with the reference policy. To use this parameter, you must - set `sync_ref_model=True`. - - > Parameters that control the logging - - generate_during_eval (`bool`, *optional*, defaults to `False`): - Whether to generate and log completions from both the model and the reference model to W&B or Comet during - evaluation. - - > Deprecated parameters - - padding_value: - - - - This parameter is deprecated and will be removed in version 0.25.0. Use `pad_token` (`str`) instead. - - - - """ - vllm_sampling_params: Optional[Any] = field( - default = None, - metadata = {'help': 'vLLM SamplingParams'}, - ) - unsloth_num_chunks : Optional[int] = field( - default = -1, - metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'}, - ) - unsloth_logit_chunk_multiplier : Optional[int] = field( - default = None, - metadata = {'help': 'Multiplier for chunked logit computations.'}, - ) - unsloth_grpo_mini_batch : Optional[int] = field( - default = None, - metadata = {'help': 'Mini batch size for GRPO hidden state accumulation. Default is None unless user defines it.'}, - ) - max_seq_length : Optional[int] = field( - default = None, - metadata = {'help': 'Maximum sequence length to truncate to.'}, - ) - def __init__( - self, - output_dir = None, - overwrite_output_dir = None, - do_train = False, - do_eval = False, - do_predict = False, - eval_strategy = 'no', - prediction_loss_only = False, - per_device_train_batch_size = 4, - per_device_eval_batch_size = 4, - per_gpu_train_batch_size = None, - per_gpu_eval_batch_size = None, - gradient_accumulation_steps = 2, - eval_accumulation_steps = 2, - eval_delay = 0, - torch_empty_cache_steps = 250, - learning_rate = 5e-05, - weight_decay = 0.01, - adam_beta1 = 0.9, - adam_beta2 = 0.999, - adam_epsilon = 1e-08, - max_grad_norm = 1.0, - num_train_epochs = 3.0, - max_steps = -1, - lr_scheduler_type = 'linear', - lr_scheduler_kwargs = None, - warmup_ratio = 0.1, - warmup_steps = 0, - log_level = 'passive', - log_level_replica = 'warning', - log_on_each_node = True, - logging_dir = None, - logging_strategy = 'steps', - logging_first_step = False, - logging_steps = 1, - logging_nan_inf_filter = False, - save_strategy = 'steps', - save_steps = 500, - save_total_limit = None, - save_safetensors = True, - save_on_each_node = False, - save_only_model = False, - restore_callback_states_from_checkpoint = False, - no_cuda = False, - use_cpu = False, - use_mps_device = False, - seed = 3407, - data_seed = 3407, - jit_mode_eval = False, - bf16 = False, - fp16 = False, - fp16_opt_level = 'O1', - half_precision_backend = 'auto', - bf16_full_eval = False, - fp16_full_eval = False, - tf32 = None, - local_rank = -1, - ddp_backend = None, - tpu_num_cores = None, - tpu_metrics_debug = False, - debug = '', - dataloader_drop_last = False, - eval_steps = None, - dataloader_num_workers = 0, - dataloader_prefetch_factor = None, - past_index = -1, - run_name = None, - disable_tqdm = None, - remove_unused_columns = True, - label_names = None, - load_best_model_at_end = False, - metric_for_best_model = None, - greater_is_better = None, - ignore_data_skip = False, - fsdp = None, - fsdp_min_num_params = 0, - fsdp_config = None, - fsdp_transformer_layer_cls_to_wrap = None, - accelerator_config = None, - parallelism_config = None, - deepspeed = None, - label_smoothing_factor = 0.0, - optim = 'adamw_8bit', - optim_args = None, - adafactor = False, - group_by_length = False, - length_column_name = 'length', - report_to = 'none', - project = 'huggingface', - trackio_space_id = 'trackio', - ddp_find_unused_parameters = None, - ddp_bucket_cap_mb = None, - ddp_broadcast_buffers = None, - dataloader_pin_memory = True, - dataloader_persistent_workers = False, - skip_memory_metrics = True, - use_legacy_prediction_loop = False, - push_to_hub = False, - resume_from_checkpoint = None, - hub_model_id = None, - hub_strategy = 'every_save', - hub_token = None, - hub_private_repo = None, - hub_always_push = False, - hub_revision = None, - gradient_checkpointing = True, - gradient_checkpointing_kwargs = None, - include_inputs_for_metrics = False, - eval_do_concat_batches = True, - fp16_backend = 'auto', - push_to_hub_model_id = None, - push_to_hub_organization = None, - push_to_hub_token = None, - mp_parameters = '', - auto_find_batch_size = False, - full_determinism = False, - torchdynamo = None, - ray_scope = 'last', - ddp_timeout = 1800, - torch_compile = False, - torch_compile_backend = None, - torch_compile_mode = None, - include_tokens_per_second = False, - include_num_input_tokens_seen = False, - neftune_noise_alpha = None, - optim_target_modules = None, - batch_eval_metrics = False, - eval_on_start = False, - use_liger_kernel = False, - liger_kernel_config = None, - eval_use_gather_object = False, - average_tokens_across_devices = True, - model_init_kwargs = None, - ref_model_init_kwargs = None, - model_adapter_name = None, - ref_adapter_name = None, - force_use_ref_model = False, - disable_dropout = True, - use_logits_to_keep = False, - dataset_num_proc = None, - pad_token = None, - label_pad_token_id = -100, - max_prompt_length = 512, - max_completion_length = None, - max_length = 1024, - truncation_mode = 'keep_end', - padding_free = False, - precompute_ref_log_probs = False, - precompute_ref_batch_size = None, - tools = None, - use_liger_loss = False, - base_model_attribute_name = 'model', - beta = 0.1, - f_alpha_divergence_coef = 1.0, - reference_free = False, - label_smoothing = 0.0, - use_weighting = False, - rpo_alpha = None, - ld_alpha = None, - discopop_tau = 0.05, - loss_weights = None, - sync_ref_model = False, - ref_model_mixup_alpha = 0.6, - ref_model_sync_steps = 512, - generate_during_eval = False, - padding_value = None, - vllm_sampling_params = None, - unsloth_num_chunks = -1, - unsloth_logit_chunk_multiplier = None, - unsloth_grpo_mini_batch = None, - max_seq_length = None, - **kwargs, - ): - if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!') - if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!') - if num_train_epochs is None: - num_train_epochs = 3.0 # Default to 3 epochs if None, max_steps will override - if output_dir is None and save_strategy == 'steps' and save_steps == 500: - output_dir = 'unsloth_training_checkpoints' - save_strategy = 'no' - import multiprocessing as _mp - if _mp.get_start_method() != 'fork': - dataset_num_proc = None - elif dataset_num_proc is None: - import psutil - dataset_num_proc = min(max((psutil.cpu_count() or 1)+4, 2), 64) - memory_gb_left = psutil.virtual_memory().available / (1024**3) - if memory_gb_left <= 2: dataset_num_proc = 1 - else: dataset_num_proc = min(dataset_num_proc, int(memory_gb_left)) - - super().__init__( - output_dir = output_dir, - overwrite_output_dir = overwrite_output_dir, - do_train = do_train, - do_eval = do_eval, - do_predict = do_predict, - eval_strategy = eval_strategy, - prediction_loss_only = prediction_loss_only, - per_device_train_batch_size = per_device_train_batch_size, - per_device_eval_batch_size = per_device_eval_batch_size, - per_gpu_train_batch_size = per_gpu_train_batch_size, - per_gpu_eval_batch_size = per_gpu_eval_batch_size, - gradient_accumulation_steps = gradient_accumulation_steps, - eval_accumulation_steps = eval_accumulation_steps, - eval_delay = eval_delay, - torch_empty_cache_steps = torch_empty_cache_steps, - learning_rate = learning_rate, - weight_decay = weight_decay, - adam_beta1 = adam_beta1, - adam_beta2 = adam_beta2, - adam_epsilon = adam_epsilon, - max_grad_norm = max_grad_norm, - num_train_epochs = num_train_epochs, - max_steps = max_steps, - lr_scheduler_type = lr_scheduler_type, - lr_scheduler_kwargs = lr_scheduler_kwargs, - warmup_ratio = warmup_ratio, - warmup_steps = warmup_steps, - log_level = log_level, - log_level_replica = log_level_replica, - log_on_each_node = log_on_each_node, - logging_dir = logging_dir, - logging_strategy = logging_strategy, - logging_first_step = logging_first_step, - logging_steps = logging_steps, - logging_nan_inf_filter = logging_nan_inf_filter, - save_strategy = save_strategy, - save_steps = save_steps, - save_total_limit = save_total_limit, - save_safetensors = save_safetensors, - save_on_each_node = save_on_each_node, - save_only_model = save_only_model, - restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint, - no_cuda = no_cuda, - use_cpu = use_cpu, - use_mps_device = use_mps_device, - seed = seed, - data_seed = data_seed, - jit_mode_eval = jit_mode_eval, - bf16 = bf16, - fp16 = fp16, - fp16_opt_level = fp16_opt_level, - half_precision_backend = half_precision_backend, - bf16_full_eval = bf16_full_eval, - fp16_full_eval = fp16_full_eval, - tf32 = tf32, - local_rank = local_rank, - ddp_backend = ddp_backend, - tpu_num_cores = tpu_num_cores, - tpu_metrics_debug = tpu_metrics_debug, - debug = debug, - dataloader_drop_last = dataloader_drop_last, - eval_steps = eval_steps, - dataloader_num_workers = dataloader_num_workers, - dataloader_prefetch_factor = dataloader_prefetch_factor, - past_index = past_index, - run_name = run_name, - disable_tqdm = disable_tqdm, - remove_unused_columns = remove_unused_columns, - label_names = label_names, - load_best_model_at_end = load_best_model_at_end, - metric_for_best_model = metric_for_best_model, - greater_is_better = greater_is_better, - ignore_data_skip = ignore_data_skip, - fsdp = fsdp, - fsdp_min_num_params = fsdp_min_num_params, - fsdp_config = fsdp_config, - fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap, - accelerator_config = accelerator_config, - parallelism_config = parallelism_config, - deepspeed = deepspeed, - label_smoothing_factor = label_smoothing_factor, - optim = optim, - optim_args = optim_args, - adafactor = adafactor, - group_by_length = group_by_length, - length_column_name = length_column_name, - report_to = report_to, - project = project, - trackio_space_id = trackio_space_id, - ddp_find_unused_parameters = ddp_find_unused_parameters, - ddp_bucket_cap_mb = ddp_bucket_cap_mb, - ddp_broadcast_buffers = ddp_broadcast_buffers, - dataloader_pin_memory = dataloader_pin_memory, - dataloader_persistent_workers = dataloader_persistent_workers, - skip_memory_metrics = skip_memory_metrics, - use_legacy_prediction_loop = use_legacy_prediction_loop, - push_to_hub = push_to_hub, - resume_from_checkpoint = resume_from_checkpoint, - hub_model_id = hub_model_id, - hub_strategy = hub_strategy, - hub_token = hub_token, - hub_private_repo = hub_private_repo, - hub_always_push = hub_always_push, - hub_revision = hub_revision, - gradient_checkpointing = gradient_checkpointing, - gradient_checkpointing_kwargs = gradient_checkpointing_kwargs, - include_inputs_for_metrics = include_inputs_for_metrics, - eval_do_concat_batches = eval_do_concat_batches, - fp16_backend = fp16_backend, - push_to_hub_model_id = push_to_hub_model_id, - push_to_hub_organization = push_to_hub_organization, - push_to_hub_token = push_to_hub_token, - mp_parameters = mp_parameters, - auto_find_batch_size = auto_find_batch_size, - full_determinism = full_determinism, - torchdynamo = torchdynamo, - ray_scope = ray_scope, - ddp_timeout = ddp_timeout, - torch_compile = torch_compile, - torch_compile_backend = torch_compile_backend, - torch_compile_mode = torch_compile_mode, - include_tokens_per_second = include_tokens_per_second, - include_num_input_tokens_seen = include_num_input_tokens_seen, - neftune_noise_alpha = neftune_noise_alpha, - optim_target_modules = optim_target_modules, - batch_eval_metrics = batch_eval_metrics, - eval_on_start = eval_on_start, - use_liger_kernel = use_liger_kernel, - liger_kernel_config = liger_kernel_config, - eval_use_gather_object = eval_use_gather_object, - average_tokens_across_devices = average_tokens_across_devices, - model_init_kwargs = model_init_kwargs, - ref_model_init_kwargs = ref_model_init_kwargs, - model_adapter_name = model_adapter_name, - ref_adapter_name = ref_adapter_name, - force_use_ref_model = force_use_ref_model, - disable_dropout = disable_dropout, - use_logits_to_keep = use_logits_to_keep, - dataset_num_proc = dataset_num_proc, - pad_token = pad_token, - label_pad_token_id = label_pad_token_id, - max_prompt_length = max_prompt_length, - max_completion_length = max_completion_length, - max_length = max_length, - truncation_mode = truncation_mode, - padding_free = padding_free, - precompute_ref_log_probs = precompute_ref_log_probs, - precompute_ref_batch_size = precompute_ref_batch_size, - tools = tools, - use_liger_loss = use_liger_loss, - base_model_attribute_name = base_model_attribute_name, - beta = beta, - f_alpha_divergence_coef = f_alpha_divergence_coef, - reference_free = reference_free, - label_smoothing = label_smoothing, - use_weighting = use_weighting, - rpo_alpha = rpo_alpha, - ld_alpha = ld_alpha, - discopop_tau = discopop_tau, - loss_weights = loss_weights, - sync_ref_model = sync_ref_model, - ref_model_mixup_alpha = ref_model_mixup_alpha, - ref_model_sync_steps = ref_model_sync_steps, - generate_during_eval = generate_during_eval, - padding_value = padding_value,**kwargs) - self.vllm_sampling_params = vllm_sampling_params - self.unsloth_num_chunks = unsloth_num_chunks - if unsloth_grpo_mini_batch is not None: - if self.generation_batch_size >= unsloth_grpo_mini_batch: - self.unsloth_grpo_mini_batch = unsloth_grpo_mini_batch - else: - raise ValueError( - f"Unsloth GRPO mini batch size needs to be less than or equal to the effective generation batch size, " - f"which is self.per_device_train_batch_size * gradient_accumulation_steps." - ) - self.unsloth_logit_chunk_multiplier = unsloth_logit_chunk_multiplier - self.max_seq_length = max_seq_length - -pass - -class _UnslothDPOTrainer(BaseTrainer): - """""" - - _tag_names = ["trl", "dpo"] - _name = "DPO" - _paper = { - "title": "Direct Preference Optimization: Your Language Model is Secretly a Reward Model", - "id": "2305.18290", - # docstyle-ignore - "citation": textwrap.dedent("""\ - @inproceedings{rafailov2023direct, - title = {{Direct Preference Optimization: Your Language Model is Secretly a Reward Model}}, - author = {Rafael Rafailov and Archit Sharma and Eric Mitchell and Christopher D. Manning and Stefano Ermon and Chelsea Finn}, - year = 2023, - booktitle = {Advances in Neural Information Processing Systems 36: Annual Conference on Neural Information Processing Systems 2023, NeurIPS 2023, New Orleans, LA, USA, December 10 - 16, 2023}, - url = {http://papers.nips.cc/paper_files/paper/2023/hash/a85b405ed65c6477a4fe8302b5e06ce7-Abstract-Conference.html}, - editor = {Alice Oh and Tristan Naumann and Amir Globerson and Kate Saenko and Moritz Hardt and Sergey Levine}, - }"""), - } - - def __init__( - self, - model: Union[str, nn.Module, PreTrainedModel], - ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, - args: Optional[DPOConfig] = None, - data_collator: Optional[DataCollator] = None, # type: ignore - train_dataset: Optional[Union[Dataset, IterableDataset]] = None, - eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None, - processing_class: Optional[ - Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] - ] = None, - compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None, - callbacks: Optional[list[TrainerCallback]] = None, - optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), - optimizer_cls_and_kwargs: Optional[tuple[type[torch.optim.Optimizer], dict[str, Any]]] = None, - preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, - peft_config: Optional["PeftConfig"] = None, - ): - # Args - if args is None: - model_name = model if isinstance(model, str) else model.config._name_or_path - model_name = model_name.split("/")[-1] - args = DPOConfig(f"{model_name}-DPO") - - # Model and reference model - if isinstance(model, str): - model = create_model_from_path(model, **args.model_init_kwargs or {}) - else: - if args.model_init_kwargs is not None: - logger.warning( - "You passed `model_init_kwargs` to the `DPOConfig`, but your model is already instantiated. " - "The `model_init_kwargs` will be ignored." - ) - model_id = model.config._name_or_path - if isinstance(ref_model, str): - ref_model = create_model_from_path(ref_model, **args.ref_model_init_kwargs or {}) - else: - if args.ref_model_init_kwargs is not None: - logger.warning( - "You passed `ref_model_init_kwargs` to the `DPOConfig`, but your model is already instantiated. " - "The `ref_model_init_kwargs` will be ignored." - ) - if ref_model is model: - raise ValueError( - "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the " - "same as `model`, you can simply omit the `ref_model` argument and it will be created for you." - ) - - # Processing class - if processing_class is None: - processing_class = AutoProcessor.from_pretrained(model_id) - - # Handle pad token for processors or tokenizers - if isinstance(processing_class, ProcessorMixin): - tokenizer = processing_class.tokenizer - self._is_vlm = True - elif isinstance(processing_class, PreTrainedTokenizerBase): - tokenizer = processing_class - self._is_vlm = False - else: - raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") - - # Get the pad token: if not provided, use the one from the processing class or the eos token - # if the processing class does not have a pad token. - if args.padding_value is not None: # deprecated, will be removed in 0.26.0. - warnings.warn( - "The `padding_value` argument is deprecated and will be removed in version 0.26.0. Please use " - "`pad_token` (str) instead." - ) - self.pad_token_id = args.padding_value - else: - pad_token = args.pad_token or tokenizer.pad_token or tokenizer.eos_token - self.pad_token_id = tokenizer.convert_tokens_to_ids(pad_token) - if self.pad_token_id is None: - raise ValueError( - f"The specified `pad_token` ('{pad_token}') is not found in the vocabulary of the given " - f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `pad_token` exists " - "in the vocabulary before using it as a padding token." - ) - - # PEFT configuration and model wrapping - model = self._prepare_peft_model(model, ref_model, peft_config, args) - - if args.generate_during_eval and not (is_wandb_available() or is_comet_available() or is_mlflow_available()): - raise ValueError( - "`generate_during_eval=True` requires Weights and Biases, MLFlow or Comet to be installed." - " Please install `wandb`, `mlflow` or `comet-ml` to resolve." - ) - - self.is_encoder_decoder = model.config.is_encoder_decoder - self.is_vision_model = model.config.model_type in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.keys() - self.is_peft_model = is_peft_available() and isinstance(model, PeftModel) - self.model_adapter_name = args.model_adapter_name - self.ref_adapter_name = args.ref_adapter_name - self.reference_free = args.reference_free - - if ref_model: - self.ref_model = ref_model - elif self.is_peft_model or args.precompute_ref_log_probs: - # The `model` with adapters turned off will be used as the reference model - self.ref_model = None - else: - self.ref_model = create_reference_model(model) - - # Disable dropout in the model and reference model - if args.disable_dropout: - disable_dropout_in_model(model) - if self.ref_model is not None: - disable_dropout_in_model(self.ref_model) - - # Liger kernel - if args.use_liger_loss: - if not is_liger_kernel_available(): - raise ImportError( - "You set `use_liger_loss=True` but the liger kernel is not available. " - "Please install liger-kernel first: `pip install liger-kernel`" - ) - if args.loss_type not in ["sigmoid", "apo_zero", "apo_down", "sppo_hard", "nca_pair"]: - raise ValueError( - "You set `use_liger_loss=True` but the loss type is not from `[sigmoid, apo_zero, apo_down, sppo_hard, nca_pair`. " - "Please set `loss_type='[sigmoid | apo_zero | apo_down | sppo_hard | nca_pair]'` to use the liger kernel." - ) - self.dpo_loss_fn = LigerFusedLinearDPOLoss( - ignore_index=args.label_pad_token_id, - beta=args.beta, - use_ref_model=not args.reference_free, - average_log_prob=False, - loss_type=args.loss_type, - ) - # The trainer estimates the number of FLOPs [floating-point operations] using the number of elements in the - # input tensor associated with the key "input_ids". However, in DPO, the sampled data does not include the - # "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and - # "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens - # of the input, floating-point operations will not be computed." To suppress this warning, we set the - # "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate - # that the warning has already been issued. - model.warnings_issued["estimate_tokens"] = True - - # Data collator - if data_collator is None: - data_collator = DataCollatorForPreference(pad_token_id=self.pad_token_id) - - self.generate_during_eval = args.generate_during_eval - self.label_pad_token_id = args.label_pad_token_id - self.max_prompt_length = args.max_prompt_length - self.max_completion_length = args.max_completion_length - self.max_length = args.max_length - self.truncation_mode = args.truncation_mode - self.precompute_ref_log_probs = args.precompute_ref_log_probs - self.use_logits_to_keep = args.use_logits_to_keep - - if args.padding_free: - if model.config._attn_implementation != "flash_attention_2": - logger.warning( - "Padding-free training is enabled, but the attention implementation is not set to " - "'flash_attention_2'. Padding-free training flattens batches into a single sequence, and " - "'flash_attention_2' is the only known attention mechanism that reliably supports this. Using " - "other implementations may lead to unexpected behavior. To ensure compatibility, set " - "`attn_implementation='flash_attention_2'` in the model configuration, or verify that your " - "attention mechanism can handle flattened sequences." - ) - self.padding_free = args.padding_free - - # Since ref_logs are precomputed on the first call to get_train/eval_dataloader - # keep track of first called to avoid computation of future calls - self._precomputed_train_ref_log_probs = False - self._precomputed_eval_ref_log_probs = False - - self.beta = args.beta - self.label_smoothing = args.label_smoothing - self.loss_type = args.loss_type if isinstance(args.loss_type, list) else [args.loss_type] - self.loss_weights = args.loss_weights - self.aux_loss_enabled = getattr(model.config, "output_router_logits", False) - self.use_weighting = args.use_weighting - self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0) - if self.aux_loss_enabled and self.aux_loss_coef == 0.0: - logger.warning( - "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to " - "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value " - "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary " - "loss.", - ) - for loss_type in self.loss_type: - if ( - loss_type in ["hinge", "ipo", "bco_pair", "sppo_hard", "nca_pair", "apo_zero", "apo_down"] - and args.label_smoothing > 0 - ): - logger.warning( - f"You are using the {loss_type} loss type that does not support label smoothing. The " - "`label_smoothing` parameter will be ignored. Set `label_smoothing` to `0.0` to remove this " - "warning.", - ) - if loss_type == "kto_pair": - raise ValueError("Support for kto_pair has been removed in DPOTrainer. Please use KTOTrainer.") - - self._stored_metrics = defaultdict(lambda: defaultdict(list)) - self.f_divergence_type = args.f_divergence_type - self.f_divergence_params = {FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY: args.f_alpha_divergence_coef} - self.dataset_num_proc = args.dataset_num_proc - - # Dataset preparation - train_dataset = self._prepare_dataset(train_dataset, processing_class, args, "train") - if eval_dataset is not None: - if isinstance(eval_dataset, dict): - eval_dataset = { - key: self._prepare_dataset(dataset, processing_class, args, key) - for key, dataset in eval_dataset.items() - } - else: - eval_dataset = self._prepare_dataset(eval_dataset, processing_class, args, "eval") - - super().__init__( - model=model, - args=args, - data_collator=data_collator, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - processing_class=processing_class, - compute_metrics=compute_metrics, - callbacks=callbacks, - optimizers=optimizers, - optimizer_cls_and_kwargs=optimizer_cls_and_kwargs, - preprocess_logits_for_metrics=preprocess_logits_for_metrics, - ) - - # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the - # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set - # self.model_accepts_loss_kwargs to False to enable scaling. - self.model_accepts_loss_kwargs = False - - # Add tags for models that have been loaded with the correct transformers version - if hasattr(self.model, "add_model_tags"): - self.model.add_model_tags(self._tag_names) - - if not hasattr(self, "accelerator"): - raise AttributeError( - "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`." - ) - - # Deepspeed Zero-3 does not support precompute_ref_log_probs - if self.is_deepspeed_enabled: - if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs: - raise ValueError( - "You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`." - ) - - if self.ref_model is None: - if not (self.is_peft_model or self.precompute_ref_log_probs): - raise ValueError( - "No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`" - ) - if args.sync_ref_model: - raise ValueError( - "You currently cannot use `ref_model=None` with TR-DPO method. Please provide `ref_model`." - ) - else: - if self.is_deepspeed_enabled: - self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) - elif self.is_fsdp_enabled: - self.ref_model = prepare_fsdp(self.ref_model, self.accelerator) - else: - self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) - - if args.sync_ref_model: - if self.precompute_ref_log_probs: - raise ValueError( - "You cannot use `precompute_ref_log_probs=True` with TR-DPO method. Please set `precompute_ref_log_probs=False`." - ) - - self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator)) - - if "bco_pair" in self.loss_type: - self.running = RunningMoments(self.accelerator) - - @property - def padding_value(self): - warnings.warn( - "The `padding_value` property is deprecated and will be removed in version 0.26.0. Please use " - "`pad_token_id` instead.", - ) - return self.pad_token_id - - @padding_value.setter - def padding_value(self, value): - warnings.warn( - "The `padding_value` property is deprecated and will be removed in version 0.26.0. Please use " - "`pad_token_id` instead.", - ) - self.pad_token_id = value - - def _prepare_peft_model( - self, model: PreTrainedModel, ref_model: PreTrainedModel, peft_config: Any, args: DPOConfig - ) -> PreTrainedModel: - """Prepares a model for PEFT training.""" - # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16` - # has been called in order to properly call autocast if needed. - self._peft_has_been_casted_to_bf16 = False - - if not is_peft_available() and peft_config is not None: - raise ValueError( - "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models" - ) - elif is_peft_available() and peft_config is not None: - # if model is a peft model and we have a peft_config, we merge and unload it first - if isinstance(model, PeftModel): - model = model.merge_and_unload() - - if ref_model is not None and not args.force_use_ref_model: - raise ValueError( - "You passed both a ref_model and a peft_config. For training PEFT adapters with DPO there is no need to pass a reference" - " model. Please pass `ref_model=None` in case you want to train PEFT adapters, or pass a ref_model with `force_use_ref_model=True` in DPOTrainer's init." - " if you want to use a different ref_model." - ) - - if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): - _support_gc_kwargs = hasattr( - args, "gradient_checkpointing_kwargs" - ) and "gradient_checkpointing_kwargs" in list( - inspect.signature(prepare_model_for_kbit_training).parameters - ) - - prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing} - - if _support_gc_kwargs: - prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs - - model = prepare_model_for_kbit_training(model, **prepare_model_kwargs) - - else: - model = self._prepare_gradient_checkpointing(model, args) - - # get peft model with the given config - model = get_peft_model(model, peft_config) - if args.bf16 and getattr(model, "is_loaded_in_4bit", False): - peft_module_casting_to_bf16(model) - # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager - self._peft_has_been_casted_to_bf16 = True - - else: - model = self._prepare_gradient_checkpointing(model, args) - - return model - - def _prepare_gradient_checkpointing(self, model: PreTrainedModel, args: DPOConfig): - """Prepare the gradienting checkpointing for the model.""" - # For models that use gradient_checkpointing, we need to attach a hook that enables input - # to explicitly have `requires_grad=True`, otherwise training will either silently - # fail or completely fail. - if args.gradient_checkpointing: - # For backward compatibility with older versions of transformers - if hasattr(model, "enable_input_require_grads"): - model.enable_input_require_grads() - else: - - def make_inputs_require_grad(module, input, output): - output.requires_grad_(True) - - model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) - - return model - - def _prepare_dataset( - self, - dataset: Union[Dataset, IterableDataset], - processing_class: Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin], - args: DPOConfig, - dataset_name: str, - ) -> Union[Dataset, IterableDataset]: - # Build the kwargs for the `map` function - map_kwargs = {} - if isinstance(dataset, Dataset): # IterableDataset does not support num_proc nor writer_batch_size - map_kwargs["num_proc"] = args.dataset_num_proc - map_kwargs["writer_batch_size"] = 10 - - with PartialState().main_process_first(): - # Extract prompt if needed - if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` - map_kwargs["desc"] = f"Extracting prompt in {dataset_name} dataset" - dataset = dataset.map(maybe_extract_prompt, **map_kwargs) - - # Apply the chat template if needed - if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` - map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset" - dataset = dataset.map( - maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class, "tools": args.tools}, **map_kwargs - ) - - # Tokenize the dataset - if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` - map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset" - - dataset = dataset.map( - self.tokenize_row if not self.is_vision_model else self.process_row, - remove_columns=["chosen", "rejected"], - fn_kwargs={ - "processing_class": processing_class, - "max_prompt_length": args.max_prompt_length, - "max_completion_length": args.max_completion_length, - # for enc-dec, we add the special tokens ([bos_token] + prompt + [eos_token]; completion + [eos_token]) - "add_special_tokens": False, - }, - **map_kwargs, - ) - - return dataset - - @staticmethod - def tokenize_row( - features: dict[str, str], - processing_class: PreTrainedTokenizerBase, - max_prompt_length: Optional[int] = None, - max_completion_length: Optional[int] = None, - add_special_tokens: bool = True, - ) -> dict[str, list[int]]: - """ - Tokenize a row of the dataset. - - Args: - features (`dict[str, str]`): - Row of the dataset, should contain the keys `"prompt"`, `"chosen"`, and `"rejected"`. - processing_class ([`~transformers.PreTrainedTokenizerBase`]): - Processing class used to process the data. - max_prompt_length (`int` or `None`): - Maximum length of the prompt sequence. If `None`, the prompt sequence is not truncated. - max_completion_length (`int` or `None`): - Maximum length of the completion sequences. If `None`, the completion sequences are not truncated. - add_special_tokens (`bool`): - Whether to add special tokens to the sequences. Typically used for encoder-decoder models. If `True`, - the prompt sequence will have a bos token prepended and an eos token appended. In any case, the - completion sequences will have an eos token appended. - - Returns: - `dict[str, list[int]]`: - Tokenized sequences with the keys `"prompt_input_ids"`, `"chosen_input_ids"`, and - `"rejected_input_ids". - - Example: - ```python - >>> from transformers import GPT2Tokenizer - - >>> tokenizer = GPT2Tokenizer.from_pretrained("gpt2") - >>> features = {"prompt": "The sky is", "chosen": " blue", "rejected": " green"} - >>> DPOTrainer.tokenize_row( - ... features, tokenizer, max_prompt_length=3, max_completion_length=3, add_special_tokens=False - ... ) - {'prompt_input_ids': [464, 6766, 318], 'chosen_input_ids': [4171, 50256], 'rejected_input_ids': [4077, 50256]} - ``` - """ - tokenizer = processing_class # the processing class is a tokenizer - prompt_input_ids = tokenizer(features["prompt"], add_special_tokens=False)["input_ids"] - chosen_input_ids = tokenizer(features["chosen"], add_special_tokens=False)["input_ids"] - rejected_input_ids = tokenizer(features["rejected"], add_special_tokens=False)["input_ids"] - - # Add special tokens (typically for encoder-decoder models) - if add_special_tokens: - if tokenizer.bos_token_id is not None: - prompt_input_ids = [tokenizer.bos_token_id] + prompt_input_ids - if tokenizer.eos_token_id is not None: - prompt_input_ids = prompt_input_ids + [tokenizer.eos_token_id] - chosen_input_ids = chosen_input_ids + [tokenizer.eos_token_id] - rejected_input_ids = rejected_input_ids + [tokenizer.eos_token_id] - - # Truncate prompt and completion sequences - if max_prompt_length is not None: - prompt_input_ids = prompt_input_ids[-max_prompt_length:] - if max_completion_length is not None: - chosen_input_ids = chosen_input_ids[:max_completion_length] - rejected_input_ids = rejected_input_ids[:max_completion_length] - - return { - "prompt_input_ids": prompt_input_ids, - "chosen_input_ids": chosen_input_ids, - "rejected_input_ids": rejected_input_ids, - } - - @staticmethod - def process_row( - features: dict[str, str], - processing_class: PreTrainedTokenizerBase, - max_prompt_length: Optional[int] = None, - max_completion_length: Optional[int] = None, - add_special_tokens: bool = True, - ) -> dict[str, list[int]]: - """ - Same as `tokenize_row` but for vision models. Please refer to `tokenize_row` for more information. - """ - processor, tokenizer = processing_class, processing_class.tokenizer # the processing class is a processor - processed_features = processor(images=features["images"], text=features["prompt"], add_special_tokens=False) - - prompt_input_ids = processed_features["input_ids"][0] - pixel_values = processed_features["pixel_values"][0] - chosen_input_ids = tokenizer(features["chosen"], add_special_tokens=False)["input_ids"] - rejected_input_ids = tokenizer(features["rejected"], add_special_tokens=False)["input_ids"] - - # Add special tokens (typically for encoder-decoder models) - if add_special_tokens: - if tokenizer.bos_token_id is not None: - prompt_input_ids = [tokenizer.bos_token_id] + prompt_input_ids - if tokenizer.eos_token_id is not None: - prompt_input_ids = prompt_input_ids + [tokenizer.eos_token_id] - chosen_input_ids = chosen_input_ids + [tokenizer.eos_token_id] - rejected_input_ids = rejected_input_ids + [tokenizer.eos_token_id] - - # Truncate prompt and completion sequences - if max_prompt_length is not None: - prompt_input_ids = prompt_input_ids[-max_prompt_length:] - if max_completion_length is not None: - chosen_input_ids = chosen_input_ids[:max_completion_length] - rejected_input_ids = rejected_input_ids[:max_completion_length] - - output = { - "prompt_input_ids": prompt_input_ids, - "pixel_values": pixel_values, - "chosen_input_ids": chosen_input_ids, - "rejected_input_ids": rejected_input_ids, - } - - if "pixel_attention_mask" in processed_features: - output["pixel_attention_mask"] = processed_features["pixel_attention_mask"][0] - if "image_sizes" in processed_features: - output["image_sizes"] = processed_features["image_sizes"][0] - if "token_type_ids" in processed_features: - output["token_type_ids"] = processed_features["token_type_ids"][0] - - return output - - def _set_signature_columns_if_needed(self): - # If `self.args.remove_unused_columns` is True, non-signature columns are removed. - # By default, this method sets `self._signature_columns` to the model's expected inputs. - # In DPOTrainer, we preprocess data, so using the model's signature columns doesn't work. - # Instead, we set them to the columns expected by `DataCollatorForPreference`, hence the override. - if self._signature_columns is None: - self._signature_columns = [ - "prompt_input_ids", - "chosen_input_ids", - "rejected_input_ids", - "image_sizes", - "token_type_ids", - "ref_chosen_logps", - "ref_rejected_logps", - ] - - def get_train_dataloader(self) -> DataLoader: - """ - Returns the training [`~torch.utils.data.DataLoader`]. - - Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`. - """ - - if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs: - batch_size = self.args.precompute_ref_batch_size or self.args.per_device_train_batch_size - dataloader_params = { - "batch_size": batch_size, - "collate_fn": self.data_collator, - "num_workers": self.args.dataloader_num_workers, - "pin_memory": self.args.dataloader_pin_memory, - "shuffle": False, - } - - # prepare dataloader - data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params)) - - ref_chosen_logps = [] - ref_rejected_logps = [] - for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"): - ref_chosen_logp, ref_rejected_logp = self.compute_ref_log_probs(padded_batch) - ref_chosen_logp, ref_rejected_logp = self.accelerator.gather_for_metrics( - (ref_chosen_logp, ref_rejected_logp) - ) - ref_chosen_logps.append(ref_chosen_logp.cpu()) - ref_rejected_logps.append(ref_rejected_logp.cpu()) - - # Unnecessary cache clearing to avoid OOM - empty_cache() - self.accelerator.free_memory() - - all_ref_chosen_logps = torch.cat(ref_chosen_logps).float().numpy() - all_ref_rejected_logps = torch.cat(ref_rejected_logps).float().numpy() - - self.train_dataset = self.train_dataset.add_column(name="ref_chosen_logps", column=all_ref_chosen_logps) - self.train_dataset = self.train_dataset.add_column( - name="ref_rejected_logps", column=all_ref_rejected_logps - ) - - self._precomputed_train_ref_log_probs = True - - return super().get_train_dataloader() - - def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: - """ - Returns the evaluation [`~torch.utils.data.DataLoader`]. - - Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`. - - Args: - eval_dataset (`torch.utils.data.Dataset`, *optional*): - If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted - by the `model.forward()` method are automatically removed. It must implement `__len__`. - """ - if eval_dataset is None and self.eval_dataset is None: - raise ValueError("Trainer: evaluation requires an eval_dataset.") - eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset - - if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs: - batch_size = self.args.precompute_ref_batch_size or self.args.per_device_eval_batch_size - dataloader_params = { - "batch_size": batch_size, - "collate_fn": self.data_collator, - "num_workers": self.args.dataloader_num_workers, - "pin_memory": self.args.dataloader_pin_memory, - "shuffle": False, - } - - # prepare dataloader - data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params)) - - ref_chosen_logps = [] - ref_rejected_logps = [] - for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"): - ref_chosen_logp, ref_rejected_logp = self.compute_ref_log_probs(padded_batch) - ref_chosen_logp, ref_rejected_logp = self.accelerator.gather_for_metrics( - (ref_chosen_logp, ref_rejected_logp) - ) - ref_chosen_logps.append(ref_chosen_logp.cpu()) - ref_rejected_logps.append(ref_rejected_logp.cpu()) - - all_ref_chosen_logps = torch.cat(ref_chosen_logps).float().numpy() - all_ref_rejected_logps = torch.cat(ref_rejected_logps).float().numpy() - - eval_dataset = eval_dataset.add_column(name="ref_chosen_logps", column=all_ref_chosen_logps) - eval_dataset = eval_dataset.add_column(name="ref_rejected_logps", column=all_ref_rejected_logps) - - # Save calculated ref_chosen_logps and ref_rejected_logps to the eval_dataset for subsequent runs - if self.eval_dataset is not None: - self.eval_dataset = eval_dataset - self._precomputed_eval_ref_log_probs = True - - return super().get_eval_dataloader(eval_dataset=eval_dataset) - - @contextmanager - def null_ref_context(self): - """Context manager for handling null reference model (that is, peft adapter manipulation).""" - with ( - self.accelerator.unwrap_model(self.model).disable_adapter() - if self.is_peft_model and not self.ref_adapter_name - else nullcontext() - ): - if self.ref_adapter_name: - self.model.set_adapter(self.ref_adapter_name) - yield - if self.ref_adapter_name: - self.model.set_adapter(self.model_adapter_name or "default") - - def compute_ref_log_probs(self, batch: dict[str, torch.LongTensor]) -> tuple[torch.Tensor, torch.Tensor]: - """Computes log probabilities of the reference model for a single padded batch of a DPO specific dataset.""" - compte_ref_context_manager = ( - autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() - ) - with torch.no_grad(), compte_ref_context_manager: - if self.ref_model is None: - with self.null_ref_context(): - ref_model_output = self.concatenated_forward(self.model, batch, is_ref_model=True) - else: - ref_model_output = self.concatenated_forward(self.ref_model, batch, is_ref_model=True) - return ref_model_output["chosen_logps"], ref_model_output["rejected_logps"] - - @staticmethod - def concatenated_inputs( - batch: dict[str, Union[list, torch.LongTensor]], padding_value: int - ) -> dict[str, torch.LongTensor]: - """ - Concatenate the `chosen` and `rejected` inputs from the batch into a single tensor for both the prompt and - completion sequences. - - Args: - batch (`dict[str, Union[list, torch.LongTensor]]`): - A batch of input data. The batch must contain the following keys: - - - `"prompt_input_ids"`: Tensor of shape `(batch_size, prompt_length)` representing the prompt input - IDs. - - `"chosen_input_ids"`: Tensor of shape `(batch_size, chosen_length)` representing the chosen - completion input IDs. - - `"rejected_input_ids"`: Tensor of shape `(batch_size, rejected_length)` representing the rejected - completion input IDs. - - `"prompt_pixel_values"` (optional): Tensor for pixel values, if available. - - `"prompt_pixel_attention_mask"` (optional): Tensor for pixel attention masks, if available. - - padding_value (`int`): - The padding value to use for the concatenated completion sequences (`chosen_input_ids` and - `rejected_input_ids`). - - Returns: - `dict[str, torch.LongTensor]`: A dictionary containing: - - - `"prompt_input_ids"`: Concatenated prompt input IDs of shape `(2 * batch_size, prompt_length)`. - - `"completion_input_ids"`: Concatenated chosen and rejected completion input IDs of shape `(2 * - batch_size, max_completion_length)`. - - `"prompt_attention_mask"`: Concatenated prompt attention masks of shape `(2 * batch_size, - prompt_length)`. - - `"completion_attention_mask"`: Concatenated chosen and rejected attention masks of shape `(2 * - batch_size, max_completion_length)`. - - `"pixel_values"` (optional): Concatenated pixel values if `"prompt_pixel_values"` are present. - - `"pixel_attention_mask"` (optional): Concatenated pixel attention masks if - `"prompt_pixel_attention_mask"` are present. - - Notes: - The completion input IDs and attention masks are padded to the maximum completion length of the chosen or - rejected sequences. - """ - output = {} - - # For the prompt, the input_ids are the same for both the chosen and rejected responses - output["prompt_input_ids"] = torch.cat([batch["prompt_input_ids"], batch["prompt_input_ids"]], dim=0) - output["prompt_attention_mask"] = torch.cat( - [batch["prompt_attention_mask"], batch["prompt_attention_mask"]], dim=0 - ) - if "pixel_values" in batch: - output["pixel_values"] = torch.cat([batch["pixel_values"], batch["pixel_values"]], dim=0) - - if "pixel_attention_mask" in batch: - output["pixel_attention_mask"] = torch.cat( - [batch["pixel_attention_mask"], batch["pixel_attention_mask"]], dim=0 - ) - if "image_sizes" in batch: - output["image_sizes"] = torch.cat([batch["image_sizes"], batch["image_sizes"]], dim=0) - if "token_type_ids" in batch: - output["token_type_ids"] = torch.cat((batch["token_type_ids"], batch["token_type_ids"])) - - # Concatenate the chosen and rejected completions - max_completion_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1]) - output["completion_input_ids"] = torch.cat( - ( - pad_to_length(batch["chosen_input_ids"], max_completion_length, pad_value=padding_value), - pad_to_length(batch["rejected_input_ids"], max_completion_length, pad_value=padding_value), - ), - ) - output["completion_attention_mask"] = torch.cat( - ( - pad_to_length(batch["chosen_attention_mask"], max_completion_length, pad_value=0), - pad_to_length(batch["rejected_attention_mask"], max_completion_length, pad_value=0), - ), - ) - - return output - - def dpo_loss( - self, - chosen_logps: torch.FloatTensor, - rejected_logps: torch.FloatTensor, - ref_chosen_logps: torch.FloatTensor, - ref_rejected_logps: torch.FloatTensor, - loss_type: str = "sigmoid", - model_output: dict[str, torch.FloatTensor] = None, - ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: - """ - Compute the DPO loss for a batch of policy and reference model log probabilities. - - Args: - chosen_logps (`torch.FloatTensor`): - Log probabilities of the model for the chosen responses. Shape: `(batch_size,)`. - rejected_logps (`torch.FloatTensor`): - Log probabilities of the model for the rejected responses. Shape: `(batch_size,)`. - ref_chosen_logps (`torch.FloatTensor`): - Log probabilities of the reference model for the chosen responses. Shape: `(batch_size,)`. - ref_rejected_logps (`torch.FloatTensor`): - Log probabilities of the reference model for the rejected responses. Shape: `(batch_size,)`. - loss_type (`str`, defaults to `"sigmoid"`): - The type of loss to compute. One of: - - `"sigmoid"`: Sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper. - - `"hinge"`: Hinge loss on the normalized likelihood from the - [SLiC](https://huggingface.co/papers/2305.10425) paper. - - `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper. - - `"exo_pair"`: Pairwise EXO loss from the [EXO](https://huggingface.co/papers/2402.00856) paper. - - `"nca_pair"`: Pairwise NCA loss from the [NCA](https://huggingface.co/papers/2402.05369) paper. - - `"robust"`: Unbiased estimate of the DPO loss that is robust to preference noise from the [Robust - DPO](https://huggingface.co/papers/2403.00409) paper. - - `"bco_pair"`: Pairwise BCO loss from the [BCO](https://huggingface.co/papers/2404.04656) paper. - - `"sppo_hard"`: SPPO loss with hard label from the [SPPO](https://huggingface.co/papers/2405.00675) - paper. - - `"aot"`: AOT loss for paired datasets from the [AOT](https://huggingface.co/papers/2406.05882) paper. - - `"aot_pair"`: AOT loss for unpaired datasets from the [AOT](https://huggingface.co/papers/2406.05882) - paper. - - `"discopop"`: DiscoPOP (a.k.a Log-Ratio Modulated Loss, LRML) loss from the - [DiscoPOP](https://huggingface.co/papers/2406.08414) paper. - - `"apo_zero"`: APO-zero loss from the [APO](https://huggingface.co/papers/2408.06266) paper. - - `"apo_down"`: APO-down loss from the [APO](https://huggingface.co/papers/2408.06266) paper. - - `"sft"`: Negative log-likelihood loss (standard supervised fine-tuning loss). - model_output (`dict[str, torch.FloatTensor]`, *optional*): - The output of the model's forward pass. This is used to compute auxiliary losses if enabled. - - Returns: - A tuple of three tensors: `(losses, chosen_rewards, rejected_rewards)`. The losses tensor contains the DPO - loss for each example in the batch. The `chosen_rewards` and `rejected_rewards` tensors contain the rewards - for the chosen and rejected responses, respectively. - """ - device = self.accelerator.device - - # Get the log ratios for the chosen and rejected responses - chosen_logratios = chosen_logps.to(device) - (not self.reference_free) * ref_chosen_logps.to(device) - rejected_logratios = rejected_logps.to(device) - (not self.reference_free) * ref_rejected_logps.to(device) - - if self.f_divergence_type == FDivergenceType.ALPHA_DIVERGENCE: - # The alpha-divergence formula: (1 - u^-alpha) / alpha - # The divergence difference between the chosen and rejected sample is: - # (1 - u[w]^-alpha) / alpha - (1 - u[l]^-alpha) / alpha - # = (u[l]^-alpha - u[w]^-alpha) / alpha - # where u[w] and u[l] are the policy/reference probability ratios - # for the chosen and rejected samples, respectively. - alpha_coef = FDivergenceConstants.ALPHA_DIVERGENCE_COEF_DEFAULT - if self.f_divergence_params and FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY in self.f_divergence_params: - alpha_coef = float(self.f_divergence_params[FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY]) - logits = (cap_exp(rejected_logratios * -alpha_coef) - cap_exp(chosen_logratios * -alpha_coef)) / alpha_coef - else: - logratios = chosen_logps - rejected_logps - if self.reference_free: - ref_logratios = torch.tensor([0], dtype=logratios.dtype, device=logratios.device) - else: - ref_logratios = ref_chosen_logps - ref_rejected_logps - - logratios = logratios.to(self.accelerator.device) - ref_logratios = ref_logratios.to(self.accelerator.device) - logits = logratios - ref_logratios - - if self.f_divergence_type == FDivergenceType.JS_DIVERGENCE: - # The js-divergence formula: log(2 * u / (1 + u)) - # The divergence difference between the chosen and rejected sample is: - # log(2 * u[w] / (1 + u[w])) - log(2 * u[l] / (1 + u[l])) - # = log(u[w]) - log(u[l]) - (log(1 + u[w]) - log(1 + u[l])) - # where u[w] and u[l] are the policy/reference probability ratios - # for the chosen and rejected samples, respectively. - logits -= F.softplus(chosen_logratios) - F.softplus(rejected_logratios) - - # The beta is a temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. - # We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the - # labels and calculates a conservative DPO loss. - if loss_type == "sigmoid": - losses = ( - -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) - - F.logsigmoid(-self.beta * logits) * self.label_smoothing - ) - - elif loss_type == "robust": - losses = ( - -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) - + F.logsigmoid(-self.beta * logits) * self.label_smoothing - ) / (1 - 2 * self.label_smoothing) - - elif loss_type == "exo_pair": - # eqn (16) of the EXO paper: https://huggingface.co/papers/2402.00856 - import math - - if self.label_smoothing == 0: - self.label_smoothing = 1e-3 - losses = (self.beta * logits).sigmoid() * ( - F.logsigmoid(self.beta * logits) - math.log(1 - self.label_smoothing) - ) + (-self.beta * logits).sigmoid() * (F.logsigmoid(-self.beta * logits) - math.log(self.label_smoothing)) - - elif loss_type == "hinge": - losses = torch.relu(1 - self.beta * logits) - - elif loss_type == "ipo": - # eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper. - losses = (logits - 1 / (2 * self.beta)) ** 2 - - elif loss_type == "bco_pair": - chosen_logratios = chosen_logps - ref_chosen_logps - rejected_logratios = rejected_logps - ref_rejected_logps - chosen_rewards = self.beta * chosen_logratios - rejected_rewards = self.beta * rejected_logratios - rewards = torch.cat((chosen_rewards, rejected_rewards), 0).mean().detach() - self.running.update(rewards) - delta = self.running.mean - losses = -F.logsigmoid((self.beta * chosen_logratios) - delta) - F.logsigmoid( - -(self.beta * rejected_logratios - delta) - ) - - elif loss_type == "sppo_hard": - # In the paper (https://huggingface.co/papers/2405.00675), SPPO employs a soft probability approach, - # estimated using the PairRM score. The probability calculation is conducted outside of the trainer class. - # The version described here is the hard probability version, where P in Equation (4.7) of Algorithm 1 is - # set to 1 for the winner and 0 for the loser. - a = chosen_logps - ref_chosen_logps - b = rejected_logps - ref_rejected_logps - losses = (a - 0.5 / self.beta) ** 2 + (b + 0.5 / self.beta) ** 2 - - elif loss_type == "nca_pair": - chosen_rewards = (chosen_logps - ref_chosen_logps) * self.beta - rejected_rewards = (rejected_logps - ref_rejected_logps) * self.beta - losses = ( - -F.logsigmoid(chosen_rewards) - - 0.5 * F.logsigmoid(-chosen_rewards) - - 0.5 * F.logsigmoid(-rejected_rewards) - ) - - elif loss_type == "aot_pair": - chosen_logratios = chosen_logps - ref_chosen_logps - rejected_logratios = rejected_logps - ref_rejected_logps - chosen_logratios_sorted, _ = torch.sort(chosen_logratios, dim=0) - rejected_logratios_sorted, _ = torch.sort(rejected_logratios, dim=0) - delta = chosen_logratios_sorted - rejected_logratios_sorted - losses = ( - -F.logsigmoid(self.beta * delta) * (1 - self.label_smoothing) - - F.logsigmoid(-self.beta * delta) * self.label_smoothing - ) - - elif loss_type == "aot": - logratios = chosen_logps - rejected_logps - ref_logratios = ref_chosen_logps - ref_rejected_logps - logratios_sorted, _ = torch.sort(logratios, dim=0) - ref_logratios_sorted, _ = torch.sort(ref_logratios, dim=0) - delta = logratios_sorted - ref_logratios_sorted - losses = ( - -F.logsigmoid(self.beta * delta) * (1 - self.label_smoothing) - - F.logsigmoid(-self.beta * delta) * self.label_smoothing - ) - - elif loss_type == "apo_zero": - # Eqn (7) of the APO paper (https://huggingface.co/papers/2408.06266) - # Use this loss when you believe the chosen outputs are better than your model's default output - losses_chosen = 1 - F.sigmoid(self.beta * chosen_logratios) # Increase chosen likelihood - losses_rejected = F.sigmoid(self.beta * rejected_logratios) # Decrease rejected likelihood - losses = losses_chosen + losses_rejected - - elif loss_type == "apo_down": - # Eqn (8) of the APO paper (https://huggingface.co/papers/2408.06266) - # Use this loss when you believe the chosen outputs are worse than your model's default output. - # Decrease chosen likelihood and decrease rejected likelihood more - losses_chosen = F.sigmoid(self.beta * chosen_logratios) - losses_rejected = 1 - F.sigmoid(self.beta * (chosen_logratios - rejected_logratios)) - losses = losses_chosen + losses_rejected - - elif loss_type == "discopop": - # Eqn (5) of the DiscoPOP paper (https://huggingface.co/papers/2406.08414) - # This loss was discovered with LLM discovery - logratios = chosen_logps - rejected_logps - ref_logratios = ref_chosen_logps - ref_rejected_logps - logits = logratios - ref_logratios - logits = logits * self.beta - # Modulate the mixing coefficient based on the log ratio magnitudes - log_ratio_modulation = torch.sigmoid(logits / self.args.discopop_tau) - logistic_component = -F.logsigmoid(logits) - exp_component = torch.exp(-logits) - # Blend between logistic and exponential component based on log ratio modulation - losses = logistic_component * (1 - log_ratio_modulation) + exp_component * log_ratio_modulation - - elif loss_type == "sft": - # SFT loss is the negative log likelihood loss on chosen responses - # This acts as the generation loss component in MPO - sft_loss = model_output["nll_loss"] - # Create losses tensor with same shape as other losses (per-sample) - batch_size = chosen_logps.shape[0] - losses = sft_loss.expand(batch_size) - # For SFT, we don't have preference rewards, so use zeros - chosen_rewards = torch.zeros_like(chosen_logps) - rejected_rewards = torch.zeros_like(rejected_logps) - - else: - raise ValueError( - f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'exo_pair', " - "'nca_pair', 'robust', 'bco_pair', 'sppo_hard', 'aot', 'aot_pair', 'discopop', 'apo_zero', " - "'apo_down', 'sft']" - ) - - chosen_rewards = self.beta * (chosen_logps.to(device) - ref_chosen_logps.to(device)).detach() - rejected_rewards = self.beta * (rejected_logps.to(device) - ref_rejected_logps.to(device)).detach() - - return losses, chosen_rewards, rejected_rewards - - def _compute_loss_liger( - self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]] - ) -> dict[str, torch.Tensor]: - unwrapped_model = self.accelerator.unwrap_model(model) - concatenated_batch = self.concatenated_inputs(batch, padding_value=self.pad_token_id) - - model_kwargs = {} - if self.aux_loss_enabled: - model_kwargs["output_router_logits"] = True - - # Add the pixel values and attention masks for vision models - if "pixel_values" in concatenated_batch: - model_kwargs["pixel_values"] = concatenated_batch["pixel_values"] - if "pixel_attention_mask" in concatenated_batch: - model_kwargs["pixel_attention_mask"] = concatenated_batch["pixel_attention_mask"] - if "image_sizes" in concatenated_batch: - model_kwargs["image_sizes"] = concatenated_batch["image_sizes"] - - prompt_attention_mask = concatenated_batch["prompt_attention_mask"] - completion_attention_mask = concatenated_batch["completion_attention_mask"] - - if self.is_encoder_decoder: - # 1. Get encoder outputs - encoder_outputs = unwrapped_model.get_encoder()( - concatenated_batch["prompt_input_ids"], - attention_mask=concatenated_batch["prompt_attention_mask"], - return_dict=True, - ) - # 2. Prepare decoder inputs - decoder_input_ids = shift_tokens_right( - concatenated_batch["completion_input_ids"], - unwrapped_model.config.decoder_start_token_id, - ) - # 3. Get decoder outputs - decoder_outputs = unwrapped_model.get_decoder()( - input_ids=decoder_input_ids, - attention_mask=concatenated_batch["completion_attention_mask"], - encoder_hidden_states=encoder_outputs.last_hidden_state, - encoder_attention_mask=concatenated_batch["prompt_attention_mask"], - use_cache=False, - ) - hidden_states = decoder_outputs.last_hidden_state - - ref_hidden_states = None - if not self.reference_free and self.ref_model is not None: - unwrapped_ref_model = self.accelerator.unwrap_model(self.ref_model) - ref_encoder_outputs = unwrapped_ref_model.get_encoder()( - concatenated_batch["prompt_input_ids"], - attention_mask=concatenated_batch["prompt_attention_mask"], - return_dict=True, - ) - ref_decoder_outputs = unwrapped_ref_model.get_decoder()( - input_ids=decoder_input_ids, - attention_mask=concatenated_batch["completion_attention_mask"], - encoder_hidden_states=ref_encoder_outputs.last_hidden_state, - encoder_attention_mask=concatenated_batch["prompt_attention_mask"], - use_cache=False, - ) - ref_hidden_states = ref_decoder_outputs.last_hidden_state - elif not self.reference_free: - with self.null_ref_context(): - ref_encoder_outputs = unwrapped_model.get_encoder()( - concatenated_batch["prompt_input_ids"], - attention_mask=concatenated_batch["prompt_attention_mask"], - return_dict=True, - ) - ref_decoder_outputs = unwrapped_model.get_decoder()( - input_ids=decoder_input_ids, - attention_mask=concatenated_batch["completion_attention_mask"], - encoder_hidden_states=ref_encoder_outputs.last_hidden_state, - encoder_attention_mask=concatenated_batch["prompt_attention_mask"], - use_cache=False, - ) - ref_hidden_states = ref_decoder_outputs.last_hidden_state - - labels = concatenated_batch["completion_input_ids"] - loss_mask = completion_attention_mask.bool() - else: - # For decoder-only models - input_ids = torch.cat( - (concatenated_batch["prompt_input_ids"], concatenated_batch["completion_input_ids"]), dim=1 - ) - attention_mask = torch.cat( - (concatenated_batch["prompt_attention_mask"], concatenated_batch["completion_attention_mask"]), - dim=1, - ) - # Mask the prompt but not the completion for the loss - loss_mask = torch.cat( - (torch.zeros_like(prompt_attention_mask), completion_attention_mask), - dim=1, - ) - - # Flush and truncate - if self.max_length is not None and self.max_length < attention_mask.size(1): - if self.truncation_mode == "keep_start": - # Flush left to reduce the memory usage - # [[0, 0, x, x, x, x], -> [[x, x, x, x], - # [0, x, x, x, 0, 0]] [x, x, x, 0]] - attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask) - attention_mask = attention_mask[:, : self.max_length] - input_ids = input_ids[:, : self.max_length] - loss_mask = loss_mask[:, : self.max_length] - elif self.truncation_mode == "keep_end": - # Flush right before truncating left, then flush left - # [[0, 0, x, x, x, x], -> [[0, 0, x, x], - # [0, x, x, x, 0, 0]] [0, x, x, x]] - attention_mask, input_ids, loss_mask = flush_right(attention_mask, input_ids, loss_mask) - input_ids = input_ids[:, -self.max_length :] - attention_mask = attention_mask[:, -self.max_length :] - loss_mask = loss_mask[:, -self.max_length :] - attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask) - else: - raise ValueError( - f"Unknown truncation mode: '{self.truncation_mode}'. Should be one of ['keep_end', " - "'keep_start']." - ) - else: - # Flush left to reduce the memory usage - # [[0, 0, x, x, x, x], -> [[x, x, x, x], - # [0, x, x, x, 0, 0]] [x, x, x, 0]] - attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask) - - # Add logits_to_keep optimization - if self.use_logits_to_keep: - first_compute_index = loss_mask.nonzero(as_tuple=True)[1].min() - logits_to_keep = (loss_mask.shape[1] - first_compute_index).item() + 1 - model_kwargs["logits_to_keep"] = logits_to_keep - - model_kwargs["output_hidden_states"] = True - - # Add padding-free training support - if self.padding_free: - input_ids = input_ids[attention_mask.bool()].unsqueeze(0) - loss_mask = loss_mask[attention_mask.bool()].unsqueeze(0) - position_ids = attention_mask.cumsum(1)[attention_mask.bool()].unsqueeze(0) - 1 - model_kwargs["position_ids"] = position_ids - else: - model_kwargs["attention_mask"] = attention_mask - - # Get the base model outputs (before LM head) - if hasattr(unwrapped_model, "get_decoder") and unwrapped_model.get_decoder() is not None: - base_model = unwrapped_model.get_decoder() - else: - base_attr = getattr(unwrapped_model, "base_model_prefix", self.args.base_model_attribute_name) - base_model = getattr(unwrapped_model, base_attr, unwrapped_model) - - outputs = base_model( - input_ids, - use_cache=False, - **model_kwargs, - ) - hidden_states = outputs.last_hidden_state[:, :-1] - - # Get reference hidden states if needed - ref_hidden_states = None - if not self.reference_free and self.ref_model is not None: - unwrapped_ref_model = self.accelerator.unwrap_model(self.ref_model) - if hasattr(unwrapped_ref_model, "get_decoder") and unwrapped_ref_model.get_decoder() is not None: - ref_base_model = unwrapped_ref_model.get_decoder() - else: - ref_attr = getattr(unwrapped_ref_model, "base_model_prefix", self.args.base_model_attribute_name) - ref_base_model = getattr(unwrapped_ref_model, ref_attr, unwrapped_ref_model) - - ref_outputs = ref_base_model( - input_ids, - use_cache=False, - **model_kwargs, - ) - ref_hidden_states = ref_outputs.last_hidden_state[:, :-1] - elif not self.reference_free: - if hasattr(unwrapped_model, "get_decoder") and unwrapped_model.get_decoder() is not None: - ref_base_model = unwrapped_model.get_decoder() - else: - ref_attr = getattr(unwrapped_model, "base_model_prefix", self.args.base_model_attribute_name) - ref_base_model = getattr(unwrapped_model, ref_attr, unwrapped_model) - with self.null_ref_context(): - ref_outputs = ref_base_model( - input_ids, - use_cache=False, - **model_kwargs, - ) - ref_hidden_states = ref_outputs.last_hidden_state[:, :-1] - - masked_input_ids = torch.where(loss_mask != 0, input_ids, self.label_pad_token_id) - labels = masked_input_ids[:, 1:] # Shift right for casual LM - - # Get the LM head - lm_head = unwrapped_model.get_output_embeddings() - - # Get reference model weights if needed - ref_weight = None - ref_bias = None - if not self.reference_free: - if self.ref_model is not None: - unwrapped_ref_model = self.accelerator.unwrap_model(self.ref_model) - ref_lm_head = unwrapped_ref_model.get_output_embeddings() - else: - with self.null_ref_context(): - ref_lm_head = unwrapped_model.get_output_embeddings() - ref_weight = ref_lm_head.weight - ref_bias = ref_lm_head.bias if hasattr(ref_lm_head, "bias") else None - - # Compute loss using Liger kernel - loss_output = self.dpo_loss_fn( - lm_head.weight, - hidden_states, - labels, - bias=lm_head.bias if hasattr(lm_head, "bias") else None, - ref_input=ref_hidden_states if not self.reference_free else None, - ref_weight=ref_weight if not self.reference_free else None, - ref_bias=ref_bias if not self.reference_free else None, - ) - ( - loss, - (chosen_logps, rejected_logps, chosen_logits_mean, rejected_logits_mean, nll_loss, *aux_outputs), - ) = loss_output - - output = { - "loss": loss, - "chosen_logps": chosen_logps, - "rejected_logps": rejected_logps, - "mean_chosen_logits": chosen_logits_mean, - "mean_rejected_logits": rejected_logits_mean, - "nll_loss": nll_loss, - "chosen_rewards": aux_outputs[0], - "rejected_rewards": aux_outputs[1], - } - if self.aux_loss_enabled: - output["aux_loss"] = outputs.aux_loss - - return output - - def concatenated_forward( - self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]], is_ref_model: bool = False - ) -> dict[str, torch.Tensor]: - """ - Runs the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. - - We do this to avoid doing two forward passes, because it's faster for FSDP. - - Args: - model: - Model to run the forward pass on. - batch: - Batch of input data. - is_ref_model: - Whether this method is being called for the reference model. If `True`, length desensitization is not - applied. - """ - num_examples = batch["prompt_input_ids"].shape[0] - - concatenated_batch = self.concatenated_inputs(batch, padding_value=self.pad_token_id) - - model_kwargs = {"use_cache": False} - if self.aux_loss_enabled: - model_kwargs["output_router_logits"] = True - - # Add the pixel values and attention masks for vision models - if "pixel_values" in concatenated_batch: - model_kwargs["pixel_values"] = concatenated_batch["pixel_values"] - if "pixel_attention_mask" in concatenated_batch: - model_kwargs["pixel_attention_mask"] = concatenated_batch["pixel_attention_mask"] - if "image_sizes" in concatenated_batch: - model_kwargs["image_sizes"] = concatenated_batch["image_sizes"] - - prompt_input_ids = concatenated_batch["prompt_input_ids"] - prompt_attention_mask = concatenated_batch["prompt_attention_mask"] - completion_input_ids = concatenated_batch["completion_input_ids"] - completion_attention_mask = concatenated_batch["completion_attention_mask"] - if self.is_encoder_decoder: - labels = completion_input_ids - labels[completion_attention_mask == 0] = self.label_pad_token_id - outputs = model( - input_ids=prompt_input_ids, - attention_mask=prompt_attention_mask, - labels=labels, # we need the labels for the logits to be returned - **model_kwargs, - ) - logits = outputs.logits - loss_mask = completion_attention_mask.bool() - else: - # Concatenate the prompt and completion inputs - input_ids = torch.cat((prompt_input_ids, completion_input_ids), dim=1) - attention_mask = torch.cat((prompt_attention_mask, completion_attention_mask), dim=1) - if "token_type_ids" in concatenated_batch: - prompt_token_type_ids = concatenated_batch["token_type_ids"] - token_type_ids = pad_to_length(prompt_token_type_ids, input_ids.shape[1], 0) - # Mask the prompt but not the completion for the loss - loss_mask = torch.cat( - (torch.zeros_like(prompt_attention_mask), completion_attention_mask), - dim=1, - ) - - # Flush and truncate - if self.max_length is not None and self.max_length < attention_mask.size(1): - if self.truncation_mode == "keep_start": - # Flush left to reduce the memory usage - # [[0, 0, x, x, x, x], -> [[x, x, x, x], - # [0, x, x, x, 0, 0]] [x, x, x, 0]] - if "token_type_ids" in concatenated_batch: - attention_mask, input_ids, loss_mask, token_type_ids = flush_left( - attention_mask, input_ids, loss_mask, token_type_ids - ) - else: - attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask) - attention_mask = attention_mask[:, : self.max_length] - input_ids = input_ids[:, : self.max_length] - loss_mask = loss_mask[:, : self.max_length] - elif self.truncation_mode == "keep_end": - # Flush right before truncating left, then flush left - # [[0, 0, x, x, x, x], -> [[0, 0, x, x], - # [0, x, x, x, 0, 0]] [0, x, x, x]] - if "token_type_ids" in concatenated_batch: - attention_mask, input_ids, loss_mask, token_type_ids = flush_left( - attention_mask, input_ids, loss_mask, token_type_ids - ) - token_type_ids = token_type_ids[:, -self.max_length :] - else: - attention_mask, input_ids, loss_mask = flush_right(attention_mask, input_ids, loss_mask) - input_ids = input_ids[:, -self.max_length :] - attention_mask = attention_mask[:, -self.max_length :] - loss_mask = loss_mask[:, -self.max_length :] - if "token_type_ids" in concatenated_batch: - attention_mask, input_ids, loss_mask, token_type_ids = flush_left( - attention_mask, input_ids, loss_mask, token_type_ids - ) - else: - attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask) - else: - raise ValueError( - f"Unknown truncation mode: '{self.truncation_mode}'. Should be one of ['keep_end', " - "'keep_start']." - ) - else: - # Flush left to reduce the memory usage - # [[0, 0, x, x, x, x], -> [[x, x, x, x], - # [0, x, x, x, 0, 0]] [x, x, x, 0]] - if "token_type_ids" in concatenated_batch: - attention_mask, input_ids, loss_mask, token_type_ids = flush_left( - attention_mask, input_ids, loss_mask, token_type_ids - ) - else: - attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask) - - if "token_type_ids" in concatenated_batch: - model_kwargs["token_type_ids"] = token_type_ids - - if self.use_logits_to_keep: - # Compute logits_to_keep based on loss_mask pattern: - # [[0, 0, 0, x, x, x, x], - # [0, 0, 0, x, x, x, 0]] - # ^ start computing logits from here ([:, -(7-3+1):]) - first_compute_index = loss_mask.nonzero(as_tuple=True)[1].min() - logits_to_keep = (loss_mask.shape[1] - first_compute_index).item() + 1 # +1 for the first label - model_kwargs["logits_to_keep"] = logits_to_keep - - model_kwargs["output_hidden_states"] = True - - if self.padding_free: - # Flatten the input_ids, position_ids, and loss_mask - # input_ids = [[a, b, c, 0], -> input_ids = [[a, b, c, d, e, f, g]] - # [d, e, f, g]] position_ids = [[0, 1, 2, 0, 1, 2, 3]] - input_ids = input_ids[attention_mask.bool()].unsqueeze(0) - loss_mask = loss_mask[attention_mask.bool()].unsqueeze(0) - position_ids = attention_mask.cumsum(1)[attention_mask.bool()].unsqueeze(0) - 1 - model_kwargs["position_ids"] = position_ids - else: - model_kwargs["attention_mask"] = attention_mask - - outputs = model(input_ids, **model_kwargs) - logits = outputs.logits - - # Offset the logits by one to align with the labels - labels = torch.roll(input_ids, shifts=-1, dims=1) - loss_mask = torch.roll(loss_mask, shifts=-1, dims=1).bool() - - if self.use_logits_to_keep: - # Align labels with logits - # logits: -, -, [x2, x3, x4, x5, x6] - # ^ --------- ^ after logits[:, :-1, :] - # labels: [y0, y1, y2, y3, y4, y5, y6] - # ^ --------- ^ with logits_to_keep=4, [:, -4:] - # loss_mask: [0, 0, 0, 1, 1, 1, 1] - labels = labels[:, -logits_to_keep:] - loss_mask = loss_mask[:, -logits_to_keep:] - - if logits.shape[:2] != labels.shape[:2]: - # for LLaVA, the returned logits include the image tokens (placed before the text tokens) - seq_len = labels.shape[1] - logits = logits[:, -seq_len:] - - # Compute the log probabilities of the labels - labels[~loss_mask] = 0 # dummy token; we'll ignore the losses on these tokens later - per_token_logps = selective_log_softmax(logits, labels) - per_token_logps[~loss_mask] = 0 - per_token_logps = torch.roll(per_token_logps, shifts=1, dims=1) - - if self.padding_free: - # Unflatten the per_token_logps (shape: [1, sum_seq_len] -> [batch_size, seq_len]) - batch_size, seq_len = attention_mask.shape - per_token_logps_ = torch.zeros( - batch_size, seq_len, device=outputs.logits.device, dtype=outputs.logits.dtype - ) - per_token_logps_[attention_mask.bool()] = per_token_logps - per_token_logps = per_token_logps_ - - all_logps = per_token_logps[:, 1:].sum(-1) - - output = {} - - if self.use_weighting: - with torch.no_grad(): - # Eq (2) of the WPO paper: https://huggingface.co/papers/2406.11827 - logprobs = F.log_softmax(logits, dim=-1) - weights_adjustment_factor = torch.logsumexp(2 * logprobs, dim=-1) # same as sum(probs**2) in log space - per_token_logps_adjusted = per_token_logps - weights_adjustment_factor - all_weights = (per_token_logps_adjusted * loss_mask).sum(-1) / loss_mask.sum(-1) - chosen_weights = all_weights[:num_examples] - rejected_weights = all_weights[num_examples:] - output["policy_weights"] = torch.clamp(torch.exp(chosen_weights + rejected_weights), max=1) - - if self.args.rpo_alpha is not None or "sft" in self.loss_type: - # Only use the chosen logits for the RPO loss or SFT loss - chosen_logits = logits[:num_examples, :-1] if not self.is_encoder_decoder else logits[:num_examples] - chosen_labels = labels[:num_examples, :-1] if not self.is_encoder_decoder else labels[:num_examples] - - # Compute the log probabilities of the labels - output["nll_loss"] = F.cross_entropy( - torch.flatten(chosen_logits, end_dim=1), torch.flatten(chosen_labels, end_dim=1), ignore_index=0 - ) - - if "ipo" in self.loss_type: - all_logps = all_logps / loss_mask.sum(-1) - - if self.args.ld_alpha is not None and not is_ref_model: - # Compute response lengths based on loss_mask - completion_lengths = loss_mask.sum(dim=1) - - chosen_lengths = completion_lengths[:num_examples] - rejected_lengths = completion_lengths[num_examples:] - public_lengths = torch.min(chosen_lengths, rejected_lengths) # l_p in the paper - public_lengths = torch.cat([public_lengths, public_lengths], dim=0) - - seq_len = per_token_logps.size(1) - position_ids = torch.arange(seq_len, device=per_token_logps.device).expand_as(per_token_logps) - - ld_mask = position_ids < public_lengths.unsqueeze(1) - mask = position_ids < completion_lengths.unsqueeze(1) - - front_mask = (ld_mask & mask).float() - rear_mask = (~ld_mask & mask).float() - front_logps = (per_token_logps * front_mask).sum(dim=1) - rear_logps = (per_token_logps * rear_mask).sum(dim=1) - - all_logps = front_logps + self.args.ld_alpha * rear_logps - - output["chosen_logps"] = all_logps[:num_examples] - output["rejected_logps"] = all_logps[num_examples:] - - # Compute the mean logits - if self.padding_free: - # position_ids contains a sequence of range identifiers (e.g., [[0, 1, 2, 0, 1, 2, 3, ...]]). - # There are 2*num_examples ranges in total: the first half corresponds to the chosen tokens, - # and the second half to the rejected tokens. - # To find the start of the rejected tokens, we look for the num_examples+1-th zero in pos_id. - split_idx = (position_ids == 0).nonzero(as_tuple=True)[1][num_examples] - mean_chosen_logits = logits[0, :split_idx][loss_mask[0, :split_idx]].mean() - mean_rejected_logits = logits[0, split_idx:][loss_mask[0, split_idx:]].mean() - else: - mean_chosen_logits = logits[:num_examples][loss_mask[:num_examples]].mean() - mean_rejected_logits = logits[num_examples:][loss_mask[num_examples:]].mean() - - output["mean_chosen_logits"] = mean_chosen_logits - output["mean_rejected_logits"] = mean_rejected_logits - - if self.aux_loss_enabled: - output["aux_loss"] = outputs.aux_loss - - return output - - def get_batch_loss_metrics( - self, - model: Union[PreTrainedModel, nn.Module], - batch: dict[str, Union[list, torch.LongTensor]], - train_eval: Literal["train", "eval"] = "train", - ) -> tuple[torch.Tensor, dict[str, float]]: - """Compute the DPO loss and other metrics for the given batch of inputs for train or test.""" - metrics = {} - - if self.args.use_liger_loss: - model_output = self._compute_loss_liger(model, batch) - losses = model_output["loss"] - chosen_rewards = model_output["chosen_rewards"] - rejected_rewards = model_output["rejected_rewards"] - else: - model_output = self.concatenated_forward(model, batch) - - # if ref_chosen_logps and ref_rejected_logps in batch use them, otherwise use the reference model - if "ref_chosen_logps" in batch and "ref_rejected_logps" in batch: - ref_chosen_logps = batch["ref_chosen_logps"] - ref_rejected_logps = batch["ref_rejected_logps"] - else: - ref_chosen_logps, ref_rejected_logps = self.compute_ref_log_probs(batch) - - # Initialize combined losses - losses = 0 - chosen_rewards = 0 - rejected_rewards = 0 - - # Compute losses for each loss type - for idx, loss_type in enumerate(self.loss_type): - # Compute individual loss using standard DPO loss function - _losses, _chosen_rewards, _rejected_rewards = self.dpo_loss( - model_output["chosen_logps"], - model_output["rejected_logps"], - ref_chosen_logps, - ref_rejected_logps, - loss_type, - model_output, - ) - - # Add weighted contributions - weight = self.loss_weights[idx] if self.loss_weights else 1.0 - losses = losses + _losses * weight - chosen_rewards = chosen_rewards + _chosen_rewards * weight - rejected_rewards = rejected_rewards + _rejected_rewards * weight - - reward_accuracies = (chosen_rewards > rejected_rewards).float() - - if self.args.rpo_alpha is not None: - losses = losses + self.args.rpo_alpha * model_output["nll_loss"] # RPO loss from V3 of the paper - - if self.use_weighting: - losses = losses * model_output["policy_weights"] - - if self.aux_loss_enabled: - losses = losses + self.aux_loss_coef * model_output["aux_loss"] - - prefix = "eval_" if train_eval == "eval" else "" - metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean().item() - metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean().item() - metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean().item() - metrics[f"{prefix}rewards/margins"] = ( - self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards).mean().item() - ) - metrics[f"{prefix}logps/chosen"] = ( - self.accelerator.gather_for_metrics(model_output["chosen_logps"]).detach().mean().item() - ) - metrics[f"{prefix}logps/rejected"] = ( - self.accelerator.gather_for_metrics(model_output["rejected_logps"]).detach().mean().item() - ) - metrics[f"{prefix}logits/chosen"] = ( - self.accelerator.gather_for_metrics(model_output["mean_chosen_logits"]).detach().mean().item() - ) - metrics[f"{prefix}logits/rejected"] = ( - self.accelerator.gather_for_metrics(model_output["mean_rejected_logits"]).detach().mean().item() - ) - if self.args.rpo_alpha is not None or "sft" in self.loss_type: - metrics[f"{prefix}nll_loss"] = ( - self.accelerator.gather_for_metrics(model_output["nll_loss"]).detach().mean().item() - ) - if self.aux_loss_enabled: - metrics[f"{prefix}aux_loss"] = ( - self.accelerator.gather_for_metrics(model_output["aux_loss"]).detach().mean().item() - ) - - return losses.mean(), metrics - - def compute_loss( - self, - model: Union[PreTrainedModel, nn.Module], - inputs: dict[str, Union[torch.Tensor, Any]], - return_outputs=False, - num_items_in_batch=None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, float]]]: - compute_loss_context_manager = ( - autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() - ) - with compute_loss_context_manager: - loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train") - - # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class: - loss = loss.to(self.args.device) - # force log the metrics - self.store_metrics(metrics, train_eval="train") - - if return_outputs: - return loss, metrics - - return loss - - def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]: - """Generate samples from the model and reference model for the given batch of inputs.""" - - # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with - # the torch amp context manager as some hidden states are silently casted to full precision. - generate_context_manager = ( - autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() - ) - - with generate_context_manager: - policy_output = model.generate( - input_ids=batch["prompt_input_ids"], - attention_mask=batch["prompt_attention_mask"], - max_length=self.max_length, - do_sample=True, - pad_token_id=self.pad_token_id, - ) - - # if ref_output in batch use that otherwise use the reference model - if "ref_output" in batch: - ref_output = batch["ref_output"] - else: - if self.ref_model is None: - with self.null_ref_context(): - ref_output = self.model.generate( - input_ids=batch["prompt_input_ids"], - attention_mask=batch["prompt_attention_mask"], - max_length=self.max_length, - do_sample=True, - pad_token_id=self.pad_token_id, - ) - else: - ref_output = self.ref_model.generate( - input_ids=batch["prompt_input_ids"], - attention_mask=batch["prompt_attention_mask"], - max_length=self.max_length, - do_sample=True, - pad_token_id=self.pad_token_id, - ) - - policy_output = pad_to_length(policy_output, self.max_length, self.pad_token_id) - policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True) - - ref_output = pad_to_length(ref_output, self.max_length, self.pad_token_id) - ref_output_decoded = self.processing_class.batch_decode(ref_output, skip_special_tokens=True) - - return policy_output_decoded, ref_output_decoded - - def prediction_step( - self, - model: Union[PreTrainedModel, nn.Module], - inputs: dict[str, Union[torch.Tensor, Any]], - prediction_loss_only: bool, - ignore_keys: Optional[list[str]] = None, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: - if ignore_keys is None: - if hasattr(model, "config"): - ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", []) - else: - ignore_keys = [] - - prediction_context_manager = ( - autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() - ) - - with torch.no_grad(), prediction_context_manager: - loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval") - - # force log the metrics - self.store_metrics(metrics, train_eval="eval") - - if prediction_loss_only: - return loss.detach(), None, None - - # logits for the chosen and rejected samples from model - logits_dict = { - "eval_logits/chosen": metrics["eval_logits/chosen"], - "eval_logits/rejected": metrics["eval_logits/rejected"], - } - logits = [v for k, v in logits_dict.items() if k not in ignore_keys] - logits = torch.tensor(logits, device=self.accelerator.device) - labels = torch.zeros(logits.shape[0], device=self.accelerator.device) - - return (loss.detach(), logits, labels) - - def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None: - for key, value in metrics.items(): - self._stored_metrics[train_eval][key].append(value) - - def evaluation_loop( - self, - dataloader: DataLoader, - description: str, - prediction_loss_only: Optional[bool] = None, - ignore_keys: Optional[list[str]] = None, - metric_key_prefix: str = "eval", - ) -> EvalLoopOutput: - """ - Overriding built-in evaluation loop to store metrics for each batch. Prediction/evaluation loop, shared by - `Trainer.evaluate()` and `Trainer.predict()`. - - Works both with or without labels. - """ - - # Sample and save to game log if requested (for one batch to save time) - if self.generate_during_eval: - # Generate random indices within the range of the total number of samples - num_samples = len(dataloader.dataset) - random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size) - - # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader - random_batch_dataset = dataloader.dataset.select(random_indices) - random_batch = self.data_collator(random_batch_dataset) - random_batch = self._prepare_inputs(random_batch) - - policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, random_batch) - - table = pd.DataFrame( - columns=["Prompt", "Policy", "Ref Model"], - data=[ - [prompt, pol[len(prompt) :], ref[len(prompt) :]] - for prompt, pol, ref in zip( - random_batch_dataset["prompt"], policy_output_decoded, ref_output_decoded - ) - ], - ) - if "wandb" in self.args.report_to and self.accelerator.is_main_process: - wandb.log({"game_log": wandb.Table(data=table)}) - - if "comet_ml" in self.args.report_to: - log_table_to_comet_experiment( - name="game_log.csv", - table=table, - ) - - if "mlflow" in self.args.report_to and self.accelerator.is_main_process: - mlflow.log_table(data=table, artifact_file="game_log.json") - - # Base evaluation - initial_output = super().evaluation_loop( - dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix - ) - - return initial_output - - def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None: - """ - Log `logs` on the various objects watching training, including stored metrics. - - Args: - logs (`dict[str, float]`): - The values to log. - start_time (`float`, *optional*): - Start time of the training. - """ - # logs either has 'loss' or 'eval_loss' - train_eval = "train" if "loss" in logs else "eval" - # Add averaged stored metrics to logs - for key, metrics in self._stored_metrics[train_eval].items(): - logs[key] = torch.tensor(metrics).mean().item() - del self._stored_metrics[train_eval] - return super().log(logs, start_time) - - # Ensure the model card is saved along with the checkpoint - def _save_checkpoint(self, model, trial): - if self.args.hub_model_id is None: - model_name = Path(self.args.output_dir).name - else: - model_name = self.args.hub_model_id.split("/")[-1] - self.create_model_card(model_name=model_name) - super()._save_checkpoint(model, trial) -class UnslothDPOTrainer(_UnslothDPOTrainer): - """ - - Trainer for Direct Preference Optimization (DPO) method. - - This class is a wrapper around the [`transformers.Trainer`] class and inherits all of its attributes and methods. - - Args: - model (`Union[str, PreTrainedModel]`): - Model to be trained. Can be either: - - - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a - path to a *directory* containing model weights saved using - [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded - using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keyword arguments in - `args.model_init_kwargs`. - - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported. - ref_model ([`PreTrainedModelWrapper`]): - Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation - and loss. If no reference model is provided, the trainer will create a reference model with the same - architecture as the model to be optimized. - args ([`DPOConfig`], *optional*): - Configuration for this trainer. If `None`, a default configuration is used. - data_collator ([`~transformers.DataCollator`], *optional*): - Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`. - Will default to [`DataCollatorForPreference`]. - train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]): - Dataset to use for training. DPO supports [preference](#preference) type and. The format of the samples can - be either: - - - [Standard](dataset_formats#standard): Each sample contains plain text. - - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role - and content). - eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`): - Dataset to use for evaluation. It must meet the same requirements as `train_dataset`. - processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*): - Processing class used to process the data. If `None`, the processing class is loaded from the model's name - with [`~transformers.AutoTokenizer.from_pretrained`]. - compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): - The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return - a dictionary string to metric values. *Note* When passing TrainingArgs with `batch_eval_metrics` set to - `True`, your compute_metrics function must take a boolean `compute_result` argument. This will be triggered - after the last eval batch to signal that the function needs to calculate and return the global summary - statistics rather than accumulating the batch-level statistics. - callbacks (list of [`~transformers.TrainerCallback`], *optional*): - List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed - in [here](https://huggingface.co/docs/transformers/main_classes/callback). - - If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`] - method. - optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`): - A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your - model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`. - optimizer_cls_and_kwargs (`Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]`, *optional*): - A tuple containing the optimizer class and keyword arguments to use. Overrides `optim` and `optim_args` in - `args`. Incompatible with the `optimizers` argument. - preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*): - A function that preprocess the logits right before caching them at each evaluation step. Must take two - tensors, the logits and the labels, and return the logits once processed as desired. The modifications made - by this function will be reflected in the predictions received by `compute_metrics`. - - Note that the labels (second parameter) will be `None` if the dataset does not have them. - peft_config ([`~peft.PeftConfig`], *optional*): - PEFT configuration used to wrap the model. If `None`, the model is not wrapped. - - """ - def __init__( - self, - model, - ref_model = None, - args = None, - data_collator = None, - train_dataset = None, - eval_dataset = None, - processing_class = None, - compute_metrics = None, - callbacks = None, - optimizer_cls_and_kwargs = None, - preprocess_logits_for_metrics = None, - peft_config = None, - **kwargs - ): - if args is None: args = UnslothDPOConfig() - use_bf16 = getattr(args, 'bf16', False) - if type(use_bf16) is not bool: use_bf16 = False - use_fp16 = getattr(args, 'fp16', False) - if type(use_fp16) is not bool: use_fp16 = False - force_float32 = False - full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1' - if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'): - print('Unsloth: Switching to float32 training since model cannot work with float16') - force_float32 = True - mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') - dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None) - if dtype is None: dtype = model.get_input_embeddings().weight.dtype - from unsloth_zoo.utils import _get_dtype - dtype = _get_dtype(dtype) - float16 = dtype == torch.float16 - if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`') - if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`') - if force_float32: - # Forced float32 training - args.fp16 = False - args.bf16 = False - os.environ['ACCELERATE_MIXED_PRECISION'] = 'no' - if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no' - # args.mixed_precision is a new argument which needs to be set now - elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32': - # Mixed precision training - args.fp16 = float16 - args.bf16 = not float16 - os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16' - if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16' - # args.mixed_precision is a new argument which needs to be set now - elif mixed_precision_dtype == 'bfloat16': - # Both False since bfloat16 full finetuning doesn't do any autocasting. - args.fp16 = False - args.bf16 = False - os.environ['ACCELERATE_MIXED_PRECISION'] = 'no' - if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no' - # args.mixed_precision is a new argument which needs to be set now - - if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no': - args.eval_strategy = 'steps' - if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1 - ga_steps = getattr(args, 'gradient_accumulation_steps', None) - if ga_steps is not None and ga_steps > 1: - from transformers import __version__ as transformers_version - if Version(transformers_version) <= Version('4.45.2'): - print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n' - '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`') - if getattr(args, 'eval_strategy', 'no') != 'no': - eval_bsz = getattr(args, 'per_device_eval_batch_size', 8) - if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size - if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps - fp16_full_eval = getattr(args, 'fp16_full_eval', False) - if type(fp16_full_eval) is not bool: fp16_full_eval = False - bf16_full_eval = getattr(args, 'bf16_full_eval', False) - if type(bf16_full_eval) is not bool: bf16_full_eval = False - if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True - if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False - if force_float32: - args.bf16_full_eval = False - args.fp16_full_eval = False - elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16': - args.bf16_full_eval = True - args.fp16_full_eval = False - elif not bf16_full_eval and not fp16_full_eval: - args.bf16_full_eval = args.bf16 - args.fp16_full_eval = args.fp16 - _output_logits = False - if locals().get('compute_metrics', None) is not None: _output_logits = True - if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True - if _output_logits: - os.environ['UNSLOTH_RETURN_LOGITS'] = '1' - if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'): - pass - else: - model_max_seq_length = getattr(model, 'max_seq_length', None) - args_max_seq_length = getattr(args, 'max_seq_length', None) - if args_max_seq_length is None and model_max_seq_length is not None: - max_seq_length = model.max_seq_length - if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length - elif args_max_seq_length is not None and model_max_seq_length is not None: - if args_max_seq_length > model_max_seq_length: - print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but ' - 'the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.') - args.max_seq_length = model_max_seq_length - if model is not None and hasattr(model, 'for_training'): - model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True)) - if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right' - if 'processing_class' in locals(): - if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right' - if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right' - __tokenizer = processing_class if 'processing_class' in locals() else tokenizer - from unsloth_zoo.vision_utils import UnslothVisionDataCollator - if not isinstance(data_collator, UnslothVisionDataCollator): - if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names: - data_collator = TransformersDataCollatorForLanguageModeling( - __tokenizer, - mlm = False, - mlm_probability = 0.0, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names: - data_collator = DataCollatorForSeq2Seq( - __tokenizer, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - else: - if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False - if hasattr(args, 'dataset_text_field'): args.dataset_text_field = '' - if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True} - if not isinstance(data_collator, UnslothVisionDataCollator): - if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'): - if isinstance(data_collator, DataCollatorForSeq2Seq): - data_collator = DataCollatorForSeq2Seq( - __tokenizer.tokenizer, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - else: - data_collator = TransformersDataCollatorForLanguageModeling( - __tokenizer.tokenizer, - mlm = False, - mlm_probability = 0.0, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - other_metrics = [] - - from unsloth_zoo.logging_utils import PatchRLStatistics - PatchRLStatistics('dpo_trainer', other_metrics) - if hasattr(train_dataset, 'column_names'): - column_names = set(train_dataset.column_names) - check = ['chosen', 'rejected', 'prompt', 'chosen_input_ids', 'chosen_attention_mask', - 'chosen_labels', 'rejected_input_ids', 'rejected_attention_mask', 'rejected_labels', - 'prompt_input_ids', 'prompt_attention_mask'] - if all(x in column_names for x in check): - train_dataset = train_dataset.remove_columns(['chosen', 'rejected', 'prompt']) - del check, column_names - - # [TODO] Fix up DataParallel multiplying batch sizes - # [TODO] DDP works, but DP seems to not work? [TODO] - if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1: - if getattr(args, "_n_gpu", 1) != 1: - args._n_gpu = 1 - if "model" in locals() and hasattr(model, "for_training"): - model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True)) - super().__init__( - model = model, - ref_model = ref_model, - args = args, - data_collator = data_collator, - train_dataset = train_dataset, - eval_dataset = eval_dataset, - processing_class = processing_class, - compute_metrics = compute_metrics, - callbacks = callbacks, - optimizer_cls_and_kwargs = optimizer_cls_and_kwargs, - preprocess_logits_for_metrics = preprocess_logits_for_metrics, - peft_config = peft_config,**kwargs) - if "model" in locals() and hasattr(model, "for_inference"): - model.for_inference() - if hasattr(self, 'neftune_hook_handle'): - self.neftune_hook_handle.remove() - if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle - if getattr(args, 'neftune_noise_alpha', None) is not None: - model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha - pass - if hasattr(self, 'accelerator'): - scaler = self.accelerator.scaler - current_model = model - while hasattr(current_model, 'model'): - current_model.accelerator_scaler = scaler - current_model = current_model.model - current_model.accelerator_scaler = scaler - pass - if hasattr(self, 'train'): - self.train = MethodType(prepare_for_training_mode(self.__class__.train), self) - pass - if hasattr(self, 'llm') and self.llm is not None and hasattr(self.llm, 'get_tokenizer'): - _vllm_tok = self.llm.get_tokenizer() - _pc = getattr(self, 'processing_class', None) or getattr(self, 'tokenizer', None) - if _vllm_tok is not None and _pc is not None and getattr(_pc, 'chat_template', None) is not None and getattr(_vllm_tok, 'chat_template', None) is None: - _vllm_tok.chat_template = _pc.chat_template - pass - -pass - - -if hasattr(logger, "addFilter"): - import logging - class HideLoggingMessage(logging.Filter): - def __init__(self, text): self.text = text - def filter(self, x): return not (self.text in x.getMessage()) - pass - logger.addFilter(HideLoggingMessage("`use_cache=True`")) - diff --git a/unsloth_compiled_cache/UnslothGKDTrainer.py b/unsloth_compiled_cache/UnslothGKDTrainer.py deleted file mode 100644 index 07d22ac..0000000 --- a/unsloth_compiled_cache/UnslothGKDTrainer.py +++ /dev/null @@ -1,1311 +0,0 @@ -""" -2026.2.1 -2026.2.1 -4.57.6 -0.24.0 -__UNSLOTH_VERSIONING__ -""" - -# Unsloth auto generated code -# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see . - -from torch import Tensor -import torch -import torch.nn as nn -from torch.nn import functional as F -from unsloth_zoo.temporary_patches.common import torch_compile -from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable -from trl.trainer.gkd_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, Callable, DataCollator, DataCollatorForChatML, Dataset, EvalPrediction, F, FeatureExtractionMixin, GKDConfig, GKDTrainer, GenerationConfig, Optional, PeftConfig, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SFTTrainer, TrainerCallback, Union, disable_dropout_in_model, empty_cache, nn, os, prepare_deepspeed, random, textwrap, torch, unwrap_model_for_generation, warnings, AutoModelForCausalLM, BaseImageProcessor, Callable, DataCollator, DataCollatorForChatML, Dataset, EvalPrediction, F, FeatureExtractionMixin, GKDConfig, GenerationConfig, Optional, PeftConfig, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SFTTrainer, TrainerCallback, Union, disable_dropout_in_model, nn, os, prepare_deepspeed, torch, warnings) - - -import os -from typing import * -from dataclasses import dataclass, field -from packaging.version import Version -import torch -import numpy as np -from contextlib import nullcontext -from torch.nn import functional as F -import inspect -from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling -from transformers.training_args import ParallelMode -from unsloth_zoo.device_type import DEVICE_TYPE, device_synchronize - -# Wrap trainer with padding to right and enable training mode -# Also patches W&B since multiple runs must use wandb.finish() -import functools -from types import MethodType -try: - from unsloth_zoo.gradient_checkpointing import reset_unsloth_gradient_checkpointing_buffers -except: - def reset_unsloth_gradient_checkpointing_buffers(): pass -def prepare_for_training_mode(f): - @functools.wraps(f) - def wrapper(self, *args, **kwargs): - # Enable training mode - _was_training = None - # Get gradient checkpointing setting from training arguments - use_gc = getattr(self.args, 'gradient_checkpointing', True) - if hasattr(self, 'model') and hasattr(self.model, "training"): - _was_training = self.model.training - if hasattr(self, 'model') and hasattr(self.model, "for_training"): - self.model.for_training(use_gradient_checkpointing=use_gc) - output = f(self, *args, **kwargs) - # Restore previous mode when possible - if hasattr(self, 'model') and hasattr(self.model, "for_inference"): - if _was_training is False: - self.model.for_inference() - elif _was_training is True and hasattr(self.model, "for_training"): - self.model.for_training(use_gradient_checkpointing=use_gc) - # Reset gradient checkpointing buffers to free memory while staying ready for next run - try: - reset_unsloth_gradient_checkpointing_buffers() - except: - pass - # Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run - try: - import wandb - wandb.finish() - except: - pass - return output - return wrapper -pass - -torch_compile_options = { - "epilogue_fusion" : True, - "max_autotune" : False, - "shape_padding" : True, - "trace.enabled" : False, - "triton.cudagraphs" : False, -} - -@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) -def chunked_hidden_states_selective_log_softmax( - hidden_states: torch.Tensor, - lm_head: torch.Tensor, - index: torch.Tensor, - chunks: int = 4, - logit_scale_multiply: float = 0.0, - logit_scale_divide: float = 0.0, - logit_softcapping: float = 0.0, - temperature: float = 1.0, -) -> torch.Tensor: - # All Unsloth Zoo code licensed under AGPL3 - flat_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1]) - flat_index = index.reshape(-1) - - chunked_hidden_states = torch.chunk(flat_hidden_states, chunks=chunks, dim=0) - chunked_index = torch.chunk(flat_index, chunks=chunks, dim=0) - - all_per_token_logps = [] - - for chunk_hidden_states, chunk_index in zip(chunked_hidden_states, chunked_index): - chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t() - - if logit_scale_multiply != 0.0: - chunk_logits = chunk_logits * logit_scale_multiply - if logit_scale_divide != 0.0: - chunk_logits = chunk_logits / logit_scale_divide - if logit_softcapping != 0.0: - chunk_logits = chunk_logits * torch.tanh(chunk_logits / logit_softcapping) - - chunk_logits = chunk_logits.to(torch.float32) - - if temperature != 1.0: - chunk_logits = chunk_logits / temperature - - selected_logits = torch.gather(chunk_logits, dim=-1, index=chunk_index.unsqueeze(-1)).squeeze(-1) - logsumexp_values = torch.logsumexp(chunk_logits, dim=-1) - per_token_logps = selected_logits - logsumexp_values - all_per_token_logps.append(per_token_logps) - - all_per_token_logps = torch.concat(all_per_token_logps) - - all_per_token_logps = all_per_token_logps.reshape((hidden_states.shape[0], hidden_states.shape[1])) - return all_per_token_logps - -@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) -def chunked_selective_log_softmax(logits, index): - # Split into 4 chunks only - chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0) - chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0) - all_per_token_logps = [] - # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index) - for chunk_logits, chunk_index in zip(chunked_logits, chunked_index): - chunk_logits = chunk_logits.to(torch.float32) - selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1) - logsumexp_values = torch.logsumexp(chunk_logits, dim = -1) - per_token_logps = selected_logits - logsumexp_values - all_per_token_logps.append(per_token_logps) - pass - all_per_token_logps = torch.concat(all_per_token_logps) - all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1])) - return all_per_token_logps - -def calculate_pad_tokens_in_prompt( - input_ids: torch.Tensor, - logits_to_keep: int, - pad_token_id: int -) -> torch.Tensor: - """ - Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens - """ - if logits_to_keep >= input_ids.shape[1]: - raise ValueError("logits_to_keep must be smaller than the sequence length.") - - prompt_section = input_ids[:, :-logits_to_keep] - - padding_mask = (prompt_section == pad_token_id) - - pad_token_counts = padding_mask.sum(dim=1) - - return pad_token_counts - -def create_completion_attention_mask( - completion_input_ids: torch.Tensor, - left_pad_tokens_per_prompt: torch.Tensor, - max_left_pad: int, - pad_token_id: int -) -> torch.Tensor: - """ - Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad] - - Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens - and pad are pad tokens, this function would make a completion mask that would 0 out the pad - and p tokens. so in this example [0,0,0,1,1,1,0,0,0] - """ - batch_size, completion_len = completion_input_ids.shape - device = completion_input_ids.device - - num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt - - indices = torch.arange(completion_len, device=device).unsqueeze(0) - shift_mask = indices >= num_tokens_to_mask.unsqueeze(1) - - non_padding_mask = (completion_input_ids != pad_token_id) - - final_mask = shift_mask & non_padding_mask - - return final_mask - -def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor: - """ - Moves all padding tokens in each sequence of a batch to the right. - """ - mask = (tensor != pad_id) - # Must do stable=True since binary mark is unordered - sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True) - packed_tensor = torch.gather(tensor, 1, sorted_indices) - return packed_tensor - -def align_logprobs_with_mask( - logprob_tensor: torch.Tensor, - attention_mask: torch.Tensor, - pad_value: float = 0.0 -) -> torch.Tensor: - """ - Aligns a log probability tensor with a given attention mask. - """ - - device = logprob_tensor.device - batch_size, logprob_seq_len = logprob_tensor.shape - mask_seq_len = attention_mask.shape[1] - - padded_logprobs = torch.full( - attention_mask.shape, - fill_value=pad_value, - dtype=logprob_tensor.dtype, - device=device - ) - - left_pad_counts = torch.argmax(attention_mask, dim=1) - - cols = torch.arange(logprob_seq_len, device=device) - dest_indices = left_pad_counts.unsqueeze(1) + cols - - # Create destination row indices - # Shape: [batch_size, logprob_seq_len] - row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices) - - # --- 4. Filter out-of-bounds indices and perform assignment --- - # Create a mask to identify only the indices that are within the bounds - # of the target tensor's sequence length. - valid_mask = dest_indices < mask_seq_len - - # Use this mask to select only the valid row indices, column indices, - # and the corresponding values from the logprob tensor. - # This flattens the selected elements into 1D tensors. - valid_rows = row_indices[valid_mask] - valid_cols = dest_indices[valid_mask] - valid_vals = logprob_tensor[valid_mask] - - # Place the valid values into their correct positions in the padded tensor - # using a single, efficient advanced indexing operation. - padded_logprobs[valid_rows, valid_cols] = valid_vals - - return padded_logprobs - -def autotune_batch_and_chunks( - total_input_rows, - seq_len, - hidden_size, - vocab_size, - dtype_bytes=16, - multiplier=None -): - if multiplier is None: - final_m = max(4, seq_len // 4096) - else: - final_m = multiplier - - if torch.cuda.is_available(): - free_bytes, _ = torch.cuda.mem_get_info() - limit_gb = (free_bytes / (1024**3))*.80 - elif hasattr(torch, "xpu") and torch.xpu.is_available(): - # For XPU: estimate free memory from total - reserved - total_mem = torch.xpu.get_device_properties(0).total_memory - reserved_mem = torch.xpu.memory_reserved() - free_bytes = total_mem - reserved_mem - limit_gb = (free_bytes / (1024**3)) * 0.80 - else: - # Fallback: assume 8GB available - limit_gb = 8.0 - - bytes_to_gb = 1024**3 - - b_vals = torch.arange(total_input_rows, 0, -1, device='cpu', dtype=torch.float32) - - hidden_gb = (b_vals * seq_len * hidden_size * dtype_bytes) / bytes_to_gb - - base_logits = ((b_vals/total_input_rows) * b_vals * seq_len * vocab_size * dtype_bytes) / bytes_to_gb - logits_gb = base_logits / final_m - - total_mem_gb = hidden_gb + logits_gb - - valid_mask = total_mem_gb <= limit_gb - valid_indices = torch.nonzero(valid_mask, as_tuple=False) - - if valid_indices.shape[0] == 0: - #This means your GPU will OOM - return 4, final_m - - best_idx = valid_indices[0].item() - final_b = int(b_vals[best_idx].item()) - - return final_b, final_m -@dataclass -class UnslothGKDConfig(GKDConfig): - """ - - Configuration class for [`GKDTrainer`]. - - This class includes only the parameters that are specific to GKD training. For a full list of training arguments, - please refer to the [`~transformers.TrainingArguments`] and [`SFTConfig`] documentation. - - Args: - temperature (`float`, *optional*, defaults to `0.9`): - Temperature for sampling. The higher the temperature, the more random the completions. - lmbda (`float`, *optional*, defaults to `0.5`): - Lambda parameter that controls the student data fraction (i.e., the proportion of on-policy - student-generated outputs). - beta (`float`, *optional*, defaults to `0.5`): - Interpolation coefficient between `0.0` and `1.0` of the Generalized Jensen-Shannon Divergence loss. When - beta is `0.0`, the loss is the KL divergence. When beta is `1.0`, the loss is the Inverse KL Divergence. - max_new_tokens (`int`, *optional*, defaults to `128`): - Maximum number of tokens to generate per completion. - teacher_model_name_or_path (`str`, *optional*): - Model name or path of the teacher model. If `None`, the teacher model will be the same as the model being - trained. - teacher_model_init_kwargs (`dict[str, Any]]`, *optional*): - Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the teacher model - from a string. - disable_dropout (`bool`, *optional*, defaults to `True`): - Whether to disable dropout in the model. - seq_kd (`bool`, *optional*, defaults to `False`): - Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT on - teacher-generated output). - - """ - vllm_sampling_params: Optional[Any] = field( - default = None, - metadata = {'help': 'vLLM SamplingParams'}, - ) - unsloth_num_chunks : Optional[int] = field( - default = -1, - metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'}, - ) - unsloth_logit_chunk_multiplier : Optional[int] = field( - default = None, - metadata = {'help': 'Multiplier for chunked logit computations.'}, - ) - unsloth_grpo_mini_batch : Optional[int] = field( - default = None, - metadata = {'help': 'Mini batch size for GRPO hidden state accumulation. Default is None unless user defines it.'}, - ) - max_seq_length : Optional[int] = field( - default = None, - metadata = {'help': 'Maximum sequence length to truncate to.'}, - ) - def __init__( - self, - output_dir = None, - overwrite_output_dir = None, - do_train = False, - do_eval = False, - do_predict = False, - eval_strategy = 'no', - prediction_loss_only = False, - per_device_train_batch_size = 4, - per_device_eval_batch_size = 4, - per_gpu_train_batch_size = None, - per_gpu_eval_batch_size = None, - gradient_accumulation_steps = 2, - eval_accumulation_steps = 2, - eval_delay = 0, - torch_empty_cache_steps = 250, - learning_rate = 5e-05, - weight_decay = 0.01, - adam_beta1 = 0.9, - adam_beta2 = 0.999, - adam_epsilon = 1e-08, - max_grad_norm = 1.0, - num_train_epochs = 3.0, - max_steps = -1, - lr_scheduler_type = 'linear', - lr_scheduler_kwargs = None, - warmup_ratio = 0.1, - warmup_steps = 0, - log_level = 'passive', - log_level_replica = 'warning', - log_on_each_node = True, - logging_dir = None, - logging_strategy = 'steps', - logging_first_step = False, - logging_steps = 1, - logging_nan_inf_filter = False, - save_strategy = 'steps', - save_steps = 500, - save_total_limit = None, - save_safetensors = True, - save_on_each_node = False, - save_only_model = False, - restore_callback_states_from_checkpoint = False, - no_cuda = False, - use_cpu = False, - use_mps_device = False, - seed = 3407, - data_seed = 3407, - jit_mode_eval = False, - bf16 = False, - fp16 = False, - fp16_opt_level = 'O1', - half_precision_backend = 'auto', - bf16_full_eval = False, - fp16_full_eval = False, - tf32 = None, - local_rank = -1, - ddp_backend = None, - tpu_num_cores = None, - tpu_metrics_debug = False, - debug = '', - dataloader_drop_last = False, - eval_steps = None, - dataloader_num_workers = 0, - dataloader_prefetch_factor = None, - past_index = -1, - run_name = None, - disable_tqdm = None, - remove_unused_columns = True, - label_names = None, - load_best_model_at_end = False, - metric_for_best_model = None, - greater_is_better = None, - ignore_data_skip = False, - fsdp = None, - fsdp_min_num_params = 0, - fsdp_config = None, - fsdp_transformer_layer_cls_to_wrap = None, - accelerator_config = None, - parallelism_config = None, - deepspeed = None, - label_smoothing_factor = 0.0, - optim = 'adamw_8bit', - optim_args = None, - adafactor = False, - group_by_length = False, - length_column_name = 'length', - report_to = 'none', - project = 'huggingface', - trackio_space_id = 'trackio', - ddp_find_unused_parameters = None, - ddp_bucket_cap_mb = None, - ddp_broadcast_buffers = None, - dataloader_pin_memory = True, - dataloader_persistent_workers = False, - skip_memory_metrics = True, - use_legacy_prediction_loop = False, - push_to_hub = False, - resume_from_checkpoint = None, - hub_model_id = None, - hub_strategy = 'every_save', - hub_token = None, - hub_private_repo = None, - hub_always_push = False, - hub_revision = None, - gradient_checkpointing = True, - gradient_checkpointing_kwargs = None, - include_inputs_for_metrics = False, - eval_do_concat_batches = True, - fp16_backend = 'auto', - push_to_hub_model_id = None, - push_to_hub_organization = None, - push_to_hub_token = None, - mp_parameters = '', - auto_find_batch_size = False, - full_determinism = False, - torchdynamo = None, - ray_scope = 'last', - ddp_timeout = 1800, - torch_compile = False, - torch_compile_backend = None, - torch_compile_mode = None, - include_tokens_per_second = False, - include_num_input_tokens_seen = False, - neftune_noise_alpha = None, - optim_target_modules = None, - batch_eval_metrics = False, - eval_on_start = False, - use_liger_kernel = False, - liger_kernel_config = None, - eval_use_gather_object = False, - average_tokens_across_devices = True, - model_init_kwargs = None, - chat_template_path = None, - dataset_text_field = 'text', - dataset_kwargs = None, - dataset_num_proc = None, - eos_token = None, - pad_token = None, - max_length = 1024, - packing = False, - packing_strategy = 'bfd', - padding_free = False, - pad_to_multiple_of = None, - eval_packing = None, - completion_only_loss = None, - assistant_only_loss = False, - loss_type = 'nll', - activation_offloading = False, - temperature = 0.9, - lmbda = 0.5, - beta = 0.5, - max_new_tokens = 128, - teacher_model_name_or_path = None, - teacher_model_init_kwargs = None, - disable_dropout = True, - seq_kd = False, - vllm_sampling_params = None, - unsloth_num_chunks = -1, - unsloth_logit_chunk_multiplier = None, - unsloth_grpo_mini_batch = None, - max_seq_length = None, - **kwargs, - ): - if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!') - if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!') - if num_train_epochs is None: - num_train_epochs = 3.0 # Default to 3 epochs if None, max_steps will override - if output_dir is None and save_strategy == 'steps' and save_steps == 500: - output_dir = 'unsloth_training_checkpoints' - save_strategy = 'no' - import multiprocessing as _mp - if _mp.get_start_method() != 'fork': - dataset_num_proc = None - elif dataset_num_proc is None: - import psutil - dataset_num_proc = min(max((psutil.cpu_count() or 1)+4, 2), 64) - memory_gb_left = psutil.virtual_memory().available / (1024**3) - if memory_gb_left <= 2: dataset_num_proc = 1 - else: dataset_num_proc = min(dataset_num_proc, int(memory_gb_left)) - if os.environ.get('UNSLOTH_ENABLE_FLEX_ATTENTION', '0') == '1': - from unsloth_zoo.flex_attention import HAS_FLEX_ATTENTION - if HAS_FLEX_ATTENTION and pad_to_multiple_of is None: - from unsloth_zoo.flex_attention import FLEX_ATTENTION_BLOCK_SIZE - pad_to_multiple_of = FLEX_ATTENTION_BLOCK_SIZE - - if temperature <= 0: - raise ValueError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.') - elif temperature >= 10: - raise ValueError('Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic.') - - - super().__init__( - output_dir = output_dir, - overwrite_output_dir = overwrite_output_dir, - do_train = do_train, - do_eval = do_eval, - do_predict = do_predict, - eval_strategy = eval_strategy, - prediction_loss_only = prediction_loss_only, - per_device_train_batch_size = per_device_train_batch_size, - per_device_eval_batch_size = per_device_eval_batch_size, - per_gpu_train_batch_size = per_gpu_train_batch_size, - per_gpu_eval_batch_size = per_gpu_eval_batch_size, - gradient_accumulation_steps = gradient_accumulation_steps, - eval_accumulation_steps = eval_accumulation_steps, - eval_delay = eval_delay, - torch_empty_cache_steps = torch_empty_cache_steps, - learning_rate = learning_rate, - weight_decay = weight_decay, - adam_beta1 = adam_beta1, - adam_beta2 = adam_beta2, - adam_epsilon = adam_epsilon, - max_grad_norm = max_grad_norm, - num_train_epochs = num_train_epochs, - max_steps = max_steps, - lr_scheduler_type = lr_scheduler_type, - lr_scheduler_kwargs = lr_scheduler_kwargs, - warmup_ratio = warmup_ratio, - warmup_steps = warmup_steps, - log_level = log_level, - log_level_replica = log_level_replica, - log_on_each_node = log_on_each_node, - logging_dir = logging_dir, - logging_strategy = logging_strategy, - logging_first_step = logging_first_step, - logging_steps = logging_steps, - logging_nan_inf_filter = logging_nan_inf_filter, - save_strategy = save_strategy, - save_steps = save_steps, - save_total_limit = save_total_limit, - save_safetensors = save_safetensors, - save_on_each_node = save_on_each_node, - save_only_model = save_only_model, - restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint, - no_cuda = no_cuda, - use_cpu = use_cpu, - use_mps_device = use_mps_device, - seed = seed, - data_seed = data_seed, - jit_mode_eval = jit_mode_eval, - bf16 = bf16, - fp16 = fp16, - fp16_opt_level = fp16_opt_level, - half_precision_backend = half_precision_backend, - bf16_full_eval = bf16_full_eval, - fp16_full_eval = fp16_full_eval, - tf32 = tf32, - local_rank = local_rank, - ddp_backend = ddp_backend, - tpu_num_cores = tpu_num_cores, - tpu_metrics_debug = tpu_metrics_debug, - debug = debug, - dataloader_drop_last = dataloader_drop_last, - eval_steps = eval_steps, - dataloader_num_workers = dataloader_num_workers, - dataloader_prefetch_factor = dataloader_prefetch_factor, - past_index = past_index, - run_name = run_name, - disable_tqdm = disable_tqdm, - remove_unused_columns = remove_unused_columns, - label_names = label_names, - load_best_model_at_end = load_best_model_at_end, - metric_for_best_model = metric_for_best_model, - greater_is_better = greater_is_better, - ignore_data_skip = ignore_data_skip, - fsdp = fsdp, - fsdp_min_num_params = fsdp_min_num_params, - fsdp_config = fsdp_config, - fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap, - accelerator_config = accelerator_config, - parallelism_config = parallelism_config, - deepspeed = deepspeed, - label_smoothing_factor = label_smoothing_factor, - optim = optim, - optim_args = optim_args, - adafactor = adafactor, - group_by_length = group_by_length, - length_column_name = length_column_name, - report_to = report_to, - project = project, - trackio_space_id = trackio_space_id, - ddp_find_unused_parameters = ddp_find_unused_parameters, - ddp_bucket_cap_mb = ddp_bucket_cap_mb, - ddp_broadcast_buffers = ddp_broadcast_buffers, - dataloader_pin_memory = dataloader_pin_memory, - dataloader_persistent_workers = dataloader_persistent_workers, - skip_memory_metrics = skip_memory_metrics, - use_legacy_prediction_loop = use_legacy_prediction_loop, - push_to_hub = push_to_hub, - resume_from_checkpoint = resume_from_checkpoint, - hub_model_id = hub_model_id, - hub_strategy = hub_strategy, - hub_token = hub_token, - hub_private_repo = hub_private_repo, - hub_always_push = hub_always_push, - hub_revision = hub_revision, - gradient_checkpointing = gradient_checkpointing, - gradient_checkpointing_kwargs = gradient_checkpointing_kwargs, - include_inputs_for_metrics = include_inputs_for_metrics, - eval_do_concat_batches = eval_do_concat_batches, - fp16_backend = fp16_backend, - push_to_hub_model_id = push_to_hub_model_id, - push_to_hub_organization = push_to_hub_organization, - push_to_hub_token = push_to_hub_token, - mp_parameters = mp_parameters, - auto_find_batch_size = auto_find_batch_size, - full_determinism = full_determinism, - torchdynamo = torchdynamo, - ray_scope = ray_scope, - ddp_timeout = ddp_timeout, - torch_compile = torch_compile, - torch_compile_backend = torch_compile_backend, - torch_compile_mode = torch_compile_mode, - include_tokens_per_second = include_tokens_per_second, - include_num_input_tokens_seen = include_num_input_tokens_seen, - neftune_noise_alpha = neftune_noise_alpha, - optim_target_modules = optim_target_modules, - batch_eval_metrics = batch_eval_metrics, - eval_on_start = eval_on_start, - use_liger_kernel = use_liger_kernel, - liger_kernel_config = liger_kernel_config, - eval_use_gather_object = eval_use_gather_object, - average_tokens_across_devices = average_tokens_across_devices, - model_init_kwargs = model_init_kwargs, - chat_template_path = chat_template_path, - dataset_text_field = dataset_text_field, - dataset_kwargs = dataset_kwargs, - dataset_num_proc = dataset_num_proc, - eos_token = eos_token, - pad_token = pad_token, - max_length = max_length, - packing = packing, - packing_strategy = packing_strategy, - padding_free = padding_free, - pad_to_multiple_of = pad_to_multiple_of, - eval_packing = eval_packing, - completion_only_loss = completion_only_loss, - assistant_only_loss = assistant_only_loss, - loss_type = loss_type, - activation_offloading = activation_offloading, - temperature = temperature, - lmbda = lmbda, - beta = beta, - max_new_tokens = max_new_tokens, - teacher_model_name_or_path = teacher_model_name_or_path, - teacher_model_init_kwargs = teacher_model_init_kwargs, - disable_dropout = disable_dropout, - seq_kd = seq_kd,**kwargs) - self.vllm_sampling_params = vllm_sampling_params - self.unsloth_num_chunks = unsloth_num_chunks - if unsloth_grpo_mini_batch is not None: - if self.generation_batch_size >= unsloth_grpo_mini_batch: - self.unsloth_grpo_mini_batch = unsloth_grpo_mini_batch - else: - raise ValueError( - f"Unsloth GRPO mini batch size needs to be less than or equal to the effective generation batch size, " - f"which is self.per_device_train_batch_size * gradient_accumulation_steps." - ) - self.unsloth_logit_chunk_multiplier = unsloth_logit_chunk_multiplier - self.max_seq_length = max_seq_length - -pass - -class _UnslothGKDTrainer(SFTTrainer): - """""" - - _tag_names = ["trl", "gkd"] - _name = "GKD" - _paper = { - "title": "On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes", - "id": "2306.13649", - # docstyle-ignore - "citation": textwrap.dedent("""\ - @inproceedings{agarwal2024on-policy, - title = {{On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes}}, - author = {Rishabh Agarwal and Nino Vieillard and Yongchao Zhou and Piotr Stanczyk and Sabela Ramos Garea and Matthieu Geist and Olivier Bachem}, - year = 2024, - booktitle = {The Twelfth International Conference on Learning Representations, {ICLR} 2024, Vienna, Austria, May 7-11, 2024}, - publisher = {OpenReview.net}, - url = {https://openreview.net/forum?id=3zKtaqxLhW}, - }"""), - } - - def __init__( - self, - model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, - teacher_model: Union[PreTrainedModel, nn.Module, str] = None, - args: Optional[GKDConfig] = None, - data_collator: Optional[DataCollator] = None, # type: ignore - train_dataset: Optional[Dataset] = None, - eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None, - processing_class: Optional[ - Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] - ] = None, - compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None, - callbacks: Optional[list[TrainerCallback]] = None, - optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), - preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, - peft_config: Optional["PeftConfig"] = None, - formatting_func: Optional[Callable] = None, - ): - if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"): - warnings.warn( - "This trainer will soon be moved to trl.experimental and is a candidate for removal. If you rely on " - "it and want it to remain, please share your comments here: " - "https://github.com/huggingface/trl/issues/4223. Silence this warning by setting environment variable " - "TRL_EXPERIMENTAL_SILENCE=1." - ) - # Ensure Trainer does not drop non-signature columns used by the collator [e.g., "prompts"] - args.remove_unused_columns = False - # Respect a user-provided data_collator; otherwise, provide a ChatML collator that - if data_collator is None: - data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_length) - - # Ensure SFTTrainer does not pre-process the dataset when using a ChatML collator, - # so that raw conversational fields [e.g., "messages"] remain available to the collator. - if args.dataset_kwargs is None: - args.dataset_kwargs = {"skip_prepare_dataset": True} - else: - args.dataset_kwargs["skip_prepare_dataset"] = True - - # Liger fused GKD loss [JSD] - self.use_liger_gkd_loss = False - if args.use_liger_kernel: - self.liger_jsd_loss = LigerFusedLinearJSDLoss( - beta=args.beta, - ignore_index=-100, - temperature=args.temperature, - compiled=False, - ) - self.use_liger_gkd_loss = True - - super().__init__( - model, - args=args, - data_collator=data_collator, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - processing_class=processing_class, - compute_metrics=compute_metrics, - callbacks=callbacks, - optimizers=optimizers, - preprocess_logits_for_metrics=preprocess_logits_for_metrics, - peft_config=peft_config, - formatting_func=formatting_func, - ) - - if args.teacher_model_init_kwargs is None: - teacher_model_init_kwargs = {} - elif not isinstance(teacher_model, str): - raise ValueError( - "You passed teacher_model_init_kwargs to the GKDConfig, but your teacher_model is already instantiated." - ) - else: - teacher_model_init_kwargs = args.teacher_model_init_kwargs - teacher_model_init_kwargs["dtype"] = ( - teacher_model_init_kwargs["dtype"] - if teacher_model_init_kwargs["dtype"] in ["auto", None] - else getattr(torch, teacher_model_init_kwargs["dtype"]) - ) - - if isinstance(teacher_model, str): - teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model, **teacher_model_init_kwargs) - - # Disable dropout in the model - if args.disable_dropout: - disable_dropout_in_model(self.model) - - if self.is_deepspeed_enabled: - self.teacher_model = prepare_deepspeed(teacher_model, self.accelerator) - else: - self.teacher_model = self.accelerator.prepare_model(teacher_model, evaluation_mode=True) - - self.lmbda = args.lmbda - self.beta = args.beta - self.temperature = args.temperature - self.seq_kd = args.seq_kd - - self.generation_config = GenerationConfig( - max_new_tokens=args.max_new_tokens, - temperature=args.temperature, - do_sample=True, - top_k=0, - use_cache=False if args.gradient_checkpointing else True, - pad_token_id=self.processing_class.pad_token_id, - ) - # Set custom EOS tokens if they are specified by the model's generation - # config. This is important for models with the Llama 3 chat template, - # which use special tokens <|eot_id|> and <|eom_id|> to mark the end of - # turns or messages. - if ( - hasattr(self.model.generation_config, "eos_token_id") - and self.model.generation_config.eos_token_id is not None - ): - self.generation_config.eos_token_id = self.model.generation_config.eos_token_id - - @staticmethod - def generalized_jsd_loss( - student_logits, teacher_logits, labels=None, beta=0.5, temperature=1.0, reduction="batchmean" - ): - """ - Compute the generalized Jensen-Shannon Divergence loss for knowledge distillation using F.kl_div. See Eq. (1) - of https://huggingface.co/papers/2306.13649 for the definition. - - Args: - student_logits: - Tensor of shape (batch_size, sequence_length, vocab_size) - teacher_logits: - Tensor of shape (batch_size, sequence_length, vocab_size) - labels: - Tensor of shape (batch_size, sequence_length) with -100 for padding tokens to ignore when computing - loss - beta: - Interpolation coefficient between 0 and 1 (default: 0.5) - temperature: - Softmax temperature (default: 1.0) - reduction: - Specifies the reduction to apply to the output (default: 'batchmean') - - Returns: - loss: Scalar tensor with the generalized JSD loss - """ - - # Apply temperature scaling - student_logits = student_logits / temperature - teacher_logits = teacher_logits / temperature - - # Compute log probabilities for student and probabilities for teacher - student_log_probs = F.log_softmax(student_logits, dim=-1) - teacher_log_probs = F.log_softmax(teacher_logits, dim=-1) - - if beta == 0: - jsd = F.kl_div(student_log_probs, teacher_log_probs, reduction="none", log_target=True) - elif beta == 1: - jsd = F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True) - else: - # Compute the log of the mixture distribution - # log(a + b) = log(exp(log(a)) + exp(log(b))) -> for mixture - beta = torch.tensor(beta, dtype=student_log_probs.dtype) - mixture_log_probs = torch.logsumexp( - torch.stack([student_log_probs + torch.log(1 - beta), teacher_log_probs + torch.log(beta)]), - dim=0, - ) - - # Compute KL divergences using F.kl_div - # PyTorch differs from the standard mathematical definition, so the order of the probability distributions is swapped compared to that defined in the paper. - kl_teacher = F.kl_div(mixture_log_probs, teacher_log_probs, reduction="none", log_target=True) - kl_student = F.kl_div(mixture_log_probs, student_log_probs, reduction="none", log_target=True) - - # Compute the Generalized Jensen-Shannon Divergence - jsd = beta * kl_teacher + (1 - beta) * kl_student - - # Masking - if labels is not None: - mask = labels != -100 - jsd = jsd[mask] - - # Apply reduction - if reduction == "batchmean": - return jsd.sum() / mask.sum() if labels is not None else jsd.sum() / jsd.size(0) - elif reduction == "sum": - return jsd.sum() - elif reduction == "mean": - return jsd.mean() - else: - return jsd - - def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): - if self.use_liger_gkd_loss: - # Forward only through the base models (avoid lm_head to save memory) - unwrapped_student = self.accelerator.unwrap_model(model) - if hasattr(unwrapped_student, "get_decoder") and unwrapped_student.get_decoder() is not None: - base_student = unwrapped_student.get_decoder() - else: - base_student = getattr( - unwrapped_student, getattr(unwrapped_student, "base_model_prefix", "model"), unwrapped_student - ) - - student_outputs = base_student( - input_ids=inputs["input_ids"], - attention_mask=inputs["attention_mask"], - output_hidden_states=True, - use_cache=False, - ) - - self.teacher_model.eval() - unwrapped_teacher = self.accelerator.unwrap_model(self.teacher_model) - if hasattr(unwrapped_teacher, "get_decoder") and unwrapped_teacher.get_decoder() is not None: - base_teacher = unwrapped_teacher.get_decoder() - else: - base_teacher = getattr( - unwrapped_teacher, getattr(unwrapped_teacher, "base_model_prefix", "model"), unwrapped_teacher - ) - with torch.no_grad(): - teacher_outputs = base_teacher( - input_ids=inputs["input_ids"], - attention_mask=inputs["attention_mask"], - output_hidden_states=True, - use_cache=False, - ) - - # hidden states (shifted) - student_hidden = student_outputs.last_hidden_state[:, :-1].contiguous() - teacher_hidden = teacher_outputs.last_hidden_state[:, :-1].contiguous() - - # labels mask and labels (shifted) - labels_mask = inputs["labels"] != -100 - masked_input_ids = torch.where( - labels_mask, inputs["input_ids"], torch.full_like(inputs["input_ids"], -100) - ) - true_labels = masked_input_ids[:, 1:].contiguous() - - # heads - student_head = unwrapped_student.get_output_embeddings() - teacher_head = unwrapped_teacher.get_output_embeddings() - - # liger fused jsd loss - loss = self.liger_jsd_loss( - student_input=student_hidden, - student_weight=student_head.weight, - teacher_input=teacher_hidden, - teacher_weight=teacher_head.weight, - true_labels=true_labels, - student_bias=getattr(student_head, "bias", None), - teacher_bias=getattr(teacher_head, "bias", None), - ) - else: - # compute student output - student_outputs = model( - input_ids=inputs["input_ids"], - attention_mask=inputs["attention_mask"], - ) - - # compute teacher output in eval mode - self.teacher_model.eval() - with torch.no_grad(): - teacher_outputs = self.teacher_model( - input_ids=inputs["input_ids"], - attention_mask=inputs["attention_mask"], - ) - - # slice the logits for the generated tokens using the inputs["prompts"] lengths - prompt_lengths = inputs["prompts"].shape[1] - shifted_student_logits = student_outputs.logits[:, prompt_lengths - 1 : -1, :] - shifted_teacher_logits = teacher_outputs.logits[:, prompt_lengths - 1 : -1, :] - shifted_labels = inputs["labels"][:, prompt_lengths:] - - # compute loss - loss = self.generalized_jsd_loss( - student_logits=shifted_student_logits, - teacher_logits=shifted_teacher_logits, - labels=shifted_labels, - beta=self.beta, - ) - - # empty cache - empty_cache() - - # Return loss - return (loss, student_outputs) if return_outputs else loss - - @staticmethod - def generate_on_policy_outputs(model, inputs, generation_config, pad_token_id=None): - # Generate output with respect to the prompt-only - generated_outputs = model.generate( - input_ids=inputs["prompts"], - attention_mask=inputs.get("prompt_attention_mask", None), - generation_config=generation_config, - return_dict_in_generate=True, - ) - - # Get the generated token IDs - generated_tokens = generated_outputs.sequences - # Calculate new attention mask - new_attention_mask = torch.ones_like(generated_tokens) - new_labels = generated_tokens.clone() - - # If there's pad_token_id, set attention mask to 0 for padding tokens - if pad_token_id is not None: - new_labels[new_labels == pad_token_id] = -100 - new_attention_mask[generated_tokens == pad_token_id] = 0 - - return generated_tokens, new_attention_mask, new_labels - - def training_step( - self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None - ) -> torch.Tensor: - """ - Perform a training step for the Generalized Knowledge Distillation (GKD) model. - - This method implements the on-policy learning approach described in the GKD paper. With probability - `self.lmbda`, it generates new responses using the student model, which are then used for training instead of - the original inputs. - """ - if self.seq_kd: - with unwrap_model_for_generation(self.teacher_model, self.accelerator) as unwrapped_model: - new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs( - unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id - ) - inputs["input_ids"] = new_input_ids - inputs["attention_mask"] = new_attention_mask - inputs["labels"] = new_labels - if random.random() <= self.lmbda: - with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: - new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs( - unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id - ) - inputs["input_ids"] = new_input_ids - inputs["attention_mask"] = new_attention_mask - inputs["labels"] = new_labels - - loss = super().training_step(model, inputs, num_items_in_batch) - return loss -class UnslothGKDTrainer(_UnslothGKDTrainer): - """ - Trainer for Generalized Knowledge Distillation (GKD) of language models. - - For details on GKD, see the paper: [On-Policy Distillation of Language Models: Learning from Self-Generated - Mistakes](https://huggingface.co/papers/2306.13649). - - Args: - model ([`~transformers.PreTrainedModel`] or `torch.nn.Module` or `str`, *optional*): - Model to be trained, or the string identifier of the model to be instantiated from a pretrained model. - teacher_model ([`~transformers.PreTrainedModel`] or `torch.nn.Module` or `str`, *optional*): - Teacher model for knowledge distillation, or the string identifier of the model to be instantiated from a - pretrained model. - args ([`GKDConfig`], *optional*): - Training arguments. - data_collator ([`~transformers.DataCollator`], *optional*): - Data collator to batch samples from the dataset. It defaults to a [`DataCollatorForChatML`] using the - `processing_class`. - train_dataset ([`~datasets.Dataset`], *optional*): - Dataset for training. - eval_dataset ([`~datasets.Dataset`] or `dict` of [`~datasets.Dataset`], *optional*): - Dataset for evaluation. - processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*): - Class to process the data. - compute_metrics (`Callable`, *optional*): - Function to compute metrics at evaluation. Must take in an [`~transformers.EvalPrediction`] and return a - dictionary string to float. - callbacks (`list` of [`~transformers.TrainerCallback`], *optional*): - Callbacks to use during training. - optimizers (`tuple` of `torch.optim.Optimizer` and `torch.optim.lr_scheduler.LambdaLR`, *optional*, defaults to `(None, None)`): - Tuple containing the optimizer and the learning rate scheduler to use for training. - preprocess_logits_for_metrics (`Callable`, *optional*): - Function to preprocess the logits before computing the metrics. Must take in the `logits` and `labels` and - return the logits to be used for metrics computation. - peft_config ([`~peft.PeftConfig`], *optional*): - PEFT configuration to use PEFT for training. If `None`, PEFT is not used. If provided, the `model` will be - wrapped with the specified PEFT adapter. - formatting_func (`Callable`, *optional*): - Function to format the dataset. Must take in an example and return an example. - - """ - def __init__( - self, - model = None, - teacher_model = None, - args = None, - data_collator = None, - train_dataset = None, - eval_dataset = None, - processing_class = None, - compute_metrics = None, - callbacks = None, - preprocess_logits_for_metrics = None, - peft_config = None, - formatting_func = None, - **kwargs - ): - if args is None: args = UnslothGKDConfig() - use_bf16 = getattr(args, 'bf16', False) - if type(use_bf16) is not bool: use_bf16 = False - use_fp16 = getattr(args, 'fp16', False) - if type(use_fp16) is not bool: use_fp16 = False - force_float32 = False - full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1' - if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'): - print('Unsloth: Switching to float32 training since model cannot work with float16') - force_float32 = True - mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') - dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None) - if dtype is None: dtype = model.get_input_embeddings().weight.dtype - from unsloth_zoo.utils import _get_dtype - dtype = _get_dtype(dtype) - float16 = dtype == torch.float16 - if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`') - if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`') - if force_float32: - # Forced float32 training - args.fp16 = False - args.bf16 = False - os.environ['ACCELERATE_MIXED_PRECISION'] = 'no' - if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no' - # args.mixed_precision is a new argument which needs to be set now - elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32': - # Mixed precision training - args.fp16 = float16 - args.bf16 = not float16 - os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16' - if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16' - # args.mixed_precision is a new argument which needs to be set now - elif mixed_precision_dtype == 'bfloat16': - # Both False since bfloat16 full finetuning doesn't do any autocasting. - args.fp16 = False - args.bf16 = False - os.environ['ACCELERATE_MIXED_PRECISION'] = 'no' - if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no' - # args.mixed_precision is a new argument which needs to be set now - - if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no': - args.eval_strategy = 'steps' - if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1 - ga_steps = getattr(args, 'gradient_accumulation_steps', None) - if ga_steps is not None and ga_steps > 1: - from transformers import __version__ as transformers_version - if Version(transformers_version) <= Version('4.45.2'): - print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n' - '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`') - if getattr(args, 'eval_strategy', 'no') != 'no': - eval_bsz = getattr(args, 'per_device_eval_batch_size', 8) - if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size - if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps - fp16_full_eval = getattr(args, 'fp16_full_eval', False) - if type(fp16_full_eval) is not bool: fp16_full_eval = False - bf16_full_eval = getattr(args, 'bf16_full_eval', False) - if type(bf16_full_eval) is not bool: bf16_full_eval = False - if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True - if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False - if force_float32: - args.bf16_full_eval = False - args.fp16_full_eval = False - elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16': - args.bf16_full_eval = True - args.fp16_full_eval = False - elif not bf16_full_eval and not fp16_full_eval: - args.bf16_full_eval = args.bf16 - args.fp16_full_eval = args.fp16 - _output_logits = False - if locals().get('compute_metrics', None) is not None: _output_logits = True - if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True - if _output_logits: - os.environ['UNSLOTH_RETURN_LOGITS'] = '1' - if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'): - pass - else: - model_max_seq_length = getattr(model, 'max_seq_length', None) - args_max_seq_length = getattr(args, 'max_seq_length', None) - if args_max_seq_length is None and model_max_seq_length is not None: - max_seq_length = model.max_seq_length - if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length - elif args_max_seq_length is not None and model_max_seq_length is not None: - if args_max_seq_length > model_max_seq_length: - print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but ' - 'the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.') - args.max_seq_length = model_max_seq_length - if model is not None and hasattr(model, 'for_training'): - model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True)) - if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right' - if 'processing_class' in locals(): - if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right' - if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right' - __tokenizer = processing_class if 'processing_class' in locals() else tokenizer - from unsloth_zoo.vision_utils import UnslothVisionDataCollator - if not isinstance(data_collator, UnslothVisionDataCollator): - if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names: - data_collator = TransformersDataCollatorForLanguageModeling( - __tokenizer, - mlm = False, - mlm_probability = 0.0, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names: - data_collator = DataCollatorForSeq2Seq( - __tokenizer, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - else: - if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False - if hasattr(args, 'dataset_text_field'): args.dataset_text_field = '' - if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True} - if not isinstance(data_collator, UnslothVisionDataCollator): - if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'): - if isinstance(data_collator, DataCollatorForSeq2Seq): - data_collator = DataCollatorForSeq2Seq( - __tokenizer.tokenizer, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - else: - data_collator = TransformersDataCollatorForLanguageModeling( - __tokenizer.tokenizer, - mlm = False, - mlm_probability = 0.0, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - other_metrics = [] - - from unsloth_zoo.logging_utils import PatchRLStatistics - PatchRLStatistics('gkd_trainer', other_metrics) - - # [TODO] Fix up DataParallel multiplying batch sizes - # [TODO] DDP works, but DP seems to not work? [TODO] - if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1: - if getattr(args, "_n_gpu", 1) != 1: - args._n_gpu = 1 - if "model" in locals() and hasattr(model, "for_training"): - model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True)) - super().__init__( - model = model, - teacher_model = teacher_model, - args = args, - data_collator = data_collator, - train_dataset = train_dataset, - eval_dataset = eval_dataset, - processing_class = processing_class, - compute_metrics = compute_metrics, - callbacks = callbacks, - preprocess_logits_for_metrics = preprocess_logits_for_metrics, - peft_config = peft_config, - formatting_func = formatting_func,**kwargs) - if "model" in locals() and hasattr(model, "for_inference"): - model.for_inference() - if hasattr(self, 'neftune_hook_handle'): - self.neftune_hook_handle.remove() - if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle - if getattr(args, 'neftune_noise_alpha', None) is not None: - model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha - pass - if hasattr(self, 'accelerator'): - scaler = self.accelerator.scaler - current_model = model - while hasattr(current_model, 'model'): - current_model.accelerator_scaler = scaler - current_model = current_model.model - current_model.accelerator_scaler = scaler - pass - if hasattr(self, 'train'): - self.train = MethodType(prepare_for_training_mode(self.__class__.train), self) - pass - if hasattr(self, 'llm') and self.llm is not None and hasattr(self.llm, 'get_tokenizer'): - _vllm_tok = self.llm.get_tokenizer() - _pc = getattr(self, 'processing_class', None) or getattr(self, 'tokenizer', None) - if _vllm_tok is not None and _pc is not None and getattr(_pc, 'chat_template', None) is not None and getattr(_vllm_tok, 'chat_template', None) is None: - _vllm_tok.chat_template = _pc.chat_template - pass - -pass diff --git a/unsloth_compiled_cache/UnslothGRPOTrainer.py b/unsloth_compiled_cache/UnslothGRPOTrainer.py deleted file mode 100644 index 380d394..0000000 --- a/unsloth_compiled_cache/UnslothGRPOTrainer.py +++ /dev/null @@ -1,4196 +0,0 @@ -""" -2026.2.1 -2026.2.1 -4.57.6 -0.24.0 -__UNSLOTH_VERSIONING__ -""" - -# Unsloth auto generated code -# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see . - -from torch import Tensor -import torch -import torch.nn as nn -from torch.nn import functional as F -from unsloth_zoo.temporary_patches.common import torch_compile -from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable -from trl.trainer.grpo_trainer import (Any, AutoConfig, AutoModelForSequenceClassification, AutoProcessor, AutoTokenizer, BaseTrainer, DataLoader, Dataset, FSDP, GRPOConfig, GRPOTrainer, GenerationConfig, IterableDataset, Optional, Path, PeftConfig, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, RepeatSampler, RewardFunc, Sampler, SyncRefModelCallback, TrainerCallback, Union, VLLMClient, _ForwardRedirection, apply_chat_template, broadcast_object_list, datasets, defaultdict, deque, disable_dropout_in_model, ensure_master_addr_port, gather, gather_object, identity, inspect, is_conversational, is_datasets_available, is_flash_attn_2_available, is_liger_kernel_available, is_peft_model, is_rich_available, is_vllm_available, logger, logging, maybe_apply_chat_template, nanmax, nanmin, nanstd, nn, nullcontext, os, pad, partial, prepare_deepspeed, prepare_fsdp, prepare_multimodal_messages, print_prompt_completions_sample, profiling_context, profiling_decorator, seed_worker, selective_log_softmax, set_seed, shuffle_sequence_dict, split_pixel_values_by_grid, split_tensor_dict, textwrap, torch, transformers, unsplit_pixel_values_by_grid, unwrap_model_for_generation, AutoConfig, AutoModelForSequenceClassification, AutoProcessor, AutoTokenizer, Dataset, GRPOConfig, GRPOTrainer, GenerationConfig, IterableDataset, Optional, PeftConfig, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, RewardFunc, SyncRefModelCallback, TrainerCallback, Union, VLLMClient, datasets, defaultdict, deque, disable_dropout_in_model, ensure_master_addr_port, identity, inspect, is_liger_kernel_available, is_peft_model, is_vllm_available, logger, nn, os, pad, prepare_deepspeed, prepare_fsdp, set_seed, torch, transformers, Any, Union, gather, gather_object, is_conversational, logging, nanmax, nanmin, nanstd, os, pad, torch, FSDP, Optional, apply_chat_template, broadcast_object_list, gather, gather_object, is_flash_attn_2_available, maybe_apply_chat_template, nullcontext, os, pad, prepare_multimodal_messages, profiling_context, torch, transformers, unwrap_model_for_generation, os, pad, selective_log_softmax, torch, transformers, Any, Union, profiling_decorator, shuffle_sequence_dict, split_pixel_values_by_grid, split_tensor_dict, torch, unsplit_pixel_values_by_grid, Optional, PreTrainedModel, logger, os, torch, FSDP, nn, os, FSDP, nn, torch, GRPOTrainer, gather, nanmax, nanmin, os, pad, torch) - - -import os -from typing import * -from dataclasses import dataclass, field -from packaging.version import Version -import torch -import numpy as np -from contextlib import nullcontext -from torch.nn import functional as F -import inspect -from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling -from transformers.training_args import ParallelMode -from unsloth_zoo.device_type import DEVICE_TYPE, device_synchronize - -# Wrap trainer with padding to right and enable training mode -# Also patches W&B since multiple runs must use wandb.finish() -import functools -from types import MethodType -try: - from unsloth_zoo.gradient_checkpointing import reset_unsloth_gradient_checkpointing_buffers -except: - def reset_unsloth_gradient_checkpointing_buffers(): pass -def prepare_for_training_mode(f): - @functools.wraps(f) - def wrapper(self, *args, **kwargs): - # Enable training mode - _was_training = None - # Get gradient checkpointing setting from training arguments - use_gc = getattr(self.args, 'gradient_checkpointing', True) - if hasattr(self, 'model') and hasattr(self.model, "training"): - _was_training = self.model.training - if hasattr(self, 'model') and hasattr(self.model, "for_training"): - self.model.for_training(use_gradient_checkpointing=use_gc) - output = f(self, *args, **kwargs) - # Restore previous mode when possible - if hasattr(self, 'model') and hasattr(self.model, "for_inference"): - if _was_training is False: - self.model.for_inference() - elif _was_training is True and hasattr(self.model, "for_training"): - self.model.for_training(use_gradient_checkpointing=use_gc) - # Reset gradient checkpointing buffers to free memory while staying ready for next run - try: - reset_unsloth_gradient_checkpointing_buffers() - except: - pass - # Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run - try: - import wandb - wandb.finish() - except: - pass - return output - return wrapper -pass - -torch_compile_options = { - "epilogue_fusion" : True, - "max_autotune" : False, - "shape_padding" : True, - "trace.enabled" : False, - "triton.enable_persistent_tma_matmul": torch.cuda.get_device_capability()[0] >= 9, - "cuda.cutlass_epilogue_fusion_enabled": torch.cuda.get_device_capability()[0] >= 9, - "cuda.cutlass_tma_only": torch.cuda.get_device_capability()[0] >= 9, - "cuda.compile_opt_level" : "-O2", - "cuda.enable_cuda_lto" : True, - } - -@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) -def chunked_hidden_states_selective_log_softmax( - hidden_states: torch.Tensor, - lm_head: torch.Tensor, - index: torch.Tensor, - chunks: int = 4, - logit_scale_multiply: float = 0.0, - logit_scale_divide: float = 0.0, - logit_softcapping: float = 0.0, - temperature: float = 1.0, -) -> torch.Tensor: - # All Unsloth Zoo code licensed under AGPL3 - flat_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1]) - flat_index = index.reshape(-1) - - chunked_hidden_states = torch.chunk(flat_hidden_states, chunks=chunks, dim=0) - chunked_index = torch.chunk(flat_index, chunks=chunks, dim=0) - - all_per_token_logps = [] - - for chunk_hidden_states, chunk_index in zip(chunked_hidden_states, chunked_index): - chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t() - - if logit_scale_multiply != 0.0: - chunk_logits = chunk_logits * logit_scale_multiply - if logit_scale_divide != 0.0: - chunk_logits = chunk_logits / logit_scale_divide - if logit_softcapping != 0.0: - chunk_logits = chunk_logits * torch.tanh(chunk_logits / logit_softcapping) - - chunk_logits = chunk_logits.to(torch.float32) - - if temperature != 1.0: - chunk_logits = chunk_logits / temperature - - selected_logits = torch.gather(chunk_logits, dim=-1, index=chunk_index.unsqueeze(-1)).squeeze(-1) - logsumexp_values = torch.logsumexp(chunk_logits, dim=-1) - per_token_logps = selected_logits - logsumexp_values - all_per_token_logps.append(per_token_logps) - - all_per_token_logps = torch.concat(all_per_token_logps) - - all_per_token_logps = all_per_token_logps.reshape((hidden_states.shape[0], hidden_states.shape[1])) - return all_per_token_logps - -@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) -def chunked_selective_log_softmax(logits, index): - # Split into 4 chunks only - chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0) - chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0) - all_per_token_logps = [] - # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index) - for chunk_logits, chunk_index in zip(chunked_logits, chunked_index): - chunk_logits = chunk_logits.to(torch.float32) - selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1) - logsumexp_values = torch.logsumexp(chunk_logits, dim = -1) - per_token_logps = selected_logits - logsumexp_values - all_per_token_logps.append(per_token_logps) - pass - all_per_token_logps = torch.concat(all_per_token_logps) - all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1])) - return all_per_token_logps - -def calculate_pad_tokens_in_prompt( - input_ids: torch.Tensor, - logits_to_keep: int, - pad_token_id: int -) -> torch.Tensor: - """ - Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens - """ - if logits_to_keep >= input_ids.shape[1]: - raise ValueError("logits_to_keep must be smaller than the sequence length.") - - prompt_section = input_ids[:, :-logits_to_keep] - - padding_mask = (prompt_section == pad_token_id) - - pad_token_counts = padding_mask.sum(dim=1) - - return pad_token_counts - -def create_completion_attention_mask( - completion_input_ids: torch.Tensor, - left_pad_tokens_per_prompt: torch.Tensor, - max_left_pad: int, - pad_token_id: int -) -> torch.Tensor: - """ - Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad] - - Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens - and pad are pad tokens, this function would make a completion mask that would 0 out the pad - and p tokens. so in this example [0,0,0,1,1,1,0,0,0] - """ - batch_size, completion_len = completion_input_ids.shape - device = completion_input_ids.device - - num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt - - indices = torch.arange(completion_len, device=device).unsqueeze(0) - shift_mask = indices >= num_tokens_to_mask.unsqueeze(1) - - non_padding_mask = (completion_input_ids != pad_token_id) - - final_mask = shift_mask & non_padding_mask - - return final_mask - -def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor: - """ - Moves all padding tokens in each sequence of a batch to the right. - """ - mask = (tensor != pad_id) - # Must do stable=True since binary mark is unordered - sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True) - packed_tensor = torch.gather(tensor, 1, sorted_indices) - return packed_tensor - -def align_logprobs_with_mask( - logprob_tensor: torch.Tensor, - attention_mask: torch.Tensor, - pad_value: float = 0.0 -) -> torch.Tensor: - """ - Aligns a log probability tensor with a given attention mask. - """ - - device = logprob_tensor.device - batch_size, logprob_seq_len = logprob_tensor.shape - mask_seq_len = attention_mask.shape[1] - - padded_logprobs = torch.full( - attention_mask.shape, - fill_value=pad_value, - dtype=logprob_tensor.dtype, - device=device - ) - - left_pad_counts = torch.argmax(attention_mask, dim=1) - - cols = torch.arange(logprob_seq_len, device=device) - dest_indices = left_pad_counts.unsqueeze(1) + cols - - # Create destination row indices - # Shape: [batch_size, logprob_seq_len] - row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices) - - # --- 4. Filter out-of-bounds indices and perform assignment --- - # Create a mask to identify only the indices that are within the bounds - # of the target tensor's sequence length. - valid_mask = dest_indices < mask_seq_len - - # Use this mask to select only the valid row indices, column indices, - # and the corresponding values from the logprob tensor. - # This flattens the selected elements into 1D tensors. - valid_rows = row_indices[valid_mask] - valid_cols = dest_indices[valid_mask] - valid_vals = logprob_tensor[valid_mask] - - # Place the valid values into their correct positions in the padded tensor - # using a single, efficient advanced indexing operation. - padded_logprobs[valid_rows, valid_cols] = valid_vals - - return padded_logprobs - -def autotune_batch_and_chunks( - total_input_rows, - seq_len, - hidden_size, - vocab_size, - dtype_bytes=16, - multiplier=None -): - if multiplier is None: - final_m = max(4, seq_len // 4096) - else: - final_m = multiplier - - if torch.cuda.is_available(): - free_bytes, _ = torch.cuda.mem_get_info() - limit_gb = (free_bytes / (1024**3))*.80 - elif hasattr(torch, "xpu") and torch.xpu.is_available(): - # For XPU: estimate free memory from total - reserved - total_mem = torch.xpu.get_device_properties(0).total_memory - reserved_mem = torch.xpu.memory_reserved() - free_bytes = total_mem - reserved_mem - limit_gb = (free_bytes / (1024**3)) * 0.80 - else: - # Fallback: assume 8GB available - limit_gb = 8.0 - - bytes_to_gb = 1024**3 - - b_vals = torch.arange(total_input_rows, 0, -1, device='cpu', dtype=torch.float32) - - hidden_gb = (b_vals * seq_len * hidden_size * dtype_bytes) / bytes_to_gb - - base_logits = ((b_vals/total_input_rows) * b_vals * seq_len * vocab_size * dtype_bytes) / bytes_to_gb - logits_gb = base_logits / final_m - - total_mem_gb = hidden_gb + logits_gb - - valid_mask = total_mem_gb <= limit_gb - valid_indices = torch.nonzero(valid_mask, as_tuple=False) - - if valid_indices.shape[0] == 0: - #This means your GPU will OOM - return 4, final_m - - best_idx = valid_indices[0].item() - final_b = int(b_vals[best_idx].item()) - - return final_b, final_m -def grpo_compute_loss( - ref, - new, - old, - sampling_per_token_logps, - input_ids, - mask, - beta, - advantages, - **kwargs -): - # All Unsloth Zoo code licensed under AGPL3 - # Set defaults for optional arguments - loss_type = kwargs.get("loss_type", "grpo") - epsilon_low = kwargs.get("epsilon_low", 0.2) - epsilon_high = kwargs.get("epsilon_high", 0.2) - max_completion_length = kwargs.get("max_completion_length", 8192) - delta = kwargs.get("delta", None) - importance_sampling_level = kwargs.get("importance_sampling_level", "token") - num_items_in_batch = kwargs.get("num_items_in_batch", None) - current_gradient_accumulation_steps = kwargs.get("current_gradient_accumulation_steps", 1) - num_processes = kwargs.get("num_processes", 1) - use_vllm = kwargs.get("use_vllm", False) - vllm_importance_sampling_cap = kwargs.get("vllm_importance_sampling_cap", 2.0) - get_sapo_token_loss = kwargs.get("get_sapo_token_loss", None) - sapo_temperature_pos = kwargs.get("sapo_temperature_pos", 1.0) - sapo_temperature_neg = kwargs.get("sapo_temperature_neg", 1.05) - get_off_policy_mask = kwargs.get("get_off_policy_mask", None) - off_policy_mask_threshold = kwargs.get("off_policy_mask_threshold", None) - input_ids = input_ids.unsqueeze(-1) - - if advantages.dim() == 1: - advantages = advantages.unsqueeze(1) - - if off_policy_mask_threshold is not None: - off_policy_mask = get_off_policy_mask( - advantages=advantages, - per_token_logps=new, - old_per_token_logps=old, - mask=mask, - off_policy_threshold=off_policy_mask_threshold, - ) - - with torch.no_grad(): - if use_vllm and sampling_per_token_logps is not None: - #must filter out extra prompt tokens in begining after making input_ids left padded - importance_sampling_ratio = torch.exp((old * mask) - sampling_per_token_logps) - importance_sampling_ratio = torch.clamp( - importance_sampling_ratio, max=vllm_importance_sampling_cap - ) - pass - - # Must detach - otherwise gradients are not propagated correctly! - # exp(x - x) == 1 - # loss_i = torch.exp(new - new.detach()) * advantages.unsqueeze(1) - if old is not None: - log_ratio = new - old - else: - log_ratio = new - new.detach() - - if importance_sampling_level == "token": - log_importance_weights = log_ratio - elif importance_sampling_level == "sequence": - log_importance_weights = (log_ratio * mask).sum(-1) / mask.sum(-1).clamp(min=1.0) - log_importance_weights = log_importance_weights.unsqueeze(-1) - else: - raise ValueError( - f"Unknown importance sampling level: {importance_sampling_level}. Possible values are 'token' " - "and 'sequence'." - ) - - coef_1 = torch.exp(log_importance_weights) - - # Reverse KL - # Note that this is a low variance low bias estimator for the KL divergence as used in GRPO paper - if beta != 0.0: - kl_i = torch.exp(ref - new) - (ref - new) - 1.0 - - else: - # set kl_i to a tensor of zeros with the correct shape - if importance_sampling_level == "sequence": - kl_i = new.new_zeros(new.size(0), 1) - else: - kl_i = torch.zeros_like(new) - # Full correct reverse KL divergence?? Missing term maybe? - # kl_i = torch.exp(new) * kl_i - - # Below is forward KL (normal KL) - # kl_i = torch.exp(old) * (old - new) - if loss_type == "cispo": - clamped_ratios = torch.clamp(coef_1, max=epsilon_high).detach() - loss_i = -clamped_ratios * advantages * new - #breakpoint() - elif loss_type in ["grpo", "bnpo", "dr_grpo", "dapo"]: - coef_2 = torch.clamp(coef_1, 1 - epsilon_low, 1 + epsilon_high) - - if delta is not None: - loss_1 = torch.clamp(coef_1, max=delta) * advantages - else: - loss_1 = coef_1 * advantages - pass - loss_2 = coef_2 * advantages - loss_i = -torch.min(loss_1, loss_2) - elif loss_type == "sapo": - if get_sapo_token_loss is None: - raise Exception(f"sapo is only available in TRL 0.26.0+") - loss_i = torch.empty_like(coef_1) - positive_advantages_mask = advantages.repeat([1, coef_1.shape[1]]) > 0 - #since we have n_chunks some tensors may error if they dont have elements in them - if coef_1[positive_advantages_mask].numel() != 0: - loss_i[positive_advantages_mask] = get_sapo_token_loss( - coef_1[positive_advantages_mask], sapo_temperature_pos - ) - if coef_1[~positive_advantages_mask].numel() != 0: - loss_i[~positive_advantages_mask] = get_sapo_token_loss( - coef_1[~positive_advantages_mask], sapo_temperature_neg - ) - loss_i = -loss_i * advantages - else: - raise ValueError(f"Unknown loss type: {loss_type}") - - if off_policy_mask_threshold is not None: - loss_i = loss_i * off_policy_mask - - if use_vllm and sampling_per_token_logps is not None: - loss_i = loss_i * importance_sampling_ratio - #delta for metric - with torch.no_grad(): - delta = torch.abs(old - sampling_per_token_logps) - delta = delta * mask - flat_is_ratio = importance_sampling_ratio * mask - else: - delta = torch.tensor([]).detach() - flat_is_ratio = torch.tensor([]).detach() - if beta != 0.0: - loss_i = loss_i + beta * kl_i - - mask = mask.to(torch.float32) - n_mask_per_reward = mask.sum(1) - - # https://github.com/huggingface/trl/blob/e8b8499f1f8d76838155b515e414ee98f757d6d5/trl/trainer/grpo_trainer.py#L1624 - if loss_type in ["grpo", "sapo"]: - loss = ((loss_i * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)).mean() - loss = loss / current_gradient_accumulation_steps - elif loss_type == "bnpo": - loss = (loss_i * mask).sum() / mask.sum().clamp(min=1.0) - loss = loss / current_gradient_accumulation_steps - elif loss_type == "dr_grpo": - loss = (loss_i * mask).sum() / (loss_i.size(0) * max_completion_length) - loss = loss / current_gradient_accumulation_steps - elif loss_type in ["cispo", "dapo"]: - normalizer = num_items_in_batch/ num_processes - loss = (loss_i * mask).sum() / normalizer - else: - raise ValueError(f"Unknown loss type: {loss_type}") - - # loss = (loss_i * mask).sum() / mask.sum() - - # Get metrics as well which are folded - def masked_batch_mean(x): - with torch.inference_mode(): - completion_length = n_mask_per_reward.mean() - if x.shape[1] == 1: # when importance_sampling_level == "sequence" - return completion_length, x.mean() - else: - mean_kl_per_reward = (x * mask).sum(1) / n_mask_per_reward - mean_kl = mean_kl_per_reward.mean() - return completion_length, mean_kl - completion_length, mean_kl = masked_batch_mean(kl_i) - return loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1 - -class UnslothEfficientGRPO(torch.autograd.Function): - # All Unsloth Zoo code licensed under AGPL3 - @staticmethod - def forward(ctx, _new_logps, _old_logps, _ref_logps, _sampling_per_token_logps, lm_head, _input_ids, _mask, _advantages, beta, scaler = None, n_chunks = 1, extra_kwargs=None): - if extra_kwargs is None: - extra_kwargs = {} - def compute_loss(new_logps, old_logps, ref_logps, sampling_per_token_logps, input_ids, mask, advantages, scaling): - loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1 = grpo_compute_loss( - ref_logps, - new_logps, - old_logps, - sampling_per_token_logps, - input_ids, - mask, - beta, - advantages, - **extra_kwargs, - ) - - # Scale loss if needed for mixed precision training - scaled_loss = loss * scaling - # Must add .loss.detach otherwise autograd uses 2x VRAM - return scaled_loss, (loss.detach(), completion_length, mean_kl, delta, flat_is_ratio, coef_1) - pass - - device =_new_logps.device - grad_inputs = torch.empty_like(_new_logps) - accumulated_loss = torch.zeros(1, device = device) - accumulated_completion_length = torch.zeros(1, device = device) - accumulated_mean_kl = torch.zeros(1, device = device) - accumulated_delta = [] - accumulated_flat_is_ratio = [] - accumulated_coef_1 = [] - - def accumulate_chunk( - new_logps_j, - old_logps_j, - ref_logps_j, - sampling_per_token_logps_j, - input_ids_j, - mask_j, - advantages_j, - scaling, - grad_inputs_j, - ): - (chunk_grad_input,), (chunk_loss, (unscaled_loss, chunk_completion_length, chunk_mean_kl, chunk_delta, chunk_flat_is_ratio, chunk_coef_1)) = torch.func.grad_and_value( - compute_loss, - argnums = (0,), - has_aux = True, - )(new_logps_j, old_logps_j, ref_logps_j, sampling_per_token_logps_j, input_ids_j, mask_j, advantages_j, scaling) - accumulated_loss .add_(unscaled_loss) - accumulated_completion_length.add_(chunk_completion_length) - accumulated_mean_kl .add_(chunk_mean_kl) - accumulated_delta .append(chunk_delta) - accumulated_flat_is_ratio .append(chunk_flat_is_ratio) - accumulated_coef_1 .append(chunk_coef_1) - grad_inputs_j[:] = chunk_grad_input - pass - - accumulate_chunk = torch.compile( - accumulate_chunk, - fullgraph = True, - # [TODO] Dynamic marking causes torch.compile errors if sequence length is long - dynamic = True, - options = torch_compile_options, - ) - - grad_inputs_chunks = torch.chunk(grad_inputs, chunks = n_chunks, dim = 0) - new_logps = torch.chunk(_new_logps, chunks = n_chunks, dim = 0) - if _old_logps is not None: - old_logps = torch.chunk(_old_logps, chunks = n_chunks, dim = 0) - else: - old_logps = [None] * n_chunks - if _ref_logps is not None: - ref_logps = torch.chunk(_ref_logps, chunks = n_chunks, dim = 0) - else: - ref_logps = [None] * n_chunks - if _sampling_per_token_logps is not None: - sampling_per_token_logps = torch.chunk(_sampling_per_token_logps, chunks = n_chunks, dim = 0) - else: - sampling_per_token_logps = [None] * n_chunks - input_ids = torch.chunk(_input_ids, chunks = n_chunks, dim = 0) - mask = torch.chunk(_mask, chunks = n_chunks, dim = 0) - advantages = torch.chunk(_advantages, chunks = n_chunks, dim = 0) - - # Get mixed precision scaling if seen - scaling = scaler.get_scale() if scaler is not None else 1.0 - - # Force torch.compile to use dynamic shapes for seqlen dim - # mark_dynamic = lambda x: torch._dynamo.mark_dynamic(x, 1) - - for (grad_inputs_j, new_logps_j, old_logps_j, ref_logps_j, sampling_per_token_logps_j, input_ids_j, mask_j, advantages_j, ) in\ - zip(grad_inputs_chunks, new_logps, old_logps, ref_logps, sampling_per_token_logps, input_ids, mask, advantages): - - # [TODO] Dynamic marking causes torch.compile errors if sequence length is long - - # mark_dynamic(new_hidden_states_j) - # mark_dynamic(ref_hidden_states_j) - # if old_hidden_states_j is not None: - # mark_dynamic(old_hidden_states_j) - # mark_dynamic(input_ids_j) - # mark_dynamic(mask_j) - accumulate_chunk( - new_logps_j, - old_logps_j, - ref_logps_j, - sampling_per_token_logps_j, - input_ids_j, - mask_j, - advantages_j, - scaling, - grad_inputs_j, - ) - pass - - grad_inputs .div_(n_chunks) - accumulated_loss .div_(n_chunks) - accumulated_completion_length.div_(n_chunks) - accumulated_mean_kl .div_(n_chunks) - - if _sampling_per_token_logps is not None: - accumulated_delta = torch.cat(accumulated_delta, dim=0) - accumulated_flat_is_ratio = torch.cat(accumulated_flat_is_ratio, dim=0) - else: - accumulated_delta = None - accumulated_flat_is_ratio = None - accumulated_coef_1 = torch.cat(accumulated_coef_1, dim=0) - ctx.save_for_backward(grad_inputs) - return ( - accumulated_loss, - accumulated_completion_length, - accumulated_mean_kl, - accumulated_delta, - accumulated_flat_is_ratio, - accumulated_coef_1 - ) - pass - - @staticmethod - def backward(ctx, grad_output, dcompletion_length, dmean_kl, ddelta, ddflat_is_ratio, dcoef_1): - (grad_input,) = ctx.saved_tensors - return (grad_input, None, None, None, None, None, None, None, None, None, None, None) - pass - -def grpo_accumulated_loss( - trainer, - input_ids, - attention_mask, - logits_to_keep, - completion_mask, - advantages, - old_logps, - ref_logps, - n_chunks = -1, - **kwargs, -): - # All Unsloth Zoo code licensed under AGPL3 - bsz, qlen = input_ids.shape - - pixel_values = kwargs.get('pixel_values',None) - image_grid_thw = kwargs.get('image_grid_thw',None) - pixel_attention_mask = kwargs.get('pixel_attention_mask',None) - image_sizes = kwargs.get('image_sizes',None) - sampling_per_token_logps = kwargs.get("sampling_per_token_logps", None) if getattr(trainer, "vllm_importance_sampling_correction", False) else None - temperature = kwargs.get("temperature", 1.0) - logit_scale_multiply = kwargs.get("logit_scale_multiply", 0.0) - logit_scale_divide = kwargs.get("logit_scale_divide", 0.0) - logit_softcapping = kwargs.get("logit_softcapping", 0.0) - prev_max_left_pad = kwargs.get("max_left_pad", 0) #Always get max_left_pad for when training LLMs, enabled by deafult. - - #Delete this from kwargs so less issues - _ = kwargs.pop("sampling_per_token_logps", None) - kwargs["vllm_importance_sampling_cap"] = trainer.vllm_importance_sampling_cap if sampling_per_token_logps is not None else None - kwargs["get_sapo_token_loss"] = trainer.get_sapo_token_loss if hasattr(trainer, "get_sapo_token_loss") else None - kwargs["sapo_temperature_pos"] = trainer.args.sapo_temperature_pos if hasattr(trainer.args, "sapo_temperature_pos") else None - kwargs["sapo_temperature_neg"] = trainer.args.sapo_temperature_neg if hasattr(trainer.args, "sapo_temperature_neg") else None - kwargs["get_off_policy_mask"] = trainer.get_off_policy_mask if hasattr(trainer, "get_off_policy_mask") else None - kwargs["off_policy_mask_threshold"] = trainer.args.off_policy_mask_threshold if hasattr(trainer.args, "off_policy_mask_threshold") else None - kwargs["use_vllm"] = trainer.use_vllm - # Find closest multiple - factors = [i for i in range(1, bsz + 1) if bsz % i == 0] - if n_chunks == -1: n_chunks = bsz - n_chunks = factors[min(np.searchsorted(factors, n_chunks), len(factors)-1)] - - if not hasattr(trainer, '_autocast_dtype'): - trainer._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16 - if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1': trainer._autocast_dtype = None - pass - os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1" - - lm_head = trainer.model.get_output_embeddings().weight - dtype_bytes = 16 if trainer._autocast_dtype in [torch.float16, torch.bfloat16] else 32 - - total_rows = input_ids.shape[0] - seq_len = input_ids.shape[1] - hidden_dim = lm_head.shape[1] - vocab_dim = lm_head.shape[0] - - if trainer.args.unsloth_grpo_mini_batch is None: - if not hasattr(trainer, "_has_autotuned"): - trainer._has_autotuned = True - B, multiplier = autotune_batch_and_chunks( - total_rows, seq_len, hidden_dim, vocab_dim, dtype_bytes, trainer.args.unsloth_logit_chunk_multiplier - ) - trainer.args.unsloth_grpo_mini_batch = total_rows//B - trainer.args.unsloth_logit_chunk_multiplier = multiplier - B = trainer.args.unsloth_grpo_mini_batch - multiplier = trainer.args.unsloth_logit_chunk_multiplier - elif trainer._step % trainer.current_gradient_accumulation_steps == 0: - B = trainer.args.unsloth_grpo_mini_batch - multiplier = trainer.args.unsloth_logit_chunk_multiplier - del trainer._has_autotuned - del trainer.args.unsloth_grpo_mini_batch - del trainer.args.unsloth_logit_chunk_multiplier - else: - B = trainer.unsloth_grpo_mini_batch - multiplier = trainer.args.unsloth_logit_chunk_multiplier - else: - if trainer.args.unsloth_grpo_mini_batch > total_rows: - B = total_rows - else: - B = trainer.args.unsloth_grpo_mini_batch - - if trainer.args.unsloth_logit_chunk_multiplier is None: - multiplier = max(4, seq_len // 4096) - else: - multiplier = trainer.args.unsloth_logit_chunk_multiplier - - if pixel_values is None: - left_pad_tokens_per_prompt = calculate_pad_tokens_in_prompt(input_ids, logits_to_keep, trainer.processing_class.pad_token_id) - - # Determine max_left_pad from precomputed logprobs shape for consistency - if old_logps is not None: - max_left_pad = old_logps.shape[1] - logits_to_keep - elif ref_logps is not None: - max_left_pad = ref_logps.shape[1] - logits_to_keep - else: - max_left_pad = torch.max(left_pad_tokens_per_prompt).item() - - input_ids = left_pack_padding(input_ids, trainer.processing_class.pad_token_id) - - completion_input_ids = input_ids[:, -(logits_to_keep +max_left_pad):] - - completion_mask = create_completion_attention_mask(completion_input_ids, left_pad_tokens_per_prompt, max_left_pad, trainer.processing_class.pad_token_id).to(attention_mask.dtype) - - if trainer.use_vllm and sampling_per_token_logps is not None and getattr(trainer, "vllm_importance_sampling_correction", False): - sampling_per_token_logps = align_logprobs_with_mask(sampling_per_token_logps, completion_mask) - else: - sampling_per_token_logps = None - attention_mask = input_ids != trainer.processing_class.pad_token_id - attention_mask = attention_mask.to(attention_mask.dtype) - else: - completion_input_ids = input_ids[:, -logits_to_keep:] - - unwrapped_model = trainer.accelerator.unwrap_model(trainer.model, keep_fp32_wrapper = False) - - for module in unwrapped_model.modules(): - if hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "io_same_decice"): - module._hf_hook.io_same_decice = False - pass - - all_logprobs_list = [] - - attention_mask_chunks = torch.chunk(attention_mask, chunks=B, dim=0) - completion_ids_chunks = torch.chunk(completion_input_ids, chunks=B, dim=0) - - def chunk_optional(tensor, chunks): - if tensor is None: - return [None] * chunks - return torch.chunk(tensor, chunks=chunks, dim=0) - - import math - total_samples = input_ids.shape[0] - batch_size = math.ceil(total_samples / B) - - input_ids_chunks = [] - attention_mask_chunks = [] - pixel_values_chunks = [] - image_grid_thw_chunks = [] - pixel_attention_mask_chunks = [] - - current_pixel_idx = 0 - #TRL 0.23.0 batching logic - for start in range(0, total_samples, batch_size): - end = start + batch_size - - input_ids_chunks.append(input_ids[start:end]) - attention_mask_chunks.append(attention_mask[start:end]) - - if image_grid_thw is not None and pixel_values is not None: - - grid_slice = image_grid_thw[start:end] - image_grid_thw_chunks.append(grid_slice) - batch_pixel_count = grid_slice.prod(dim=-1).sum().item() - - start_pixel_idx = current_pixel_idx - end_pixel_idx = current_pixel_idx + batch_pixel_count - - pixel_values_chunks.append(pixel_values[start_pixel_idx:end_pixel_idx]) - - if pixel_attention_mask is not None: - pixel_attention_mask_chunks.append( - pixel_attention_mask[start_pixel_idx:end_pixel_idx] - ) - else: - pixel_attention_mask_chunks.append(None) - - current_pixel_idx = end_pixel_idx - - else: - pixel_values_chunks.append(None) - image_grid_thw_chunks.append(None) - pixel_attention_mask_chunks.append(None) - - if image_sizes is not None and not isinstance(image_sizes, torch.Tensor): - image_sizes_chunks = [[size] for size in image_sizes] - else: - image_sizes_chunks = chunk_optional(image_sizes, B) - - zipped_inputs = zip( - input_ids_chunks, - attention_mask_chunks, - pixel_values_chunks, - image_grid_thw_chunks, - pixel_attention_mask_chunks, - image_sizes_chunks, - completion_ids_chunks - ) - - if trainer._autocast_dtype is None: - autocaster = nullcontext() - else: - autocaster = torch.amp.autocast(device_type = trainer.model.device.type, dtype = trainer._autocast_dtype) - - def to_device(tensor, device, non_blocking=True): - if tensor is None: return None - return tensor.to(device, non_blocking=non_blocking) - - class Unsloth_Offloaded_Log_Softmax(torch.autograd.Function): - """ - Manual Gradient Checkpointing/CPU Offloading for Log Softmax. - """ - @staticmethod - def forward(ctx, hidden_states, lm_head, index, chunks, - logit_scale_multiply, logit_scale_divide, - logit_softcapping, temperature): - - ctx.saved_hidden_states = to_device(hidden_states, "cpu", non_blocking=True) - ctx.device = hidden_states.device - ctx.dtype = hidden_states.dtype - - ctx.lm_head = lm_head - ctx.lm_head_requires_grad = lm_head.requires_grad - ctx.index = index - ctx.args = (chunks, logit_scale_multiply, logit_scale_divide, logit_softcapping, temperature) - - with torch.no_grad(): - output = chunked_hidden_states_selective_log_softmax( - hidden_states, lm_head, index, *ctx.args - ) - - return output - - @staticmethod - def backward(ctx, grad_output): - hidden_states = to_device(ctx.saved_hidden_states, ctx.device) - hidden_states = hidden_states.to(ctx.dtype) - hidden_states.requires_grad_(True) - - lm_head = ctx.lm_head - # #Possibly redundant lines - # if ctx.lm_head_requires_grad: - # hidden_states.requires_grad_(True) - # else: - # lm_head = lm_head.detach() - - index = ctx.index - - with torch.enable_grad(): - output = chunked_hidden_states_selective_log_softmax( - hidden_states, lm_head, index, *ctx.args - ) - - torch.autograd.backward(output, grad_output) - - return ( - hidden_states.grad, - lm_head.grad if ctx.lm_head_requires_grad else None, - None, - None, - None, - None, - None, - None, - ) - - def efficient_log_softmax(hidden_states, lm_head, index, chunks=32, - logit_scale_multiply=0.0, logit_scale_divide=0.0, - logit_softcapping=0.0, temperature=1, batch_size=8): - if (index.shape[1] <= 1024 and batch_size <= 8) or batch_size==1: - #We save a gigabyte or speed with the normal path under these specific conditions - return chunked_hidden_states_selective_log_softmax( - hidden_states, - lm_head, - index, - chunks, - logit_scale_multiply, - logit_scale_divide, - logit_softcapping, - temperature - ) - else: - return Unsloth_Offloaded_Log_Softmax.apply( - hidden_states, lm_head, index, chunks, - logit_scale_multiply, logit_scale_divide, - logit_softcapping, temperature - ) - for ( - input_ids_chunk, - attention_mask_chunk, - pixel_values_chunk, - image_grid_thw_chunk, - pixel_attention_mask_chunk, - image_sizes_chunk, - completion_ids - ) in zipped_inputs: - with autocaster: - if pixel_values is None: - new_hidden_states_chunk = unwrapped_model( - input_ids = input_ids_chunk, - attention_mask = attention_mask_chunk, - pixel_values = pixel_values_chunk, - image_grid_thw = image_grid_thw_chunk, - pixel_attention_mask = pixel_attention_mask_chunk, - image_sizes = image_sizes_chunk, - ).logits - - new_hidden_states_chunk = new_hidden_states_chunk[:, -(logits_to_keep + max_left_pad + 1): , :] - new_hidden_states_chunk = new_hidden_states_chunk[:, :-1, :] - else: - new_hidden_states_chunk = unwrapped_model( - input_ids = input_ids_chunk, - attention_mask = attention_mask_chunk, - pixel_values = pixel_values_chunk, - image_grid_thw = image_grid_thw_chunk, - pixel_attention_mask = pixel_attention_mask_chunk, - image_sizes = image_sizes_chunk, - logits_to_keep = logits_to_keep + 1, - ).logits - - new_hidden_states_chunk = new_hidden_states_chunk[:, :-1, :] - - logprobs_chunk = efficient_log_softmax( - new_hidden_states_chunk, - lm_head, - completion_ids, - chunks=input_ids_chunk.shape[0]*multiplier, - logit_scale_multiply=logit_scale_multiply, - logit_scale_divide=logit_scale_divide, - logit_softcapping=logit_softcapping, - temperature=temperature, - batch_size = B - ) - #This is needed to avoid race conditions with GPT OSS offload_embbed=True - #However, it seems that this line does not slow down or disrupt models. - device_synchronize() - all_logprobs_list.append(logprobs_chunk) - - new_logprobs = torch.cat(all_logprobs_list, dim=0) - - with autocaster: - loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1 = UnslothEfficientGRPO.apply( - new_logprobs, - old_logps, - ref_logps, - sampling_per_token_logps, - lm_head, - completion_input_ids, - completion_mask, - advantages, - trainer.beta, - trainer.accelerator.scaler, - 1, - kwargs - ) - - # Must force not returning hidden states but logits otherwise gibberish - os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "0" - - return loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1 - # Old non efficient code path - new_logits = torch.matmul(new_hidden_states, lm_head.t()) - new_logits = new_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred - old_logits = torch.matmul(old_hidden_states, lm_head.t()) - old_logits = old_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred - loss, completion_length, mean_kl = grpo_compute_loss( - old_logits, - new_logits, - completion_input_ids, - completion_mask, - trainer.beta, - advantages, - ) - return loss, completion_length, mean_kl - pass - -@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options) -def grpo_compute_loss_slow( - ref, - new, - old, - sampling_per_token_logps, - input_ids, - mask, - beta, - advantages, - **kwargs -): - # All Unsloth Zoo code licensed under AGPL3 - # Set defaults for optional arguments - loss_type = kwargs.get("loss_type", "grpo") - epsilon_low = kwargs.get("epsilon_low", 0.2) - epsilon_high = kwargs.get("epsilon_high", 0.2) - max_completion_length = kwargs.get("max_completion_length", 8192) - delta = kwargs.get("delta", None) - importance_sampling_level = kwargs.get("importance_sampling_level", "token") - num_items_in_batch = kwargs.get("num_items_in_batch", None) - current_gradient_accumulation_steps = kwargs.get("current_gradient_accumulation_steps", 1) - num_processes = kwargs.get("num_processes", 1) - use_vllm = kwargs.get("use_vllm", False) - vllm_importance_sampling_cap = kwargs.get("vllm_importance_sampling_cap", 2.0) - get_sapo_token_loss = kwargs.get("get_sapo_token_loss", None) - sapo_temperature_pos = kwargs.get("sapo_temperature_pos", 1.0) - sapo_temperature_neg = kwargs.get("sapo_temperature_neg", 1.05) - get_off_policy_mask = kwargs.get("get_off_policy_mask", None) - off_policy_mask_threshold = kwargs.get("off_policy_mask_threshold", None) - input_ids = input_ids.unsqueeze(-1) - - if advantages.dim() == 1: - advantages = advantages.unsqueeze(1) - - if off_policy_mask_threshold is not None: - off_policy_mask = get_off_policy_mask( - advantages=advantages, - per_token_logps=new, - old_per_token_logps=old, - mask=mask, - off_policy_threshold=off_policy_mask_threshold, - ) - - with torch.no_grad(): - if use_vllm and sampling_per_token_logps is not None: - #must filter out extra prompt tokens in begining after making input_ids left padded - importance_sampling_ratio = torch.exp((old * mask) - sampling_per_token_logps) - importance_sampling_ratio = torch.clamp( - importance_sampling_ratio, max=vllm_importance_sampling_cap - ) - pass - - # Must detach - otherwise gradients are not propagated correctly! - # exp(x - x) == 1 - # loss_i = torch.exp(new - new.detach()) * advantages.unsqueeze(1) - if old is not None: - log_ratio = new - old - else: - log_ratio = new - new.detach() - - if importance_sampling_level == "token": - log_importance_weights = log_ratio - elif importance_sampling_level == "sequence": - log_importance_weights = (log_ratio * mask).sum(-1) / mask.sum(-1).clamp(min=1.0) - log_importance_weights = log_importance_weights.unsqueeze(-1) - else: - raise ValueError( - f"Unknown importance sampling level: {importance_sampling_level}. Possible values are 'token' " - "and 'sequence'." - ) - - coef_1 = torch.exp(log_importance_weights) - - # Reverse KL - # Note that this is a low variance low bias estimator for the KL divergence as used in GRPO paper - if beta != 0.0: - kl_i = torch.exp(ref - new) - (ref - new) - 1.0 - - else: - # set kl_i to a tensor of zeros with the correct shape - if importance_sampling_level == "sequence": - kl_i = new.new_zeros(new.size(0), 1) - else: - kl_i = torch.zeros_like(new) - # Full correct reverse KL divergence?? Missing term maybe? - # kl_i = torch.exp(new) * kl_i - - # Below is forward KL (normal KL) - # kl_i = torch.exp(old) * (old - new) - if loss_type == "cispo": - clamped_ratios = torch.clamp(coef_1, max=epsilon_high).detach() - loss_i = -clamped_ratios * advantages * new - #breakpoint() - elif loss_type in ["grpo", "bnpo", "dr_grpo", "dapo"]: - coef_2 = torch.clamp(coef_1, 1 - epsilon_low, 1 + epsilon_high) - - if delta is not None: - loss_1 = torch.clamp(coef_1, max=delta) * advantages - else: - loss_1 = coef_1 * advantages - pass - loss_2 = coef_2 * advantages - loss_i = -torch.min(loss_1, loss_2) - elif loss_type == "sapo": - if get_sapo_token_loss is None: - raise Exception(f"sapo is only available in TRL 0.26.0+") - loss_i = torch.empty_like(coef_1) - positive_advantages_mask = advantages.repeat([1, coef_1.shape[1]]) > 0 - #since we have n_chunks some tensors may error if they dont have elements in them - if coef_1[positive_advantages_mask].numel() != 0: - loss_i[positive_advantages_mask] = get_sapo_token_loss( - coef_1[positive_advantages_mask], sapo_temperature_pos - ) - if coef_1[~positive_advantages_mask].numel() != 0: - loss_i[~positive_advantages_mask] = get_sapo_token_loss( - coef_1[~positive_advantages_mask], sapo_temperature_neg - ) - loss_i = -loss_i * advantages - else: - raise ValueError(f"Unknown loss type: {loss_type}") - - if off_policy_mask_threshold is not None: - loss_i = loss_i * off_policy_mask - - if use_vllm and sampling_per_token_logps is not None: - loss_i = loss_i * importance_sampling_ratio - #delta for metric - with torch.no_grad(): - delta = torch.abs(old - sampling_per_token_logps) - delta = delta * mask - flat_is_ratio = importance_sampling_ratio * mask - else: - delta = torch.tensor([]).detach() - flat_is_ratio = torch.tensor([]).detach() - if beta != 0.0: - loss_i = loss_i + beta * kl_i - - mask = mask.to(torch.float32) - n_mask_per_reward = mask.sum(1) - - # https://github.com/huggingface/trl/blob/e8b8499f1f8d76838155b515e414ee98f757d6d5/trl/trainer/grpo_trainer.py#L1624 - if loss_type in ["grpo", "sapo"]: - loss = ((loss_i * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)).mean() - loss = loss / current_gradient_accumulation_steps - elif loss_type == "bnpo": - loss = (loss_i * mask).sum() / mask.sum().clamp(min=1.0) - loss = loss / current_gradient_accumulation_steps - elif loss_type == "dr_grpo": - loss = (loss_i * mask).sum() / (loss_i.size(0) * max_completion_length) - loss = loss / current_gradient_accumulation_steps - elif loss_type in ["cispo", "dapo"]: - normalizer = num_items_in_batch/ num_processes - loss = (loss_i * mask).sum() / normalizer - else: - raise ValueError(f"Unknown loss type: {loss_type}") - - # loss = (loss_i * mask).sum() / mask.sum() - - # Get metrics as well which are folded - def masked_batch_mean(x): - with torch.inference_mode(): - completion_length = n_mask_per_reward.mean() - if x.shape[1] == 1: # when importance_sampling_level == "sequence" - return completion_length, x.mean() - else: - mean_kl_per_reward = (x * mask).sum(1) / n_mask_per_reward - mean_kl = mean_kl_per_reward.mean() - return completion_length, mean_kl - completion_length, mean_kl = masked_batch_mean(kl_i) - return loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1 - -def grpo_update_SamplingParams(SamplingParams, generation_kwargs, vllm_sampling_params = None): - good_sampling_params_keys = inspect.signature(SamplingParams).parameters.keys() - - # Filter generation_kwargs - new_generation_kwargs = {} - for key in generation_kwargs.keys(): - if key in good_sampling_params_keys: - new_generation_kwargs[key] = generation_kwargs[key] - generation_kwargs = new_generation_kwargs - - if vllm_sampling_params is not None: - for key in good_sampling_params_keys: - if hasattr(vllm_sampling_params, key): - overwrited_key = getattr(vllm_sampling_params, key) - if overwrited_key is not None and (type(overwrited_key) in (list, tuple,) and len(overwrited_key) != 0): - generation_kwargs[key] = overwrited_key - return generation_kwargs - -def _get_inference_mode_context_manager(model: torch.nn.Module): - """ - If the state dict was quantized using torchao, we will run into - the following error when calling ops like aten.t() in inference mode. - This is a bug in PyTorch that affects all tensor subclasses. - - Cannot set version_counter for inference tensor - - For now, we work around this issue by using `torch.no_grad()` in this case. - See https://github.com/pytorch/pytorch/issues/164872 for more details. - Otherwise, just return `torch.inference_mode()`. - """ - torchao_config = getattr(model, "torchao_config", None) - if torchao_config is not None and torchao_config.qat_scheme is None: - return torch.no_grad() - else: - return torch.inference_mode() - -def vLLMSamplingParams(**kwargs): - from vllm import SamplingParams - - sampling_params = SamplingParams(**kwargs) - sampling_params._set_kwargs = kwargs - return sampling_params -@dataclass -class UnslothGRPOConfig(GRPOConfig): - """ - - Configuration class for the [`GRPOTrainer`]. - - This class includes only the parameters that are specific to GRPO training. For a full list of training arguments, - please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may - differ from those in [`~transformers.TrainingArguments`]. - - Using [`~transformers.HfArgumentParser`] we can turn this class into - [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the - command line. - - Parameters: - > Parameters that control the model and reference model - - model_init_kwargs (`str`, `dict[str, Any]`, *optional*): - Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model` - argument of the [`GRPOTrainer`] is provided as a string. - disable_dropout (`bool`, *optional*, defaults to `False`): - Whether to disable dropout in the model. This is useful for training with a reference model, as it prevents - the model from generating different logprobs for the same input. - - > Parameters that control the data preprocessing - - remove_unused_columns (`bool`, *optional*, defaults to `False`): - Whether to only keep the column `"prompt"` in the dataset. If you use a custom reward function that - requires any column other than `"prompts"` and `"completions"`, you should keep this to `False`. - max_prompt_length (`int` or `None`, *optional*, defaults to `512`): - Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left. - num_generations (`int` or `None`, *optional*, defaults to `8`): - Number of generations per prompt to sample. The effective batch size (num_processes * per_device_batch_size - * gradient_accumulation_steps) must be evenly divisible by this value. - max_completion_length (`int` or `None`, *optional*, defaults to `256`): - Maximum length of the generated completion. - ds3_gather_for_generation (`bool`, *optional*, defaults to `True`): - This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation, - improving generation speed. However, disabling this option allows training models that exceed the VRAM - capacity of a single GPU, albeit at the cost of slower generation. Disabling this option is not compatible - with vLLM generation. - shuffle_dataset (`bool`, *optional*, defaults to `True`): - Whether to shuffle the training dataset. - - > Parameters that control generation - - generation_batch_size: (`int`, *optional*): - Batch size to use for generation. If `None`, it defaults to the effective training batch size: - `per_device_train_batch_size * num_processes * steps_per_generation`. In other words, there is one - generation batch processed per optimization step. Mutually exclusive with `steps_per_generation`. - steps_per_generation: (`int`, *optional*): - Number of steps per generation. If `None`, it defaults to `gradient_accumulation_steps`. Mutually exclusive - with `generation_batch_size`. - temperature (`float`, defaults to `1.0`): - Temperature for sampling. The higher the temperature, the more random the completions. - top_p (`float`, *optional*, defaults to `1.0`): - Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. Set to - `1.0` to consider all tokens. - top_k (`int`, *optional*): - Number of highest probability vocabulary tokens to keep for top-k-filtering. If `None`, top-k-filtering is - disabled and all tokens are considered. - min_p (`float`, *optional*): - Minimum token probability, which will be scaled by the probability of the most likely token. It must be a - value between `0.0` and `1.0`. Typical values are in the `0.01-0.2` range. - repetition_penalty (`float`, *optional*, defaults to `1.0`): - Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far. - Values > `1.0` encourage the model to use new tokens, while values < `1.0` encourage the model to repeat - tokens. - use_transformers_paged (`bool`, *optional*, defaults to `False`): - Whether to use the `transformers` paged implementation for generation. If set to `True`, the `transformers` - paged implementation will be used for generation instead of the default padded implementation. This - parameter is only effective when `use_vllm` is set to `False`. - cache_implementation (`str`, *optional*): - Implementation of the cache method for faster generation when `use_vllm` is set to `False`. - generation_kwargs (`dict[str, Any]`, *optional*): - Additional keyword arguments to pass to [`~transformers.GenerationConfig`] (if using transformers) or - `SamplingParams` (if using vLLM) when sampling completions. This can be used to further customize the - generation behavior, such as setting `suppress_tokens`, `num_beams`, etc. If it contains keys that conflict - with the other generation parameters (like `min_p`, `top_p`, etc.), they will override them. - - > Parameters that control generation acceleration powered by vLLM - - use_vllm (`bool`, *optional*, defaults to `False`): - Whether to use vLLM for generating completions. If set to `True`, the trainer will use vLLM for generation - instead of the default model.generate(). Requires `vllm` to be installed. - vllm_mode (`str`, *optional*, defaults to `"server"`): - Mode to use for vLLM integration when `use_vllm` is set to `True`. Must be one of `"server"` or - `"colocate"`. - - - `"server"`: The trainer will send generation requests to a separate vLLM server. Make sure a TRL vLLM - server is running (start with `trl vllm-serve`). - - `"colocate"`: vLLM will run in the same process and share the training GPUs. This avoids the need for a - separate server but may cause resource contention with training. - vllm_model_impl (`str`, *optional*, defaults to `"vllm"`): - Model implementation to use for vLLM. Must be one of `"transformers"` or `"vllm"`. `"transformers"`: Use - the `transformers` backend for model implementation. `"vllm"`: Use the `vllm` library for model - implementation. - vllm_guided_decoding_regex (`str`, *optional*): - Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled. - - > Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`) - - vllm_server_base_url (`str`, *optional*): - Base URL for the vLLM server (e.g., `"http://localhost:8000"`). If provided, `vllm_server_host` and - `vllm_server_port` are ignored. - vllm_server_host (`str`, *optional*, defaults to `"0.0.0.0"`): - Host of the vLLM server to connect to. Ignored if `vllm_server_base_url` is provided. - vllm_server_port (`int`, *optional*, defaults to `8000`): - Port of the vLLM server to connect to. Ignored if `vllm_server_base_url` is provided. - vllm_server_timeout (`float`, *optional*, defaults to `240.0`): - Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up after the - timeout, a `ConnectionError` is raised. - - > Parameters that control colocated vLLM execution (only used when `vllm_mode` is `"colocate"`) - - vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.3`): - Control the GPU memory utilization for vLLM. This setting only applies when `vllm_mode` is set to - `"colocate"`. If you are using `vllm_mode="server"`, this parameter must be passed separately when - launching the vLLM server via the `--vllm_gpu_memory_utilization` flag. - vllm_tensor_parallel_size (`int`, *optional*, defaults to `1`): - Control the tensor parallel size for vLLM. This setting only applies when `vllm_mode` is set to - `"colocate"`. If you are using `vllm_mode="server"`, this parameter must be passed separately when - launching the vLLM server via the `--vllm_tensor_parallel_size` flag. - vllm_enable_sleep_mode (`bool`, *optional*, defaults to `False`): - Whether to enable sleep mode for vLLM. If `True`, vLLM will sleep during the optimization step and woken - for weight sync and generation. - - > Parameters that control the training - - beta (`float`, *optional*, defaults to `0.0`): - KL coefficient. If `0.0` (default), the reference model is not loaded, reducing memory usage and improving - training speed. - num_iterations (`int`, *optional*, defaults to `1`): - Number of iterations per batch (denoted as μ in the algorithm). - epsilon (`float`, *optional*, defaults to `0.2`): - Epsilon value for clipping. - delta (`float`, *optional*): - Enables the upper clipping bound in two-sided GRPO loss when set to a float. If `None` (default), standard - GRPO clipping is used. Recommended to be greater than `1 + ε` when enabled. This method is introduced in - the [INTELLECT-2 tech report](https://huggingface.co/papers/2505.07291). - epsilon_high (`float`, *optional*): - Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the lower-bound - specified in argument `epsilon`. Paper [DAPO](https://huggingface.co/papers/2503.14476) recommends `0.28`. - importance_sampling_level (`str`, *optional*, defaults to `"token"`): - Controls whether importance sampling ratios are computed at the `"token"` or `"sequence"` level. `"token"` - keeps the raw per-token log-probability ratios (one weight per token). `"sequence"` averages the - log-probability ratios across valid tokens to produce a single ratio per sequence. The [GSPO - paper](https://huggingface.co/papers/2507.18071) shows that sequence-level sampling often yields more - stable training and better alignment with sequence-level rewards. - reward_weights (`list[float]`, *optional*): - Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are - weighted equally with weight `1.0`. - scale_rewards (`str` or `bool`, *optional*, defaults to `"group"`): - Specifies the scaling strategy for rewards. Supported values are: - - - `True` or `"group"` (default): rewards are scaled by the standard deviation within each group, ensuring - unit variance within a group. - - `"batch"`: rewards are scaled by the standard deviation across the entire batch, as recommended in the - [PPO Lite paper](https://huggingface.co/papers/2508.08221). - - `False` or `"none"`: no scaling is applied. The [Dr. GRPO - paper](https://huggingface.co/papers/2503.20783) recommends not scaling rewards, as scaling by the - standard deviation introduces a question-level difficulty bias. - loss_type (`str`, *optional*, defaults to `"dapo"`): - Specifies the loss formulation to use. Supported values are: - - - `"grpo"`: Aggregates token-level losses by normalizing over sequence length. Not recommended due to - length bias—this approach tends to prefer shorter completions with positive advantages and longer ones - with negative advantages. - - `"dr_grpo"`: Aggregates token-level losses by normalizing with a global constant. This method was - introduced in the [Dr. GRPO paper](https://huggingface.co/papers/2503.20783) to eliminate length bias. - The value of the constant corresponds to `max_completion_length`. - - `"dapo"` (default): Aggregates token-level losses by normalizing with the number of active token in the - global accumulated batch. This method was introduced in the [DAPO - paper](https://huggingface.co/papers/2503.14476) to eliminate length bias. - - `"bnpo"`: Aggregates token-level losses by normalizing with the number of active token in the local - batch. Note that normalization is performed over the local batch only, so results may slightly vary - depending on the local batch size, despite a constant effective batch size. When using - `per_device_train_batch_size==1`, the loss is equivalent to the GRPO loss. - mask_truncated_completions (`bool`, *optional*, defaults to `False`): - When enabled, truncated completions are excluded from the loss calculation, preventing them from being - incorrectly penalized and introducing noise during training. According to the - [DAPO](https://huggingface.co/papers/2503.14476) paper, this is a good practice for training stability. - sync_ref_model (`bool`, *optional*, defaults to `False`): - Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using - the `ref_model_mixup_alpha` parameter. This synchronization originates from the - [TR-DPO](https://huggingface.co/papers/2404.09656) paper. - ref_model_mixup_alpha (`float`, *optional*, defaults to `0.6`): - α parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which controls the mix - between the current policy and the previous reference policy during updates. The reference policy is - updated according to the equation: `π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you - must set `sync_ref_model=True`. - ref_model_sync_steps (`int`, *optional*, defaults to `512`): - τ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which determines how - frequently the current policy is synchronized with the reference policy. To use this parameter, you must - set `sync_ref_model=True`. - top_entropy_quantile (`float`, *optional*, defaults to `1.0`): - ρ parameter from [Beyond the 80/20 Rule](https://huggingface.co/papers/2506.01939). Keeps in the policy - loss term only the top-ρ quantile of tokens by entropy of the probability distribution at each sequence - position, improving results. Range: `[0.0-1.0]`. A value of `0.0` masks all but the highest entropy token; - `1.0` keeps all tokens. The paper recommends a value of `0.2`. If used with - `mask_truncated_completions=True`, only tokens from non-truncated completions are considered. - use_liger_loss (`bool`, *optional*, defaults to `False`): - Whether to use the Liger GRPO loss. - vllm_importance_sampling_correction (`bool`, *optional*, defaults to `True`): - Whether to apply Truncated Importance Sampling (TIS) between vLLM completion logprobs and recomputed - logprobs. [Your Efficient RL Framework Secretly Brings You Off-Policy RL - Training](https://fengyao.notion.site/off-policy-rl) highlights that using a separate generation framework - (such as vLLM) can introduce off-policy effects due to subtle implementation differences between generation - and training backends. TIS is proposed as a remedy for this issue. - vllm_importance_sampling_cap (`float`, *optional*, defaults to `2.0`): - Truncation parameter C for Truncated Importance Sampling (TIS). This sets an upper bound on the importance - sampling ratio, improving training stability. - - > Parameters that control the logging - - log_completions (`bool`, *optional*, defaults to `False`): - Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is installed, - it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`. - num_completions_to_print (`int`, *optional*): - Number of completions to print with `rich`. If `None`, all completions are logged. - wandb_log_unique_prompts (`bool`, *optional*, defaults to `False`): - Whether to log unique prompts in wandb. If `True`, only unique prompts are logged. If `False`, all prompts - are logged. - - """ - vllm_sampling_params: Optional[Any] = field( - default = None, - metadata = {'help': 'vLLM SamplingParams'}, - ) - unsloth_num_chunks : Optional[int] = field( - default = -1, - metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'}, - ) - unsloth_logit_chunk_multiplier : Optional[int] = field( - default = None, - metadata = {'help': 'Multiplier for chunked logit computations.'}, - ) - unsloth_grpo_mini_batch : Optional[int] = field( - default = None, - metadata = {'help': 'Mini batch size for GRPO hidden state accumulation. Default is None unless user defines it.'}, - ) - - def __init__( - self, - output_dir = None, - overwrite_output_dir = None, - do_train = False, - do_eval = False, - do_predict = False, - eval_strategy = 'no', - prediction_loss_only = False, - per_device_train_batch_size = 4, - per_device_eval_batch_size = 4, - per_gpu_train_batch_size = None, - per_gpu_eval_batch_size = None, - gradient_accumulation_steps = 2, - eval_accumulation_steps = 2, - eval_delay = 0, - torch_empty_cache_steps = 250, - learning_rate = 5e-05, - weight_decay = 0.01, - adam_beta1 = 0.9, - adam_beta2 = 0.999, - adam_epsilon = 1e-08, - max_grad_norm = 1.0, - num_train_epochs = 3.0, - max_steps = -1, - lr_scheduler_type = 'linear', - lr_scheduler_kwargs = None, - warmup_ratio = 0.1, - warmup_steps = 0, - log_level = 'passive', - log_level_replica = 'warning', - log_on_each_node = True, - logging_dir = None, - logging_strategy = 'steps', - logging_first_step = False, - logging_steps = 1, - logging_nan_inf_filter = False, - save_strategy = 'steps', - save_steps = 500, - save_total_limit = None, - save_safetensors = True, - save_on_each_node = False, - save_only_model = False, - restore_callback_states_from_checkpoint = False, - no_cuda = False, - use_cpu = False, - use_mps_device = False, - seed = 3407, - data_seed = 3407, - jit_mode_eval = False, - bf16 = False, - fp16 = False, - fp16_opt_level = 'O1', - half_precision_backend = 'auto', - bf16_full_eval = False, - fp16_full_eval = False, - tf32 = None, - local_rank = -1, - ddp_backend = None, - tpu_num_cores = None, - tpu_metrics_debug = False, - debug = '', - dataloader_drop_last = False, - eval_steps = None, - dataloader_num_workers = 0, - dataloader_prefetch_factor = None, - past_index = -1, - run_name = None, - disable_tqdm = None, - remove_unused_columns = False, - label_names = None, - load_best_model_at_end = False, - metric_for_best_model = None, - greater_is_better = None, - ignore_data_skip = False, - fsdp = None, - fsdp_min_num_params = 0, - fsdp_config = None, - fsdp_transformer_layer_cls_to_wrap = None, - accelerator_config = None, - parallelism_config = None, - deepspeed = None, - label_smoothing_factor = 0.0, - optim = 'adamw_8bit', - optim_args = None, - adafactor = False, - group_by_length = False, - length_column_name = 'length', - report_to = 'none', - project = 'huggingface', - trackio_space_id = 'trackio', - ddp_find_unused_parameters = None, - ddp_bucket_cap_mb = None, - ddp_broadcast_buffers = None, - dataloader_pin_memory = True, - dataloader_persistent_workers = False, - skip_memory_metrics = True, - use_legacy_prediction_loop = False, - push_to_hub = False, - resume_from_checkpoint = None, - hub_model_id = None, - hub_strategy = 'every_save', - hub_token = None, - hub_private_repo = None, - hub_always_push = False, - hub_revision = None, - gradient_checkpointing = True, - gradient_checkpointing_kwargs = None, - include_inputs_for_metrics = False, - eval_do_concat_batches = True, - fp16_backend = 'auto', - push_to_hub_model_id = None, - push_to_hub_organization = None, - push_to_hub_token = None, - mp_parameters = '', - auto_find_batch_size = False, - full_determinism = False, - torchdynamo = None, - ray_scope = 'last', - ddp_timeout = 1800, - torch_compile = False, - torch_compile_backend = None, - torch_compile_mode = None, - include_tokens_per_second = False, - include_num_input_tokens_seen = False, - neftune_noise_alpha = None, - optim_target_modules = None, - batch_eval_metrics = False, - eval_on_start = False, - use_liger_kernel = False, - liger_kernel_config = None, - eval_use_gather_object = False, - average_tokens_across_devices = True, - model_init_kwargs = None, - disable_dropout = False, - max_prompt_length = 512, - num_generations = 8, - max_completion_length = 256, - ds3_gather_for_generation = True, - shuffle_dataset = True, - generation_batch_size = None, - steps_per_generation = None, - temperature = 1.0, - top_p = 1.0, - top_k = None, - min_p = None, - generation_kwargs = {}, - repetition_penalty = 1.0, - use_transformers_paged = False, - cache_implementation = None, - use_vllm = False, - vllm_mode = 'colocate', - vllm_model_impl = 'vllm', - vllm_enable_sleep_mode = False, - vllm_guided_decoding_regex = None, - vllm_server_base_url = None, - vllm_server_host = '0.0.0.0', - vllm_server_port = 8000, - vllm_server_timeout = 240.0, - vllm_gpu_memory_utilization = 0.3, - vllm_tensor_parallel_size = 1, - beta = 0.001, - num_iterations = 1, - epsilon = 0.2, - delta = None, - epsilon_high = None, - importance_sampling_level = 'token', - reward_weights = None, - scale_rewards = 'group', - loss_type = 'bnpo', - mask_truncated_completions = False, - sync_ref_model = False, - ref_model_mixup_alpha = 0.6, - ref_model_sync_steps = 512, - top_entropy_quantile = 1.0, - use_liger_loss = False, - vllm_importance_sampling_correction = False, - vllm_importance_sampling_cap = 2.0, - log_completions = False, - num_completions_to_print = None, - wandb_log_unique_prompts = False, - vllm_sampling_params = None, - unsloth_num_chunks = -1, - unsloth_logit_chunk_multiplier = None, - unsloth_grpo_mini_batch = None, - - **kwargs, - ): - if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!') - if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!') - if num_train_epochs is None: - num_train_epochs = 3.0 # Default to 3 epochs if None, max_steps will override - if output_dir is None and save_strategy == 'steps' and save_steps == 500: - output_dir = 'unsloth_training_checkpoints' - save_strategy = 'no' - if loss_type.lower() == 'dr_grpo': - loss_type = 'dr_grpo' - elif loss_type.lower() == 'dapo': - loss_type = 'dapo' - if loss_type.lower() == 'dr_grpo': - if scale_rewards == None: - scale_rewards = True - elif scale_rewards == True: - print('Unsloth: The Dr GRPO paper recommends setting `scale_rewards` to False! Will override. Set it to `None` to force False.') - scale_rewards = False - elif loss_type.lower() == 'dapo': - if mask_truncated_completions != True: - print('Unsloth: The DAPO paper recommends `mask_truncated_completions = True` - we will set it.') - if epsilon_high != 0.28: - print('Unsloth: The DAPO paper recommends `epsilon_high = 0.28` - we will set it.') - if beta != 0.0: - print(f'[WARNING] Unsloth: The DAPO paper recommends setting `beta = 0.0` to remove the KL term - You have set it to {beta}.') - mask_truncated_completions = True - epsilon_high = 0.28 - - if steps_per_generation is None and generation_batch_size is None: - ga = gradient_accumulation_steps - world_size = int(os.environ.get('WORLD_SIZE', '1')) - if (ga * world_size * per_device_train_batch_size) % num_generations != 0: - print('Unsloth: We now expect `per_device_train_batch_size` * `gradient_accumulation_steps` * `world_size` to be a multiple of `num_generations`.\nWe will change the batch size of ' + str(per_device_train_batch_size) + ' to the `num_generations` of ' + str(num_generations)) - per_device_train_batch_size = num_generations - - if temperature <= 0: - raise ValueError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.') - elif temperature >= 10: - raise ValueError('Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic.') - - if use_vllm and (top_k is None or top_k == 0): top_k = -1 - - super().__init__( - output_dir = output_dir, - overwrite_output_dir = overwrite_output_dir, - do_train = do_train, - do_eval = do_eval, - do_predict = do_predict, - eval_strategy = eval_strategy, - prediction_loss_only = prediction_loss_only, - per_device_train_batch_size = per_device_train_batch_size, - per_device_eval_batch_size = per_device_eval_batch_size, - per_gpu_train_batch_size = per_gpu_train_batch_size, - per_gpu_eval_batch_size = per_gpu_eval_batch_size, - gradient_accumulation_steps = gradient_accumulation_steps, - eval_accumulation_steps = eval_accumulation_steps, - eval_delay = eval_delay, - torch_empty_cache_steps = torch_empty_cache_steps, - learning_rate = learning_rate, - weight_decay = weight_decay, - adam_beta1 = adam_beta1, - adam_beta2 = adam_beta2, - adam_epsilon = adam_epsilon, - max_grad_norm = max_grad_norm, - num_train_epochs = num_train_epochs, - max_steps = max_steps, - lr_scheduler_type = lr_scheduler_type, - lr_scheduler_kwargs = lr_scheduler_kwargs, - warmup_ratio = warmup_ratio, - warmup_steps = warmup_steps, - log_level = log_level, - log_level_replica = log_level_replica, - log_on_each_node = log_on_each_node, - logging_dir = logging_dir, - logging_strategy = logging_strategy, - logging_first_step = logging_first_step, - logging_steps = logging_steps, - logging_nan_inf_filter = logging_nan_inf_filter, - save_strategy = save_strategy, - save_steps = save_steps, - save_total_limit = save_total_limit, - save_safetensors = save_safetensors, - save_on_each_node = save_on_each_node, - save_only_model = save_only_model, - restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint, - no_cuda = no_cuda, - use_cpu = use_cpu, - use_mps_device = use_mps_device, - seed = seed, - data_seed = data_seed, - jit_mode_eval = jit_mode_eval, - bf16 = bf16, - fp16 = fp16, - fp16_opt_level = fp16_opt_level, - half_precision_backend = half_precision_backend, - bf16_full_eval = bf16_full_eval, - fp16_full_eval = fp16_full_eval, - tf32 = tf32, - local_rank = local_rank, - ddp_backend = ddp_backend, - tpu_num_cores = tpu_num_cores, - tpu_metrics_debug = tpu_metrics_debug, - debug = debug, - dataloader_drop_last = dataloader_drop_last, - eval_steps = eval_steps, - dataloader_num_workers = dataloader_num_workers, - dataloader_prefetch_factor = dataloader_prefetch_factor, - past_index = past_index, - run_name = run_name, - disable_tqdm = disable_tqdm, - remove_unused_columns = remove_unused_columns, - label_names = label_names, - load_best_model_at_end = load_best_model_at_end, - metric_for_best_model = metric_for_best_model, - greater_is_better = greater_is_better, - ignore_data_skip = ignore_data_skip, - fsdp = fsdp, - fsdp_min_num_params = fsdp_min_num_params, - fsdp_config = fsdp_config, - fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap, - accelerator_config = accelerator_config, - parallelism_config = parallelism_config, - deepspeed = deepspeed, - label_smoothing_factor = label_smoothing_factor, - optim = optim, - optim_args = optim_args, - adafactor = adafactor, - group_by_length = group_by_length, - length_column_name = length_column_name, - report_to = report_to, - project = project, - trackio_space_id = trackio_space_id, - ddp_find_unused_parameters = ddp_find_unused_parameters, - ddp_bucket_cap_mb = ddp_bucket_cap_mb, - ddp_broadcast_buffers = ddp_broadcast_buffers, - dataloader_pin_memory = dataloader_pin_memory, - dataloader_persistent_workers = dataloader_persistent_workers, - skip_memory_metrics = skip_memory_metrics, - use_legacy_prediction_loop = use_legacy_prediction_loop, - push_to_hub = push_to_hub, - resume_from_checkpoint = resume_from_checkpoint, - hub_model_id = hub_model_id, - hub_strategy = hub_strategy, - hub_token = hub_token, - hub_private_repo = hub_private_repo, - hub_always_push = hub_always_push, - hub_revision = hub_revision, - gradient_checkpointing = gradient_checkpointing, - gradient_checkpointing_kwargs = gradient_checkpointing_kwargs, - include_inputs_for_metrics = include_inputs_for_metrics, - eval_do_concat_batches = eval_do_concat_batches, - fp16_backend = fp16_backend, - push_to_hub_model_id = push_to_hub_model_id, - push_to_hub_organization = push_to_hub_organization, - push_to_hub_token = push_to_hub_token, - mp_parameters = mp_parameters, - auto_find_batch_size = auto_find_batch_size, - full_determinism = full_determinism, - torchdynamo = torchdynamo, - ray_scope = ray_scope, - ddp_timeout = ddp_timeout, - torch_compile = torch_compile, - torch_compile_backend = torch_compile_backend, - torch_compile_mode = torch_compile_mode, - include_tokens_per_second = include_tokens_per_second, - include_num_input_tokens_seen = include_num_input_tokens_seen, - neftune_noise_alpha = neftune_noise_alpha, - optim_target_modules = optim_target_modules, - batch_eval_metrics = batch_eval_metrics, - eval_on_start = eval_on_start, - use_liger_kernel = use_liger_kernel, - liger_kernel_config = liger_kernel_config, - eval_use_gather_object = eval_use_gather_object, - average_tokens_across_devices = average_tokens_across_devices, - model_init_kwargs = model_init_kwargs, - disable_dropout = disable_dropout, - max_prompt_length = max_prompt_length, - num_generations = num_generations, - max_completion_length = max_completion_length, - ds3_gather_for_generation = ds3_gather_for_generation, - shuffle_dataset = shuffle_dataset, - generation_batch_size = generation_batch_size, - steps_per_generation = steps_per_generation, - temperature = temperature, - top_p = top_p, - top_k = top_k, - min_p = min_p, - generation_kwargs = generation_kwargs, - repetition_penalty = repetition_penalty, - use_transformers_paged = use_transformers_paged, - cache_implementation = cache_implementation, - use_vllm = use_vllm, - vllm_mode = vllm_mode, - vllm_model_impl = vllm_model_impl, - vllm_enable_sleep_mode = vllm_enable_sleep_mode, - vllm_guided_decoding_regex = vllm_guided_decoding_regex, - vllm_server_base_url = vllm_server_base_url, - vllm_server_host = vllm_server_host, - vllm_server_port = vllm_server_port, - vllm_server_timeout = vllm_server_timeout, - vllm_gpu_memory_utilization = vllm_gpu_memory_utilization, - vllm_tensor_parallel_size = vllm_tensor_parallel_size, - beta = beta, - num_iterations = num_iterations, - epsilon = epsilon, - delta = delta, - epsilon_high = epsilon_high, - importance_sampling_level = importance_sampling_level, - reward_weights = reward_weights, - scale_rewards = scale_rewards, - loss_type = loss_type, - mask_truncated_completions = mask_truncated_completions, - sync_ref_model = sync_ref_model, - ref_model_mixup_alpha = ref_model_mixup_alpha, - ref_model_sync_steps = ref_model_sync_steps, - top_entropy_quantile = top_entropy_quantile, - use_liger_loss = use_liger_loss, - vllm_importance_sampling_correction = vllm_importance_sampling_correction, - vllm_importance_sampling_cap = vllm_importance_sampling_cap, - log_completions = log_completions, - num_completions_to_print = num_completions_to_print, - wandb_log_unique_prompts = wandb_log_unique_prompts,**kwargs) - self.vllm_sampling_params = vllm_sampling_params - self.unsloth_num_chunks = unsloth_num_chunks - if unsloth_grpo_mini_batch is not None: - if self.generation_batch_size >= unsloth_grpo_mini_batch: - self.unsloth_grpo_mini_batch = unsloth_grpo_mini_batch - else: - raise ValueError( - f"Unsloth GRPO mini batch size needs to be less than or equal to the effective generation batch size, " - f"which is self.per_device_train_batch_size * gradient_accumulation_steps." - ) - self.unsloth_logit_chunk_multiplier = unsloth_logit_chunk_multiplier - - -pass - -class _UnslothGRPOTrainer(BaseTrainer): - """""" - - _tag_names = ["trl", "grpo"] - _name = "GRPO" - _paper = { - "title": "DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models", - "id": "2402.03300", - # docstyle-ignore - "citation": textwrap.dedent("""\ - @article{shao2024deepseekmath, - title = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}}, - author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo}, - year = 2024, - eprint = {arXiv:2402.03300}, - } - """), - } - - def __init__( - self, - model: Union[str, PreTrainedModel], - reward_funcs: Union[RewardFunc, list[RewardFunc]], - args: Optional[GRPOConfig] = None, - train_dataset: Optional[Union[Dataset, IterableDataset]] = None, - eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None, - processing_class: Optional[Union[PreTrainedTokenizerBase, ProcessorMixin]] = None, - reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None, - callbacks: Optional[list[TrainerCallback]] = None, - optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), - peft_config: Optional["PeftConfig"] = None, - ): - - if hasattr(model, 'vllm_engine') and hasattr(args, 'use_vllm'): - if (getattr(args, 'use_vllm', False) == False): - args.use_vllm = True - args.vllm_mode='colocate' - if os.environ.get('UNSLOTH_VLLM_STANDBY', '0') == '1': - args.vllm_enable_sleep_mode=True - # Args - if args is None: - model_name = model if isinstance(model, str) else model.config._name_or_path - model_name = model_name.split("/")[-1] - args = GRPOConfig(f"{model_name}-GRPO") - - # Models - # Trained model - model_init_kwargs = args.model_init_kwargs or {} - if isinstance(model, str): - model_id = model - dtype = model_init_kwargs.get("dtype") - if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None: - pass # dtype is already a torch.dtype or "auto" or None - elif isinstance(dtype, str): # it's a str, but not "auto" - dtype = getattr(torch, dtype) - model_init_kwargs["dtype"] = dtype - else: - raise ValueError( - "Invalid `dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing " - f"a `torch.dtype` (e.g., 'float32'), but got {dtype}." - ) - # Disable caching if gradient checkpointing is enabled [not supported] - config = AutoConfig.from_pretrained(model_id) - architecture = getattr(transformers, config.architectures[0]) - model = architecture.from_pretrained(model_id, **model_init_kwargs) - else: - model_id = model.config._name_or_path - if args.model_init_kwargs is not None: - logger.warning( - "You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. " - "The `model_init_kwargs` will be ignored." - ) - - # Some models [SmolVLM/Idefics3] don't support `logits_to_keep` argument and error out if we pass it - # Inspect the forward method before we wrap the model with PEFT - self.model_kwarg_keys = ( - inspect.signature(model.forward).parameters.keys() - if not hasattr(model, "get_base_model") - else inspect.signature(model.get_base_model().forward).parameters.keys() - ) - - if False: - pass - - # Processing class - if processing_class is None: - processing_class = AutoProcessor.from_pretrained(model.config._name_or_path, truncation_side="left") - - # Handle pad token for processors or tokenizers - if isinstance(processing_class, ProcessorMixin): - tokenizer = processing_class.tokenizer - elif isinstance(processing_class, PreTrainedTokenizerBase): - tokenizer = processing_class - else: - raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") - - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - - self.pad_token = tokenizer.pad_token - self.pad_token_id = tokenizer.pad_token_id - self.eos_token_id = tokenizer.eos_token_id - - # Reward functions - if not isinstance(reward_funcs, list): - reward_funcs = [reward_funcs] - self.reward_func_names = [] - for i, reward_func in enumerate(reward_funcs): - if isinstance(reward_func, str): - reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained( - reward_func, num_labels=1, **model_init_kwargs - ) - if isinstance(reward_funcs[i], nn.Module): # Use Module over PretrainedModel for compat w/ compiled models - self.reward_func_names.append(reward_funcs[i].config._name_or_path.split("/")[-1]) - else: - self.reward_func_names.append(reward_funcs[i].__name__) - self.reward_funcs = reward_funcs - - # Reward weights - if args.reward_weights is not None: - if len(args.reward_weights) != len(reward_funcs): - raise ValueError( - f"Number of reward weights ({len(args.reward_weights)}) must match number of reward " - f"functions ({len(reward_funcs)})" - ) - self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32) - else: - self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32) - - # Reward processing class - if reward_processing_classes is None: - reward_processing_classes = [None] * len(reward_funcs) - elif not isinstance(reward_processing_classes, list): - reward_processing_classes = [reward_processing_classes] - if len(reward_processing_classes) != len(reward_funcs): - raise ValueError( - f"The number of reward processing classes ({len(reward_processing_classes)}) must match the number of " - f"reward functions ({len(reward_funcs)})." - ) - - for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)): - if isinstance(reward_func, PreTrainedModel): - if reward_processing_class is None: - reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path) - if reward_processing_class.pad_token_id is None: - reward_processing_class.pad_token = reward_processing_class.eos_token - # The reward model computes the reward for the latest non-padded token in the input sequence. - # So it's important to set the pad token ID to the padding token ID of the processing class. - reward_func.config.pad_token_id = reward_processing_class.pad_token_id - reward_processing_classes[i] = reward_processing_class - - self.reward_processing_classes = reward_processing_classes - - # Training arguments - self.max_prompt_length = args.max_prompt_length - self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper - self.num_generations = args.num_generations # = G in the GRPO paper - self.temperature = args.temperature - self.top_p = args.top_p - self.top_k = args.top_k - self.min_p = args.min_p - self.repetition_penalty = args.repetition_penalty - self.use_transformers_paged = args.use_transformers_paged - self.use_vllm = args.use_vllm - self.vllm_mode = args.vllm_mode - self.vllm_gpu_memory_utilization = args.vllm_gpu_memory_utilization # only applies to colocation mode - self.vllm_tensor_parallel_size = args.vllm_tensor_parallel_size # only applies to colocation mode - self.vllm_importance_sampling_correction = args.vllm_importance_sampling_correction - self.vllm_importance_sampling_cap = args.vllm_importance_sampling_cap - self.use_liger_loss = args.use_liger_loss - self.loss_type = args.loss_type - self.scale_rewards = args.scale_rewards - self.importance_sampling_level = args.importance_sampling_level - self.mask_truncated_completions = args.mask_truncated_completions - self.top_entropy_quantile = args.top_entropy_quantile - if self.use_liger_loss and self.top_entropy_quantile < 1.0: - raise NotImplementedError( - "Liger Kernels don't currently support masking token positions based on entropy." - ) - if self.use_liger_loss and not self.importance_sampling_level == "token": - raise NotImplementedError( - "Liger Kernels currently only support token-level importance sampling. Please set" - "`importance_sampling_level` to 'token'." - ) - - # Datasets - self.shuffle_dataset = args.shuffle_dataset - - if ( - isinstance(train_dataset, IterableDataset) - or isinstance(eval_dataset, IterableDataset) - or ( - isinstance(eval_dataset, dict) and any(isinstance(ds, IterableDataset) for ds in eval_dataset.values()) - ) - ): - # See https://github.com/huggingface/trl/issues/3213 - raise NotImplementedError( - "Iterable datasets are not yet supported in GRPOTrainer. Please use a standard dataset instead." - ) - - # Multi-step - self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper - self.epsilon_low = args.epsilon - self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon - # Tracks the number of iterations [forward + backward passes], including those within a grad accum cycle - self._step = 0 - # Buffer the batch to reuse generated outputs across multiple updates. For more details, see - # `_get_train_sampler` and `_prepare_inputs`. - self._buffered_inputs = None - - # The trainer estimates the number of FLOPs [floating-point operations] using the number of elements in the - # input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the - # "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning: - # "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To - # suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True. - # This acts as a flag to indicate that the warning has already been issued. - model.warnings_issued["estimate_tokens"] = True - - super().__init__( - model=model, - args=args, - data_collator=identity, # No data collation is needed in GRPO - train_dataset=train_dataset, - eval_dataset=eval_dataset, - processing_class=processing_class, - callbacks=callbacks, - optimizers=optimizers, - # In Trainer, `training_step` scales the loss by `gradient_accumulation_steps` only if `compute_loss_func` - # is None. For DAPO, loss scaling instead depends on the total number of completions tokens across the - # global accumulated batch. To control scaling ourselves, we must disable Trainer’s built-in scaling. The - # simplest [though a bit hacky] way is to set `compute_loss_func` to any non-None value, which bypasses - # that behavior without rewriting `training_step`. - compute_loss_func="non-None value to disable scaling", - ) - - # Reference model - self.beta = args.beta - if self.beta == 0.0: - # If beta is 0.0, the reference model is not needed - self.ref_model = None - elif is_peft_model(model): - # If PEFT is used, the reference model is not needed since the adapter can be disabled - # to revert to the initial model. - self.ref_model = None - else: - # For deepspeed, fsdp or non-distributed models, create a reference model from scratch - config = AutoConfig.from_pretrained(model_id) - architecture = getattr(transformers, config.architectures[0]) - self.ref_model = architecture.from_pretrained(model_id, **model_init_kwargs) - - # Disable dropout in the models - if args.disable_dropout: - disable_dropout_in_model(model) - if self.ref_model is not None: - disable_dropout_in_model(self.ref_model) - - # Liger loss - if self.use_liger_loss: - if not is_liger_kernel_available(): - raise ImportError( - "Liger is required to use `liger_loss` as the GRPO loss. Run `pip install liger-kernel`." - ) - # redirect the model.module forward to the model forward to ensure pre-forward hooks are called - self._forward_redirection = _ForwardRedirection() - - self.liger_grpo_loss = LigerFusedLinearGRPOLoss( - beta=self.beta, - epsilon_low=self.epsilon_low, - epsilon_high=self.epsilon_high, - temperature=self.temperature, - use_ref_model=self.beta != 0.0, - loss_type=self.loss_type, - max_completion_length=self.max_completion_length, - ) - - # Initialize the metrics - self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} - self._total_train_tokens = 0 - self.log_completions = args.log_completions - self.wandb_log_unique_prompts = args.wandb_log_unique_prompts - self.num_completions_to_print = args.num_completions_to_print - # Keep logs sized to the generation batch to record only outputs from the latest model update. - self._logs = { - "images": deque(maxlen=args.generation_batch_size), - "prompt": deque(maxlen=args.generation_batch_size), - "completion": deque(maxlen=args.generation_batch_size), - "rewards": defaultdict(lambda: deque(maxlen=args.generation_batch_size)), - "advantages": deque(maxlen=args.generation_batch_size), - } - - # Ensure each process receives a unique seed to prevent duplicate completions when generating with - # transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but - # it's safer to set it in all cases. - set_seed(args.seed, device_specific=True) - - if self.use_vllm: - if not is_vllm_available(): - raise ImportError( - "vLLM is not available and `use_vllm` is set to True. Please install vLLM with " - "`pip install trl[vllm]` to use it." - ) - - if self.vllm_mode == "server": - if self.accelerator.is_main_process: - if args.vllm_server_base_url is not None: - base_url = args.vllm_server_base_url - else: - base_url = f"http://{args.vllm_server_host}:{args.vllm_server_port}" - self.vllm_client = VLLMClient(base_url=base_url, connection_timeout=args.vllm_server_timeout) - self.vllm_client.init_communicator(device=torch.cuda.current_device()) - - elif self.vllm_mode == "colocate": - if not self.accelerator.num_processes % self.vllm_tensor_parallel_size == 0: - raise ValueError( - f"vllm_tensor_parallel_size ({self.vllm_tensor_parallel_size}) must divide world size " - f"({self.accelerator.num_processes}) evenly." - ) - - if self.vllm_tensor_parallel_size > 1: - self.tp_group, _ = torch.distributed.new_subgroups_by_enumeration( - [ - list(range(i * self.vllm_tensor_parallel_size, (i + 1) * self.vllm_tensor_parallel_size)) - for i in range(self.accelerator.num_processes // self.vllm_tensor_parallel_size) - ] - ) - os.environ["RANK"] = str(self.accelerator.process_index) - os.environ["LOCAL_RANK"] = str(self.accelerator.local_process_index) - os.environ["WORLD_SIZE"] = str(self.accelerator.num_processes) - ensure_master_addr_port() - - if self.max_prompt_length is not None and self.max_completion_length is not None: - max_model_len = self.max_prompt_length + self.max_completion_length - else: - max_model_len = None - self.llm = model.vllm_engine - if self.args.vllm_enable_sleep_mode: - self.llm.sleep(level=1) - else: - raise ValueError(f"vllm_mode must be either 'server' or 'colocate', got '{self.vllm_mode}'.") - self.guided_decoding_regex = args.vllm_guided_decoding_regex - - self._last_loaded_step = -1 - self.accelerator.wait_for_everyone() - else: - generation_kwargs = { - "max_new_tokens": self.max_completion_length, - "do_sample": True, - "pad_token_id": tokenizer.pad_token_id, - "bos_token_id": tokenizer.bos_token_id, - "eos_token_id": tokenizer.eos_token_id, - "temperature": self.temperature, - "top_p": self.top_p, - "top_k": self.top_k, - "min_p": self.min_p, - "repetition_penalty": self.repetition_penalty, - "cache_implementation": args.cache_implementation, - } - if args.generation_kwargs is not None: - generation_kwargs.update(args.generation_kwargs) - self.generation_config = GenerationConfig(**generation_kwargs) - - # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the - # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set - # self.model_accepts_loss_kwargs to False to enable scaling. - self.model_accepts_loss_kwargs = False - - # Add tags to the model - self.model.add_model_tags(self._tag_names) - - if self.ref_model is not None: - if self.is_deepspeed_enabled: - self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) - elif self.is_fsdp_enabled: - self.ref_model = prepare_fsdp(self.ref_model, self.accelerator) - else: - self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) - - if args.sync_ref_model: - self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator)) - - for i, reward_func in enumerate(self.reward_funcs): - if isinstance(reward_func, PreTrainedModel): - if self.is_deepspeed_enabled: - self.reward_funcs[i] = prepare_deepspeed(reward_func, self.accelerator) - else: - # set device placement to True to make `prepare_model` move `reward_func` to device when using fsdp - self.reward_funcs[i] = self.accelerator.prepare_model( - reward_func, evaluation_mode=True, device_placement=True - ) - - def _set_signature_columns_if_needed(self): - # If `self.args.remove_unused_columns` is True, non-signature columns are removed. - # By default, this method sets `self._signature_columns` to the model's expected inputs. - # In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work. - # Instead, we set them to the columns expected by the `training_step` method, hence the override. - if self._signature_columns is None: - self._signature_columns = ["prompt", "image", "images"] - - # This method overrides `Trainer.get_train_dataloader` to support our custom batching strategy. - # Instead of returning a standard per-step batch (i.e., `per_device_batch_size), our dataloader loads an - # *generation* batch (i.e., `per_device_batch_size × steps_per_generation`). This allows us to generate completions - # once every steps_per_generation step—rather than once per accumulation step—which is significantly more - # efficient. The only change from the original implementation is multiplying the batch size by - # `steps_per_generation`. Thus, `_prepare_inputs` is called with this *generation* batch, and it handles the - # splitting internally. - # Maintenance note: This method is a copy-paste of the original `Trainer.get_train_dataloader` with only one line - # modification. As a result, some parts of the method aren't relevant to GRPO, but we keep them to stay one line - # apart from the super method, ensuring easier maintenance in the future. - def get_train_dataloader(self): - if self.train_dataset is None: - raise ValueError("Trainer: training requires a train_dataset.") - - train_dataset = self.train_dataset - data_collator = self.data_collator - if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): - train_dataset = self._remove_unused_columns(train_dataset, description="training") - else: - data_collator = self._get_collator_with_removed_columns(data_collator, description="training") - - dataloader_params = { - "batch_size": self._train_batch_size * self.args.steps_per_generation, # < this is the change - "collate_fn": data_collator, - "num_workers": self.args.dataloader_num_workers, - "pin_memory": self.args.dataloader_pin_memory, - "persistent_workers": self.args.dataloader_persistent_workers, - } - - if not isinstance(train_dataset, torch.utils.data.IterableDataset): - dataloader_params["sampler"] = self._get_train_sampler() - dataloader_params["drop_last"] = self.args.dataloader_drop_last - dataloader_params["worker_init_fn"] = partial( - seed_worker, num_workers=self.args.dataloader_num_workers, rank=self.args.process_index - ) - - dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor - - return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) - - def _get_train_sampler(self, dataset: Optional[Dataset] = None) -> Sampler: - # Returns a sampler that - # 1. ensures each prompt is repeated across multiple processes. This guarantees that identical prompts are - # distributed to different GPUs, allowing rewards to be computed and normalized correctly within each prompt - # group. Using the same seed across processes ensures consistent prompt assignment, preventing discrepancies - # in group formation. - # 2. repeats the batch multiple times to allow reusing generations across multiple updates. Refer to - # _prepare_inputs to see how the generations are stored and reused. - - # In the following figure, the values are the prompt indices. The first row shows the first sampled batch, the - # second row shows the second sampled batch, and so on. - # - # | GPU 0 | GPU 1 | - # - # global_step step <-───> num_generations=2 - # <-───────> per_device_train_batch_size=3 - # grad_accum ▲ ▲ 0 0 0 0 1 1 2 2 <- Generate for the first `steps_per_generation` (prompts 0 to 11); store the completions; use the first slice to compute the loss - # =2 ▼ | 0 1 3 3 4 4 5 5 <- Take the stored generations and use the second slice to compute the loss - # | - # | 1 2 6 6 7 7 8 8 <- Take the stored generations and use the third slice to compute the loss - # steps_per_gen=4 ▼ 1 3 9 9 10 10 11 11 <- Take the stored generations and use the fourth slice to compute the loss - # - # 2 4 12 12 13 13 14 14 <- Generate for the second `steps_per_generation` (prompts 12 to 23); store the completions; use the first slice to compute the loss - # 2 5 15 15 16 16 17 17 <- Take the stored generations and use the second slice to compute the loss - # ... - if dataset is None: - dataset = self.train_dataset - return RepeatSampler( - data_source=dataset, - mini_repeat_count=self.num_generations, - batch_size=self.args.generation_batch_size // self.num_generations, - repeat_count=self.num_iterations * self.args.steps_per_generation, - shuffle=self.shuffle_dataset, - seed=self.args.seed, - ) - - def _get_eval_sampler(self, eval_dataset) -> Sampler: - # See _get_train_sampler for an explanation of the sampler. - return RepeatSampler( - data_source=eval_dataset, - mini_repeat_count=self.num_generations, - seed=self.args.seed, - ) - - @profiling_decorator - def _get_last_hidden_state( - self, - unwrapped_model, - input_ids, - attention_mask, - logits_to_keep, - pixel_values=None, - image_grid_thw=None, - pixel_attention_mask=None, - image_sizes=None, - ): - if is_peft_model(unwrapped_model): - unwrapped_model = unwrapped_model.base_model.model - - # Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't) - model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask} - - # For Qwen models: - if image_grid_thw is not None and pixel_values is not None: - model_inputs["image_grid_thw"] = image_grid_thw - # For Gemma, SmolVLM2, LLaVa-Next etc.: - if pixel_values is not None: - model_inputs["pixel_values"] = pixel_values - # For SmolVLM2 - if pixel_attention_mask is not None: - model_inputs["pixel_attention_mask"] = pixel_attention_mask - # For LLaVa-Next - if image_sizes is not None: - model_inputs["image_sizes"] = image_sizes - - # Only add logits_to_keep if the model supports it - if "logits_to_keep" in self.model_kwarg_keys: - # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded - model_inputs["logits_to_keep"] = logits_to_keep + 1 - - model_inputs["use_cache"] = False # only used in generation; set False to suppress warnings - - last_hidden_state = unwrapped_model.model(**model_inputs).last_hidden_state - # Exclude the last value: it corresponds to the next token pred - last_hidden_state = last_hidden_state[:, :-1, :] # (B, L-1, H) - # Only keep the last logits_to_keep. For model that support logits_to_keep, this is a no-op. - last_hidden_state = last_hidden_state[:, -logits_to_keep:, :] # (B, logits_to_keep, H) - return last_hidden_state - - def get_high_entropy_mask(self, entropies: torch.Tensor, mask: torch.Tensor, threshold: float) -> torch.Tensor: - """ - Returns a binary mask identifying tokens whose entropy exceeds a given quantile threshold. - - Args: - entropies (`torch.Tensor`): - Tensor of shape (batch_size, seq_len) with per-token entropy values. - mask (`torch.Tensor`): - Binary mask of the same shape as `entropies`, where `1` indicates valid tokens and `0` padding. - threshold (`float`): - Quantile threshold between `0.0` and `1.0` to select high-entropy tokens. - - Returns: - `torch.Tensor`: - Boolean mask of shape (batch_size, seq_len), where `True` indicates tokens with entropy >= threshold - and `False` otherwise. - """ - local = entropies[mask.bool()].float() - - # Use a negative pad_value as a sentinel because entropy values are always >= 0. - # This guarantees that the sentinel cannot collide with any real entropy value. - pad_value = -1e9 - - # Pad across processes so that every rank has the same tensor length - padded = self.accelerator.pad_across_processes(local, dim=0, pad_index=pad_value) - gathered = self.accelerator.gather(padded) - - # Drop sentinel values (safe because no entropy can be negative) - gathered = gathered[gathered != pad_value] - - if gathered.numel() == 0: - return torch.zeros_like(entropies, dtype=torch.bool) - - entropy_threshold = torch.quantile(gathered, threshold) - masked_entropies = entropies * mask.float() - entropy_mask = masked_entropies >= entropy_threshold - return entropy_mask & mask.bool() # ensure padding tokens are always masked out - - def _get_per_token_logps_and_entropies( - self, - model, - input_ids, - attention_mask, - logits_to_keep, - batch_size = None, - compute_entropy = False, - compute_efficient = False, - *args, - **kwargs, - ): - # All Unsloth code here in this function is licensed under AGPL3 - # if True: # os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0': - # return None, None # logps, entropies Unsloth efficient GRPO - if compute_efficient: - return None, None - else: - if not hasattr(self, "_autocast_dtype"): - self._autocast_dtype = ( - torch.float16 - if os.environ.get("ACCELERATE_MIXED_PRECISION", "fp16") == "fp16" - else torch.bfloat16 - ) - if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": - self._autocast_dtype = torch.float16 - - pixel_values, image_grid_thw = ( - kwargs.get("pixel_values", None), - kwargs.get("image_grid_thw", None), - ) - pixel_attention_mask, image_sizes = ( - kwargs.get("pixel_attention_mask", None), - kwargs.get("image_sizes", None), - ) - - unwrapped_model = self.accelerator.unwrap_model( - model, keep_fp32_wrapper = False - ) - - lm_head = self.model.get_output_embeddings().weight - - dtype_bytes = ( - 16 if self._autocast_dtype in [torch.float16, torch.bfloat16] else 32 - ) - total_rows = input_ids.shape[0] - seq_len = input_ids.shape[1] - hidden_dim = lm_head.shape[1] - vocab_dim = lm_head.shape[0] - - if self.args.unsloth_grpo_mini_batch is None: - B, multiplier = autotune_batch_and_chunks( - total_rows, - seq_len, - hidden_dim, - vocab_dim, - dtype_bytes, - self.args.unsloth_logit_chunk_multiplier, - ) - B = total_rows // B - else: - B = self.args.unsloth_grpo_mini_batch - - if self.args.unsloth_logit_chunk_multiplier is None: - multiplier = max(4, seq_len // 4096) - else: - multiplier = self.args.unsloth_logit_chunk_multiplier - - all_logprobs_list = [] - if pixel_values is None: - left_pad_tokens_per_prompt = calculate_pad_tokens_in_prompt( - input_ids, logits_to_keep, self.processing_class.pad_token_id - ) - max_left_pad = torch.max(left_pad_tokens_per_prompt).item() - input_ids = left_pack_padding( - input_ids, self.processing_class.pad_token_id - ) - attention_mask = input_ids != self.processing_class.pad_token_id - attention_mask = attention_mask.to(attention_mask.dtype) - else: - max_left_pad = 0 - - # input_ids_chunks = torch.chunk(input_ids, chunks = B, dim = 0) - attention_mask_chunks = torch.chunk(attention_mask, chunks = B, dim = 0) - - def chunk_optional(tensor, chunks): - if tensor is None: - return [None] * chunks - return torch.chunk(tensor, chunks = chunks, dim = 0) - - import math - - total_samples = input_ids.shape[0] - batch_size = math.ceil(total_samples / B) - - input_ids_chunks = [] - attention_mask_chunks = [] - pixel_values_chunks = [] - image_grid_thw_chunks = [] - pixel_attention_mask_chunks = [] - - current_pixel_idx = 0 - # TRL 0.23.0 batching logic - for start in range(0, total_samples, batch_size): - end = start + batch_size - - input_ids_chunks.append(input_ids[start:end]) - attention_mask_chunks.append(attention_mask[start:end]) - - if image_grid_thw is not None and pixel_values is not None: - grid_slice = image_grid_thw[start:end] - image_grid_thw_chunks.append(grid_slice) - - batch_pixel_count = grid_slice.prod(dim = -1).sum().item() - - start_pixel_idx = current_pixel_idx - end_pixel_idx = current_pixel_idx + batch_pixel_count - - pixel_values_chunks.append( - pixel_values[start_pixel_idx:end_pixel_idx] - ) - - if pixel_attention_mask is not None: - pixel_attention_mask_chunks.append( - pixel_attention_mask[start_pixel_idx:end_pixel_idx] - ) - else: - pixel_attention_mask_chunks.append(None) - - current_pixel_idx = end_pixel_idx - - else: - pixel_values_chunks.append(None) - image_grid_thw_chunks.append(None) - pixel_attention_mask_chunks.append(None) - - if image_sizes is not None and not isinstance(image_sizes, torch.Tensor): - image_sizes_chunks = [[size] for size in image_sizes] - else: - image_sizes_chunks = chunk_optional(image_sizes, B) - - temperature = self.temperature - logit_softcapping = getattr(model.config, "final_logit_softcapping", 0) - if logit_softcapping is None: - logit_softcapping = 0 - logit_scale_multiply = getattr(model.config, "logit_scale", 0) - if logit_scale_multiply is None: - logit_scale_multiply = 0 - logit_scale_divide = getattr(model.config, "logits_scaling", 0) - if logit_scale_divide is None: - logit_scale_divide = 0 - - zipped_inputs = zip( - input_ids_chunks, - attention_mask_chunks, - pixel_values_chunks, - image_grid_thw_chunks, - pixel_attention_mask_chunks, - image_sizes_chunks, - ) - os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1" - - with _get_inference_mode_context_manager(model): - for ( - input_ids_chunk, - attention_mask_chunk, - pixel_values_chunk, - image_grid_thw_chunk, - pixel_attention_mask_chunk, - image_sizes_chunk, - ) in zipped_inputs: - with torch.amp.autocast( - device_type = "cuda", dtype = self._autocast_dtype - ): - if pixel_values is None: - logits_chunk = unwrapped_model( - input_ids = input_ids_chunk, - attention_mask = attention_mask_chunk, - pixel_values = pixel_values_chunk, - image_grid_thw = image_grid_thw_chunk, - pixel_attention_mask = pixel_attention_mask_chunk, - image_sizes = image_sizes_chunk, - ).logits - - completion_input_ids_chunk = input_ids_chunk[ - :, -(logits_to_keep + max_left_pad) : - ] - logits_chunk = logits_chunk[ - :, -(logits_to_keep + max_left_pad + 1) :, : - ] - logits_chunk = logits_chunk[:, :-1, :] - else: - # Essentially, for VLMs we do not go via the optimized path in models/, - # so we don't encounter the Flash Attn left-padding issue. - logits_chunk = unwrapped_model( - input_ids = input_ids_chunk, - attention_mask = attention_mask_chunk, - pixel_values = pixel_values_chunk, - image_grid_thw = image_grid_thw_chunk, - pixel_attention_mask = pixel_attention_mask_chunk, - image_sizes = image_sizes_chunk, - logits_to_keep = logits_to_keep + 1, - ).logits - - logits_chunk = logits_chunk[:, :-1, :] - completion_input_ids_chunk = input_ids_chunk[ - :, -logits_to_keep: - ] - - logprobs_chunk = chunked_hidden_states_selective_log_softmax( - logits_chunk, - lm_head, - completion_input_ids_chunk, - chunks = input_ids_chunk.shape[0] * multiplier, - logit_scale_multiply = logit_scale_multiply, - logit_scale_divide = logit_scale_divide, - logit_softcapping = logit_softcapping, - temperature = temperature, - ) - # This is needed to avoid race conditions with GPT OSS offload_embbed=True - # However, it seems that this line does not slow down or disrupt models. - device_synchronize() - all_logprobs_list.append(logprobs_chunk) - logprobs = torch.cat(all_logprobs_list, dim = 0) - entropies = None - - os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "0" - - return logprobs.detach(), entropies # logps, entropies - # input_ids = input_ids[:, -logits_to_keep:] - # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves. - # See https://github.com/huggingface/trl/issues/2770 - # logits = logits[:, -logits_to_keep:] - # return logits - # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details - # logits = logits / self.temperature - # logps = selective_log_softmax(logits, input_ids) - - # row_indices, col_indices = torch.where(logps < -20) - - # # Method 1: Check if tensors have elements - # if len(row_indices) > 0 and len(col_indices) > 0: - # breakpoint() # Breakpoint triggered here - # print("Found high values!") - # return logps # compute logprobs for the input tokens - - def _fix_param_name_to_vllm(self, name, extra_prefixes: Optional[list[str]] = None): - extra_prefixes = extra_prefixes or [] - prefixes = ["_checkpoint_wrapped_module."] + extra_prefixes - for prefix in prefixes: - name = name.replace(prefix, "") - return name - - def _sync_fsdp1_params_to_vllm(self, module: nn.Module, prefix: str = "", visited=None): - """Memory-efficient post-order traversal of FSDP modules to extract full parameters and sync with vLLM.""" - # For FSDP1, we need to recurse into children and also use summon_full_params - if visited is None: - visited = set() - for child_name, child_module in module.named_children(): - child_prefix = f"{prefix}.{child_name}" if prefix else child_name - self._sync_fsdp1_params_to_vllm( - child_module, prefix=child_prefix, visited=visited - ) # recurse into the child - - if isinstance(module, FSDP): - with FSDP.summon_full_params(module, recurse=False, writeback=False): - for param_name, param in module.named_parameters(): - full_name = f"{prefix}.{param_name}" if prefix else param_name - full_name = self._fix_param_name_to_vllm(full_name, extra_prefixes=["_fsdp_wrapped_module."]) - - if full_name in visited: - continue # skip FSDP subtrees already traversed - visited.add(full_name) - - if self.vllm_mode == "server" and self.accelerator.is_main_process: - self.vllm_client.update_named_param(full_name, param.data) - elif self.vllm_mode == "colocate": - - pass - - pass - - def _sync_fsdp2_params_to_vllm(self, module: nn.Module): - # For FSDP2, module already covers all parameters, so no need for recursion - for name, param in module.items(): - if param.is_cpu: - param = param.to(torch.device("cuda")) - param = param.full_tensor() - - if self.vllm_mode == "server" and self.accelerator.is_main_process: - self.vllm_client.update_named_param(name, param) - elif self.vllm_mode == "colocate": - - pass - - pass - - def _move_model_to_vllm(self, *args, **kwargs): - return None - - @profiling_decorator - def _prepare_inputs( - self, generation_batch: dict[str, Union[torch.Tensor, Any]] - ) -> dict[str, Union[torch.Tensor, Any]]: - # Prepares inputs for model training/evaluation by managing completion generation and batch handling. - # During training: - # - Receives the local generation batch (Per-GPU batch size × steps per generation) - # from the modified training dataloader instead of the standard local batch - # - Generates completions once for the entire generation batch and splits it into batches of size - # `per_device_train_batch_size` - # - Buffers these completions and returns the appropriate slice for the current accumulation step - # - Optimizes by regenerating completions only periodically (every steps_per_generation * num_iterations) - # During evaluation: - # - The input is treated as a standard local batch (no accumulation, no multiple iterations) - # - Completions are generated for each batch without buffering or reuse - # Returns a single local batch in both cases. - - mode = "train" if self.model.training else "eval" - if mode == "train": - generate_every = self.args.steps_per_generation * self.num_iterations - if self._step % generate_every == 0 or self._buffered_inputs is None: - # self._buffered_inputs=None can occur when resuming from a checkpoint - generation_batch = self._generate_and_score_completions(generation_batch) - generation_batch = split_pixel_values_by_grid(generation_batch) - - try: generation_batch = shuffle_sequence_dict(generation_batch) - - except: pass - generation_batches = split_tensor_dict(generation_batch, self.args.steps_per_generation) - self._buffered_inputs = [unsplit_pixel_values_by_grid(batch) for batch in generation_batches] - inputs = self._buffered_inputs[self._step % self.args.steps_per_generation] - self._step += 1 - else: - # In evaluation, there is neither batch grouping for generation, nor multiple iterations, hence - # local generation batch == local eval batch - inputs = self._generate_and_score_completions(generation_batch) - return inputs - - @profiling_decorator - def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): - device = self.accelerator.device - rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device) - - # Repeat all input columns (but "prompt", "completion", and "completion_ids") to match the num of generations - keys = [key for key in inputs[0] if key not in ["prompt", "completion", "completion_ids"]] - reward_kwargs = {key: [example[key] for example in inputs] for key in keys} - - # This allows for dynamic reward shaping based on training progress. - reward_kwargs["trainer_state"] = self.state - - for i, (reward_func, reward_processing_class, reward_func_name) in enumerate( - zip(self.reward_funcs, self.reward_processing_classes, self.reward_func_names) - ): - with profiling_context(self, reward_func_name): - if isinstance(reward_func, nn.Module): # Module (no PretrainedModel) for compat with compiled models - if is_conversational(inputs[0]): - messages = [{"messages": p + c} for p, c in zip(prompts, completions)] - texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages] - else: - texts = [p + c for p, c in zip(prompts, completions)] - reward_inputs = reward_processing_class( - text=texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False - ) - reward_inputs = super()._prepare_inputs(reward_inputs) - with torch.inference_mode(): - rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,) - else: - output_reward_func = reward_func( - prompts=prompts, completions=completions, completion_ids=completion_ids_list, **reward_kwargs - ) - # Convert None values to NaN - output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func] - - rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) - - # If all reward functions return None for a given row, issue a detailed warning - if torch.isnan(rewards_per_func).all(dim=1).any(): - nan_row_idx = torch.isnan(rewards_per_func).all(dim=1).nonzero(as_tuple=True)[0][0] - row_reward_kwargs = { - key: value[nan_row_idx] for key, value in reward_kwargs.items() if key != "trainer_state" - } - row_reward_kwargs["prompt"] = prompts[nan_row_idx] - row_reward_kwargs["completion"] = completions[nan_row_idx] - logger.warning( - f"All reward functions returned None for the following kwargs:\n{row_reward_kwargs}\n" - "Please ensure that at least one reward function returns a valid reward." - ) - - # Gather the reward per function: this part is crucial, because the rewards are normalized per group and the - # completions may be distributed across processes - rewards_per_func = gather(rewards_per_func) - return rewards_per_func - - def _generate_single_turn(self, prompts: list[str], images: Optional[list]): - device = self.accelerator.device - - # If the prompts are conversational and the inputs contain images, we need to convert the prompts from - # [{"role": "user", "content": "What color is the sky?"}] to - # [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}] - kwargs = {} - if images is not None: - kwargs = {"images": images} - for prompt, image_list in zip(prompts, images): - if isinstance(prompt, list): # i.e., when using conversational data - prepare_multimodal_messages(prompt, num_images=len(image_list)) - - - _chat_template_ = getattr(self.processing_class, "chat_template", None) - if _chat_template_ is None: _chat_template_ = "" - _supported_keys_ = set(("prompt", "chosen", "rejected", "completion", "messages", "label")) - _batch_chat_kwargs_ = getattr(self, "_unsloth_batch_chat_kwargs", None) - - prompts_text = [] - for _idx_, _example_ in enumerate(prompts): - _tokenizer_kwargs_ = {} - if type(_example_) is not dict: - _example_ = {"prompt": _example_} - _left_keys_ = _example_.keys() - _supported_keys_ - for k in _left_keys_: - if k in _chat_template_: - v = _example_[k] - if type(v) is str: - _tokenizer_kwargs_[k] = v - if _batch_chat_kwargs_ is not None and _idx_ < len(_batch_chat_kwargs_): - for _bk_, _bv_ in _batch_chat_kwargs_[_idx_].items(): - if _bk_ not in _tokenizer_kwargs_: - _tokenizer_kwargs_[_bk_] = _bv_ - _x_ = maybe_apply_chat_template(_example_, self.processing_class, **_tokenizer_kwargs_)["prompt"] - prompts_text.append(_x_) - if images is not None: - prompt_inputs = self.processing_class(text=prompts_text, padding=True, return_tensors="pt", **kwargs) - prompt_inputs = super()._prepare_inputs(prompt_inputs) - forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]} - else: - forward_kwargs = {} - - # Generate completions using either vLLM or regular generation - if self.use_vllm: - if self.vllm_mode == "colocate" and self.args.vllm_enable_sleep_mode: - # wake up colocated vLLM instances if needed - torch.cuda.empty_cache() # required to avoid OOM in some cases - self.llm.wake_up() - - # First, update the vLLM weights if needed - if self.state.global_step != self._last_loaded_step: - self._move_model_to_vllm() - self._last_loaded_step = self.state.global_step - - # Generate completions using vLLM: gather all prompts and use them in a single call in the main process - if self.vllm_mode == "server": - all_prompts_text = gather_object(prompts_text) - if images is not None: - all_images = gather_object(images) - - if self.accelerator.is_main_process: - # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate - # num_generations outputs for each one. This is faster than generating outputs for each duplicate - # prompt individually. - ordered_set_of_prompts = all_prompts_text[:: self.num_generations] - - if images is not None: - ordered_set_of_images = all_images[:: self.num_generations] - else: - ordered_set_of_images = None - - with profiling_context(self, "vLLM.generate"): - output = self.vllm_client.generate( - prompts=ordered_set_of_prompts, - images=ordered_set_of_images, - n=self.num_generations, - repetition_penalty=self.repetition_penalty, - temperature=self.temperature, - top_p=self.top_p, - top_k=-1 if self.top_k is None else self.top_k, - min_p=0.0 if self.min_p is None else self.min_p, - max_tokens=self.max_completion_length, - truncate_prompt_tokens=self.max_prompt_length, - guided_decoding_regex=self.guided_decoding_regex, - generation_kwargs=self.args.generation_kwargs, - ) - payload = (output["prompt_ids"], output["completion_ids"], output["logprobs"]) - else: - payload = None - - # Broadcast the completions from the main process to all processes, ensuring each process receives its corresponding slice. - obj_list = [payload] - broadcast_object_list(obj_list, from_process=0) - all_prompt_ids, all_completion_ids, all_logprobs = obj_list[0] - - # At this point, we only get 1 copy of each prompt, so we need to repeat them num_generations times - all_prompt_ids = [ids for ids in all_prompt_ids for _ in range(self.num_generations)] - - process_slice = slice( - self.accelerator.process_index * len(prompts), - (self.accelerator.process_index + 1) * len(prompts), - ) - prompt_ids = all_prompt_ids[process_slice] - completion_ids = all_completion_ids[process_slice] - logprobs = all_logprobs[process_slice] - - # Generate completions using colocated vLLM instances: each device holds vLLM copy and work on their own batch of prompts - elif self.vllm_mode == "colocate": - if self.guided_decoding_regex: - guided_decoding = GuidedDecodingParams(regex=self.guided_decoding_regex) - else: - guided_decoding = None - - generation_kwargs = { - "n": 1, # vLLM on each GPU generates only 1 in colocate mode - "repetition_penalty": self.repetition_penalty, - "temperature": self.temperature, - "top_p": self.top_p, - "top_k": -1 if self.top_k is None else self.top_k, - "min_p": 0.0 if self.min_p is None else self.min_p, - "max_tokens": self.max_completion_length, - "truncate_prompt_tokens": self.max_prompt_length, - "guided_decoding": guided_decoding, - "logprobs": 0, # only return the logprob of the generated token - } - if self.args.generation_kwargs is not None: - generation_kwargs.update(self.args.generation_kwargs) - sampling_params = SamplingParams(**grpo_update_SamplingParams(SamplingParams, generation_kwargs, getattr(self.args, 'vllm_sampling_params', None))) - - if self.vllm_tensor_parallel_size > 1: - # Gather prompts from all ranks in the TP group and flatten. - # Each rank starts with its own prompts; after gathering, all ranks see the full group set. - orig_size = len(prompts_text) - gathered_prompts = [None for _ in range(self.vllm_tensor_parallel_size)] - torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.tp_group) - all_prompts_text = [p for sublist in gathered_prompts for p in sublist] - - if images is not None: - gathered_images = [None for _ in range(self.vllm_tensor_parallel_size)] - torch.distributed.all_gather_object(gathered_images, images, group=self.tp_group) - all_images = [img for sublist in gathered_images for img in sublist] - else: - all_images = None - else: - all_prompts_text = prompts_text - all_images = images - - if images is not None and all_images: - vllm_inputs = [] - for prompt, image_list in zip(all_prompts_text, all_images): - vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image_list}}) - - else: - vllm_inputs = all_prompts_text - - with profiling_context(self, "vLLM.generate"): - all_outputs = self.llm.generate(vllm_inputs, sampling_params=sampling_params, use_tqdm=False, lora_request = self.model.load_lora('grpo_trainer_lora_model', load_tensors = True)) - - all_prompt_ids = [output.prompt_token_ids for output in all_outputs] - all_completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] - all_logprobs = [ - [next(iter(lp.values())).logprob for lp in output.logprobs] - for outputs in all_outputs - for output in outputs.outputs - ] - - if self.vllm_tensor_parallel_size > 1: - # Slice completions for this rank within its TP group. - # Each rank generates all outputs — we keep only our share. - local_rank_in_group = torch.distributed.get_rank(group=self.tp_group) - tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size) - prompt_ids = all_prompt_ids[tp_slice] - completion_ids = all_completion_ids[tp_slice] - logprobs = all_logprobs[tp_slice] - else: - prompt_ids = all_prompt_ids - completion_ids = all_completion_ids - logprobs = all_logprobs - - if self.args.vllm_enable_sleep_mode: - self.llm.sleep(level=1) - - elif self.use_transformers_paged: - # Re-process inputs for paged generation if needed - # Note: images are already validated and preprocessed above - paged_prompt_inputs = self.processing_class(text=prompts_text, **kwargs) - previous_attn = self.model_wrapped.config._attn_implementation - - if is_flash_attn_2_available(): - self.model_wrapped.config._attn_implementation = "paged_attention" - else: - self.model_wrapped.config._attn_implementation = "sdpa_paged" - with ( - profiling_context(self, "transformers.generate_batch"), - unwrap_model_for_generation( - self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation - ) as unwrapped_model, - torch.no_grad(), - FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(), - ): - # Cast to the appropriate dtype based on training configuration - if self.args.bf16: - unwrapped_model.to(torch.bfloat16) - elif self.args.fp16: - unwrapped_model.to(torch.float16) - with torch.inference_mode(): - all_outputs = unwrapped_model.generate_batch( - paged_prompt_inputs.input_ids, generation_config=self.generation_config, progress_bar=False - ) - unwrapped_model.train() # restore training mode, as generate_batch forces eval mode - completion_ids = [output.generated_tokens for output in all_outputs.values()] - prompt_ids = paged_prompt_inputs.input_ids - # Restore the original attention implementation, training mode - self.model_wrapped.config._attn_implementation = previous_attn - logprobs = None # not used in this case - - else: - # Regular generation path - generate_inputs = self.processing_class( - text=prompts_text, - return_tensors="pt", - padding=True, - padding_side="left", - **kwargs, - ) - generate_inputs = super()._prepare_inputs(generate_inputs) - - with ( - profiling_context(self, "transformers.generate"), - unwrap_model_for_generation( - self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation - ) as unwrapped_model, - torch.no_grad(), - FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(), - ): - prompt_completion_ids = unwrapped_model.generate( - **generate_inputs, generation_config=self.generation_config, disable_compile=True - ) - # Compute prompt length and extract completion ids - prompt_ids, prompt_mask = generate_inputs["input_ids"], generate_inputs["attention_mask"] - prompt_length = prompt_ids.size(1) - completion_ids = prompt_completion_ids[:, prompt_length:] - - # Mask everything after the first EOS token - is_eos = completion_ids == self.eos_token_id - eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) - eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] - sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) - completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() - prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())] - completion_ids = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool())] - logprobs = None # not used in this case - - return prompt_ids, completion_ids, logprobs, forward_kwargs - - def _generate(self, prompts: list[str], images: Optional[list]): - device = self.accelerator.device - mode = "train" if self.model.training else "eval" - - prompt_ids, completion_ids, logprobs, forward_kwargs = self._generate_single_turn(prompts, images) - - # Get completion length per sequence, used for logging - prompt_lengths = torch.tensor([len(ids) for ids in prompt_ids], device=device) - completion_lengths = torch.tensor([len(ids) for ids in completion_ids], device=device) - agg_prompt_lengths = self.accelerator.gather(prompt_lengths) - agg_completion_lengths = self.accelerator.gather(completion_lengths) - total_prompt_tokens = agg_prompt_lengths.sum() - total_completion_tokens = agg_completion_lengths.sum() # = num_items_in_batch, required for the DAPO loss - - # Log the metrics - if mode == "train": - self.state.num_input_tokens_seen += (total_prompt_tokens + total_completion_tokens).item() - self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen] - - # Log completion lengths, mean, min, max - self._metrics[mode]["completions/mean_length"].append(agg_completion_lengths.float().mean().item()) - self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item()) - self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) - - # Identify sequences that terminated with EOS and log their lengths - eos_and_pad = [self.eos_token_id, self.pad_token_id] - is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids], device=device) - agg_is_truncated = self.accelerator.gather(is_truncated) - self._metrics[mode]["completions/clipped_ratio"].append(agg_is_truncated.float().mean().item()) - term_completion_lengths = agg_completion_lengths[~agg_is_truncated] - if len(term_completion_lengths) == 0: # edge case where no terminated sequences are found - term_completion_lengths = torch.zeros(1, device=device) - self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item()) - self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) - self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) - - return prompt_ids, completion_ids, total_completion_tokens, logprobs, forward_kwargs - - def _generate_and_score_completions( - self, inputs: list[dict[str, Union[torch.Tensor, Any]]] - ) -> dict[str, Union[torch.Tensor, Any]]: - device = self.accelerator.device - mode = "train" if self.model.training else "eval" - - prompts = [x["prompt"] for x in inputs] - # Unsloth: Extract per-sample chat_template_kwargs before metadata is lost - _ct_ = getattr(self.processing_class, 'chat_template', None) or '' - _sk_ = {'prompt', 'chosen', 'rejected', 'completion', 'messages', 'label', - 'images', 'image', 'videos', 'video', 'audios', 'audio'} - self._unsloth_batch_chat_kwargs = [] - for _inp_ in inputs: - _kw_ = {} - if isinstance(_inp_, dict): - for _k_ in _inp_.keys() - _sk_: - if _k_ in _ct_ and isinstance(_inp_[_k_], str): - _kw_[_k_] = _inp_[_k_] - self._unsloth_batch_chat_kwargs.append(_kw_) - if "images" in inputs[0]: - images = [example.get("images") for example in inputs] - elif "image" in inputs[0]: - images = [[example.get("image")] if example.get("image") is not None else None for example in inputs] - else: - images = None - # Transformers requires at least one image in the batch, otherwise it throws an error - if images is not None and all(img_list == [] for img_list in images): - images = None - - ( - prompt_ids_list, - completion_ids_list, - num_items_in_batch, - sampling_per_token_logps_list, - forward_kwargs, - ) = self._generate(prompts, images) - - # Convert lists of token IDs to padded tensors - prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list] - prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] - prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") - prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left") - completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids_list] - completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] - completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right") - completion_mask = pad(completion_mask, padding_value=0, padding_side="right") - if sampling_per_token_logps_list is not None: - sampling_per_token_logps = [torch.tensor(logps, device=device) for logps in sampling_per_token_logps_list] - sampling_per_token_logps = pad(sampling_per_token_logps, padding_value=0.0, padding_side="right") - else: - sampling_per_token_logps = None - - # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask - if self.mask_truncated_completions: - eos_and_pad = [self.eos_token_id, self.pad_token_id] - is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) - completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int() - - # Concatenate prompt_mask with completion_mask for logit computation - prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C) - attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C) - # If token_type_ids are used, extend them with zeros for the completion part - if "token_type_ids" in forward_kwargs: - token_type_ids = forward_kwargs["token_type_ids"] - forward_kwargs["token_type_ids"] = torch.cat( - [token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1 - ) - - logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens - - max_left_pad = None - batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size - try: - # TRL 0.23.1 and below path - if not has_images: - # Left pad prompt before calculation old and ref hidden states - left_pad_tokens_per_prompt = calculate_pad_tokens_in_prompt(prompt_completion_ids, logits_to_keep, self.processing_class.pad_token_id) - max_left_pad = torch.max(left_pad_tokens_per_prompt).item() - except: - # TRL 0.24.0 and below path - if images is None: - # Left pad prompt before calculation old and ref hidden states - left_pad_tokens_per_prompt = calculate_pad_tokens_in_prompt(prompt_completion_ids, logits_to_keep, self.processing_class.pad_token_id) - max_left_pad = torch.max(left_pad_tokens_per_prompt).item() - self.model.for_training() - - num_images = [len(img_list) for img_list in images] if images is not None else None - - with torch.no_grad(): - # If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of - # a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the - # samples may come from an earlier version of the model. In that case, we need to track old_per_token_logps - # for importance sampling. If the steps are aligned, importance sampling isn't necessary and we set - # old_per_token_logps to None. - # When using vLLM, we always compute old_per_token_logps for importance sampling, it was shown that the - # distribution mismatch between vLLM and the training model can be large and harm the training. - generate_every = self.args.steps_per_generation * self.num_iterations # generation frequency - - if self.args.gradient_accumulation_steps % generate_every != 0 or ( - self.use_vllm - ): - old_per_token_logps, _ = self._get_per_token_logps_and_entropies( - self.model, - prompt_completion_ids, - attention_mask, - logits_to_keep, - batch_size, - num_images=num_images, - **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes - ) - else: - old_per_token_logps = None - - # Compute the importance sampling ratio when using vLLM, to correct for potential distribution mismatch - if False and self.use_vllm and self.vllm_importance_sampling_correction: - importance_sampling_ratio = torch.exp(old_per_token_logps - sampling_per_token_logps) - importance_sampling_ratio = torch.clamp( - importance_sampling_ratio, max=self.vllm_importance_sampling_cap - ) - - # Compute the per-token log probabilities for the reference model - if self.beta != 0.0: - if self.ref_model is not None: - ref_per_token_logps, _ = self._get_per_token_logps_and_entropies( - self.ref_model, - prompt_completion_ids, - attention_mask, - logits_to_keep, - batch_size=batch_size, - num_images=num_images, - **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes - ) - else: - with self.accelerator.unwrap_model(self.model).disable_adapter(): - ref_per_token_logps, _ = self._get_per_token_logps_and_entropies( - self.model, - prompt_completion_ids, - attention_mask, - logits_to_keep, - batch_size=batch_size, - num_images=num_images, - **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes - ) - else: - ref_per_token_logps = None - - # Decode - prompts_text = self.processing_class.batch_decode(prompt_ids, skip_special_tokens=True) - completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) - if is_conversational(inputs[0]): - completions = [] - for prompt, completion in zip(prompts, completions_text): - bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else "" - completions.append([{"role": "assistant", "content": bootstrap + completion}]) - else: - completions = completions_text - - # Calculate rewards for each reward function. rewards_per_func aggregates rewards across all processes. This is - # important because rewards will be normalized per group, and completions are distributed. We will later slice - # rewards_per_func to extract each process's subset. - if images is not None: - rewards_per_func = self._calculate_rewards(inputs, prompts_text, completions_text, completion_ids_list) - else: - rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list) - - # Apply weights to each reward function's output and sum - rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) - - # Compute grouped-wise rewards - mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1) - - # Normalize the rewards to compute the advantages - mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0) - advantages = rewards - mean_grouped_rewards - - if self.scale_rewards in ["group", "none"]: - # If self.scale_rewards = "none", we'll still log group level std - std_rewards = rewards.view(-1, self.num_generations).std(dim=1) - std_rewards = std_rewards.repeat_interleave(self.num_generations, dim=0) - elif self.scale_rewards == "batch": - # Compute global std - std_rewards = rewards.std().expand_as(rewards) - else: - raise ValueError( - f"Invalid value for scale_rewards: {self.scale_rewards}. Must be one of 'batch', 'group', or 'none'." - ) - - is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards)) - if self.scale_rewards != "none": - advantages = advantages / (std_rewards + 1e-4) - - # Slice to keep only the local part of the data - process_slice = slice( - self.accelerator.process_index * len(prompts), - (self.accelerator.process_index + 1) * len(prompts), - ) - all_process_advantages = advantages.clone() # keep the aggregated advantages for logging - advantages = advantages[process_slice] - - # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values) - for i, reward_func_name in enumerate(self.reward_func_names): - mean_rewards = torch.nanmean(rewards_per_func[:, i]).item() - self._metrics[mode][f"rewards/{reward_func_name}/mean"].append(mean_rewards) - std_func_rewards = nanstd(rewards_per_func[:, i]).item() - self._metrics[mode][f"rewards/{reward_func_name}/std"].append(std_func_rewards) - self._metrics[mode]["reward"].append(mean_grouped_rewards.mean().item()) - self._metrics[mode]["reward_std"].append(std_rewards.mean().item()) - self._metrics[mode]["frac_reward_zero_std"].append(is_std_zero.float().mean().item()) - - # Log prompt and completion texts - self._logs["prompt"].extend(gather_object(prompts_text)) - self._logs["completion"].extend(gather_object(completions_text)) - for i, name in enumerate(self.reward_func_names): - self._logs["rewards"][name].extend(rewards_per_func[:, i].tolist()) - self._logs["advantages"].extend(all_process_advantages.tolist()) - - if images is not None: - self._logs["images"].extend(gather_object(images)) - - if False and self.use_vllm and self.vllm_importance_sampling_correction: - delta = torch.abs(old_per_token_logps - sampling_per_token_logps) - delta = delta[completion_mask.bool()] - mean_delta = torch.mean(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device) - max_delta = torch.max(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device) - self._metrics[mode]["sampling/sampling_logp_difference/mean"].append( - self.accelerator.gather(mean_delta).mean().item() - ) - self._metrics[mode]["sampling/sampling_logp_difference/max"].append( - self.accelerator.gather(max_delta).max().item() - ) - - flat_is_ratio = importance_sampling_ratio[completion_mask.bool()] - min_importance_sampling_ratio = ( - torch.min(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) - ) - mean_importance_sampling_ratio = ( - torch.mean(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) - ) - max_importance_sampling_ratio = ( - torch.max(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) - ) - self._metrics[mode]["sampling/importance_sampling_ratio/min"].append( - nanmin(self.accelerator.gather(min_importance_sampling_ratio)).item() - ) - self._metrics[mode]["sampling/importance_sampling_ratio/mean"].append( - self.accelerator.gather(mean_importance_sampling_ratio).nanmean().item() - ) - self._metrics[mode]["sampling/importance_sampling_ratio/max"].append( - nanmax(self.accelerator.gather(max_importance_sampling_ratio)).item() - ) - - output = { - "prompt_ids": prompt_ids, - "prompt_mask": prompt_mask, - "completion_ids": completion_ids, - "completion_mask": completion_mask, - "advantages": advantages, - "num_items_in_batch": num_items_in_batch, - } - if old_per_token_logps is not None: - output["old_per_token_logps"] = old_per_token_logps - if False and self.use_vllm and self.vllm_importance_sampling_correction: - output["importance_sampling_ratio"] = importance_sampling_ratio - if ref_per_token_logps is not None: - output["ref_per_token_logps"] = ref_per_token_logps - if "pixel_values" in forward_kwargs: - output["pixel_values"] = forward_kwargs["pixel_values"] - if "image_grid_thw" in forward_kwargs: - output["image_grid_thw"] = forward_kwargs["image_grid_thw"] - if "pixel_attention_mask" in forward_kwargs: - output["pixel_attention_mask"] = forward_kwargs["pixel_attention_mask"] - if "image_sizes" in forward_kwargs: - output["image_sizes"] = forward_kwargs["image_sizes"] - if "token_type_ids" in forward_kwargs: - output["token_type_ids"] = forward_kwargs["token_type_ids"] - if images is not None: - output["num_images"] = num_images - if max_left_pad is not None: - output["max_left_pad"] = torch.tensor(prompt_ids.shape[0] * [max_left_pad]).unsqueeze(-1) - try: - if self.use_vllm and getattr(self, "vllm_importance_sampling_correction", False): - output["sampling_per_token_logps"] = sampling_per_token_logps - except NameError: - output["sampling_per_token_logps"] = None - return output - - def compute_liger_loss(self, unwrapped_model, inputs): - # Compute the per-token log probabilities for the model - prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] - completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] - input_ids = torch.cat([prompt_ids, completion_ids], dim=1) - attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) - logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens - - # Get the last hidden state of the model - last_hidden_state = self._get_last_hidden_state( - unwrapped_model, - input_ids, - attention_mask, - logits_to_keep, - inputs.get("pixel_values"), - inputs.get("image_grid_thw"), - inputs.get("pixel_attention_mask"), - inputs.get("image_sizes"), - ) - - # compute loss and metrics using liger grpo loss - loss, metrics = self.liger_grpo_loss( - _input=last_hidden_state, - lin_weight=unwrapped_model.lm_head.weight, - selected_token_ids=completion_ids, - attention_mask=completion_mask, - advantages=inputs["advantages"], - bias=unwrapped_model.lm_head.bias, - old_per_token_logps=inputs.get("old_per_token_logps"), - ref_per_token_logps=inputs.get("ref_per_token_logps"), - ) - # Extract metrics from the liger_grpo_loss output - # KL divergence is the first metric when beta is non-zero - mean_kl = metrics[0] if self.beta != 0.0 else None - clip_ratio = metrics[-1] - - mode = "train" if self.model.training else "eval" - if self.beta != 0.0: - self._metrics[mode]["kl"].append(self.accelerator.gather(mean_kl).mean().item()) - self._metrics[mode]["clip_ratio"].append(self.accelerator.gather(clip_ratio).mean().item()) - return loss / self.current_gradient_accumulation_steps - - def compute_loss( - self, model, inputs, return_outputs = False, num_items_in_batch = None - ): - if return_outputs: - raise ValueError("The GRPOTrainer does not support returning outputs") - # Compute the per-token log probabilities for the model - - prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] - completion_ids, completion_mask = ( - inputs["completion_ids"], - inputs["completion_mask"], - ) - pixel_values, image_grid_thw = ( - inputs.get("pixel_values", None), - inputs.get("image_grid_thw", None), - ) - pixel_attention_mask, image_sizes = ( - inputs.get("pixel_attention_mask", None), - inputs.get("image_sizes", None), - ) - num_items_in_batch = inputs.get("num_items_in_batch", None) - sampling_per_token_logps = inputs.get("sampling_per_token_logps", None) - current_gradient_accumulation_steps = self.current_gradient_accumulation_steps - num_processes = self.accelerator.num_processes - - input_ids = torch.cat([prompt_ids, completion_ids], dim = 1) - bsz, qlen = input_ids.shape - attention_mask = torch.cat([prompt_mask, completion_mask], dim = 1) - # attention_mask = None - logits_to_keep = completion_ids.size( - 1 - ) # we only need to compute the logits for the completion tokens - _input_ids = input_ids - _logits_to_keep = logits_to_keep - - get_logps_func = ( - lambda model, - input_ids, - attention_mask, - logits_to_keep, - batch_size = None, - compute_entropy = False, - compute_efficient = False: self._get_per_token_logps( - model, input_ids, attention_mask, logits_to_keep, compute_efficient - ) - if hasattr(self, "_get_per_token_logps") - else self._get_per_token_logps_and_entropies( - model, - input_ids, - attention_mask, - logits_to_keep, - batch_size, - compute_entropy, - compute_efficient, - )[0] - ) # logps - - per_token_logps = get_logps_func( - model, input_ids, attention_mask, logits_to_keep, compute_efficient = True - ) - # Compute the KL divergence between the model and the reference model - # _prepare_inputs doesn't return reference log probs anymore. We need to calculate it ourselves. - # https://github.com/huggingface/trl/blob/05bc43e960396581e458195b8388efe6b82cae1f/trl/trainer/grpo_trainer.py#L1328 - # if self.beta != 0.0: - # with torch.inference_mode(), model.disable_adapter(): - # ref_per_token_logps = per_token_logps = get_logps_func(model, input_ids, attention_mask, logits_to_keep) - # else: - # ref_per_token_logps = None - ref_logps = inputs.get("ref_per_token_logps", None) - # per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 - # x - x.detach() allows for preserving gradients from x - advantages = inputs["advantages"] - # per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) - # per_token_loss = -(per_token_loss - self.beta * per_token_kl) - # loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() - old_logps = inputs.get("old_per_token_logps", None) - - input_ids = input_ids[:, -logits_to_keep:] - - # Get logit softcapping and logit scale - logit_softcapping = getattr(model.config, "final_logit_softcapping", 0) # Gemma - if logit_softcapping is None: - logit_softcapping = 0 - logit_scale_multiply = getattr(model.config, "logit_scale", 0) # Cohere - if logit_scale_multiply is None: - logit_scale_multiply = 0 - logit_scale_divide = getattr(model.config, "logits_scaling", 0) # Granite - if logit_scale_divide is None: - logit_scale_divide = 0 - - max_left_pad = inputs.get("max_left_pad", 0) - if per_token_logps is not None: - loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1 = ( - grpo_compute_loss_slow( - ref_logps, - per_token_logps, - old_logps, - input_ids, - completion_mask, - self.beta, - advantages, - pixel_values = pixel_values, - image_grid_thw = image_grid_thw, - loss_type = self.args.loss_type, - importance_sampling_level = self.importance_sampling_level, - epsilon_low = self.epsilon_low, - epsilon_high = self.epsilon_high, - max_completion_length = self.args.max_completion_length, - delta = self.args.delta, - temperature = self.args.temperature, - max_left_pad = max_left_pad, - logit_softcapping = logit_softcapping, - logit_scale_multiply = logit_scale_multiply, - logit_scale_divide = logit_scale_divide, - num_items_in_batch = num_items_in_batch, - current_gradient_accumulation_steps = current_gradient_accumulation_steps, - num_processes = num_processes, - sampling_per_token_logps = sampling_per_token_logps, - ) - ) - else: - if hasattr(self.args, "loss_type"): - loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1 = ( - grpo_accumulated_loss( - trainer = self, - input_ids = _input_ids, - pixel_values = pixel_values, - image_grid_thw = image_grid_thw, - logits_to_keep = logits_to_keep, - completion_mask = completion_mask, - advantages = advantages, - old_logps = old_logps, - ref_logps = ref_logps, - n_chunks = self.args.unsloth_num_chunks, - loss_type = self.args.loss_type, - importance_sampling_level = self.importance_sampling_level, - epsilon_low = self.epsilon_low, - epsilon_high = self.epsilon_high, - max_completion_length = self.args.max_completion_length, - delta = self.args.delta, - temperature = self.args.temperature, - max_left_pad = max_left_pad, - logit_softcapping = logit_softcapping, - logit_scale_multiply = logit_scale_multiply, - logit_scale_divide = logit_scale_divide, - attention_mask = attention_mask, - num_items_in_batch = num_items_in_batch, - current_gradient_accumulation_steps = current_gradient_accumulation_steps, - num_processes = num_processes, - sampling_per_token_logps = sampling_per_token_logps, - ) - ) - else: - # to ensure backwards compatibility with trl 0.15.2 and maybe even 0.17 - loss, completion_length, mean_kl, coef_1 = grpo_accumulated_loss( - trainer = self, - input_ids = _input_ids, - logits_to_keep = logits_to_keep, - completion_mask = completion_mask, - advantages = advantages, - old_logps = old_logps, - ref_logps = ref_logps, - n_chunks = self.args.unsloth_num_chunks, - temperature = self.args.temperature, - logit_softcapping = logit_softcapping, - logit_scale_multiply = logit_scale_multiply, - logit_scale_divide = logit_scale_divide, - attention_mask = attention_mask, - ) - if "train" in self._metrics: - mode = "eval" if self.control.should_evaluate else "train" - self._metrics[mode]["completion_length"].append(completion_length.item()) - self._metrics[mode]["kl"].append(mean_kl.item()) - else: - self._metrics["completion_length"].append(completion_length.item()) - self._metrics["kl"].append(mean_kl.item()) - - if ( - self.use_vllm - and delta is not None - and getattr(self, "vllm_importance_sampling_correction", False) - ): - mean_delta = ( - torch.mean(delta) - if delta.numel() > 0 - else torch.tensor(0.0, device = self.model.device) - ) - max_delta = ( - torch.max(delta) - if delta.numel() > 0 - else torch.tensor(0.0, device = self.model.device) - ) - self._metrics[mode]["sampling/sampling_logp_difference/mean"].append( - self.accelerator.gather(mean_delta).mean().item() - ) - self._metrics[mode]["sampling/sampling_logp_difference/max"].append( - self.accelerator.gather(max_delta).max().item() - ) - - min_importance_sampling_ratio = ( - torch.min(flat_is_ratio) - if flat_is_ratio.numel() > 0 - else torch.tensor(0.0, device = self.model.device) - ) - mean_importance_sampling_ratio = ( - torch.mean(flat_is_ratio) - if flat_is_ratio.numel() > 0 - else torch.tensor(0.0, device = self.model.device) - ) - max_importance_sampling_ratio = ( - torch.max(flat_is_ratio) - if flat_is_ratio.numel() > 0 - else torch.tensor(0.0, device = self.model.device) - ) - self._metrics[mode]["sampling/importance_sampling_ratio/min"].append( - self.accelerator.gather(min_importance_sampling_ratio) - .nan_to_num(nan = float("inf")) - .min() - .item() - ) - self._metrics[mode]["sampling/importance_sampling_ratio/mean"].append( - self.accelerator.gather(mean_importance_sampling_ratio).nanmean().item() - ) - self._metrics[mode]["sampling/importance_sampling_ratio/max"].append( - self.accelerator.gather(max_importance_sampling_ratio) - .nan_to_num(nan = float("-inf")) - .max() - .item() - ) - - completion_token_count = completion_mask.sum().clamp(min = 1.0) - - def masked_batch_mean(x): - if x.shape[1] == 1: # when importance_sampling_level == "sequence" - return x.mean() - else: - return (x * completion_mask).sum() / completion_token_count - - if advantages.dim() == 1: - advantages = advantages.unsqueeze(1) - - if self.loss_type in ["grpo", "bnpo", "dr_grpo", "dapo"]: - # Compute the clipped probability ratios - is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages < 0) - is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages > 0) - is_region_clipped = is_low_clipped | is_high_clipped - - low_clip = masked_batch_mean(is_low_clipped.float()) - high_clip = masked_batch_mean(is_high_clipped.float()) - clip_ratio = masked_batch_mean(is_region_clipped.float()) - - gathered_low_clip = self.accelerator.gather(low_clip) - self._metrics[mode]["clip_ratio/low_mean"].append( - gathered_low_clip.nanmean().item() - ) - self._metrics[mode]["clip_ratio/low_min"].append( - nanmin(gathered_low_clip).item() - ) - gathered_high_clip = self.accelerator.gather(high_clip) - self._metrics[mode]["clip_ratio/high_mean"].append( - gathered_high_clip.nanmean().item() - ) - self._metrics[mode]["clip_ratio/high_max"].append( - nanmax(gathered_high_clip).item() - ) - gathered_clip_ratio = self.accelerator.gather(clip_ratio) - self._metrics[mode]["clip_ratio/region_mean"].append( - gathered_clip_ratio.nanmean().item() - ) - elif self.loss_type == "cispo": - is_cispo_clipped = (coef_1 > self.epsilon_high) & (advantages > 0) - cispo_clip_ratio = masked_batch_mean(is_cispo_clipped.float()) - gathered_cispo_clip_ratio = self.accelerator.gather(cispo_clip_ratio) - self._metrics[mode]["cispo_clip_ratio"].append( - gathered_cispo_clip_ratio.nanmean().item() - ) - - return loss - - def _compute_loss(self, model, inputs): - # Compute the per-token log probabilities for the model - prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] - completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] - input_ids = torch.cat([prompt_ids, completion_ids], dim=1) - attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) - logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens - - # Compute the per_token_logps and the entropy at each position in the completion - per_token_logps, entropies = self._get_per_token_logps_and_entropies( - model, - input_ids, - attention_mask, - logits_to_keep, - compute_entropy=True, - pixel_values=inputs.get("pixel_values"), - image_grid_thw=inputs.get("image_grid_thw"), - num_images=inputs.get("num_images"), - pixel_attention_mask=inputs.get("pixel_attention_mask"), - image_sizes=inputs.get("image_sizes"), - token_type_ids=inputs.get("token_type_ids"), - ) - - if self.top_entropy_quantile < 1.0: - entropy_mask = self.get_high_entropy_mask(entropies, completion_mask, 1 - self.top_entropy_quantile) - else: - entropy_mask = None - - # Compute the KL divergence between the model and the reference model - if self.beta != 0.0: - ref_per_token_logps = inputs["ref_per_token_logps"] - per_token_kl = ( - torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 - ) - - # Compute the loss - advantages = inputs["advantages"] - # When num_iterations == 1 and steps_per_generation <= gradient_accumulation_steps, - # old_per_token_logps == per_token_logps. In this case we can skip its computation - # (see _generate_and_score_completions) and instead use per_token_logps.detach(). - # The exception is when using vLLM, where we always compute old_per_token_logps - # for importance sampling - old_per_token_logps = inputs.get("old_per_token_logps") - old_per_token_logps = per_token_logps.detach() if old_per_token_logps is None else old_per_token_logps - - log_ratio = per_token_logps - old_per_token_logps - if self.importance_sampling_level == "token": - log_importance_weights = log_ratio - elif self.importance_sampling_level == "sequence": - log_importance_weights = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0) - log_importance_weights = log_importance_weights.unsqueeze(-1) - else: - raise ValueError( - f"Unknown importance sampling level: {self.importance_sampling_level}. Possible values are 'token' " - "and 'sequence'." - ) - # From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on - # importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1) - - coef_1 = torch.exp(log_importance_weights) - coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) - - # Two-sided clipping - if self.args.delta is not None: - coef_1 = torch.clamp(coef_1, max=self.args.delta) - - per_token_loss1 = coef_1 * advantages.unsqueeze(1) - per_token_loss2 = coef_2 * advantages.unsqueeze(1) - per_token_loss = -torch.min(per_token_loss1, per_token_loss2) - if entropy_mask is not None: - per_token_loss = per_token_loss * entropy_mask - - if self.use_vllm and self.vllm_importance_sampling_correction: - per_token_loss = per_token_loss * inputs["importance_sampling_ratio"] - - if self.beta != 0.0: - per_token_loss = per_token_loss + self.beta * per_token_kl - - if self.loss_type == "grpo": - loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean() - loss = loss / self.current_gradient_accumulation_steps - elif self.loss_type == "bnpo": - loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0) - loss = loss / self.current_gradient_accumulation_steps - elif self.loss_type == "dr_grpo": - loss = (per_token_loss * completion_mask).sum() / (per_token_loss.size(0) * self.max_completion_length) - loss = loss / self.current_gradient_accumulation_steps - elif self.loss_type == "dapo": - normalizer = inputs["num_items_in_batch"] / self.accelerator.num_processes - loss = (per_token_loss * completion_mask).sum() / normalizer - else: - raise ValueError(f"Unknown loss type: {self.loss_type}") - - # Log the metrics - mode = "train" if self.model.training else "eval" - - completion_token_count = completion_mask.sum().clamp(min=1.0) - - def masked_batch_mean(x): - if x.shape[1] == 1: # when importance_sampling_level == "sequence" - return x.mean() - else: - return (x * completion_mask).sum() / completion_token_count - - if self.beta != 0.0: - mean_kl = masked_batch_mean(per_token_kl) - self._metrics[mode]["kl"].append(self.accelerator.gather(mean_kl).nanmean().item()) - - mean_entropy = masked_batch_mean(entropies) - self._metrics[mode]["entropy"].append(self.accelerator.gather(mean_entropy).nanmean().item()) - - # Compute the clipped probability ratios - is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0) - is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0) - is_region_clipped = is_low_clipped | is_high_clipped - - low_clip = masked_batch_mean(is_low_clipped.float()) - high_clip = masked_batch_mean(is_high_clipped.float()) - clip_ratio = masked_batch_mean(is_region_clipped.float()) - - gathered_low_clip = self.accelerator.gather(low_clip) - self._metrics[mode]["clip_ratio/low_mean"].append(gathered_low_clip.nanmean().item()) - self._metrics[mode]["clip_ratio/low_min"].append(nanmin(gathered_low_clip).item()) - gathered_high_clip = self.accelerator.gather(high_clip) - self._metrics[mode]["clip_ratio/high_mean"].append(gathered_high_clip.nanmean().item()) - self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item()) - gathered_clip_ratio = self.accelerator.gather(clip_ratio) - self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item()) - return loss - - def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: Optional[list[str]] = None): - inputs = self._prepare_inputs(inputs) - with torch.no_grad(): - with self.compute_loss_context_manager(): - loss = self.compute_loss(model, inputs) - loss = loss.mean().detach() - return loss, None, None - - def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None: - mode = "train" if self.model.training else "eval" - metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics - - # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs` - # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format. - if mode == "eval": - metrics = {f"eval_{key}": val for key, val in metrics.items()} - - logs = {**logs, **metrics} - super().log(logs, start_time) - self._metrics[mode].clear() - - if self.accelerator.is_main_process and self.log_completions: - if is_rich_available(): - print_prompt_completions_sample( - self._logs["prompt"], - self._logs["completion"], - self._logs["rewards"], - self._logs["advantages"], - self.state.global_step, - self.num_completions_to_print, - ) - - if self.args.report_to and "wandb" in self.args.report_to and wandb.run is not None: - import pandas as pd - - table = { - "step": [str(self.state.global_step)] * len(self._logs["prompt"]), - "prompt": self._logs["prompt"], - "completion": self._logs["completion"], - **self._logs["rewards"], - "advantage": self._logs["advantages"], - } - - if self._logs["images"]: - table["images"] = [] - for image_list in self._logs["images"]: - # Convert images to wandb Image objects for proper visualization - table["images"].append([wandb.Image(image) for image in image_list]) - - df = pd.DataFrame(table) - if self.wandb_log_unique_prompts: - df = df.drop_duplicates(subset=["prompt"]) - wandb.log({"completions": wandb.Table(dataframe=df)}) - - # Ensure the model card is saved along with the checkpoint - def _save_checkpoint(self, model, trial): - if self.args.hub_model_id is None: - model_name = Path(self.args.output_dir).name - else: - model_name = self.args.hub_model_id.split("/")[-1] - self.create_model_card(model_name=model_name) - super()._save_checkpoint(model, trial) -class UnslothGRPOTrainer(_UnslothGRPOTrainer): - """ - - Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the - paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language - Models](https://huggingface.co/papers/2402.03300). - - Example: - - ```python - from datasets import load_dataset - from trl import GRPOTrainer - - dataset = load_dataset("trl-lib/tldr", split="train") - def reward_func(completions, **kwargs): - # Dummy reward function that rewards completions with more unique letters. - return [float(len(set(completion))) for completion in completions] - trainer = GRPOTrainer( - model="Qwen/Qwen2-0.5B-Instruct", - reward_funcs=reward_func, - train_dataset=dataset, - ) - - trainer.train() - ``` - - Args: - model (`Union[str, PreTrainedModel]`): - Model to be trained. Can be either: - - - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a - path to a *directory* containing model weights saved using - [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded - using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keyword arguments in - `args.model_init_kwargs`. - - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported. - reward_funcs (`Union[RewardFunc, list[RewardFunc]]`): - Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward - functions with the prompts and completions and sum the rewards. Can be either: - - - A single reward function, such as: - - A string: The *model ID* of a pretrained model hosted inside a model repo on huggingface.co, or a - path to a *directory* containing model weights saved using - [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded - using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the - keyword arguments in `args.model_init_kwargs`. - - A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported. - - A custom reward function: The function is provided with the prompts and the generated completions, - plus any additional columns in the dataset. It should return a list of rewards. Custom reward - functions can also return `None` when the reward is not applicable to those samples. This is useful - for multi-task training where different reward functions apply to different types of samples. When a - reward function returns `None` for a sample, that reward function is excluded from the reward - calculation for that sample. For more details, see [Using a custom reward - function](#using-a-custom-reward-function). - - The trainer's state is also passed to the reward function. The trainer's state is an instance of - [`~transformers.TrainerState`] and can be accessed by accessing the `trainer_state` argument to the - reward function's signature. - - A list of reward functions, where each item can independently be any of the above types. Mixing different - types within the list (e.g., a string model ID and a custom reward function) is allowed. - args ([`GRPOConfig`], *optional*): - Configuration for this trainer. If `None`, a default configuration is used. - train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]): - Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset is - ignored. The format of the samples can be either: - - - [Standard](dataset_formats#standard): Each sample contains plain text. - - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role - and content). - eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`): - Dataset to use for evaluation. It must meet the same requirements as `train_dataset`. - processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.ProcessorMixin`], *optional*): - Processing class used to process the data. The padding side must be set to "left". If `None`, the - processing class is loaded from the model's name with [`~transformers.AutoProcessor.from_pretrained`]. A - padding token, `tokenizer.pad_token`, must be set. If the processing class has not set a padding token, - `tokenizer.eos_token` will be used as the default. - reward_processing_classes ([`~transformers.PreTrainedTokenizerBase`] or `list[PreTrainedTokenizerBase]`, *optional*): - Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either: - - - A single processing class: Used when `reward_funcs` contains only one reward function. - - A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`. - If set to `None`, or if an element of the list corresponding to a [`~transformers.PreTrainedModel`] is - `None`, the tokenizer for the model is automatically loaded using - [`~transformers.AutoTokenizer.from_pretrained`]. For elements in `reward_funcs` that are custom reward - functions (not [`~transformers.PreTrainedModel`]), the corresponding entries in `reward_processing_classes` - are ignored. - callbacks (list of [`~transformers.TrainerCallback`], *optional*): - List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed - in [here](https://huggingface.co/docs/transformers/main_classes/callback). - - If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`] - method. - optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`): - A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your - model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`. - peft_config ([`~peft.PeftConfig`], *optional*): - PEFT configuration used to wrap the model. If `None`, the model is not wrapped. - - """ - def __init__( - self, - model, - reward_funcs, - args = None, - train_dataset = None, - eval_dataset = None, - processing_class = None, - reward_processing_classes = None, - callbacks = None, - peft_config = None, - **kwargs - ): - if args is None: args = UnslothGRPOConfig() - use_bf16 = getattr(args, 'bf16', False) - if type(use_bf16) is not bool: use_bf16 = False - use_fp16 = getattr(args, 'fp16', False) - if type(use_fp16) is not bool: use_fp16 = False - force_float32 = False - full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1' - if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'): - print('Unsloth: Switching to float32 training since model cannot work with float16') - force_float32 = True - mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') - dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None) - if dtype is None: dtype = model.get_input_embeddings().weight.dtype - from unsloth_zoo.utils import _get_dtype - dtype = _get_dtype(dtype) - float16 = dtype == torch.float16 - if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`') - if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`') - if force_float32: - # Forced float32 training - args.fp16 = False - args.bf16 = False - os.environ['ACCELERATE_MIXED_PRECISION'] = 'no' - if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no' - # args.mixed_precision is a new argument which needs to be set now - elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32': - # Mixed precision training - args.fp16 = float16 - args.bf16 = not float16 - os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16' - if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16' - # args.mixed_precision is a new argument which needs to be set now - elif mixed_precision_dtype == 'bfloat16': - # Both False since bfloat16 full finetuning doesn't do any autocasting. - args.fp16 = False - args.bf16 = False - os.environ['ACCELERATE_MIXED_PRECISION'] = 'no' - if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no' - # args.mixed_precision is a new argument which needs to be set now - - if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no': - args.eval_strategy = 'steps' - if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1 - ga_steps = getattr(args, 'gradient_accumulation_steps', None) - if ga_steps is not None and ga_steps > 1: - from transformers import __version__ as transformers_version - if Version(transformers_version) <= Version('4.45.2'): - print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n' - '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`') - if getattr(args, 'eval_strategy', 'no') != 'no': - eval_bsz = getattr(args, 'per_device_eval_batch_size', 8) - if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size - if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps - fp16_full_eval = getattr(args, 'fp16_full_eval', False) - if type(fp16_full_eval) is not bool: fp16_full_eval = False - bf16_full_eval = getattr(args, 'bf16_full_eval', False) - if type(bf16_full_eval) is not bool: bf16_full_eval = False - if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True - if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False - if force_float32: - args.bf16_full_eval = False - args.fp16_full_eval = False - elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16': - args.bf16_full_eval = True - args.fp16_full_eval = False - elif not bf16_full_eval and not fp16_full_eval: - args.bf16_full_eval = args.bf16 - args.fp16_full_eval = args.fp16 - _output_logits = False - if locals().get('compute_metrics', None) is not None: _output_logits = True - if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True - if _output_logits: - os.environ['UNSLOTH_RETURN_LOGITS'] = '1' - if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'): - pass - else: - model_max_seq_length = getattr(model, 'max_seq_length', None) - args_max_seq_length = getattr(args, 'max_seq_length', None) - if args_max_seq_length is None and model_max_seq_length is not None: - max_seq_length = model.max_seq_length - if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length - elif args_max_seq_length is not None and model_max_seq_length is not None: - if args_max_seq_length > model_max_seq_length: - print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but ' - 'the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.') - args.max_seq_length = model_max_seq_length - if model is not None and hasattr(model, 'for_training'): - model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True)) - if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right' - if 'processing_class' in locals(): - if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right' - if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right' - other_metrics = [] - if not isinstance(reward_funcs, list): _reward_funcs = [reward_funcs] - else: _reward_funcs = reward_funcs - for reward_func in _reward_funcs: - try: - reward_func_name = reward_func.__name__ - if True: - other_metrics.append(f'rewards/{reward_func_name}/mean') - if True: - other_metrics.append(f'rewards/{reward_func_name}/std') - if False: - other_metrics.append(f'rewards/{reward_func_name}') - except: pass - - from unsloth_zoo.logging_utils import PatchRLStatistics - PatchRLStatistics('grpo_trainer', other_metrics) - - # [TODO] Fix up DataParallel multiplying batch sizes - # [TODO] DDP works, but DP seems to not work? [TODO] - if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1: - if getattr(args, "_n_gpu", 1) != 1: - args._n_gpu = 1 - if "model" in locals() and hasattr(model, "for_training"): - model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True)) - super().__init__( - model = model, - reward_funcs = reward_funcs, - args = args, - train_dataset = train_dataset, - eval_dataset = eval_dataset, - processing_class = processing_class, - reward_processing_classes = reward_processing_classes, - callbacks = callbacks, - peft_config = peft_config,**kwargs) - if "model" in locals() and hasattr(model, "for_inference"): - model.for_inference() - if hasattr(self, 'neftune_hook_handle'): - self.neftune_hook_handle.remove() - if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle - if getattr(args, 'neftune_noise_alpha', None) is not None: - model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha - pass - if hasattr(self, 'accelerator'): - scaler = self.accelerator.scaler - current_model = model - while hasattr(current_model, 'model'): - current_model.accelerator_scaler = scaler - current_model = current_model.model - current_model.accelerator_scaler = scaler - pass - if hasattr(self, 'train'): - self.train = MethodType(prepare_for_training_mode(self.__class__.train), self) - pass - if hasattr(self, 'llm') and self.llm is not None and hasattr(self.llm, 'get_tokenizer'): - _vllm_tok = self.llm.get_tokenizer() - _pc = getattr(self, 'processing_class', None) or getattr(self, 'tokenizer', None) - if _vllm_tok is not None and _pc is not None and getattr(_pc, 'chat_template', None) is not None and getattr(_vllm_tok, 'chat_template', None) is None: - _vllm_tok.chat_template = _pc.chat_template - pass - -pass - - -if hasattr(logger, "addFilter"): - import logging - class HideLoggingMessage(logging.Filter): - def __init__(self, text): self.text = text - def filter(self, x): return not (self.text in x.getMessage()) - pass - logger.addFilter(HideLoggingMessage("`use_cache=True`")) - diff --git a/unsloth_compiled_cache/UnslothKTOTrainer.py b/unsloth_compiled_cache/UnslothKTOTrainer.py deleted file mode 100644 index c4f9194..0000000 --- a/unsloth_compiled_cache/UnslothKTOTrainer.py +++ /dev/null @@ -1,2377 +0,0 @@ -""" -2026.2.1 -2026.2.1 -4.57.6 -0.24.0 -__UNSLOTH_VERSIONING__ -""" - -# Unsloth auto generated code -# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see . - -from torch import Tensor -import torch -import torch.nn as nn -from torch.nn import functional as F -from unsloth_zoo.temporary_patches.common import torch_compile -from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable -from trl.trainer.kto_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, BaseTrainer, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, KTOConfig, KTOTrainer, Literal, Optional, PartialState, Path, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SequentialSampler, TrainerCallback, TrainingArguments, Union, _get_kl_dataset, _process_tokens, _tokenize, autocast, concatenate_datasets, contextmanager, create_reference_model, defaultdict, disable_dropout_in_model, has_length, inspect, is_comet_available, is_liger_kernel_available, is_peft_available, is_wandb_available, itemgetter, log_table_to_comet_experiment, logger, logging, maybe_apply_chat_template, maybe_extract_prompt, maybe_unpair_preference_dataset, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_deepspeed, prepare_model_for_kbit_training, random, selective_log_softmax, textwrap, torch, tqdm, warnings, AutoModelForCausalLM, BaseImageProcessor, Callable, DPODataCollatorWithPadding, DataCollator, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, KTOConfig, KTOTrainer, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, TrainingArguments, Union, autocast, concatenate_datasets, create_reference_model, defaultdict, disable_dropout_in_model, inspect, is_comet_available, is_liger_kernel_available, is_peft_available, is_wandb_available, logger, maybe_apply_chat_template, maybe_extract_prompt, maybe_unpair_preference_dataset, nn, np, os, peft_module_casting_to_bf16, prepare_deepspeed, prepare_model_for_kbit_training, torch, warnings, F, Optional, PeftModel, PreTrainedModel, is_peft_available, logger, os, torch, F, nn, np, os, selective_log_softmax, torch) - - -import os -from typing import * -from dataclasses import dataclass, field -from packaging.version import Version -import torch -import numpy as np -from contextlib import nullcontext -from torch.nn import functional as F -import inspect -from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling -from transformers.training_args import ParallelMode -from unsloth_zoo.device_type import DEVICE_TYPE, device_synchronize - -# Wrap trainer with padding to right and enable training mode -# Also patches W&B since multiple runs must use wandb.finish() -import functools -from types import MethodType -try: - from unsloth_zoo.gradient_checkpointing import reset_unsloth_gradient_checkpointing_buffers -except: - def reset_unsloth_gradient_checkpointing_buffers(): pass -def prepare_for_training_mode(f): - @functools.wraps(f) - def wrapper(self, *args, **kwargs): - # Enable training mode - _was_training = None - # Get gradient checkpointing setting from training arguments - use_gc = getattr(self.args, 'gradient_checkpointing', True) - if hasattr(self, 'model') and hasattr(self.model, "training"): - _was_training = self.model.training - if hasattr(self, 'model') and hasattr(self.model, "for_training"): - self.model.for_training(use_gradient_checkpointing=use_gc) - output = f(self, *args, **kwargs) - # Restore previous mode when possible - if hasattr(self, 'model') and hasattr(self.model, "for_inference"): - if _was_training is False: - self.model.for_inference() - elif _was_training is True and hasattr(self.model, "for_training"): - self.model.for_training(use_gradient_checkpointing=use_gc) - # Reset gradient checkpointing buffers to free memory while staying ready for next run - try: - reset_unsloth_gradient_checkpointing_buffers() - except: - pass - # Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run - try: - import wandb - wandb.finish() - except: - pass - return output - return wrapper -pass - -torch_compile_options = { - "epilogue_fusion" : True, - "max_autotune" : False, - "shape_padding" : True, - "trace.enabled" : False, - "triton.cudagraphs" : False, -} - -@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) -def chunked_hidden_states_selective_log_softmax( - hidden_states: torch.Tensor, - lm_head: torch.Tensor, - index: torch.Tensor, - chunks: int = 4, - logit_scale_multiply: float = 0.0, - logit_scale_divide: float = 0.0, - logit_softcapping: float = 0.0, - temperature: float = 1.0, -) -> torch.Tensor: - # All Unsloth Zoo code licensed under AGPL3 - flat_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1]) - flat_index = index.reshape(-1) - - chunked_hidden_states = torch.chunk(flat_hidden_states, chunks=chunks, dim=0) - chunked_index = torch.chunk(flat_index, chunks=chunks, dim=0) - - all_per_token_logps = [] - - for chunk_hidden_states, chunk_index in zip(chunked_hidden_states, chunked_index): - chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t() - - if logit_scale_multiply != 0.0: - chunk_logits = chunk_logits * logit_scale_multiply - if logit_scale_divide != 0.0: - chunk_logits = chunk_logits / logit_scale_divide - if logit_softcapping != 0.0: - chunk_logits = chunk_logits * torch.tanh(chunk_logits / logit_softcapping) - - chunk_logits = chunk_logits.to(torch.float32) - - if temperature != 1.0: - chunk_logits = chunk_logits / temperature - - selected_logits = torch.gather(chunk_logits, dim=-1, index=chunk_index.unsqueeze(-1)).squeeze(-1) - logsumexp_values = torch.logsumexp(chunk_logits, dim=-1) - per_token_logps = selected_logits - logsumexp_values - all_per_token_logps.append(per_token_logps) - - all_per_token_logps = torch.concat(all_per_token_logps) - - all_per_token_logps = all_per_token_logps.reshape((hidden_states.shape[0], hidden_states.shape[1])) - return all_per_token_logps - -@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) -def chunked_selective_log_softmax(logits, index): - # Split into 4 chunks only - chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0) - chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0) - all_per_token_logps = [] - # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index) - for chunk_logits, chunk_index in zip(chunked_logits, chunked_index): - chunk_logits = chunk_logits.to(torch.float32) - selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1) - logsumexp_values = torch.logsumexp(chunk_logits, dim = -1) - per_token_logps = selected_logits - logsumexp_values - all_per_token_logps.append(per_token_logps) - pass - all_per_token_logps = torch.concat(all_per_token_logps) - all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1])) - return all_per_token_logps - -def calculate_pad_tokens_in_prompt( - input_ids: torch.Tensor, - logits_to_keep: int, - pad_token_id: int -) -> torch.Tensor: - """ - Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens - """ - if logits_to_keep >= input_ids.shape[1]: - raise ValueError("logits_to_keep must be smaller than the sequence length.") - - prompt_section = input_ids[:, :-logits_to_keep] - - padding_mask = (prompt_section == pad_token_id) - - pad_token_counts = padding_mask.sum(dim=1) - - return pad_token_counts - -def create_completion_attention_mask( - completion_input_ids: torch.Tensor, - left_pad_tokens_per_prompt: torch.Tensor, - max_left_pad: int, - pad_token_id: int -) -> torch.Tensor: - """ - Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad] - - Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens - and pad are pad tokens, this function would make a completion mask that would 0 out the pad - and p tokens. so in this example [0,0,0,1,1,1,0,0,0] - """ - batch_size, completion_len = completion_input_ids.shape - device = completion_input_ids.device - - num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt - - indices = torch.arange(completion_len, device=device).unsqueeze(0) - shift_mask = indices >= num_tokens_to_mask.unsqueeze(1) - - non_padding_mask = (completion_input_ids != pad_token_id) - - final_mask = shift_mask & non_padding_mask - - return final_mask - -def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor: - """ - Moves all padding tokens in each sequence of a batch to the right. - """ - mask = (tensor != pad_id) - # Must do stable=True since binary mark is unordered - sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True) - packed_tensor = torch.gather(tensor, 1, sorted_indices) - return packed_tensor - -def align_logprobs_with_mask( - logprob_tensor: torch.Tensor, - attention_mask: torch.Tensor, - pad_value: float = 0.0 -) -> torch.Tensor: - """ - Aligns a log probability tensor with a given attention mask. - """ - - device = logprob_tensor.device - batch_size, logprob_seq_len = logprob_tensor.shape - mask_seq_len = attention_mask.shape[1] - - padded_logprobs = torch.full( - attention_mask.shape, - fill_value=pad_value, - dtype=logprob_tensor.dtype, - device=device - ) - - left_pad_counts = torch.argmax(attention_mask, dim=1) - - cols = torch.arange(logprob_seq_len, device=device) - dest_indices = left_pad_counts.unsqueeze(1) + cols - - # Create destination row indices - # Shape: [batch_size, logprob_seq_len] - row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices) - - # --- 4. Filter out-of-bounds indices and perform assignment --- - # Create a mask to identify only the indices that are within the bounds - # of the target tensor's sequence length. - valid_mask = dest_indices < mask_seq_len - - # Use this mask to select only the valid row indices, column indices, - # and the corresponding values from the logprob tensor. - # This flattens the selected elements into 1D tensors. - valid_rows = row_indices[valid_mask] - valid_cols = dest_indices[valid_mask] - valid_vals = logprob_tensor[valid_mask] - - # Place the valid values into their correct positions in the padded tensor - # using a single, efficient advanced indexing operation. - padded_logprobs[valid_rows, valid_cols] = valid_vals - - return padded_logprobs - -def autotune_batch_and_chunks( - total_input_rows, - seq_len, - hidden_size, - vocab_size, - dtype_bytes=16, - multiplier=None -): - if multiplier is None: - final_m = max(4, seq_len // 4096) - else: - final_m = multiplier - - if torch.cuda.is_available(): - free_bytes, _ = torch.cuda.mem_get_info() - limit_gb = (free_bytes / (1024**3))*.80 - elif hasattr(torch, "xpu") and torch.xpu.is_available(): - # For XPU: estimate free memory from total - reserved - total_mem = torch.xpu.get_device_properties(0).total_memory - reserved_mem = torch.xpu.memory_reserved() - free_bytes = total_mem - reserved_mem - limit_gb = (free_bytes / (1024**3)) * 0.80 - else: - # Fallback: assume 8GB available - limit_gb = 8.0 - - bytes_to_gb = 1024**3 - - b_vals = torch.arange(total_input_rows, 0, -1, device='cpu', dtype=torch.float32) - - hidden_gb = (b_vals * seq_len * hidden_size * dtype_bytes) / bytes_to_gb - - base_logits = ((b_vals/total_input_rows) * b_vals * seq_len * vocab_size * dtype_bytes) / bytes_to_gb - logits_gb = base_logits / final_m - - total_mem_gb = hidden_gb + logits_gb - - valid_mask = total_mem_gb <= limit_gb - valid_indices = torch.nonzero(valid_mask, as_tuple=False) - - if valid_indices.shape[0] == 0: - #This means your GPU will OOM - return 4, final_m - - best_idx = valid_indices[0].item() - final_b = int(b_vals[best_idx].item()) - - return final_b, final_m -@dataclass -class UnslothKTOConfig(KTOConfig): - """ - - Configuration class for the [`KTOTrainer`]. - - This class includes only the parameters that are specific to KTO training. For a full list of training arguments, - please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may - differ from those in [`~transformers.TrainingArguments`]. - - Using [`~transformers.HfArgumentParser`] we can turn this class into - [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the - command line. - - Parameters: - max_length (`int` or `None`, *optional*, defaults to `1024`): - Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want - to use the default data collator. - max_prompt_length (`int` or `None`, *optional*, defaults to `512`): - Maximum length of the prompt. This argument is required if you want to use the default data collator. - max_completion_length (`int`, *optional*): - Maximum length of the completion. This argument is required if you want to use the default data collator - and your model is an encoder-decoder. - beta (`float`, *optional*, defaults to `0.1`): - Parameter controlling the deviation from the reference model. Higher β means less deviation from the - reference model. - loss_type (`str`, *optional*, defaults to `"kto"`): - Type of loss to use. Possible values are: - - - `"kto"`: KTO loss from the [KTO](https://huggingface.co/papers/2402.01306) paper. - - `"apo_zero_unpaired"`: Unpaired variant of APO-zero loss from the - [APO](https://huggingface.co/papers/2408.06266) paper. - - desirable_weight (`float`, *optional*, defaults to `1.0`): - Desirable losses are weighed by this factor to counter unequal number of desirable and undesirable paris. - undesirable_weight (`float`, *optional*, defaults to `1.0`): - Undesirable losses are weighed by this factor to counter unequal number of desirable and undesirable pairs. - label_pad_token_id (`int`, *optional*, defaults to `-100`): - Label pad token id. This argument is required if you want to use the default data collator. - padding_value (`int`, *optional*): - Padding value to use. If `None`, the padding value of the tokenizer is used. - truncation_mode (`str`, *optional*, defaults to `"keep_end"`): - Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`. - This argument is required if you want to use the default data collator. - generate_during_eval (`bool`, *optional*, defaults to `False`): - If `True`, generates and logs completions from both the model and the reference model to W&B or Comet - during evaluation. - is_encoder_decoder (`bool`, *optional*): - When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument, - you need to specify if the model returned by the callable is an encoder-decoder model. - precompute_ref_log_probs (`bool`, *optional*, defaults to `False`): - Whether to precompute reference model log probabilities for training and evaluation datasets. This is - useful when training without the reference model to reduce the total GPU memory needed. - model_init_kwargs (`dict[str, Any]`, *optional*): - Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a - string. - ref_model_init_kwargs (`dict[str, Any]`, *optional*): - Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the reference model - from a string. - dataset_num_proc: (`int`, *optional*): - Number of processes to use for processing the dataset. - disable_dropout (`bool`, *optional*, defaults to `True`): - Whether to disable dropout in the model and reference model. - use_liger_loss (`bool`, *optional*, defaults to `False`): - Whether to use Liger loss. It requires liger-kernel to be installed. - base_model_attribute_name (`str`, *optional*, defaults to `"model"`): - Name of the attribute in the model that contains the base model. This is used to get the base model from - the model when the model does not have a `get_decoder` method in the case when `use_liger_loss` is `True`. - - """ - vllm_sampling_params: Optional[Any] = field( - default = None, - metadata = {'help': 'vLLM SamplingParams'}, - ) - unsloth_num_chunks : Optional[int] = field( - default = -1, - metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'}, - ) - unsloth_logit_chunk_multiplier : Optional[int] = field( - default = None, - metadata = {'help': 'Multiplier for chunked logit computations.'}, - ) - unsloth_grpo_mini_batch : Optional[int] = field( - default = None, - metadata = {'help': 'Mini batch size for GRPO hidden state accumulation. Default is None unless user defines it.'}, - ) - max_seq_length : Optional[int] = field( - default = None, - metadata = {'help': 'Maximum sequence length to truncate to.'}, - ) - def __init__( - self, - output_dir = None, - overwrite_output_dir = None, - do_train = False, - do_eval = False, - do_predict = False, - eval_strategy = 'no', - prediction_loss_only = False, - per_device_train_batch_size = 4, - per_device_eval_batch_size = 4, - per_gpu_train_batch_size = None, - per_gpu_eval_batch_size = None, - gradient_accumulation_steps = 2, - eval_accumulation_steps = 2, - eval_delay = 0, - torch_empty_cache_steps = 250, - learning_rate = 5e-05, - weight_decay = 0.01, - adam_beta1 = 0.9, - adam_beta2 = 0.999, - adam_epsilon = 1e-08, - max_grad_norm = 1.0, - num_train_epochs = 3.0, - max_steps = -1, - lr_scheduler_type = 'linear', - lr_scheduler_kwargs = None, - warmup_ratio = 0.1, - warmup_steps = 0, - log_level = 'passive', - log_level_replica = 'warning', - log_on_each_node = True, - logging_dir = None, - logging_strategy = 'steps', - logging_first_step = False, - logging_steps = 1, - logging_nan_inf_filter = False, - save_strategy = 'steps', - save_steps = 500, - save_total_limit = None, - save_safetensors = True, - save_on_each_node = False, - save_only_model = False, - restore_callback_states_from_checkpoint = False, - no_cuda = False, - use_cpu = False, - use_mps_device = False, - seed = 3407, - data_seed = 3407, - jit_mode_eval = False, - bf16 = False, - fp16 = False, - fp16_opt_level = 'O1', - half_precision_backend = 'auto', - bf16_full_eval = False, - fp16_full_eval = False, - tf32 = None, - local_rank = -1, - ddp_backend = None, - tpu_num_cores = None, - tpu_metrics_debug = False, - debug = '', - dataloader_drop_last = False, - eval_steps = None, - dataloader_num_workers = 0, - dataloader_prefetch_factor = None, - past_index = -1, - run_name = None, - disable_tqdm = None, - remove_unused_columns = True, - label_names = None, - load_best_model_at_end = False, - metric_for_best_model = None, - greater_is_better = None, - ignore_data_skip = False, - fsdp = None, - fsdp_min_num_params = 0, - fsdp_config = None, - fsdp_transformer_layer_cls_to_wrap = None, - accelerator_config = None, - parallelism_config = None, - deepspeed = None, - label_smoothing_factor = 0.0, - optim = 'adamw_8bit', - optim_args = None, - adafactor = False, - group_by_length = False, - length_column_name = 'length', - report_to = 'none', - project = 'huggingface', - trackio_space_id = 'trackio', - ddp_find_unused_parameters = None, - ddp_bucket_cap_mb = None, - ddp_broadcast_buffers = None, - dataloader_pin_memory = True, - dataloader_persistent_workers = False, - skip_memory_metrics = True, - use_legacy_prediction_loop = False, - push_to_hub = False, - resume_from_checkpoint = None, - hub_model_id = None, - hub_strategy = 'every_save', - hub_token = None, - hub_private_repo = None, - hub_always_push = False, - hub_revision = None, - gradient_checkpointing = True, - gradient_checkpointing_kwargs = None, - include_inputs_for_metrics = False, - eval_do_concat_batches = True, - fp16_backend = 'auto', - push_to_hub_model_id = None, - push_to_hub_organization = None, - push_to_hub_token = None, - mp_parameters = '', - auto_find_batch_size = False, - full_determinism = False, - torchdynamo = None, - ray_scope = 'last', - ddp_timeout = 1800, - torch_compile = False, - torch_compile_backend = None, - torch_compile_mode = None, - include_tokens_per_second = False, - include_num_input_tokens_seen = False, - neftune_noise_alpha = None, - optim_target_modules = None, - batch_eval_metrics = False, - eval_on_start = False, - use_liger_kernel = False, - liger_kernel_config = None, - eval_use_gather_object = False, - average_tokens_across_devices = True, - max_length = 1024, - max_prompt_length = 512, - max_completion_length = None, - beta = 0.1, - loss_type = 'kto', - desirable_weight = 1.0, - undesirable_weight = 1.0, - label_pad_token_id = -100, - padding_value = None, - truncation_mode = 'keep_end', - generate_during_eval = False, - is_encoder_decoder = None, - disable_dropout = True, - precompute_ref_log_probs = False, - model_init_kwargs = None, - ref_model_init_kwargs = None, - dataset_num_proc = None, - use_liger_loss = False, - base_model_attribute_name = 'model', - vllm_sampling_params = None, - unsloth_num_chunks = -1, - unsloth_logit_chunk_multiplier = None, - unsloth_grpo_mini_batch = None, - max_seq_length = None, - **kwargs, - ): - if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!') - if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!') - if num_train_epochs is None: - num_train_epochs = 3.0 # Default to 3 epochs if None, max_steps will override - if output_dir is None and save_strategy == 'steps' and save_steps == 500: - output_dir = 'unsloth_training_checkpoints' - save_strategy = 'no' - import multiprocessing as _mp - if _mp.get_start_method() != 'fork': - dataset_num_proc = None - elif dataset_num_proc is None: - import psutil - dataset_num_proc = min(max((psutil.cpu_count() or 1)+4, 2), 64) - memory_gb_left = psutil.virtual_memory().available / (1024**3) - if memory_gb_left <= 2: dataset_num_proc = 1 - else: dataset_num_proc = min(dataset_num_proc, int(memory_gb_left)) - - super().__init__( - output_dir = output_dir, - overwrite_output_dir = overwrite_output_dir, - do_train = do_train, - do_eval = do_eval, - do_predict = do_predict, - eval_strategy = eval_strategy, - prediction_loss_only = prediction_loss_only, - per_device_train_batch_size = per_device_train_batch_size, - per_device_eval_batch_size = per_device_eval_batch_size, - per_gpu_train_batch_size = per_gpu_train_batch_size, - per_gpu_eval_batch_size = per_gpu_eval_batch_size, - gradient_accumulation_steps = gradient_accumulation_steps, - eval_accumulation_steps = eval_accumulation_steps, - eval_delay = eval_delay, - torch_empty_cache_steps = torch_empty_cache_steps, - learning_rate = learning_rate, - weight_decay = weight_decay, - adam_beta1 = adam_beta1, - adam_beta2 = adam_beta2, - adam_epsilon = adam_epsilon, - max_grad_norm = max_grad_norm, - num_train_epochs = num_train_epochs, - max_steps = max_steps, - lr_scheduler_type = lr_scheduler_type, - lr_scheduler_kwargs = lr_scheduler_kwargs, - warmup_ratio = warmup_ratio, - warmup_steps = warmup_steps, - log_level = log_level, - log_level_replica = log_level_replica, - log_on_each_node = log_on_each_node, - logging_dir = logging_dir, - logging_strategy = logging_strategy, - logging_first_step = logging_first_step, - logging_steps = logging_steps, - logging_nan_inf_filter = logging_nan_inf_filter, - save_strategy = save_strategy, - save_steps = save_steps, - save_total_limit = save_total_limit, - save_safetensors = save_safetensors, - save_on_each_node = save_on_each_node, - save_only_model = save_only_model, - restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint, - no_cuda = no_cuda, - use_cpu = use_cpu, - use_mps_device = use_mps_device, - seed = seed, - data_seed = data_seed, - jit_mode_eval = jit_mode_eval, - bf16 = bf16, - fp16 = fp16, - fp16_opt_level = fp16_opt_level, - half_precision_backend = half_precision_backend, - bf16_full_eval = bf16_full_eval, - fp16_full_eval = fp16_full_eval, - tf32 = tf32, - local_rank = local_rank, - ddp_backend = ddp_backend, - tpu_num_cores = tpu_num_cores, - tpu_metrics_debug = tpu_metrics_debug, - debug = debug, - dataloader_drop_last = dataloader_drop_last, - eval_steps = eval_steps, - dataloader_num_workers = dataloader_num_workers, - dataloader_prefetch_factor = dataloader_prefetch_factor, - past_index = past_index, - run_name = run_name, - disable_tqdm = disable_tqdm, - remove_unused_columns = remove_unused_columns, - label_names = label_names, - load_best_model_at_end = load_best_model_at_end, - metric_for_best_model = metric_for_best_model, - greater_is_better = greater_is_better, - ignore_data_skip = ignore_data_skip, - fsdp = fsdp, - fsdp_min_num_params = fsdp_min_num_params, - fsdp_config = fsdp_config, - fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap, - accelerator_config = accelerator_config, - parallelism_config = parallelism_config, - deepspeed = deepspeed, - label_smoothing_factor = label_smoothing_factor, - optim = optim, - optim_args = optim_args, - adafactor = adafactor, - group_by_length = group_by_length, - length_column_name = length_column_name, - report_to = report_to, - project = project, - trackio_space_id = trackio_space_id, - ddp_find_unused_parameters = ddp_find_unused_parameters, - ddp_bucket_cap_mb = ddp_bucket_cap_mb, - ddp_broadcast_buffers = ddp_broadcast_buffers, - dataloader_pin_memory = dataloader_pin_memory, - dataloader_persistent_workers = dataloader_persistent_workers, - skip_memory_metrics = skip_memory_metrics, - use_legacy_prediction_loop = use_legacy_prediction_loop, - push_to_hub = push_to_hub, - resume_from_checkpoint = resume_from_checkpoint, - hub_model_id = hub_model_id, - hub_strategy = hub_strategy, - hub_token = hub_token, - hub_private_repo = hub_private_repo, - hub_always_push = hub_always_push, - hub_revision = hub_revision, - gradient_checkpointing = gradient_checkpointing, - gradient_checkpointing_kwargs = gradient_checkpointing_kwargs, - include_inputs_for_metrics = include_inputs_for_metrics, - eval_do_concat_batches = eval_do_concat_batches, - fp16_backend = fp16_backend, - push_to_hub_model_id = push_to_hub_model_id, - push_to_hub_organization = push_to_hub_organization, - push_to_hub_token = push_to_hub_token, - mp_parameters = mp_parameters, - auto_find_batch_size = auto_find_batch_size, - full_determinism = full_determinism, - torchdynamo = torchdynamo, - ray_scope = ray_scope, - ddp_timeout = ddp_timeout, - torch_compile = torch_compile, - torch_compile_backend = torch_compile_backend, - torch_compile_mode = torch_compile_mode, - include_tokens_per_second = include_tokens_per_second, - include_num_input_tokens_seen = include_num_input_tokens_seen, - neftune_noise_alpha = neftune_noise_alpha, - optim_target_modules = optim_target_modules, - batch_eval_metrics = batch_eval_metrics, - eval_on_start = eval_on_start, - use_liger_kernel = use_liger_kernel, - liger_kernel_config = liger_kernel_config, - eval_use_gather_object = eval_use_gather_object, - average_tokens_across_devices = average_tokens_across_devices, - max_length = max_length, - max_prompt_length = max_prompt_length, - max_completion_length = max_completion_length, - beta = beta, - loss_type = loss_type, - desirable_weight = desirable_weight, - undesirable_weight = undesirable_weight, - label_pad_token_id = label_pad_token_id, - padding_value = padding_value, - truncation_mode = truncation_mode, - generate_during_eval = generate_during_eval, - is_encoder_decoder = is_encoder_decoder, - disable_dropout = disable_dropout, - precompute_ref_log_probs = precompute_ref_log_probs, - model_init_kwargs = model_init_kwargs, - ref_model_init_kwargs = ref_model_init_kwargs, - dataset_num_proc = dataset_num_proc, - use_liger_loss = use_liger_loss, - base_model_attribute_name = base_model_attribute_name,**kwargs) - self.vllm_sampling_params = vllm_sampling_params - self.unsloth_num_chunks = unsloth_num_chunks - if unsloth_grpo_mini_batch is not None: - if self.generation_batch_size >= unsloth_grpo_mini_batch: - self.unsloth_grpo_mini_batch = unsloth_grpo_mini_batch - else: - raise ValueError( - f"Unsloth GRPO mini batch size needs to be less than or equal to the effective generation batch size, " - f"which is self.per_device_train_batch_size * gradient_accumulation_steps." - ) - self.unsloth_logit_chunk_multiplier = unsloth_logit_chunk_multiplier - self.max_seq_length = max_seq_length - -pass - -class _UnslothKTOTrainer(BaseTrainer): - r"""""" - - _tag_names = ["trl", "kto"] - _name = "KTO" - _paper = { - "title": "KTO: Model Alignment as Prospect Theoretic Optimization", - "id": "2402.01306", - # docstyle-ignore - "citation": textwrap.dedent("""\ - @article{ethayarajh2024kto, - title = {{KTO: Model Alignment as Prospect Theoretic Optimization}}, - author = {Kawin Ethayarajh and Winnie Xu and Niklas Muennighoff and Dan Jurafsky and Douwe Kiela}, - year = 2024, - eprint = {arXiv:2402.01306}, - }"""), - } - - def __init__( - self, - model: Union[PreTrainedModel, nn.Module, str] = None, - ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, - args: KTOConfig = None, - train_dataset: Optional[Dataset] = None, - eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None, - processing_class: Optional[ - Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] - ] = None, - data_collator: Optional[DataCollator] = None, - model_init: Optional[Callable[[], PreTrainedModel]] = None, - callbacks: Optional[list[TrainerCallback]] = None, - optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), - preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, - peft_config: Optional[dict] = None, - compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None, - model_adapter_name: Optional[str] = None, - ref_adapter_name: Optional[str] = None, - ): - if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"): - warnings.warn( - "This trainer will soon be moved to trl.experimental and is a candidate for removal. If you rely on " - "it and want it to remain, please share your comments here: " - "https://github.com/huggingface/trl/issues/4223. Silence this warning by setting environment variable " - "TRL_EXPERIMENTAL_SILENCE=1." - ) - if type(args) is TrainingArguments: - raise ValueError("Please use `KTOConfig` instead TrainingArguments.") - - if not isinstance(model, str) and ref_model is model: - raise ValueError( - "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the " - "same as `model`, you must mass a copy of it, or `None` if you use peft." - ) - - if args.model_init_kwargs is None: - model_init_kwargs = {} - elif not isinstance(model, str): - raise ValueError("You passed model_kwargs to the KTOTrainer. But your model is already instantiated.") - else: - model_init_kwargs = args.model_init_kwargs - dtype = model_init_kwargs.get("dtype") - if dtype is not None: - # Convert to `torch.dtype` if an str is passed - if isinstance(dtype, str) and dtype != "auto": - dtype = getattr(torch, dtype) - if dtype != "auto" and not isinstance(dtype, torch.dtype): - raise ValueError( - f"Invalid `dtype` passed to the KTOConfig. Expected a string with either `torch.dtype` or 'auto', but got {dtype}." - ) - model_init_kwargs["dtype"] = dtype - - if args.ref_model_init_kwargs is None: - ref_model_init_kwargs = {} - elif not isinstance(ref_model, str): - raise ValueError( - "You passed ref_model_kwargs to the KTOTrainer. But your ref_model is already instantiated." - ) - else: - ref_model_init_kwargs = args.ref_model_init_kwargs - dtype = ref_model_init_kwargs.get("dtype") - if dtype is not None: - # Convert to `torch.dtype` if an str is passed - if isinstance(dtype, str) and dtype != "auto": - dtype = getattr(torch, dtype) - if dtype != "auto" and not isinstance(dtype, torch.dtype): - raise ValueError( - f"Invalid `dtype` passed to the KTOConfig. Expected a string with either `torch.dtype` or 'auto', but got {dtype}." - ) - ref_model_init_kwargs["dtype"] = dtype - - if isinstance(model, str): - model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) - - if isinstance(ref_model, str): - ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs) - - # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16` - # has been called in order to properly call autocast if needed. - self._peft_has_been_casted_to_bf16 = False - - if not is_peft_available() and peft_config is not None: - raise ValueError( - "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it with `pip install peft` to use the PEFT models" - ) - elif is_peft_available() and peft_config is not None: - # if model is a peft model and we have a peft_config, we merge and unload it first - if isinstance(model, PeftModel): - model = model.merge_and_unload() - - if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): - _support_gc_kwargs = hasattr( - args, "gradient_checkpointing_kwargs" - ) and "gradient_checkpointing_kwargs" in list( - inspect.signature(prepare_model_for_kbit_training).parameters - ) - - prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing} - - if _support_gc_kwargs: - prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs - - model = prepare_model_for_kbit_training(model, **prepare_model_kwargs) - elif args.gradient_checkpointing: - # For backward compatibility with older versions of transformers - if hasattr(model, "enable_input_require_grads"): - model.enable_input_require_grads() - else: - - def make_inputs_require_grad(module, input, output): - output.requires_grad_(True) - - model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) - - # get peft model with the given config - model = model - if args.bf16 and getattr(model, "is_loaded_in_4bit", False): - peft_module_casting_to_bf16(model) - # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager - self._peft_has_been_casted_to_bf16 = True - - # For models that use gradient_checkpointing, we need to attach a hook that enables input - # to explicitly have `requires_grad=True`, otherwise training will either silently - # fail or completely fail. - elif args.gradient_checkpointing: - # For backward compatibility with older versions of transformers - if hasattr(model, "enable_input_require_grads"): - model.enable_input_require_grads() - else: - - def make_inputs_require_grad(module, input, output): - output.requires_grad_(True) - - model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) - - if args.generate_during_eval and not (is_wandb_available() or is_comet_available()): - raise ValueError( - "`generate_during_eval=True` requires Weights and Biases or Comet to be installed." - " Please install `wandb` or `comet-ml` to resolve." - ) - - if model is not None: - self.is_encoder_decoder = model.config.is_encoder_decoder - elif args.is_encoder_decoder is None: - raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.") - else: - self.is_encoder_decoder = args.is_encoder_decoder - - self.is_peft_model = is_peft_available() and isinstance(model, PeftModel) - self.model_adapter_name = model_adapter_name - self.ref_adapter_name = ref_adapter_name - - if ref_model: - self.ref_model = ref_model - elif self.is_peft_model or args.precompute_ref_log_probs: - # The `model` with adapters turned off will be used as the reference model - self.ref_model = None - else: - self.ref_model = create_reference_model(model) - - if processing_class is None: - raise ValueError( - "max_length or a processing_class must be specified when using the default DPODataCollatorWithPadding" - ) - if args.max_length is None: - logger.warning( - "When using DPODataCollatorWithPadding, you should set `max_length` in the KTOTrainer's init" - " it will be set to `512` by default, but you should do it yourself in the future.", - ) - max_length = 512 - if args.max_length is not None: - max_length = args.max_length - - if args.max_prompt_length is None: - logger.warning( - "When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the KTOTrainer's init" - " it will be set to `128` by default, but you should do it yourself in the future.", - ) - max_prompt_length = 128 - if args.max_prompt_length is not None: - max_prompt_length = args.max_prompt_length - - max_completion_length = None - if args.max_completion_length is None and self.is_encoder_decoder: - logger.warning( - "When using DPODataCollatorWithPadding with an encoder decoder architecture, you should set `max_completion_length` in the KTOTrainer's init" - " it will be set to `128` by default, but you should do it yourself in the future.", - ) - max_completion_length = 128 - if args.max_completion_length is not None and self.is_encoder_decoder: - max_completion_length = args.max_completion_length - - if data_collator is None: - data_collator = DPODataCollatorWithPadding( - pad_token_id=processing_class.pad_token_id, - label_pad_token_id=args.label_pad_token_id, - is_encoder_decoder=self.is_encoder_decoder, - ) - - if args.remove_unused_columns: - args.remove_unused_columns = False - # warn users - logger.warning( - "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your KTOConfig" - " we have set it for you, but you should do it yourself in the future.", - ) - - self.use_dpo_data_collator = True - else: - self.use_dpo_data_collator = False - - # Disable dropout in the model and reference model - if args.disable_dropout: - disable_dropout_in_model(model) - if self.ref_model is not None: - disable_dropout_in_model(self.ref_model) - - self.loss_type = args.loss_type - self.max_length = max_length - self.generate_during_eval = args.generate_during_eval - self.label_pad_token_id = args.label_pad_token_id - self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id - self.max_prompt_length = max_prompt_length - self.truncation_mode = args.truncation_mode - self.max_completion_length = max_completion_length - self.processing_class = processing_class - self.precompute_ref_log_probs = args.precompute_ref_log_probs - - # Not all losses require a KL calculation - self.calculate_KL = True - if self.loss_type in ["apo_zero_unpaired"]: - self.calculate_KL = False - - # Since ref_logs are precomputed on the first call to get_train/eval_dataloader - # keep track of first called to avoid computation of future calls - self._precomputed_train_ref_log_probs = False - self._precomputed_eval_ref_log_probs = False - - # metric - self._stored_metrics = defaultdict(lambda: defaultdict(list)) - - # KTO parameter - self.beta = args.beta - self.desirable_weight = args.desirable_weight - self.undesirable_weight = args.undesirable_weight - self.aux_loss_enabled = getattr(model.config, "output_router_logits", False) - self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0) - if self.aux_loss_enabled and self.aux_loss_coef == 0.0: - logger.warning( - "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to " - "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value " - "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary " - "loss.", - ) - - # The trainer estimates the number of FLOPs [floating-point operations] using the number of elements in the - # input tensor associated with the key "input_ids". However, in KTO, the sampled data does not include the - # "input_ids" key. Instead, the available keys are "prompt_input_ids" and "completion_input_ids". As a result, - # the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point - # operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's - # "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been - # issued. - model.warnings_issued["estimate_tokens"] = True - - # Compute that only on the main process for faster data processing. - # see: https://github.com/huggingface/trl/pull/1255 - with PartialState().main_process_first(): - # Extract the prompt if needed - train_dataset = train_dataset.map( - maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from train dataset" - ) - # Unpair the dataset if needed - train_dataset = maybe_unpair_preference_dataset( - train_dataset, args.dataset_num_proc, desc="Unpairing train dataset" - ) - # Apply the chat template if needed - train_dataset = train_dataset.map( - maybe_apply_chat_template, - fn_kwargs={"tokenizer": processing_class}, - num_proc=args.dataset_num_proc, - desc="Applying chat template to train dataset", - ) - if eval_dataset is not None: - eval_dataset = eval_dataset.map( - maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from eval dataset" - ) - eval_dataset = maybe_unpair_preference_dataset( - eval_dataset, args.dataset_num_proc, desc="Unpairing eval dataset" - ) - eval_dataset = eval_dataset.map( - maybe_apply_chat_template, - fn_kwargs={"tokenizer": processing_class}, - num_proc=args.dataset_num_proc, - desc="Applying chat template to eval dataset", - ) - - # Tokenize and prepare the training datasets - train_dataset = train_dataset.map( - _tokenize, - batched=True, - fn_kwargs={"tokenizer": self.processing_class}, - num_proc=args.dataset_num_proc, - desc="Tokenizing train dataset", - ) - - fn_kwargs = { - "prefix": "", - "is_encoder_decoder": self.is_encoder_decoder, - "tokenizer": self.processing_class, - "max_length": self.max_length, - "truncation_mode": self.truncation_mode, - "label_pad_token_id": self.label_pad_token_id, - "max_prompt_length": self.max_prompt_length, - "max_completion_length": self.max_completion_length, - } - - train_dataset = train_dataset.map( - _process_tokens, - fn_kwargs=fn_kwargs, - num_proc=args.dataset_num_proc, - desc="Processing tokenized train dataset", - ) - - # Tokenize and prepare the eval datasets - if eval_dataset is not None: - eval_dataset = eval_dataset.map( - _tokenize, - fn_kwargs={"tokenizer": self.processing_class}, - batched=True, - num_proc=args.dataset_num_proc, - desc="Tokenizing eval dataset", - ) - - eval_dataset = eval_dataset.map( - _process_tokens, - fn_kwargs=fn_kwargs, - num_proc=args.dataset_num_proc, - desc="Processing tokenized eval dataset", - ) - - # Get KL datasets if needed - if self.calculate_KL: - if args.per_device_train_batch_size <= 1: - raise ValueError( - "Actual (not effective) batch size must be > 1. KTO will not work properly because the KL term will be equivalent to the implied reward." - ) - - # create pairs for estimating the KL term by flipping the matched pairs in each batch of size total_batch_size - # i.e., [x_1, y_1], ..., [x_n, y_n] --> [x_1, y_n], ..., [x_n, y_1] = [x'_1, y'_1], ..., [x'_n, y'_n] - train_kl_dataset = train_dataset.map( - _get_kl_dataset, - batched=True, - batch_size=args.per_device_train_batch_size, - num_proc=args.dataset_num_proc, - desc="Extracting KL train dataset", - ) - - fn_kwargs["prefix"] = "KL_" - train_kl_dataset = train_kl_dataset.map( - _process_tokens, - fn_kwargs=fn_kwargs, - num_proc=args.dataset_num_proc, - remove_columns=[c for c in train_kl_dataset.column_names if c in train_dataset.column_names], - desc="Processing tokenized train KL dataset", - ) - - # merge the datasets - train_dataset = concatenate_datasets([train_dataset, train_kl_dataset], axis=1) - - if eval_dataset is not None: - # Get KL dataset - eval_kl_dataset = eval_dataset.map( - _get_kl_dataset, - batched=True, - batch_size=args.per_device_train_batch_size, - num_proc=args.dataset_num_proc, - desc="Extracting eval KL dataset", - ) - - eval_kl_dataset = eval_kl_dataset.map( - _process_tokens, - fn_kwargs=fn_kwargs, - num_proc=args.dataset_num_proc, - remove_columns=[c for c in eval_kl_dataset.column_names if c in eval_dataset.column_names], - desc="Processing tokenized eval KL dataset", - ) - - # merge the datasets - eval_dataset = concatenate_datasets([eval_dataset, eval_kl_dataset], axis=1) - - # calculate dataset desirability balance - num_desirable = max(sum(train_dataset["label"]), 1) - num_undesirable = max(len(train_dataset["label"]) - num_desirable, 1) # "label" is binary - - if num_desirable != num_undesirable: - # The lower and upper bounds come from Eq. [8] of https://huggingface.co/papers/2402.01306 - des_weight_lower_bound = round((num_undesirable * self.undesirable_weight / num_desirable) * 1, 2) - des_weight_upper_bound = round((num_undesirable * self.undesirable_weight / num_desirable) * 1.33, 2) - und_weight_lower_bound = round((num_desirable * self.desirable_weight / num_undesirable) / 1.33, 2) - und_weight_upper_bound = round((num_desirable * self.desirable_weight / num_undesirable) / 1, 2) - - des_weight_in_range = des_weight_lower_bound <= self.desirable_weight <= des_weight_upper_bound - und_weight_in_range = und_weight_lower_bound <= self.undesirable_weight <= und_weight_upper_bound - - if not (des_weight_in_range or und_weight_in_range): - logger.warning( - "You have different amounts of desirable/positive and undesirable/negative examples but the " - "weights on the desirable and undesirable losses don't seem to be in an ideal range. Based " - f"on your data, we recommend EITHER " - f"desirable_weight in [{des_weight_lower_bound}, {des_weight_upper_bound}] or " - f"undesirable_weight in [{und_weight_lower_bound}, {und_weight_upper_bound}] (but NOT BOTH). " - "See the documentation on how to optimally set these weights.", - ) - - super().__init__( - model=model, - args=args, - data_collator=data_collator, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - processing_class=processing_class, - model_init=model_init, - compute_metrics=compute_metrics, - callbacks=callbacks, - optimizers=optimizers, - preprocess_logits_for_metrics=preprocess_logits_for_metrics, - ) - - # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the - # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set - # self.model_accepts_loss_kwargs to False to enable scaling. - self.model_accepts_loss_kwargs = False - - # Add tags for models that have been loaded with the correct transformers version - if hasattr(self.model, "add_model_tags"): - self.model.add_model_tags(self._tag_names) - - if not hasattr(self, "accelerator"): - raise AttributeError( - "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`." - ) - - # Deepspeed Zero-3 does not support precompute_ref_log_probs - if self.is_deepspeed_enabled: - if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs: - raise ValueError( - "You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`." - ) - - if self.ref_model is None: - if not (self.is_peft_model or self.precompute_ref_log_probs): - raise ValueError( - "No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`" - ) - else: - if self.is_deepspeed_enabled: - self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) - else: - self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) - - # Import Liger loss if enabled - if self.args.use_liger_loss: - if not is_liger_kernel_available(): - raise ImportError( - "You set `use_liger_loss=True` but the liger kernel is not available. " - "Please install liger-kernel first: `pip install liger-kernel`" - ) - if self.loss_type in ["apo_zero_unpaired"]: - raise ValueError( - "You cannot set `loss_type='apo_zero_unpaired'` with liger-kernel." - "Only KTO loss is supported with liger-kernel." - ) - if self.precompute_ref_log_probs: - raise ValueError( - "You cannot use `precompute_ref_log_probs=True` with liger kernel. Please set " - "`precompute_ref_log_probs=False`." - ) - if self.is_peft_model or self.ref_adapter_name is not None: - raise ValueError( - "You cannot use `use_liger_loss=True` with Peft models. Please set `use_liger_loss=False`." - ) - self.kto_loss_fn = LigerFusedLinearKTOLoss( - ignore_index=self.label_pad_token_id, beta=self.beta, use_ref_model=(self.ref_model is not None) - ) - - @contextmanager - def null_ref_context(self): - """Context manager for handling null reference model (that is, peft adapter manipulation).""" - with ( - self.accelerator.unwrap_model(self.model).disable_adapter() - if self.is_peft_model and not self.ref_adapter_name - else nullcontext() - ): - if self.ref_adapter_name: - self.model.set_adapter(self.ref_adapter_name) - yield - if self.ref_adapter_name: - self.model.set_adapter(self.model_adapter_name or "default") - - def get_train_dataloader(self) -> DataLoader: - """ - Returns the training [`~torch.utils.data.DataLoader`]. - - Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`. - """ - - if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs: - dataloader_params = { - "batch_size": self.args.per_device_train_batch_size, - "collate_fn": self.data_collator, - "num_workers": self.args.dataloader_num_workers, - "pin_memory": self.args.dataloader_pin_memory, - "shuffle": False, - } - - # prepare dataloader - data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params)) - reference_completion_logps = [] - reference_KL_logps = [] - - for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"): - reference_completion_logp, reference_KL_logp = self.compute_reference_log_probs(padded_batch) - - reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp) - reference_completion_logps.append(reference_completion_logp.cpu()) - - if self.calculate_KL: - reference_KL_logp = self.accelerator.gather_for_metrics(reference_KL_logp) - reference_KL_logps.append(reference_KL_logp.cpu()) - - self.train_dataset = self.train_dataset.add_column( - name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy() - ) - - if self.calculate_KL: - self.train_dataset = self.train_dataset.add_column( - name="reference_KL_logps", column=torch.cat(reference_KL_logps).float().numpy() - ) - - self._precomputed_train_ref_log_probs = True - - return super().get_train_dataloader() - - def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: - """ - Returns the evaluation [`~torch.utils.data.DataLoader`]. - - Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`. - - Args: - eval_dataset (`torch.utils.data.Dataset`, *optional*): - If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted - by the `model.forward()` method are automatically removed. It must implement `__len__`. - """ - if eval_dataset is None and self.eval_dataset is None: - raise ValueError("Trainer: evaluation requires an eval_dataset.") - eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset - - if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs: - dataloader_params = { - "batch_size": self.args.per_device_eval_batch_size, - "collate_fn": self.data_collator, - "num_workers": self.args.dataloader_num_workers, - "pin_memory": self.args.dataloader_pin_memory, - "shuffle": False, - } - - # prepare dataloader - data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params)) - - reference_completion_logps = [] - reference_KL_logps = [] - - for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"): - reference_completion_logp, reference_KL_logp = self.compute_reference_log_probs(padded_batch) - - reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp) - reference_completion_logps.append(reference_completion_logp.cpu()) - - if self.calculate_KL: - reference_KL_logp = self.accelerator.gather_for_metrics(reference_KL_logp) - reference_KL_logps.append(reference_KL_logp.cpu()) - - eval_dataset = eval_dataset.add_column( - name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy() - ) - if self.calculate_KL: - eval_dataset = eval_dataset.add_column( - name="reference_KL_logps", column=torch.cat(reference_KL_logps).float().numpy() - ) - - # Save calculated reference_chosen_logps and reference_rejected_logps to the eval_dataset for subsequent runs - if self.eval_dataset is not None: - self.eval_dataset = eval_dataset - self._precomputed_eval_ref_log_probs = True - - return super().get_eval_dataloader(eval_dataset=eval_dataset) - - def compute_reference_log_probs(self, padded_batch: dict) -> dict: - """Computes log probabilities of the reference model for a single padded batch of a KTO specific dataset.""" - with torch.no_grad(): - if self.ref_model is None: - with self.null_ref_context(): - if self.is_encoder_decoder: - completion_logits = self.model( - padded_batch["prompt_input_ids"], - attention_mask=padded_batch["prompt_attention_mask"], - decoder_input_ids=padded_batch.get("completion_decoder_input_ids"), - labels=padded_batch["completion_labels"], - ).logits - - if self.calculate_KL: - KL_logits = self.model( - padded_batch["KL_prompt_input_ids"], - attention_mask=padded_batch["KL_prompt_attention_mask"], - decoder_input_ids=padded_batch.get("KL_completion_decoder_input_ids"), - labels=padded_batch["KL_completion_labels"], - ).logits - else: - completion_logits = self.model( - padded_batch["completion_input_ids"], - attention_mask=padded_batch["completion_attention_mask"], - ).logits - - if self.calculate_KL: - KL_logits = self.model( - padded_batch["KL_completion_input_ids"], - attention_mask=padded_batch["KL_completion_attention_mask"], - ).logits - else: - if self.is_encoder_decoder: - completion_logits = self.ref_model( - padded_batch["prompt_input_ids"], - attention_mask=padded_batch["prompt_attention_mask"], - decoder_input_ids=padded_batch.get("completion_decoder_input_ids"), - labels=padded_batch["completion_labels"], - ).logits - - if self.calculate_KL: - KL_logits = self.ref_model( - padded_batch["KL_prompt_input_ids"], - attention_mask=padded_batch["KL_prompt_attention_mask"], - decoder_input_ids=padded_batch.get("KL_completion_decoder_input_ids"), - labels=padded_batch["KL_completion_labels"], - ).logits - else: - completion_logits = self.ref_model( - padded_batch["completion_input_ids"], attention_mask=padded_batch["completion_attention_mask"] - ).logits - - if self.calculate_KL: - KL_logits = self.ref_model( - padded_batch["KL_completion_input_ids"], - attention_mask=padded_batch["KL_completion_attention_mask"], - ).logits - - completion_logps = self.get_batch_logps( - completion_logits, - padded_batch["completion_labels"], - average_log_prob=False, - is_encoder_decoder=self.is_encoder_decoder, - label_pad_token_id=self.label_pad_token_id, - ) - - if self.calculate_KL: - KL_logps = self.get_batch_logps( - KL_logits, - padded_batch["KL_completion_labels"], - average_log_prob=False, - is_encoder_decoder=self.is_encoder_decoder, - label_pad_token_id=self.label_pad_token_id, - ) - else: - KL_logps = None - - return completion_logps, KL_logps - - @staticmethod - def get_batch_logps( - logits: torch.FloatTensor, - labels: torch.LongTensor, - average_log_prob: bool = False, - label_pad_token_id: int = -100, - is_encoder_decoder: bool = False, - ) -> torch.FloatTensor: - """Compute the log probabilities of the given labels under the given logits. - - Args: - logits: - Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) - labels: - Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are - ignored. Shape: (batch_size, sequence_length) - average_log_prob: - If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the - log probabilities of the (non-masked) tokens. - label_pad_token_id: - The label value to ignore when computing log probabilities. - is_encoder_decoder: - Whether the model is an encoder-decoder model. If True, the labels are not shifted and the logits are - assumed to already be aligned with the labels. If False, the labels are shifted to the right by one - position, and the logits are assumed to be aligned with the shifted labels. - - Returns: - A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the - given logits. - """ - if logits.shape[:-1] != labels.shape: - # Unsloth: auto-truncate to shorter sequence length (model may have truncated input_ids) - _min_len = min(logits.shape[1], labels.shape[1]) - logits = logits[:, :_min_len, :] - labels = labels[:, :_min_len] - - if not is_encoder_decoder: - labels = labels[:, 1:].clone() - logits = logits[:, :-1, :] - else: - # Fixes end-dec RuntimeError - labels = labels.clone() - - loss_mask = labels != label_pad_token_id - - # dummy token; we'll ignore the losses on these tokens later - labels[labels == label_pad_token_id] = 0 - - per_token_logps = selective_log_softmax(logits, labels) - - if average_log_prob: - return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) - else: - return (per_token_logps * loss_mask).sum(-1) - - def forward( - self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]] - ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: - KL_logps = self._compute_kl_logps(model, batch) - - model_kwargs = ( - { - "labels": batch["completion_labels"], - "decoder_input_ids": batch.get("completion_decoder_input_ids"), - } - if self.is_encoder_decoder - else {} - ) - if self.aux_loss_enabled: - model_kwargs["output_router_logits"] = True - - outputs = model( - batch["completion_input_ids"], - attention_mask=batch["completion_attention_mask"], - **model_kwargs, - ) - completion_logits = outputs.logits - - completion_logps = self.get_batch_logps( - completion_logits, - batch["completion_labels"], - average_log_prob=False, - is_encoder_decoder=self.is_encoder_decoder, - label_pad_token_id=self.label_pad_token_id, - ) - - if completion_logps.shape[0] != len(batch["label"]): - raise ValueError( - "There is a mismatch between the number of examples in this batch and the number of " - "examples for which an output sequence was predicted." - ) - - chosen_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is True] - rejected_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is False] - - chosen_logps = completion_logps[chosen_idx, ...] - rejected_logps = completion_logps[rejected_idx, ...] - - chosen_logits = completion_logits[chosen_idx, ...] - rejected_logits = completion_logits[rejected_idx, ...] - - if self.aux_loss_enabled: - return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, KL_logps, outputs.aux_loss) - else: - return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, KL_logps) - - def kto_loss( - self, - policy_chosen_logps: torch.FloatTensor, - policy_rejected_logps: torch.FloatTensor, - policy_KL_logps: torch.FloatTensor, - reference_chosen_logps: torch.FloatTensor, - reference_rejected_logps: torch.FloatTensor, - reference_KL_logps: torch.FloatTensor, - ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: - """Compute the KTO loss for a batch of policy and reference model log probabilities. - - Args: - policy_chosen_logps: - Log probabilities of the policy model for the chosen responses. Shape: (num(chosen) in batch_size,) - policy_rejected_logps: - Log probabilities of the policy model for the rejected responses. Shape: (num(rejected) in batch_size,) - policy_KL_logps: Log probabilities of the policy model for the KL responses. Shape: (batch_size,) - reference_chosen_logps: - Log probabilities of the reference model for the chosen responses. Shape: (num(chosen) in batch_size,) - reference_rejected_logps: - Log probabilities of the reference model for the rejected responses. Shape: (num(rejected) in - batch_size,) - reference_KL_logps: Log probabilities of the reference model for the KL responses. Shape: (batch_size,) - - Returns: - A tuple of four tensors: (losses, chosen_rewards, rejected_rewards, KL). The losses tensor contains the KTO - loss for each example in the batch. The chosen_rewards and rejected_rewards tensors contain the rewards for - the chosen and rejected responses, respectively. The KL tensor contains the detached KL divergence estimate - between the policy and reference models. - """ - if self.calculate_KL: - kl = (policy_KL_logps - reference_KL_logps).mean().detach() - kl = self.accelerator.gather_for_metrics(kl).mean().clamp(min=0) - else: - kl = torch.zeros(1).to(policy_chosen_logps.device) - - # Chosen losses - if policy_chosen_logps.shape[0] != 0 or reference_chosen_logps.shape[0] != 0: - chosen_logratios = policy_chosen_logps - reference_chosen_logps - - if self.loss_type == "kto": - # Eqn (7) of the KTO paper (https://huggingface.co/papers/2402.01306) - chosen_losses = 1 - F.sigmoid(self.beta * (chosen_logratios - kl)) - elif self.loss_type == "apo_zero_unpaired": - # Unpaired variant of Eqn (7) of the APO paper (https://huggingface.co/papers/2408.06266) - # Use this loss when you believe the chosen outputs are better than your model's default output - chosen_losses = 1 - F.sigmoid(self.beta * chosen_logratios) - - chosen_rewards = self.beta * chosen_logratios.detach() - - else: - # lists can't be empty -- if they are, then accelerate.gather will hang - chosen_losses = torch.Tensor([]).to(self.accelerator.device) - chosen_rewards = torch.Tensor([]).to(self.accelerator.device) - - # Rejected losses - if policy_rejected_logps.shape[0] != 0 or reference_rejected_logps.shape[0] != 0: - rejected_logratios = policy_rejected_logps - reference_rejected_logps - - if self.loss_type == "kto": - rejected_losses = 1 - F.sigmoid(self.beta * (kl - rejected_logratios)) - elif self.loss_type == "apo_zero_unpaired": - rejected_losses = F.sigmoid(self.beta * rejected_logratios) - - rejected_rewards = self.beta * rejected_logratios.detach() - else: - # lists can't be empty -- if they are, then accelerate.gather will hang - rejected_losses = torch.Tensor([]).to(self.accelerator.device) - rejected_rewards = torch.Tensor([]).to(self.accelerator.device) - - losses = torch.cat( - (self.desirable_weight * chosen_losses, self.undesirable_weight * rejected_losses), - 0, - ) - - return losses, chosen_rewards, rejected_rewards, kl - - def _compute_kl_logps(self, model, batch): - """Compute KL log probabilities for a given batch.""" - KL_logps = None - if self.calculate_KL: - if self.is_encoder_decoder: - KL_model_kwargs = { - "input_ids": batch["KL_prompt_input_ids"], - "attention_mask": batch["KL_prompt_attention_mask"], - "labels": batch["KL_completion_labels"], - "decoder_input_ids": batch.get("KL_completion_decoder_input_ids"), - } - else: - KL_model_kwargs = { - "input_ids": batch["KL_completion_input_ids"], - "attention_mask": batch["KL_completion_attention_mask"], - } - - with torch.no_grad(): - KL_logits = model(**KL_model_kwargs).logits - - KL_logps = self.get_batch_logps( - KL_logits, - batch["KL_completion_labels"], - average_log_prob=False, - is_encoder_decoder=self.is_encoder_decoder, - label_pad_token_id=self.label_pad_token_id, - ) - return KL_logps - - def _compute_loss_liger(self, model, batch): - """ - Compute the KTO loss using the Liger-Kernel's LigerFusedLinearKTOLoss. - - Args: - model: - The policy model used for generating log probabilities and outputs. It could be an encoder-decoder - model or a regular language model. - batch: A dictionary containing the input data and labels for the batch. - - Returns: - A dictionary containing the following keys: - - "loss": The computed KTO loss for the batch. - - "chosen_logits_sum": Sum of the logits for the chosen responses from the policy model. - - "rejected_logits_sum": Sum of the logits for the rejected responses from the policy model. - - "chosen_logps": Log probabilities of the chosen responses from the policy model. - - "rejected_logps": Log probabilities of the rejected responses from the policy model. - - "chosen_rewards": Rewards for the chosen responses. - - "rejected_rewards": Rewards for the rejected responses. - - "kl": The KL divergence between the policy and reference models (detached). - - If auxiliary loss is enabled, the dictionary will also include: - - "aux_loss": The auxiliary loss from the model outputs. - """ - policy_KL_logps = self._compute_kl_logps(model, batch) - reference_KL_logps = self._compute_kl_logps(self.ref_model, batch) - if self.calculate_KL: - kl = (policy_KL_logps - reference_KL_logps).mean().detach() - kl = self.accelerator.gather_for_metrics(kl).mean().clamp(min=0) - else: - kl = torch.zeros(1).to(self.accelerator.device) - - model_kwargs = ( - { - "labels": batch["completion_labels"], - "decoder_input_ids": batch.get("completion_decoder_input_ids"), - } - if self.is_encoder_decoder - else {} - ) - if self.aux_loss_enabled: - model_kwargs["output_router_logits"] = True - - if self.is_encoder_decoder: - # 1. Get encoder outputs - encoder_outputs = model.get_encoder()( - batch["completion_input_ids"], - attention_mask=batch["completion_attention_mask"], - return_dict=True, - **model_kwargs, - ) - # 2. Get decoder outputs - outputs = model.get_decoder()( - input_ids=model_kwargs["decoder_input_ids"], - encoder_hidden_states=encoder_outputs.last_hidden_state, - use_cache=False, - **model_kwargs, - ) - # 1. Get reference encoder outputs - ref_encoder_outputs = self.ref_model.get_encoder()( - batch["completion_input_ids"], - attention_mask=batch["completion_attention_mask"], - return_dict=True, - **model_kwargs, - ) - # 2. Get reference decoder outputs - ref_outputs = self.ref_model.get_decoder()( - input_ids=model_kwargs["decoder_input_ids"], - encoder_hidden_states=ref_encoder_outputs.last_hidden_state, - use_cache=False, - **model_kwargs, - ) - else: - # skip the lm head and get the last hidden state - if hasattr(model, "get_decoder") and model.get_decoder() is not None: - base_model = model.get_decoder() - else: - base_attr = getattr(model, "base_model_prefix", self.args.base_model_attribute_name) - base_model = getattr(model, base_attr, model) - outputs = base_model( - batch["completion_input_ids"], - attention_mask=batch["completion_attention_mask"], - use_cache=False, - **model_kwargs, - ) - - # reference model - if hasattr(self.ref_model, "get_decoder") and self.ref_model.get_decoder() is not None: - ref_base_model = self.ref_model.get_decoder() - else: - ref_attr = getattr(self.ref_model, "base_model_prefix", self.args.base_model_attribute_name) - ref_base_model = getattr(self.ref_model, ref_attr, self.ref_model) - ref_outputs = ref_base_model( - batch["completion_input_ids"], - attention_mask=batch["completion_attention_mask"], - use_cache=False, - **model_kwargs, - ) - lm_head = model.get_output_embeddings() - ref_lm_head = self.ref_model.get_output_embeddings() - - ( - loss, - ( - chosen_logps_sum, - rejected_logps_sum, - chosen_logits_sum, - rejected_logits_sum, - chosen_rewards_sum, - rejected_rewards_sum, - ), - ) = self.kto_loss_fn( - _input=outputs.last_hidden_state[:, :-1] if not self.is_encoder_decoder else outputs.last_hidden_state, - lin_weight=lm_head.weight, - target=batch["completion_labels"][:, 1:], - bias=lm_head.bias if hasattr(lm_head, "bias") else None, - preference_labels=torch.tensor(batch["label"], dtype=torch.bool).to(self.accelerator.device), - ref_input=ref_outputs.last_hidden_state[:, :-1] - if not self.is_encoder_decoder - else outputs.last_hidden_state, - ref_weight=ref_lm_head.weight, - ref_bias=ref_lm_head.bias if hasattr(lm_head, "bias") else None, - kl=kl, - ) - - output = { - "loss": loss, - "chosen_logits_sum": chosen_logits_sum, - "rejected_logits_sum": rejected_logits_sum, - "chosen_logps_sum": chosen_logps_sum, - "rejected_logps_sum": rejected_logps_sum, - "chosen_rewards_sum": chosen_rewards_sum, - "rejected_rewards_sum": rejected_rewards_sum, - "kl": kl, - } - if self.aux_loss_enabled: - output["aux_loss"] = outputs.aux_loss - - return output - - def get_batch_loss_metrics( - self, - model, - batch: dict[str, Union[list, torch.LongTensor]], - ): - """Compute the KTO loss and other metrics for the given batch of inputs for train or test.""" - metrics = {} - batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()} - - labels = torch.tensor(batch["label"]) - num_chosen = labels.sum().to(self.accelerator.device) - num_rejected = (len(labels) - num_chosen).to(self.accelerator.device) - - if self.args.use_liger_loss: - model_output = self._compute_loss_liger(model, batch) - losses = model_output["loss"] - policy_chosen_logits = model_output["chosen_logits_sum"] - policy_rejected_logits = model_output["rejected_logits_sum"] - policy_chosen_logps = model_output["chosen_logps_sum"] - policy_rejected_logps = model_output["rejected_logps_sum"] - chosen_rewards = model_output["chosen_rewards_sum"] - rejected_rewards = model_output["rejected_rewards_sum"] - kl = model_output["kl"] - if self.aux_loss_enabled: - aux_loss = model_output["aux_loss"] - else: - forward_output = self.forward(model, batch) - ( - policy_chosen_logps, - policy_rejected_logps, - policy_chosen_logits, - policy_rejected_logits, - policy_KL_logps, - ) = forward_output[:5] - if self.aux_loss_enabled: - aux_loss = forward_output[5] - - # if reference_logps in batch use them, otherwise use the reference model - if "reference_logps" in batch: - chosen_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is True] - rejected_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is False] - - reference_chosen_logps = batch["reference_logps"][chosen_idx, ...] - reference_rejected_logps = batch["reference_logps"][rejected_idx, ...] - if self.calculate_KL: - reference_KL_logps = batch["reference_KL_logps"] - else: - reference_KL_logps = None - else: - with torch.no_grad(): - if self.ref_model is None: - with self.null_ref_context(): - ( - reference_chosen_logps, - reference_rejected_logps, - _, - _, - reference_KL_logps, - ) = self.forward(self.model, batch)[:5] - else: - ( - reference_chosen_logps, - reference_rejected_logps, - _, - _, - reference_KL_logps, - ) = self.forward(self.ref_model, batch)[:5] - - losses, chosen_rewards, rejected_rewards, kl = self.kto_loss( - policy_chosen_logps, - policy_rejected_logps, - policy_KL_logps, - reference_chosen_logps, - reference_rejected_logps, - reference_KL_logps, - ) - - metrics["kl"] = kl.item() - - all_num_chosen = self.accelerator.gather_for_metrics(num_chosen).sum().item() - all_num_rejected = self.accelerator.gather_for_metrics(num_rejected).sum().item() - - if all_num_chosen > 0: - metrics["rewards/chosen_sum"] = ( - self.accelerator.gather_for_metrics(chosen_rewards.nansum()).nansum().item() - ) - metrics["logps/chosen_sum"] = ( - self.accelerator.gather_for_metrics(policy_chosen_logps.nansum()).nansum().item() - ) - metrics["logits/chosen_sum"] = ( - self.accelerator.gather_for_metrics(policy_chosen_logits.nansum()).nansum().item() - ) - metrics["count/chosen"] = all_num_chosen - - if all_num_rejected > 0: - metrics["rewards/rejected_sum"] = ( - self.accelerator.gather_for_metrics(rejected_rewards.nansum()).nansum().item() - ) - metrics["logps/rejected_sum"] = ( - self.accelerator.gather_for_metrics(policy_rejected_logps.nansum()).nansum().item() - ) - metrics["logits/rejected_sum"] = ( - self.accelerator.gather_for_metrics(policy_rejected_logits.nansum()).nansum().item() - ) - metrics["count/rejected"] = all_num_rejected - - loss = losses.nanmean() - if self.aux_loss_enabled: - loss += self.aux_loss_coef * aux_loss - - return loss, metrics - - def compute_loss( - self, - model: Union[PreTrainedModel, nn.Module], - inputs: dict[str, Union[torch.Tensor, Any]], - return_outputs=False, - num_items_in_batch=None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]: - compute_loss_context_manager = ( - autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() - ) - - with compute_loss_context_manager: - loss, metrics = self.get_batch_loss_metrics(model, inputs) - - # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class: - loss = loss.to(self.args.device) - # force log the metrics - if self.accelerator.is_main_process: - self.store_metrics(metrics, train_eval="train") - - if return_outputs: - return (loss, metrics) - return loss - - def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None: - for key, value in metrics.items(): - self._stored_metrics[train_eval][key].append(value) - - def _get_train_sampler(self, dataset: Optional[Dataset] = None) -> Optional[torch.utils.data.Sampler]: - if dataset is None: - dataset = self.train_dataset - if dataset is None or not has_length(dataset): - return None - return SequentialSampler(dataset) - - def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]: - """Generate samples from the model and reference model for the given batch of inputs.""" - - # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with - # the torch amp context manager as some hidden states are silently casted to full precision. - generate_context_manager = ( - autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() - ) - - with generate_context_manager: - policy_output = model.generate( - input_ids=batch["prompt_input_ids"], - attention_mask=batch["prompt_attention_mask"], - max_length=self.max_length, - do_sample=True, - pad_token_id=self.processing_class.pad_token_id, - ) - - # if reference_output in batch use that otherwise use the reference model - if "reference_output" in batch: - reference_output = batch["reference_output"] - else: - if self.ref_model is None: - with self.null_ref_context(): - reference_output = self.model.generate( - input_ids=batch["prompt_input_ids"], - attention_mask=batch["prompt_attention_mask"], - max_length=self.max_length, - do_sample=True, - pad_token_id=self.processing_class.pad_token_id, - ) - else: - reference_output = self.ref_model.generate( - input_ids=batch["prompt_input_ids"], - attention_mask=batch["prompt_attention_mask"], - max_length=self.max_length, - do_sample=True, - pad_token_id=self.processing_class.pad_token_id, - ) - - policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id) - policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True) - - reference_output = pad_to_length(reference_output, self.max_length, self.processing_class.pad_token_id) - reference_output_decoded = self.processing_class.batch_decode(reference_output, skip_special_tokens=True) - - return policy_output_decoded, reference_output_decoded - - def prediction_step( - self, - model: Union[PreTrainedModel, nn.Module], - inputs: dict[str, Union[torch.Tensor, Any]], - prediction_loss_only: bool, - ignore_keys: Optional[list[str]] = None, - ): - if ignore_keys is None: - if hasattr(model, "config"): - ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", []) - else: - ignore_keys = [] - - prediction_context_manager = ( - autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() - ) - with torch.no_grad(), prediction_context_manager: - loss, metrics = self.get_batch_loss_metrics(model, inputs) - - # force log the metrics - if self.accelerator.is_main_process: - self.store_metrics(metrics, train_eval="eval") - - if prediction_loss_only: - return (loss.detach(), None, None) - - # logits for the chosen and rejected samples from model - logits_dict = {} - if "logits/chosen_sum" in metrics: - logits_dict["eval_logits/chosen"] = metrics["logits/chosen_sum"] - if "logits/rejected_sum" in metrics: - logits_dict["eval_logits/rejected"] = metrics["logits/rejected_sum"] - logits = [v for k, v in logits_dict.items() if k not in ignore_keys] - logits = torch.tensor(logits, device=self.accelerator.device) - labels = torch.zeros(logits.shape[0], device=self.accelerator.device) - - return (loss.detach(), logits, labels) - - def evaluation_loop( - self, - dataloader: DataLoader, - description: str, - prediction_loss_only: Optional[bool] = None, - ignore_keys: Optional[list[str]] = None, - metric_key_prefix: str = "eval", - ) -> EvalLoopOutput: - """ - Overriding built-in evaluation loop to store metrics for each batch. Prediction/evaluation loop, shared by - `Trainer.evaluate()` and `Trainer.predict()`. - - Works both with or without labels. - """ - - # Sample and save to game log if requested (for one batch to save time) - if self.generate_during_eval: - # Generate random indices within the range of the total number of samples - num_samples = len(dataloader.dataset) - random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size) - - # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader - random_batch_dataset = dataloader.dataset.select(random_indices) - random_batch = self.data_collator(random_batch_dataset) - random_batch = self._prepare_inputs(random_batch) - - target_labels = torch.tensor(random_batch["label"], dtype=torch.bool, device=self.accelerator.device) - target_indices = torch.where(~target_labels)[0] - target_batch = { - "prompt_input_ids": random_batch["prompt_input_ids"][target_indices], - "prompt_attention_mask": random_batch["prompt_attention_mask"][target_indices], - "prompt": itemgetter(*target_indices)(random_batch["prompt"]), - } - policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, target_batch) - - table = pd.DataFrame( - columns=["Prompt", "Policy", "Ref Model"], - data=[ - [prompt, pol[len(prompt) :], ref[len(prompt) :]] - for prompt, pol, ref in zip(target_batch["prompt"], policy_output_decoded, ref_output_decoded) - ], - ) - if "wandb" in self.args.report_to: - wandb.log({"game_log": wandb.Table(data=table)}) - - if "comet_ml" in self.args.report_to: - log_table_to_comet_experiment( - name="game_log.csv", - table=table, - ) - - # Base evaluation - initial_output = super().evaluation_loop( - dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix - ) - - return initial_output - - def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None: - """ - Log `logs` on the various objects watching training, including stored metrics. - - Args: - logs (`dict[str, float]`): - The values to log. - start_time (`float`, *optional*): - Start time of the training. - """ - # logs either has 'loss' or 'eval_loss' - train_eval = "train" if "loss" in logs else "eval" - # train metrics should have no prefix, eval should have 'eval_' - prefix = "eval_" if train_eval == "eval" else "" - # accumulate average metrics from sums and lengths - for split in ["chosen", "rejected"]: - if f"count/{split}" in self._stored_metrics[train_eval]: - count_sum = torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"]).sum().item() - for metric in ["rewards", "logps", "logits"]: - logs[f"{prefix}{metric}/{split}"] = ( - torch.Tensor(self._stored_metrics[train_eval][f"{metric}/{split}_sum"]).sum().item() - / count_sum - ) - # delete obsolete metric - del self._stored_metrics[train_eval][f"{metric}/{split}_sum"] - del self._stored_metrics[train_eval][f"count/{split}"] - # calculate reward margin - if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs: - logs[f"{prefix}rewards/margins"] = logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"] - # Add averaged stored metrics to logs - for key, metrics in self._stored_metrics[train_eval].items(): - logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item() - del self._stored_metrics[train_eval] - return super().log(logs, start_time) - - # Ensure the model card is saved along with the checkpoint - def _save_checkpoint(self, model, trial): - if self.args.hub_model_id is None: - model_name = Path(self.args.output_dir).name - else: - model_name = self.args.hub_model_id.split("/")[-1] - self.create_model_card(model_name=model_name) - super()._save_checkpoint(model, trial) -class UnslothKTOTrainer(_UnslothKTOTrainer): - """ - - Initialize KTOTrainer. - - Args: - model ([`~transformers.PreTrainedModel`]): - The model to train, preferably an [`~transformers.AutoModelForSequenceClassification`]. - ref_model ([`PreTrainedModelWrapper`]): - Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation - and loss. If no reference model is provided, the trainer will create a reference model with the same - architecture as the model to be optimized. - args ([`KTOConfig`]): - The arguments to use for training. - train_dataset ([`~datasets.Dataset`]): - The dataset to use for training. - eval_dataset ([`~datasets.Dataset`]): - The dataset to use for evaluation. - processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*): - Processing class used to process the data. If provided, will be used to automatically process the inputs - for the model, and it will be saved along the model to make it easier to rerun an interrupted training or - reuse the fine-tuned model. - data_collator ([`~transformers.DataCollator`], *optional*): - The data collator to use for training. If None is specified, the default data collator - ([`DPODataCollatorWithPadding`]) will be used which will pad the sequences to the maximum length of the - sequences in the batch, given a dataset of paired sequences. - model_init (`Callable[[], transformers.PreTrainedModel]`): - The model initializer to use for training. If None is specified, the default model initializer will be - used. - callbacks (`list[transformers.TrainerCallback]`): - The callbacks to use for training. - optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): - The optimizer and scheduler to use for training. - preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): - The function to use to preprocess the logits before computing the metrics. - peft_config (`dict`, defaults to `None`): - The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in - a PEFT model. - compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): - The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to - metric values. - model_adapter_name (`str`, defaults to `None`): - Name of the train target PEFT adapter, when using LoRA with multiple adapters. - ref_adapter_name (`str`, defaults to `None`): - Name of the reference PEFT adapter, when using LoRA with multiple adapters. - - """ - def __init__( - self, - model = None, - ref_model = None, - args = None, - train_dataset = None, - eval_dataset = None, - processing_class = None, - data_collator = None, - model_init = None, - callbacks = None, - preprocess_logits_for_metrics = None, - peft_config = None, - compute_metrics = None, - model_adapter_name = None, - ref_adapter_name = None, - **kwargs - ): - if args is None: args = UnslothKTOConfig() - use_bf16 = getattr(args, 'bf16', False) - if type(use_bf16) is not bool: use_bf16 = False - use_fp16 = getattr(args, 'fp16', False) - if type(use_fp16) is not bool: use_fp16 = False - force_float32 = False - full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1' - if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'): - print('Unsloth: Switching to float32 training since model cannot work with float16') - force_float32 = True - mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') - dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None) - if dtype is None: dtype = model.get_input_embeddings().weight.dtype - from unsloth_zoo.utils import _get_dtype - dtype = _get_dtype(dtype) - float16 = dtype == torch.float16 - if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`') - if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`') - if force_float32: - # Forced float32 training - args.fp16 = False - args.bf16 = False - os.environ['ACCELERATE_MIXED_PRECISION'] = 'no' - if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no' - # args.mixed_precision is a new argument which needs to be set now - elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32': - # Mixed precision training - args.fp16 = float16 - args.bf16 = not float16 - os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16' - if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16' - # args.mixed_precision is a new argument which needs to be set now - elif mixed_precision_dtype == 'bfloat16': - # Both False since bfloat16 full finetuning doesn't do any autocasting. - args.fp16 = False - args.bf16 = False - os.environ['ACCELERATE_MIXED_PRECISION'] = 'no' - if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no' - # args.mixed_precision is a new argument which needs to be set now - - if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no': - args.eval_strategy = 'steps' - if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1 - ga_steps = getattr(args, 'gradient_accumulation_steps', None) - if ga_steps is not None and ga_steps > 1: - from transformers import __version__ as transformers_version - if Version(transformers_version) <= Version('4.45.2'): - print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n' - '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`') - if getattr(args, 'eval_strategy', 'no') != 'no': - eval_bsz = getattr(args, 'per_device_eval_batch_size', 8) - if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size - if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps - fp16_full_eval = getattr(args, 'fp16_full_eval', False) - if type(fp16_full_eval) is not bool: fp16_full_eval = False - bf16_full_eval = getattr(args, 'bf16_full_eval', False) - if type(bf16_full_eval) is not bool: bf16_full_eval = False - if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True - if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False - if force_float32: - args.bf16_full_eval = False - args.fp16_full_eval = False - elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16': - args.bf16_full_eval = True - args.fp16_full_eval = False - elif not bf16_full_eval and not fp16_full_eval: - args.bf16_full_eval = args.bf16 - args.fp16_full_eval = args.fp16 - _output_logits = False - if locals().get('compute_metrics', None) is not None: _output_logits = True - if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True - if _output_logits: - os.environ['UNSLOTH_RETURN_LOGITS'] = '1' - if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'): - pass - else: - model_max_seq_length = getattr(model, 'max_seq_length', None) - args_max_seq_length = getattr(args, 'max_seq_length', None) - if args_max_seq_length is None and model_max_seq_length is not None: - max_seq_length = model.max_seq_length - if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length - elif args_max_seq_length is not None and model_max_seq_length is not None: - if args_max_seq_length > model_max_seq_length: - print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but ' - 'the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.') - args.max_seq_length = model_max_seq_length - if model is not None and hasattr(model, 'for_training'): - model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True)) - if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right' - if 'processing_class' in locals(): - if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right' - if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right' - __tokenizer = processing_class if 'processing_class' in locals() else tokenizer - from unsloth_zoo.vision_utils import UnslothVisionDataCollator - if not isinstance(data_collator, UnslothVisionDataCollator): - if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names: - data_collator = TransformersDataCollatorForLanguageModeling( - __tokenizer, - mlm = False, - mlm_probability = 0.0, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names: - data_collator = DataCollatorForSeq2Seq( - __tokenizer, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - else: - if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False - if hasattr(args, 'dataset_text_field'): args.dataset_text_field = '' - if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True} - if not isinstance(data_collator, UnslothVisionDataCollator): - if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'): - if isinstance(data_collator, DataCollatorForSeq2Seq): - data_collator = DataCollatorForSeq2Seq( - __tokenizer.tokenizer, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - else: - data_collator = TransformersDataCollatorForLanguageModeling( - __tokenizer.tokenizer, - mlm = False, - mlm_probability = 0.0, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - other_metrics = [] - - from unsloth_zoo.logging_utils import PatchRLStatistics - PatchRLStatistics('kto_trainer', other_metrics) - - # [TODO] Fix up DataParallel multiplying batch sizes - # [TODO] DDP works, but DP seems to not work? [TODO] - if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1: - if getattr(args, "_n_gpu", 1) != 1: - args._n_gpu = 1 - if "model" in locals() and hasattr(model, "for_training"): - model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True)) - super().__init__( - model = model, - ref_model = ref_model, - args = args, - train_dataset = train_dataset, - eval_dataset = eval_dataset, - processing_class = processing_class, - data_collator = data_collator, - model_init = model_init, - callbacks = callbacks, - preprocess_logits_for_metrics = preprocess_logits_for_metrics, - peft_config = peft_config, - compute_metrics = compute_metrics, - model_adapter_name = model_adapter_name, - ref_adapter_name = ref_adapter_name,**kwargs) - if "model" in locals() and hasattr(model, "for_inference"): - model.for_inference() - if hasattr(self, 'neftune_hook_handle'): - self.neftune_hook_handle.remove() - if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle - if getattr(args, 'neftune_noise_alpha', None) is not None: - model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha - pass - if hasattr(self, 'accelerator'): - scaler = self.accelerator.scaler - current_model = model - while hasattr(current_model, 'model'): - current_model.accelerator_scaler = scaler - current_model = current_model.model - current_model.accelerator_scaler = scaler - pass - if hasattr(self, 'train'): - self.train = MethodType(prepare_for_training_mode(self.__class__.train), self) - pass - if hasattr(self, 'llm') and self.llm is not None and hasattr(self.llm, 'get_tokenizer'): - _vllm_tok = self.llm.get_tokenizer() - _pc = getattr(self, 'processing_class', None) or getattr(self, 'tokenizer', None) - if _vllm_tok is not None and _pc is not None and getattr(_pc, 'chat_template', None) is not None and getattr(_vllm_tok, 'chat_template', None) is None: - _vllm_tok.chat_template = _pc.chat_template - pass - -pass - - -if hasattr(logger, "addFilter"): - import logging - class HideLoggingMessage(logging.Filter): - def __init__(self, text): self.text = text - def filter(self, x): return not (self.text in x.getMessage()) - pass - logger.addFilter(HideLoggingMessage("`use_cache=True`")) - diff --git a/unsloth_compiled_cache/UnslothNashMDTrainer.py b/unsloth_compiled_cache/UnslothNashMDTrainer.py deleted file mode 100644 index b44f278..0000000 --- a/unsloth_compiled_cache/UnslothNashMDTrainer.py +++ /dev/null @@ -1,1364 +0,0 @@ -""" -2026.2.1 -2026.2.1 -4.57.6 -0.24.0 -__UNSLOTH_VERSIONING__ -""" - -# Unsloth auto generated code -# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see . - -from torch import Tensor -import torch -import torch.nn as nn -from torch.nn import functional as F -from unsloth_zoo.temporary_patches.common import torch_compile -from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable -from trl.trainer.nash_md_trainer import (Any, BaseImageProcessor, BasePairwiseJudge, Callable, Dataset, EvalPrediction, F, FeatureExtractionMixin, GeometricMixtureWrapper, IterableDataset, NashMDConfig, NashMDTrainer, OnlineDPOTrainer, OptimizerNames, Optional, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SIMPLE_CHAT_TEMPLATE, TrainerCallback, Union, empty_cache, get_reward, is_conversational, is_peft_available, jinja2, maybe_apply_chat_template, nn, selective_log_softmax, textwrap, torch, truncate_right, unwrap_model_for_generation) - - -import os -from typing import * -from dataclasses import dataclass, field -from packaging.version import Version -import torch -import numpy as np -from contextlib import nullcontext -from torch.nn import functional as F -import inspect -from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling -from transformers.training_args import ParallelMode -from unsloth_zoo.device_type import DEVICE_TYPE, device_synchronize - -# Wrap trainer with padding to right and enable training mode -# Also patches W&B since multiple runs must use wandb.finish() -import functools -from types import MethodType -try: - from unsloth_zoo.gradient_checkpointing import reset_unsloth_gradient_checkpointing_buffers -except: - def reset_unsloth_gradient_checkpointing_buffers(): pass -def prepare_for_training_mode(f): - @functools.wraps(f) - def wrapper(self, *args, **kwargs): - # Enable training mode - _was_training = None - # Get gradient checkpointing setting from training arguments - use_gc = getattr(self.args, 'gradient_checkpointing', True) - if hasattr(self, 'model') and hasattr(self.model, "training"): - _was_training = self.model.training - if hasattr(self, 'model') and hasattr(self.model, "for_training"): - self.model.for_training(use_gradient_checkpointing=use_gc) - output = f(self, *args, **kwargs) - # Restore previous mode when possible - if hasattr(self, 'model') and hasattr(self.model, "for_inference"): - if _was_training is False: - self.model.for_inference() - elif _was_training is True and hasattr(self.model, "for_training"): - self.model.for_training(use_gradient_checkpointing=use_gc) - # Reset gradient checkpointing buffers to free memory while staying ready for next run - try: - reset_unsloth_gradient_checkpointing_buffers() - except: - pass - # Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run - try: - import wandb - wandb.finish() - except: - pass - return output - return wrapper -pass - -torch_compile_options = { - "epilogue_fusion" : True, - "max_autotune" : False, - "shape_padding" : True, - "trace.enabled" : False, - "triton.cudagraphs" : False, -} - -@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) -def chunked_hidden_states_selective_log_softmax( - hidden_states: torch.Tensor, - lm_head: torch.Tensor, - index: torch.Tensor, - chunks: int = 4, - logit_scale_multiply: float = 0.0, - logit_scale_divide: float = 0.0, - logit_softcapping: float = 0.0, - temperature: float = 1.0, -) -> torch.Tensor: - # All Unsloth Zoo code licensed under AGPL3 - flat_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1]) - flat_index = index.reshape(-1) - - chunked_hidden_states = torch.chunk(flat_hidden_states, chunks=chunks, dim=0) - chunked_index = torch.chunk(flat_index, chunks=chunks, dim=0) - - all_per_token_logps = [] - - for chunk_hidden_states, chunk_index in zip(chunked_hidden_states, chunked_index): - chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t() - - if logit_scale_multiply != 0.0: - chunk_logits = chunk_logits * logit_scale_multiply - if logit_scale_divide != 0.0: - chunk_logits = chunk_logits / logit_scale_divide - if logit_softcapping != 0.0: - chunk_logits = chunk_logits * torch.tanh(chunk_logits / logit_softcapping) - - chunk_logits = chunk_logits.to(torch.float32) - - if temperature != 1.0: - chunk_logits = chunk_logits / temperature - - selected_logits = torch.gather(chunk_logits, dim=-1, index=chunk_index.unsqueeze(-1)).squeeze(-1) - logsumexp_values = torch.logsumexp(chunk_logits, dim=-1) - per_token_logps = selected_logits - logsumexp_values - all_per_token_logps.append(per_token_logps) - - all_per_token_logps = torch.concat(all_per_token_logps) - - all_per_token_logps = all_per_token_logps.reshape((hidden_states.shape[0], hidden_states.shape[1])) - return all_per_token_logps - -@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) -def chunked_selective_log_softmax(logits, index): - # Split into 4 chunks only - chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0) - chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0) - all_per_token_logps = [] - # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index) - for chunk_logits, chunk_index in zip(chunked_logits, chunked_index): - chunk_logits = chunk_logits.to(torch.float32) - selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1) - logsumexp_values = torch.logsumexp(chunk_logits, dim = -1) - per_token_logps = selected_logits - logsumexp_values - all_per_token_logps.append(per_token_logps) - pass - all_per_token_logps = torch.concat(all_per_token_logps) - all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1])) - return all_per_token_logps - -def calculate_pad_tokens_in_prompt( - input_ids: torch.Tensor, - logits_to_keep: int, - pad_token_id: int -) -> torch.Tensor: - """ - Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens - """ - if logits_to_keep >= input_ids.shape[1]: - raise ValueError("logits_to_keep must be smaller than the sequence length.") - - prompt_section = input_ids[:, :-logits_to_keep] - - padding_mask = (prompt_section == pad_token_id) - - pad_token_counts = padding_mask.sum(dim=1) - - return pad_token_counts - -def create_completion_attention_mask( - completion_input_ids: torch.Tensor, - left_pad_tokens_per_prompt: torch.Tensor, - max_left_pad: int, - pad_token_id: int -) -> torch.Tensor: - """ - Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad] - - Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens - and pad are pad tokens, this function would make a completion mask that would 0 out the pad - and p tokens. so in this example [0,0,0,1,1,1,0,0,0] - """ - batch_size, completion_len = completion_input_ids.shape - device = completion_input_ids.device - - num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt - - indices = torch.arange(completion_len, device=device).unsqueeze(0) - shift_mask = indices >= num_tokens_to_mask.unsqueeze(1) - - non_padding_mask = (completion_input_ids != pad_token_id) - - final_mask = shift_mask & non_padding_mask - - return final_mask - -def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor: - """ - Moves all padding tokens in each sequence of a batch to the right. - """ - mask = (tensor != pad_id) - # Must do stable=True since binary mark is unordered - sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True) - packed_tensor = torch.gather(tensor, 1, sorted_indices) - return packed_tensor - -def align_logprobs_with_mask( - logprob_tensor: torch.Tensor, - attention_mask: torch.Tensor, - pad_value: float = 0.0 -) -> torch.Tensor: - """ - Aligns a log probability tensor with a given attention mask. - """ - - device = logprob_tensor.device - batch_size, logprob_seq_len = logprob_tensor.shape - mask_seq_len = attention_mask.shape[1] - - padded_logprobs = torch.full( - attention_mask.shape, - fill_value=pad_value, - dtype=logprob_tensor.dtype, - device=device - ) - - left_pad_counts = torch.argmax(attention_mask, dim=1) - - cols = torch.arange(logprob_seq_len, device=device) - dest_indices = left_pad_counts.unsqueeze(1) + cols - - # Create destination row indices - # Shape: [batch_size, logprob_seq_len] - row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices) - - # --- 4. Filter out-of-bounds indices and perform assignment --- - # Create a mask to identify only the indices that are within the bounds - # of the target tensor's sequence length. - valid_mask = dest_indices < mask_seq_len - - # Use this mask to select only the valid row indices, column indices, - # and the corresponding values from the logprob tensor. - # This flattens the selected elements into 1D tensors. - valid_rows = row_indices[valid_mask] - valid_cols = dest_indices[valid_mask] - valid_vals = logprob_tensor[valid_mask] - - # Place the valid values into their correct positions in the padded tensor - # using a single, efficient advanced indexing operation. - padded_logprobs[valid_rows, valid_cols] = valid_vals - - return padded_logprobs - -def autotune_batch_and_chunks( - total_input_rows, - seq_len, - hidden_size, - vocab_size, - dtype_bytes=16, - multiplier=None -): - if multiplier is None: - final_m = max(4, seq_len // 4096) - else: - final_m = multiplier - - if torch.cuda.is_available(): - free_bytes, _ = torch.cuda.mem_get_info() - limit_gb = (free_bytes / (1024**3))*.80 - elif hasattr(torch, "xpu") and torch.xpu.is_available(): - # For XPU: estimate free memory from total - reserved - total_mem = torch.xpu.get_device_properties(0).total_memory - reserved_mem = torch.xpu.memory_reserved() - free_bytes = total_mem - reserved_mem - limit_gb = (free_bytes / (1024**3)) * 0.80 - else: - # Fallback: assume 8GB available - limit_gb = 8.0 - - bytes_to_gb = 1024**3 - - b_vals = torch.arange(total_input_rows, 0, -1, device='cpu', dtype=torch.float32) - - hidden_gb = (b_vals * seq_len * hidden_size * dtype_bytes) / bytes_to_gb - - base_logits = ((b_vals/total_input_rows) * b_vals * seq_len * vocab_size * dtype_bytes) / bytes_to_gb - logits_gb = base_logits / final_m - - total_mem_gb = hidden_gb + logits_gb - - valid_mask = total_mem_gb <= limit_gb - valid_indices = torch.nonzero(valid_mask, as_tuple=False) - - if valid_indices.shape[0] == 0: - #This means your GPU will OOM - return 4, final_m - - best_idx = valid_indices[0].item() - final_b = int(b_vals[best_idx].item()) - - return final_b, final_m -@dataclass -class UnslothNashMDConfig(NashMDConfig): - """ - - Configuration class for the [`NashMDTrainer`]. - - Subclass of [`OnlineDPOConfig`] we can use all its arguments and add the following: - - Parameters: - mixture_coef (`float` or `list[float]`, *optional*, defaults to `0.5`): - Logit mixture coefficient for the model and reference model. If a list of floats is provided then the - mixture coefficient is selected for each new epoch and the last coefficient is used for the rest of the - epochs. - - """ - vllm_sampling_params: Optional[Any] = field( - default = None, - metadata = {'help': 'vLLM SamplingParams'}, - ) - unsloth_num_chunks : Optional[int] = field( - default = -1, - metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'}, - ) - unsloth_logit_chunk_multiplier : Optional[int] = field( - default = None, - metadata = {'help': 'Multiplier for chunked logit computations.'}, - ) - unsloth_grpo_mini_batch : Optional[int] = field( - default = None, - metadata = {'help': 'Mini batch size for GRPO hidden state accumulation. Default is None unless user defines it.'}, - ) - max_seq_length : Optional[int] = field( - default = None, - metadata = {'help': 'Maximum sequence length to truncate to.'}, - ) - def __init__( - self, - output_dir = None, - overwrite_output_dir = None, - do_train = False, - do_eval = False, - do_predict = False, - eval_strategy = 'no', - prediction_loss_only = False, - per_device_train_batch_size = 4, - per_device_eval_batch_size = 4, - per_gpu_train_batch_size = None, - per_gpu_eval_batch_size = None, - gradient_accumulation_steps = 2, - eval_accumulation_steps = 2, - eval_delay = 0, - torch_empty_cache_steps = 250, - learning_rate = 5e-05, - weight_decay = 0.01, - adam_beta1 = 0.9, - adam_beta2 = 0.999, - adam_epsilon = 1e-08, - max_grad_norm = 1.0, - num_train_epochs = 3.0, - max_steps = -1, - lr_scheduler_type = 'linear', - lr_scheduler_kwargs = None, - warmup_ratio = 0.1, - warmup_steps = 0, - log_level = 'passive', - log_level_replica = 'warning', - log_on_each_node = True, - logging_dir = None, - logging_strategy = 'steps', - logging_first_step = False, - logging_steps = 1, - logging_nan_inf_filter = False, - save_strategy = 'steps', - save_steps = 500, - save_total_limit = None, - save_safetensors = True, - save_on_each_node = False, - save_only_model = False, - restore_callback_states_from_checkpoint = False, - no_cuda = False, - use_cpu = False, - use_mps_device = False, - seed = 3407, - data_seed = 3407, - jit_mode_eval = False, - bf16 = False, - fp16 = False, - fp16_opt_level = 'O1', - half_precision_backend = 'auto', - bf16_full_eval = False, - fp16_full_eval = False, - tf32 = None, - local_rank = -1, - ddp_backend = None, - tpu_num_cores = None, - tpu_metrics_debug = False, - debug = '', - dataloader_drop_last = False, - eval_steps = None, - dataloader_num_workers = 0, - dataloader_prefetch_factor = None, - past_index = -1, - run_name = None, - disable_tqdm = None, - remove_unused_columns = True, - label_names = None, - load_best_model_at_end = False, - metric_for_best_model = None, - greater_is_better = None, - ignore_data_skip = False, - fsdp = None, - fsdp_min_num_params = 0, - fsdp_config = None, - fsdp_transformer_layer_cls_to_wrap = None, - accelerator_config = None, - parallelism_config = None, - deepspeed = None, - label_smoothing_factor = 0.0, - optim = 'adamw_8bit', - optim_args = None, - adafactor = False, - group_by_length = False, - length_column_name = 'length', - report_to = 'none', - project = 'huggingface', - trackio_space_id = 'trackio', - ddp_find_unused_parameters = None, - ddp_bucket_cap_mb = None, - ddp_broadcast_buffers = None, - dataloader_pin_memory = True, - dataloader_persistent_workers = False, - skip_memory_metrics = True, - use_legacy_prediction_loop = False, - push_to_hub = False, - resume_from_checkpoint = None, - hub_model_id = None, - hub_strategy = 'every_save', - hub_token = None, - hub_private_repo = None, - hub_always_push = False, - hub_revision = None, - gradient_checkpointing = True, - gradient_checkpointing_kwargs = None, - include_inputs_for_metrics = False, - eval_do_concat_batches = True, - fp16_backend = 'auto', - push_to_hub_model_id = None, - push_to_hub_organization = None, - push_to_hub_token = None, - mp_parameters = '', - auto_find_batch_size = False, - full_determinism = False, - torchdynamo = None, - ray_scope = 'last', - ddp_timeout = 1800, - torch_compile = False, - torch_compile_backend = None, - torch_compile_mode = None, - include_tokens_per_second = False, - include_num_input_tokens_seen = False, - neftune_noise_alpha = None, - optim_target_modules = None, - batch_eval_metrics = False, - eval_on_start = False, - use_liger_kernel = False, - liger_kernel_config = None, - eval_use_gather_object = False, - average_tokens_across_devices = True, - reward_model_path = None, - judge = None, - max_new_tokens = 64, - max_length = 512, - temperature = 0.9, - top_p = 1.0, - top_k = None, - min_p = None, - repetition_penalty = 1.0, - generation_kwargs = {}, - use_transformers_paged = False, - cache_implementation = None, - missing_eos_penalty = None, - loss_type = 'sigmoid', - disable_dropout = True, - use_vllm = False, - vllm_model_impl = 'vllm', - vllm_guided_decoding_regex = None, - vllm_gpu_memory_utilization = 0.55, - vllm_mode = 'colocate', - vllm_server_base_url = None, - vllm_server_host = '0.0.0.0', - vllm_server_port = 8000, - vllm_server_timeout = 240.0, - vllm_tensor_parallel_size = 1, - ds3_gather_for_generation = True, - model_init_kwargs = None, - reward_weights = None, - dataset_num_proc = None, - gpu_memory_utilization = None, - vllm_sampling_params = None, - unsloth_num_chunks = -1, - unsloth_logit_chunk_multiplier = None, - unsloth_grpo_mini_batch = None, - max_seq_length = None, - **kwargs, - ): - if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!') - if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!') - if num_train_epochs is None: - num_train_epochs = 3.0 # Default to 3 epochs if None, max_steps will override - if output_dir is None and save_strategy == 'steps' and save_steps == 500: - output_dir = 'unsloth_training_checkpoints' - save_strategy = 'no' - import multiprocessing as _mp - if _mp.get_start_method() != 'fork': - dataset_num_proc = None - elif dataset_num_proc is None: - import psutil - dataset_num_proc = min(max((psutil.cpu_count() or 1)+4, 2), 64) - memory_gb_left = psutil.virtual_memory().available / (1024**3) - if memory_gb_left <= 2: dataset_num_proc = 1 - else: dataset_num_proc = min(dataset_num_proc, int(memory_gb_left)) - if temperature <= 0: - raise ValueError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.') - elif temperature >= 10: - raise ValueError('Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic.') - - - super().__init__( - output_dir = output_dir, - overwrite_output_dir = overwrite_output_dir, - do_train = do_train, - do_eval = do_eval, - do_predict = do_predict, - eval_strategy = eval_strategy, - prediction_loss_only = prediction_loss_only, - per_device_train_batch_size = per_device_train_batch_size, - per_device_eval_batch_size = per_device_eval_batch_size, - per_gpu_train_batch_size = per_gpu_train_batch_size, - per_gpu_eval_batch_size = per_gpu_eval_batch_size, - gradient_accumulation_steps = gradient_accumulation_steps, - eval_accumulation_steps = eval_accumulation_steps, - eval_delay = eval_delay, - torch_empty_cache_steps = torch_empty_cache_steps, - learning_rate = learning_rate, - weight_decay = weight_decay, - adam_beta1 = adam_beta1, - adam_beta2 = adam_beta2, - adam_epsilon = adam_epsilon, - max_grad_norm = max_grad_norm, - num_train_epochs = num_train_epochs, - max_steps = max_steps, - lr_scheduler_type = lr_scheduler_type, - lr_scheduler_kwargs = lr_scheduler_kwargs, - warmup_ratio = warmup_ratio, - warmup_steps = warmup_steps, - log_level = log_level, - log_level_replica = log_level_replica, - log_on_each_node = log_on_each_node, - logging_dir = logging_dir, - logging_strategy = logging_strategy, - logging_first_step = logging_first_step, - logging_steps = logging_steps, - logging_nan_inf_filter = logging_nan_inf_filter, - save_strategy = save_strategy, - save_steps = save_steps, - save_total_limit = save_total_limit, - save_safetensors = save_safetensors, - save_on_each_node = save_on_each_node, - save_only_model = save_only_model, - restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint, - no_cuda = no_cuda, - use_cpu = use_cpu, - use_mps_device = use_mps_device, - seed = seed, - data_seed = data_seed, - jit_mode_eval = jit_mode_eval, - bf16 = bf16, - fp16 = fp16, - fp16_opt_level = fp16_opt_level, - half_precision_backend = half_precision_backend, - bf16_full_eval = bf16_full_eval, - fp16_full_eval = fp16_full_eval, - tf32 = tf32, - local_rank = local_rank, - ddp_backend = ddp_backend, - tpu_num_cores = tpu_num_cores, - tpu_metrics_debug = tpu_metrics_debug, - debug = debug, - dataloader_drop_last = dataloader_drop_last, - eval_steps = eval_steps, - dataloader_num_workers = dataloader_num_workers, - dataloader_prefetch_factor = dataloader_prefetch_factor, - past_index = past_index, - run_name = run_name, - disable_tqdm = disable_tqdm, - remove_unused_columns = remove_unused_columns, - label_names = label_names, - load_best_model_at_end = load_best_model_at_end, - metric_for_best_model = metric_for_best_model, - greater_is_better = greater_is_better, - ignore_data_skip = ignore_data_skip, - fsdp = fsdp, - fsdp_min_num_params = fsdp_min_num_params, - fsdp_config = fsdp_config, - fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap, - accelerator_config = accelerator_config, - parallelism_config = parallelism_config, - deepspeed = deepspeed, - label_smoothing_factor = label_smoothing_factor, - optim = optim, - optim_args = optim_args, - adafactor = adafactor, - group_by_length = group_by_length, - length_column_name = length_column_name, - report_to = report_to, - project = project, - trackio_space_id = trackio_space_id, - ddp_find_unused_parameters = ddp_find_unused_parameters, - ddp_bucket_cap_mb = ddp_bucket_cap_mb, - ddp_broadcast_buffers = ddp_broadcast_buffers, - dataloader_pin_memory = dataloader_pin_memory, - dataloader_persistent_workers = dataloader_persistent_workers, - skip_memory_metrics = skip_memory_metrics, - use_legacy_prediction_loop = use_legacy_prediction_loop, - push_to_hub = push_to_hub, - resume_from_checkpoint = resume_from_checkpoint, - hub_model_id = hub_model_id, - hub_strategy = hub_strategy, - hub_token = hub_token, - hub_private_repo = hub_private_repo, - hub_always_push = hub_always_push, - hub_revision = hub_revision, - gradient_checkpointing = gradient_checkpointing, - gradient_checkpointing_kwargs = gradient_checkpointing_kwargs, - include_inputs_for_metrics = include_inputs_for_metrics, - eval_do_concat_batches = eval_do_concat_batches, - fp16_backend = fp16_backend, - push_to_hub_model_id = push_to_hub_model_id, - push_to_hub_organization = push_to_hub_organization, - push_to_hub_token = push_to_hub_token, - mp_parameters = mp_parameters, - auto_find_batch_size = auto_find_batch_size, - full_determinism = full_determinism, - torchdynamo = torchdynamo, - ray_scope = ray_scope, - ddp_timeout = ddp_timeout, - torch_compile = torch_compile, - torch_compile_backend = torch_compile_backend, - torch_compile_mode = torch_compile_mode, - include_tokens_per_second = include_tokens_per_second, - include_num_input_tokens_seen = include_num_input_tokens_seen, - neftune_noise_alpha = neftune_noise_alpha, - optim_target_modules = optim_target_modules, - batch_eval_metrics = batch_eval_metrics, - eval_on_start = eval_on_start, - use_liger_kernel = use_liger_kernel, - liger_kernel_config = liger_kernel_config, - eval_use_gather_object = eval_use_gather_object, - average_tokens_across_devices = average_tokens_across_devices, - reward_model_path = reward_model_path, - judge = judge, - max_new_tokens = max_new_tokens, - max_length = max_length, - temperature = temperature, - top_p = top_p, - top_k = top_k, - min_p = min_p, - repetition_penalty = repetition_penalty, - generation_kwargs = generation_kwargs, - use_transformers_paged = use_transformers_paged, - cache_implementation = cache_implementation, - missing_eos_penalty = missing_eos_penalty, - loss_type = loss_type, - disable_dropout = disable_dropout, - use_vllm = use_vllm, - vllm_model_impl = vllm_model_impl, - vllm_guided_decoding_regex = vllm_guided_decoding_regex, - vllm_gpu_memory_utilization = vllm_gpu_memory_utilization, - vllm_mode = vllm_mode, - vllm_server_base_url = vllm_server_base_url, - vllm_server_host = vllm_server_host, - vllm_server_port = vllm_server_port, - vllm_server_timeout = vllm_server_timeout, - vllm_tensor_parallel_size = vllm_tensor_parallel_size, - ds3_gather_for_generation = ds3_gather_for_generation, - model_init_kwargs = model_init_kwargs, - reward_weights = reward_weights, - dataset_num_proc = dataset_num_proc, - gpu_memory_utilization = gpu_memory_utilization,**kwargs) - self.vllm_sampling_params = vllm_sampling_params - self.unsloth_num_chunks = unsloth_num_chunks - if unsloth_grpo_mini_batch is not None: - if self.generation_batch_size >= unsloth_grpo_mini_batch: - self.unsloth_grpo_mini_batch = unsloth_grpo_mini_batch - else: - raise ValueError( - f"Unsloth GRPO mini batch size needs to be less than or equal to the effective generation batch size, " - f"which is self.per_device_train_batch_size * gradient_accumulation_steps." - ) - self.unsloth_logit_chunk_multiplier = unsloth_logit_chunk_multiplier - self.max_seq_length = max_seq_length - -pass - -class _UnslothNashMDTrainer(OnlineDPOTrainer): - """""" - - _tag_names = ["trl", "nash-md"] - _name = "Nash-MD" - _paper = { - "title": "Nash Learning from Human Feedback", - "id": "2312.00886", - # docstyle-ignore - "citation": textwrap.dedent("""\ - @inproceedings{munos2024nash, - title = {{Nash Learning from Human Feedback}}, - author = {R{\'{e}}mi Munos and Michal Valko and Daniele Calandriello and Mohammad Gheshlaghi Azar and Mark Rowland and Zhaohan Daniel Guo and Yunhao Tang and Matthieu Geist and Thomas Mesnard and C{\\^{o}}me Fiegel and Andrea Michi and Marco Selvi and Sertan Girgin and Nikola Momchev and Olivier Bachem and Daniel J. Mankowitz and Doina Precup and Bilal Piot}, - year = 2024, - booktitle = {Forty-first International Conference on Machine Learning, {ICML} 2024, Vienna, Austria, July 21-27, 2024}, - publisher = {OpenReview.net}, - url = {https://openreview.net/forum?id=Y5AmNYiyCQ} - }"""), - } - - def __init__( - self, - model: Union[PreTrainedModel, nn.Module] = None, - ref_model: Union[PreTrainedModel, nn.Module] = None, - reward_funcs: Union[PreTrainedModel, nn.Module, None] = None, - judge: Optional[BasePairwiseJudge] = None, - args: Optional[NashMDConfig] = None, - data_collator: Optional[Callable] = None, - train_dataset: Optional[Union[Dataset, IterableDataset]] = None, - eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None, - processing_class: Optional[ - Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] - ] = None, - peft_config: Optional[dict] = None, - compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None, - callbacks: Optional[list[TrainerCallback]] = None, - optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), - preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, - # Deprecated parameters - reward_model: Optional[Union[PreTrainedModel, nn.Module]] = None, - ) -> None: - super().__init__( - model=model, - ref_model=ref_model, - reward_funcs=reward_funcs, - judge=judge, - args=args, - data_collator=data_collator, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - processing_class=processing_class, - reward_processing_classes=processing_class, - peft_config=peft_config, - compute_metrics=compute_metrics, - callbacks=callbacks, - optimizers=optimizers, - preprocess_logits_for_metrics=preprocess_logits_for_metrics, - reward_model=reward_model, - ) - - self._mixture_coef = self.args.mixture_coef - - # Overwrite the stats dictionary to include NashMD specific statistics - self.stats = { - # Remove "non_score_reward", "rlhf_reward", "scores_margin" - # Add "mixture_coef" - "loss/kl": [], - "objective/entropy": [], - "loss/score": [], - "rewards/probabilities": [], - "rewards/accuracies": [], - "rewards/margins": [], - "logps/chosen": [], - "logps/rejected": [], - "val/model_contain_eos_token": [], - "val/ref_contain_eos_token": [], - "beta": [], - "mixture_coef": [], - } - if self.reward_funcs is not None: - if len(self.reward_funcs) != 1: - raise ValueError("NashMDTrainer only supports one reward function/model.") - self.reward_funcs = self.reward_funcs[0] - self.stats["rewards/chosen"] = [] - self.stats["rewards/rejected"] = [] - - @property - def mixture_coef(self): - if isinstance(self._mixture_coef, list): - epoch = self.state.epoch - return self._mixture_coef[epoch] if epoch < len(self._mixture_coef) else self._mixture_coef[-1] - else: - return self._mixture_coef - - def _generate_completions(self, model, prompts): - # Generate completions from the policy model. - with unwrap_model_for_generation(model, self.accelerator) as unwrapped_policy_for_gen_ctx: - model_output = unwrapped_policy_for_gen_ctx.generate( - input_ids=prompts["input_ids"], - attention_mask=prompts["attention_mask"], - generation_config=self.generation_config, - ) - - # Get the DDP/FSDP unwrapped version of the main model. - # This will be the policy model for GeometricMixtureWrapper (PEFT adapters active if PEFT is used). - policy_model_for_gmw = self.accelerator.unwrap_model(model) - - # Determine the correct reference model for GeometricMixtureWrapper. - # This also needs to be DDP/FSDP unwrapped. - ref_model_for_gmw: torch.nn.Module - if self.ref_model is None: - # No explicit ref_model is provided. - # Use the base of the main `model` if it's a PEFT model. - # policy_model_for_gmw is already DDP-unwrapped. - if is_peft_available() and isinstance(policy_model_for_gmw, PeftModel): - ref_model_for_gmw = policy_model_for_gmw.get_base_model() - else: - # Not a PEFT model (or PEFT not available), or already a base model. - # Use the DDP-unwrapped policy model itself as the reference. - ref_model_for_gmw = policy_model_for_gmw - else: - # An explicit ref_model is provided. Unwrap it for DDP/FSDP. - ref_model_for_gmw = self.accelerator.unwrap_model(self.ref_model) - - # Both models given to GeometricMixtureWrapper (policy_model_for_gmw and ref_model_for_gmw) are DDP-unwrapped. - with torch.no_grad(): # Ensure no_grad context for mixture model generation - mixture_model = GeometricMixtureWrapper( - model=policy_model_for_gmw, - ref_model=ref_model_for_gmw, - generation_config=self.generation_config, - mixture_coef=self.mixture_coef, - device=self.accelerator.device, - ) - - mixture_output = mixture_model.generate( - input_ids=prompts["input_ids"], - attention_mask=prompts["attention_mask"], - generation_config=self.generation_config, - ) - - return model_output, mixture_output - - def _process_completions(self, model_output, mixture_output, prompts): - context_length = prompts["input_ids"].shape[1] - - # Process model completions - model_completion_ids = model_output[:, context_length:] - model_completion_ids, model_completion_mask = truncate_right( - model_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id - ) - model_data = { - "input_ids": torch.cat((prompts["input_ids"], model_completion_ids), dim=1), - "attention_mask": torch.cat((prompts["attention_mask"], model_completion_mask), dim=1), - "raw": prompts["raw"], - } - - # Process reference model completions - mixture_completion_ids = mixture_output[:, context_length:] - mixture_completion_ids, mixture_completion_mask = truncate_right( - mixture_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id - ) - mixture_data = { - "input_ids": torch.cat((prompts["input_ids"], mixture_completion_ids), dim=1), - "attention_mask": torch.cat((prompts["attention_mask"], mixture_completion_mask), dim=1), - "raw": prompts["raw"], - } - - return model_data, mixture_data - - def _compute_rewards(self, model_data, mixture_data, context_length): - with torch.no_grad(): - _, model_scores, _ = get_reward( - self.reward_funcs, model_data["input_ids"], self.processing_class.pad_token_id, context_length - ) - _, mixture_scores, _ = get_reward( - self.reward_funcs, mixture_data["input_ids"], self.processing_class.pad_token_id, context_length - ) - - # Apply EOS penalty if needed - if self.args.missing_eos_penalty is not None: - model_contain_eos = torch.any(model_data["input_ids"] == self.processing_class.eos_token_id, dim=-1) - mixture_contain_eos = torch.any(mixture_data["input_ids"] == self.processing_class.eos_token_id, dim=-1) - model_scores[~model_contain_eos] -= self.args.missing_eos_penalty - mixture_scores[~mixture_contain_eos] -= self.args.missing_eos_penalty - - return model_scores, mixture_scores - - def _compute_judge(self, model_data, mixture_data, context_length): - prompts = model_data["raw"] - model_data_completions = self.processing_class.batch_decode( - model_data["input_ids"][:, context_length:], skip_special_tokens=True - ) - model_data_completions = [completion.strip() for completion in model_data_completions] - - mixture_data_completions = self.processing_class.batch_decode( - mixture_data["input_ids"][:, context_length:], skip_special_tokens=True - ) - mixture_data_completions = [completion.strip() for completion in mixture_data_completions] - if is_conversational({"prompt": prompts[0]}): - model_data_completions = [ - [{"role": "assistant", "content": completion}] for completion in model_data_completions - ] - environment = jinja2.Environment() - template = environment.from_string(SIMPLE_CHAT_TEMPLATE) - prompts = [template.render(messages=message) for message in prompts] - model_data_completions = [template.render(messages=completion) for completion in model_data_completions] - - mixture_data_completions = [ - [{"role": "assistant", "content": completion}] for completion in mixture_data_completions - ] - mixture_data_completions = [ - template.render(messages=completion) for completion in mixture_data_completions - ] - - probability = self.judge.judge( - prompts, - list(zip(model_data_completions, mixture_data_completions)), - return_scores=True, - ) - return torch.tensor(probability, device=model_data["input_ids"].device) - - def _compute_logprobs(self, model, model_data, context_length): - def compute_logprobs_for_data(m, data): - output = m(data["input_ids"], attention_mask=data["attention_mask"]) - logits = output.logits[:, context_length - 1 : -1] - token_logprobs = selective_log_softmax(logits, data["input_ids"][:, context_length:]) - return token_logprobs - - # Compute logprobs for model completions under the model - model_logprobs_model_data = compute_logprobs_for_data(model, model_data) - - # Compute logprobs of model completions under the reference model - with torch.no_grad(): - if self.ref_model is None: - with model.disable_adapter(): - ref_logprobs_model_data = compute_logprobs_for_data(model, model_data) - else: - ref_logprobs_model_data = compute_logprobs_for_data(self.ref_model, model_data) - - # Mask padding tokens - model_padding_mask = model_data["attention_mask"][:, context_length:] == 0 - model_logprobs_model_data = model_logprobs_model_data.masked_fill(model_padding_mask, 0.0) - ref_logprobs_model_data = ref_logprobs_model_data.masked_fill(model_padding_mask, 0.0) - - return (model_logprobs_model_data, ref_logprobs_model_data) - - def _compute_losses( - self, - model_logprobs_model_data, - ref_logprobs_model_data, - probability, - ): - # reinforce score where 0.5 is a control variate - score = (probability - 0.5) * model_logprobs_model_data.sum(1) - - # kl divergence via reinforce - with torch.no_grad(): - log_ratio = model_logprobs_model_data - ref_logprobs_model_data - kl_div_log = log_ratio.sum(1) - kl_div_loss = (log_ratio * model_logprobs_model_data).sum(1) - - # final loss - loss = self.beta * kl_div_loss - score - - return loss.mean(), score, kl_div_log - - def _log_statistics( - self, - model_data, - mixture_data, - model_logprobs_model_data, - ref_logprobs_model_data, - probability, - score, - kl_div, - context_length, - model_scores=None, - mixture_scores=None, - ): - # Helper function to gather and compute mean - def gather_mean(tensor): - return self.accelerator.gather_for_metrics(tensor).mean().item() - - # Log score - self.stats["loss/score"].append(gather_mean(score)) - # Log KL divergence - self.stats["loss/kl"].append(gather_mean(kl_div)) - - # Log logprobs - model_logprobs_model_data_sum = model_logprobs_model_data.sum(1) - ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1) - - self.stats["logps/chosen"].append(gather_mean(model_logprobs_model_data_sum)) - self.stats["logps/rejected"].append(gather_mean(ref_logprobs_model_data_sum)) - - # Log rewards - if self.reward_funcs is not None: - self.stats["rewards/chosen"].append(gather_mean(model_scores)) - self.stats["rewards/rejected"].append(gather_mean(mixture_scores)) - - # Log probabilities - self.stats["rewards/probabilities"].append(gather_mean(probability)) - - # Calculate entropy for model data - entropy_model_data = -model_logprobs_model_data.sum(1) - self.stats["objective/entropy"].append(gather_mean(entropy_model_data)) - - # Calculate margins - margin = model_logprobs_model_data_sum - ref_logprobs_model_data_sum - self.stats["rewards/margins"].append(gather_mean(margin)) - - # Calculate accuracy - accuracy = (margin > 0).float() - self.stats["rewards/accuracies"].append(gather_mean(accuracy)) - - # Log EOS token statistics - model_eos = (model_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1) - mixture_eos = (mixture_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1) - self.stats["val/model_contain_eos_token"].append(gather_mean(model_eos.float())) - self.stats["val/ref_contain_eos_token"].append(gather_mean(mixture_eos.float())) - - # Log beta and mixture coef - self.stats["beta"].append(self.beta) - self.stats["mixture_coef"].append(self.mixture_coef) - - def training_step( - self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None - ) -> torch.Tensor: - model.train() - - # Apply chat template and tokenize the input - batch_size = len(next(iter(inputs.values()))) - prompts = inputs["prompt"] - inputs = [{k: v[i] for k, v in inputs.items()} for i in range(batch_size)] - inputs = [maybe_apply_chat_template(x, self.processing_class) for x in inputs] - inputs = [self.tokenize_row(x, self.model.config.is_encoder_decoder, self.processing_class) for x in inputs] - inputs = self.data_collator(inputs) - - # need the prompt_ only - inputs = self._prepare_inputs(inputs) - context_length = inputs["prompt_input_ids"].shape[1] - prompts = { - "input_ids": inputs["prompt_input_ids"], - "attention_mask": inputs["prompt_attention_mask"], - "raw": prompts, - } - del inputs - - # Sample completions from both the model and the reference model - model_output, mixture_output = self._generate_completions(model, prompts) - - # Process model completions - model_data, mixture_data = self._process_completions(model_output, mixture_output, prompts) - - # Compute rewards - if self.reward_funcs is not None: - model_scores, mixture_scores = self._compute_rewards(model_data, mixture_data, context_length) - # probability of the model data vs the mixture data - probability = F.sigmoid(model_scores - mixture_scores) - else: - model_scores, mixture_scores = None, None - probability = self._compute_judge(model_data, mixture_data, context_length) - - # Compute logprobs - model_logprobs_model_data, ref_logprobs_model_data = self._compute_logprobs(model, model_data, context_length) - - # Compute loss - loss, score, kl_div = self._compute_losses(model_logprobs_model_data, ref_logprobs_model_data, probability) - - # Log everything - self._log_statistics( - model_data, - mixture_data, - model_logprobs_model_data.detach(), - ref_logprobs_model_data, - probability, - score.detach(), - kl_div.detach(), - context_length, - model_scores, - mixture_scores, - ) - - if ( - self.args.torch_empty_cache_steps is not None - and self.state.global_step % self.args.torch_empty_cache_steps == 0 - ): - empty_cache() - - kwargs = {} - # For LOMO optimizers you need to explicitly use the learning rate - if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]: - kwargs["learning_rate"] = self._get_learning_rate() - - if self.args.n_gpu > 1: - loss = loss.mean() # mean() to average on multi-gpu parallel training - - self.accelerator.backward(loss, **kwargs) - - return loss.detach() / self.args.gradient_accumulation_steps -class UnslothNashMDTrainer(_UnslothNashMDTrainer): - """ - - Trainer for the Nash-MD method. - - It is implemented as a subclass of [`OnlineDPOTrainer`]. - - Args: - model ([`~transformers.PreTrainedModel`]): - The model to train, preferably an `AutoModelForCausalLM`. - ref_model ([`PreTrainedModelWrapper`]): - Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation - and loss. If no reference model is provided, the trainer will create a reference model with the same - architecture as the model to be optimized. - reward_funcs ([`~transformers.PreTrainedModel`]): - The reward model to score completions with, preferably an - [`~transformers.AutoModelForSequenceClassification`]. - judge ([`BasePairwiseJudge`]): - The judge to use for pairwise comparison of model completions. - args ([`NashMDConfig`]): - The NashMD config arguments to use for training. - data_collator ([`~transformers.DataCollator`]): - The data collator to use for training. If None is specified, the default data collator - ([`DPODataCollatorWithPadding`]) will be used which will pad the sequences to the maximum length of the - sequences in the batch, given a dataset of paired sequences. - train_dataset ([`~datasets.Dataset`]): - The dataset to use for training. - eval_dataset ([`~datasets.Dataset`]): - The dataset to use for evaluation. - processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*): - Processing class used to process the data. If provided, will be used to automatically process the inputs - for the model, and it will be saved along the model to make it easier to rerun an interrupted training or - reuse the fine-tuned model. - peft_config (`dict`): - The peft config to use for training. - compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): - The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to - metric values. - callbacks (`list[transformers.TrainerCallback]`): - The callbacks to use for training. - optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): - The optimizer and scheduler to use for training. - preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): - The function to use to preprocess the logits before computing the metrics. - - reward_model: - - - - This parameter is deprecated and will be removed in version 0.25.0. Use `reward_funcs` instead. - - - - """ - def __init__( - self, - model = None, - ref_model = None, - reward_funcs = None, - judge = None, - args = None, - data_collator = None, - train_dataset = None, - eval_dataset = None, - processing_class = None, - peft_config = None, - compute_metrics = None, - callbacks = None, - preprocess_logits_for_metrics = None, - reward_model = None, - **kwargs - ): - if args is None: args = UnslothNashMDConfig() - use_bf16 = getattr(args, 'bf16', False) - if type(use_bf16) is not bool: use_bf16 = False - use_fp16 = getattr(args, 'fp16', False) - if type(use_fp16) is not bool: use_fp16 = False - force_float32 = False - full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1' - if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'): - print('Unsloth: Switching to float32 training since model cannot work with float16') - force_float32 = True - mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') - dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None) - if dtype is None: dtype = model.get_input_embeddings().weight.dtype - from unsloth_zoo.utils import _get_dtype - dtype = _get_dtype(dtype) - float16 = dtype == torch.float16 - if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`') - if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`') - if force_float32: - # Forced float32 training - args.fp16 = False - args.bf16 = False - os.environ['ACCELERATE_MIXED_PRECISION'] = 'no' - if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no' - # args.mixed_precision is a new argument which needs to be set now - elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32': - # Mixed precision training - args.fp16 = float16 - args.bf16 = not float16 - os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16' - if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16' - # args.mixed_precision is a new argument which needs to be set now - elif mixed_precision_dtype == 'bfloat16': - # Both False since bfloat16 full finetuning doesn't do any autocasting. - args.fp16 = False - args.bf16 = False - os.environ['ACCELERATE_MIXED_PRECISION'] = 'no' - if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no' - # args.mixed_precision is a new argument which needs to be set now - - if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no': - args.eval_strategy = 'steps' - if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1 - ga_steps = getattr(args, 'gradient_accumulation_steps', None) - if ga_steps is not None and ga_steps > 1: - from transformers import __version__ as transformers_version - if Version(transformers_version) <= Version('4.45.2'): - print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n' - '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`') - if getattr(args, 'eval_strategy', 'no') != 'no': - eval_bsz = getattr(args, 'per_device_eval_batch_size', 8) - if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size - if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps - fp16_full_eval = getattr(args, 'fp16_full_eval', False) - if type(fp16_full_eval) is not bool: fp16_full_eval = False - bf16_full_eval = getattr(args, 'bf16_full_eval', False) - if type(bf16_full_eval) is not bool: bf16_full_eval = False - if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True - if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False - if force_float32: - args.bf16_full_eval = False - args.fp16_full_eval = False - elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16': - args.bf16_full_eval = True - args.fp16_full_eval = False - elif not bf16_full_eval and not fp16_full_eval: - args.bf16_full_eval = args.bf16 - args.fp16_full_eval = args.fp16 - _output_logits = False - if locals().get('compute_metrics', None) is not None: _output_logits = True - if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True - if _output_logits: - os.environ['UNSLOTH_RETURN_LOGITS'] = '1' - if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'): - pass - else: - model_max_seq_length = getattr(model, 'max_seq_length', None) - args_max_seq_length = getattr(args, 'max_seq_length', None) - if args_max_seq_length is None and model_max_seq_length is not None: - max_seq_length = model.max_seq_length - if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length - elif args_max_seq_length is not None and model_max_seq_length is not None: - if args_max_seq_length > model_max_seq_length: - print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but ' - 'the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.') - args.max_seq_length = model_max_seq_length - if model is not None and hasattr(model, 'for_training'): - model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True)) - if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right' - if 'processing_class' in locals(): - if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right' - if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right' - __tokenizer = processing_class if 'processing_class' in locals() else tokenizer - from unsloth_zoo.vision_utils import UnslothVisionDataCollator - if not isinstance(data_collator, UnslothVisionDataCollator): - if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names: - data_collator = TransformersDataCollatorForLanguageModeling( - __tokenizer, - mlm = False, - mlm_probability = 0.0, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names: - data_collator = DataCollatorForSeq2Seq( - __tokenizer, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - else: - if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False - if hasattr(args, 'dataset_text_field'): args.dataset_text_field = '' - if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True} - if not isinstance(data_collator, UnslothVisionDataCollator): - if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'): - if isinstance(data_collator, DataCollatorForSeq2Seq): - data_collator = DataCollatorForSeq2Seq( - __tokenizer.tokenizer, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - else: - data_collator = TransformersDataCollatorForLanguageModeling( - __tokenizer.tokenizer, - mlm = False, - mlm_probability = 0.0, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - other_metrics = [] - - from unsloth_zoo.logging_utils import PatchRLStatistics - PatchRLStatistics('nash_md_trainer', other_metrics) - - # [TODO] Fix up DataParallel multiplying batch sizes - # [TODO] DDP works, but DP seems to not work? [TODO] - if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1: - if getattr(args, "_n_gpu", 1) != 1: - args._n_gpu = 1 - if "model" in locals() and hasattr(model, "for_training"): - model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True)) - super().__init__( - model = model, - ref_model = ref_model, - reward_funcs = reward_funcs, - judge = judge, - args = args, - data_collator = data_collator, - train_dataset = train_dataset, - eval_dataset = eval_dataset, - processing_class = processing_class, - peft_config = peft_config, - compute_metrics = compute_metrics, - callbacks = callbacks, - preprocess_logits_for_metrics = preprocess_logits_for_metrics, - reward_model = reward_model,**kwargs) - if "model" in locals() and hasattr(model, "for_inference"): - model.for_inference() - if hasattr(self, 'neftune_hook_handle'): - self.neftune_hook_handle.remove() - if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle - if getattr(args, 'neftune_noise_alpha', None) is not None: - model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha - pass - if hasattr(self, 'accelerator'): - scaler = self.accelerator.scaler - current_model = model - while hasattr(current_model, 'model'): - current_model.accelerator_scaler = scaler - current_model = current_model.model - current_model.accelerator_scaler = scaler - pass - if hasattr(self, 'train'): - self.train = MethodType(prepare_for_training_mode(self.__class__.train), self) - pass - if hasattr(self, 'llm') and self.llm is not None and hasattr(self.llm, 'get_tokenizer'): - _vllm_tok = self.llm.get_tokenizer() - _pc = getattr(self, 'processing_class', None) or getattr(self, 'tokenizer', None) - if _vllm_tok is not None and _pc is not None and getattr(_pc, 'chat_template', None) is not None and getattr(_vllm_tok, 'chat_template', None) is None: - _vllm_tok.chat_template = _pc.chat_template - pass - -pass diff --git a/unsloth_compiled_cache/UnslothORPOTrainer.py b/unsloth_compiled_cache/UnslothORPOTrainer.py deleted file mode 100644 index 1ed2b5a..0000000 --- a/unsloth_compiled_cache/UnslothORPOTrainer.py +++ /dev/null @@ -1,1884 +0,0 @@ -""" -2026.2.1 -2026.2.1 -4.57.6 -0.24.0 -__UNSLOTH_VERSIONING__ -""" - -# Unsloth auto generated code -# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see . - -from torch import Tensor -import torch -import torch.nn as nn -from torch.nn import functional as F -from unsloth_zoo.temporary_patches.common import torch_compile -from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable -from trl.trainer.orpo_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, BaseTrainer, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, ORPOConfig, ORPOTrainer, Optional, PartialState, Path, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, Union, add_bos_token_if_needed, add_eos_token_if_needed, autocast, defaultdict, disable_dropout_in_model, inspect, is_comet_available, is_peft_available, is_torch_fx_proxy, is_torch_xla_available, is_wandb_available, log_table_to_comet_experiment, logger, logging, maybe_apply_chat_template, maybe_extract_prompt, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, selective_log_softmax, textwrap, torch, warnings, AutoModelForCausalLM, BaseImageProcessor, Callable, DPODataCollatorWithPadding, DataCollator, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, ORPOConfig, ORPOTrainer, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, Union, autocast, defaultdict, disable_dropout_in_model, inspect, is_comet_available, is_peft_available, is_wandb_available, logger, maybe_apply_chat_template, maybe_extract_prompt, nn, np, os, peft_module_casting_to_bf16, prepare_model_for_kbit_training, torch, warnings, F, Optional, PeftModel, PreTrainedModel, is_peft_available, logger, os, torch) - - -import os -from typing import * -from dataclasses import dataclass, field -from packaging.version import Version -import torch -import numpy as np -from contextlib import nullcontext -from torch.nn import functional as F -import inspect -from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling -from transformers.training_args import ParallelMode -from unsloth_zoo.device_type import DEVICE_TYPE, device_synchronize - -# Wrap trainer with padding to right and enable training mode -# Also patches W&B since multiple runs must use wandb.finish() -import functools -from types import MethodType -try: - from unsloth_zoo.gradient_checkpointing import reset_unsloth_gradient_checkpointing_buffers -except: - def reset_unsloth_gradient_checkpointing_buffers(): pass -def prepare_for_training_mode(f): - @functools.wraps(f) - def wrapper(self, *args, **kwargs): - # Enable training mode - _was_training = None - # Get gradient checkpointing setting from training arguments - use_gc = getattr(self.args, 'gradient_checkpointing', True) - if hasattr(self, 'model') and hasattr(self.model, "training"): - _was_training = self.model.training - if hasattr(self, 'model') and hasattr(self.model, "for_training"): - self.model.for_training(use_gradient_checkpointing=use_gc) - output = f(self, *args, **kwargs) - # Restore previous mode when possible - if hasattr(self, 'model') and hasattr(self.model, "for_inference"): - if _was_training is False: - self.model.for_inference() - elif _was_training is True and hasattr(self.model, "for_training"): - self.model.for_training(use_gradient_checkpointing=use_gc) - # Reset gradient checkpointing buffers to free memory while staying ready for next run - try: - reset_unsloth_gradient_checkpointing_buffers() - except: - pass - # Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run - try: - import wandb - wandb.finish() - except: - pass - return output - return wrapper -pass - -torch_compile_options = { - "epilogue_fusion" : True, - "max_autotune" : False, - "shape_padding" : True, - "trace.enabled" : False, - "triton.cudagraphs" : False, -} - -@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) -def chunked_hidden_states_selective_log_softmax( - hidden_states: torch.Tensor, - lm_head: torch.Tensor, - index: torch.Tensor, - chunks: int = 4, - logit_scale_multiply: float = 0.0, - logit_scale_divide: float = 0.0, - logit_softcapping: float = 0.0, - temperature: float = 1.0, -) -> torch.Tensor: - # All Unsloth Zoo code licensed under AGPL3 - flat_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1]) - flat_index = index.reshape(-1) - - chunked_hidden_states = torch.chunk(flat_hidden_states, chunks=chunks, dim=0) - chunked_index = torch.chunk(flat_index, chunks=chunks, dim=0) - - all_per_token_logps = [] - - for chunk_hidden_states, chunk_index in zip(chunked_hidden_states, chunked_index): - chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t() - - if logit_scale_multiply != 0.0: - chunk_logits = chunk_logits * logit_scale_multiply - if logit_scale_divide != 0.0: - chunk_logits = chunk_logits / logit_scale_divide - if logit_softcapping != 0.0: - chunk_logits = chunk_logits * torch.tanh(chunk_logits / logit_softcapping) - - chunk_logits = chunk_logits.to(torch.float32) - - if temperature != 1.0: - chunk_logits = chunk_logits / temperature - - selected_logits = torch.gather(chunk_logits, dim=-1, index=chunk_index.unsqueeze(-1)).squeeze(-1) - logsumexp_values = torch.logsumexp(chunk_logits, dim=-1) - per_token_logps = selected_logits - logsumexp_values - all_per_token_logps.append(per_token_logps) - - all_per_token_logps = torch.concat(all_per_token_logps) - - all_per_token_logps = all_per_token_logps.reshape((hidden_states.shape[0], hidden_states.shape[1])) - return all_per_token_logps - -@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) -def chunked_selective_log_softmax(logits, index): - # Split into 4 chunks only - chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0) - chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0) - all_per_token_logps = [] - # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index) - for chunk_logits, chunk_index in zip(chunked_logits, chunked_index): - chunk_logits = chunk_logits.to(torch.float32) - selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1) - logsumexp_values = torch.logsumexp(chunk_logits, dim = -1) - per_token_logps = selected_logits - logsumexp_values - all_per_token_logps.append(per_token_logps) - pass - all_per_token_logps = torch.concat(all_per_token_logps) - all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1])) - return all_per_token_logps - -def calculate_pad_tokens_in_prompt( - input_ids: torch.Tensor, - logits_to_keep: int, - pad_token_id: int -) -> torch.Tensor: - """ - Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens - """ - if logits_to_keep >= input_ids.shape[1]: - raise ValueError("logits_to_keep must be smaller than the sequence length.") - - prompt_section = input_ids[:, :-logits_to_keep] - - padding_mask = (prompt_section == pad_token_id) - - pad_token_counts = padding_mask.sum(dim=1) - - return pad_token_counts - -def create_completion_attention_mask( - completion_input_ids: torch.Tensor, - left_pad_tokens_per_prompt: torch.Tensor, - max_left_pad: int, - pad_token_id: int -) -> torch.Tensor: - """ - Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad] - - Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens - and pad are pad tokens, this function would make a completion mask that would 0 out the pad - and p tokens. so in this example [0,0,0,1,1,1,0,0,0] - """ - batch_size, completion_len = completion_input_ids.shape - device = completion_input_ids.device - - num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt - - indices = torch.arange(completion_len, device=device).unsqueeze(0) - shift_mask = indices >= num_tokens_to_mask.unsqueeze(1) - - non_padding_mask = (completion_input_ids != pad_token_id) - - final_mask = shift_mask & non_padding_mask - - return final_mask - -def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor: - """ - Moves all padding tokens in each sequence of a batch to the right. - """ - mask = (tensor != pad_id) - # Must do stable=True since binary mark is unordered - sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True) - packed_tensor = torch.gather(tensor, 1, sorted_indices) - return packed_tensor - -def align_logprobs_with_mask( - logprob_tensor: torch.Tensor, - attention_mask: torch.Tensor, - pad_value: float = 0.0 -) -> torch.Tensor: - """ - Aligns a log probability tensor with a given attention mask. - """ - - device = logprob_tensor.device - batch_size, logprob_seq_len = logprob_tensor.shape - mask_seq_len = attention_mask.shape[1] - - padded_logprobs = torch.full( - attention_mask.shape, - fill_value=pad_value, - dtype=logprob_tensor.dtype, - device=device - ) - - left_pad_counts = torch.argmax(attention_mask, dim=1) - - cols = torch.arange(logprob_seq_len, device=device) - dest_indices = left_pad_counts.unsqueeze(1) + cols - - # Create destination row indices - # Shape: [batch_size, logprob_seq_len] - row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices) - - # --- 4. Filter out-of-bounds indices and perform assignment --- - # Create a mask to identify only the indices that are within the bounds - # of the target tensor's sequence length. - valid_mask = dest_indices < mask_seq_len - - # Use this mask to select only the valid row indices, column indices, - # and the corresponding values from the logprob tensor. - # This flattens the selected elements into 1D tensors. - valid_rows = row_indices[valid_mask] - valid_cols = dest_indices[valid_mask] - valid_vals = logprob_tensor[valid_mask] - - # Place the valid values into their correct positions in the padded tensor - # using a single, efficient advanced indexing operation. - padded_logprobs[valid_rows, valid_cols] = valid_vals - - return padded_logprobs - -def autotune_batch_and_chunks( - total_input_rows, - seq_len, - hidden_size, - vocab_size, - dtype_bytes=16, - multiplier=None -): - if multiplier is None: - final_m = max(4, seq_len // 4096) - else: - final_m = multiplier - - if torch.cuda.is_available(): - free_bytes, _ = torch.cuda.mem_get_info() - limit_gb = (free_bytes / (1024**3))*.80 - elif hasattr(torch, "xpu") and torch.xpu.is_available(): - # For XPU: estimate free memory from total - reserved - total_mem = torch.xpu.get_device_properties(0).total_memory - reserved_mem = torch.xpu.memory_reserved() - free_bytes = total_mem - reserved_mem - limit_gb = (free_bytes / (1024**3)) * 0.80 - else: - # Fallback: assume 8GB available - limit_gb = 8.0 - - bytes_to_gb = 1024**3 - - b_vals = torch.arange(total_input_rows, 0, -1, device='cpu', dtype=torch.float32) - - hidden_gb = (b_vals * seq_len * hidden_size * dtype_bytes) / bytes_to_gb - - base_logits = ((b_vals/total_input_rows) * b_vals * seq_len * vocab_size * dtype_bytes) / bytes_to_gb - logits_gb = base_logits / final_m - - total_mem_gb = hidden_gb + logits_gb - - valid_mask = total_mem_gb <= limit_gb - valid_indices = torch.nonzero(valid_mask, as_tuple=False) - - if valid_indices.shape[0] == 0: - #This means your GPU will OOM - return 4, final_m - - best_idx = valid_indices[0].item() - final_b = int(b_vals[best_idx].item()) - - return final_b, final_m -@dataclass -class UnslothORPOConfig(ORPOConfig): - """ - - Configuration class for the [`ORPOTrainer`]. - - This class includes only the parameters that are specific to ORPO training. For a full list of training arguments, - please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may - differ from those in [`~transformers.TrainingArguments`]. - - Using [`~transformers.HfArgumentParser`] we can turn this class into - [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the - command line. - - Parameters: - max_length (`int` or `None`, *optional*, defaults to `1024`): - Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want - to use the default data collator. - max_prompt_length (`int` or `None`, *optional*, defaults to `512`): - Maximum length of the prompt. This argument is required if you want to use the default data collator. - max_completion_length (`int`, *optional*): - Maximum length of the completion. This argument is required if you want to use the default data collator - and your model is an encoder-decoder. - beta (`float`, *optional*, defaults to `0.1`): - Parameter controlling the relative ratio loss weight in the ORPO loss. In the - [paper](https://huggingface.co/papers/2403.07691), it is denoted by λ. In the - [code](https://github.com/xfactlab/orpo), it is denoted by `alpha`. - disable_dropout (`bool`, *optional*, defaults to `True`): - Whether to disable dropout in the model. - label_pad_token_id (`int`, *optional*, defaults to `-100`): - Label pad token id. This argument is required if you want to use the default data collator. - padding_value (`int`, *optional*): - Padding value to use. If `None`, the padding value of the tokenizer is used. - truncation_mode (`str`, *optional*, defaults to `"keep_end"`): - Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`. - This argument is required if you want to use the default data collator. - generate_during_eval (`bool`, *optional*, defaults to `False`): - If `True`, generates and logs completions from the model to W&B or Comet during evaluation. - is_encoder_decoder (`bool`, *optional*): - When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument, - you need to specify if the model returned by the callable is an encoder-decoder model. - model_init_kwargs (`dict[str, Any]`, *optional*): - Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a - string. - dataset_num_proc (`int`, *optional*): - Number of processes to use for processing the dataset. - - """ - vllm_sampling_params: Optional[Any] = field( - default = None, - metadata = {'help': 'vLLM SamplingParams'}, - ) - unsloth_num_chunks : Optional[int] = field( - default = -1, - metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'}, - ) - unsloth_logit_chunk_multiplier : Optional[int] = field( - default = None, - metadata = {'help': 'Multiplier for chunked logit computations.'}, - ) - unsloth_grpo_mini_batch : Optional[int] = field( - default = None, - metadata = {'help': 'Mini batch size for GRPO hidden state accumulation. Default is None unless user defines it.'}, - ) - max_seq_length : Optional[int] = field( - default = None, - metadata = {'help': 'Maximum sequence length to truncate to.'}, - ) - def __init__( - self, - output_dir = None, - overwrite_output_dir = None, - do_train = False, - do_eval = False, - do_predict = False, - eval_strategy = 'no', - prediction_loss_only = False, - per_device_train_batch_size = 4, - per_device_eval_batch_size = 4, - per_gpu_train_batch_size = None, - per_gpu_eval_batch_size = None, - gradient_accumulation_steps = 2, - eval_accumulation_steps = 2, - eval_delay = 0, - torch_empty_cache_steps = 250, - learning_rate = 5e-05, - weight_decay = 0.01, - adam_beta1 = 0.9, - adam_beta2 = 0.999, - adam_epsilon = 1e-08, - max_grad_norm = 1.0, - num_train_epochs = 3.0, - max_steps = -1, - lr_scheduler_type = 'linear', - lr_scheduler_kwargs = None, - warmup_ratio = 0.1, - warmup_steps = 0, - log_level = 'passive', - log_level_replica = 'warning', - log_on_each_node = True, - logging_dir = None, - logging_strategy = 'steps', - logging_first_step = False, - logging_steps = 1, - logging_nan_inf_filter = False, - save_strategy = 'steps', - save_steps = 500, - save_total_limit = None, - save_safetensors = True, - save_on_each_node = False, - save_only_model = False, - restore_callback_states_from_checkpoint = False, - no_cuda = False, - use_cpu = False, - use_mps_device = False, - seed = 3407, - data_seed = 3407, - jit_mode_eval = False, - bf16 = False, - fp16 = False, - fp16_opt_level = 'O1', - half_precision_backend = 'auto', - bf16_full_eval = False, - fp16_full_eval = False, - tf32 = None, - local_rank = -1, - ddp_backend = None, - tpu_num_cores = None, - tpu_metrics_debug = False, - debug = '', - dataloader_drop_last = False, - eval_steps = None, - dataloader_num_workers = 0, - dataloader_prefetch_factor = None, - past_index = -1, - run_name = None, - disable_tqdm = None, - remove_unused_columns = True, - label_names = None, - load_best_model_at_end = False, - metric_for_best_model = None, - greater_is_better = None, - ignore_data_skip = False, - fsdp = None, - fsdp_min_num_params = 0, - fsdp_config = None, - fsdp_transformer_layer_cls_to_wrap = None, - accelerator_config = None, - parallelism_config = None, - deepspeed = None, - label_smoothing_factor = 0.0, - optim = 'adamw_8bit', - optim_args = None, - adafactor = False, - group_by_length = False, - length_column_name = 'length', - report_to = 'none', - project = 'huggingface', - trackio_space_id = 'trackio', - ddp_find_unused_parameters = None, - ddp_bucket_cap_mb = None, - ddp_broadcast_buffers = None, - dataloader_pin_memory = True, - dataloader_persistent_workers = False, - skip_memory_metrics = True, - use_legacy_prediction_loop = False, - push_to_hub = False, - resume_from_checkpoint = None, - hub_model_id = None, - hub_strategy = 'every_save', - hub_token = None, - hub_private_repo = None, - hub_always_push = False, - hub_revision = None, - gradient_checkpointing = True, - gradient_checkpointing_kwargs = None, - include_inputs_for_metrics = False, - eval_do_concat_batches = True, - fp16_backend = 'auto', - push_to_hub_model_id = None, - push_to_hub_organization = None, - push_to_hub_token = None, - mp_parameters = '', - auto_find_batch_size = False, - full_determinism = False, - torchdynamo = None, - ray_scope = 'last', - ddp_timeout = 1800, - torch_compile = False, - torch_compile_backend = None, - torch_compile_mode = None, - include_tokens_per_second = False, - include_num_input_tokens_seen = False, - neftune_noise_alpha = None, - optim_target_modules = None, - batch_eval_metrics = False, - eval_on_start = False, - use_liger_kernel = False, - liger_kernel_config = None, - eval_use_gather_object = False, - average_tokens_across_devices = True, - max_length = 1024, - max_prompt_length = 512, - max_completion_length = None, - beta = 0.1, - disable_dropout = True, - label_pad_token_id = -100, - padding_value = None, - truncation_mode = 'keep_end', - generate_during_eval = False, - is_encoder_decoder = None, - model_init_kwargs = None, - dataset_num_proc = None, - vllm_sampling_params = None, - unsloth_num_chunks = -1, - unsloth_logit_chunk_multiplier = None, - unsloth_grpo_mini_batch = None, - max_seq_length = None, - **kwargs, - ): - if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!') - if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!') - if num_train_epochs is None: - num_train_epochs = 3.0 # Default to 3 epochs if None, max_steps will override - if output_dir is None and save_strategy == 'steps' and save_steps == 500: - output_dir = 'unsloth_training_checkpoints' - save_strategy = 'no' - import multiprocessing as _mp - if _mp.get_start_method() != 'fork': - dataset_num_proc = None - elif dataset_num_proc is None: - import psutil - dataset_num_proc = min(max((psutil.cpu_count() or 1)+4, 2), 64) - memory_gb_left = psutil.virtual_memory().available / (1024**3) - if memory_gb_left <= 2: dataset_num_proc = 1 - else: dataset_num_proc = min(dataset_num_proc, int(memory_gb_left)) - - super().__init__( - output_dir = output_dir, - overwrite_output_dir = overwrite_output_dir, - do_train = do_train, - do_eval = do_eval, - do_predict = do_predict, - eval_strategy = eval_strategy, - prediction_loss_only = prediction_loss_only, - per_device_train_batch_size = per_device_train_batch_size, - per_device_eval_batch_size = per_device_eval_batch_size, - per_gpu_train_batch_size = per_gpu_train_batch_size, - per_gpu_eval_batch_size = per_gpu_eval_batch_size, - gradient_accumulation_steps = gradient_accumulation_steps, - eval_accumulation_steps = eval_accumulation_steps, - eval_delay = eval_delay, - torch_empty_cache_steps = torch_empty_cache_steps, - learning_rate = learning_rate, - weight_decay = weight_decay, - adam_beta1 = adam_beta1, - adam_beta2 = adam_beta2, - adam_epsilon = adam_epsilon, - max_grad_norm = max_grad_norm, - num_train_epochs = num_train_epochs, - max_steps = max_steps, - lr_scheduler_type = lr_scheduler_type, - lr_scheduler_kwargs = lr_scheduler_kwargs, - warmup_ratio = warmup_ratio, - warmup_steps = warmup_steps, - log_level = log_level, - log_level_replica = log_level_replica, - log_on_each_node = log_on_each_node, - logging_dir = logging_dir, - logging_strategy = logging_strategy, - logging_first_step = logging_first_step, - logging_steps = logging_steps, - logging_nan_inf_filter = logging_nan_inf_filter, - save_strategy = save_strategy, - save_steps = save_steps, - save_total_limit = save_total_limit, - save_safetensors = save_safetensors, - save_on_each_node = save_on_each_node, - save_only_model = save_only_model, - restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint, - no_cuda = no_cuda, - use_cpu = use_cpu, - use_mps_device = use_mps_device, - seed = seed, - data_seed = data_seed, - jit_mode_eval = jit_mode_eval, - bf16 = bf16, - fp16 = fp16, - fp16_opt_level = fp16_opt_level, - half_precision_backend = half_precision_backend, - bf16_full_eval = bf16_full_eval, - fp16_full_eval = fp16_full_eval, - tf32 = tf32, - local_rank = local_rank, - ddp_backend = ddp_backend, - tpu_num_cores = tpu_num_cores, - tpu_metrics_debug = tpu_metrics_debug, - debug = debug, - dataloader_drop_last = dataloader_drop_last, - eval_steps = eval_steps, - dataloader_num_workers = dataloader_num_workers, - dataloader_prefetch_factor = dataloader_prefetch_factor, - past_index = past_index, - run_name = run_name, - disable_tqdm = disable_tqdm, - remove_unused_columns = remove_unused_columns, - label_names = label_names, - load_best_model_at_end = load_best_model_at_end, - metric_for_best_model = metric_for_best_model, - greater_is_better = greater_is_better, - ignore_data_skip = ignore_data_skip, - fsdp = fsdp, - fsdp_min_num_params = fsdp_min_num_params, - fsdp_config = fsdp_config, - fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap, - accelerator_config = accelerator_config, - parallelism_config = parallelism_config, - deepspeed = deepspeed, - label_smoothing_factor = label_smoothing_factor, - optim = optim, - optim_args = optim_args, - adafactor = adafactor, - group_by_length = group_by_length, - length_column_name = length_column_name, - report_to = report_to, - project = project, - trackio_space_id = trackio_space_id, - ddp_find_unused_parameters = ddp_find_unused_parameters, - ddp_bucket_cap_mb = ddp_bucket_cap_mb, - ddp_broadcast_buffers = ddp_broadcast_buffers, - dataloader_pin_memory = dataloader_pin_memory, - dataloader_persistent_workers = dataloader_persistent_workers, - skip_memory_metrics = skip_memory_metrics, - use_legacy_prediction_loop = use_legacy_prediction_loop, - push_to_hub = push_to_hub, - resume_from_checkpoint = resume_from_checkpoint, - hub_model_id = hub_model_id, - hub_strategy = hub_strategy, - hub_token = hub_token, - hub_private_repo = hub_private_repo, - hub_always_push = hub_always_push, - hub_revision = hub_revision, - gradient_checkpointing = gradient_checkpointing, - gradient_checkpointing_kwargs = gradient_checkpointing_kwargs, - include_inputs_for_metrics = include_inputs_for_metrics, - eval_do_concat_batches = eval_do_concat_batches, - fp16_backend = fp16_backend, - push_to_hub_model_id = push_to_hub_model_id, - push_to_hub_organization = push_to_hub_organization, - push_to_hub_token = push_to_hub_token, - mp_parameters = mp_parameters, - auto_find_batch_size = auto_find_batch_size, - full_determinism = full_determinism, - torchdynamo = torchdynamo, - ray_scope = ray_scope, - ddp_timeout = ddp_timeout, - torch_compile = torch_compile, - torch_compile_backend = torch_compile_backend, - torch_compile_mode = torch_compile_mode, - include_tokens_per_second = include_tokens_per_second, - include_num_input_tokens_seen = include_num_input_tokens_seen, - neftune_noise_alpha = neftune_noise_alpha, - optim_target_modules = optim_target_modules, - batch_eval_metrics = batch_eval_metrics, - eval_on_start = eval_on_start, - use_liger_kernel = use_liger_kernel, - liger_kernel_config = liger_kernel_config, - eval_use_gather_object = eval_use_gather_object, - average_tokens_across_devices = average_tokens_across_devices, - max_length = max_length, - max_prompt_length = max_prompt_length, - max_completion_length = max_completion_length, - beta = beta, - disable_dropout = disable_dropout, - label_pad_token_id = label_pad_token_id, - padding_value = padding_value, - truncation_mode = truncation_mode, - generate_during_eval = generate_during_eval, - is_encoder_decoder = is_encoder_decoder, - model_init_kwargs = model_init_kwargs, - dataset_num_proc = dataset_num_proc,**kwargs) - self.vllm_sampling_params = vllm_sampling_params - self.unsloth_num_chunks = unsloth_num_chunks - if unsloth_grpo_mini_batch is not None: - if self.generation_batch_size >= unsloth_grpo_mini_batch: - self.unsloth_grpo_mini_batch = unsloth_grpo_mini_batch - else: - raise ValueError( - f"Unsloth GRPO mini batch size needs to be less than or equal to the effective generation batch size, " - f"which is self.per_device_train_batch_size * gradient_accumulation_steps." - ) - self.unsloth_logit_chunk_multiplier = unsloth_logit_chunk_multiplier - self.max_seq_length = max_seq_length - -pass - -class _UnslothORPOTrainer(BaseTrainer): - r"""""" - - _tag_names = ["trl", "orpo"] - _name = "ORPO" - _paper = { - "title": "ORPO: Monolithic Preference Optimization without Reference Model", - "id": "2403.07691", - # docstyle-ignore - "citation": textwrap.dedent("""\ - @article{hong2024orpo, - title = {{ORPO: Monolithic Preference Optimization without Reference Model}}, - author = {Jiwoo Hong and Noah Lee and James Thorne}, - year = 2024, - eprint = {arXiv:2403.07691} - }"""), - } - - def __init__( - self, - model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, - args: Optional[ORPOConfig] = None, - data_collator: Optional[DataCollator] = None, - train_dataset: Optional[Dataset] = None, - eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None, - processing_class: Optional[ - Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] - ] = None, - model_init: Optional[Callable[[], PreTrainedModel]] = None, - callbacks: Optional[list[TrainerCallback]] = None, - optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), - preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, - peft_config: Optional[dict] = None, - compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None, - ): - if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"): - warnings.warn( - "This trainer will soon be moved to trl.experimental and is a candidate for removal. If you rely on " - "it and want it to remain, please share your comments here: " - "https://github.com/huggingface/trl/issues/4223. Silence this warning by setting environment variable " - "TRL_EXPERIMENTAL_SILENCE=1." - ) - if args.model_init_kwargs is None: - model_init_kwargs = {} - elif not isinstance(model, str): - raise ValueError("You passed model_kwargs to the ORPOTrainer. But your model is already instantiated.") - else: - model_init_kwargs = args.model_init_kwargs - dtype = model_init_kwargs.get("dtype") - if dtype is not None: - # Convert to `torch.dtype` if an str is passed - if isinstance(dtype, str) and dtype != "auto": - dtype = getattr(torch, dtype) - if dtype != "auto" and not isinstance(dtype, torch.dtype): - raise ValueError( - f"Invalid `dtype` passed to the ORPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {dtype}." - ) - model_init_kwargs["dtype"] = dtype - - if isinstance(model, str): - model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) - - # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16` - # has been called in order to properly call autocast if needed. - self._peft_has_been_casted_to_bf16 = False - - if not is_peft_available() and peft_config is not None: - raise ValueError( - "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models" - ) - elif is_peft_available() and peft_config is not None: - # if model is a peft model and we have a peft_config, we merge and unload it first - if isinstance(model, PeftModel): - model = model.merge_and_unload() - - if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): - _support_gc_kwargs = hasattr( - args, "gradient_checkpointing_kwargs" - ) and "gradient_checkpointing_kwargs" in list( - inspect.signature(prepare_model_for_kbit_training).parameters - ) - - prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing} - - if _support_gc_kwargs: - prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs - - model = prepare_model_for_kbit_training(model, **prepare_model_kwargs) - elif args.gradient_checkpointing: - # For backward compatibility with older versions of transformers - if hasattr(model, "enable_input_require_grads"): - model.enable_input_require_grads() - else: - - def make_inputs_require_grad(module, input, output): - output.requires_grad_(True) - - model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) - - # get peft model with the given config - model = model - if args.bf16 and getattr(model, "is_loaded_in_4bit", False): - peft_module_casting_to_bf16(model) - # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager - self._peft_has_been_casted_to_bf16 = True - - # For models that use gradient_checkpointing, we need to attach a hook that enables input - # to explicitly have `requires_grad=True`, otherwise training will either silently - # fail or completely fail. - elif args.gradient_checkpointing: - # For backward compatibility with older versions of transformers - if hasattr(model, "enable_input_require_grads"): - model.enable_input_require_grads() - else: - - def make_inputs_require_grad(module, input, output): - output.requires_grad_(True) - - model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) - - if args.generate_during_eval and not (is_wandb_available() or is_comet_available()): - raise ValueError( - "`generate_during_eval=True` requires Weights and Biases or Comet to be installed." - " Please install `wandb` or `comet-ml` to resolve." - ) - - if model is not None: - self.is_encoder_decoder = model.config.is_encoder_decoder - elif args.is_encoder_decoder is None: - raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.") - else: - self.is_encoder_decoder = args.is_encoder_decoder - - if self.is_encoder_decoder: - self.decoder_start_token_id = model.config.decoder_start_token_id - self.pad_token_id = model.config.pad_token_id - - if processing_class is None: - raise ValueError("processing_class must be specified to tokenize a ORPO dataset.") - if args.max_length is None: - logger.warning( - "`max_length` is not set in the ORPOConfig's init" - " it will default to `512` by default, but you should do it yourself in the future.", - ) - max_length = 512 - else: - max_length = args.max_length - if args.max_prompt_length is None: - logger.warning( - "`max_prompt_length` is not set in the ORPOConfig's init" - " it will default to `128` by default, but you should do it yourself in the future.", - ) - max_prompt_length = 128 - else: - max_prompt_length = args.max_prompt_length - - if args.max_completion_length is None and self.is_encoder_decoder: - logger.warning( - "When using an encoder decoder architecture, you should set `max_completion_length` in the ORPOConfig's init" - " it will default to `128` by default, but you should do it yourself in the future.", - ) - self.max_completion_length = 128 - else: - self.max_completion_length = args.max_completion_length - - if data_collator is None: - data_collator = DPODataCollatorWithPadding( - pad_token_id=processing_class.pad_token_id, - label_pad_token_id=args.label_pad_token_id, - is_encoder_decoder=self.is_encoder_decoder, - ) - - if args.remove_unused_columns: - args.remove_unused_columns = False - # warn users - logger.warning( - "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments" - " we have set it for you, but you should do it yourself in the future.", - ) - - self.use_dpo_data_collator = True - else: - self.use_dpo_data_collator = False - - # Disable dropout in the model and reference model - if args.disable_dropout: - disable_dropout_in_model(model) - - self.max_length = max_length - self.generate_during_eval = args.generate_during_eval - self.label_pad_token_id = args.label_pad_token_id - self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id - self.max_prompt_length = max_prompt_length - self.truncation_mode = args.truncation_mode - self.processing_class = processing_class - - self.beta = args.beta - self.aux_loss_enabled = getattr(model.config, "output_router_logits", False) - self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0) - if self.aux_loss_enabled and self.aux_loss_coef == 0.0: - logger.warning( - "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to " - "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value " - "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary " - "loss.", - ) - - self._stored_metrics = defaultdict(lambda: defaultdict(list)) - - # The trainer estimates the number of FLOPs [floating-point operations] using the number of elements in the - # input tensor associated with the key "input_ids". However, in ORPO, the sampled data does not include the - # "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and - # "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens - # of the input, floating-point operations will not be computed." To suppress this warning, we set the - # "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate - # that the warning has already been issued. - model.warnings_issued["estimate_tokens"] = True - - # Compute that only on the main process for faster data processing. - # see: https://github.com/huggingface/trl/pull/1255 - with PartialState().main_process_first(): - # Extract the prompt if needed, and apply the chat template if needed - train_dataset = train_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc) - train_dataset = train_dataset.map( - maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc - ) - train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc) - if eval_dataset is not None: - eval_dataset = eval_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc) - eval_dataset = eval_dataset.map( - maybe_apply_chat_template, - fn_kwargs={"tokenizer": processing_class}, - num_proc=args.dataset_num_proc, - ) - eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc) - - super().__init__( - model=model, - args=args, - data_collator=data_collator, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - processing_class=processing_class, - model_init=model_init, - compute_metrics=compute_metrics, - callbacks=callbacks, - optimizers=optimizers, - preprocess_logits_for_metrics=preprocess_logits_for_metrics, - ) - - # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the - # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set - # self.model_accepts_loss_kwargs to False to enable scaling. - self.model_accepts_loss_kwargs = False - - # Add tags for models that have been loaded with the correct transformers version - if hasattr(self.model, "add_model_tags"): - self.model.add_model_tags(self._tag_names) - - if not hasattr(self, "accelerator"): - raise AttributeError( - "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`." - ) - - def build_tokenized_answer(self, prompt, answer): - """ - Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`. It does ensure `enc(a + b) = enc(a) + enc(a + - b)[len(enc(a)):]`. Reference: - https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257 - """ - - full_tokenized = self.processing_class(prompt + answer, add_special_tokens=False) - prompt_input_ids = self.processing_class(prompt, add_special_tokens=False)["input_ids"] - - answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :] - answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :] - - # Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]` - full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids]) - - # Prepare input tokens for token by token comparison - full_input_ids = np.array(full_tokenized["input_ids"]) - - if len(full_input_ids) != len(full_concat_input_ids): - raise ValueError("Prompt input ids and answer input ids should have the same length.") - - # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens - # can be merged together when tokenizing prompt+answer. This could result - # on the last token from the prompt being different when tokenized on its own - # vs when done as prompt+answer. - response_token_ids_start_idx = len(prompt_input_ids) - - # If tokenized prompt is different than both prompt+answer, then it means the - # last token has changed due to merging. - if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]: - response_token_ids_start_idx -= 1 - - prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx] - prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx] - - if len(prompt_input_ids) != len(prompt_attention_mask): - raise ValueError("Prompt input ids and attention mask should have the same length.") - - answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:] - answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:] - - return dict( - prompt_input_ids=prompt_input_ids, - prompt_attention_mask=prompt_attention_mask, - input_ids=answer_input_ids, - attention_mask=answer_attention_mask, - ) - - def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> dict: - """Tokenize a single row from a ORPO specific dataset. - - At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation in case the prompt + - chosen or prompt + rejected responses is/are too long. First we truncate the prompt; if we're still too long, - we truncate the chosen/rejected. - - We also create the labels for the chosen/rejected responses, which are of length equal to the sum of the length - of the prompt and the chosen/rejected response, with label_pad_token_id for the prompt tokens. - """ - batch = {} - prompt = feature["prompt"] - chosen = feature["chosen"] - rejected = feature["rejected"] - - if not self.is_encoder_decoder: - # Check issues below for more details - # 1. https://github.com/huggingface/trl/issues/907 - # 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257 - # 3. https://github.com/LianjiaTech/BELLE/issues/337 - - if not isinstance(prompt, str): - raise ValueError(f"prompt should be an str but got {type(prompt)}") - prompt_tokens = self.processing_class(prompt, add_special_tokens=False) - prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()} - - if not isinstance(chosen, str): - raise ValueError(f"chosen should be an str but got {type(chosen)}") - chosen_tokens = self.build_tokenized_answer(prompt, chosen) - - if not isinstance(rejected, str): - raise ValueError(f"rejected should be an str but got {type(rejected)}") - rejected_tokens = self.build_tokenized_answer(prompt, rejected) - - # Last prompt token might get merged by tokenizer and - # it should not be included for generation if that happens - prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"]) - - chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"]) - rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"]) - prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids) - - for k, v in prompt_tokens.items(): - prompt_tokens[k] = v[:prompt_len_input_ids] - - # Make sure prompts only have one different token at most an - # and length only differs by 1 at most - num_diff_tokens = sum( - a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"]) - ) - num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids) - if num_diff_tokens > 1 or num_diff_len > 1: - raise ValueError( - "Chosen and rejected prompt_input_ids might only differ on the " - "last token due to tokenizer merge ops." - ) - - # add BOS token to head of prompt. Avoid adding if it's already there - prompt_tokens, chosen_tokens, rejected_tokens = add_bos_token_if_needed( - self.processing_class.bos_token_id, - prompt_len_input_ids, - prompt_tokens, - chosen_prompt_len_input_ids, - chosen_tokens, - rejected_prompt_len_input_ids, - rejected_tokens, - ) - - # add EOS token to end of answer. Avoid adding if it's already there - chosen_tokens, rejected_tokens = add_eos_token_if_needed( - self.processing_class.eos_token_id, chosen_tokens, rejected_tokens - ) - - longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"])) - - # if combined sequence is too long, truncate the prompt - for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]: - if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length: - if self.truncation_mode == "keep_start": - for k in ["prompt_input_ids", "prompt_attention_mask"]: - answer_tokens[k] = answer_tokens[k][: self.max_prompt_length] - elif self.truncation_mode == "keep_end": - for k in ["prompt_input_ids", "prompt_attention_mask"]: - answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :] - else: - raise ValueError(f"Unknown truncation mode: {self.truncation_mode}") - - # if that's still too long, truncate the response - for answer_tokens in [chosen_tokens, rejected_tokens]: - if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length: - for k in ["input_ids", "attention_mask"]: - answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length] - - # Create labels - chosen_sequence_tokens = { - k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"] - } - rejected_sequence_tokens = { - k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"] - } - chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:] - chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [ - self.label_pad_token_id - ] * len(chosen_tokens["prompt_input_ids"]) - rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:] - rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [ - self.label_pad_token_id - ] * len(rejected_tokens["prompt_input_ids"]) - - for k, toks in { - "chosen_": chosen_sequence_tokens, - "rejected_": rejected_sequence_tokens, - "": prompt_tokens, - }.items(): - for type_key, tokens in toks.items(): - if type_key == "token_type_ids": - continue - batch[f"{k}{type_key}"] = tokens - - else: - chosen_tokens = self.processing_class( - chosen, truncation=True, max_length=self.max_completion_length, add_special_tokens=True - ) - rejected_tokens = self.processing_class( - rejected, truncation=True, max_length=self.max_completion_length, add_special_tokens=True - ) - prompt_tokens = self.processing_class( - prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True - ) - - batch["chosen_labels"] = chosen_tokens["input_ids"] - batch["rejected_labels"] = rejected_tokens["input_ids"] - batch["prompt_input_ids"] = prompt_tokens["input_ids"] - batch["prompt_attention_mask"] = prompt_tokens["attention_mask"] - - if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"): - batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels( - labels=torch.tensor(batch["rejected_labels"]) - ) - batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels( - labels=torch.tensor(batch["chosen_labels"]) - ) - - if is_torch_xla_available(): - # Pad the sequences to global max_length to avoid TorchXLA recompilation - for k in batch: - if "labels" in k or self.is_encoder_decoder: - pad_value = self.label_pad_token_id - elif k.endswith("_input_ids"): - pad_value = self.padding_value - elif k.endswith("_attention_mask"): - pad_value = 0 - batch[k] = batch[k] + [pad_value] * (self.max_length - len(batch[k])) - return batch - - @staticmethod - def concatenated_inputs( - batch: dict[str, Union[list, torch.LongTensor]], - is_encoder_decoder: bool = False, - label_pad_token_id: int = -100, - padding_value: int = 0, - device: Optional[torch.device] = None, - ) -> dict[str, torch.LongTensor]: - """Concatenate the chosen and rejected inputs into a single tensor. - - Args: - batch: - A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors - of shape (batch_size, sequence_length). - is_encoder_decoder: - Whether the model is an encoder-decoder model. - label_pad_token_id: - The label pad token id. - padding_value: - The padding value to use for the concatenated inputs_ids. - device: - The device for the concatenated inputs. - - Returns: - A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'. - """ - concatenated_batch = {} - - if is_encoder_decoder: - max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1]) - else: - max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1]) - - for k in batch: - if k.startswith("chosen") and isinstance(batch[k], torch.Tensor): - if "labels" in k or is_encoder_decoder: - pad_value = label_pad_token_id - elif k.endswith("_input_ids"): - pad_value = padding_value - elif k.endswith("_attention_mask"): - pad_value = 0 - concatenated_key = k.replace("chosen", "concatenated") - concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value) - for k in batch: - if k.startswith("rejected") and isinstance(batch[k], torch.Tensor): - if "labels" in k or is_encoder_decoder: - pad_value = label_pad_token_id - elif k.endswith("_input_ids"): - pad_value = padding_value - elif k.endswith("_attention_mask"): - pad_value = 0 - concatenated_key = k.replace("rejected", "concatenated") - concatenated_batch[concatenated_key] = torch.cat( - ( - concatenated_batch[concatenated_key], - pad_to_length(batch[k], max_length, pad_value=pad_value), - ), - dim=0, - ).to(device=device) - - if is_encoder_decoder: - concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device) - concatenated_batch["concatenated_attention_mask"] = ( - batch["prompt_attention_mask"].repeat(2, 1).to(device=device) - ) - - return concatenated_batch - - def odds_ratio_loss( - self, - policy_chosen_logps: torch.FloatTensor, - policy_rejected_logps: torch.FloatTensor, - ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: - """Compute ORPO's odds ratio (OR) loss for a batch of policy and reference model log probabilities. - - Args: - policy_chosen_logps: - Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) - policy_rejected_logps: - Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) - - Returns: - A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). The losses tensor contains the ORPO - loss for each example in the batch. The chosen_rewards and rejected_rewards tensors contain the rewards for - the chosen and rejected responses, respectively. The log odds ratio of the chosen responses over the - rejected responses ratio for logging purposes. The `log(sigmoid(log_odds_chosen))` for logging purposes. - """ - - # Derived from Eqs. (4) and (7) from https://huggingface.co/papers/2403.07691 by using log identities and exp(log(P(y|x)) = P(y|x) - log_odds = (policy_chosen_logps - policy_rejected_logps) - ( - torch.log1p(-torch.exp(policy_chosen_logps)) - torch.log1p(-torch.exp(policy_rejected_logps)) - ) - ratio = F.logsigmoid(log_odds) - losses = self.beta * ratio - - chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach() - rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach() - - return losses, chosen_rewards, rejected_rewards, torch.mean(ratio), torch.mean(log_odds) - - @staticmethod - def get_batch_logps( - logits: torch.FloatTensor, - labels: torch.LongTensor, - average_log_prob: bool = False, - label_pad_token_id: int = -100, - is_encoder_decoder: bool = False, - ) -> torch.FloatTensor: - """Compute the log probabilities of the given labels under the given logits. - - Args: - logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) - labels: - Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are - ignored. Shape: (batch_size, sequence_length) - average_log_prob: - If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the - log probabilities of the (non-masked) tokens. - label_pad_token_id: The label pad token id. - is_encoder_decoder: Whether the model is an encoder-decoder model. - - Returns: - A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the - given logits. - """ - if logits.shape[:-1] != labels.shape: - raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.") - - if not is_encoder_decoder: - labels = labels[:, 1:].clone() - logits = logits[:, :-1, :] - loss_mask = labels != label_pad_token_id - - # dummy token; we'll ignore the losses on these tokens later - labels = torch.where(labels == label_pad_token_id, 0, labels) - - per_token_logps = selective_log_softmax(logits, labels) - - if average_log_prob: - return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) - else: - return (per_token_logps * loss_mask).sum(-1) - - def concatenated_forward( - self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]] - ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: - """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. - - We do this to avoid doing two forward passes, because it's faster for FSDP. - """ - concatenated_batch = self.concatenated_inputs( - batch, - is_encoder_decoder=self.is_encoder_decoder, - label_pad_token_id=self.label_pad_token_id, - padding_value=self.padding_value, - device=self.accelerator.device, - ) - len_chosen = batch["chosen_labels"].shape[0] - - model_kwargs = ( - { - "decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]), - } - if self.is_encoder_decoder - else {} - ) - - if self.aux_loss_enabled: - model_kwargs["output_router_logits"] = True - - outputs = model( - concatenated_batch["concatenated_input_ids"], - attention_mask=concatenated_batch["concatenated_attention_mask"], - use_cache=False, - **model_kwargs, - ) - all_logits = outputs.logits - - def cross_entropy_loss(logits, labels): - if not self.is_encoder_decoder: - # Shift so that tokens < n predict n - logits = logits[..., :-1, :].contiguous() - labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = nn.CrossEntropyLoss() - logits = logits.view(-1, logits.shape[-1]) - labels = labels.view(-1) - # Enable model parallelism - labels = labels.to(logits.device) - loss = loss_fct(logits, labels) - return loss - - if self.is_encoder_decoder: - labels = concatenated_batch["concatenated_labels"].clone() - else: - labels = concatenated_batch["concatenated_input_ids"].clone() - attention_mask = concatenated_batch["concatenated_attention_mask"] - labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id) - # orpo chosen nll loss is computed over the full prompt and response - chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen]) - - all_logps = self.get_batch_logps( - all_logits, - concatenated_batch["concatenated_labels"], - average_log_prob=True, - is_encoder_decoder=self.is_encoder_decoder, - label_pad_token_id=self.label_pad_token_id, - ) - - chosen_logps = all_logps[:len_chosen] - rejected_logps = all_logps[len_chosen:] - - if not self.is_encoder_decoder: - chosen_logits = all_logits[:len_chosen, :-1, :] - rejected_logits = all_logits[len_chosen:, :-1, :] - else: - chosen_logits = all_logits[:len_chosen] - rejected_logits = all_logits[len_chosen:] - - if self.aux_loss_enabled: - return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss, outputs.aux_loss) - - return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss) - - def get_batch_loss_metrics( - self, - model, - batch: dict[str, Union[list, torch.LongTensor]], - train_eval: Literal["train", "eval"] = "train", - ): - """Compute the ORPO loss and other metrics for the given batch of inputs for train or test.""" - metrics = {} - - forward_output = self.concatenated_forward(model, batch) - ( - policy_chosen_logps, - policy_rejected_logps, - policy_chosen_logits, - policy_rejected_logits, - policy_nll_loss, - ) = forward_output[:5] - if self.aux_loss_enabled: - aux_loss = forward_output[5] - - losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss( - policy_chosen_logps, policy_rejected_logps - ) - # full ORPO loss - loss = policy_nll_loss - losses.mean() - - reward_accuracies = (chosen_rewards > rejected_rewards).float() - - prefix = "eval_" if train_eval == "eval" else "" - metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean() - metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean() - metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean() - metrics[f"{prefix}rewards/margins"] = self.accelerator.gather_for_metrics( - chosen_rewards - rejected_rewards - ).mean() - metrics[f"{prefix}logps/rejected"] = self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean() - metrics[f"{prefix}logps/chosen"] = self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean() - metrics[f"{prefix}logits/rejected"] = self.accelerator.gather_for_metrics( - policy_rejected_logits.detach().mean() - ).mean() - metrics[f"{prefix}logits/chosen"] = self.accelerator.gather_for_metrics( - policy_chosen_logits.detach().mean() - ).mean() - metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean() - metrics[f"{prefix}log_odds_ratio"] = self.accelerator.gather_for_metrics(log_odds_ratio).detach().mean() - metrics[f"{prefix}log_odds_chosen"] = self.accelerator.gather_for_metrics(log_odds_chosen).detach().mean() - if is_torch_xla_available(): - xm.mark_step() # needed because .item() calls - for k, v in metrics.items(): - metrics[k] = v.item() - if self.aux_loss_enabled: - loss += self.aux_loss_coef * aux_loss - - return loss, metrics - - def compute_loss( - self, - model: Union[PreTrainedModel, nn.Module], - inputs: dict[str, Union[torch.Tensor, Any]], - return_outputs=False, - num_items_in_batch=None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]: - compute_loss_context_manager = ( - autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() - ) - - with compute_loss_context_manager: - loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train") - - # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class: - loss = loss.to(self.args.device) - - # force log the metrics - self.store_metrics(metrics, train_eval="train") - - if return_outputs: - return (loss, metrics) - return loss - - def generate_from_model(self, model, batch: dict[str, torch.LongTensor]) -> str: - """Generate samples from the model and reference model for the given batch of inputs.""" - - # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with - # the torch amp context manager as some hidden states are silently casted to full precision. - generate_context_manager = ( - autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() - ) - - with generate_context_manager: - policy_output = model.generate( - input_ids=batch["prompt_input_ids"], - attention_mask=batch["prompt_attention_mask"], - max_length=self.max_length, - do_sample=True, - pad_token_id=self.processing_class.pad_token_id, - ) - - policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id) - policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True) - - return policy_output_decoded - - def prediction_step( - self, - model: Union[PreTrainedModel, nn.Module], - inputs: dict[str, Union[torch.Tensor, Any]], - prediction_loss_only: bool, - ignore_keys: Optional[list[str]] = None, - ): - if not self.use_dpo_data_collator: - logger.warning( - "prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than " - "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator" - ) - if ignore_keys is None: - if hasattr(model, "config"): - ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", []) - else: - ignore_keys = [] - - prediction_context_manager = ( - autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() - ) - - with torch.no_grad(), prediction_context_manager: - loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval") - - # force log the metrics - self.store_metrics(metrics, train_eval="eval") - - if prediction_loss_only: - return (loss.detach(), None, None) - - # logits for the chosen and rejected samples from model - logits_dict = { - "eval_logits/chosen": metrics["eval_logits/chosen"], - "eval_logits/rejected": metrics["eval_logits/rejected"], - } - logits = [v for k, v in logits_dict.items() if k not in ignore_keys] - logits = torch.tensor(logits, device=self.accelerator.device) - labels = torch.zeros(logits.shape[0], device=self.accelerator.device) - - return (loss.detach(), logits, labels) - - def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None: - for key, value in metrics.items(): - self._stored_metrics[train_eval][key].append(value) - - def evaluation_loop( - self, - dataloader: DataLoader, - description: str, - prediction_loss_only: Optional[bool] = None, - ignore_keys: Optional[list[str]] = None, - metric_key_prefix: str = "eval", - ) -> EvalLoopOutput: - """ - Overriding built-in evaluation loop to store metrics for each batch. Prediction/evaluation loop, shared by - `Trainer.evaluate()` and `Trainer.predict()`. - - Works both with or without labels. - """ - - # Sample and save to game log if requested (for one batch to save time) - if self.generate_during_eval: - # Generate random indices within the range of the total number of samples - num_samples = len(dataloader.dataset) - random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size) - - # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader - random_batch_dataset = dataloader.dataset.select(random_indices) - random_batch = self.data_collator(random_batch_dataset) - random_batch = self._prepare_inputs(random_batch) - - policy_output_decoded = self.generate_from_model(self.model, random_batch) - - table = pd.DataFrame( - columns=["Prompt", "Policy"], - data=[ - [prompt, pol[len(prompt) :]] for prompt, pol in zip(random_batch["prompt"], policy_output_decoded) - ], - ) - if "wandb" in self.args.report_to: - wandb.log({"game_log": wandb.Table(data=table)}) - - if "comet_ml" in self.args.report_to: - log_table_to_comet_experiment( - name="game_log.csv", - table=table, - ) - - # Base evaluation - initial_output = super().evaluation_loop( - dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix - ) - - return initial_output - - def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None: - """ - Log `logs` on the various objects watching training, including stored metrics. - - Args: - logs (`dict[str, float]`): - The values to log. - start_time (`float`, *optional*): - Start time of the training. - """ - # logs either has 'loss' or 'eval_loss' - train_eval = "train" if "loss" in logs else "eval" - # Add averaged stored metrics to logs - for key, metrics in self._stored_metrics[train_eval].items(): - logs[key] = torch.tensor(metrics).mean().item() - del self._stored_metrics[train_eval] - return super().log(logs, start_time) - - def _shift_right(self, input_ids): - if self.decoder_start_token_id is None: - raise ValueError( - "model.config.decoder_start_token_id has to be defined. It is usually set to the pad_token_id." - ) - - # shift inputs to the right - if is_torch_fx_proxy(input_ids): - # Item assignment is not supported natively for proxies. - shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), self.decoder_start_token_id) - shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1) - else: - shifted_input_ids = input_ids.new_zeros(input_ids.shape) - shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() - shifted_input_ids[..., 0] = self.decoder_start_token_id - - if self.pad_token_id is None: - raise ValueError("model.config.pad_token_id has to be defined.") - # replace possible -100 values in labels by `pad_token_id` - shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id) - - return shifted_input_ids - - # Ensure the model card is saved along with the checkpoint - def _save_checkpoint(self, model, trial): - if self.args.hub_model_id is None: - model_name = Path(self.args.output_dir).name - else: - model_name = self.args.hub_model_id.split("/")[-1] - self.create_model_card(model_name=model_name) - super()._save_checkpoint(model, trial) -class UnslothORPOTrainer(_UnslothORPOTrainer): - """ - - Initialize ORPOTrainer. - - Args: - model ([`~transformers.PreTrainedModel`]): - The model to train, preferably an [`~transformers.AutoModelForSequenceClassification`]. - args ([`ORPOConfig`]): - The ORPO config arguments to use for training. - data_collator ([`~transformers.DataCollator`]): - The data collator to use for training. If None is specified, the default data collator - ([`DPODataCollatorWithPadding`]) will be used which will pad the sequences to the maximum length of the - sequences in the batch, given a dataset of paired sequences. - train_dataset ([`~datasets.Dataset`]): - The dataset to use for training. - eval_dataset ([`~datasets.Dataset`]): - The dataset to use for evaluation. - processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*): - Processing class used to process the data. If provided, will be used to automatically process the inputs - for the model, and it will be saved along the model to make it easier to rerun an interrupted training or - reuse the fine-tuned model. - model_init (`Callable[[], transformers.PreTrainedModel]`): - The model initializer to use for training. If None is specified, the default model initializer will be - used. - callbacks (`list[transformers.TrainerCallback]`): - The callbacks to use for training. - optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): - The optimizer and scheduler to use for training. - preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): - The function to use to preprocess the logits before computing the metrics. - peft_config (`dict`, defaults to `None`): - The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in - a PEFT model. - compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): - The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to - metric values. - - """ - def __init__( - self, - model = None, - args = None, - data_collator = None, - train_dataset = None, - eval_dataset = None, - processing_class = None, - model_init = None, - callbacks = None, - preprocess_logits_for_metrics = None, - peft_config = None, - compute_metrics = None, - **kwargs - ): - if args is None: args = UnslothORPOConfig() - use_bf16 = getattr(args, 'bf16', False) - if type(use_bf16) is not bool: use_bf16 = False - use_fp16 = getattr(args, 'fp16', False) - if type(use_fp16) is not bool: use_fp16 = False - force_float32 = False - full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1' - if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'): - print('Unsloth: Switching to float32 training since model cannot work with float16') - force_float32 = True - mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') - dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None) - if dtype is None: dtype = model.get_input_embeddings().weight.dtype - from unsloth_zoo.utils import _get_dtype - dtype = _get_dtype(dtype) - float16 = dtype == torch.float16 - if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`') - if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`') - if force_float32: - # Forced float32 training - args.fp16 = False - args.bf16 = False - os.environ['ACCELERATE_MIXED_PRECISION'] = 'no' - if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no' - # args.mixed_precision is a new argument which needs to be set now - elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32': - # Mixed precision training - args.fp16 = float16 - args.bf16 = not float16 - os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16' - if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16' - # args.mixed_precision is a new argument which needs to be set now - elif mixed_precision_dtype == 'bfloat16': - # Both False since bfloat16 full finetuning doesn't do any autocasting. - args.fp16 = False - args.bf16 = False - os.environ['ACCELERATE_MIXED_PRECISION'] = 'no' - if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no' - # args.mixed_precision is a new argument which needs to be set now - - if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no': - args.eval_strategy = 'steps' - if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1 - ga_steps = getattr(args, 'gradient_accumulation_steps', None) - if ga_steps is not None and ga_steps > 1: - from transformers import __version__ as transformers_version - if Version(transformers_version) <= Version('4.45.2'): - print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n' - '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`') - if getattr(args, 'eval_strategy', 'no') != 'no': - eval_bsz = getattr(args, 'per_device_eval_batch_size', 8) - if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size - if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps - fp16_full_eval = getattr(args, 'fp16_full_eval', False) - if type(fp16_full_eval) is not bool: fp16_full_eval = False - bf16_full_eval = getattr(args, 'bf16_full_eval', False) - if type(bf16_full_eval) is not bool: bf16_full_eval = False - if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True - if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False - if force_float32: - args.bf16_full_eval = False - args.fp16_full_eval = False - elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16': - args.bf16_full_eval = True - args.fp16_full_eval = False - elif not bf16_full_eval and not fp16_full_eval: - args.bf16_full_eval = args.bf16 - args.fp16_full_eval = args.fp16 - _output_logits = False - if locals().get('compute_metrics', None) is not None: _output_logits = True - if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True - if _output_logits: - os.environ['UNSLOTH_RETURN_LOGITS'] = '1' - if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'): - pass - else: - model_max_seq_length = getattr(model, 'max_seq_length', None) - args_max_seq_length = getattr(args, 'max_seq_length', None) - if args_max_seq_length is None and model_max_seq_length is not None: - max_seq_length = model.max_seq_length - if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length - elif args_max_seq_length is not None and model_max_seq_length is not None: - if args_max_seq_length > model_max_seq_length: - print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but ' - 'the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.') - args.max_seq_length = model_max_seq_length - if model is not None and hasattr(model, 'for_training'): - model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True)) - if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right' - if 'processing_class' in locals(): - if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right' - if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right' - __tokenizer = processing_class if 'processing_class' in locals() else tokenizer - from unsloth_zoo.vision_utils import UnslothVisionDataCollator - if not isinstance(data_collator, UnslothVisionDataCollator): - if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names: - data_collator = TransformersDataCollatorForLanguageModeling( - __tokenizer, - mlm = False, - mlm_probability = 0.0, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names: - data_collator = DataCollatorForSeq2Seq( - __tokenizer, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - else: - if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False - if hasattr(args, 'dataset_text_field'): args.dataset_text_field = '' - if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True} - if not isinstance(data_collator, UnslothVisionDataCollator): - if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'): - if isinstance(data_collator, DataCollatorForSeq2Seq): - data_collator = DataCollatorForSeq2Seq( - __tokenizer.tokenizer, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - else: - data_collator = TransformersDataCollatorForLanguageModeling( - __tokenizer.tokenizer, - mlm = False, - mlm_probability = 0.0, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - other_metrics = [] - - from unsloth_zoo.logging_utils import PatchRLStatistics - PatchRLStatistics('orpo_trainer', other_metrics) - - # [TODO] Fix up DataParallel multiplying batch sizes - # [TODO] DDP works, but DP seems to not work? [TODO] - if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1: - if getattr(args, "_n_gpu", 1) != 1: - args._n_gpu = 1 - if "model" in locals() and hasattr(model, "for_training"): - model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True)) - super().__init__( - model = model, - args = args, - data_collator = data_collator, - train_dataset = train_dataset, - eval_dataset = eval_dataset, - processing_class = processing_class, - model_init = model_init, - callbacks = callbacks, - preprocess_logits_for_metrics = preprocess_logits_for_metrics, - peft_config = peft_config, - compute_metrics = compute_metrics,**kwargs) - if "model" in locals() and hasattr(model, "for_inference"): - model.for_inference() - if hasattr(self, 'neftune_hook_handle'): - self.neftune_hook_handle.remove() - if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle - if getattr(args, 'neftune_noise_alpha', None) is not None: - model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha - pass - if hasattr(self, 'accelerator'): - scaler = self.accelerator.scaler - current_model = model - while hasattr(current_model, 'model'): - current_model.accelerator_scaler = scaler - current_model = current_model.model - current_model.accelerator_scaler = scaler - pass - if hasattr(self, 'train'): - self.train = MethodType(prepare_for_training_mode(self.__class__.train), self) - pass - if hasattr(self, 'llm') and self.llm is not None and hasattr(self.llm, 'get_tokenizer'): - _vllm_tok = self.llm.get_tokenizer() - _pc = getattr(self, 'processing_class', None) or getattr(self, 'tokenizer', None) - if _vllm_tok is not None and _pc is not None and getattr(_pc, 'chat_template', None) is not None and getattr(_vllm_tok, 'chat_template', None) is None: - _vllm_tok.chat_template = _pc.chat_template - pass - -pass - - -if hasattr(logger, "addFilter"): - import logging - class HideLoggingMessage(logging.Filter): - def __init__(self, text): self.text = text - def filter(self, x): return not (self.text in x.getMessage()) - pass - logger.addFilter(HideLoggingMessage("`use_cache=True`")) - diff --git a/unsloth_compiled_cache/UnslothOnlineDPOTrainer.py b/unsloth_compiled_cache/UnslothOnlineDPOTrainer.py deleted file mode 100644 index 2921887..0000000 --- a/unsloth_compiled_cache/UnslothOnlineDPOTrainer.py +++ /dev/null @@ -1,2467 +0,0 @@ -""" -2026.2.1 -2026.2.1 -4.57.6 -0.24.0 -__UNSLOTH_VERSIONING__ -""" - -# Unsloth auto generated code -# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see . - -from torch import Tensor -import torch -import torch.nn as nn -from torch.nn import functional as F -from unsloth_zoo.temporary_patches.common import torch_compile -from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable -from trl.trainer.online_dpo_trainer import (Any, AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, BasePairwiseJudge, BaseTrainer, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalPrediction, F, FSDP, GenerationConfig, IterableDataset, MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES, OnlineDPOConfig, OnlineDPOTrainer, OptimizerNames, Optional, Path, PeftConfig, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, RewardFunc, SIMPLE_CHAT_TEMPLATE, Trainer, TrainerCallback, Union, VLLMClient, apply_chat_template, broadcast_object_list, create_reference_model, disable_dropout_in_model, empty_cache, ensure_master_addr_port, gather_object, is_conversational, is_flash_attn_2_available, is_peft_model, is_vllm_available, jinja2, logger, logging, maybe_apply_chat_template, nn, nullcontext, os, pad, prepare_deepspeed, prepare_fsdp, profiling_context, re, seed_worker, textwrap, torch, truncate_right, unwrap_model_for_generation, version, warnings, wraps, AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, BasePairwiseJudge, Callable, DPODataCollatorWithPadding, DataCollator, Dataset, EvalPrediction, F, GenerationConfig, IterableDataset, MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES, OnlineDPOConfig, Optional, PeftConfig, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, RewardFunc, Trainer, TrainerCallback, Union, VLLMClient, create_reference_model, disable_dropout_in_model, ensure_master_addr_port, is_vllm_available, logger, nn, os, pad, prepare_deepspeed, prepare_fsdp, re, torch, version, warnings, F, apply_chat_template, is_conversational, re, F, FSDP, is_peft_model, nn, nullcontext, os, re, version, F, Optional, PreTrainedModel, Trainer, logger, os, re, torch, F, FSDP, nn, os, re, F, FSDP, nn, re, torch) - - -import os -from typing import * -from dataclasses import dataclass, field -from packaging.version import Version -import torch -import numpy as np -from contextlib import nullcontext -from torch.nn import functional as F -import inspect -from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling -from transformers.training_args import ParallelMode -from unsloth_zoo.device_type import DEVICE_TYPE, device_synchronize - -# Wrap trainer with padding to right and enable training mode -# Also patches W&B since multiple runs must use wandb.finish() -import functools -from types import MethodType -try: - from unsloth_zoo.gradient_checkpointing import reset_unsloth_gradient_checkpointing_buffers -except: - def reset_unsloth_gradient_checkpointing_buffers(): pass -def prepare_for_training_mode(f): - @functools.wraps(f) - def wrapper(self, *args, **kwargs): - # Enable training mode - _was_training = None - # Get gradient checkpointing setting from training arguments - use_gc = getattr(self.args, 'gradient_checkpointing', True) - if hasattr(self, 'model') and hasattr(self.model, "training"): - _was_training = self.model.training - if hasattr(self, 'model') and hasattr(self.model, "for_training"): - self.model.for_training(use_gradient_checkpointing=use_gc) - output = f(self, *args, **kwargs) - # Restore previous mode when possible - if hasattr(self, 'model') and hasattr(self.model, "for_inference"): - if _was_training is False: - self.model.for_inference() - elif _was_training is True and hasattr(self.model, "for_training"): - self.model.for_training(use_gradient_checkpointing=use_gc) - # Reset gradient checkpointing buffers to free memory while staying ready for next run - try: - reset_unsloth_gradient_checkpointing_buffers() - except: - pass - # Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run - try: - import wandb - wandb.finish() - except: - pass - return output - return wrapper -pass - -torch_compile_options = { - "epilogue_fusion" : True, - "max_autotune" : False, - "shape_padding" : True, - "trace.enabled" : False, - "triton.cudagraphs" : False, -} - -@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) -def chunked_hidden_states_selective_log_softmax( - hidden_states: torch.Tensor, - lm_head: torch.Tensor, - index: torch.Tensor, - chunks: int = 4, - logit_scale_multiply: float = 0.0, - logit_scale_divide: float = 0.0, - logit_softcapping: float = 0.0, - temperature: float = 1.0, -) -> torch.Tensor: - # All Unsloth Zoo code licensed under AGPL3 - flat_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1]) - flat_index = index.reshape(-1) - - chunked_hidden_states = torch.chunk(flat_hidden_states, chunks=chunks, dim=0) - chunked_index = torch.chunk(flat_index, chunks=chunks, dim=0) - - all_per_token_logps = [] - - for chunk_hidden_states, chunk_index in zip(chunked_hidden_states, chunked_index): - chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t() - - if logit_scale_multiply != 0.0: - chunk_logits = chunk_logits * logit_scale_multiply - if logit_scale_divide != 0.0: - chunk_logits = chunk_logits / logit_scale_divide - if logit_softcapping != 0.0: - chunk_logits = chunk_logits * torch.tanh(chunk_logits / logit_softcapping) - - chunk_logits = chunk_logits.to(torch.float32) - - if temperature != 1.0: - chunk_logits = chunk_logits / temperature - - selected_logits = torch.gather(chunk_logits, dim=-1, index=chunk_index.unsqueeze(-1)).squeeze(-1) - logsumexp_values = torch.logsumexp(chunk_logits, dim=-1) - per_token_logps = selected_logits - logsumexp_values - all_per_token_logps.append(per_token_logps) - - all_per_token_logps = torch.concat(all_per_token_logps) - - all_per_token_logps = all_per_token_logps.reshape((hidden_states.shape[0], hidden_states.shape[1])) - return all_per_token_logps - -@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) -def chunked_selective_log_softmax(logits, index): - # Split into 4 chunks only - chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0) - chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0) - all_per_token_logps = [] - # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index) - for chunk_logits, chunk_index in zip(chunked_logits, chunked_index): - chunk_logits = chunk_logits.to(torch.float32) - selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1) - logsumexp_values = torch.logsumexp(chunk_logits, dim = -1) - per_token_logps = selected_logits - logsumexp_values - all_per_token_logps.append(per_token_logps) - pass - all_per_token_logps = torch.concat(all_per_token_logps) - all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1])) - return all_per_token_logps - -def calculate_pad_tokens_in_prompt( - input_ids: torch.Tensor, - logits_to_keep: int, - pad_token_id: int -) -> torch.Tensor: - """ - Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens - """ - if logits_to_keep >= input_ids.shape[1]: - raise ValueError("logits_to_keep must be smaller than the sequence length.") - - prompt_section = input_ids[:, :-logits_to_keep] - - padding_mask = (prompt_section == pad_token_id) - - pad_token_counts = padding_mask.sum(dim=1) - - return pad_token_counts - -def create_completion_attention_mask( - completion_input_ids: torch.Tensor, - left_pad_tokens_per_prompt: torch.Tensor, - max_left_pad: int, - pad_token_id: int -) -> torch.Tensor: - """ - Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad] - - Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens - and pad are pad tokens, this function would make a completion mask that would 0 out the pad - and p tokens. so in this example [0,0,0,1,1,1,0,0,0] - """ - batch_size, completion_len = completion_input_ids.shape - device = completion_input_ids.device - - num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt - - indices = torch.arange(completion_len, device=device).unsqueeze(0) - shift_mask = indices >= num_tokens_to_mask.unsqueeze(1) - - non_padding_mask = (completion_input_ids != pad_token_id) - - final_mask = shift_mask & non_padding_mask - - return final_mask - -def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor: - """ - Moves all padding tokens in each sequence of a batch to the right. - """ - mask = (tensor != pad_id) - # Must do stable=True since binary mark is unordered - sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True) - packed_tensor = torch.gather(tensor, 1, sorted_indices) - return packed_tensor - -def align_logprobs_with_mask( - logprob_tensor: torch.Tensor, - attention_mask: torch.Tensor, - pad_value: float = 0.0 -) -> torch.Tensor: - """ - Aligns a log probability tensor with a given attention mask. - """ - - device = logprob_tensor.device - batch_size, logprob_seq_len = logprob_tensor.shape - mask_seq_len = attention_mask.shape[1] - - padded_logprobs = torch.full( - attention_mask.shape, - fill_value=pad_value, - dtype=logprob_tensor.dtype, - device=device - ) - - left_pad_counts = torch.argmax(attention_mask, dim=1) - - cols = torch.arange(logprob_seq_len, device=device) - dest_indices = left_pad_counts.unsqueeze(1) + cols - - # Create destination row indices - # Shape: [batch_size, logprob_seq_len] - row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices) - - # --- 4. Filter out-of-bounds indices and perform assignment --- - # Create a mask to identify only the indices that are within the bounds - # of the target tensor's sequence length. - valid_mask = dest_indices < mask_seq_len - - # Use this mask to select only the valid row indices, column indices, - # and the corresponding values from the logprob tensor. - # This flattens the selected elements into 1D tensors. - valid_rows = row_indices[valid_mask] - valid_cols = dest_indices[valid_mask] - valid_vals = logprob_tensor[valid_mask] - - # Place the valid values into their correct positions in the padded tensor - # using a single, efficient advanced indexing operation. - padded_logprobs[valid_rows, valid_cols] = valid_vals - - return padded_logprobs - -def autotune_batch_and_chunks( - total_input_rows, - seq_len, - hidden_size, - vocab_size, - dtype_bytes=16, - multiplier=None -): - if multiplier is None: - final_m = max(4, seq_len // 4096) - else: - final_m = multiplier - - if torch.cuda.is_available(): - free_bytes, _ = torch.cuda.mem_get_info() - limit_gb = (free_bytes / (1024**3))*.80 - elif hasattr(torch, "xpu") and torch.xpu.is_available(): - # For XPU: estimate free memory from total - reserved - total_mem = torch.xpu.get_device_properties(0).total_memory - reserved_mem = torch.xpu.memory_reserved() - free_bytes = total_mem - reserved_mem - limit_gb = (free_bytes / (1024**3)) * 0.80 - else: - # Fallback: assume 8GB available - limit_gb = 8.0 - - bytes_to_gb = 1024**3 - - b_vals = torch.arange(total_input_rows, 0, -1, device='cpu', dtype=torch.float32) - - hidden_gb = (b_vals * seq_len * hidden_size * dtype_bytes) / bytes_to_gb - - base_logits = ((b_vals/total_input_rows) * b_vals * seq_len * vocab_size * dtype_bytes) / bytes_to_gb - logits_gb = base_logits / final_m - - total_mem_gb = hidden_gb + logits_gb - - valid_mask = total_mem_gb <= limit_gb - valid_indices = torch.nonzero(valid_mask, as_tuple=False) - - if valid_indices.shape[0] == 0: - #This means your GPU will OOM - return 4, final_m - - best_idx = valid_indices[0].item() - final_b = int(b_vals[best_idx].item()) - - return final_b, final_m -def vLLMSamplingParams(**kwargs): - from vllm import SamplingParams - - sampling_params = SamplingParams(**kwargs) - sampling_params._set_kwargs = kwargs - return sampling_params -@dataclass -class UnslothOnlineDPOConfig(OnlineDPOConfig): - """ - - Configuration class for the [`OnlineDPOTrainer`]. - - This class includes only the parameters that are specific to Online DPO training. For a full list of training - arguments, please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this - class may differ from those in [`~transformers.TrainingArguments`]. - - Using [`~transformers.HfArgumentParser`] we can turn this class into - [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the - command line. - - Parameters: - reward_model_path (`str`, *optional*): - Path to the reward model. Either `judge` or `reward_model_path` must be set, but not both. - judge (`str`, *optional*): - Name of the judge to use. Either `judge` or `reward_model_path` must be set, but not both. - max_new_tokens (`int`, *optional*, defaults to `64`): - Maximum number of tokens to generate per completion. - max_length (`int`, *optional*, defaults to `256`): - Maximum total length of the sequence (prompt + completion) used to compute log probabilities. If the - sequence exceeds this limit, the leftmost tokens will be truncated to preserve as much of the completion as - possible. - temperature (`float`, *optional*, defaults to `0.9`): - Temperature for sampling. The higher the temperature, the more random the completions. - missing_eos_penalty (`float`, *optional*): - Penalty applied to the score when the model fails to generate an EOS token. This is useful to encourage to - generate completions shorter than the maximum length (`max_new_tokens`). The penalty must be a positive - value. This parameter only works when using `reward_funcs` and not when using `judge`. - beta (`float` or `list[float]`, *optional*, defaults to `0.1`): - Parameter controlling the deviation from the reference model. Higher β means less deviation from the - reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in - the [paper](https://huggingface.co/papers/2310.12036). If a list of floats is provided then the β is - selected for each new epoch and the last β is used for the rest of the epochs. - loss_type (`str`, *optional*, defaults to `"sigmoid"`): - Type of loss to use. Possible values are: - - - `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper. - - `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper. - - dataset_num_proc (`int`, *optional*): - Number of processes to use for processing the dataset. - - - - This parameter is deprecated and will be removed in version 0.25.0. Since OnlineDPO does not involve - dataset preparation, you can safely remove it. - - - - disable_dropout (`bool`, *optional*, defaults to `True`): - Whether to disable dropout in the model and reference model. - - > Parameters that control generation - - top_p (`float`, *optional*, defaults to `1.0`): - Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. Set to - `1.0` to consider all tokens. - top_k (`int`, *optional*): - Number of highest probability vocabulary tokens to keep for top-k-filtering. If `None`, top-k-filtering is - disabled and all tokens are considered. - min_p (`float`, *optional*): - Minimum token probability, which will be scaled by the probability of the most likely token. It must be a - value between `0.0` and `1.0`. Typical values are in the `0.01-0.2` range. - repetition_penalty (`float`, *optional*, defaults to `1.0`): - Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far. - Values > `1.0` encourage the model to use new tokens, while values < `1.0` encourage the model to repeat - tokens. - use_transformers_paged (`bool`, *optional*, defaults to `False`): - Whether to use the `transformers` paged implementation for generation. If set to `True`, the `transformers` - paged implementation will be used for generation instead of the default padded implementation. This - parameter is only effective when `use_vllm` is set to `False`. - cache_implementation (`str`, *optional*): - Implementation of the cache method for faster generation when `use_vllm` is set to `False`. - generation_kwargs (`dict[str, Any]`, *optional*): - Additional keyword arguments to pass to [`~transformers.GenerationConfig`] (if using transformers) or - `SamplingParams` (if using vLLM) when sampling completions. This can be used to further customize the - generation behavior, such as setting `suppress_tokens`, `num_beams`, etc. If it contains keys that conflict - with the other generation parameters (like `min_p`, `top_p`, etc.), they will override them. - - > Parameters that control generation acceleration powered by vLLM - - use_vllm (`bool`, *optional*, defaults to `False`): - Whether to use vLLM for generating completions. If set to `True`, the trainer will use vLLM for generation - instead of the default model.generate(). Requires `vllm` to be installed. - vllm_model_impl (`str`, *optional*, defaults to `"vllm"`): - Model implementation to use for vLLM. Must be one of `"transformers"` or `"vllm"`. `"transformers"`: Use - the `transformers` backend for model implementation. `"vllm"`: Use the `vllm` library for model - implementation. - vllm_mode (`str`, *optional*, defaults to `"server"`): - Mode to use for vLLM integration when `use_vllm` is set to `True`. Must be one of `"server"` or - `"colocate"`. - - - `"server"`: The trainer will send generation requests to a separate vLLM server. Make sure a TRL vLLM - server is running (start with `trl vllm-serve`). - - `"colocate"`: vLLM will run in the same process and share the training GPUs. This avoids the need for a - separate server but may cause resource contention with training. - vllm_guided_decoding_regex (`str`, *optional*): - Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled. - - > Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`) - - vllm_server_base_url (`str`, *optional*): - Base URL for the vLLM server (e.g., `"http://localhost:8000"`). If provided, `vllm_server_host` and - `vllm_server_port` are ignored. - vllm_server_host (`str`, *optional*, defaults to `"0.0.0.0"`): - Host of the vLLM server to connect to. Ignored if `vllm_server_base_url` is provided. - vllm_server_port (`int`, *optional*, defaults to `8000`): - Port of the vLLM server to connect to. Ignored if `vllm_server_base_url` is provided. - vllm_server_timeout (`float`, *optional*, defaults to `240.0`): - Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up after the - timeout, a `ConnectionError` is raised. - - > Parameters that control colocated vLLM execution (only used when `vllm_mode` is `"colocate"`) - - vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.55`): - Control the GPU memory utilization for vLLM. This setting only applies when `vllm_mode` is set to - `"colocate"`. If you are using `vllm_mode="server"`, this parameter must be passed separately when - launching the vLLM server via the `--vllm_gpu_memory_utilization` flag. - vllm_tensor_parallel_size (`int`, *optional*, defaults to `1`): - Control the tensor parallel size for vLLM. This setting only applies when `vllm_mode` is set to - `"colocate"`. If you are using `vllm_mode="server"`, this parameter must be passed separately when - launching the vLLM server via the `--vllm_tensor_parallel_size` flag. - - > Other parameters - - ds3_gather_for_generation (`bool`, *optional*, defaults to `True`): - This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation, - improving generation speed. However, disabling this option allows training models that exceed the VRAM - capacity of a single GPU, albeit at the cost of slower generation. Disabling this option is not compatible - with vLLM generation. - model_init_kwargs (`dict[str, Any]`, *optional*): - Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a - string. - - """ - vllm_sampling_params: Optional[Any] = field( - default = None, - metadata = {'help': 'vLLM SamplingParams'}, - ) - unsloth_num_chunks : Optional[int] = field( - default = -1, - metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'}, - ) - unsloth_logit_chunk_multiplier : Optional[int] = field( - default = None, - metadata = {'help': 'Multiplier for chunked logit computations.'}, - ) - unsloth_grpo_mini_batch : Optional[int] = field( - default = None, - metadata = {'help': 'Mini batch size for GRPO hidden state accumulation. Default is None unless user defines it.'}, - ) - max_seq_length : Optional[int] = field( - default = None, - metadata = {'help': 'Maximum sequence length to truncate to.'}, - ) - def __init__( - self, - output_dir = None, - overwrite_output_dir = None, - do_train = False, - do_eval = False, - do_predict = False, - eval_strategy = 'no', - prediction_loss_only = False, - per_device_train_batch_size = 4, - per_device_eval_batch_size = 4, - per_gpu_train_batch_size = None, - per_gpu_eval_batch_size = None, - gradient_accumulation_steps = 2, - eval_accumulation_steps = 2, - eval_delay = 0, - torch_empty_cache_steps = 250, - learning_rate = 5e-05, - weight_decay = 0.01, - adam_beta1 = 0.9, - adam_beta2 = 0.999, - adam_epsilon = 1e-08, - max_grad_norm = 1.0, - num_train_epochs = 3.0, - max_steps = -1, - lr_scheduler_type = 'linear', - lr_scheduler_kwargs = None, - warmup_ratio = 0.1, - warmup_steps = 0, - log_level = 'passive', - log_level_replica = 'warning', - log_on_each_node = True, - logging_dir = None, - logging_strategy = 'steps', - logging_first_step = False, - logging_steps = 1, - logging_nan_inf_filter = False, - save_strategy = 'steps', - save_steps = 500, - save_total_limit = None, - save_safetensors = True, - save_on_each_node = False, - save_only_model = False, - restore_callback_states_from_checkpoint = False, - no_cuda = False, - use_cpu = False, - use_mps_device = False, - seed = 3407, - data_seed = 3407, - jit_mode_eval = False, - bf16 = False, - fp16 = False, - fp16_opt_level = 'O1', - half_precision_backend = 'auto', - bf16_full_eval = False, - fp16_full_eval = False, - tf32 = None, - local_rank = -1, - ddp_backend = None, - tpu_num_cores = None, - tpu_metrics_debug = False, - debug = '', - dataloader_drop_last = False, - eval_steps = None, - dataloader_num_workers = 0, - dataloader_prefetch_factor = None, - past_index = -1, - run_name = None, - disable_tqdm = None, - remove_unused_columns = True, - label_names = None, - load_best_model_at_end = False, - metric_for_best_model = None, - greater_is_better = None, - ignore_data_skip = False, - fsdp = None, - fsdp_min_num_params = 0, - fsdp_config = None, - fsdp_transformer_layer_cls_to_wrap = None, - accelerator_config = None, - parallelism_config = None, - deepspeed = None, - label_smoothing_factor = 0.0, - optim = 'adamw_8bit', - optim_args = None, - adafactor = False, - group_by_length = False, - length_column_name = 'length', - report_to = 'none', - project = 'huggingface', - trackio_space_id = 'trackio', - ddp_find_unused_parameters = None, - ddp_bucket_cap_mb = None, - ddp_broadcast_buffers = None, - dataloader_pin_memory = True, - dataloader_persistent_workers = False, - skip_memory_metrics = True, - use_legacy_prediction_loop = False, - push_to_hub = False, - resume_from_checkpoint = None, - hub_model_id = None, - hub_strategy = 'every_save', - hub_token = None, - hub_private_repo = None, - hub_always_push = False, - hub_revision = None, - gradient_checkpointing = True, - gradient_checkpointing_kwargs = None, - include_inputs_for_metrics = False, - eval_do_concat_batches = True, - fp16_backend = 'auto', - push_to_hub_model_id = None, - push_to_hub_organization = None, - push_to_hub_token = None, - mp_parameters = '', - auto_find_batch_size = False, - full_determinism = False, - torchdynamo = None, - ray_scope = 'last', - ddp_timeout = 1800, - torch_compile = False, - torch_compile_backend = None, - torch_compile_mode = None, - include_tokens_per_second = False, - include_num_input_tokens_seen = False, - neftune_noise_alpha = None, - optim_target_modules = None, - batch_eval_metrics = False, - eval_on_start = False, - use_liger_kernel = False, - liger_kernel_config = None, - eval_use_gather_object = False, - average_tokens_across_devices = True, - reward_model_path = None, - judge = None, - max_new_tokens = 64, - max_length = 512, - temperature = 0.9, - top_p = 1.0, - top_k = None, - min_p = None, - repetition_penalty = 1.0, - generation_kwargs = {}, - use_transformers_paged = False, - cache_implementation = None, - missing_eos_penalty = None, - loss_type = 'sigmoid', - disable_dropout = True, - use_vllm = False, - vllm_model_impl = 'vllm', - vllm_guided_decoding_regex = None, - vllm_gpu_memory_utilization = 0.55, - vllm_mode = 'colocate', - vllm_server_base_url = None, - vllm_server_host = '0.0.0.0', - vllm_server_port = 8000, - vllm_server_timeout = 240.0, - vllm_tensor_parallel_size = 1, - ds3_gather_for_generation = True, - model_init_kwargs = None, - reward_weights = None, - dataset_num_proc = None, - gpu_memory_utilization = None, - vllm_sampling_params = None, - unsloth_num_chunks = -1, - unsloth_logit_chunk_multiplier = None, - unsloth_grpo_mini_batch = None, - max_seq_length = None, - **kwargs, - ): - if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!') - if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!') - if num_train_epochs is None: - num_train_epochs = 3.0 # Default to 3 epochs if None, max_steps will override - if output_dir is None and save_strategy == 'steps' and save_steps == 500: - output_dir = 'unsloth_training_checkpoints' - save_strategy = 'no' - import multiprocessing as _mp - if _mp.get_start_method() != 'fork': - dataset_num_proc = None - elif dataset_num_proc is None: - import psutil - dataset_num_proc = min(max((psutil.cpu_count() or 1)+4, 2), 64) - memory_gb_left = psutil.virtual_memory().available / (1024**3) - if memory_gb_left <= 2: dataset_num_proc = 1 - else: dataset_num_proc = min(dataset_num_proc, int(memory_gb_left)) - if temperature <= 0: - raise ValueError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.') - elif temperature >= 10: - raise ValueError('Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic.') - - - super().__init__( - output_dir = output_dir, - overwrite_output_dir = overwrite_output_dir, - do_train = do_train, - do_eval = do_eval, - do_predict = do_predict, - eval_strategy = eval_strategy, - prediction_loss_only = prediction_loss_only, - per_device_train_batch_size = per_device_train_batch_size, - per_device_eval_batch_size = per_device_eval_batch_size, - per_gpu_train_batch_size = per_gpu_train_batch_size, - per_gpu_eval_batch_size = per_gpu_eval_batch_size, - gradient_accumulation_steps = gradient_accumulation_steps, - eval_accumulation_steps = eval_accumulation_steps, - eval_delay = eval_delay, - torch_empty_cache_steps = torch_empty_cache_steps, - learning_rate = learning_rate, - weight_decay = weight_decay, - adam_beta1 = adam_beta1, - adam_beta2 = adam_beta2, - adam_epsilon = adam_epsilon, - max_grad_norm = max_grad_norm, - num_train_epochs = num_train_epochs, - max_steps = max_steps, - lr_scheduler_type = lr_scheduler_type, - lr_scheduler_kwargs = lr_scheduler_kwargs, - warmup_ratio = warmup_ratio, - warmup_steps = warmup_steps, - log_level = log_level, - log_level_replica = log_level_replica, - log_on_each_node = log_on_each_node, - logging_dir = logging_dir, - logging_strategy = logging_strategy, - logging_first_step = logging_first_step, - logging_steps = logging_steps, - logging_nan_inf_filter = logging_nan_inf_filter, - save_strategy = save_strategy, - save_steps = save_steps, - save_total_limit = save_total_limit, - save_safetensors = save_safetensors, - save_on_each_node = save_on_each_node, - save_only_model = save_only_model, - restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint, - no_cuda = no_cuda, - use_cpu = use_cpu, - use_mps_device = use_mps_device, - seed = seed, - data_seed = data_seed, - jit_mode_eval = jit_mode_eval, - bf16 = bf16, - fp16 = fp16, - fp16_opt_level = fp16_opt_level, - half_precision_backend = half_precision_backend, - bf16_full_eval = bf16_full_eval, - fp16_full_eval = fp16_full_eval, - tf32 = tf32, - local_rank = local_rank, - ddp_backend = ddp_backend, - tpu_num_cores = tpu_num_cores, - tpu_metrics_debug = tpu_metrics_debug, - debug = debug, - dataloader_drop_last = dataloader_drop_last, - eval_steps = eval_steps, - dataloader_num_workers = dataloader_num_workers, - dataloader_prefetch_factor = dataloader_prefetch_factor, - past_index = past_index, - run_name = run_name, - disable_tqdm = disable_tqdm, - remove_unused_columns = remove_unused_columns, - label_names = label_names, - load_best_model_at_end = load_best_model_at_end, - metric_for_best_model = metric_for_best_model, - greater_is_better = greater_is_better, - ignore_data_skip = ignore_data_skip, - fsdp = fsdp, - fsdp_min_num_params = fsdp_min_num_params, - fsdp_config = fsdp_config, - fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap, - accelerator_config = accelerator_config, - parallelism_config = parallelism_config, - deepspeed = deepspeed, - label_smoothing_factor = label_smoothing_factor, - optim = optim, - optim_args = optim_args, - adafactor = adafactor, - group_by_length = group_by_length, - length_column_name = length_column_name, - report_to = report_to, - project = project, - trackio_space_id = trackio_space_id, - ddp_find_unused_parameters = ddp_find_unused_parameters, - ddp_bucket_cap_mb = ddp_bucket_cap_mb, - ddp_broadcast_buffers = ddp_broadcast_buffers, - dataloader_pin_memory = dataloader_pin_memory, - dataloader_persistent_workers = dataloader_persistent_workers, - skip_memory_metrics = skip_memory_metrics, - use_legacy_prediction_loop = use_legacy_prediction_loop, - push_to_hub = push_to_hub, - resume_from_checkpoint = resume_from_checkpoint, - hub_model_id = hub_model_id, - hub_strategy = hub_strategy, - hub_token = hub_token, - hub_private_repo = hub_private_repo, - hub_always_push = hub_always_push, - hub_revision = hub_revision, - gradient_checkpointing = gradient_checkpointing, - gradient_checkpointing_kwargs = gradient_checkpointing_kwargs, - include_inputs_for_metrics = include_inputs_for_metrics, - eval_do_concat_batches = eval_do_concat_batches, - fp16_backend = fp16_backend, - push_to_hub_model_id = push_to_hub_model_id, - push_to_hub_organization = push_to_hub_organization, - push_to_hub_token = push_to_hub_token, - mp_parameters = mp_parameters, - auto_find_batch_size = auto_find_batch_size, - full_determinism = full_determinism, - torchdynamo = torchdynamo, - ray_scope = ray_scope, - ddp_timeout = ddp_timeout, - torch_compile = torch_compile, - torch_compile_backend = torch_compile_backend, - torch_compile_mode = torch_compile_mode, - include_tokens_per_second = include_tokens_per_second, - include_num_input_tokens_seen = include_num_input_tokens_seen, - neftune_noise_alpha = neftune_noise_alpha, - optim_target_modules = optim_target_modules, - batch_eval_metrics = batch_eval_metrics, - eval_on_start = eval_on_start, - use_liger_kernel = use_liger_kernel, - liger_kernel_config = liger_kernel_config, - eval_use_gather_object = eval_use_gather_object, - average_tokens_across_devices = average_tokens_across_devices, - reward_model_path = reward_model_path, - judge = judge, - max_new_tokens = max_new_tokens, - max_length = max_length, - temperature = temperature, - top_p = top_p, - top_k = top_k, - min_p = min_p, - repetition_penalty = repetition_penalty, - generation_kwargs = generation_kwargs, - use_transformers_paged = use_transformers_paged, - cache_implementation = cache_implementation, - missing_eos_penalty = missing_eos_penalty, - loss_type = loss_type, - disable_dropout = disable_dropout, - use_vllm = use_vllm, - vllm_model_impl = vllm_model_impl, - vllm_guided_decoding_regex = vllm_guided_decoding_regex, - vllm_gpu_memory_utilization = vllm_gpu_memory_utilization, - vllm_mode = vllm_mode, - vllm_server_base_url = vllm_server_base_url, - vllm_server_host = vllm_server_host, - vllm_server_port = vllm_server_port, - vllm_server_timeout = vllm_server_timeout, - vllm_tensor_parallel_size = vllm_tensor_parallel_size, - ds3_gather_for_generation = ds3_gather_for_generation, - model_init_kwargs = model_init_kwargs, - reward_weights = reward_weights, - dataset_num_proc = dataset_num_proc, - gpu_memory_utilization = gpu_memory_utilization,**kwargs) - self.vllm_sampling_params = vllm_sampling_params - self.unsloth_num_chunks = unsloth_num_chunks - if unsloth_grpo_mini_batch is not None: - if self.generation_batch_size >= unsloth_grpo_mini_batch: - self.unsloth_grpo_mini_batch = unsloth_grpo_mini_batch - else: - raise ValueError( - f"Unsloth GRPO mini batch size needs to be less than or equal to the effective generation batch size, " - f"which is self.per_device_train_batch_size * gradient_accumulation_steps." - ) - self.unsloth_logit_chunk_multiplier = unsloth_logit_chunk_multiplier - self.max_seq_length = max_seq_length - -pass - -class _UnslothOnlineDPOTrainer(BaseTrainer): - r"""""" - - _tag_names = ["trl", "online-dpo"] - _name = "Online DPO" - _paper = { - "title": "Direct Language Model Alignment from Online AI Feedback", - "id": "2402.04792", - # docstyle-ignore - "citation": textwrap.dedent("""\ - @article{guo2024direct, - title = {{Direct Language Model Alignment from Online AI Feedback}}, - author = {Shangmin Guo and Biao Zhang and Tianlin Liu and Tianqi Liu and Misha Khalman and Felipe Llinares and Alexandre Ram{\'{e}} and Thomas Mesnard and Yao Zhao and Bilal Piot and Johan Ferret and Mathieu Blondel}, - year = 2024, - eprint = {arXiv:2402.04792} - }"""), - } - - def __init__( - self, - model: Union[PreTrainedModel, nn.Module, str], - ref_model: Union[PreTrainedModel, nn.Module, None] = None, - reward_funcs: Optional[Union[RewardFunc, list[RewardFunc]]] = None, - judge: Optional[BasePairwiseJudge] = None, - args: Optional[OnlineDPOConfig] = None, - data_collator: Optional[DataCollator] = None, - train_dataset: Optional[Union[Dataset, IterableDataset]] = None, - eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None, - processing_class: Optional[Union[PreTrainedTokenizerBase, ProcessorMixin]] = None, - reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None, - peft_config: Optional["PeftConfig"] = None, - compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None, - callbacks: Optional[list[TrainerCallback]] = None, - optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), - preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, - # Deprecated parameters - reward_model: Optional[Union[PreTrainedModel, nn.Module]] = None, - reward_processing_class: Optional[PreTrainedTokenizerBase] = None, - ) -> None: - - if hasattr(model, 'vllm_engine') and hasattr(args, 'use_vllm'): - if (getattr(args, 'use_vllm', False) == False): - args.use_vllm = True - if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"): - warnings.warn( - "This trainer will soon be moved to trl.experimental and is a candidate for removal. If you rely on " - "it and want it to remain, please share your comments here: " - "https://github.com/huggingface/trl/issues/4223. Silence this warning by setting environment variable " - "TRL_EXPERIMENTAL_SILENCE=1." - ) - if ref_model is model: - raise ValueError( - "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the " - "same as `model`, either omit the `ref_model` argument or pass `None`." - ) - - self.ref_model = ref_model - - # Handle deprecated parameters for backward compatibility - if reward_model is not None: - warnings.warn( - "The `reward_model` parameter is deprecated and will be removed in version 0.25.0. " - "Please use `reward_funcs` instead. For example, change `reward_model=model` to `reward_funcs=model`.", - ) - # Convert old reward_model to new reward_funcs format - if reward_funcs is None: - reward_funcs = reward_model - else: - warnings.warn( - "Both `reward_model` and `reward_funcs` are provided. Using `reward_funcs` and ignoring " - "`reward_model`.", - ) - - if reward_processing_class is not None: - warnings.warn( - "The `reward_processing_class` parameter is deprecated and will be removed in version 0.25.0. " - "Please use `reward_processing_classes` instead. For example, change " - "`reward_processing_class=tokenizer` to `reward_processing_classes=tokenizer`.", - ) - # Convert old reward_processing_class to new reward_processing_classes format - if reward_processing_classes is None: - reward_processing_classes = reward_processing_class - else: - warnings.warn( - "Both `reward_processing_class` and `reward_processing_classes` are provided. Using " - "`reward_processing_classes` and ignoring `reward_processing_class`.", - ) - - # Validate reward configuration - must have exactly one of: judge, or reward_funcs - reward_configs = sum(x is not None for x in [judge, reward_funcs]) - if reward_configs == 0: - raise ValueError("One of `judge` or `reward_funcs` must be provided.") - elif reward_configs > 1: - if judge is not None: - logger.warning( - "Both `judge` and `reward_funcs` are provided. Using `judge` and ignoring `reward_funcs`.", - UserWarning, - ) - reward_funcs = None - self.judge = judge - - # Handle reward_funcs - if reward_funcs is not None: - if not isinstance(reward_funcs, list): - reward_funcs = [reward_funcs] - self.reward_func_names = [] - - # Process reward functions [convert strings to models, collect names] - model_init_kwargs = args.model_init_kwargs or {} - for i, reward_func in enumerate(reward_funcs): - if isinstance(reward_func, str): - # Load model from string path - reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained( - reward_func, num_labels=1, **model_init_kwargs - ) - if isinstance(reward_funcs[i], nn.Module): - self.reward_func_names.append(reward_funcs[i].config._name_or_path.split("/")[-1]) - else: - self.reward_func_names.append(reward_funcs[i].__name__) - self.reward_funcs = reward_funcs - - # Handle reward processing classes for reward_funcs - if reward_processing_classes is None: - reward_processing_classes = [None] * len(reward_funcs) - elif not isinstance(reward_processing_classes, list): - reward_processing_classes = [reward_processing_classes] - else: - if len(reward_processing_classes) != len(reward_funcs): - raise ValueError( - "The number of reward processing classes must match the number of reward functions." - ) - - self.reward_processing_classes = [] - for reward_processing_class_i, reward_func in zip(reward_processing_classes, reward_funcs): - if isinstance(reward_func, PreTrainedModel): - if reward_processing_class_i is None: - reward_processing_class_i = AutoTokenizer.from_pretrained(reward_func.config._name_or_path) - if reward_processing_class_i.pad_token_id is None: - reward_processing_class_i.pad_token = reward_processing_class_i.eos_token - # Set pad token ID on reward model config - reward_func.config.pad_token_id = reward_processing_class_i.pad_token_id - self.reward_processing_classes.append(reward_processing_class_i) - else: - self.reward_funcs = None - self.reward_func_names = [] - self.reward_processing_classes = [] - - # Handle reward_weights - if reward_funcs is not None: - if args.reward_weights is not None: - if len(args.reward_weights) != len(self.reward_funcs): - raise ValueError( - f"Number of reward weights ({len(args.reward_weights)}) must match number of reward " - f"functions ({len(self.reward_funcs)})" - ) - self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32) - else: - self.reward_weights = torch.ones(len(self.reward_funcs), dtype=torch.float32) - else: - self.reward_weights = None - - if args.missing_eos_penalty is not None and reward_funcs is None and judge is None: - # Check if this is the old reward_model case - if reward_model is not None: - logger.warning( - "The `missing_eos_penalty` parameter is deprecated when used with the deprecated `reward_model` parameter. " - "Please use `reward_funcs` instead of `reward_model` to continue using this feature.", - FutureWarning, - stacklevel=2, - ) - else: - raise ValueError("`missing_eos_penalty` is only supported when `reward_funcs` is provided.") - - if args is None: - raise ValueError("`args` must be provided.") - - # Check that the processing_class is provided - if processing_class is None: - raise ValueError("`processing_class` must be provided.") - - model_init_kwargs = args.model_init_kwargs or {} - if isinstance(model, str): - model_id = model - - # Handle dtype in model_init_kwargs - dtype = model_init_kwargs.get("dtype") - if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None: - pass - elif isinstance(dtype, str): - dtype = getattr(torch, dtype) - model_init_kwargs["dtype"] = dtype - else: - raise ValueError( - "Invalid `dtype` passed to `OnlineDPOConfig`. Expected either 'auto' or a string " - f"representing a `torch.dtype` (e.g., 'float32'), but got {dtype}." - ) - - model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs) - else: - if args.model_init_kwargs is not None: - raise ValueError( - "You passed `model_init_kwargs` to the `OnlineDPOConfig`, but your model is already instantiated. " - "This argument can only be used when the `model` argument is a string." - ) - self.is_encoder_decoder = model.config.is_encoder_decoder - self.is_vision_model = model.config.model_type in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.keys() - - if False: - pass - - # Enable gradient checkpointing if requested - if args.gradient_checkpointing: - model = self._enable_gradient_checkpointing(model, args) - - # Disable dropout in the model and reference model - if args.disable_dropout: - disable_dropout_in_model(model) - if self.ref_model is not None: - disable_dropout_in_model(self.ref_model) - - # Handle the ref_model - # Usually, the user wants the ref model to be the initial version of the model. When using PEFT, it's easy to - # get the ref model, as it's just the model with a disabled adapter. When not using PEFT, we need to create - # the ref model from the model by copying it and disable the gradients and set it in evaluation mode. - if ref_model is None: # No ref model provided, the most common case - if False: - self.ref_model = create_reference_model(model) # copy, disable gradients, set eval mode - else: - self.ref_model = None # we don't need a ref model here, we can just disable the adapter. - else: # rare case, the user provided a ref model - self.ref_model = ref_model - self.ref_model.eval() - - # Disable the gradient and set the reward model in eval mode - if reward_funcs is not None: - for reward_func in reward_funcs: - if isinstance(reward_func, PreTrainedModel): - reward_func.eval() - - self.max_length = args.max_length - - self.stats = { - "objective/kl": [], - "objective/entropy": [], - "objective/non_score_reward": [], - "rewards/chosen": [], - "rewards/rejected": [], - "rewards/accuracies": [], - "rewards/margins": [], - "logps/chosen": [], - "logps/rejected": [], - "val/contain_eos_token": [], - "beta": [], - } - if self.reward_funcs is not None: - self.stats["objective/rlhf_reward"] = [] - self.stats["objective/scores_margin"] = [] - self.stats["objective/scores"] = [] - - # Store generation parameters for later use - self.use_vllm = args.use_vllm - self.num_generations = 2 # Generate 2 completions per prompt for Online DPO - self.temperature = args.temperature - self.top_p = args.top_p - self.top_k = args.top_k - self.min_p = args.min_p - self.repetition_penalty = args.repetition_penalty - self.use_transformers_paged = args.use_transformers_paged - self.vllm_mode = args.vllm_mode if args.use_vllm else None - self.vllm_gpu_memory_utilization = args.vllm_gpu_memory_utilization - self.vllm_tensor_parallel_size = args.vllm_tensor_parallel_size - self.vllm_model_impl = args.vllm_model_impl - - # Handle pad token for processors or tokenizers - if isinstance(processing_class, ProcessorMixin): - tokenizer = processing_class.tokenizer - elif isinstance(processing_class, PreTrainedTokenizerBase): - tokenizer = processing_class - else: - raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") - - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - - self.pad_token = tokenizer.pad_token - self.pad_token_id = tokenizer.pad_token_id - self.eos_token_id = tokenizer.eos_token_id - - # Vision tokens for VLM support - self.image_token_id = getattr(processing_class, "image_token_id", None) - self.vision_start_token_id = getattr(processing_class, "vision_start_token_id", None) - self.vision_end_token_id = getattr(processing_class, "vision_end_token_id", None) - # Get the image token string for token collapsing - self.image_token = None - if self.image_token_id is not None: - self.image_token = tokenizer.decode([self.image_token_id]) - - # Define the collator if not provided - if data_collator is None: - data_collator = DPODataCollatorWithPadding(pad_token_id=self.pad_token_id) - - # The trainer estimates the number of FLOPs [floating-point operations] using the number of elements in the - # input tensor associated with the key "input_ids". However, in Online DPO, the sampled data does not include - # the "input_ids" key. As a result, the trainer issues the warning: "Could not estimate the number of tokens - # of the input, floating-point operations will not be computed." To suppress this warning, we set the - # "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate - # that the warning has already been issued. - model.warnings_issued["estimate_tokens"] = True - - super().__init__( - model=model, - args=args, - data_collator=data_collator, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - processing_class=processing_class, - compute_metrics=compute_metrics, - callbacks=callbacks, - optimizers=optimizers, - preprocess_logits_for_metrics=preprocess_logits_for_metrics, - ) - - # Add tags for models that have been loaded with the correct transformers version - if hasattr(self.model, "add_model_tags"): - self.model.add_model_tags(self._tag_names) - - self._beta = args.beta - - # Set up generation configuration and vLLM after super[].__init__ - if self.use_vllm: - if not is_vllm_available(): - raise ImportError( - "vLLM is not available and `use_vllm` is set to True. Please install vLLM with " - "`pip install trl[vllm]` to use it." - ) - - if self.vllm_mode == "server": - if self.accelerator.is_main_process: - if args.vllm_server_base_url is not None: - base_url = args.vllm_server_base_url - else: - base_url = f"http://{args.vllm_server_host}:{args.vllm_server_port}" - self.vllm_client = VLLMClient(base_url=base_url, connection_timeout=args.vllm_server_timeout) - self.vllm_client.init_communicator(device=torch.cuda.current_device()) - else: - self.vllm_client = None - elif self.vllm_mode == "colocate": - vllm_kwargs = { - "model": model.name_or_path, - "tensor_parallel_size": self.vllm_tensor_parallel_size, - "gpu_memory_utilization": self.vllm_gpu_memory_utilization, - "model_impl": self.vllm_model_impl, - "max_num_seqs": self.args.per_device_train_batch_size * self.vllm_tensor_parallel_size, - "max_model_len": args.max_length + args.max_new_tokens, - "distributed_executor_backend": "external_launcher", - "seed": self.accelerator.process_index // self.vllm_tensor_parallel_size, - "max_num_batched_tokens": 4096, - } - os.environ["RANK"] = str(self.accelerator.process_index) - os.environ["LOCAL_RANK"] = str(self.accelerator.local_process_index) - os.environ["WORLD_SIZE"] = str(self.accelerator.num_processes) - ensure_master_addr_port() - - self.llm = model.vllm_engine - else: - raise ValueError(f"vllm_mode must be either 'server' or 'colocate', got '{self.vllm_mode}'.") - self.guided_decoding_regex = args.vllm_guided_decoding_regex - self._last_loaded_step = -1 - generation_params = { - "n": 2, - "repetition_penalty": self.repetition_penalty, - "temperature": self.temperature, - "top_p": self.top_p, - "top_k": -1 if self.top_k is None else self.top_k, - "min_p": 0.0 if self.min_p is None else self.min_p, - "max_tokens": args.max_new_tokens, - "detokenize": False, - } - if args.generation_kwargs is not None: - generation_params.update(args.generation_kwargs) - if self.guided_decoding_regex: - generation_params["guided_decoding"] = GuidedDecodingParams(regex=self.guided_decoding_regex) - self.generation_config = SamplingParams(**generation_params) - self.accelerator.wait_for_everyone() - else: - # Set up transformers generation config - generation_kwargs = { - "max_new_tokens": args.max_new_tokens, - "do_sample": True, - "pad_token_id": self.pad_token_id, - "bos_token_id": tokenizer.bos_token_id, - "eos_token_id": self.eos_token_id, - "temperature": self.temperature, - "top_k": self.top_k, - "top_p": self.top_p, - "repetition_penalty": self.repetition_penalty, - "use_cache": True if not self.args.gradient_checkpointing else False, - } - # Add min_p if supported - if self.min_p is not None: - generation_kwargs["min_p"] = self.min_p - if args.generation_kwargs is not None: - generation_kwargs.update(args.generation_kwargs) - # Remove None values - generation_kwargs = {k: v for k, v in generation_kwargs.items() if v is not None} - self.generation_config = GenerationConfig(**generation_kwargs) - - if self.ref_model is not None: - if self.is_deepspeed_enabled: - self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) - elif self.is_fsdp_enabled: - self.ref_model = prepare_fsdp(self.ref_model, self.accelerator) - else: - self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) - if self.reward_funcs is not None: - for i, reward_func in enumerate(self.reward_funcs): - if isinstance(reward_func, PreTrainedModel): - if self.is_deepspeed_enabled: - self.reward_funcs[i] = prepare_deepspeed(reward_func, self.accelerator) - else: - # set device placement to True to make `prepare_model` move `reward_func` to device when using fsdp - self.reward_funcs[i] = self.accelerator.prepare_model( - reward_func, evaluation_mode=True, device_placement=True - ) - - @property - def beta(self): - if isinstance(self._beta, list): - epoch = self.state.epoch - return self._beta[epoch] if epoch < len(self._beta) else self._beta[-1] - else: - return self._beta - - @staticmethod - def tokenize_row(feature, is_encoder_decoder: bool, tokenizer: PreTrainedTokenizerBase) -> dict[str, Any]: - """Tokenize a single row from a DPO specific dataset.""" - if not is_encoder_decoder: - batch = tokenizer(feature["prompt"], add_special_tokens=False) - # Add BOS token to head of prompt. Avoid adding if it's already there - if tokenizer.bos_token_id is not None: - prompt_len_input_ids = len(batch["input_ids"]) - if prompt_len_input_ids == 0 or tokenizer.bos_token_id != batch["input_ids"][0]: - batch["input_ids"] = [tokenizer.bos_token_id] + batch["input_ids"] - batch["attention_mask"] = [1] + batch["attention_mask"] - else: - batch = tokenizer(feature["prompt"], add_special_tokens=True) - batch = {f"prompt_{key}": value for key, value in batch.items()} - return batch - - # Same as Trainer.get_train_dataloader but skip the "remove_unused_columns". - @wraps(Trainer.get_train_dataloader) - def get_train_dataloader(self) -> DataLoader: - if self.train_dataset is None: - raise ValueError("Trainer: training requires a train_dataset.") - - train_dataset = self.train_dataset - data_collator = self.data_collator - dataloader_params = { - "batch_size": self._train_batch_size, - "collate_fn": data_collator, - "num_workers": self.args.dataloader_num_workers, - "pin_memory": self.args.dataloader_pin_memory, - "persistent_workers": self.args.dataloader_persistent_workers, - } - - if not isinstance(train_dataset, torch.utils.data.IterableDataset): - dataloader_params["sampler"] = self._get_train_sampler() - dataloader_params["drop_last"] = self.args.dataloader_drop_last - dataloader_params["worker_init_fn"] = seed_worker - dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor - - return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) - - # Same as Trainer.get_eval_dataloader but skip the "remove_unused_columns". - @wraps(Trainer.get_eval_dataloader) - def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None) -> DataLoader: - if eval_dataset is None and self.eval_dataset is None: - raise ValueError("Trainer: evaluation requires an eval_dataset.") - - # If we have persistent workers, don't do a fork bomb especially as eval datasets - # don't change during training - dataloader_key = eval_dataset if isinstance(eval_dataset, str) else "eval" - if ( - hasattr(self, "_eval_dataloaders") - and dataloader_key in self._eval_dataloaders - and self.args.dataloader_persistent_workers - ): - return self.accelerator.prepare(self._eval_dataloaders[dataloader_key]) - - eval_dataset = ( - self.eval_dataset[eval_dataset] - if isinstance(eval_dataset, str) - else eval_dataset - if eval_dataset is not None - else self.eval_dataset - ) - data_collator = self.data_collator - - dataloader_params = { - "batch_size": self.args.eval_batch_size, - "collate_fn": data_collator, - "num_workers": self.args.dataloader_num_workers, - "pin_memory": self.args.dataloader_pin_memory, - "persistent_workers": self.args.dataloader_persistent_workers, - } - - if not isinstance(eval_dataset, torch.utils.data.IterableDataset): - dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset) - dataloader_params["drop_last"] = self.args.dataloader_drop_last - dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor - - # accelerator.free_memory() will destroy the references, so - # we need to store the non-prepared version - eval_dataloader = DataLoader(eval_dataset, **dataloader_params) - if self.args.dataloader_persistent_workers: - if hasattr(self, "_eval_dataloaders"): - self._eval_dataloaders[dataloader_key] = eval_dataloader - else: - self._eval_dataloaders = {dataloader_key: eval_dataloader} - - return self.accelerator.prepare(eval_dataloader) - - def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: OnlineDPOConfig) -> PreTrainedModel: - """Enables gradient checkpointing for the model.""" - # Ensure use_cache is disabled - model.config.use_cache = False - - # Enable gradient checkpointing on the base model for PEFT - if is_peft_model(model): - model.base_model.gradient_checkpointing_enable() - # Enable gradient checkpointing for non-PEFT models - else: - model.gradient_checkpointing_enable() - - gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {} - use_reentrant = ( - "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"] - ) - - if use_reentrant: - model.enable_input_require_grads() - - return model - - def _generate_vllm(self, prompts, images=None): - eos_token_id = self.eos_token_id - pad_token_id = self.pad_token_id - - # Generate completion_ids and prompt_ids based on mode - if self.vllm_mode == "server": - completion_ids, prompt_ids = self._generate_vllm_server(prompts, images) - elif self.vllm_mode == "colocate": - completion_ids, prompt_ids = self._generate_vllm_colocate(prompts, images) - - # Shared padding, masking, and tensor conversion logic - max_prompt_length = max(len(ids) for ids in prompt_ids) - prompt_mask = [[0] * (max_prompt_length - len(ids)) + [1] * len(ids) for ids in prompt_ids] - prompt_ids = [[pad_token_id] * (max_prompt_length - len(ids)) + ids for ids in prompt_ids] - max_tokens = self.generation_config.max_tokens - completion_mask = [[1] * len(ids) + [0] * (max_tokens - len(ids)) for ids in completion_ids] - completion_ids = [ - ids + [eos_token_id] if ids[-1] != eos_token_id and len(ids) < max_tokens else ids - for ids in completion_ids - ] - completion_ids = [ids + [pad_token_id] * (max_tokens - len(ids)) for ids in completion_ids] - - # Convert to tensors - prompt_ids = torch.tensor(prompt_ids, device=self.accelerator.device) - prompt_mask = torch.tensor(prompt_mask, device=self.accelerator.device) - completion_ids = torch.tensor(completion_ids, device=self.accelerator.device) - completion_mask = torch.tensor(completion_mask, device=self.accelerator.device) - - return prompt_ids, prompt_mask, completion_ids, completion_mask - - def _generate_vllm_server(self, prompts, images=None): - """Generate completions using vLLM server mode""" - has_images = images is not None - - # Update vLLM server weights if needed - if hasattr(self, "_last_loaded_step") and self.state.global_step != self._last_loaded_step: - self._move_model_to_vllm() - self._last_loaded_step = self.state.global_step - elif not hasattr(self, "_last_loaded_step"): - self._move_model_to_vllm() - self._last_loaded_step = self.state.global_step - - # Apply chat template if conversational - if is_conversational({"prompt": prompts[0]}): - prompts_text = [apply_chat_template({"prompt": p}, self.processing_class)["prompt"] for p in prompts] - else: - prompts_text = prompts - # Gather all prompts to main process - all_prompts = gather_object(prompts_text) - if has_images: - all_images = gather_object(images) - - if self.accelerator.is_main_process: - # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate - # num_generations outputs for each one. This is faster than generating outputs for each duplicate - # prompt individually. - ordered_set_of_prompts = all_prompts[:: self.num_generations] - if has_images: - ordered_set_of_images = all_images[:: self.num_generations] - else: - ordered_set_of_images = None - completion_ids = self.vllm_client.generate( - prompts=ordered_set_of_prompts, - images=ordered_set_of_images, - n=self.num_generations, - repetition_penalty=self.repetition_penalty, - temperature=self.temperature, - top_p=self.top_p, - top_k=-1 if self.top_k is None else self.top_k, - min_p=0.0 if self.min_p is None else self.min_p, - max_tokens=self.generation_config.max_tokens, - guided_decoding_regex=self.guided_decoding_regex if hasattr(self, "guided_decoding_regex") else None, - generation_kwargs=self.args.generation_kwargs, - ) - # Flatten: each prompt generates 2 completions - completion_ids = [[comp_id] for prompt_completions in completion_ids for comp_id in prompt_completions] - else: - completion_ids = [None] * (len(all_prompts) * 2) - - # Broadcast completions to all processes - completion_ids = broadcast_object_list(completion_ids, from_process=0) - - # Each process takes its slice - process_slice = slice( - self.accelerator.process_index * len(prompts) * 2, - (self.accelerator.process_index + 1) * len(prompts) * 2, - ) - completion_ids = completion_ids[process_slice] - - # Create prompt_ids by tokenizing locally - prompt_inputs = self.processing_class( - text=prompts_text, - return_tensors="pt", - padding=True, - padding_side="left", - add_special_tokens=False, - ) - prompt_ids = [] - for prompt_tokens in prompt_inputs["input_ids"]: - prompt_ids.extend([prompt_tokens.tolist(), prompt_tokens.tolist()]) # 2 copies for 2 completions - return completion_ids, prompt_ids - - def _generate_vllm_colocate(self, prompts, images=None): - """Generate completions using vLLM colocate mode""" - # Update model weights if needed - only after gradient accumulation completes - if self.state.global_step != self._last_loaded_step: - self._move_model_to_vllm() - self._last_loaded_step = self.state.global_step - - # Apply chat template if conversational - if is_conversational({"prompt": prompts[0]}): - prompts_text = [apply_chat_template({"prompt": p}, self.processing_class)["prompt"] for p in prompts] - else: - prompts_text = prompts - - # Prepare vLLM inputs with images if available - if images is not None: - vllm_inputs = [] - for prompt, image in zip(prompts_text, images): - if image is not None: - vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image}}) - else: - vllm_inputs.append(prompt) - else: - vllm_inputs = prompts_text - - outputs = self.llm.generate(vllm_inputs, self.generation_config, use_tqdm=False, lora_request = self.model.load_lora('online_dpo_trainer_lora_model', load_tensors = True)) - - completion_ids = [list(output.outputs[i].token_ids) for i in range(2) for output in outputs] - prompt_ids = [list(output.prompt_token_ids) for _ in range(2) for output in outputs] - - return completion_ids, prompt_ids - - def _move_model_to_vllm(self): - """Synchronize model weights to vLLM server with support for PEFT, DeepSpeed, and FSDP""" - # For DeepSpeed ZeRO-3 and FSDP, we need to gather all parameters before operations - deepspeed_plugin = self.accelerator.state.deepspeed_plugin - zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3 - if zero_stage_3: - import deepspeed - - gather_if_zero3 = deepspeed.zero.GatheredParameters - else: - gather_if_zero3 = nullcontext - - if is_peft_model(self.model): - # With PEFT and FSDP/DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as - # merging adapters in a sharded manner is not supported. - # TODO: does this work with FSDP? - with gather_if_zero3(list(self.model.parameters())): - self.model.merge_adapter() - - # Update vLLM weights while parameters are gathered - if self.is_fsdp_enabled: # note if using FSDP, gather_if_zero3 is nullcontext - # Update vLLM weights while parameters are gathered - # For PEFT with FSDP we need to use the memory efficient post-order traversal - fsdp_plugin = getattr(self.accelerator.state, "fsdp_plugin", None) - fsdp_version = getattr(fsdp_plugin, "fsdp_version", 1) if fsdp_plugin else 1 - if fsdp_version == 1: - # use memory-efficient post-order traversal for FSDP - self._sync_fsdp1_params_to_vllm(self.model) - elif fsdp_version == 2: - self._sync_fsdp2_params_to_vllm(self.model) - else: - # DeepSpeed ZeRO-3 with PEFT - for name, param in self.model.named_parameters(): - # When using PEFT, we need to recover the original parameter name and discard some parameters - name = name.removeprefix("base_model.model.").replace(".base_layer", "") - if self.model.prefix in name: - continue - # When module to save, remove its prefix and discard the original module - if "original_module" in name: - continue - name = self._fix_param_name_to_vllm(name, extra_prefixes=["modules_to_save.default."]) - - if self.vllm_mode == "server" and self.accelerator.is_main_process: - self.vllm_client.update_named_param(name, param.data) - elif self.vllm_mode == "colocate": - - pass - - pass - # Unmerge adapters while parameters are still gathered - self.model.unmerge_adapter() - # Parameters will automatically be repartitioned when exiting the context - else: - # For non-PEFT models, simply gather (if needed) and update each parameter individually. - if self.is_fsdp_enabled: - fsdp_plugin = getattr(self.accelerator.state, "fsdp_plugin", None) - fsdp_version = getattr(fsdp_plugin, "fsdp_version", 1) if fsdp_plugin else 1 - if fsdp_version == 1: - self._sync_fsdp1_params_to_vllm(self.model) # use memory-efficient post-order traversal for FSDP - elif fsdp_version == 2: - self._sync_fsdp2_params_to_vllm(self.model) - else: - for name, param in self.model.named_parameters(): - name = self._fix_param_name_to_vllm(name) - with gather_if_zero3([param]): - if self.vllm_mode == "server" and self.accelerator.is_main_process: - self.vllm_client.update_named_param(name, param.data) - elif self.vllm_mode == "colocate": - - pass - - pass - - # Reset cache on vLLM - if self.vllm_mode == "server" and self.accelerator.is_main_process: - self.vllm_client.reset_prefix_cache() - elif self.vllm_mode == "colocate": - self.llm.reset_prefix_cache() - - def _sync_fsdp1_params_to_vllm(self, module: nn.Module, prefix: str = "", visited=None): - """Memory-efficient post-order traversal of FSDP modules to extract full parameters and sync with vLLM.""" - # For FSDP1, we need to recurse into children and also use summon_full_params - if visited is None: - visited = set() - for child_name, child_module in module.named_children(): - child_prefix = f"{prefix}.{child_name}" if prefix else child_name - self._sync_fsdp1_params_to_vllm( - child_module, prefix=child_prefix, visited=visited - ) # recurse into the child - - if isinstance(module, FSDP): - with FSDP.summon_full_params(module, recurse=False, writeback=False): - for param_name, param in module.named_parameters(): - full_name = f"{prefix}.{param_name}" if prefix else param_name - full_name = self._fix_param_name_to_vllm(full_name, extra_prefixes=["_fsdp_wrapped_module."]) - - if full_name in visited: - continue # skip FSDP subtrees already traversed - visited.add(full_name) - - if self.vllm_mode == "server" and self.accelerator.is_main_process: - self.vllm_client.update_named_param(full_name, param.data) - elif self.vllm_mode == "colocate": - - pass - - pass - - def _sync_fsdp2_params_to_vllm(self, module: nn.Module): - # For FSDP2, module already covers all parameters, so no need for recursion - for name, param in module.items(): - if param.is_cpu: - param = param.to(torch.device("cuda")) - param = param.full_tensor() - - if self.vllm_mode == "server" and self.accelerator.is_main_process: - self.vllm_client.update_named_param(name, param) - elif self.vllm_mode == "colocate": - - pass - - pass - - def _fix_param_name_to_vllm(self, name, extra_prefixes: Optional[list[str]] = None): - """Clean parameter names for vLLM compatibility""" - extra_prefixes = extra_prefixes or [] - prefixes = ["_checkpoint_wrapped_module."] + extra_prefixes - for prefix in prefixes: - name = name.replace(prefix, "") - return name - - def process_vision_row( - self, features: dict[str, Union[list, torch.Tensor]], processing_class=None - ) -> dict[str, list[int]]: - """ - Process a vision row for VLM models (adapted from DPO trainer) - """ - processor = processing_class or self.processing_class - processed_features = processor(images=[features["image"]], text=features["prompt"], add_special_tokens=False) - - prompt_input_ids = processed_features["input_ids"][0] - - # Create the output dict with required fields - output = { - "prompt_input_ids": prompt_input_ids, - "prompt_attention_mask": processed_features["attention_mask"][0], - } - - # Add vision-specific fields - if "pixel_values" in processed_features: - output["pixel_values"] = processed_features["pixel_values"][0] - if "pixel_attention_mask" in processed_features: - output["pixel_attention_mask"] = processed_features["pixel_attention_mask"][0] - if "image_sizes" in processed_features: - output["image_sizes"] = processed_features["image_sizes"][0] - - return output - - def _generate(self, model, prompts, images=None): - """Generate completions using the model""" - device = next(model.parameters()).device - eos_token_id = self.eos_token_id - pad_token_id = self.pad_token_id - - # Apply chat template and tokenize the input - inputs = [{"prompt": prompt} for prompt in prompts] - - # Add images if provided (VLM support) - if images is not None: - for i, image in enumerate(images): - inputs[i]["image"] = image - - # Apply chat template to get text prompts - prompts_text = [maybe_apply_chat_template(x, self.processing_class)["prompt"] for x in inputs] - - # Handle image token collapsing/removal - # The chat template sometimes inserts a single image token into the prompt text. However, when this text is - # later tokenized, the single image token string is expanded into multiple image token IDs, depending on the - # image size. We need to handle this properly. - if self.image_token is not None and images is not None: - escaped_img_token = re.escape(self.image_token) - # Search for the image token in the chat template - if hasattr(self.processing_class, "chat_template") and self.processing_class.chat_template: - if re.search(escaped_img_token, self.processing_class.chat_template): - # Collapse repeated image tokens back into a single token - prompts_text = [ - re.sub(rf"({escaped_img_token})+", self.image_token, text) for text in prompts_text - ] - else: - # If the chat template doesn't use the image token, remove all instances - if self.vision_end_token_id is not None: - escaped_eoi_token = re.escape( - self.processing_class.tokenizer.decode([self.vision_end_token_id]) - ) - prompts_text = [ - re.sub(rf"({escaped_img_token})+{escaped_eoi_token}", "", text) for text in prompts_text - ] - else: - # If vision_end_token_id is None, just remove the image tokens - prompts_text = [re.sub(rf"({escaped_img_token})+", "", text) for text in prompts_text] - - # Prepare kwargs for processing class - kwargs = {} - if images is not None: - kwargs = {"images": [[img] for img in images]} - - # Process inputs using the processing class (handles both VLM and LLM) - prompt_inputs = self.processing_class( - text=prompts_text, - return_tensors="pt", - padding=True, - padding_side="left", - add_special_tokens=False, - **kwargs, - ) - - prompt_inputs = {k: v.to(device) for k, v in prompt_inputs.items()} - # Convert vision inputs to model's dtype for proper computation - if "pixel_values" in prompt_inputs: - # Handle DataParallel wrapped models - model_dtype = getattr(model, "dtype", None) - if model_dtype is None and hasattr(model, "module"): - model_dtype = model.module.dtype - if model_dtype is not None: - prompt_inputs["pixel_values"] = prompt_inputs["pixel_values"].to(model_dtype) - - # Sample 2 completions per prompt of size `max_new_tokens` from the model - prompt_ids = prompt_inputs["input_ids"].repeat(2, 1) - prompt_mask = prompt_inputs["attention_mask"].repeat(2, 1) - - # Prepare vision inputs if available - vision_generation_kwargs = {} - if self.is_vision_model and images is not None: - if "pixel_values" in prompt_inputs: - vision_generation_kwargs["pixel_values"] = prompt_inputs["pixel_values"].repeat(2, 1, 1, 1) - if "pixel_attention_mask" in prompt_inputs: - vision_generation_kwargs["pixel_attention_mask"] = prompt_inputs["pixel_attention_mask"].repeat(2, 1) - if "image_sizes" in prompt_inputs: - vision_generation_kwargs["image_sizes"] = prompt_inputs["image_sizes"].repeat(2, 1) - if "image_grid_thw" in prompt_inputs: - vision_generation_kwargs["image_grid_thw"] = prompt_inputs["image_grid_thw"].repeat(2, 1) - - if self.use_transformers_paged: - previous_attn = self.model_wrapped.config._attn_implementation - - if is_flash_attn_2_available(): - self.model_wrapped.config._attn_implementation = "paged_attention" - else: - self.model_wrapped.config._attn_implementation = "sdpa_paged" - with ( - profiling_context(self, "transformers.generate_batch"), - unwrap_model_for_generation( - model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation - ) as unwrapped_model, - torch.no_grad(), - FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(), - ): - # Cast to the appropriate dtype based on training configuration - if self.args.bf16: - unwrapped_model.to(torch.bfloat16) - elif self.args.fp16: - unwrapped_model.to(torch.float16) - with torch.inference_mode(): - all_outputs = unwrapped_model.generate_batch( - prompt_ids.tolist(), - generation_config=self.generation_config, - progress_bar=False, - ) - unwrapped_model.train() # restore training mode, as generate_batch forces eval mode - completion_ids = [output.generated_tokens for output in all_outputs.values()] - completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] - completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right") - prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) - # Restore the original attention implementation, training mode - self.model_wrapped.config._attn_implementation = previous_attn - - # Extract completion_ids and create completion_mask - prompt_length = prompt_ids.size(1) - completion_ids = prompt_completion_ids[:, prompt_length:] - completion_ids, completion_mask = truncate_right(completion_ids, eos_token_id, pad_token_id) - - return prompt_ids, prompt_mask, completion_ids, completion_mask - else: - # Regular generation path - with ( - profiling_context(self, "transformers.generate"), - unwrap_model_for_generation( - model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation - ) as unwrapped_model, - torch.no_grad(), - FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(), - ): - # Setup cache implementation if specified - if self.args.cache_implementation is not None: - unwrapped_model.generation_config.cache_implementation = self.args.cache_implementation - - # Standard generation - output = unwrapped_model.generate( - input_ids=prompt_ids, - attention_mask=prompt_mask, - generation_config=self.generation_config, - **vision_generation_kwargs, - ) - - completion_ids = output[:, prompt_ids.size(1) :] - completion_ids, completion_mask = truncate_right(completion_ids, eos_token_id, pad_token_id) - - return prompt_ids, prompt_mask, completion_ids, completion_mask - - def _calculate_rewards_from_functions(self, prompts, completions, completion_ids_list, **reward_kwargs): - """ - Calculate rewards using reward functions - """ - device = self.accelerator.device - rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device) - - # Add trainer state to reward kwargs for dynamic reward shaping - reward_kwargs["trainer_state"] = self.state - - for i, (reward_func, reward_processing_class) in enumerate( - zip(self.reward_funcs, self.reward_processing_classes) - ): - if isinstance(reward_func, nn.Module): # Model-based reward function - # Handle conversational vs text input - if is_conversational({"prompt": prompts[0]}): - messages = [{"messages": p + c} for p, c in zip(prompts, completions)] - texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages] - else: - texts = [p + c for p, c in zip(prompts, completions)] - - # Tokenize and get reward scores - reward_inputs = reward_processing_class( - text=texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False - ) - reward_inputs = {k: v.to(device) for k, v in reward_inputs.items()} - - with torch.inference_mode(): - rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,) - else: - # Custom reward function - output_reward_func = reward_func( - prompts=prompts, completions=completions, completion_ids=completion_ids_list, **reward_kwargs - ) - # Convert None values to NaN - output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func] - rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) - - # Weight and sum across all reward functions - if self.reward_weights is not None: - total_rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) - else: - total_rewards = rewards_per_func.nansum(dim=1) - - return total_rewards - - def _forward(self, model, prompt_ids, prompt_mask, completion_ids, completion_mask, vision_inputs=None): - # Get the number of tokens to truncate from prompt - num_tokens_to_truncate = max(prompt_ids.size(1) + completion_ids.size(1) - self.max_length, 0) - - # Truncate left to avoid oom - prompt_ids = prompt_ids[:, num_tokens_to_truncate:] - prompt_mask = prompt_mask[:, num_tokens_to_truncate:] - - # Concat the prompt and completion - prompt_completion_ids = torch.cat((prompt_ids, completion_ids), dim=1) - prompt_completion_mask = torch.cat((prompt_mask, completion_mask), dim=1) - - # Prepare model kwargs with vision inputs if available - model_kwargs = {"attention_mask": prompt_completion_mask} - if vision_inputs is not None: - if "pixel_values" in vision_inputs: - model_kwargs["pixel_values"] = vision_inputs["pixel_values"] - if "pixel_attention_mask" in vision_inputs: - model_kwargs["pixel_attention_mask"] = vision_inputs["pixel_attention_mask"] - if "image_sizes" in vision_inputs: - model_kwargs["image_sizes"] = vision_inputs["image_sizes"] - if "image_grid_thw" in vision_inputs: - model_kwargs["image_grid_thw"] = vision_inputs["image_grid_thw"] - - # Get the logprobs of the completions from the model - output = model(prompt_completion_ids, **model_kwargs) - - # There is 1 offset, because the model predicts the next token - prompt_len = prompt_ids.size(1) - start_idx = prompt_len - 1 if prompt_len > 0 else 0 - # Only slice off the last logit when we have a prompt, otherwise we need all logits - end_idx = -1 if prompt_len > 0 else None - logits = output.logits[:, start_idx:end_idx] - - # Take the completion tokens logprob - logprobs = torch.take_along_dim(logits.log_softmax(dim=-1), completion_ids.unsqueeze(-1), dim=2).squeeze(-1) - return logprobs - - def training_step( - self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None - ) -> torch.Tensor: - model.train() - - prompts = inputs["prompt"] - batch_size = len(prompts) - - # Handle images for VLM support - has_images = "image" in inputs - images = None - if has_images: - images = inputs["image"] - # Convert conversational prompts to include image tokens - for prompt in prompts: - if isinstance(prompt, list): - for message in prompt: - if not isinstance(message, dict): - continue - content = message.get("content") - role = message.get("role") - if isinstance(content, str): - if role == "user": - message["content"] = [{"type": "image"}, {"type": "text", "text": content}] - elif role == "system": - message["content"] = [{"type": "text", "text": content}] - - if self.args.use_vllm: - prompt_ids, prompt_mask, completion_ids, completion_mask = self._generate_vllm(prompts, images) - else: - prompt_ids, prompt_mask, completion_ids, completion_mask = self._generate(model, prompts, images) - - contain_eos_token = torch.any(completion_ids == self.eos_token_id, dim=-1) - - # Extract vision inputs if available for VLM support - vision_inputs = None - if has_images and self.is_vision_model and not self.args.use_vllm: - # For vision models with transformers generation, we need to prepare vision inputs - # Process the images to get vision inputs that can be passed through the forward pass - vision_inputs = {} - kwargs = {"images": [[img] for img in images]} - processed = self.processing_class( - text=[""] * len(images), # Dummy text for vision processing - return_tensors="pt", - **kwargs, - ) - # Handle DataParallel wrapped models - model_device = getattr(model, "device", None) - model_dtype = getattr(model, "dtype", None) - if model_device is None and hasattr(model, "module"): - model_device = model.module.device - model_dtype = model.module.dtype - # Move vision tensors to device and convert to model dtype - # Need to duplicate for 2 completions per prompt - if "pixel_values" in processed: - vision_inputs["pixel_values"] = ( - processed["pixel_values"].to(model_device, dtype=model_dtype).repeat(2, 1, 1, 1) - ) - if "pixel_attention_mask" in processed: - vision_inputs["pixel_attention_mask"] = processed["pixel_attention_mask"].to(model_device).repeat(2, 1) - if "image_sizes" in processed: - vision_inputs["image_sizes"] = processed["image_sizes"].to(model_device).repeat(2, 1) - if "image_grid_thw" in processed: - vision_inputs["image_grid_thw"] = processed["image_grid_thw"].to(model_device).repeat(2, 1) - - logprobs = self._forward(model, prompt_ids, prompt_mask, completion_ids, completion_mask, vision_inputs) - with torch.no_grad(): - if self.ref_model is not None: - ref_logprobs = self._forward( - self.ref_model, prompt_ids, prompt_mask, completion_ids, completion_mask, vision_inputs - ) - else: # peft case: we just need to disable the adapter - with self.model.disable_adapter(): - ref_logprobs = self._forward( - self.model, prompt_ids, prompt_mask, completion_ids, completion_mask, vision_inputs - ) - - # Decode the completions, and format them if the input is conversational - device = logprobs.device - completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) - if is_conversational({"prompt": prompts[0]}): - completions = [[{"role": "assistant", "content": completion}] for completion in completions] - - # Get the reward from reward functions, judge, or deprecated reward_model - if self.reward_funcs is not None: - # First create completion_ids_list for custom reward functions - completion_ids_list = [completion_ids[i].tolist() for i in range(completion_ids.shape[0])] - - # Extract additional fields from inputs for reward functions - reward_kwargs = {} - keys = [key for key in inputs if key not in ["prompt"]] - for key in keys: - if isinstance(inputs[key], (list, tuple)): - # Repeat input fields to match number of completions (2 per prompt) - reward_kwargs[key] = inputs[key] * 2 - else: - reward_kwargs[key] = inputs[key] - - # Calculate rewards using reward functions - rewards = self._calculate_rewards_from_functions( - prompts=2 * prompts, completions=completions, completion_ids_list=completion_ids_list, **reward_kwargs - ) - - # Apply missing EOS penalty if configured - if self.args.missing_eos_penalty is not None: - rewards[~contain_eos_token] -= self.args.missing_eos_penalty - - # Split rewards into chosen/rejected pairs - first_half, second_half = rewards.split(batch_size) - mask = first_half >= second_half - elif self.judge is not None: - # Once formatted, conversational data may contain special tokens (such as <|im_start|>) that are not - # directly understandable by the judge and could alter its judgment. To avoid this and make the judge - # independent of the model's chat template, we use the raw conversation data, and apply our own chat - # template to it. - if is_conversational({"prompt": prompts[0]}): - environment = jinja2.Environment() - template = environment.from_string(SIMPLE_CHAT_TEMPLATE) - prompts = [template.render(messages=prompt) for prompt in prompts] - completions = [template.render(messages=completion) for completion in completions] - - ranks_of_first_completion = self.judge.judge( - prompts, list(zip(completions[:batch_size], completions[batch_size:])) - ) - - # convert ranks to a True/False mask: - # when rank == 0, it means the first completion is the best - # when rank == 1, it means the second completion is the best - mask = torch.tensor([rank == 0 for rank in ranks_of_first_completion], device=device) - - batch_range = torch.arange(batch_size, device=device) - chosen_indices = batch_range + (~mask * batch_size) - rejected_indices = batch_range + (mask * batch_size) - - # Build tensor so that the first half is the chosen examples and the second half the rejected examples - cr_indices = torch.cat((chosen_indices, rejected_indices), dim=0) # cr = chosen and rejected - cr_logprobs = logprobs[cr_indices] - cr_ref_logprobs = ref_logprobs[cr_indices] - - # mask out the padding tokens - padding_mask = ~completion_mask.bool() - cr_padding_mask = padding_mask[cr_indices] - - cr_logprobs_sum = (cr_logprobs * ~cr_padding_mask).sum(1) - cr_ref_logprobs_sum = (cr_ref_logprobs * ~cr_padding_mask).sum(1) - - # Split the chosen and rejected examples - chosen_logprobs_sum, rejected_logprobs_sum = torch.split(cr_logprobs_sum, batch_size) - chosen_ref_logprobs_sum, rejected_ref_logprobs_sum = torch.split(cr_ref_logprobs_sum, batch_size) - pi_logratios = chosen_logprobs_sum - rejected_logprobs_sum - ref_logratios = chosen_ref_logprobs_sum - rejected_ref_logprobs_sum - - logits = pi_logratios - ref_logratios - - if self.args.loss_type == "sigmoid": - losses = -F.logsigmoid(self.beta * logits) - elif self.args.loss_type == "ipo": - losses = (logits - 1 / (2 * self.beta)) ** 2 - else: - raise NotImplementedError(f"invalid loss type {self.loss_type}") - - loss = losses.mean() - - # Log everything - if self.reward_funcs is not None: - # When using reward_funcs, we have rewards instead of scores - scores_margin = rewards[chosen_indices] - rewards[rejected_indices] - self.stats["objective/scores_margin"].append( - self.accelerator.gather_for_metrics(scores_margin.mean()).mean().item() - ) - self.stats["objective/scores"].append(self.accelerator.gather_for_metrics(rewards.mean()).mean().item()) - self.stats["val/contain_eos_token"].append(contain_eos_token.float().mean().item()) - self.stats["logps/chosen"].append(self.accelerator.gather_for_metrics(chosen_logprobs_sum).mean().item()) - self.stats["logps/rejected"].append(self.accelerator.gather_for_metrics(rejected_logprobs_sum).mean().item()) - - kl = logprobs - ref_logprobs - mean_kl = kl.sum(1).mean() - self.stats["objective/kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item()) - non_score_reward = (-self.beta * kl).sum(1) - mean_non_score_reward = non_score_reward.mean() - self.stats["objective/non_score_reward"].append( - self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item() - ) - if self.reward_funcs is not None: - # Calculate RLHF reward by combining rewards with non_score_reward - rlhf_reward = rewards + non_score_reward - self.stats["objective/rlhf_reward"].append(self.accelerator.gather_for_metrics(rlhf_reward).mean().item()) - - mean_entropy = -logprobs.sum(1).mean() - self.stats["objective/entropy"].append(self.accelerator.gather_for_metrics(mean_entropy).mean().item()) - chosen_rewards = self.beta * (chosen_logprobs_sum - chosen_ref_logprobs_sum) - gathered_chosen_rewards = self.accelerator.gather_for_metrics(chosen_rewards) - self.stats["rewards/chosen"].append(gathered_chosen_rewards.mean().item()) - rejected_rewards = self.beta * (rejected_logprobs_sum - rejected_ref_logprobs_sum) - gathered_rejected_rewards = self.accelerator.gather_for_metrics(rejected_rewards) - self.stats["rewards/rejected"].append(gathered_rejected_rewards.mean().item()) - margin = gathered_chosen_rewards - gathered_rejected_rewards - self.stats["rewards/margins"].append(margin.mean().item()) - accuracy = margin > 0 - self.stats["rewards/accuracies"].append(accuracy.float().mean().item()) - self.stats["beta"].append(self.beta) - - if ( - self.args.torch_empty_cache_steps is not None - and self.state.global_step % self.args.torch_empty_cache_steps == 0 - ): - empty_cache() - - kwargs = {} - - # For LOMO optimizers you need to explicitly use the learning rate - if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]: - kwargs["learning_rate"] = self._get_learning_rate() - - if self.args.n_gpu > 1: - loss = loss.mean() # mean() to average on multi-gpu parallel training - - self.accelerator.backward(loss, **kwargs) - - return loss.detach() / self.args.gradient_accumulation_steps - - # Same as Trainer._maybe_log_save_evaluate but log our metrics - def _maybe_log_save_evaluate( - self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, learning_rate=None - ): - if self.control.should_log and self.state.global_step > self._globalstep_last_logged: - logs: dict[str, float] = {} - - # all_gather + mean() to get average loss over all processes - tr_loss_scalar = self._nested_gather(tr_loss).mean().item() - - # reset tr_loss to zero - tr_loss -= tr_loss - - logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) - if grad_norm is not None: - logs["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm - if learning_rate is not None: - logs["learning_rate"] = learning_rate - else: - logs["learning_rate"] = self._get_learning_rate() - - # Add our metrics - for key, val in self.stats.items(): - logs[key] = sum(val) / len(val) - self.stats = {key: [] for key in self.stats} # reset stats - - self._total_loss_scalar += tr_loss_scalar - self._globalstep_last_logged = self.state.global_step - self.store_flos() - self.log(logs, start_time) - - metrics = None - if self.control.should_evaluate: - metrics = self._evaluate(trial, ignore_keys_for_eval) - is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial) - - if self.args.save_strategy == "best": - self.control.should_save = is_new_best_metric - - if self.control.should_save: - self._save_checkpoint(model, trial) - self.control = self.callback_handler.on_save(self.args, self.state, self.control) - - # Ensure the model card is saved along with the checkpoint - def _save_checkpoint(self, model, trial): - if self.args.hub_model_id is None: - model_name = Path(self.args.output_dir).name - else: - model_name = self.args.hub_model_id.split("/")[-1] - self.create_model_card(model_name=model_name) - super()._save_checkpoint(model, trial) -class UnslothOnlineDPOTrainer(_UnslothOnlineDPOTrainer): - """ - - Initialize OnlineDPOTrainer. - - Args: - model (`Union[str, nn.Module, PreTrainedModel]`): - Model to be trained. Can be either: - - - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a - path to a *directory* containing model weights saved using - [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded - using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keyword arguments in - `args.model_init_kwargs`. - - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported. - ref_model ([`~transformers.PreTrainedModel`] or `torch.nn.Module` or `None`): - The reference model to use for training. If None is specified, the reference model will be created from the - model. - judge ([`BasePairwiseJudge`]): - The judge to use for pairwise comparison of model completions. - reward_funcs (`Union[RewardFunc, list[RewardFunc]]`, *optional*): - Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward - functions with the prompts and completions and sum the rewards. Can be either: - - - A single reward function: Can be a string (path to model), a [`~transformers.PreTrainedModel`], or a - custom callable function. - - A list of reward functions: Must all be of compatible types. - - Note: Only one of `judge`, or `reward_funcs` should be provided. - args ([`OnlineDPOConfig`]): - The online DPO config arguments to use for training. - data_collator ([`~transformers.DataCollator`]): - The data collator to use for training. If None is specified, the default data collator - ([`DPODataCollatorWithPadding`]) will be used which will pad the sequences to the maximum length of the - sequences in the batch, given a dataset of paired sequences. - train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]): - The dataset to use for training. - eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`): - The dataset to use for evaluation. - processing_class ([`~transformers.PreTrainedTokenizerBase`] or [`~transformers.ProcessorMixin`], *optional*): - Processing class used to process the data. If provided, will be used to automatically process the inputs - for the model, and it will be saved along the model to make it easier to rerun an interrupted training or - reuse the fine-tuned model. - reward_processing_classes ([`~transformers.PreTrainedTokenizerBase`] or `list[PreTrainedTokenizerBase]`, *optional*): - Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either: - - - A single processing class: Used when `reward_funcs` contains only one reward function. - - A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`. - - If set to `None`, the tokenizer for each model-based reward function is automatically loaded using - [`~transformers.AutoTokenizer.from_pretrained`]. - peft_config ([`~peft.PeftConfig`], *optional*): - PEFT configuration used to wrap the model. If `None`, the model is not wrapped. - compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): - The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to - metric values. - callbacks (`list[transformers.TrainerCallback]`): - The callbacks to use for training. - optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): - The optimizer and scheduler to use for training. - preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): - The function to use to preprocess the logits before computing the metrics. - - reward_model: - - - - This parameter is deprecated and will be removed in version 0.25.0. Use `reward_funcs` instead. - - - - """ - def __init__( - self, - model, - ref_model = None, - reward_funcs = None, - judge = None, - args = None, - data_collator = None, - train_dataset = None, - eval_dataset = None, - processing_class = None, - reward_processing_classes = None, - peft_config = None, - compute_metrics = None, - callbacks = None, - preprocess_logits_for_metrics = None, - reward_model = None, - reward_processing_class = None, - **kwargs - ): - if args is None: args = UnslothOnlineDPOConfig() - use_bf16 = getattr(args, 'bf16', False) - if type(use_bf16) is not bool: use_bf16 = False - use_fp16 = getattr(args, 'fp16', False) - if type(use_fp16) is not bool: use_fp16 = False - force_float32 = False - full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1' - if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'): - print('Unsloth: Switching to float32 training since model cannot work with float16') - force_float32 = True - mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') - dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None) - if dtype is None: dtype = model.get_input_embeddings().weight.dtype - from unsloth_zoo.utils import _get_dtype - dtype = _get_dtype(dtype) - float16 = dtype == torch.float16 - if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`') - if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`') - if force_float32: - # Forced float32 training - args.fp16 = False - args.bf16 = False - os.environ['ACCELERATE_MIXED_PRECISION'] = 'no' - if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no' - # args.mixed_precision is a new argument which needs to be set now - elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32': - # Mixed precision training - args.fp16 = float16 - args.bf16 = not float16 - os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16' - if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16' - # args.mixed_precision is a new argument which needs to be set now - elif mixed_precision_dtype == 'bfloat16': - # Both False since bfloat16 full finetuning doesn't do any autocasting. - args.fp16 = False - args.bf16 = False - os.environ['ACCELERATE_MIXED_PRECISION'] = 'no' - if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no' - # args.mixed_precision is a new argument which needs to be set now - - if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no': - args.eval_strategy = 'steps' - if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1 - ga_steps = getattr(args, 'gradient_accumulation_steps', None) - if ga_steps is not None and ga_steps > 1: - from transformers import __version__ as transformers_version - if Version(transformers_version) <= Version('4.45.2'): - print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n' - '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`') - if getattr(args, 'eval_strategy', 'no') != 'no': - eval_bsz = getattr(args, 'per_device_eval_batch_size', 8) - if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size - if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps - fp16_full_eval = getattr(args, 'fp16_full_eval', False) - if type(fp16_full_eval) is not bool: fp16_full_eval = False - bf16_full_eval = getattr(args, 'bf16_full_eval', False) - if type(bf16_full_eval) is not bool: bf16_full_eval = False - if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True - if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False - if force_float32: - args.bf16_full_eval = False - args.fp16_full_eval = False - elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16': - args.bf16_full_eval = True - args.fp16_full_eval = False - elif not bf16_full_eval and not fp16_full_eval: - args.bf16_full_eval = args.bf16 - args.fp16_full_eval = args.fp16 - _output_logits = False - if locals().get('compute_metrics', None) is not None: _output_logits = True - if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True - if _output_logits: - os.environ['UNSLOTH_RETURN_LOGITS'] = '1' - if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'): - pass - else: - model_max_seq_length = getattr(model, 'max_seq_length', None) - args_max_seq_length = getattr(args, 'max_seq_length', None) - if args_max_seq_length is None and model_max_seq_length is not None: - max_seq_length = model.max_seq_length - if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length - elif args_max_seq_length is not None and model_max_seq_length is not None: - if args_max_seq_length > model_max_seq_length: - print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but ' - 'the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.') - args.max_seq_length = model_max_seq_length - if model is not None and hasattr(model, 'for_training'): - model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True)) - if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right' - if 'processing_class' in locals(): - if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right' - if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right' - __tokenizer = processing_class if 'processing_class' in locals() else tokenizer - from unsloth_zoo.vision_utils import UnslothVisionDataCollator - if not isinstance(data_collator, UnslothVisionDataCollator): - if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names: - data_collator = TransformersDataCollatorForLanguageModeling( - __tokenizer, - mlm = False, - mlm_probability = 0.0, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names: - data_collator = DataCollatorForSeq2Seq( - __tokenizer, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - else: - if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False - if hasattr(args, 'dataset_text_field'): args.dataset_text_field = '' - if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True} - if not isinstance(data_collator, UnslothVisionDataCollator): - if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'): - if isinstance(data_collator, DataCollatorForSeq2Seq): - data_collator = DataCollatorForSeq2Seq( - __tokenizer.tokenizer, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - else: - data_collator = TransformersDataCollatorForLanguageModeling( - __tokenizer.tokenizer, - mlm = False, - mlm_probability = 0.0, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - other_metrics = [] - - from unsloth_zoo.logging_utils import PatchRLStatistics - PatchRLStatistics('online_dpo_trainer', other_metrics) - - # [TODO] Fix up DataParallel multiplying batch sizes - # [TODO] DDP works, but DP seems to not work? [TODO] - if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1: - if getattr(args, "_n_gpu", 1) != 1: - args._n_gpu = 1 - if "model" in locals() and hasattr(model, "for_training"): - model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True)) - super().__init__( - model = model, - ref_model = ref_model, - reward_funcs = reward_funcs, - judge = judge, - args = args, - data_collator = data_collator, - train_dataset = train_dataset, - eval_dataset = eval_dataset, - processing_class = processing_class, - reward_processing_classes = reward_processing_classes, - peft_config = peft_config, - compute_metrics = compute_metrics, - callbacks = callbacks, - preprocess_logits_for_metrics = preprocess_logits_for_metrics, - reward_model = reward_model, - reward_processing_class = reward_processing_class,**kwargs) - if "model" in locals() and hasattr(model, "for_inference"): - model.for_inference() - if hasattr(self, 'neftune_hook_handle'): - self.neftune_hook_handle.remove() - if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle - if getattr(args, 'neftune_noise_alpha', None) is not None: - model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha - pass - if hasattr(self, 'accelerator'): - scaler = self.accelerator.scaler - current_model = model - while hasattr(current_model, 'model'): - current_model.accelerator_scaler = scaler - current_model = current_model.model - current_model.accelerator_scaler = scaler - pass - if hasattr(self, 'train'): - self.train = MethodType(prepare_for_training_mode(self.__class__.train), self) - pass - if hasattr(self, 'llm') and self.llm is not None and hasattr(self.llm, 'get_tokenizer'): - _vllm_tok = self.llm.get_tokenizer() - _pc = getattr(self, 'processing_class', None) or getattr(self, 'tokenizer', None) - if _vllm_tok is not None and _pc is not None and getattr(_pc, 'chat_template', None) is not None and getattr(_vllm_tok, 'chat_template', None) is None: - _vllm_tok.chat_template = _pc.chat_template - pass - -pass - - -if hasattr(logger, "addFilter"): - import logging - class HideLoggingMessage(logging.Filter): - def __init__(self, text): self.text = text - def filter(self, x): return not (self.text in x.getMessage()) - pass - logger.addFilter(HideLoggingMessage("`use_cache=True`")) - diff --git a/unsloth_compiled_cache/UnslothPPOTrainer.py b/unsloth_compiled_cache/UnslothPPOTrainer.py deleted file mode 100644 index 5226980..0000000 --- a/unsloth_compiled_cache/UnslothPPOTrainer.py +++ /dev/null @@ -1,1658 +0,0 @@ -""" -2026.2.1 -2026.2.1 -4.57.6 -0.24.0 -__UNSLOTH_VERSIONING__ -""" - -# Unsloth auto generated code -# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see . - -from torch import Tensor -import torch -import torch.nn as nn -from torch.nn import functional as F -from unsloth_zoo.temporary_patches.common import torch_compile -from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable -from trl.trainer.ppo_trainer import (Accelerator, BaseImageProcessor, BaseTrainer, CallbackHandler, DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK, DataCollatorWithPadding, DataLoader, Dataset, ExportableState, FeatureExtractionMixin, GenerationConfig, INVALID_LOGPROB, OnlineTrainerState, Optional, PPOConfig, PPOTrainer, Path, PeftConfig, PeftModel, PolicyAndValueWrapper, PreTrainedTokenizerBase, PrinterCallback, ProcessorMixin, TrainerCallback, TrainerControl, Union, batch_generation, broadcast, contextmanager, create_reference_model, defaultdict, disable_dropout_in_model, empty_cache, exact_div, first_true_indices, forward, gather_object, gc, get_peft_model, get_reporting_integration_callbacks, get_reward, is_peft_available, is_rich_available, log_table_to_comet_experiment, masked_mean, masked_whiten, math, nn, np, nullcontext, os, pd, peft_module_casting_to_bf16, prepare_deepspeed, print_rich_table, selective_log_softmax, textwrap, time, torch, truncate_response, unwrap_model_for_generation, warnings, Accelerator, BaseImageProcessor, CallbackHandler, DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK, DataCollatorWithPadding, DataLoader, Dataset, ExportableState, FeatureExtractionMixin, OnlineTrainerState, Optional, PPOConfig, PeftConfig, PeftModel, PolicyAndValueWrapper, PreTrainedTokenizerBase, PrinterCallback, ProcessorMixin, TrainerCallback, TrainerControl, Union, broadcast, create_reference_model, disable_dropout_in_model, exact_div, forward, get_peft_model, get_reporting_integration_callbacks, is_peft_available, math, nn, os, pd, peft_module_casting_to_bf16, prepare_deepspeed, time, torch, warnings, Optional, PeftModel, is_peft_available, os, torch) - - -import os -from typing import * -from dataclasses import dataclass, field -from packaging.version import Version -import torch -import numpy as np -from contextlib import nullcontext -from torch.nn import functional as F -import inspect -from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling -from transformers.training_args import ParallelMode -from unsloth_zoo.device_type import DEVICE_TYPE, device_synchronize - -# Wrap trainer with padding to right and enable training mode -# Also patches W&B since multiple runs must use wandb.finish() -import functools -from types import MethodType -try: - from unsloth_zoo.gradient_checkpointing import reset_unsloth_gradient_checkpointing_buffers -except: - def reset_unsloth_gradient_checkpointing_buffers(): pass -def prepare_for_training_mode(f): - @functools.wraps(f) - def wrapper(self, *args, **kwargs): - # Enable training mode - _was_training = None - # Get gradient checkpointing setting from training arguments - use_gc = getattr(self.args, 'gradient_checkpointing', True) - if hasattr(self, 'model') and hasattr(self.model, "training"): - _was_training = self.model.training - if hasattr(self, 'model') and hasattr(self.model, "for_training"): - self.model.for_training(use_gradient_checkpointing=use_gc) - output = f(self, *args, **kwargs) - # Restore previous mode when possible - if hasattr(self, 'model') and hasattr(self.model, "for_inference"): - if _was_training is False: - self.model.for_inference() - elif _was_training is True and hasattr(self.model, "for_training"): - self.model.for_training(use_gradient_checkpointing=use_gc) - # Reset gradient checkpointing buffers to free memory while staying ready for next run - try: - reset_unsloth_gradient_checkpointing_buffers() - except: - pass - # Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run - try: - import wandb - wandb.finish() - except: - pass - return output - return wrapper -pass - -torch_compile_options = { - "epilogue_fusion" : True, - "max_autotune" : False, - "shape_padding" : True, - "trace.enabled" : False, - "triton.cudagraphs" : False, -} - -@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) -def chunked_hidden_states_selective_log_softmax( - hidden_states: torch.Tensor, - lm_head: torch.Tensor, - index: torch.Tensor, - chunks: int = 4, - logit_scale_multiply: float = 0.0, - logit_scale_divide: float = 0.0, - logit_softcapping: float = 0.0, - temperature: float = 1.0, -) -> torch.Tensor: - # All Unsloth Zoo code licensed under AGPL3 - flat_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1]) - flat_index = index.reshape(-1) - - chunked_hidden_states = torch.chunk(flat_hidden_states, chunks=chunks, dim=0) - chunked_index = torch.chunk(flat_index, chunks=chunks, dim=0) - - all_per_token_logps = [] - - for chunk_hidden_states, chunk_index in zip(chunked_hidden_states, chunked_index): - chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t() - - if logit_scale_multiply != 0.0: - chunk_logits = chunk_logits * logit_scale_multiply - if logit_scale_divide != 0.0: - chunk_logits = chunk_logits / logit_scale_divide - if logit_softcapping != 0.0: - chunk_logits = chunk_logits * torch.tanh(chunk_logits / logit_softcapping) - - chunk_logits = chunk_logits.to(torch.float32) - - if temperature != 1.0: - chunk_logits = chunk_logits / temperature - - selected_logits = torch.gather(chunk_logits, dim=-1, index=chunk_index.unsqueeze(-1)).squeeze(-1) - logsumexp_values = torch.logsumexp(chunk_logits, dim=-1) - per_token_logps = selected_logits - logsumexp_values - all_per_token_logps.append(per_token_logps) - - all_per_token_logps = torch.concat(all_per_token_logps) - - all_per_token_logps = all_per_token_logps.reshape((hidden_states.shape[0], hidden_states.shape[1])) - return all_per_token_logps - -@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) -def chunked_selective_log_softmax(logits, index): - # Split into 4 chunks only - chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0) - chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0) - all_per_token_logps = [] - # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index) - for chunk_logits, chunk_index in zip(chunked_logits, chunked_index): - chunk_logits = chunk_logits.to(torch.float32) - selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1) - logsumexp_values = torch.logsumexp(chunk_logits, dim = -1) - per_token_logps = selected_logits - logsumexp_values - all_per_token_logps.append(per_token_logps) - pass - all_per_token_logps = torch.concat(all_per_token_logps) - all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1])) - return all_per_token_logps - -def calculate_pad_tokens_in_prompt( - input_ids: torch.Tensor, - logits_to_keep: int, - pad_token_id: int -) -> torch.Tensor: - """ - Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens - """ - if logits_to_keep >= input_ids.shape[1]: - raise ValueError("logits_to_keep must be smaller than the sequence length.") - - prompt_section = input_ids[:, :-logits_to_keep] - - padding_mask = (prompt_section == pad_token_id) - - pad_token_counts = padding_mask.sum(dim=1) - - return pad_token_counts - -def create_completion_attention_mask( - completion_input_ids: torch.Tensor, - left_pad_tokens_per_prompt: torch.Tensor, - max_left_pad: int, - pad_token_id: int -) -> torch.Tensor: - """ - Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad] - - Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens - and pad are pad tokens, this function would make a completion mask that would 0 out the pad - and p tokens. so in this example [0,0,0,1,1,1,0,0,0] - """ - batch_size, completion_len = completion_input_ids.shape - device = completion_input_ids.device - - num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt - - indices = torch.arange(completion_len, device=device).unsqueeze(0) - shift_mask = indices >= num_tokens_to_mask.unsqueeze(1) - - non_padding_mask = (completion_input_ids != pad_token_id) - - final_mask = shift_mask & non_padding_mask - - return final_mask - -def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor: - """ - Moves all padding tokens in each sequence of a batch to the right. - """ - mask = (tensor != pad_id) - # Must do stable=True since binary mark is unordered - sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True) - packed_tensor = torch.gather(tensor, 1, sorted_indices) - return packed_tensor - -def align_logprobs_with_mask( - logprob_tensor: torch.Tensor, - attention_mask: torch.Tensor, - pad_value: float = 0.0 -) -> torch.Tensor: - """ - Aligns a log probability tensor with a given attention mask. - """ - - device = logprob_tensor.device - batch_size, logprob_seq_len = logprob_tensor.shape - mask_seq_len = attention_mask.shape[1] - - padded_logprobs = torch.full( - attention_mask.shape, - fill_value=pad_value, - dtype=logprob_tensor.dtype, - device=device - ) - - left_pad_counts = torch.argmax(attention_mask, dim=1) - - cols = torch.arange(logprob_seq_len, device=device) - dest_indices = left_pad_counts.unsqueeze(1) + cols - - # Create destination row indices - # Shape: [batch_size, logprob_seq_len] - row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices) - - # --- 4. Filter out-of-bounds indices and perform assignment --- - # Create a mask to identify only the indices that are within the bounds - # of the target tensor's sequence length. - valid_mask = dest_indices < mask_seq_len - - # Use this mask to select only the valid row indices, column indices, - # and the corresponding values from the logprob tensor. - # This flattens the selected elements into 1D tensors. - valid_rows = row_indices[valid_mask] - valid_cols = dest_indices[valid_mask] - valid_vals = logprob_tensor[valid_mask] - - # Place the valid values into their correct positions in the padded tensor - # using a single, efficient advanced indexing operation. - padded_logprobs[valid_rows, valid_cols] = valid_vals - - return padded_logprobs - -def autotune_batch_and_chunks( - total_input_rows, - seq_len, - hidden_size, - vocab_size, - dtype_bytes=16, - multiplier=None -): - if multiplier is None: - final_m = max(4, seq_len // 4096) - else: - final_m = multiplier - - if torch.cuda.is_available(): - free_bytes, _ = torch.cuda.mem_get_info() - limit_gb = (free_bytes / (1024**3))*.80 - elif hasattr(torch, "xpu") and torch.xpu.is_available(): - # For XPU: estimate free memory from total - reserved - total_mem = torch.xpu.get_device_properties(0).total_memory - reserved_mem = torch.xpu.memory_reserved() - free_bytes = total_mem - reserved_mem - limit_gb = (free_bytes / (1024**3)) * 0.80 - else: - # Fallback: assume 8GB available - limit_gb = 8.0 - - bytes_to_gb = 1024**3 - - b_vals = torch.arange(total_input_rows, 0, -1, device='cpu', dtype=torch.float32) - - hidden_gb = (b_vals * seq_len * hidden_size * dtype_bytes) / bytes_to_gb - - base_logits = ((b_vals/total_input_rows) * b_vals * seq_len * vocab_size * dtype_bytes) / bytes_to_gb - logits_gb = base_logits / final_m - - total_mem_gb = hidden_gb + logits_gb - - valid_mask = total_mem_gb <= limit_gb - valid_indices = torch.nonzero(valid_mask, as_tuple=False) - - if valid_indices.shape[0] == 0: - #This means your GPU will OOM - return 4, final_m - - best_idx = valid_indices[0].item() - final_b = int(b_vals[best_idx].item()) - - return final_b, final_m -@dataclass -class UnslothPPOConfig(PPOConfig): - """ - - Configuration class for the [`PPOTrainer`]. - - This class includes only the parameters that are specific to PPO training. For a full list of training arguments, - please refer to the [`~transformers.TrainingArguments`] and [`OnPolicyConfig`] documentation. Note that default - values in this class may differ from those in [`~transformers.TrainingArguments`]. - - Using [`~transformers.HfArgumentParser`] we can turn this class into - [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the - command line. - - Parameters: - exp_name (`str`, *optional*, defaults to `os.path.basename(__file__)[:-3]`): - Name of this experiment. - reward_model_path (`str`, *optional*, defaults to `"EleutherAI/pythia-160m"`): - Path to the reward model. - model_adapter_name (`str`, *optional*): - Name of the train target PEFT adapter, when using LoRA with multiple adapters. - ref_adapter_name (`str`, *optional*): - Name of the reference PEFT adapter, when using LoRA with multiple adapters. - num_ppo_epochs (`int`, *optional*, defaults to `4`): - Number of epochs to train. - whiten_rewards (`bool`, *optional*, defaults to `False`): - Whether to whiten the rewards. - kl_coef (`float`, *optional*, defaults to `0.05`): - KL coefficient. - kl_estimator (`Literal["k1", "k3"]`, *optional*, defaults to `"k1"`): - Which estimator for KL-Divergence to use from [Approximating KL - Divergence](http://joschu.net/blog/kl-approx.html). Defaults to "k1", a straightforward, unbiased - estimator. Can be set to "k3", an unbiased estimator with lower variance which "appears to be a strictly - better estimator". Cannot be set to "k2", as it is used for logging purposes. - cliprange (`float`, *optional*, defaults to `0.2`): - Clip range. - vf_coef (`float`, *optional*, defaults to `0.1`): - Value function coefficient. - cliprange_value (`float`, *optional*, defaults to `0.2`): - Clip range for the value function. - gamma (`float`, *optional*, defaults to `1.0`): - Discount factor. - lam (`float`, *optional*, defaults to `0.95`): - Lambda value for GAE. - ds3_gather_for_generation (`bool`, *optional*, defaults to `True`): - This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation, - improving generation speed. However, disabling this option allows training models that exceed the VRAM - capacity of a single GPU, albeit at the cost of slower generation. - - """ - vllm_sampling_params: Optional[Any] = field( - default = None, - metadata = {'help': 'vLLM SamplingParams'}, - ) - unsloth_num_chunks : Optional[int] = field( - default = -1, - metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'}, - ) - unsloth_logit_chunk_multiplier : Optional[int] = field( - default = None, - metadata = {'help': 'Multiplier for chunked logit computations.'}, - ) - unsloth_grpo_mini_batch : Optional[int] = field( - default = None, - metadata = {'help': 'Mini batch size for GRPO hidden state accumulation. Default is None unless user defines it.'}, - ) - - def __init__( - self, - output_dir = None, - overwrite_output_dir = None, - do_train = False, - do_eval = False, - do_predict = False, - eval_strategy = 'no', - prediction_loss_only = False, - per_device_train_batch_size = 4, - per_device_eval_batch_size = 4, - per_gpu_train_batch_size = None, - per_gpu_eval_batch_size = None, - gradient_accumulation_steps = 2, - eval_accumulation_steps = 2, - eval_delay = 0, - torch_empty_cache_steps = 250, - learning_rate = 5e-05, - weight_decay = 0.01, - adam_beta1 = 0.9, - adam_beta2 = 0.999, - adam_epsilon = 1e-08, - max_grad_norm = 1.0, - num_train_epochs = 3.0, - max_steps = -1, - lr_scheduler_type = 'linear', - lr_scheduler_kwargs = None, - warmup_ratio = 0.1, - warmup_steps = 0, - log_level = 'passive', - log_level_replica = 'warning', - log_on_each_node = True, - logging_dir = None, - logging_strategy = 'steps', - logging_first_step = False, - logging_steps = 1, - logging_nan_inf_filter = False, - save_strategy = 'steps', - save_steps = 500, - save_total_limit = None, - save_safetensors = True, - save_on_each_node = False, - save_only_model = False, - restore_callback_states_from_checkpoint = False, - no_cuda = False, - use_cpu = False, - use_mps_device = False, - seed = 3407, - data_seed = 3407, - jit_mode_eval = False, - bf16 = False, - fp16 = False, - fp16_opt_level = 'O1', - half_precision_backend = 'auto', - bf16_full_eval = False, - fp16_full_eval = False, - tf32 = None, - local_rank = -1, - ddp_backend = None, - tpu_num_cores = None, - tpu_metrics_debug = False, - debug = '', - dataloader_drop_last = False, - eval_steps = None, - dataloader_num_workers = 0, - dataloader_prefetch_factor = None, - past_index = -1, - run_name = None, - disable_tqdm = None, - remove_unused_columns = True, - label_names = None, - load_best_model_at_end = False, - metric_for_best_model = None, - greater_is_better = None, - ignore_data_skip = False, - fsdp = None, - fsdp_min_num_params = 0, - fsdp_config = None, - fsdp_transformer_layer_cls_to_wrap = None, - accelerator_config = None, - parallelism_config = None, - deepspeed = None, - label_smoothing_factor = 0.0, - optim = 'adamw_8bit', - optim_args = None, - adafactor = False, - group_by_length = False, - length_column_name = 'length', - report_to = 'none', - project = 'huggingface', - trackio_space_id = 'trackio', - ddp_find_unused_parameters = None, - ddp_bucket_cap_mb = None, - ddp_broadcast_buffers = None, - dataloader_pin_memory = True, - dataloader_persistent_workers = False, - skip_memory_metrics = True, - use_legacy_prediction_loop = False, - push_to_hub = False, - resume_from_checkpoint = None, - hub_model_id = None, - hub_strategy = 'every_save', - hub_token = None, - hub_private_repo = None, - hub_always_push = False, - hub_revision = None, - gradient_checkpointing = True, - gradient_checkpointing_kwargs = None, - include_inputs_for_metrics = False, - eval_do_concat_batches = True, - fp16_backend = 'auto', - push_to_hub_model_id = None, - push_to_hub_organization = None, - push_to_hub_token = None, - mp_parameters = '', - auto_find_batch_size = False, - full_determinism = False, - torchdynamo = None, - ray_scope = 'last', - ddp_timeout = 1800, - torch_compile = False, - torch_compile_backend = None, - torch_compile_mode = None, - include_tokens_per_second = False, - include_num_input_tokens_seen = False, - neftune_noise_alpha = None, - optim_target_modules = None, - batch_eval_metrics = False, - eval_on_start = False, - use_liger_kernel = False, - liger_kernel_config = None, - eval_use_gather_object = False, - average_tokens_across_devices = True, - dataset_num_proc = None, - num_mini_batches = 1, - total_episodes = None, - local_rollout_forward_batch_size = 64, - num_sample_generations = 10, - response_length = 53, - stop_token = None, - stop_token_id = None, - temperature = 0.7, - missing_eos_penalty = None, - sft_model_path = 'EleutherAI/pythia-160m', - world_size = None, - num_total_batches = None, - micro_batch_size = None, - local_batch_size = None, - batch_size = None, - local_mini_batch_size = None, - mini_batch_size = None, - exp_name = 'ppo_config', - reward_model_path = 'EleutherAI/pythia-160m', - model_adapter_name = None, - ref_adapter_name = None, - num_ppo_epochs = 4, - whiten_rewards = False, - kl_coef = 0.05, - kl_estimator = 'k1', - cliprange = 0.2, - vf_coef = 0.1, - cliprange_value = 0.2, - gamma = 1.0, - lam = 0.95, - ds3_gather_for_generation = True, - vllm_sampling_params = None, - unsloth_num_chunks = -1, - unsloth_logit_chunk_multiplier = None, - unsloth_grpo_mini_batch = None, - - **kwargs, - ): - if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!') - if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!') - if num_train_epochs is None: - num_train_epochs = 3.0 # Default to 3 epochs if None, max_steps will override - if output_dir is None and save_strategy == 'steps' and save_steps == 500: - output_dir = 'unsloth_training_checkpoints' - save_strategy = 'no' - import multiprocessing as _mp - if _mp.get_start_method() != 'fork': - dataset_num_proc = None - elif dataset_num_proc is None: - import psutil - dataset_num_proc = min(max((psutil.cpu_count() or 1)+4, 2), 64) - memory_gb_left = psutil.virtual_memory().available / (1024**3) - if memory_gb_left <= 2: dataset_num_proc = 1 - else: dataset_num_proc = min(dataset_num_proc, int(memory_gb_left)) - if temperature <= 0: - raise ValueError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.') - elif temperature >= 10: - raise ValueError('Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic.') - - - super().__init__( - output_dir = output_dir, - overwrite_output_dir = overwrite_output_dir, - do_train = do_train, - do_eval = do_eval, - do_predict = do_predict, - eval_strategy = eval_strategy, - prediction_loss_only = prediction_loss_only, - per_device_train_batch_size = per_device_train_batch_size, - per_device_eval_batch_size = per_device_eval_batch_size, - per_gpu_train_batch_size = per_gpu_train_batch_size, - per_gpu_eval_batch_size = per_gpu_eval_batch_size, - gradient_accumulation_steps = gradient_accumulation_steps, - eval_accumulation_steps = eval_accumulation_steps, - eval_delay = eval_delay, - torch_empty_cache_steps = torch_empty_cache_steps, - learning_rate = learning_rate, - weight_decay = weight_decay, - adam_beta1 = adam_beta1, - adam_beta2 = adam_beta2, - adam_epsilon = adam_epsilon, - max_grad_norm = max_grad_norm, - num_train_epochs = num_train_epochs, - max_steps = max_steps, - lr_scheduler_type = lr_scheduler_type, - lr_scheduler_kwargs = lr_scheduler_kwargs, - warmup_ratio = warmup_ratio, - warmup_steps = warmup_steps, - log_level = log_level, - log_level_replica = log_level_replica, - log_on_each_node = log_on_each_node, - logging_dir = logging_dir, - logging_strategy = logging_strategy, - logging_first_step = logging_first_step, - logging_steps = logging_steps, - logging_nan_inf_filter = logging_nan_inf_filter, - save_strategy = save_strategy, - save_steps = save_steps, - save_total_limit = save_total_limit, - save_safetensors = save_safetensors, - save_on_each_node = save_on_each_node, - save_only_model = save_only_model, - restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint, - no_cuda = no_cuda, - use_cpu = use_cpu, - use_mps_device = use_mps_device, - seed = seed, - data_seed = data_seed, - jit_mode_eval = jit_mode_eval, - bf16 = bf16, - fp16 = fp16, - fp16_opt_level = fp16_opt_level, - half_precision_backend = half_precision_backend, - bf16_full_eval = bf16_full_eval, - fp16_full_eval = fp16_full_eval, - tf32 = tf32, - local_rank = local_rank, - ddp_backend = ddp_backend, - tpu_num_cores = tpu_num_cores, - tpu_metrics_debug = tpu_metrics_debug, - debug = debug, - dataloader_drop_last = dataloader_drop_last, - eval_steps = eval_steps, - dataloader_num_workers = dataloader_num_workers, - dataloader_prefetch_factor = dataloader_prefetch_factor, - past_index = past_index, - run_name = run_name, - disable_tqdm = disable_tqdm, - remove_unused_columns = remove_unused_columns, - label_names = label_names, - load_best_model_at_end = load_best_model_at_end, - metric_for_best_model = metric_for_best_model, - greater_is_better = greater_is_better, - ignore_data_skip = ignore_data_skip, - fsdp = fsdp, - fsdp_min_num_params = fsdp_min_num_params, - fsdp_config = fsdp_config, - fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap, - accelerator_config = accelerator_config, - parallelism_config = parallelism_config, - deepspeed = deepspeed, - label_smoothing_factor = label_smoothing_factor, - optim = optim, - optim_args = optim_args, - adafactor = adafactor, - group_by_length = group_by_length, - length_column_name = length_column_name, - report_to = report_to, - project = project, - trackio_space_id = trackio_space_id, - ddp_find_unused_parameters = ddp_find_unused_parameters, - ddp_bucket_cap_mb = ddp_bucket_cap_mb, - ddp_broadcast_buffers = ddp_broadcast_buffers, - dataloader_pin_memory = dataloader_pin_memory, - dataloader_persistent_workers = dataloader_persistent_workers, - skip_memory_metrics = skip_memory_metrics, - use_legacy_prediction_loop = use_legacy_prediction_loop, - push_to_hub = push_to_hub, - resume_from_checkpoint = resume_from_checkpoint, - hub_model_id = hub_model_id, - hub_strategy = hub_strategy, - hub_token = hub_token, - hub_private_repo = hub_private_repo, - hub_always_push = hub_always_push, - hub_revision = hub_revision, - gradient_checkpointing = gradient_checkpointing, - gradient_checkpointing_kwargs = gradient_checkpointing_kwargs, - include_inputs_for_metrics = include_inputs_for_metrics, - eval_do_concat_batches = eval_do_concat_batches, - fp16_backend = fp16_backend, - push_to_hub_model_id = push_to_hub_model_id, - push_to_hub_organization = push_to_hub_organization, - push_to_hub_token = push_to_hub_token, - mp_parameters = mp_parameters, - auto_find_batch_size = auto_find_batch_size, - full_determinism = full_determinism, - torchdynamo = torchdynamo, - ray_scope = ray_scope, - ddp_timeout = ddp_timeout, - torch_compile = torch_compile, - torch_compile_backend = torch_compile_backend, - torch_compile_mode = torch_compile_mode, - include_tokens_per_second = include_tokens_per_second, - include_num_input_tokens_seen = include_num_input_tokens_seen, - neftune_noise_alpha = neftune_noise_alpha, - optim_target_modules = optim_target_modules, - batch_eval_metrics = batch_eval_metrics, - eval_on_start = eval_on_start, - use_liger_kernel = use_liger_kernel, - liger_kernel_config = liger_kernel_config, - eval_use_gather_object = eval_use_gather_object, - average_tokens_across_devices = average_tokens_across_devices, - dataset_num_proc = dataset_num_proc, - num_mini_batches = num_mini_batches, - total_episodes = total_episodes, - local_rollout_forward_batch_size = local_rollout_forward_batch_size, - num_sample_generations = num_sample_generations, - response_length = response_length, - stop_token = stop_token, - stop_token_id = stop_token_id, - temperature = temperature, - missing_eos_penalty = missing_eos_penalty, - sft_model_path = sft_model_path, - world_size = world_size, - num_total_batches = num_total_batches, - micro_batch_size = micro_batch_size, - local_batch_size = local_batch_size, - batch_size = batch_size, - local_mini_batch_size = local_mini_batch_size, - mini_batch_size = mini_batch_size, - exp_name = exp_name, - reward_model_path = reward_model_path, - model_adapter_name = model_adapter_name, - ref_adapter_name = ref_adapter_name, - num_ppo_epochs = num_ppo_epochs, - whiten_rewards = whiten_rewards, - kl_coef = kl_coef, - kl_estimator = kl_estimator, - cliprange = cliprange, - vf_coef = vf_coef, - cliprange_value = cliprange_value, - gamma = gamma, - lam = lam, - ds3_gather_for_generation = ds3_gather_for_generation,**kwargs) - self.vllm_sampling_params = vllm_sampling_params - self.unsloth_num_chunks = unsloth_num_chunks - if unsloth_grpo_mini_batch is not None: - if self.generation_batch_size >= unsloth_grpo_mini_batch: - self.unsloth_grpo_mini_batch = unsloth_grpo_mini_batch - else: - raise ValueError( - f"Unsloth GRPO mini batch size needs to be less than or equal to the effective generation batch size, " - f"which is self.per_device_train_batch_size * gradient_accumulation_steps." - ) - self.unsloth_logit_chunk_multiplier = unsloth_logit_chunk_multiplier - - -pass - -class _UnslothPPOTrainer(BaseTrainer): - """""" - - _tag_names = ["trl", "ppo"] - _name = "PPO" - _paper = { - "title": "Fine-Tuning Language Models from Human Preferences", - "id": "1909.08593", - # docstyle-ignore - "citation": textwrap.dedent("""\ - @article{mziegler2019fine-tuning, - title = {{Fine-Tuning Language Models from Human Preferences}}, - author = {Daniel M. Ziegler and Nisan Stiennon and Jeffrey Wu and Tom B. Brown and Alec Radford and Dario Amodei and Paul F. Christiano and Geoffrey Irving}, - year = 2019, - eprint = {arXiv:1909.08593} - }"""), - } - - def __init__( - self, - args: PPOConfig, - processing_class: Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin], - model: nn.Module, - ref_model: Optional[nn.Module], - reward_model: nn.Module, - train_dataset: Dataset, - value_model: nn.Module, - data_collator: Optional[DataCollatorWithPadding] = None, - eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None, - # less commonly used - optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), - callbacks: Optional[list[TrainerCallback]] = None, - peft_config: Optional["PeftConfig"] = None, - ) -> None: - if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"): - warnings.warn( - "This trainer will soon be moved to trl.experimental and is a candidate for removal. If you rely on " - "it and want it to remain, please share your comments here: " - "https://github.com/huggingface/trl/issues/4223. Silence this warning by setting environment variable " - "TRL_EXPERIMENTAL_SILENCE=1." - ) - if ref_model is model: - raise ValueError( - "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the " - "same as `model`, you must make a copy of it, or `None` if you use peft." - ) - - self.args = args - self.processing_class = processing_class - self.policy_model = model - - # Define the collator if not provided - if data_collator is None: - data_collator = DataCollatorWithPadding(self.processing_class) - - # Handle stop token settings: update policy model's generation_config to use provided stop token - if args.stop_token and args.stop_token_id: - raise ValueError("You cannot set both `stop_token` and `stop_token_id`.") - elif args.stop_token: - if args.stop_token == "eos": - self.policy_model.generation_config.eos_token_id = self.stop_token_id = processing_class.eos_token_id - else: - raise ValueError( - f"Unknown `stop_token` {args.stop_token}. Allowed values are: `'eos'` and `None` (no stop token)." - ) - else: - self.policy_model.generation_config.eos_token_id = self.stop_token_id = args.stop_token_id # None or int - - # Check that the kl estimator is valid - if self.args.kl_estimator not in {"k1", "k3"}: - raise ValueError( - "kl_estimator must be either 'k1' (straightforward, unbiased) or 'k3' (lower variance, unbiased, " - "appears to be a strictly better estimator). See " - "[Approximating KL Divergence](http://joschu.net/blog/kl-approx.html) for details." - ) - - # peft support - if not is_peft_available() and peft_config is not None: - raise ImportError( - "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models" - ) - elif is_peft_available() and peft_config is not None: - # if model is a peft model and we have a peft_confg, we merge and unload it first - if isinstance(self.policy_model, PeftModel): - self.policy_model = self.policy_model.merge_and_unload() - - # get peft model with the given config - self.policy_model = get_peft_model(self.policy_model, peft_config) - if args.bf16 and getattr(self.policy_model, "is_loaded_in_4bit", False): - peft_module_casting_to_bf16(self.policy_model) - - self.is_peft_model = is_peft_available() and isinstance(self.policy_model, PeftModel) - self.model_adapter_name = args.model_adapter_name - self.ref_adapter_name = args.ref_adapter_name - - if ref_model: - self.ref_model = ref_model - elif self.is_peft_model: - self.ref_model = None - else: - self.ref_model = create_reference_model(self.policy_model) - - self.reward_model = reward_model - self.train_dataset = train_dataset - self.train_dataset_len = len(train_dataset) - self.value_model = value_model - self.data_collator = data_collator - self.eval_dataset = eval_dataset - self.optimizer, self.lr_scheduler = optimizers - self.optimizer_cls_and_kwargs = None # needed for transformers >= 4.47 - - ######### - # calculate various batch sizes - ######### - if args.total_episodes is None: # allow the users to define episodes in terms of epochs. - args.total_episodes = int(args.num_train_epochs * self.train_dataset_len) - accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps) - self.accelerator = accelerator - args.world_size = accelerator.num_processes - args.local_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps - args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size) - args.batch_size = int(args.local_batch_size * args.world_size) - args.mini_batch_size = exact_div( - args.batch_size, args.num_mini_batches, "`batch_size` must be a multiple of `num_mini_batches`" - ) - args.local_mini_batch_size = exact_div( - args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`" - ) - if args.whiten_rewards: - assert args.local_mini_batch_size >= 8, ( - f"Per-rank minibatch size {args.local_mini_batch_size} is insufficient for whitening" - ) - # `per_rank_rollout_batch_size` is our `args.local_batch_size` - # `per_rank_minibatch_size` is our `args.local_mini_batch_size` - args.num_total_batches = math.ceil( - args.total_episodes / args.batch_size - ) # we may train for more than `total_episodes` - time_tensor = torch.tensor(int(time.time()), device=accelerator.device) - time_int = broadcast(time_tensor, 0).item() # avoid different timestamps across processes - args.run_name = f"{args.exp_name}__{args.seed}__{time_int}" - self.local_seed = args.seed + accelerator.process_index * 100003 # Prime - if args.num_sample_generations > 0: - self.sample_generations_freq = max(1, args.num_total_batches // args.num_sample_generations) - self.local_dataloader_batch_size = args.local_batch_size - - ######### - # setup model, optimizer, and others - ######### - for module in [self.policy_model, self.ref_model, self.value_model, self.reward_model]: - if module is not None: - disable_dropout_in_model(module) - self.model = PolicyAndValueWrapper(self.policy_model, self.value_model) - self.model.config = self.policy_model.config # needed for pushing to hub - self.create_optimizer_and_scheduler( - num_training_steps=args.num_total_batches - ) # note that we are calling `self.lr_scheduler.step[]` manually only at the batch level - - ######### - # trainer specifics - ######### - default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to) - self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks - self.callback_handler = CallbackHandler( - self.callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler - ) - self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK) - self.control = TrainerControl() - self.state = OnlineTrainerState( - is_local_process_zero=self.is_local_process_zero(), - is_world_process_zero=self.is_world_process_zero(), - stateful_callbacks=[ - cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState) - ], - ) - self.current_flos = 0 - self.hp_search_backend = None - self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None - self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None - # Create distant repo and output directory if needed - self.hub_model_id = None - if self.args.push_to_hub: - self.init_hf_repo() - if self.args.should_save: - os.makedirs(self.args.output_dir, exist_ok=True) - - # Add tags for models that have been loaded with the correct transformers version - if hasattr(self.model, "add_model_tags"): - self.model.add_model_tags(self._tag_names) - - ######### - # setup dataloader - ######### - self.dataloader = DataLoader( - self.train_dataset, - batch_size=self.local_dataloader_batch_size, - shuffle=True, - collate_fn=self.data_collator, - drop_last=True, # needed; otherwise the last batch will be of ragged shape - ) - # sync random states for DataLoader[shuffle=True] before `accelerator.prepare` - # see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c - torch.manual_seed(args.seed) - self.model, self.optimizer, self.dataloader = accelerator.prepare(self.model, self.optimizer, self.dataloader) - torch.manual_seed(self.local_seed) # reset the local seed again - - self.eval_dataloader = DataLoader( - self.eval_dataset, - batch_size=args.per_device_eval_batch_size, - collate_fn=self.data_collator, - drop_last=True, - ) # no need to shuffle eval dataset - self.eval_dataloader = accelerator.prepare(self.eval_dataloader) - - if self.is_deepspeed_enabled: - self.reward_model = prepare_deepspeed( - self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16 - ) - - if self.ref_model is None: - if not self.is_peft_model: - raise ValueError("No reference model and model is not a Peft model.") - else: - self.ref_model = prepare_deepspeed( - self.ref_model, args.per_device_train_batch_size, args.fp16, args.bf16 - ) - else: - if self.ref_model is None: - if not self.is_peft_model: - raise ValueError("No reference model and model is not a Peft model.") - else: - self.ref_model = self.ref_model.to(self.accelerator.device) - self.reward_model = self.reward_model.to(self.accelerator.device) - - def get_train_dataloader(self) -> DataLoader: - return self.dataloader - - def get_eval_dataloader(self) -> DataLoader: - return self.eval_dataloader - - @contextmanager - def null_ref_context(self): - """Context manager for handling null reference model (that is, peft adapter manipulation).""" - with ( - self.accelerator.unwrap_model(self.model.policy).disable_adapter() - if self.is_peft_model and not self.ref_adapter_name - else nullcontext() - ): - if self.ref_adapter_name: - self.model.policy.set_adapter(self.ref_adapter_name) - yield - if self.ref_adapter_name: - self.model.policy.set_adapter(self.model_adapter_name or "default") - - def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False): - backup_model = self.model - self.model = self.model.policy # save only the policy - - if self.is_deepspeed_enabled: - backup_deepspeed = self.deepspeed - self.deepspeed = self.model - - super().save_model(output_dir, _internal_call) - - self.model = backup_model - - if self.is_deepspeed_enabled: - self.deepspeed = backup_deepspeed - - def train(self): - args = self.args - accelerator = self.accelerator - optimizer = self.optimizer - model = self.model - ref_policy = self.ref_model - reward_model = self.reward_model - processing_class = self.processing_class - dataloader = self.dataloader - device = accelerator.device - - def repeat_generator(): - while True: - yield from dataloader - - iter_dataloader = iter(repeat_generator()) - generation_config = GenerationConfig( - max_new_tokens=args.response_length, - temperature=(args.temperature + 1e-7), - top_k=0.0, - top_p=1.0, - do_sample=True, - ) - - accelerator.print("===training policy===") - start_time = time.time() - stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps) - approxkl_stats = torch.zeros(stats_shape, device=device) - pg_clipfrac_stats = torch.zeros(stats_shape, device=device) - pg_loss_stats = torch.zeros(stats_shape, device=device) - vf_loss_stats = torch.zeros(stats_shape, device=device) - vf_clipfrac_stats = torch.zeros(stats_shape, device=device) - entropy_stats = torch.zeros(stats_shape, device=device) - ratio_stats = torch.zeros(stats_shape, device=device) - model.train() - - # trainer state initialization - self.state.global_step = 0 - self.state.episode = 0 - self.state.max_steps = args.num_total_batches - self.state.num_train_epochs = args.total_episodes / self.train_dataset_len - # Compute absolute values for logging, eval, and save if given as ratio - if args.logging_steps is not None: - if args.logging_steps < 1: - self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps) - else: - self.state.logging_steps = args.logging_steps - if args.eval_steps is not None: - if args.eval_steps < 1: - self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps) - else: - self.state.eval_steps = args.eval_steps - if args.save_steps is not None: - if args.save_steps < 1: - self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps) - else: - self.state.save_steps = args.save_steps - self.control = self.callback_handler.on_train_begin(args, self.state, self.control) - - # backward compatibility - if self.is_deepspeed_enabled: - self.deepspeed = self.model - self.model_wrapped = self.model - - for update in range(1, args.num_total_batches + 1): - self.state.episode += 1 * args.batch_size - data = next(iter_dataloader) - with torch.no_grad(): - queries = data["input_ids"].to(device) - context_length = queries.shape[1] - responses = [] - postprocessed_responses = [] - logprobs = [] - ref_logprobs = [] - scores = [] - sequence_lengths = [] - values = [] - with unwrap_model_for_generation( - self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation - ) as unwrapped_model: - query_responses, logitss = batch_generation( - unwrapped_model.policy, - queries, - args.local_rollout_forward_batch_size, - processing_class.pad_token_id, - generation_config, - ) - - for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size): - query = queries[i : i + args.local_rollout_forward_batch_size] - query_response = query_responses[i : i + args.local_rollout_forward_batch_size] - response = query_response[:, context_length:] - logits = logitss[i : i + args.local_rollout_forward_batch_size] - logprob = selective_log_softmax(logits, response) - del logits - empty_cache() - - if ref_policy is None: - with self.null_ref_context(): - ref_output = forward(model.policy, query_response, processing_class.pad_token_id) - else: - ref_output = forward(ref_policy, query_response, processing_class.pad_token_id) - ref_logits = ref_output.logits[:, context_length - 1 : -1] - ref_logits /= args.temperature + 1e-7 - ref_logprob = selective_log_softmax(ref_logits, response) - del ref_output, ref_logits - empty_cache() - - # Response Processing 1. truncate response after the first occurrence of `stop_token_id` - postprocessed_response = response - if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 - postprocessed_response = truncate_response( - self.stop_token_id, processing_class.pad_token_id, response - ) - - # Response Processing 2. run reward model on the truncated responses - postprocessed_query_response = torch.cat((query, postprocessed_response), 1) - sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1 - unwrapped_value_model = accelerator.unwrap_model(model).value_model - full_value, _, _ = get_reward( - unwrapped_value_model, query_response, processing_class.pad_token_id, context_length - ) - value = full_value[:, context_length - 1 : -1].squeeze(-1) - _, score, _ = get_reward( - reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length - ) - - responses.append(response) - postprocessed_responses.append(postprocessed_response) - logprobs.append(logprob) - ref_logprobs.append(ref_logprob) - sequence_lengths.append(sequence_length) - scores.append(score) - values.append(value) - responses = torch.cat(responses, 0) - postprocessed_responses = torch.cat(postprocessed_responses, 0) - logprobs = torch.cat(logprobs, 0) - ref_logprobs = torch.cat(ref_logprobs, 0) - sequence_lengths = torch.cat(sequence_lengths, 0) - scores = torch.cat(scores, 0) - values = torch.cat(values, 0) - del (logprob, ref_logprob, full_value, value, score, unwrapped_model) - empty_cache() - gc.collect() - - # Response Processing 3. Filter completion. Ensure that the sample contains stop_token_id - # Completions not passing that filter will receive a lower score. - contain_eos_token = torch.any(postprocessed_responses == self.processing_class.eos_token_id, dim=-1) - if self.args.missing_eos_penalty is not None: - scores[~contain_eos_token] -= self.args.missing_eos_penalty - # accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}") - - # be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw - response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1) - padding_mask = response_idxs > sequence_lengths.unsqueeze(1) - logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB) - ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB) - sequence_lengths_p1 = sequence_lengths + 1 - padding_mask_p1 = response_idxs > (sequence_lengths_p1.unsqueeze(1)) - values = torch.masked_fill(values, padding_mask_p1, 0) - - # 4. compute rewards - # Formula used by http://joschu.net/blog/kl-approx.html for the k1 and k3 estimators - logr = ref_logprobs - logprobs - kl = -logr if args.kl_estimator == "k1" else (logr.exp() - 1) - logr # Else statement is k3 - non_score_reward = -args.kl_coef * kl - rewards = non_score_reward.clone() - actual_start = torch.arange(rewards.size(0), device=rewards.device) - actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths) - rewards[[actual_start, actual_end]] += scores - - # 5. whiten rewards - if args.whiten_rewards: - rewards = masked_whiten(rewards, mask=~padding_mask_p1, shift_mean=False) - rewards = torch.masked_fill(rewards, padding_mask_p1, 0) - - # 6. compute advantages and returns - lastgaelam = 0 - advantages_reversed = [] - gen_length = responses.shape[1] - for t in reversed(range(gen_length)): - nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0 - delta = rewards[:, t] + args.gamma * nextvalues - values[:, t] - lastgaelam = delta + args.gamma * args.lam * lastgaelam - advantages_reversed.append(lastgaelam) - advantages = torch.stack(advantages_reversed[::-1], axis=1) - returns = advantages + values - advantages = masked_whiten(advantages, ~padding_mask) - advantages = torch.masked_fill(advantages, padding_mask, 0) - empty_cache() - - # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch - for ppo_epoch_idx in range(args.num_ppo_epochs): - b_inds = np.random.permutation(args.local_batch_size) - minibatch_idx = 0 - for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size): - mini_batch_end = mini_batch_start + args.local_mini_batch_size - mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] - gradient_accumulation_idx = 0 - for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size): - with accelerator.accumulate(model): - micro_batch_end = micro_batch_start + args.per_device_train_batch_size - micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] - mb_advantage = advantages[micro_batch_inds] - mb_responses = responses[micro_batch_inds] - mb_query_responses = query_responses[micro_batch_inds] - mb_logprobs = logprobs[micro_batch_inds] - mb_return = returns[micro_batch_inds] - mb_values = values[micro_batch_inds] - - output, vpred_temp = forward(model, mb_query_responses, processing_class.pad_token_id) - logits = output.logits[:, context_length - 1 : -1] - logits /= args.temperature + 1e-7 - new_logprobs = selective_log_softmax(logits, mb_responses) - new_logprobs = torch.masked_fill( - new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB - ) - vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) - vpred = torch.masked_fill(vpred, padding_mask_p1[micro_batch_inds], 0) - vpredclipped = torch.clamp( - vpred, - mb_values - args.cliprange_value, - mb_values + args.cliprange_value, - ) - vf_losses1 = torch.square(vpred - mb_return) - vf_losses2 = torch.square(vpredclipped - mb_return) - vf_loss_max = torch.max(vf_losses1, vf_losses2) - vf_loss = 0.5 * masked_mean(vf_loss_max, ~padding_mask_p1[micro_batch_inds]) - vf_clipfrac = masked_mean( - (vf_losses2 > vf_losses1).float(), ~padding_mask_p1[micro_batch_inds] - ) - logprobs_diff = new_logprobs - mb_logprobs - ratio = torch.exp(logprobs_diff) - pg_losses = -mb_advantage * ratio - pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange) - pg_loss_max = torch.max(pg_losses, pg_losses2) - pg_loss = masked_mean(pg_loss_max, ~padding_mask[micro_batch_inds]) - loss = pg_loss + args.vf_coef * vf_loss - accelerator.backward(loss) - optimizer.step() - optimizer.zero_grad() - with torch.no_grad(): - pg_clipfrac = masked_mean( - (pg_losses2 > pg_losses).float(), ~padding_mask[micro_batch_inds] - ) - prob_dist = torch.nn.functional.softmax(logits, dim=-1, dtype = torch.float32).to(logits.dtype) - entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) - approxkl = 0.5 * (logprobs_diff**2).mean() - approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl - pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ( - pg_clipfrac - ) - pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss - vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss - vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ( - vf_clipfrac - ) - entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() - ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean() - gradient_accumulation_idx += 1 - minibatch_idx += 1 - # del everything and empty cache - # fmt: off - del ( - output, vpred_temp, logits, new_logprobs, vpred, vpredclipped, - vf_losses1, vf_losses2, vf_loss, vf_clipfrac, logprobs_diff, ratio, pg_losses, pg_losses2, pg_loss_max, - pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, mb_return, - mb_advantage, mb_values, mb_responses, mb_query_responses, mb_logprobs, - ) - # fmt: on - empty_cache() - with torch.no_grad(): - mean_kl = kl.sum(1).mean() - mean_entropy = (-logprobs).sum(1).mean() - mean_non_score_reward = non_score_reward.sum(1).mean() - rlhf_reward = mean_non_score_reward + scores.mean() - eps = int(self.state.episode / (time.time() - start_time)) - metrics = {} - metrics["eps"] = eps - metrics["objective/kl"] = self.accelerator.gather_for_metrics(mean_kl).mean().item() - metrics["objective/entropy"] = self.accelerator.gather_for_metrics(mean_entropy).mean().item() - metrics["objective/non_score_reward"] = ( - self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item() - ) - metrics["objective/rlhf_reward"] = self.accelerator.gather_for_metrics(rlhf_reward).mean().item() - metrics["objective/scores"] = self.accelerator.gather_for_metrics(scores.mean()).mean().item() - metrics["policy/approxkl_avg"] = self.accelerator.gather_for_metrics(approxkl_stats).mean().item() - metrics["policy/clipfrac_avg"] = self.accelerator.gather_for_metrics(pg_clipfrac_stats).mean().item() - metrics["loss/policy_avg"] = self.accelerator.gather_for_metrics(pg_loss_stats).mean().item() - metrics["loss/value_avg"] = self.accelerator.gather_for_metrics(vf_loss_stats).mean().item() - metrics["val/clipfrac_avg"] = self.accelerator.gather_for_metrics(vf_clipfrac_stats).mean().item() - metrics["policy/entropy_avg"] = self.accelerator.gather_for_metrics(entropy_stats).mean().item() - metrics["val/ratio"] = self.accelerator.gather_for_metrics(ratio_stats).mean().item() - metrics["val/ratio_var"] = self.accelerator.gather_for_metrics(ratio_stats).var().item() - metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item() - metrics["lr"] = self.lr_scheduler.get_last_lr()[0] - metrics["episode"] = self.state.episode - self.state.epoch = self.state.episode / self.train_dataset_len # used by self.log - self.state.global_step += 1 - self.log(metrics) - - self.lr_scheduler.step() - self.control = self.callback_handler.on_step_end(args, self.state, self.control) - if self.control.should_save: - self._save_checkpoint(model, trial=None) - self.control = self.callback_handler.on_save(self.args, self.state, self.control) - del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_reward - empty_cache() - gc.collect() - - if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0: - self.generate_completions(sampling=True) - empty_cache() - del ( - query_responses, - responses, - postprocessed_responses, - logprobs, - ref_logprobs, - values, - sequence_lengths, - contain_eos_token, - sequence_lengths_p1, - response_idxs, - padding_mask, - padding_mask_p1, - rewards, - actual_start, - actual_end, - advantages, - returns, - ) - empty_cache() - - # HF trainer specifics - self.control = self.callback_handler.on_train_end(args, self.state, self.control) - if self.control.should_save: - self._save_checkpoint(model, trial=None) - self.control = self.callback_handler.on_save(self.args, self.state, self.control) - - def generate_completions(self, sampling: bool = False): - args = self.args - processing_class = self.processing_class - generation_config = GenerationConfig( - max_new_tokens=self.args.response_length, - temperature=(0.01 + 1e-7), - top_k=0.0, - top_p=1.0, - do_sample=True, - ) - - table = defaultdict(list) - with unwrap_model_for_generation( - self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation - ) as unwrapped_model: - for batch in self.eval_dataloader: - query = batch["input_ids"] - with torch.no_grad(): - context_length = query.shape[1] - query_response, _ = batch_generation( - unwrapped_model.policy, - query, - query.shape[0], - processing_class.pad_token_id, - generation_config, - ) - response = query_response[:, context_length:] - postprocessed_response = response - if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 - postprocessed_response = truncate_response( - self.stop_token_id, processing_class.pad_token_id, response - ) - table["query"].extend( - gather_object(processing_class.batch_decode(query, skip_special_tokens=True)) - ) - table["model response"].extend( - gather_object(processing_class.batch_decode(postprocessed_response)) - ) - - postprocessed_query_response = torch.cat((query, postprocessed_response), 1) - _, score, _ = get_reward( - self.reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length - ) - table["score"].extend(self.accelerator.gather_for_metrics(score).float().cpu().numpy()) - - if sampling: - break - df = pd.DataFrame(table) - - if self.accelerator.is_main_process: - if is_rich_available(): - print_rich_table(df.iloc[0 : 0 + 5]) - if "wandb" in args.report_to: - import wandb - - if wandb.run is not None: - wandb.log({"completions": wandb.Table(dataframe=df)}) - - if "comet_ml" in args.report_to: - log_table_to_comet_experiment( - name="completions.csv", - table=df, - ) - - # Ensure the model card is saved along with the checkpoint - def _save_checkpoint(self, model, trial): - if self.args.hub_model_id is None: - model_name = Path(self.args.output_dir).name - else: - model_name = self.args.hub_model_id.split("/")[-1] - self.create_model_card(model_name=model_name) - super()._save_checkpoint(model, trial) -class UnslothPPOTrainer(_UnslothPPOTrainer): - """ - Trainer for Proximal Policy Optimization (PPO). - - For details on PPO, see the paper: [Proximal Policy Optimization - Algorithms](https://huggingface.co/papers/1707.06347). - - Args: - args ([`PPOConfig`]): - Training arguments. - processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`]): - Class to process the data. - model (`torch.nn.Module`): - Model to be trained. This is the policy model. - ref_model (`torch.nn.Module`, *optional*): - Reference model used to compute the KL divergence. If `None`, a copy of the policy model is created. - reward_model (`torch.nn.Module`): - Reward model used to compute the rewards. - train_dataset ([`~datasets.Dataset`]): - Dataset for training. - value_model (`torch.nn.Module`): - Value model used to predict the value of a state. - data_collator ([`~transformers.DataCollatorWithPadding`], *optional*): - Data collator to batch and pad samples from the dataset. If `None`, a default data collator is created - using the `processing_class`. - eval_dataset ([`~datasets.Dataset`] or `dict` of [`~datasets.Dataset`], *optional*): - Dataset for evaluation. - optimizers (`tuple` of `torch.optim.Optimizer` and `torch.optim.lr_scheduler.LambdaLR`, *optional*, defaults to `(None, None)`): - Tuple containing the optimizer and the learning rate scheduler to use for training. If `None`, the - optimizer and the learning rate scheduler are created using the - [`~transformers.Trainer.create_optimizer_and_scheduler`] method. - callbacks (`list` of [`~transformers.TrainerCallback`], *optional*): - Callbacks to use during training. - peft_config ([`~peft.PeftConfig`], *optional*): - PEFT configuration to use PEFT for training. If `None`, PEFT is not used. If provided, the policy `model` - will be wrapped with the specified PEFT adapter. - - """ - def __init__( - self, - args, - processing_class, - model, - ref_model, - reward_model, - train_dataset, - value_model, - data_collator = None, - eval_dataset = None, - callbacks = None, - peft_config = None, - **kwargs - ): - if args is None: args = UnslothPPOConfig() - use_bf16 = getattr(args, 'bf16', False) - if type(use_bf16) is not bool: use_bf16 = False - use_fp16 = getattr(args, 'fp16', False) - if type(use_fp16) is not bool: use_fp16 = False - force_float32 = False - full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1' - if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'): - print('Unsloth: Switching to float32 training since model cannot work with float16') - force_float32 = True - mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') - dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None) - if dtype is None: dtype = model.get_input_embeddings().weight.dtype - from unsloth_zoo.utils import _get_dtype - dtype = _get_dtype(dtype) - float16 = dtype == torch.float16 - if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`') - if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`') - if force_float32: - # Forced float32 training - args.fp16 = False - args.bf16 = False - os.environ['ACCELERATE_MIXED_PRECISION'] = 'no' - if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no' - # args.mixed_precision is a new argument which needs to be set now - elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32': - # Mixed precision training - args.fp16 = float16 - args.bf16 = not float16 - os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16' - if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16' - # args.mixed_precision is a new argument which needs to be set now - elif mixed_precision_dtype == 'bfloat16': - # Both False since bfloat16 full finetuning doesn't do any autocasting. - args.fp16 = False - args.bf16 = False - os.environ['ACCELERATE_MIXED_PRECISION'] = 'no' - if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no' - # args.mixed_precision is a new argument which needs to be set now - - if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no': - args.eval_strategy = 'steps' - if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1 - ga_steps = getattr(args, 'gradient_accumulation_steps', None) - if ga_steps is not None and ga_steps > 1: - from transformers import __version__ as transformers_version - if Version(transformers_version) <= Version('4.45.2'): - print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n' - '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`') - if getattr(args, 'eval_strategy', 'no') != 'no': - eval_bsz = getattr(args, 'per_device_eval_batch_size', 8) - if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size - if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps - fp16_full_eval = getattr(args, 'fp16_full_eval', False) - if type(fp16_full_eval) is not bool: fp16_full_eval = False - bf16_full_eval = getattr(args, 'bf16_full_eval', False) - if type(bf16_full_eval) is not bool: bf16_full_eval = False - if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True - if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False - if force_float32: - args.bf16_full_eval = False - args.fp16_full_eval = False - elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16': - args.bf16_full_eval = True - args.fp16_full_eval = False - elif not bf16_full_eval and not fp16_full_eval: - args.bf16_full_eval = args.bf16 - args.fp16_full_eval = args.fp16 - _output_logits = False - if locals().get('compute_metrics', None) is not None: _output_logits = True - if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True - if _output_logits: - os.environ['UNSLOTH_RETURN_LOGITS'] = '1' - if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'): - pass - else: - model_max_seq_length = getattr(model, 'max_seq_length', None) - args_max_seq_length = getattr(args, 'max_seq_length', None) - if args_max_seq_length is None and model_max_seq_length is not None: - max_seq_length = model.max_seq_length - if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length - elif args_max_seq_length is not None and model_max_seq_length is not None: - if args_max_seq_length > model_max_seq_length: - print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but ' - 'the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.') - args.max_seq_length = model_max_seq_length - if model is not None and hasattr(model, 'for_training'): - model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True)) - if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right' - if 'processing_class' in locals(): - if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right' - if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right' - __tokenizer = processing_class if 'processing_class' in locals() else tokenizer - from unsloth_zoo.vision_utils import UnslothVisionDataCollator - if not isinstance(data_collator, UnslothVisionDataCollator): - if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names: - data_collator = TransformersDataCollatorForLanguageModeling( - __tokenizer, - mlm = False, - mlm_probability = 0.0, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names: - data_collator = DataCollatorForSeq2Seq( - __tokenizer, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - else: - if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False - if hasattr(args, 'dataset_text_field'): args.dataset_text_field = '' - if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True} - if not isinstance(data_collator, UnslothVisionDataCollator): - if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'): - if isinstance(data_collator, DataCollatorForSeq2Seq): - data_collator = DataCollatorForSeq2Seq( - __tokenizer.tokenizer, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - else: - data_collator = TransformersDataCollatorForLanguageModeling( - __tokenizer.tokenizer, - mlm = False, - mlm_probability = 0.0, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - other_metrics = [] - - from unsloth_zoo.logging_utils import PatchRLStatistics - PatchRLStatistics('ppo_trainer', other_metrics) - - # [TODO] Fix up DataParallel multiplying batch sizes - # [TODO] DDP works, but DP seems to not work? [TODO] - if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1: - if getattr(args, "_n_gpu", 1) != 1: - args._n_gpu = 1 - if "model" in locals() and hasattr(model, "for_training"): - model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True)) - super().__init__( - args = args, - processing_class = processing_class, - model = model, - ref_model = ref_model, - reward_model = reward_model, - train_dataset = train_dataset, - value_model = value_model, - data_collator = data_collator, - eval_dataset = eval_dataset, - callbacks = callbacks, - peft_config = peft_config,**kwargs) - if "model" in locals() and hasattr(model, "for_inference"): - model.for_inference() - if hasattr(self, 'neftune_hook_handle'): - self.neftune_hook_handle.remove() - if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle - if getattr(args, 'neftune_noise_alpha', None) is not None: - model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha - pass - if hasattr(self, 'accelerator'): - scaler = self.accelerator.scaler - current_model = model - while hasattr(current_model, 'model'): - current_model.accelerator_scaler = scaler - current_model = current_model.model - current_model.accelerator_scaler = scaler - pass - if hasattr(self, 'train'): - self.train = MethodType(prepare_for_training_mode(self.__class__.train), self) - pass - if hasattr(self, 'llm') and self.llm is not None and hasattr(self.llm, 'get_tokenizer'): - _vllm_tok = self.llm.get_tokenizer() - _pc = getattr(self, 'processing_class', None) or getattr(self, 'tokenizer', None) - if _vllm_tok is not None and _pc is not None and getattr(_pc, 'chat_template', None) is not None and getattr(_vllm_tok, 'chat_template', None) is None: - _vllm_tok.chat_template = _pc.chat_template - pass - -pass diff --git a/unsloth_compiled_cache/UnslothPRMTrainer.py b/unsloth_compiled_cache/UnslothPRMTrainer.py deleted file mode 100644 index 8232c86..0000000 --- a/unsloth_compiled_cache/UnslothPRMTrainer.py +++ /dev/null @@ -1,1133 +0,0 @@ -""" -2026.2.1 -2026.2.1 -4.57.6 -0.24.0 -__UNSLOTH_VERSIONING__ -""" - -# Unsloth auto generated code -# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see . - -from torch import Tensor -import torch -import torch.nn as nn -from torch.nn import functional as F -from unsloth_zoo.temporary_patches.common import torch_compile -from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable -from trl.trainer.prm_trainer import (BaseImageProcessor, BaseTrainer, Callable, DataCollator, DataCollatorForTokenClassification, Dataset, EvalPrediction, FeatureExtractionMixin, Optional, PRMConfig, PRMTrainer, PartialState, Path, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, Union, chain, compute_accuracy, disable_dropout_in_model, features, nn, os, textwrap, torch, warnings, BaseImageProcessor, Callable, DataCollator, DataCollatorForTokenClassification, Dataset, EvalPrediction, FeatureExtractionMixin, Optional, PRMConfig, PartialState, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, Union, compute_accuracy, disable_dropout_in_model, features, nn, os, torch, warnings, Optional, PreTrainedModel, os, torch) - - -import os -from typing import * -from dataclasses import dataclass, field -from packaging.version import Version -import torch -import numpy as np -from contextlib import nullcontext -from torch.nn import functional as F -import inspect -from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling -from transformers.training_args import ParallelMode -from unsloth_zoo.device_type import DEVICE_TYPE, device_synchronize - -# Wrap trainer with padding to right and enable training mode -# Also patches W&B since multiple runs must use wandb.finish() -import functools -from types import MethodType -try: - from unsloth_zoo.gradient_checkpointing import reset_unsloth_gradient_checkpointing_buffers -except: - def reset_unsloth_gradient_checkpointing_buffers(): pass -def prepare_for_training_mode(f): - @functools.wraps(f) - def wrapper(self, *args, **kwargs): - # Enable training mode - _was_training = None - # Get gradient checkpointing setting from training arguments - use_gc = getattr(self.args, 'gradient_checkpointing', True) - if hasattr(self, 'model') and hasattr(self.model, "training"): - _was_training = self.model.training - if hasattr(self, 'model') and hasattr(self.model, "for_training"): - self.model.for_training(use_gradient_checkpointing=use_gc) - output = f(self, *args, **kwargs) - # Restore previous mode when possible - if hasattr(self, 'model') and hasattr(self.model, "for_inference"): - if _was_training is False: - self.model.for_inference() - elif _was_training is True and hasattr(self.model, "for_training"): - self.model.for_training(use_gradient_checkpointing=use_gc) - # Reset gradient checkpointing buffers to free memory while staying ready for next run - try: - reset_unsloth_gradient_checkpointing_buffers() - except: - pass - # Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run - try: - import wandb - wandb.finish() - except: - pass - return output - return wrapper -pass - -torch_compile_options = { - "epilogue_fusion" : True, - "max_autotune" : False, - "shape_padding" : True, - "trace.enabled" : False, - "triton.cudagraphs" : False, -} - -@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) -def chunked_hidden_states_selective_log_softmax( - hidden_states: torch.Tensor, - lm_head: torch.Tensor, - index: torch.Tensor, - chunks: int = 4, - logit_scale_multiply: float = 0.0, - logit_scale_divide: float = 0.0, - logit_softcapping: float = 0.0, - temperature: float = 1.0, -) -> torch.Tensor: - # All Unsloth Zoo code licensed under AGPL3 - flat_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1]) - flat_index = index.reshape(-1) - - chunked_hidden_states = torch.chunk(flat_hidden_states, chunks=chunks, dim=0) - chunked_index = torch.chunk(flat_index, chunks=chunks, dim=0) - - all_per_token_logps = [] - - for chunk_hidden_states, chunk_index in zip(chunked_hidden_states, chunked_index): - chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t() - - if logit_scale_multiply != 0.0: - chunk_logits = chunk_logits * logit_scale_multiply - if logit_scale_divide != 0.0: - chunk_logits = chunk_logits / logit_scale_divide - if logit_softcapping != 0.0: - chunk_logits = chunk_logits * torch.tanh(chunk_logits / logit_softcapping) - - chunk_logits = chunk_logits.to(torch.float32) - - if temperature != 1.0: - chunk_logits = chunk_logits / temperature - - selected_logits = torch.gather(chunk_logits, dim=-1, index=chunk_index.unsqueeze(-1)).squeeze(-1) - logsumexp_values = torch.logsumexp(chunk_logits, dim=-1) - per_token_logps = selected_logits - logsumexp_values - all_per_token_logps.append(per_token_logps) - - all_per_token_logps = torch.concat(all_per_token_logps) - - all_per_token_logps = all_per_token_logps.reshape((hidden_states.shape[0], hidden_states.shape[1])) - return all_per_token_logps - -@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) -def chunked_selective_log_softmax(logits, index): - # Split into 4 chunks only - chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0) - chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0) - all_per_token_logps = [] - # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index) - for chunk_logits, chunk_index in zip(chunked_logits, chunked_index): - chunk_logits = chunk_logits.to(torch.float32) - selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1) - logsumexp_values = torch.logsumexp(chunk_logits, dim = -1) - per_token_logps = selected_logits - logsumexp_values - all_per_token_logps.append(per_token_logps) - pass - all_per_token_logps = torch.concat(all_per_token_logps) - all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1])) - return all_per_token_logps - -def calculate_pad_tokens_in_prompt( - input_ids: torch.Tensor, - logits_to_keep: int, - pad_token_id: int -) -> torch.Tensor: - """ - Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens - """ - if logits_to_keep >= input_ids.shape[1]: - raise ValueError("logits_to_keep must be smaller than the sequence length.") - - prompt_section = input_ids[:, :-logits_to_keep] - - padding_mask = (prompt_section == pad_token_id) - - pad_token_counts = padding_mask.sum(dim=1) - - return pad_token_counts - -def create_completion_attention_mask( - completion_input_ids: torch.Tensor, - left_pad_tokens_per_prompt: torch.Tensor, - max_left_pad: int, - pad_token_id: int -) -> torch.Tensor: - """ - Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad] - - Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens - and pad are pad tokens, this function would make a completion mask that would 0 out the pad - and p tokens. so in this example [0,0,0,1,1,1,0,0,0] - """ - batch_size, completion_len = completion_input_ids.shape - device = completion_input_ids.device - - num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt - - indices = torch.arange(completion_len, device=device).unsqueeze(0) - shift_mask = indices >= num_tokens_to_mask.unsqueeze(1) - - non_padding_mask = (completion_input_ids != pad_token_id) - - final_mask = shift_mask & non_padding_mask - - return final_mask - -def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor: - """ - Moves all padding tokens in each sequence of a batch to the right. - """ - mask = (tensor != pad_id) - # Must do stable=True since binary mark is unordered - sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True) - packed_tensor = torch.gather(tensor, 1, sorted_indices) - return packed_tensor - -def align_logprobs_with_mask( - logprob_tensor: torch.Tensor, - attention_mask: torch.Tensor, - pad_value: float = 0.0 -) -> torch.Tensor: - """ - Aligns a log probability tensor with a given attention mask. - """ - - device = logprob_tensor.device - batch_size, logprob_seq_len = logprob_tensor.shape - mask_seq_len = attention_mask.shape[1] - - padded_logprobs = torch.full( - attention_mask.shape, - fill_value=pad_value, - dtype=logprob_tensor.dtype, - device=device - ) - - left_pad_counts = torch.argmax(attention_mask, dim=1) - - cols = torch.arange(logprob_seq_len, device=device) - dest_indices = left_pad_counts.unsqueeze(1) + cols - - # Create destination row indices - # Shape: [batch_size, logprob_seq_len] - row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices) - - # --- 4. Filter out-of-bounds indices and perform assignment --- - # Create a mask to identify only the indices that are within the bounds - # of the target tensor's sequence length. - valid_mask = dest_indices < mask_seq_len - - # Use this mask to select only the valid row indices, column indices, - # and the corresponding values from the logprob tensor. - # This flattens the selected elements into 1D tensors. - valid_rows = row_indices[valid_mask] - valid_cols = dest_indices[valid_mask] - valid_vals = logprob_tensor[valid_mask] - - # Place the valid values into their correct positions in the padded tensor - # using a single, efficient advanced indexing operation. - padded_logprobs[valid_rows, valid_cols] = valid_vals - - return padded_logprobs - -def autotune_batch_and_chunks( - total_input_rows, - seq_len, - hidden_size, - vocab_size, - dtype_bytes=16, - multiplier=None -): - if multiplier is None: - final_m = max(4, seq_len // 4096) - else: - final_m = multiplier - - if torch.cuda.is_available(): - free_bytes, _ = torch.cuda.mem_get_info() - limit_gb = (free_bytes / (1024**3))*.80 - elif hasattr(torch, "xpu") and torch.xpu.is_available(): - # For XPU: estimate free memory from total - reserved - total_mem = torch.xpu.get_device_properties(0).total_memory - reserved_mem = torch.xpu.memory_reserved() - free_bytes = total_mem - reserved_mem - limit_gb = (free_bytes / (1024**3)) * 0.80 - else: - # Fallback: assume 8GB available - limit_gb = 8.0 - - bytes_to_gb = 1024**3 - - b_vals = torch.arange(total_input_rows, 0, -1, device='cpu', dtype=torch.float32) - - hidden_gb = (b_vals * seq_len * hidden_size * dtype_bytes) / bytes_to_gb - - base_logits = ((b_vals/total_input_rows) * b_vals * seq_len * vocab_size * dtype_bytes) / bytes_to_gb - logits_gb = base_logits / final_m - - total_mem_gb = hidden_gb + logits_gb - - valid_mask = total_mem_gb <= limit_gb - valid_indices = torch.nonzero(valid_mask, as_tuple=False) - - if valid_indices.shape[0] == 0: - #This means your GPU will OOM - return 4, final_m - - best_idx = valid_indices[0].item() - final_b = int(b_vals[best_idx].item()) - - return final_b, final_m -@dataclass -class UnslothPRMConfig(PRMConfig): - """ - - Configuration class for the [`PRMTrainer`]. - - This class includes only the parameters that are specific to PRM training. For a full list of training arguments, - please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may - differ from those in [`~transformers.TrainingArguments`]. - - Using [`~transformers.HfArgumentParser`] we can turn this class into - [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the - command line. - - Parameters: - max_length (`int` or `None`, *optional*, defaults to `1024`): - Maximum length of the sequences (prompt + completion) used for truncation. - max_prompt_length (`int` or `None`, *optional*, defaults to `512`): - Maximum length of the prompt used for truncation. - max_completion_length (`int`, *optional*): - Maximum length of the completion used for truncation. The completion is the concatenation of the steps. - disable_dropout (`bool`, *optional*, defaults to `True`): - Whether to disable dropout in the model. - step_separator (`str`, *optional*, defaults to `"\n"`): - Separator used to separate each step of the reasoning process. - train_on_last_step_only (`bool`, *optional*, defaults to `False`): - Whether to train only on the last step. - dataset_num_proc (`int`, *optional*): - Number of processes to use for processing the dataset. - - """ - vllm_sampling_params: Optional[Any] = field( - default = None, - metadata = {'help': 'vLLM SamplingParams'}, - ) - unsloth_num_chunks : Optional[int] = field( - default = -1, - metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'}, - ) - unsloth_logit_chunk_multiplier : Optional[int] = field( - default = None, - metadata = {'help': 'Multiplier for chunked logit computations.'}, - ) - unsloth_grpo_mini_batch : Optional[int] = field( - default = None, - metadata = {'help': 'Mini batch size for GRPO hidden state accumulation. Default is None unless user defines it.'}, - ) - max_seq_length : Optional[int] = field( - default = None, - metadata = {'help': 'Maximum sequence length to truncate to.'}, - ) - def __init__( - self, - output_dir = None, - overwrite_output_dir = None, - do_train = False, - do_eval = False, - do_predict = False, - eval_strategy = 'no', - prediction_loss_only = False, - per_device_train_batch_size = 4, - per_device_eval_batch_size = 4, - per_gpu_train_batch_size = None, - per_gpu_eval_batch_size = None, - gradient_accumulation_steps = 2, - eval_accumulation_steps = 2, - eval_delay = 0, - torch_empty_cache_steps = 250, - learning_rate = 5e-05, - weight_decay = 0.01, - adam_beta1 = 0.9, - adam_beta2 = 0.999, - adam_epsilon = 1e-08, - max_grad_norm = 1.0, - num_train_epochs = 3.0, - max_steps = -1, - lr_scheduler_type = 'linear', - lr_scheduler_kwargs = None, - warmup_ratio = 0.1, - warmup_steps = 0, - log_level = 'passive', - log_level_replica = 'warning', - log_on_each_node = True, - logging_dir = None, - logging_strategy = 'steps', - logging_first_step = False, - logging_steps = 1, - logging_nan_inf_filter = False, - save_strategy = 'steps', - save_steps = 500, - save_total_limit = None, - save_safetensors = True, - save_on_each_node = False, - save_only_model = False, - restore_callback_states_from_checkpoint = False, - no_cuda = False, - use_cpu = False, - use_mps_device = False, - seed = 3407, - data_seed = 3407, - jit_mode_eval = False, - bf16 = False, - fp16 = False, - fp16_opt_level = 'O1', - half_precision_backend = 'auto', - bf16_full_eval = False, - fp16_full_eval = False, - tf32 = None, - local_rank = -1, - ddp_backend = None, - tpu_num_cores = None, - tpu_metrics_debug = False, - debug = '', - dataloader_drop_last = False, - eval_steps = None, - dataloader_num_workers = 0, - dataloader_prefetch_factor = None, - past_index = -1, - run_name = None, - disable_tqdm = None, - remove_unused_columns = True, - label_names = None, - load_best_model_at_end = False, - metric_for_best_model = None, - greater_is_better = None, - ignore_data_skip = False, - fsdp = None, - fsdp_min_num_params = 0, - fsdp_config = None, - fsdp_transformer_layer_cls_to_wrap = None, - accelerator_config = None, - parallelism_config = None, - deepspeed = None, - label_smoothing_factor = 0.0, - optim = 'adamw_8bit', - optim_args = None, - adafactor = False, - group_by_length = False, - length_column_name = 'length', - report_to = 'none', - project = 'huggingface', - trackio_space_id = 'trackio', - ddp_find_unused_parameters = None, - ddp_bucket_cap_mb = None, - ddp_broadcast_buffers = None, - dataloader_pin_memory = True, - dataloader_persistent_workers = False, - skip_memory_metrics = True, - use_legacy_prediction_loop = False, - push_to_hub = False, - resume_from_checkpoint = None, - hub_model_id = None, - hub_strategy = 'every_save', - hub_token = None, - hub_private_repo = None, - hub_always_push = False, - hub_revision = None, - gradient_checkpointing = True, - gradient_checkpointing_kwargs = None, - include_inputs_for_metrics = False, - eval_do_concat_batches = True, - fp16_backend = 'auto', - push_to_hub_model_id = None, - push_to_hub_organization = None, - push_to_hub_token = None, - mp_parameters = '', - auto_find_batch_size = False, - full_determinism = False, - torchdynamo = None, - ray_scope = 'last', - ddp_timeout = 1800, - torch_compile = False, - torch_compile_backend = None, - torch_compile_mode = None, - include_tokens_per_second = False, - include_num_input_tokens_seen = False, - neftune_noise_alpha = None, - optim_target_modules = None, - batch_eval_metrics = False, - eval_on_start = False, - use_liger_kernel = False, - liger_kernel_config = None, - eval_use_gather_object = False, - average_tokens_across_devices = True, - max_length = 1024, - max_prompt_length = 512, - max_completion_length = None, - disable_dropout = True, - step_separator = '\ -', - train_on_last_step_only = False, - dataset_num_proc = None, - vllm_sampling_params = None, - unsloth_num_chunks = -1, - unsloth_logit_chunk_multiplier = None, - unsloth_grpo_mini_batch = None, - max_seq_length = None, - **kwargs, - ): - if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!') - if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!') - if num_train_epochs is None: - num_train_epochs = 3.0 # Default to 3 epochs if None, max_steps will override - if output_dir is None and save_strategy == 'steps' and save_steps == 500: - output_dir = 'unsloth_training_checkpoints' - save_strategy = 'no' - import multiprocessing as _mp - if _mp.get_start_method() != 'fork': - dataset_num_proc = None - elif dataset_num_proc is None: - import psutil - dataset_num_proc = min(max((psutil.cpu_count() or 1)+4, 2), 64) - memory_gb_left = psutil.virtual_memory().available / (1024**3) - if memory_gb_left <= 2: dataset_num_proc = 1 - else: dataset_num_proc = min(dataset_num_proc, int(memory_gb_left)) - - super().__init__( - output_dir = output_dir, - overwrite_output_dir = overwrite_output_dir, - do_train = do_train, - do_eval = do_eval, - do_predict = do_predict, - eval_strategy = eval_strategy, - prediction_loss_only = prediction_loss_only, - per_device_train_batch_size = per_device_train_batch_size, - per_device_eval_batch_size = per_device_eval_batch_size, - per_gpu_train_batch_size = per_gpu_train_batch_size, - per_gpu_eval_batch_size = per_gpu_eval_batch_size, - gradient_accumulation_steps = gradient_accumulation_steps, - eval_accumulation_steps = eval_accumulation_steps, - eval_delay = eval_delay, - torch_empty_cache_steps = torch_empty_cache_steps, - learning_rate = learning_rate, - weight_decay = weight_decay, - adam_beta1 = adam_beta1, - adam_beta2 = adam_beta2, - adam_epsilon = adam_epsilon, - max_grad_norm = max_grad_norm, - num_train_epochs = num_train_epochs, - max_steps = max_steps, - lr_scheduler_type = lr_scheduler_type, - lr_scheduler_kwargs = lr_scheduler_kwargs, - warmup_ratio = warmup_ratio, - warmup_steps = warmup_steps, - log_level = log_level, - log_level_replica = log_level_replica, - log_on_each_node = log_on_each_node, - logging_dir = logging_dir, - logging_strategy = logging_strategy, - logging_first_step = logging_first_step, - logging_steps = logging_steps, - logging_nan_inf_filter = logging_nan_inf_filter, - save_strategy = save_strategy, - save_steps = save_steps, - save_total_limit = save_total_limit, - save_safetensors = save_safetensors, - save_on_each_node = save_on_each_node, - save_only_model = save_only_model, - restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint, - no_cuda = no_cuda, - use_cpu = use_cpu, - use_mps_device = use_mps_device, - seed = seed, - data_seed = data_seed, - jit_mode_eval = jit_mode_eval, - bf16 = bf16, - fp16 = fp16, - fp16_opt_level = fp16_opt_level, - half_precision_backend = half_precision_backend, - bf16_full_eval = bf16_full_eval, - fp16_full_eval = fp16_full_eval, - tf32 = tf32, - local_rank = local_rank, - ddp_backend = ddp_backend, - tpu_num_cores = tpu_num_cores, - tpu_metrics_debug = tpu_metrics_debug, - debug = debug, - dataloader_drop_last = dataloader_drop_last, - eval_steps = eval_steps, - dataloader_num_workers = dataloader_num_workers, - dataloader_prefetch_factor = dataloader_prefetch_factor, - past_index = past_index, - run_name = run_name, - disable_tqdm = disable_tqdm, - remove_unused_columns = remove_unused_columns, - label_names = label_names, - load_best_model_at_end = load_best_model_at_end, - metric_for_best_model = metric_for_best_model, - greater_is_better = greater_is_better, - ignore_data_skip = ignore_data_skip, - fsdp = fsdp, - fsdp_min_num_params = fsdp_min_num_params, - fsdp_config = fsdp_config, - fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap, - accelerator_config = accelerator_config, - parallelism_config = parallelism_config, - deepspeed = deepspeed, - label_smoothing_factor = label_smoothing_factor, - optim = optim, - optim_args = optim_args, - adafactor = adafactor, - group_by_length = group_by_length, - length_column_name = length_column_name, - report_to = report_to, - project = project, - trackio_space_id = trackio_space_id, - ddp_find_unused_parameters = ddp_find_unused_parameters, - ddp_bucket_cap_mb = ddp_bucket_cap_mb, - ddp_broadcast_buffers = ddp_broadcast_buffers, - dataloader_pin_memory = dataloader_pin_memory, - dataloader_persistent_workers = dataloader_persistent_workers, - skip_memory_metrics = skip_memory_metrics, - use_legacy_prediction_loop = use_legacy_prediction_loop, - push_to_hub = push_to_hub, - resume_from_checkpoint = resume_from_checkpoint, - hub_model_id = hub_model_id, - hub_strategy = hub_strategy, - hub_token = hub_token, - hub_private_repo = hub_private_repo, - hub_always_push = hub_always_push, - hub_revision = hub_revision, - gradient_checkpointing = gradient_checkpointing, - gradient_checkpointing_kwargs = gradient_checkpointing_kwargs, - include_inputs_for_metrics = include_inputs_for_metrics, - eval_do_concat_batches = eval_do_concat_batches, - fp16_backend = fp16_backend, - push_to_hub_model_id = push_to_hub_model_id, - push_to_hub_organization = push_to_hub_organization, - push_to_hub_token = push_to_hub_token, - mp_parameters = mp_parameters, - auto_find_batch_size = auto_find_batch_size, - full_determinism = full_determinism, - torchdynamo = torchdynamo, - ray_scope = ray_scope, - ddp_timeout = ddp_timeout, - torch_compile = torch_compile, - torch_compile_backend = torch_compile_backend, - torch_compile_mode = torch_compile_mode, - include_tokens_per_second = include_tokens_per_second, - include_num_input_tokens_seen = include_num_input_tokens_seen, - neftune_noise_alpha = neftune_noise_alpha, - optim_target_modules = optim_target_modules, - batch_eval_metrics = batch_eval_metrics, - eval_on_start = eval_on_start, - use_liger_kernel = use_liger_kernel, - liger_kernel_config = liger_kernel_config, - eval_use_gather_object = eval_use_gather_object, - average_tokens_across_devices = average_tokens_across_devices, - max_length = max_length, - max_prompt_length = max_prompt_length, - max_completion_length = max_completion_length, - disable_dropout = disable_dropout, - step_separator = step_separator, - train_on_last_step_only = train_on_last_step_only, - dataset_num_proc = dataset_num_proc,**kwargs) - self.vllm_sampling_params = vllm_sampling_params - self.unsloth_num_chunks = unsloth_num_chunks - if unsloth_grpo_mini_batch is not None: - if self.generation_batch_size >= unsloth_grpo_mini_batch: - self.unsloth_grpo_mini_batch = unsloth_grpo_mini_batch - else: - raise ValueError( - f"Unsloth GRPO mini batch size needs to be less than or equal to the effective generation batch size, " - f"which is self.per_device_train_batch_size * gradient_accumulation_steps." - ) - self.unsloth_logit_chunk_multiplier = unsloth_logit_chunk_multiplier - self.max_seq_length = max_seq_length - -pass - -class _UnslothPRMTrainer(BaseTrainer): - """""" - - _tag_names = ["trl", "prm"] - _name = "PRM" - _paper = { - "title": "Solving math word problems with process-and outcome-based feedback", - "id": "2211.14275", - # docstyle-ignore - "citation": textwrap.dedent("""\ - @article{uesato2022solving, - title = {{Solving Math Word Problems With Process- and Outcome-Based Feedback}}, - author = {Uesato, Jonathan and Kushman, Nate and Kumar, Ramana and Song, Francis and Siegel, Noah and Wang, Lisa and Creswell, Antonia and Irving, Geoffrey and Higgins, Irina}, - year = 2022, - journal = {arXiv preprint arXiv:2211.14275} - }"""), - } - - def __init__( - self, - model: Optional[Union[PreTrainedModel, nn.Module]] = None, - args: Optional[PRMConfig] = None, - data_collator: Optional[DataCollator] = None, - train_dataset: Optional[Dataset] = None, - eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None, - processing_class: Optional[ - Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] - ] = None, - model_init: Optional[Callable[[], PreTrainedModel]] = None, - compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None, - callbacks: Optional[list[TrainerCallback]] = None, - optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = ( - None, - None, - ), - preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, - peft_config: Optional[dict] = None, - ): - if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"): - warnings.warn( - "This trainer will soon be moved to trl.experimental and is a candidate for removal. If you rely on " - "it and want it to remain, please share your comments here: " - "https://github.com/huggingface/trl/issues/4223. Silence this warning by setting environment variable " - "TRL_EXPERIMENTAL_SILENCE=1." - ) - if False: - pass - - # Disable dropout in the model - if args.disable_dropout: - disable_dropout_in_model(model) - - if compute_metrics is None: - compute_metrics = compute_accuracy - - if data_collator is None: - if processing_class is None: - raise ValueError( - "A processing_class must be specified when using the default DataCollatorForTokenClassification" - ) - data_collator = DataCollatorForTokenClassification(processing_class, max_length=args.max_length) - - if "input_ids" not in train_dataset.column_names: - with PartialState().main_process_first(): - fn_kwargs = { - "tokenizer": processing_class, - "step_separator": args.step_separator, - "max_length": args.max_length, - "max_prompt_length": args.max_prompt_length, - "max_completion_length": args.max_completion_length, - "train_on_last_step_only": args.train_on_last_step_only, - } - train_fn_kwargs = {**fn_kwargs, "is_eval": False} - train_dataset = train_dataset.map( - self.tokenize_row, - fn_kwargs=train_fn_kwargs, - num_proc=args.dataset_num_proc, - remove_columns=train_dataset.features, - desc="Tokenizing train dataset", - features=features.Features( # needed to avoid map to cast labels to bool - { - "labels": features.Sequence(features.Value("int64")), - "input_ids": features.Sequence(features.Value("int64")), - } - ), - ) - - eval_fn_kwargs = {**fn_kwargs, "is_eval": True} - if eval_dataset is not None: - eval_dataset = eval_dataset.map( - self.tokenize_row, - fn_kwargs=eval_fn_kwargs, - num_proc=args.dataset_num_proc, - remove_columns=eval_dataset.features, - desc="Tokenizing eval dataset", - features=features.Features( # needed to avoid map to cast labels to bool - { - "labels": features.Sequence(features.Value("int64")), - "input_ids": features.Sequence(features.Value("int64")), - } - ), - ) - - super().__init__( - model=model, - args=args, - data_collator=data_collator, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - processing_class=processing_class, - model_init=model_init, - compute_metrics=compute_metrics, - callbacks=callbacks, - optimizers=optimizers, - preprocess_logits_for_metrics=preprocess_logits_for_metrics, - ) - - # Add tags for models that have been loaded with the correct transformers version - if hasattr(self.model, "add_model_tags"): - self.model.add_model_tags(self._tag_names) - - @staticmethod - def tokenize_row( - features, - tokenizer, - step_separator, - max_length, - max_prompt_length, - max_completion_length, - train_on_last_step_only, - is_eval, - ): - r""" - Tokenize a row of the dataset. - - Args: - features (`dict[str, str]`): - Row of the dataset, should contain the keys `"prompt"`, `"completions"`, and `"labels"`. - tokenizer ([`~transformers.PreTrainedTokenizerBase`]): - Tokenizer used to process the data. - step_separator (`str`): - Separator between steps in the completion. - max_length (`int` or `None`): - Maximum length of the sequences (prompt + completion). If `None`, the sequences are not truncated. - max_prompt_length (`int` or `None`): - Maximum length of the prompt. If `None`, the prompt is not truncated. - max_completion_length (`int` or `None`): - Maximum length of the completion sequences. If `None`, the completion sequences are not truncated. - train_on_last_step_only (`bool`): - Whether to train only on the last step. If `True`, the labels are `-100` for all tokens except the last - token of the completion. - is_eval (`bool`): - Whether the function is used to tokenize samples from a training or an evaluation dataset. Used only if - `train_on_last_step_only` is set to `True`. - - Returns: - `dict[str, list[int]]`: - Tokenized sequences with the keys `"input_ids"`, and `"labels". - - Example: - ```python - >>> from transformers import AutoTokenizer - - >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B") - >>> features = { - ... "prompt": "Which number is larger, 9.8 or 9.11?", - ... "completions": ["11 is greater than 8.", "Hence, 9.11 > 9.8."], - ... "labels": [True, False], - ... } - >>> PRMTrainer.tokenize_row( - ... features, tokenizer, "\n", max_completion_length=None, train_on_last_step_only=False, is_eval=False - ... ) - {'input_ids': [23085, 1372, 374, 8131, 11, 220, 24, 13, 23, 476, 220, 24, 13, 16, 16, 30, 16, 16, 374, 7046, 1091, 220, 23, 13, 198, 39, 763, 11, 220, 24, 13, 16, 16, 861, 220, 24, 13, 23, 13, 198], - 'labels': [-100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 0]} - ``` - """ - # Tokenize the prompt and completions - prompt_ids = tokenizer(features["prompt"], add_special_tokens=False)["input_ids"] - completions_ids = [ - tokenizer(completion, add_special_tokens=False)["input_ids"] for completion in features["completions"] - ] - if train_on_last_step_only and not is_eval: - labels = [-100] * (len(features["labels"]) - 1) + [int(features["labels"][-1])] - else: - labels = [int(label) for label in features["labels"]] - - # Get the ID of the separator token and add it to the completions - separator_ids = tokenizer.encode(step_separator, add_special_tokens=False) - completions_ids = [completion + separator_ids for completion in completions_ids] - - # Create the label - labels = [[-100] * (len(completion) - 1) + [label] for completion, label in zip(completions_ids, labels)] - - # Join the completions and labels steps - completion_ids = list(chain(*completions_ids)) - labels = list(chain(*labels)) - - if tokenizer.bos_token_id is not None: - prompt_ids = [tokenizer.bos_token_id] + prompt_ids - - # Truncate prompt and completion sequences - if max_prompt_length is not None: - prompt_ids = prompt_ids[-max_prompt_length:] - if max_completion_length is not None: - completion_ids = completion_ids[:max_completion_length] - labels = labels[:max_completion_length] - - input_ids = prompt_ids + completion_ids - labels = [-100] * len(prompt_ids) + labels - - if max_length is not None: - input_ids = input_ids[:max_length] - labels = labels[:max_length] - - return {"input_ids": input_ids, "labels": labels} - - # Ensure the model card is saved along with the checkpoint - def _save_checkpoint(self, model, trial): - if self.args.hub_model_id is None: - model_name = Path(self.args.output_dir).name - else: - model_name = self.args.hub_model_id.split("/")[-1] - self.create_model_card(model_name=model_name) - super()._save_checkpoint(model, trial) -class UnslothPRMTrainer(_UnslothPRMTrainer): - """ - - Initialize PRMTrainer. - - Args: - model ([`~transformers.PreTrainedModel`]): - The model to train, preferably an `AutoModelForTokenClassification`. - args ([`PRMConfig`]): - The arguments to use for training. - data_collator ([`~transformers.DataCollator`]): - The data collator to use for training. If None is specified, the default data collator - ([`~transformers.DataCollatorForTokenClassification`]) will be used which will pad the sequences to the - maximum length of the sequences in the batch, given a dataset of paired sequences. - train_dataset ([`~datasets.Dataset`]): - The dataset to use for training. - eval_dataset ([`~datasets.Dataset`]): - The dataset to use for evaluation. - processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*): - Processing class used to process the data. If provided, will be used to automatically process the inputs - for the model, and it will be saved along the model to make it easier to rerun an interrupted training or - reuse the fine-tuned model. - model_init (`Callable[[], transformers.PreTrainedModel]`): - The model initializer to use for training. If None is specified, the default model initializer will be - used. - compute_metrics (`Callable[[transformers.EvalPrediction], dict]`, *optional* defaults to `compute_accuracy`): - The metrics to use for evaluation. If no metrics are specified, the default metric (`compute_accuracy`) - will be used. - callbacks (`list[transformers.TrainerCallback]`): - The callbacks to use for training. - optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): - The optimizer and scheduler to use for training. - preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): - The function to use to preprocess the logits before computing the metrics. - peft_config (`dict`, defaults to `None`): - The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in - a PEFT model. - - """ - def __init__( - self, - model = None, - args = None, - data_collator = None, - train_dataset = None, - eval_dataset = None, - processing_class = None, - model_init = None, - compute_metrics = None, - callbacks = None, - preprocess_logits_for_metrics = None, - peft_config = None, - **kwargs - ): - if args is None: args = UnslothPRMConfig() - use_bf16 = getattr(args, 'bf16', False) - if type(use_bf16) is not bool: use_bf16 = False - use_fp16 = getattr(args, 'fp16', False) - if type(use_fp16) is not bool: use_fp16 = False - force_float32 = False - full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1' - if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'): - print('Unsloth: Switching to float32 training since model cannot work with float16') - force_float32 = True - mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') - dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None) - if dtype is None: dtype = model.get_input_embeddings().weight.dtype - from unsloth_zoo.utils import _get_dtype - dtype = _get_dtype(dtype) - float16 = dtype == torch.float16 - if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`') - if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`') - if force_float32: - # Forced float32 training - args.fp16 = False - args.bf16 = False - os.environ['ACCELERATE_MIXED_PRECISION'] = 'no' - if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no' - # args.mixed_precision is a new argument which needs to be set now - elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32': - # Mixed precision training - args.fp16 = float16 - args.bf16 = not float16 - os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16' - if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16' - # args.mixed_precision is a new argument which needs to be set now - elif mixed_precision_dtype == 'bfloat16': - # Both False since bfloat16 full finetuning doesn't do any autocasting. - args.fp16 = False - args.bf16 = False - os.environ['ACCELERATE_MIXED_PRECISION'] = 'no' - if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no' - # args.mixed_precision is a new argument which needs to be set now - - if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no': - args.eval_strategy = 'steps' - if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1 - ga_steps = getattr(args, 'gradient_accumulation_steps', None) - if ga_steps is not None and ga_steps > 1: - from transformers import __version__ as transformers_version - if Version(transformers_version) <= Version('4.45.2'): - print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n' - '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`') - if getattr(args, 'eval_strategy', 'no') != 'no': - eval_bsz = getattr(args, 'per_device_eval_batch_size', 8) - if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size - if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps - fp16_full_eval = getattr(args, 'fp16_full_eval', False) - if type(fp16_full_eval) is not bool: fp16_full_eval = False - bf16_full_eval = getattr(args, 'bf16_full_eval', False) - if type(bf16_full_eval) is not bool: bf16_full_eval = False - if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True - if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False - if force_float32: - args.bf16_full_eval = False - args.fp16_full_eval = False - elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16': - args.bf16_full_eval = True - args.fp16_full_eval = False - elif not bf16_full_eval and not fp16_full_eval: - args.bf16_full_eval = args.bf16 - args.fp16_full_eval = args.fp16 - _output_logits = False - if locals().get('compute_metrics', None) is not None: _output_logits = True - if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True - if _output_logits: - os.environ['UNSLOTH_RETURN_LOGITS'] = '1' - if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'): - pass - else: - model_max_seq_length = getattr(model, 'max_seq_length', None) - args_max_seq_length = getattr(args, 'max_seq_length', None) - if args_max_seq_length is None and model_max_seq_length is not None: - max_seq_length = model.max_seq_length - if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length - elif args_max_seq_length is not None and model_max_seq_length is not None: - if args_max_seq_length > model_max_seq_length: - print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but ' - 'the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.') - args.max_seq_length = model_max_seq_length - if model is not None and hasattr(model, 'for_training'): - model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True)) - if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right' - if 'processing_class' in locals(): - if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right' - if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right' - __tokenizer = processing_class if 'processing_class' in locals() else tokenizer - from unsloth_zoo.vision_utils import UnslothVisionDataCollator - if not isinstance(data_collator, UnslothVisionDataCollator): - if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names: - data_collator = TransformersDataCollatorForLanguageModeling( - __tokenizer, - mlm = False, - mlm_probability = 0.0, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names: - data_collator = DataCollatorForSeq2Seq( - __tokenizer, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - else: - if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False - if hasattr(args, 'dataset_text_field'): args.dataset_text_field = '' - if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True} - if not isinstance(data_collator, UnslothVisionDataCollator): - if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'): - if isinstance(data_collator, DataCollatorForSeq2Seq): - data_collator = DataCollatorForSeq2Seq( - __tokenizer.tokenizer, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - else: - data_collator = TransformersDataCollatorForLanguageModeling( - __tokenizer.tokenizer, - mlm = False, - mlm_probability = 0.0, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - other_metrics = [] - - from unsloth_zoo.logging_utils import PatchRLStatistics - PatchRLStatistics('prm_trainer', other_metrics) - - # [TODO] Fix up DataParallel multiplying batch sizes - # [TODO] DDP works, but DP seems to not work? [TODO] - if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1: - if getattr(args, "_n_gpu", 1) != 1: - args._n_gpu = 1 - if "model" in locals() and hasattr(model, "for_training"): - model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True)) - super().__init__( - model = model, - args = args, - data_collator = data_collator, - train_dataset = train_dataset, - eval_dataset = eval_dataset, - processing_class = processing_class, - model_init = model_init, - compute_metrics = compute_metrics, - callbacks = callbacks, - preprocess_logits_for_metrics = preprocess_logits_for_metrics, - peft_config = peft_config,**kwargs) - if "model" in locals() and hasattr(model, "for_inference"): - model.for_inference() - if hasattr(self, 'neftune_hook_handle'): - self.neftune_hook_handle.remove() - if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle - if getattr(args, 'neftune_noise_alpha', None) is not None: - model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha - pass - if hasattr(self, 'accelerator'): - scaler = self.accelerator.scaler - current_model = model - while hasattr(current_model, 'model'): - current_model.accelerator_scaler = scaler - current_model = current_model.model - current_model.accelerator_scaler = scaler - pass - if hasattr(self, 'train'): - self.train = MethodType(prepare_for_training_mode(self.__class__.train), self) - pass - if hasattr(self, 'llm') and self.llm is not None and hasattr(self.llm, 'get_tokenizer'): - _vllm_tok = self.llm.get_tokenizer() - _pc = getattr(self, 'processing_class', None) or getattr(self, 'tokenizer', None) - if _vllm_tok is not None and _pc is not None and getattr(_pc, 'chat_template', None) is not None and getattr(_vllm_tok, 'chat_template', None) is None: - _vllm_tok.chat_template = _pc.chat_template - pass - -pass diff --git a/unsloth_compiled_cache/UnslothRLOOTrainer.py b/unsloth_compiled_cache/UnslothRLOOTrainer.py deleted file mode 100644 index 740659a..0000000 --- a/unsloth_compiled_cache/UnslothRLOOTrainer.py +++ /dev/null @@ -1,2828 +0,0 @@ -""" -2026.2.1 -2026.2.1 -4.57.6 -0.24.0 -__UNSLOTH_VERSIONING__ -""" - -# Unsloth auto generated code -# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see . - -from torch import Tensor -import torch -import torch.nn as nn -from torch.nn import functional as F -from unsloth_zoo.temporary_patches.common import torch_compile -from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable -from trl.trainer.rloo_trainer import (Any, AutoConfig, AutoModelForSequenceClassification, AutoProcessor, AutoTokenizer, BaseTrainer, DataLoader, Dataset, FSDP, GenerationConfig, IterableDataset, Optional, Path, PeftConfig, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, RLOOConfig, RLOOTrainer, RepeatSampler, RewardFunc, Sampler, SyncRefModelCallback, TrainerCallback, Union, VLLMClient, apply_chat_template, broadcast_object_list, datasets, defaultdict, deque, disable_dropout_in_model, ensure_master_addr_port, entropy_from_logits, gather, gather_object, identity, inspect, is_conversational, is_datasets_available, is_flash_attn_2_available, is_peft_model, is_rich_available, is_vllm_available, logger, logging, maybe_apply_chat_template, nanmax, nanmin, nanstd, nn, nullcontext, os, pad, partial, prepare_deepspeed, prepare_fsdp, prepare_multimodal_messages, print_prompt_completions_sample, profiling_context, profiling_decorator, seed_worker, selective_log_softmax, set_seed, shuffle_sequence_dict, split_pixel_values_by_grid, split_tensor_dict, textwrap, torch, transformers, unsplit_pixel_values_by_grid, unwrap_model_for_generation, warnings, AutoConfig, AutoModelForSequenceClassification, AutoProcessor, AutoTokenizer, Dataset, GenerationConfig, IterableDataset, Optional, PeftConfig, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, RLOOConfig, RLOOTrainer, RewardFunc, SyncRefModelCallback, TrainerCallback, Union, VLLMClient, datasets, defaultdict, deque, disable_dropout_in_model, ensure_master_addr_port, identity, inspect, is_peft_model, is_vllm_available, logger, nn, os, pad, prepare_deepspeed, prepare_fsdp, set_seed, torch, transformers, warnings, FSDP, Optional, apply_chat_template, broadcast_object_list, gather, gather_object, is_flash_attn_2_available, maybe_apply_chat_template, nullcontext, os, pad, prepare_multimodal_messages, profiling_context, torch, transformers, unwrap_model_for_generation, FSDP, gather, is_peft_model, nn, nullcontext, os, profiling_decorator, Any, Union, profiling_decorator, shuffle_sequence_dict, split_pixel_values_by_grid, split_tensor_dict, torch, unsplit_pixel_values_by_grid, Optional, PreTrainedModel, logger, os, torch, FSDP, nn, os, FSDP, nn, torch) - - -import os -from typing import * -from dataclasses import dataclass, field -from packaging.version import Version -import torch -import numpy as np -from contextlib import nullcontext -from torch.nn import functional as F -import inspect -from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling -from transformers.training_args import ParallelMode -from unsloth_zoo.device_type import DEVICE_TYPE, device_synchronize - -# Wrap trainer with padding to right and enable training mode -# Also patches W&B since multiple runs must use wandb.finish() -import functools -from types import MethodType -try: - from unsloth_zoo.gradient_checkpointing import reset_unsloth_gradient_checkpointing_buffers -except: - def reset_unsloth_gradient_checkpointing_buffers(): pass -def prepare_for_training_mode(f): - @functools.wraps(f) - def wrapper(self, *args, **kwargs): - # Enable training mode - _was_training = None - # Get gradient checkpointing setting from training arguments - use_gc = getattr(self.args, 'gradient_checkpointing', True) - if hasattr(self, 'model') and hasattr(self.model, "training"): - _was_training = self.model.training - if hasattr(self, 'model') and hasattr(self.model, "for_training"): - self.model.for_training(use_gradient_checkpointing=use_gc) - output = f(self, *args, **kwargs) - # Restore previous mode when possible - if hasattr(self, 'model') and hasattr(self.model, "for_inference"): - if _was_training is False: - self.model.for_inference() - elif _was_training is True and hasattr(self.model, "for_training"): - self.model.for_training(use_gradient_checkpointing=use_gc) - # Reset gradient checkpointing buffers to free memory while staying ready for next run - try: - reset_unsloth_gradient_checkpointing_buffers() - except: - pass - # Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run - try: - import wandb - wandb.finish() - except: - pass - return output - return wrapper -pass - -torch_compile_options = { - "epilogue_fusion" : True, - "max_autotune" : False, - "shape_padding" : True, - "trace.enabled" : False, - "triton.cudagraphs" : False, -} - -@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) -def chunked_hidden_states_selective_log_softmax( - hidden_states: torch.Tensor, - lm_head: torch.Tensor, - index: torch.Tensor, - chunks: int = 4, - logit_scale_multiply: float = 0.0, - logit_scale_divide: float = 0.0, - logit_softcapping: float = 0.0, - temperature: float = 1.0, -) -> torch.Tensor: - # All Unsloth Zoo code licensed under AGPL3 - flat_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1]) - flat_index = index.reshape(-1) - - chunked_hidden_states = torch.chunk(flat_hidden_states, chunks=chunks, dim=0) - chunked_index = torch.chunk(flat_index, chunks=chunks, dim=0) - - all_per_token_logps = [] - - for chunk_hidden_states, chunk_index in zip(chunked_hidden_states, chunked_index): - chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t() - - if logit_scale_multiply != 0.0: - chunk_logits = chunk_logits * logit_scale_multiply - if logit_scale_divide != 0.0: - chunk_logits = chunk_logits / logit_scale_divide - if logit_softcapping != 0.0: - chunk_logits = chunk_logits * torch.tanh(chunk_logits / logit_softcapping) - - chunk_logits = chunk_logits.to(torch.float32) - - if temperature != 1.0: - chunk_logits = chunk_logits / temperature - - selected_logits = torch.gather(chunk_logits, dim=-1, index=chunk_index.unsqueeze(-1)).squeeze(-1) - logsumexp_values = torch.logsumexp(chunk_logits, dim=-1) - per_token_logps = selected_logits - logsumexp_values - all_per_token_logps.append(per_token_logps) - - all_per_token_logps = torch.concat(all_per_token_logps) - - all_per_token_logps = all_per_token_logps.reshape((hidden_states.shape[0], hidden_states.shape[1])) - return all_per_token_logps - -@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) -def chunked_selective_log_softmax(logits, index): - # Split into 4 chunks only - chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0) - chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0) - all_per_token_logps = [] - # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index) - for chunk_logits, chunk_index in zip(chunked_logits, chunked_index): - chunk_logits = chunk_logits.to(torch.float32) - selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1) - logsumexp_values = torch.logsumexp(chunk_logits, dim = -1) - per_token_logps = selected_logits - logsumexp_values - all_per_token_logps.append(per_token_logps) - pass - all_per_token_logps = torch.concat(all_per_token_logps) - all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1])) - return all_per_token_logps - -def calculate_pad_tokens_in_prompt( - input_ids: torch.Tensor, - logits_to_keep: int, - pad_token_id: int -) -> torch.Tensor: - """ - Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens - """ - if logits_to_keep >= input_ids.shape[1]: - raise ValueError("logits_to_keep must be smaller than the sequence length.") - - prompt_section = input_ids[:, :-logits_to_keep] - - padding_mask = (prompt_section == pad_token_id) - - pad_token_counts = padding_mask.sum(dim=1) - - return pad_token_counts - -def create_completion_attention_mask( - completion_input_ids: torch.Tensor, - left_pad_tokens_per_prompt: torch.Tensor, - max_left_pad: int, - pad_token_id: int -) -> torch.Tensor: - """ - Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad] - - Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens - and pad are pad tokens, this function would make a completion mask that would 0 out the pad - and p tokens. so in this example [0,0,0,1,1,1,0,0,0] - """ - batch_size, completion_len = completion_input_ids.shape - device = completion_input_ids.device - - num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt - - indices = torch.arange(completion_len, device=device).unsqueeze(0) - shift_mask = indices >= num_tokens_to_mask.unsqueeze(1) - - non_padding_mask = (completion_input_ids != pad_token_id) - - final_mask = shift_mask & non_padding_mask - - return final_mask - -def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor: - """ - Moves all padding tokens in each sequence of a batch to the right. - """ - mask = (tensor != pad_id) - # Must do stable=True since binary mark is unordered - sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True) - packed_tensor = torch.gather(tensor, 1, sorted_indices) - return packed_tensor - -def align_logprobs_with_mask( - logprob_tensor: torch.Tensor, - attention_mask: torch.Tensor, - pad_value: float = 0.0 -) -> torch.Tensor: - """ - Aligns a log probability tensor with a given attention mask. - """ - - device = logprob_tensor.device - batch_size, logprob_seq_len = logprob_tensor.shape - mask_seq_len = attention_mask.shape[1] - - padded_logprobs = torch.full( - attention_mask.shape, - fill_value=pad_value, - dtype=logprob_tensor.dtype, - device=device - ) - - left_pad_counts = torch.argmax(attention_mask, dim=1) - - cols = torch.arange(logprob_seq_len, device=device) - dest_indices = left_pad_counts.unsqueeze(1) + cols - - # Create destination row indices - # Shape: [batch_size, logprob_seq_len] - row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices) - - # --- 4. Filter out-of-bounds indices and perform assignment --- - # Create a mask to identify only the indices that are within the bounds - # of the target tensor's sequence length. - valid_mask = dest_indices < mask_seq_len - - # Use this mask to select only the valid row indices, column indices, - # and the corresponding values from the logprob tensor. - # This flattens the selected elements into 1D tensors. - valid_rows = row_indices[valid_mask] - valid_cols = dest_indices[valid_mask] - valid_vals = logprob_tensor[valid_mask] - - # Place the valid values into their correct positions in the padded tensor - # using a single, efficient advanced indexing operation. - padded_logprobs[valid_rows, valid_cols] = valid_vals - - return padded_logprobs - -def autotune_batch_and_chunks( - total_input_rows, - seq_len, - hidden_size, - vocab_size, - dtype_bytes=16, - multiplier=None -): - if multiplier is None: - final_m = max(4, seq_len // 4096) - else: - final_m = multiplier - - if torch.cuda.is_available(): - free_bytes, _ = torch.cuda.mem_get_info() - limit_gb = (free_bytes / (1024**3))*.80 - elif hasattr(torch, "xpu") and torch.xpu.is_available(): - # For XPU: estimate free memory from total - reserved - total_mem = torch.xpu.get_device_properties(0).total_memory - reserved_mem = torch.xpu.memory_reserved() - free_bytes = total_mem - reserved_mem - limit_gb = (free_bytes / (1024**3)) * 0.80 - else: - # Fallback: assume 8GB available - limit_gb = 8.0 - - bytes_to_gb = 1024**3 - - b_vals = torch.arange(total_input_rows, 0, -1, device='cpu', dtype=torch.float32) - - hidden_gb = (b_vals * seq_len * hidden_size * dtype_bytes) / bytes_to_gb - - base_logits = ((b_vals/total_input_rows) * b_vals * seq_len * vocab_size * dtype_bytes) / bytes_to_gb - logits_gb = base_logits / final_m - - total_mem_gb = hidden_gb + logits_gb - - valid_mask = total_mem_gb <= limit_gb - valid_indices = torch.nonzero(valid_mask, as_tuple=False) - - if valid_indices.shape[0] == 0: - #This means your GPU will OOM - return 4, final_m - - best_idx = valid_indices[0].item() - final_b = int(b_vals[best_idx].item()) - - return final_b, final_m -def vLLMSamplingParams(**kwargs): - from vllm import SamplingParams - - sampling_params = SamplingParams(**kwargs) - sampling_params._set_kwargs = kwargs - return sampling_params -@dataclass -class UnslothRLOOConfig(RLOOConfig): - """ - - Configuration class for the [`RLOOTrainer`]. - - This class includes only the parameters that are specific to RLOO training. For a full list of training arguments, - please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may - differ from those in [`~transformers.TrainingArguments`]. - - Using [`~transformers.HfArgumentParser`] we can turn this class into - [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the - command line. - - Parameters: - > Parameters that control the model and reference model - - model_init_kwargs (`str`, `dict[str, Any]`, *optional*): - Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model` - argument of the [`RLOOTrainer`] is provided as a string. - disable_dropout (`bool`, *optional*, defaults to `False`): - Whether to disable dropout in the model. This is useful for training with a reference model, as it prevents - the model from generating different logprobs for the same input. - - > Parameters that control the data preprocessing - - remove_unused_columns (`bool`, *optional*, defaults to `False`): - Whether to only keep the column `"prompt"` in the dataset. If you use a custom reward function that - requires any column other than `"prompts"` and `"completions"`, you should keep this to `False`. - max_prompt_length (`int` or `None`, *optional*, defaults to `512`): - Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left. - num_generations (`int` or `None`, *optional*, defaults to `2`): - Number of generations per prompt to sample. The effective batch size (num_processes * per_device_batch_size - * gradient_accumulation_steps) must be evenly divisible by this value. - max_completion_length (`int` or `None`, *optional*, defaults to `256`): - Maximum length of the generated completion. - ds3_gather_for_generation (`bool`, *optional*, defaults to `True`): - This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation, - improving generation speed. However, disabling this option allows training models that exceed the VRAM - capacity of a single GPU, albeit at the cost of slower generation. Disabling this option is not compatible - with vLLM generation. - shuffle_dataset (`bool`, *optional*, defaults to `True`): - Whether to shuffle the training dataset. - - > Parameters that control generation - - generation_batch_size: (`int`, *optional*): - Batch size to use for generation. If `None`, it defaults to the effective training batch size: - `per_device_train_batch_size * num_processes * steps_per_generation`. In other words, there is one - generation batch processed per optimization step. Mutually exclusive with `steps_per_generation`. - steps_per_generation: (`int`, *optional*): - Number of steps per generation. If `None`, it defaults to `gradient_accumulation_steps`. Mutually exclusive - with `generation_batch_size`. - temperature (`float`, defaults to `1.0`): - Temperature for sampling. The higher the temperature, the more random the completions. - top_p (`float`, *optional*, defaults to `1.0`): - Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. Set to - `1.0` to consider all tokens. - top_k (`int`, *optional*): - Number of highest probability vocabulary tokens to keep for top-k-filtering. If `None`, top-k-filtering is - disabled and all tokens are considered. - min_p (`float`, *optional*): - Minimum token probability, which will be scaled by the probability of the most likely token. It must be a - value between `0.0` and `1.0`. Typical values are in the `0.01-0.2` range. - repetition_penalty (`float`, *optional*, defaults to `1.0`): - Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far. - Values > `1.0` encourage the model to use new tokens, while values < `1.0` encourage the model to repeat - tokens. - use_transformers_paged (`bool`, *optional*, defaults to `False`): - Whether to use the `transformers` paged implementation for generation. If set to `True`, the `transformers` - paged implementation will be used for generation instead of the default padded implementation. This - parameter is only effective when `use_vllm` is set to `False`. - cache_implementation (`str`, *optional*): - Implementation of the cache method for faster generation when `use_vllm` is set to `False`. - generation_kwargs (`dict[str, Any]`, *optional*): - Additional keyword arguments to pass to [`~transformers.GenerationConfig`] (if using transformers) or - `SamplingParams` (if using vLLM) when sampling completions. This can be used to further customize the - generation behavior, such as setting `suppress_tokens`, `num_beams`, etc. If it contains keys that conflict - with the other generation parameters (like `min_p`, `top_p`, etc.), they will override them. - - > Parameters that control generation acceleration powered by vLLM - - use_vllm (`bool`, *optional*, defaults to `False`): - Whether to use vLLM for generating completions. If set to `True`, the trainer will use vLLM for generation - instead of the default model.generate(). Requires `vllm` to be installed. - vllm_mode (`str`, *optional*, defaults to `"server"`): - Mode to use for vLLM integration when `use_vllm` is set to `True`. Must be one of `"server"` or - `"colocate"`. - - - `"server"`: The trainer will send generation requests to a separate vLLM server. Make sure a TRL vLLM - server is running (start with `trl vllm-serve`). - - `"colocate"`: vLLM will run in the same process and share the training GPUs. This avoids the need for a - separate server but may cause resource contention with training. - vllm_model_impl (`str`, *optional*, defaults to `"vllm"`): - Model implementation to use for vLLM. Must be one of `"transformers"` or `"vllm"`. `"transformers"`: Use - the `transformers` backend for model implementation. `"vllm"`: Use the `vllm` library for model - implementation. - vllm_guided_decoding_regex (`str`, *optional*): - Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled. - - > Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`) - - vllm_server_base_url (`str`, *optional*): - Base URL for the vLLM server (e.g., `"http://localhost:8000"`). If provided, `vllm_server_host` and - `vllm_server_port` are ignored. - vllm_server_host (`str`, *optional*, defaults to `"0.0.0.0"`): - Host of the vLLM server to connect to. Ignored if `vllm_server_base_url` is provided. - vllm_server_port (`int`, *optional*, defaults to `8000`): - Port of the vLLM server to connect to. Ignored if `vllm_server_base_url` is provided. - vllm_server_timeout (`float`, *optional*, defaults to `240.0`): - Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up after the - timeout, a `ConnectionError` is raised. - - > Parameters that control colocated vLLM execution (only used when `vllm_mode` is `"colocate"`) - - vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.3`): - Control the GPU memory utilization for vLLM. This setting only applies when `vllm_mode` is set to - `"colocate"`. If you are using `vllm_mode="server"`, this parameter must be passed separately when - launching the vLLM server via the `--vllm_gpu_memory_utilization` flag. - vllm_tensor_parallel_size (`int`, *optional*, defaults to `1`): - Control the tensor parallel size for vLLM. This setting only applies when `vllm_mode` is set to - `"colocate"`. If you are using `vllm_mode="server"`, this parameter must be passed separately when - launching the vLLM server via the `--vllm_tensor_parallel_size` flag. - vllm_enable_sleep_mode (`bool`, *optional*, defaults to `False`): - Whether to enable sleep mode for vLLM. If `True`, vLLM will sleep during the optimization step and woken - for weight sync and generation. - - > Parameters that control the training - - beta (`float`, *optional*, defaults to `0.05`): - KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving training - speed. - num_iterations (`int`, *optional*, defaults to `1`): - Number of iterations per batch (denoted as μ in the algorithm). - epsilon (`float`, *optional*, defaults to `0.2`): - Epsilon value for clipping. - epsilon_high (`float`, *optional*): - Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the lower-bound - specified in argument `epsilon`. Paper [DAPO](https://huggingface.co/papers/2503.14476) recommends `0.28`. - reward_weights (`list[float]`, *optional*): - Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are - weighted equally with weight `1.0`. - normalize_advantages (`bool`, *optional*, defaults to `False`): - Whether to normalize advantages. Normalization is done per generation batch to have mean `0.0` and standard - deviation of `1.0`. - reward_clip_range (`tuple[float, float]`, *optional*): - Clip range for rewards as (min, max). If `None`, no clipping is applied. - mask_truncated_completions (`bool`, *optional*, defaults to `False`): - When enabled, truncated completions are excluded from the loss calculation, preventing them from being - incorrectly penalized and introducing noise during training. According to the - [DAPO](https://huggingface.co/papers/2503.14476) paper, this is a good practice for training stability. - sync_ref_model (`bool`, *optional*, defaults to `False`): - Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using - the `ref_model_mixup_alpha` parameter. This synchronization originates from the - [TR-DPO](https://huggingface.co/papers/2404.09656) paper. - ref_model_mixup_alpha (`float`, *optional*, defaults to `0.6`): - α parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which controls the mix - between the current policy and the previous reference policy during updates. The reference policy is - updated according to the equation: `π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you - must set `sync_ref_model=True`. - ref_model_sync_steps (`int`, *optional*, defaults to `512`): - τ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which determines how - frequently the current policy is synchronized with the reference policy. To use this parameter, you must - set `sync_ref_model=True`. - - > Parameters that control the logging - - log_completions (`bool`, *optional*, defaults to `False`): - Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is installed, - it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`. - num_completions_to_print (`int`, *optional*): - Number of completions to print with `rich`. If `None`, all completions are logged. - wandb_log_unique_prompts (`bool`, *optional*, defaults to `False`): - Whether to log unique prompts in wandb. If `True`, only unique prompts are logged. If `False`, all prompts - are logged. - - > Deprecated parameters - - rloo_k: - - - - This parameter is deprecated and will be removed in version 0.25.0. Use `num_generations` instead. - - - - cliprange: - - - - This parameter is deprecated and will be removed in version 0.25.0. Use `epsilon` instead. - - - - kl_coef: - - - - This parameter is deprecated and will be removed in version 0.25.0. Use `beta` instead. - - - - exp_name: - - - - This parameter is deprecated and will be removed in version 0.25.0. Use `run_name` instead. - - - - normalize_reward: - - - - This parameter is deprecated and will be removed in version 0.25.0. Use `normalize_advantages` instead. - - - - num_ppo_epochs: - - - - This parameter is deprecated and will be removed in version 0.25.0. Use `num_iterations` instead. - - - - num_mini_batches: - - - - This parameter is deprecated and will be removed in version 0.25.0. Use `steps_per_generation` instead. - - - - total_episodes: - - - - This parameter is deprecated and will be removed in version 0.25.0. Use `max_steps` instead. - - - - response_length: - - - - This parameter is deprecated and will be removed in version 0.25.0. Use `max_completion_length` instead. - - - - token_level_kl: - - - - This parameter is deprecated and will be removed in version 0.25.0. KL is now computed only at the sequence - level. - - - - dataset_num_proc: - - - - This parameter is deprecated and will be removed in version 0.25.0. This parameter was unused, you can - safely remove it from your scripts. - - - - local_rollout_forward_batch_size: - - - - This parameter is deprecated and will be removed in version 0.25.0. Now it is automatically set to - `per_device_train_batch_size` (or `per_device_eval_batch_size` during evaluation). - - - - num_sample_generations: - - - - This parameter is deprecated and will be removed in version 0.25.0. Use `logging_steps` to control - generation logging frequency. - - - - stop_token: - - - - This parameter is deprecated and will be removed in version 0.25.0. - - - - stop_token_id: - - - - This parameter is deprecated and will be removed in version 0.25.0. Use `processing_class.eos_token_id` - instead. - - - - missing_eos_penalty: - - - - This parameter is deprecated and will be removed in version 0.25.0. Replicate with a custom reward function - checking if `eos_token_id` is in `completion_ids`. - - - - """ - vllm_sampling_params: Optional[Any] = field( - default = None, - metadata = {'help': 'vLLM SamplingParams'}, - ) - unsloth_num_chunks : Optional[int] = field( - default = -1, - metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'}, - ) - unsloth_logit_chunk_multiplier : Optional[int] = field( - default = None, - metadata = {'help': 'Multiplier for chunked logit computations.'}, - ) - unsloth_grpo_mini_batch : Optional[int] = field( - default = None, - metadata = {'help': 'Mini batch size for GRPO hidden state accumulation. Default is None unless user defines it.'}, - ) - - def __init__( - self, - output_dir = None, - overwrite_output_dir = None, - do_train = False, - do_eval = False, - do_predict = False, - eval_strategy = 'no', - prediction_loss_only = False, - per_device_train_batch_size = 4, - per_device_eval_batch_size = 4, - per_gpu_train_batch_size = None, - per_gpu_eval_batch_size = None, - gradient_accumulation_steps = 2, - eval_accumulation_steps = 2, - eval_delay = 0, - torch_empty_cache_steps = 250, - learning_rate = 5e-05, - weight_decay = 0.01, - adam_beta1 = 0.9, - adam_beta2 = 0.999, - adam_epsilon = 1e-08, - max_grad_norm = 1.0, - num_train_epochs = 3.0, - max_steps = -1, - lr_scheduler_type = 'linear', - lr_scheduler_kwargs = None, - warmup_ratio = 0.1, - warmup_steps = 0, - log_level = 'passive', - log_level_replica = 'warning', - log_on_each_node = True, - logging_dir = None, - logging_strategy = 'steps', - logging_first_step = False, - logging_steps = 1, - logging_nan_inf_filter = False, - save_strategy = 'steps', - save_steps = 500, - save_total_limit = None, - save_safetensors = True, - save_on_each_node = False, - save_only_model = False, - restore_callback_states_from_checkpoint = False, - no_cuda = False, - use_cpu = False, - use_mps_device = False, - seed = 3407, - data_seed = 3407, - jit_mode_eval = False, - bf16 = False, - fp16 = False, - fp16_opt_level = 'O1', - half_precision_backend = 'auto', - bf16_full_eval = False, - fp16_full_eval = False, - tf32 = None, - local_rank = -1, - ddp_backend = None, - tpu_num_cores = None, - tpu_metrics_debug = False, - debug = '', - dataloader_drop_last = False, - eval_steps = None, - dataloader_num_workers = 0, - dataloader_prefetch_factor = None, - past_index = -1, - run_name = None, - disable_tqdm = None, - remove_unused_columns = False, - label_names = None, - load_best_model_at_end = False, - metric_for_best_model = None, - greater_is_better = None, - ignore_data_skip = False, - fsdp = None, - fsdp_min_num_params = 0, - fsdp_config = None, - fsdp_transformer_layer_cls_to_wrap = None, - accelerator_config = None, - parallelism_config = None, - deepspeed = None, - label_smoothing_factor = 0.0, - optim = 'adamw_8bit', - optim_args = None, - adafactor = False, - group_by_length = False, - length_column_name = 'length', - report_to = 'none', - project = 'huggingface', - trackio_space_id = 'trackio', - ddp_find_unused_parameters = None, - ddp_bucket_cap_mb = None, - ddp_broadcast_buffers = None, - dataloader_pin_memory = True, - dataloader_persistent_workers = False, - skip_memory_metrics = True, - use_legacy_prediction_loop = False, - push_to_hub = False, - resume_from_checkpoint = None, - hub_model_id = None, - hub_strategy = 'every_save', - hub_token = None, - hub_private_repo = None, - hub_always_push = False, - hub_revision = None, - gradient_checkpointing = True, - gradient_checkpointing_kwargs = None, - include_inputs_for_metrics = False, - eval_do_concat_batches = True, - fp16_backend = 'auto', - push_to_hub_model_id = None, - push_to_hub_organization = None, - push_to_hub_token = None, - mp_parameters = '', - auto_find_batch_size = False, - full_determinism = False, - torchdynamo = None, - ray_scope = 'last', - ddp_timeout = 1800, - torch_compile = False, - torch_compile_backend = None, - torch_compile_mode = None, - include_tokens_per_second = False, - include_num_input_tokens_seen = False, - neftune_noise_alpha = None, - optim_target_modules = None, - batch_eval_metrics = False, - eval_on_start = False, - use_liger_kernel = False, - liger_kernel_config = None, - eval_use_gather_object = False, - average_tokens_across_devices = True, - model_init_kwargs = None, - disable_dropout = False, - max_prompt_length = 512, - num_generations = 8, - max_completion_length = 256, - ds3_gather_for_generation = True, - shuffle_dataset = True, - generation_batch_size = None, - steps_per_generation = None, - temperature = 1.0, - top_p = 1.0, - top_k = None, - min_p = None, - generation_kwargs = {}, - repetition_penalty = 1.0, - use_transformers_paged = False, - cache_implementation = None, - use_vllm = False, - vllm_mode = 'colocate', - vllm_model_impl = 'vllm', - vllm_enable_sleep_mode = False, - vllm_guided_decoding_regex = None, - vllm_server_base_url = None, - vllm_server_host = '0.0.0.0', - vllm_server_port = 8000, - vllm_server_timeout = 240.0, - vllm_gpu_memory_utilization = 0.3, - vllm_tensor_parallel_size = 1, - beta = 0.05, - num_iterations = 1, - epsilon = 0.2, - epsilon_high = None, - reward_weights = None, - normalize_advantages = False, - reward_clip_range = None, - mask_truncated_completions = False, - sync_ref_model = False, - ref_model_mixup_alpha = 0.6, - ref_model_sync_steps = 512, - log_completions = False, - num_completions_to_print = None, - wandb_log_unique_prompts = False, - rloo_k = None, - cliprange = None, - kl_coef = None, - exp_name = None, - normalize_reward = None, - num_ppo_epochs = None, - num_mini_batches = None, - total_episodes = None, - response_length = None, - token_level_kl = None, - dataset_num_proc = None, - local_rollout_forward_batch_size = None, - num_sample_generations = None, - stop_token = None, - stop_token_id = None, - missing_eos_penalty = None, - vllm_sampling_params = None, - unsloth_num_chunks = -1, - unsloth_logit_chunk_multiplier = None, - unsloth_grpo_mini_batch = None, - - **kwargs, - ): - if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!') - if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!') - if num_train_epochs is None: - num_train_epochs = 3.0 # Default to 3 epochs if None, max_steps will override - if output_dir is None and save_strategy == 'steps' and save_steps == 500: - output_dir = 'unsloth_training_checkpoints' - save_strategy = 'no' - import multiprocessing as _mp - if _mp.get_start_method() != 'fork': - dataset_num_proc = None - elif dataset_num_proc is None: - import psutil - dataset_num_proc = min(max((psutil.cpu_count() or 1)+4, 2), 64) - memory_gb_left = psutil.virtual_memory().available / (1024**3) - if memory_gb_left <= 2: dataset_num_proc = 1 - else: dataset_num_proc = min(dataset_num_proc, int(memory_gb_left)) - if steps_per_generation is None and generation_batch_size is None: - ga = gradient_accumulation_steps - world_size = int(os.environ.get('WORLD_SIZE', '1')) - if (ga * world_size * per_device_train_batch_size) % num_generations != 0: - print('Unsloth: We now expect `per_device_train_batch_size` * `gradient_accumulation_steps` * `world_size` to be a multiple of `num_generations`.\nWe will change the batch size of ' + str(per_device_train_batch_size) + ' to the `num_generations` of ' + str(num_generations)) - per_device_train_batch_size = num_generations - - if temperature <= 0: - raise ValueError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.') - elif temperature >= 10: - raise ValueError('Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic.') - - - super().__init__( - output_dir = output_dir, - overwrite_output_dir = overwrite_output_dir, - do_train = do_train, - do_eval = do_eval, - do_predict = do_predict, - eval_strategy = eval_strategy, - prediction_loss_only = prediction_loss_only, - per_device_train_batch_size = per_device_train_batch_size, - per_device_eval_batch_size = per_device_eval_batch_size, - per_gpu_train_batch_size = per_gpu_train_batch_size, - per_gpu_eval_batch_size = per_gpu_eval_batch_size, - gradient_accumulation_steps = gradient_accumulation_steps, - eval_accumulation_steps = eval_accumulation_steps, - eval_delay = eval_delay, - torch_empty_cache_steps = torch_empty_cache_steps, - learning_rate = learning_rate, - weight_decay = weight_decay, - adam_beta1 = adam_beta1, - adam_beta2 = adam_beta2, - adam_epsilon = adam_epsilon, - max_grad_norm = max_grad_norm, - num_train_epochs = num_train_epochs, - max_steps = max_steps, - lr_scheduler_type = lr_scheduler_type, - lr_scheduler_kwargs = lr_scheduler_kwargs, - warmup_ratio = warmup_ratio, - warmup_steps = warmup_steps, - log_level = log_level, - log_level_replica = log_level_replica, - log_on_each_node = log_on_each_node, - logging_dir = logging_dir, - logging_strategy = logging_strategy, - logging_first_step = logging_first_step, - logging_steps = logging_steps, - logging_nan_inf_filter = logging_nan_inf_filter, - save_strategy = save_strategy, - save_steps = save_steps, - save_total_limit = save_total_limit, - save_safetensors = save_safetensors, - save_on_each_node = save_on_each_node, - save_only_model = save_only_model, - restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint, - no_cuda = no_cuda, - use_cpu = use_cpu, - use_mps_device = use_mps_device, - seed = seed, - data_seed = data_seed, - jit_mode_eval = jit_mode_eval, - bf16 = bf16, - fp16 = fp16, - fp16_opt_level = fp16_opt_level, - half_precision_backend = half_precision_backend, - bf16_full_eval = bf16_full_eval, - fp16_full_eval = fp16_full_eval, - tf32 = tf32, - local_rank = local_rank, - ddp_backend = ddp_backend, - tpu_num_cores = tpu_num_cores, - tpu_metrics_debug = tpu_metrics_debug, - debug = debug, - dataloader_drop_last = dataloader_drop_last, - eval_steps = eval_steps, - dataloader_num_workers = dataloader_num_workers, - dataloader_prefetch_factor = dataloader_prefetch_factor, - past_index = past_index, - run_name = run_name, - disable_tqdm = disable_tqdm, - remove_unused_columns = remove_unused_columns, - label_names = label_names, - load_best_model_at_end = load_best_model_at_end, - metric_for_best_model = metric_for_best_model, - greater_is_better = greater_is_better, - ignore_data_skip = ignore_data_skip, - fsdp = fsdp, - fsdp_min_num_params = fsdp_min_num_params, - fsdp_config = fsdp_config, - fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap, - accelerator_config = accelerator_config, - parallelism_config = parallelism_config, - deepspeed = deepspeed, - label_smoothing_factor = label_smoothing_factor, - optim = optim, - optim_args = optim_args, - adafactor = adafactor, - group_by_length = group_by_length, - length_column_name = length_column_name, - report_to = report_to, - project = project, - trackio_space_id = trackio_space_id, - ddp_find_unused_parameters = ddp_find_unused_parameters, - ddp_bucket_cap_mb = ddp_bucket_cap_mb, - ddp_broadcast_buffers = ddp_broadcast_buffers, - dataloader_pin_memory = dataloader_pin_memory, - dataloader_persistent_workers = dataloader_persistent_workers, - skip_memory_metrics = skip_memory_metrics, - use_legacy_prediction_loop = use_legacy_prediction_loop, - push_to_hub = push_to_hub, - resume_from_checkpoint = resume_from_checkpoint, - hub_model_id = hub_model_id, - hub_strategy = hub_strategy, - hub_token = hub_token, - hub_private_repo = hub_private_repo, - hub_always_push = hub_always_push, - hub_revision = hub_revision, - gradient_checkpointing = gradient_checkpointing, - gradient_checkpointing_kwargs = gradient_checkpointing_kwargs, - include_inputs_for_metrics = include_inputs_for_metrics, - eval_do_concat_batches = eval_do_concat_batches, - fp16_backend = fp16_backend, - push_to_hub_model_id = push_to_hub_model_id, - push_to_hub_organization = push_to_hub_organization, - push_to_hub_token = push_to_hub_token, - mp_parameters = mp_parameters, - auto_find_batch_size = auto_find_batch_size, - full_determinism = full_determinism, - torchdynamo = torchdynamo, - ray_scope = ray_scope, - ddp_timeout = ddp_timeout, - torch_compile = torch_compile, - torch_compile_backend = torch_compile_backend, - torch_compile_mode = torch_compile_mode, - include_tokens_per_second = include_tokens_per_second, - include_num_input_tokens_seen = include_num_input_tokens_seen, - neftune_noise_alpha = neftune_noise_alpha, - optim_target_modules = optim_target_modules, - batch_eval_metrics = batch_eval_metrics, - eval_on_start = eval_on_start, - use_liger_kernel = use_liger_kernel, - liger_kernel_config = liger_kernel_config, - eval_use_gather_object = eval_use_gather_object, - average_tokens_across_devices = average_tokens_across_devices, - model_init_kwargs = model_init_kwargs, - disable_dropout = disable_dropout, - max_prompt_length = max_prompt_length, - num_generations = num_generations, - max_completion_length = max_completion_length, - ds3_gather_for_generation = ds3_gather_for_generation, - shuffle_dataset = shuffle_dataset, - generation_batch_size = generation_batch_size, - steps_per_generation = steps_per_generation, - temperature = temperature, - top_p = top_p, - top_k = top_k, - min_p = min_p, - generation_kwargs = generation_kwargs, - repetition_penalty = repetition_penalty, - use_transformers_paged = use_transformers_paged, - cache_implementation = cache_implementation, - use_vllm = use_vllm, - vllm_mode = vllm_mode, - vllm_model_impl = vllm_model_impl, - vllm_enable_sleep_mode = vllm_enable_sleep_mode, - vllm_guided_decoding_regex = vllm_guided_decoding_regex, - vllm_server_base_url = vllm_server_base_url, - vllm_server_host = vllm_server_host, - vllm_server_port = vllm_server_port, - vllm_server_timeout = vllm_server_timeout, - vllm_gpu_memory_utilization = vllm_gpu_memory_utilization, - vllm_tensor_parallel_size = vllm_tensor_parallel_size, - beta = beta, - num_iterations = num_iterations, - epsilon = epsilon, - epsilon_high = epsilon_high, - reward_weights = reward_weights, - normalize_advantages = normalize_advantages, - reward_clip_range = reward_clip_range, - mask_truncated_completions = mask_truncated_completions, - sync_ref_model = sync_ref_model, - ref_model_mixup_alpha = ref_model_mixup_alpha, - ref_model_sync_steps = ref_model_sync_steps, - log_completions = log_completions, - num_completions_to_print = num_completions_to_print, - wandb_log_unique_prompts = wandb_log_unique_prompts, - rloo_k = rloo_k, - cliprange = cliprange, - kl_coef = kl_coef, - exp_name = exp_name, - normalize_reward = normalize_reward, - num_ppo_epochs = num_ppo_epochs, - num_mini_batches = num_mini_batches, - total_episodes = total_episodes, - response_length = response_length, - token_level_kl = token_level_kl, - dataset_num_proc = dataset_num_proc, - local_rollout_forward_batch_size = local_rollout_forward_batch_size, - num_sample_generations = num_sample_generations, - stop_token = stop_token, - stop_token_id = stop_token_id, - missing_eos_penalty = missing_eos_penalty,**kwargs) - self.vllm_sampling_params = vllm_sampling_params - self.unsloth_num_chunks = unsloth_num_chunks - if unsloth_grpo_mini_batch is not None: - if self.generation_batch_size >= unsloth_grpo_mini_batch: - self.unsloth_grpo_mini_batch = unsloth_grpo_mini_batch - else: - raise ValueError( - f"Unsloth GRPO mini batch size needs to be less than or equal to the effective generation batch size, " - f"which is self.per_device_train_batch_size * gradient_accumulation_steps." - ) - self.unsloth_logit_chunk_multiplier = unsloth_logit_chunk_multiplier - - -pass - -class _UnslothRLOOTrainer(BaseTrainer): - """""" - - _tag_names = ["trl", "rloo"] - _name = "RLOO" - _paper = { - "title": "Back to Basics: Revisiting REINFORCE-Style Optimization for Learning from Human Feedback in LLMs", - "id": "2402.14740", - # docstyle-ignore - "citation": textwrap.dedent("""\ - @inproceedings{ahmadian2024back, - title = {{Back to Basics: Revisiting REINFORCE-Style Optimization for Learning from Human Feedback in LLMs}}, - author = {Arash Ahmadian and Chris Cremer and Matthias Gall{\'{e}} and Marzieh Fadaee and Julia Kreutzer and Olivier Pietquin and Ahmet {\"{U}}st{\"{u}}n and Sara Hooker}, - year = 2024, - booktitle = {Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), {ACL} 2024, Bangkok, Thailand, August 11-16, 2024}, - pages = {12248--12267}, - publisher = {Association for Computational Linguistics}, - editor = {Lun{-}Wei Ku and Andre Martins and Vivek Srikumar}, - }"""), - } - - def __init__( - self, - # Note for dev: we can remove the default None when we remove the deprecated model parameter in version 0.25.0 - model: Union[str, PreTrainedModel] = None, - reward_funcs: Union[RewardFunc, list[RewardFunc]] = None, - args: Optional[RLOOConfig] = None, - train_dataset: Optional[Union[Dataset, IterableDataset]] = None, - eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None, - processing_class: Optional[Union[PreTrainedTokenizerBase, ProcessorMixin]] = None, - reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None, - callbacks: Optional[list[TrainerCallback]] = None, - optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), - peft_config: Optional["PeftConfig"] = None, - # Deprecated parameters - config=None, - reward_model=None, - policy=None, - ref_policy=None, - data_collator=None, - ): - - if hasattr(model, 'vllm_engine') and hasattr(args, 'use_vllm'): - if (getattr(args, 'use_vllm', False) == False): - args.use_vllm = True - if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"): - warnings.warn( - "This trainer will soon be moved to trl.experimental and is a candidate for removal. If you rely on " - "it and want it to remain, please share your comments here: " - "https://github.com/huggingface/trl/issues/4223. Silence this warning by setting environment variable " - "TRL_EXPERIMENTAL_SILENCE=1." - ) - # Handle deprecated parameters - if config is not None: - warnings.warn( - "Parameter 'config' is deprecated and will be removed in version 0.25.0. Please use 'args' instead. " - "We are setting args=config" - ) - if args is None: - args = config - else: - raise ValueError("Cannot specify both 'config' (deprecated) and 'args'. Please use 'args' only.") - - if reward_model is not None: - warnings.warn( - "Parameter 'reward_model' is deprecated and will be removed in version 0.25.0. Please use " - "'reward_funcs' instead. We are setting reward_funcs=reward_model" - ) - if reward_funcs is None: - reward_funcs = reward_model - else: - raise ValueError( - "Cannot specify both 'reward_model' (deprecated) and 'reward_funcs'. Please use 'reward_funcs' " - "only." - ) - if policy is not None: - warnings.warn( - "Parameter 'policy' is deprecated and will be removed in version 0.25.0. Please use 'model' instead. " - "We are setting model=policy" - ) - if model is None: - model = policy - else: - raise ValueError("Cannot specify both 'policy' (deprecated) and 'model'. Please use 'model' only.") - if ref_policy is not None: - warnings.warn( - "Parameter 'ref_policy' is deprecated and will be removed in version 0.25.0. To use the initial model " - "as the reference model, simply omit this parameter. The parameter is ignored." - ) - if data_collator is not None: - warnings.warn( - "Parameter 'data_collator' is deprecated and will be removed in version 0.25.0. The RLOOTrainer does " - "not use a data collator, so this parameter is ignored." - ) - if "input_ids" in train_dataset.column_names: - warnings.warn( - "The training dataset contains a column named 'input_ids', indicating that it is pre-tokenized. " - "Support for pre-tokenized datasets is deprecated and will be removed in version 0.25. Please provide " - "the raw dataset (conversational or standard) with a 'prompt' column instead." - ) - - def decode(example, tokenizer): - return {"prompt": tokenizer.decode(example["input_ids"])} - - train_dataset = train_dataset.map(decode, fn_kwargs={"tokenizer": processing_class}) - if eval_dataset is not None and "input_ids" in eval_dataset.column_names: - warnings.warn( - "The evaluation dataset contains a column named 'input_ids', indicating that it is pre-tokenized. " - "Support for pre-tokenized datasets is deprecated and will be removed in version 0.25. Please provide " - "the raw dataset (conversational or standard) with a 'prompt' column instead." - ) - - def decode(example, tokenizer): - return {"prompt": tokenizer.decode(example["input_ids"])} - - eval_dataset = eval_dataset.map(decode, fn_kwargs={"tokenizer": processing_class}) - - # Args - if args is None: - model_name = model if isinstance(model, str) else model.config._name_or_path - model_name = model_name.split("/")[-1] - args = RLOOConfig(f"{model_name}-RLOO") - - # Models - # Trained model - model_init_kwargs = args.model_init_kwargs or {} - if isinstance(model, str): - model_id = model - dtype = model_init_kwargs.get("dtype") - if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None: - pass # dtype is already a torch.dtype or "auto" or None - elif isinstance(dtype, str): # it's a str, but not "auto" - dtype = getattr(torch, dtype) - model_init_kwargs["dtype"] = dtype - else: - raise ValueError( - "Invalid `dtype` passed to `RLOOConfig`. Expected either 'auto' or a string representing " - f"a `torch.dtype` (e.g., 'float32'), but got {dtype}." - ) - # Disable caching if gradient checkpointing is enabled [not supported] - config = AutoConfig.from_pretrained(model_id) - architecture = getattr(transformers, config.architectures[0]) - model = architecture.from_pretrained(model_id, **model_init_kwargs) - else: - model_id = model.config._name_or_path - if args.model_init_kwargs is not None: - logger.warning( - "You passed `model_init_kwargs` to the `RLOOConfig`, but your model is already instantiated. " - "The `model_init_kwargs` will be ignored." - ) - - # Some models [SmolVLM/Idefics3] don't support `logits_to_keep` argument and error out if we pass it - # Inspect the forward method before we wrap the model with PEFT - self.model_kwarg_keys = ( - inspect.signature(model.forward).parameters.keys() - if not hasattr(model, "get_base_model") - else inspect.signature(model.get_base_model().forward).parameters.keys() - ) - - if False: - pass - - # Processing class - if processing_class is None: - processing_class = AutoProcessor.from_pretrained(model.config._name_or_path, truncation_side="left") - - # Handle pad token for processors or tokenizers - if isinstance(processing_class, ProcessorMixin): - tokenizer = processing_class.tokenizer - elif isinstance(processing_class, PreTrainedTokenizerBase): - tokenizer = processing_class - else: - raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") - - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - - self.pad_token = tokenizer.pad_token - self.pad_token_id = tokenizer.pad_token_id - self.eos_token_id = tokenizer.eos_token_id - - # Reward functions - if not isinstance(reward_funcs, list): - reward_funcs = [reward_funcs] - self.reward_func_names = [] - for i, reward_func in enumerate(reward_funcs): - if isinstance(reward_func, str): - reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained( - reward_func, num_labels=1, **model_init_kwargs - ) - if isinstance(reward_funcs[i], nn.Module): # Use Module over PretrainedModel for compat w/ compiled models - self.reward_func_names.append(reward_funcs[i].config._name_or_path.split("/")[-1]) - else: - self.reward_func_names.append(reward_funcs[i].__name__) - self.reward_funcs = reward_funcs - - # Reward weights - if args.reward_weights is not None: - if len(args.reward_weights) != len(reward_funcs): - raise ValueError( - f"Number of reward weights ({len(args.reward_weights)}) must match number of reward " - f"functions ({len(reward_funcs)})" - ) - self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32) - else: - self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32) - - # Reward processing class - if reward_processing_classes is None: - reward_processing_classes = [None] * len(reward_funcs) - elif not isinstance(reward_processing_classes, list): - reward_processing_classes = [reward_processing_classes] - if len(reward_processing_classes) != len(reward_funcs): - raise ValueError( - f"The number of reward processing classes ({len(reward_processing_classes)}) must match the number of " - f"reward functions ({len(reward_funcs)})." - ) - - for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)): - if isinstance(reward_func, PreTrainedModel): - if reward_processing_class is None: - reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path) - if reward_processing_class.pad_token_id is None: - reward_processing_class.pad_token = reward_processing_class.eos_token - # The reward model computes the reward for the latest non-padded token in the input sequence. - # So it's important to set the pad token ID to the padding token ID of the processing class. - reward_func.config.pad_token_id = reward_processing_class.pad_token_id - reward_processing_classes[i] = reward_processing_class - - self.reward_processing_classes = reward_processing_classes - - # Training arguments - self.max_prompt_length = args.max_prompt_length - self.max_completion_length = args.max_completion_length - self.num_generations = args.num_generations - self.temperature = args.temperature - self.top_p = args.top_p - self.top_k = args.top_k - self.min_p = args.min_p - self.repetition_penalty = args.repetition_penalty - self.use_transformers_paged = args.use_transformers_paged - self.use_vllm = args.use_vllm - self.vllm_mode = args.vllm_mode - self.vllm_gpu_memory_utilization = args.vllm_gpu_memory_utilization # only applies to colocation mode - self.vllm_tensor_parallel_size = args.vllm_tensor_parallel_size # only applies to colocation mode - self.normalize_advantages = args.normalize_advantages - self.mask_truncated_completions = args.mask_truncated_completions - self.reward_clip_range = args.reward_clip_range - - # Datasets - self.shuffle_dataset = args.shuffle_dataset - - if ( - isinstance(train_dataset, IterableDataset) - or isinstance(eval_dataset, IterableDataset) - or ( - isinstance(eval_dataset, dict) and any(isinstance(ds, IterableDataset) for ds in eval_dataset.values()) - ) - ): - # See https://github.com/huggingface/trl/issues/3213 - raise NotImplementedError( - "Iterable datasets are not yet supported in RLOOTrainer. Please use a standard dataset instead." - ) - - # Multi-step - self.num_iterations = args.num_iterations - self.epsilon_low = args.epsilon - self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon - # Tracks the number of iterations [forward + backward passes], including those within a grad accum cycle - self._step = 0 - # Buffer the batch to reuse generated outputs across multiple updates. For more details, see - # `_get_train_sampler` and `_prepare_inputs`. - self._buffered_inputs = None - - # The trainer estimates the number of FLOPs [floating-point operations] using the number of elements in the - # input tensor associated with the key "input_ids". However, in RLOO, the sampled data does not include the - # "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning: - # "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To - # suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True. - # This acts as a flag to indicate that the warning has already been issued. - model.warnings_issued["estimate_tokens"] = True - - super().__init__( - model=model, - args=args, - data_collator=identity, # No data collation is needed in RLOO - train_dataset=train_dataset, - eval_dataset=eval_dataset, - processing_class=processing_class, - callbacks=callbacks, - optimizers=optimizers, - ) - - # Reference model - self.beta = args.beta - if self.beta == 0.0: - # If beta is 0.0, the reference model is not needed - self.ref_model = None - elif is_peft_model(model): - # If PEFT is used, the reference model is not needed since the adapter can be disabled - # to revert to the initial model. - self.ref_model = None - else: - # For deepspeed, fsdp or non-distributed models, create a reference model from scratch - config = AutoConfig.from_pretrained(model_id) - architecture = getattr(transformers, config.architectures[0]) - self.ref_model = architecture.from_pretrained(model_id, **model_init_kwargs) - - # Disable dropout in the models - if args.disable_dropout: - disable_dropout_in_model(model) - if self.ref_model is not None: - disable_dropout_in_model(self.ref_model) - - # Initialize the metrics - self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} - self._total_train_tokens = 0 - self.log_completions = args.log_completions - self.wandb_log_unique_prompts = args.wandb_log_unique_prompts - self.num_completions_to_print = args.num_completions_to_print - # Keep logs sized to the generation batch to record only outputs from the latest model update. - self._logs = { - "images": deque(maxlen=args.generation_batch_size), - "prompt": deque(maxlen=args.generation_batch_size), - "completion": deque(maxlen=args.generation_batch_size), - "rewards": defaultdict(lambda: deque(maxlen=args.generation_batch_size)), - "advantages": deque(maxlen=args.generation_batch_size), - } - - # Ensure each process receives a unique seed to prevent duplicate completions when generating with - # transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but - # it's safer to set it in all cases. - set_seed(args.seed, device_specific=True) - - if self.use_vllm: - if not is_vllm_available(): - raise ImportError( - "vLLM is not available and `use_vllm` is set to True. Please install vLLM with " - "`pip install trl[vllm]` to use it." - ) - - if self.vllm_mode == "server": - if self.accelerator.is_main_process: - if args.vllm_server_base_url is not None: - base_url = args.vllm_server_base_url - else: - base_url = f"http://{args.vllm_server_host}:{args.vllm_server_port}" - self.vllm_client = VLLMClient(base_url=base_url, connection_timeout=args.vllm_server_timeout) - self.vllm_client.init_communicator(device=torch.cuda.current_device()) - - elif self.vllm_mode == "colocate": - if not self.accelerator.num_processes % self.vllm_tensor_parallel_size == 0: - raise ValueError( - f"vllm_tensor_parallel_size ({self.vllm_tensor_parallel_size}) must divide world size " - f"({self.accelerator.num_processes}) evenly." - ) - - if self.vllm_tensor_parallel_size > 1: - self.tp_group, _ = torch.distributed.new_subgroups_by_enumeration( - [ - list(range(i * self.vllm_tensor_parallel_size, (i + 1) * self.vllm_tensor_parallel_size)) - for i in range(self.accelerator.num_processes // self.vllm_tensor_parallel_size) - ] - ) - os.environ["RANK"] = str(self.accelerator.process_index) - os.environ["LOCAL_RANK"] = str(self.accelerator.local_process_index) - os.environ["WORLD_SIZE"] = str(self.accelerator.num_processes) - ensure_master_addr_port() - - if self.max_prompt_length is not None and self.max_completion_length is not None: - max_model_len = self.max_prompt_length + self.max_completion_length - else: - max_model_len = None - self.llm = model.vllm_engine - if self.args.vllm_enable_sleep_mode: - self.llm.sleep(level=1) - else: - raise ValueError(f"vllm_mode must be either 'server' or 'colocate', got '{self.vllm_mode}'.") - self.guided_decoding_regex = args.vllm_guided_decoding_regex - - self._last_loaded_step = -1 - self.accelerator.wait_for_everyone() - else: - generation_kwargs = { - "max_new_tokens": self.max_completion_length, - "do_sample": True, - "pad_token_id": tokenizer.pad_token_id, - "bos_token_id": tokenizer.bos_token_id, - "eos_token_id": tokenizer.eos_token_id, - "temperature": self.temperature, - "top_p": self.top_p, - "top_k": self.top_k, - "min_p": self.min_p, - "repetition_penalty": self.repetition_penalty, - "cache_implementation": args.cache_implementation, - } - if args.generation_kwargs is not None: - generation_kwargs.update(args.generation_kwargs) - self.generation_config = GenerationConfig(**generation_kwargs) - - # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the - # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set - # self.model_accepts_loss_kwargs to False to enable scaling. - self.model_accepts_loss_kwargs = False - - # Add tags to the model - self.model.add_model_tags(self._tag_names) - - if self.ref_model is not None: - if self.is_deepspeed_enabled: - self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) - elif self.is_fsdp_enabled: - self.ref_model = prepare_fsdp(self.ref_model, self.accelerator) - else: - self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) - - if args.sync_ref_model: - self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator)) - - for i, reward_func in enumerate(self.reward_funcs): - if isinstance(reward_func, PreTrainedModel): - if self.is_deepspeed_enabled: - self.reward_funcs[i] = prepare_deepspeed(reward_func, self.accelerator) - else: - # set device placement to True to make `prepare_model` move `reward_func` to device when using fsdp - self.reward_funcs[i] = self.accelerator.prepare_model( - reward_func, evaluation_mode=True, device_placement=True - ) - - def _set_signature_columns_if_needed(self): - # If `self.args.remove_unused_columns` is True, non-signature columns are removed. - # By default, this method sets `self._signature_columns` to the model's expected inputs. - # In RLOOTrainer, we preprocess data, so using the model's signature columns doesn't work. - # Instead, we set them to the columns expected by the `training_step` method, hence the override. - if self._signature_columns is None: - self._signature_columns = ["prompt", "image", "images"] - - # This method overrides `Trainer.get_train_dataloader` to support our custom batching strategy. - # Instead of returning a standard per-step batch (i.e., `per_device_batch_size), our dataloader loads an - # *generation* batch (i.e., `per_device_batch_size × steps_per_generation`). This allows us to generate completions - # once every steps_per_generation step—rather than once per accumulation step—which is significantly more - # efficient. The only change from the original implementation is multiplying the batch size by - # `steps_per_generation`. Thus, `_prepare_inputs` is called with this *generation* batch, and it handles the - # splitting internally. - # Maintenance note: This method is a copy-paste of the original `Trainer.get_train_dataloader` with only one line - # modification. As a result, some parts of the method aren't relevant to RLOO, but we keep them to stay one line - # apart from the super method, ensuring easier maintenance in the future. - def get_train_dataloader(self): - if self.train_dataset is None: - raise ValueError("Trainer: training requires a train_dataset.") - - train_dataset = self.train_dataset - data_collator = self.data_collator - if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): - train_dataset = self._remove_unused_columns(train_dataset, description="training") - else: - data_collator = self._get_collator_with_removed_columns(data_collator, description="training") - - dataloader_params = { - "batch_size": self._train_batch_size * self.args.steps_per_generation, # < this is the change - "collate_fn": data_collator, - "num_workers": self.args.dataloader_num_workers, - "pin_memory": self.args.dataloader_pin_memory, - "persistent_workers": self.args.dataloader_persistent_workers, - } - - if not isinstance(train_dataset, torch.utils.data.IterableDataset): - dataloader_params["sampler"] = self._get_train_sampler() - dataloader_params["drop_last"] = self.args.dataloader_drop_last - dataloader_params["worker_init_fn"] = partial( - seed_worker, num_workers=self.args.dataloader_num_workers, rank=self.args.process_index - ) - - dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor - - return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) - - def _get_train_sampler(self, dataset: Optional[Dataset] = None) -> Sampler: - # Returns a sampler that - # 1. ensures each prompt is repeated across multiple processes. This guarantees that identical prompts are - # distributed to different GPUs, allowing rewards to be computed and normalized correctly within each prompt - # group. Using the same seed across processes ensures consistent prompt assignment, preventing discrepancies - # in group formation. - # 2. repeats the batch multiple times to allow reusing generations across multiple updates. Refer to - # _prepare_inputs to see how the generations are stored and reused. - - # In the following figure, the values are the prompt indices. The first row shows the first sampled batch, the - # second row shows the second sampled batch, and so on. - # - # | GPU 0 | GPU 1 | - # - # global_step step <-───> num_generations=2 - # <-───────> per_device_train_batch_size=3 - # grad_accum ▲ ▲ 0 0 0 0 1 1 2 2 <- Generate for the first `steps_per_generation` (prompts 0 to 11); store the completions; use the first slice to compute the loss - # =2 ▼ | 0 1 3 3 4 4 5 5 <- Take the stored generations and use the second slice to compute the loss - # | - # | 1 2 6 6 7 7 8 8 <- Take the stored generations and use the third slice to compute the loss - # steps_per_gen=4 ▼ 1 3 9 9 10 10 11 11 <- Take the stored generations and use the fourth slice to compute the loss - # - # 2 4 12 12 13 13 14 14 <- Generate for the second `steps_per_generation` (prompts 12 to 23); store the completions; use the first slice to compute the loss - # 2 5 15 15 16 16 17 17 <- Take the stored generations and use the second slice to compute the loss - # ... - if dataset is None: - dataset = self.train_dataset - return RepeatSampler( - data_source=dataset, - mini_repeat_count=self.num_generations, - batch_size=self.args.generation_batch_size // self.num_generations, - repeat_count=self.num_iterations * self.args.steps_per_generation, - shuffle=self.shuffle_dataset, - seed=self.args.seed, - ) - - def _get_eval_sampler(self, eval_dataset) -> Sampler: - # See _get_train_sampler for an explanation of the sampler. - return RepeatSampler( - data_source=eval_dataset, - mini_repeat_count=self.num_generations, - seed=self.args.seed, - ) - - @profiling_decorator - def _get_per_token_logps_and_entropies( - self, - model, - input_ids, - attention_mask, - logits_to_keep, - batch_size=None, - compute_entropy=False, - pixel_values=None, - image_grid_thw=None, - num_images=None, - pixel_attention_mask=None, - image_sizes=None, - token_type_ids=None, - ) -> dict[str, Optional[torch.Tensor]]: - """Compute log-probs and (optionally) entropies for each token.""" - batch_size = batch_size or input_ids.size(0) # Chunk inputs into smaller batches to reduce memory peak - all_logps = [] - all_entropies = [] - for start in range(0, input_ids.size(0), batch_size): - input_ids_batch = input_ids[start : start + batch_size] - attention_mask_batch = attention_mask[start : start + batch_size] - - # Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't) - model_inputs = {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch} - - if image_grid_thw is not None and pixel_values is not None: - rows_per_image = image_grid_thw.prod(dim=-1) - rows_per_sample = torch.split(rows_per_image, num_images) - rows_per_sample = torch.stack([s.sum() for s in rows_per_sample]) - cum_rows = torch.cat([torch.tensor([0], device=rows_per_sample.device), rows_per_sample.cumsum(0)]) - row_start, row_end = cum_rows[start].item(), cum_rows[start + batch_size].item() - model_inputs["pixel_values"] = pixel_values[row_start:row_end] - cum_imgs = torch.tensor([0] + num_images).cumsum(0) - img_start, img_end = cum_imgs[start], cum_imgs[start + batch_size] - model_inputs["image_grid_thw"] = image_grid_thw[img_start:img_end] - elif pixel_values is not None: - model_inputs["pixel_values"] = pixel_values[start : start + batch_size] - if pixel_attention_mask is not None: - model_inputs["pixel_attention_mask"] = pixel_attention_mask[start : start + batch_size] - if image_sizes is not None: - model_inputs["image_sizes"] = image_sizes[start : start + batch_size] - if token_type_ids is not None: - model_inputs["token_type_ids"] = token_type_ids[start : start + batch_size] - - # Only add logits_to_keep if the model supports it - if "logits_to_keep" in self.model_kwarg_keys: - # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded - model_inputs["logits_to_keep"] = logits_to_keep + 1 - - model_inputs["use_cache"] = False # only used in generation; set False to suppress warnings - - logits = model(**model_inputs).logits - # Exclude the last value: it corresponds to the next token pred - logits = logits[:, :-1, :] # (B, L-1, H) - # Only keep the last logits_to_keep. For model that support logits_to_keep, this is a no-op. - logits = logits[:, -logits_to_keep:, :] # (B, logits_to_keep, H) - # Divide logits by sampling temperature. - # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details - logits = logits / self.temperature - - completion_ids = input_ids_batch[:, -logits_to_keep:] - logps = selective_log_softmax(logits, completion_ids) # compute logprobs - all_logps.append(logps) - - if compute_entropy: - with torch.no_grad(): - entropies = entropy_from_logits(logits) - all_entropies.append(entropies) - - logps = torch.cat(all_logps, dim=0) - entropies = torch.cat(all_entropies, dim=0) if compute_entropy else None - return logps, entropies - - def _fix_param_name_to_vllm(self, name, extra_prefixes: Optional[list[str]] = None): - extra_prefixes = extra_prefixes or [] - prefixes = ["_checkpoint_wrapped_module."] + extra_prefixes - for prefix in prefixes: - name = name.replace(prefix, "") - return name - - def _sync_fsdp1_params_to_vllm(self, module: nn.Module, prefix: str = "", visited=None): - """Memory-efficient post-order traversal of FSDP modules to extract full parameters and sync with vLLM.""" - # For FSDP1, we need to recurse into children and also use summon_full_params - if visited is None: - visited = set() - for child_name, child_module in module.named_children(): - child_prefix = f"{prefix}.{child_name}" if prefix else child_name - self._sync_fsdp1_params_to_vllm( - child_module, prefix=child_prefix, visited=visited - ) # recurse into the child - - if isinstance(module, FSDP): - with FSDP.summon_full_params(module, recurse=False, writeback=False): - for param_name, param in module.named_parameters(): - full_name = f"{prefix}.{param_name}" if prefix else param_name - full_name = self._fix_param_name_to_vllm(full_name, extra_prefixes=["_fsdp_wrapped_module."]) - - if full_name in visited: - continue # skip FSDP subtrees already traversed - visited.add(full_name) - - if self.vllm_mode == "server" and self.accelerator.is_main_process: - self.vllm_client.update_named_param(full_name, param.data) - elif self.vllm_mode == "colocate": - - pass - - pass - - def _sync_fsdp2_params_to_vllm(self, module: nn.Module): - # For FSDP2, module already covers all parameters, so no need for recursion - for name, param in module.items(): - if param.is_cpu: - param = param.to(torch.device("cuda")) - param = param.full_tensor() - - if self.vllm_mode == "server" and self.accelerator.is_main_process: - self.vllm_client.update_named_param(name, param) - elif self.vllm_mode == "colocate": - - pass - - pass - - @profiling_decorator - def _move_model_to_vllm(self): - # For DeepSpeed ZeRO-3 and FSDP, we need to gather all parameters before operations - deepspeed_plugin = self.accelerator.state.deepspeed_plugin - zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3 - if zero_stage_3: - import deepspeed - - gather_if_zero3 = deepspeed.zero.GatheredParameters - else: - gather_if_zero3 = nullcontext - - if is_peft_model(self.model): - # With PEFT and FSDP/DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as - # merging adapters in a sharded manner is not supported. - # TODO: does this work with FSDP? - with gather_if_zero3(list(self.model.parameters())): - self.model.merge_adapter() - - # Update vLLM weights while parameters are gathered - if self.is_fsdp_enabled: # note if using FSDP, gather_if_zero3 is nullcontext - # Update vLLM weights while parameters are gathered - # For PEFT with FSDP we need to use the memory efficient post-order traversal - fsdp_plugin = getattr(self.accelerator.state, "fsdp_plugin", None) - fsdp_version = getattr(fsdp_plugin, "fsdp_version", 1) if fsdp_plugin else 1 - if fsdp_version == 1: - self._sync_fsdp1_params_to_vllm( - self.model - ) # use memory-efficient post-order traversal for FSDP - elif fsdp_version == 2: - self._sync_fsdp2_params_to_vllm(self.model) - else: - # DeepSpeed ZeRO-3 with PEFT - for name, param in self.model.named_parameters(): - # When using PEFT, we need to recover the original parameter name and discard some parameters - name = name.removeprefix("base_model.model.").replace(".base_layer", "") - if self.model.prefix in name: - continue - # When module to save, remove its prefix and discard the original module - if "original_module" in name: - continue - name = self._fix_param_name_to_vllm(name, extra_prefixes=["modules_to_save.default."]) - - if self.vllm_mode == "server" and self.accelerator.is_main_process: - self.vllm_client.update_named_param(name, param.data) - elif self.vllm_mode == "colocate": - - pass - - pass - # Unmerge adapters while parameters are still gathered - self.model.unmerge_adapter() - # Parameters will automatically be repartitioned when exiting the context - else: - # For non-PEFT models, simply gather (if needed) and update each parameter individually. - if self.is_fsdp_enabled: - fsdp_plugin = getattr(self.accelerator.state, "fsdp_plugin", None) - fsdp_version = getattr(fsdp_plugin, "fsdp_version", 1) if fsdp_plugin else 1 - if fsdp_version == 1: - self._sync_fsdp1_params_to_vllm(self.model) # use memory-efficient post-order traversal for FSDP - elif fsdp_version == 2: - self._sync_fsdp2_params_to_vllm(self.model) - else: - for name, param in self.model.named_parameters(): - name = self._fix_param_name_to_vllm(name) - with gather_if_zero3([param]): - if self.vllm_mode == "server" and self.accelerator.is_main_process: - self.vllm_client.update_named_param(name, param.data) - elif self.vllm_mode == "colocate": - - pass - - pass - - # Reset cache on vLLM - if self.vllm_mode == "server" and self.accelerator.is_main_process: - self.vllm_client.reset_prefix_cache() - elif self.vllm_mode == "colocate": - self.llm.reset_prefix_cache() - - @profiling_decorator - def _prepare_inputs( - self, generation_batch: dict[str, Union[torch.Tensor, Any]] - ) -> dict[str, Union[torch.Tensor, Any]]: - # Prepares inputs for model training/evaluation by managing completion generation and batch handling. - # During training: - # - Receives the local generation batch (Per-GPU batch size × steps per generation) - # from the modified training dataloader instead of the standard local batch - # - Generates completions once for the entire generation batch and splits it into batches of size - # `per_device_train_batch_size` - # - Buffers these completions and returns the appropriate slice for the current accumulation step - # - Optimizes by regenerating completions only periodically (every steps_per_generation * num_iterations) - # During evaluation: - # - The input is treated as a standard local batch (no accumulation, no multiple iterations) - # - Completions are generated for each batch without buffering or reuse - # Returns a single local batch in both cases. - - mode = "train" if self.model.training else "eval" - if mode == "train": - generate_every = self.args.steps_per_generation * self.num_iterations - if self._step % generate_every == 0 or self._buffered_inputs is None: - # self._buffered_inputs=None can occur when resuming from a checkpoint - generation_batch = self._generate_and_score_completions(generation_batch) - generation_batch = split_pixel_values_by_grid(generation_batch) - - try: generation_batch = shuffle_sequence_dict(generation_batch) - - except: pass - generation_batches = split_tensor_dict(generation_batch, self.args.steps_per_generation) - self._buffered_inputs = [unsplit_pixel_values_by_grid(batch) for batch in generation_batches] - inputs = self._buffered_inputs[self._step % self.args.steps_per_generation] - self._step += 1 - else: - # In evaluation, there is neither batch grouping for generation, nor multiple iterations, hence - # local generation batch == local eval batch - inputs = self._generate_and_score_completions(generation_batch) - return inputs - - @profiling_decorator - def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): - device = self.accelerator.device - rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device) - - # Repeat all input columns (but "prompt", "completion", and "completion_ids") to match the num of generations - keys = [key for key in inputs[0] if key not in ["prompt", "completion", "completion_ids"]] - reward_kwargs = {key: [example[key] for example in inputs] for key in keys} - - # This allows for dynamic reward shaping based on training progress. - reward_kwargs["trainer_state"] = self.state - - for i, (reward_func, reward_processing_class, reward_func_name) in enumerate( - zip(self.reward_funcs, self.reward_processing_classes, self.reward_func_names) - ): - with profiling_context(self, reward_func_name): - if isinstance(reward_func, nn.Module): # Module (no PretrainedModel) for compat with compiled models - if is_conversational(inputs[0]): - messages = [{"messages": p + c} for p, c in zip(prompts, completions)] - texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages] - else: - texts = [p + c for p, c in zip(prompts, completions)] - reward_inputs = reward_processing_class( - text=texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False - ) - reward_inputs = super()._prepare_inputs(reward_inputs) - with torch.inference_mode(): - rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,) - else: - output_reward_func = reward_func( - prompts=prompts, completions=completions, completion_ids=completion_ids_list, **reward_kwargs - ) - # Convert None values to NaN - output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func] - - rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) - - # If all reward functions return None for a given row, issue a detailed warning - if torch.isnan(rewards_per_func).all(dim=1).any(): - nan_row_idx = torch.isnan(rewards_per_func).all(dim=1).nonzero(as_tuple=True)[0][0] - row_reward_kwargs = { - key: value[nan_row_idx] for key, value in reward_kwargs.items() if key != "trainer_state" - } - row_reward_kwargs["prompt"] = prompts[nan_row_idx] - row_reward_kwargs["completion"] = completions[nan_row_idx] - logger.warning( - f"All reward functions returned None for the following kwargs:\n{row_reward_kwargs}\n" - "Please ensure that at least one reward function returns a valid reward." - ) - - # Gather the reward per function: this part is crucial, because the rewards are normalized per group and the - # completions may be distributed across processes - rewards_per_func = gather(rewards_per_func) - return rewards_per_func - - def _generate_single_turn(self, prompts: list[str], images: Optional[list]): - device = self.accelerator.device - - # If the prompts are conversational and the inputs contain images, we need to convert the prompts from - # [{"role": "user", "content": "What color is the sky?"}] to - # [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}] - kwargs = {} - if images is not None: - kwargs = {"images": images} - for prompt, image_list in zip(prompts, images): - if isinstance(prompt, list): # i.e., when using conversational data - prepare_multimodal_messages(prompt, num_images=len(image_list)) - - prompts_text = [ - maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts - ] - - if images is not None: - prompt_inputs = self.processing_class(text=prompts_text, padding=True, return_tensors="pt", **kwargs) - prompt_inputs = super()._prepare_inputs(prompt_inputs) - forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]} - else: - forward_kwargs = {} - - # Generate completions using either vLLM or regular generation - if self.use_vllm: - if self.vllm_mode == "colocate" and self.args.vllm_enable_sleep_mode: - # wake up colocated vLLM instances if needed - torch.cuda.empty_cache() # required to avoid OOM in some cases - self.llm.wake_up() - - # First, update the vLLM weights if needed - if self.state.global_step != self._last_loaded_step: - self._move_model_to_vllm() - self._last_loaded_step = self.state.global_step - - # Generate completions using vLLM: gather all prompts and use them in a single call in the main process - if self.vllm_mode == "server": - all_prompts_text = gather_object(prompts_text) - if images is not None: - all_images = gather_object(images) - - if self.accelerator.is_main_process: - # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate - # num_generations outputs for each one. This is faster than generating outputs for each duplicate - # prompt individually. - ordered_set_of_prompts = all_prompts_text[:: self.num_generations] - - if images is not None: - ordered_set_of_images = all_images[:: self.num_generations] - else: - ordered_set_of_images = None - - with profiling_context(self, "vLLM.generate"): - output = self.vllm_client.generate( - prompts=ordered_set_of_prompts, - images=ordered_set_of_images, - n=self.num_generations, - repetition_penalty=self.repetition_penalty, - temperature=self.temperature, - top_p=self.top_p, - top_k=-1 if self.top_k is None else self.top_k, - min_p=0.0 if self.min_p is None else self.min_p, - max_tokens=self.max_completion_length, - truncate_prompt_tokens=self.max_prompt_length, - guided_decoding_regex=self.guided_decoding_regex, - generation_kwargs=self.args.generation_kwargs, - ) - payload = (output["prompt_ids"], output["completion_ids"], output["logprobs"]) - else: - payload = None - - # Broadcast the completions from the main process to all processes, ensuring each process receives its corresponding slice. - obj_list = [payload] - broadcast_object_list(obj_list, from_process=0) - all_prompt_ids, all_completion_ids, _ = obj_list[0] - - # At this point, we only get 1 copy of each prompt, so we need to repeat them num_generations times - all_prompt_ids = [ids for ids in all_prompt_ids for _ in range(self.num_generations)] - - process_slice = slice( - self.accelerator.process_index * len(prompts), - (self.accelerator.process_index + 1) * len(prompts), - ) - prompt_ids = all_prompt_ids[process_slice] - completion_ids = all_completion_ids[process_slice] - - # Generate completions using colocated vLLM instances: each device holds vLLM copy and work on their own batch of prompts - elif self.vllm_mode == "colocate": - if self.guided_decoding_regex: - guided_decoding = GuidedDecodingParams(regex=self.guided_decoding_regex) - else: - guided_decoding = None - - generation_kwargs = { - "n": 1, # vLLM on each GPU generates only 1 in colocate mode - "repetition_penalty": self.repetition_penalty, - "temperature": self.temperature, - "top_p": self.top_p, - "top_k": -1 if self.top_k is None else self.top_k, - "min_p": 0.0 if self.min_p is None else self.min_p, - "max_tokens": self.max_completion_length, - "truncate_prompt_tokens": self.max_prompt_length, - "guided_decoding": guided_decoding, - } - if self.args.generation_kwargs is not None: - generation_kwargs.update(self.args.generation_kwargs) - sampling_params = SamplingParams(**grpo_update_SamplingParams(SamplingParams, generation_kwargs, getattr(self.args, 'vllm_sampling_params', None))) - - if self.vllm_tensor_parallel_size > 1: - # Gather prompts from all ranks in the TP group and flatten. - # Each rank starts with its own prompts; after gathering, all ranks see the full group set. - orig_size = len(prompts_text) - gathered_prompts = [None for _ in range(self.vllm_tensor_parallel_size)] - torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.tp_group) - all_prompts_text = [p for sublist in gathered_prompts for p in sublist] - - if images is not None: - gathered_images = [None for _ in range(self.vllm_tensor_parallel_size)] - torch.distributed.all_gather_object(gathered_images, images, group=self.tp_group) - all_images = [img for sublist in gathered_images for img in sublist] - else: - all_images = None - else: - all_prompts_text = prompts_text - all_images = images - - if images is not None and all_images: - vllm_inputs = [] - for prompt, image_list in zip(all_prompts_text, all_images): - vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image_list}}) - - else: - vllm_inputs = all_prompts_text - - with profiling_context(self, "vLLM.generate"): - all_outputs = self.llm.generate(vllm_inputs, sampling_params=sampling_params, use_tqdm=False, lora_request = self.model.load_lora('rloo_trainer_lora_model', load_tensors = True)) - - all_prompt_ids = [output.prompt_token_ids for output in all_outputs] - all_completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] - - if self.vllm_tensor_parallel_size > 1: - # Slice completions for this rank within its TP group. - # Each rank generates all outputs — we keep only our share. - local_rank_in_group = torch.distributed.get_rank(group=self.tp_group) - tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size) - prompt_ids = all_prompt_ids[tp_slice] - completion_ids = all_completion_ids[tp_slice] - else: - prompt_ids = all_prompt_ids - completion_ids = all_completion_ids - - if self.args.vllm_enable_sleep_mode: - self.llm.sleep(level=1) - - elif self.use_transformers_paged: - # Re-process inputs for paged generation if needed - # Note: images are already validated and preprocessed above - paged_prompt_inputs = self.processing_class(text=prompts_text, **kwargs) - previous_attn = self.model_wrapped.config._attn_implementation - - if is_flash_attn_2_available(): - self.model_wrapped.config._attn_implementation = "paged_attention" - else: - self.model_wrapped.config._attn_implementation = "sdpa_paged" - with ( - profiling_context(self, "transformers.generate_batch"), - unwrap_model_for_generation( - self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation - ) as unwrapped_model, - torch.no_grad(), - FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(), - ): - # Cast to the appropriate dtype based on training configuration - if self.args.bf16: - unwrapped_model.to(torch.bfloat16) - elif self.args.fp16: - unwrapped_model.to(torch.float16) - with torch.inference_mode(): - all_outputs = unwrapped_model.generate_batch( - paged_prompt_inputs.input_ids, generation_config=self.generation_config, progress_bar=False - ) - unwrapped_model.train() # restore training mode, as generate_batch forces eval mode - completion_ids = [output.generated_tokens for output in all_outputs.values()] - prompt_ids = paged_prompt_inputs.input_ids - # Restore the original attention implementation, training mode - self.model_wrapped.config._attn_implementation = previous_attn - - else: - # Regular generation path - generate_inputs = self.processing_class( - text=prompts_text, - return_tensors="pt", - padding=True, - padding_side="left", - max_length=self.max_prompt_length, - truncation=True, - add_special_tokens=False, - **kwargs, - ) - generate_inputs = super()._prepare_inputs(generate_inputs) - - with ( - profiling_context(self, "transformers.generate"), - unwrap_model_for_generation( - self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation - ) as unwrapped_model, - torch.no_grad(), - FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(), - ): - prompt_completion_ids = unwrapped_model.generate( - **generate_inputs, generation_config=self.generation_config, disable_compile=True - ) - # Compute prompt length and extract completion ids - prompt_ids, prompt_mask = generate_inputs["input_ids"], generate_inputs["attention_mask"] - prompt_length = prompt_ids.size(1) - completion_ids = prompt_completion_ids[:, prompt_length:] - - # Mask everything after the first EOS token - is_eos = completion_ids == self.eos_token_id - eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) - eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] - sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) - completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() - prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())] - completion_ids = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool())] - - return prompt_ids, completion_ids, forward_kwargs - - def _generate(self, prompts: list[str], images: Optional[list]): - device = self.accelerator.device - mode = "train" if self.model.training else "eval" - - prompt_ids, completion_ids, forward_kwargs = self._generate_single_turn(prompts, images) - - # Get completion length per sequence, used for logging - prompt_lengths = torch.tensor([len(ids) for ids in prompt_ids], device=device) - completion_lengths = torch.tensor([len(ids) for ids in completion_ids], device=device) - agg_prompt_lengths = self.accelerator.gather(prompt_lengths) - agg_completion_lengths = self.accelerator.gather(completion_lengths) - total_prompt_tokens = agg_prompt_lengths.sum() - total_completion_tokens = agg_completion_lengths.sum() # = num_items_in_batch, required for the DAPO loss - - # Log the metrics - if mode == "train": - self.state.num_input_tokens_seen += (total_prompt_tokens + total_completion_tokens).item() - self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen] - - # Log completion lengths, mean, min, max - agg_completion_lengths = self.accelerator.gather(completion_lengths) - self._metrics[mode]["completions/mean_length"].append(agg_completion_lengths.float().mean().item()) - self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item()) - self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) - - # Identify sequences that terminated with EOS and log their lengths - eos_and_pad = [self.eos_token_id, self.pad_token_id] - is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids], device=device) - agg_is_truncated = self.accelerator.gather(is_truncated) - self._metrics[mode]["completions/clipped_ratio"].append(agg_is_truncated.float().mean().item()) - term_completion_lengths = agg_completion_lengths[~agg_is_truncated] - if len(term_completion_lengths) == 0: # edge case where no terminated sequences are found - term_completion_lengths = torch.zeros(1, device=device) - self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item()) - self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) - self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) - - return prompt_ids, completion_ids, forward_kwargs - - def _generate_and_score_completions( - self, inputs: list[dict[str, Union[torch.Tensor, Any]]] - ) -> dict[str, Union[torch.Tensor, Any]]: - device = self.accelerator.device - mode = "train" if self.model.training else "eval" - - prompts = [x["prompt"] for x in inputs] - - if "images" in inputs[0]: - images = [example.get("images") for example in inputs] - elif "image" in inputs[0]: - images = [[example.get("image")] if example.get("image") is not None else None for example in inputs] - else: - images = None - # Transformers requires at least one image in the batch, otherwise it throws an error - if images is not None and all(img_list == [] for img_list in images): - images = None - - prompt_ids_list, completion_ids_list, forward_kwargs = self._generate(prompts, images) - - # Convert lists of token IDs to padded tensors - prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list] - prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] - prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") - prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left") - completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids_list] - completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] - completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right") - completion_mask = pad(completion_mask, padding_value=0, padding_side="right") - - # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask - if self.mask_truncated_completions: - eos_and_pad = [self.eos_token_id, self.pad_token_id] - is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) - completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int() - - # Concatenate prompt_mask with completion_mask for logit computation - prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C) - attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C) - # If token_type_ids are used, extend them with zeros for the completion part - if "token_type_ids" in forward_kwargs: - token_type_ids = forward_kwargs["token_type_ids"] - forward_kwargs["token_type_ids"] = torch.cat( - [token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1 - ) - - logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens - batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size - - num_images = [len(img_list) for img_list in images] if images is not None else None - - with torch.no_grad(): - # Compute the per-token log probabilities for the current model - old_per_token_logps, _ = self._get_per_token_logps_and_entropies( - self.model, - prompt_completion_ids, - attention_mask, - logits_to_keep, - batch_size, - num_images=num_images, - **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes - ) - old_logps = (old_per_token_logps * completion_mask).sum(1) # mask out padding and tokens after EOS - - # Compute the per-token log probabilities for the reference model - if self.beta != 0.0: - if self.ref_model is not None: - ref_per_token_logps, _ = self._get_per_token_logps_and_entropies( - self.ref_model, - prompt_completion_ids, - attention_mask, - logits_to_keep, - batch_size=batch_size, - num_images=num_images, - **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes - ) - else: - with self.accelerator.unwrap_model(self.model).disable_adapter(): - ref_per_token_logps, _ = self._get_per_token_logps_and_entropies( - self.model, - prompt_completion_ids, - attention_mask, - logits_to_keep, - batch_size=batch_size, - num_images=num_images, - **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes - ) - else: - ref_per_token_logps = None - - # Decode - prompts_text = self.processing_class.batch_decode(prompt_ids, skip_special_tokens=True) - completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) - if is_conversational(inputs[0]): - completions = [] - for prompt, completion in zip(prompts, completions_text): - bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else "" - completions.append([{"role": "assistant", "content": bootstrap + completion}]) - else: - completions = completions_text - - # Calculate rewards for each reward function. rewards_per_func aggregates rewards across all processes. This is - # important because rewards will be normalized per group, and completions are distributed. We will later slice - # rewards_per_func to extract each process's subset. - rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list) - - # Apply weights to each reward function's output and sum - rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) - - # Apply reward clipping if specified - if self.reward_clip_range: - rewards = rewards.clamp(min=self.reward_clip_range[0], max=self.reward_clip_range[1]) - - # Include the KL penalty in the reward - if self.beta != 0.0: - per_token_kl = old_per_token_logps - ref_per_token_logps - # Apply sequence-level KL penalty to rewards (sum KL across tokens first, then apply to each sequence) - kl = (per_token_kl * completion_mask).sum(-1) - kl = gather(kl) # rewards are gathered, so kl must be too - rewards = rewards - self.beta * kl - - grouped_rewards = rewards.view(-1, self.num_generations) - mean_grouped_rewards = grouped_rewards.mean(dim=1) - std_rewards = grouped_rewards.std(dim=1) - is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards)) - - # RLOO advantages computation - grouped_sum = grouped_rewards.sum(dim=1, keepdim=True) # (num_prompts, 1) - baselines = (grouped_sum - grouped_rewards) / (self.num_generations - 1) # (num_prompts, num_generations) - baselines = baselines.view(-1) # Flatten back to match rewards shape - advantages = rewards - baselines - - # Normalize advantages - if self.normalize_advantages: - advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-4) - - # Slice to keep only the local part of the data - process_slice = slice( - self.accelerator.process_index * len(prompts), - (self.accelerator.process_index + 1) * len(prompts), - ) - all_process_advantages = advantages.clone() # keep the aggregated advantages for logging - advantages = advantages[process_slice] - - # Calculate and log the mean KL divergence between current and reference model - if self.beta != 0.0: - mean_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum().clamp(min=1.0) - self._metrics[mode]["kl"].append(self.accelerator.gather(mean_kl).nanmean().item()) - - # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values) - for i, reward_func_name in enumerate(self.reward_func_names): - mean_rewards = torch.nanmean(rewards_per_func[:, i]).item() - self._metrics[mode][f"rewards/{reward_func_name}/mean"].append(mean_rewards) - std_func_rewards = nanstd(rewards_per_func[:, i]).item() - self._metrics[mode][f"rewards/{reward_func_name}/std"].append(std_func_rewards) - self._metrics[mode]["reward"].append(mean_grouped_rewards.mean().item()) - self._metrics[mode]["reward_std"].append(std_rewards.mean().item()) - self._metrics[mode]["frac_reward_zero_std"].append(is_std_zero.float().mean().item()) - - # Log prompt and completion texts - self._logs["prompt"].extend(gather_object(prompts_text)) - self._logs["completion"].extend(gather_object(completions_text)) - for i, name in enumerate(self.reward_func_names): - self._logs["rewards"][name].extend(rewards_per_func[:, i].tolist()) - self._logs["advantages"].extend(all_process_advantages.tolist()) - - if images is not None: - self._logs["images"].extend(gather_object(images)) - - output = { - "prompt_ids": prompt_ids, - "prompt_mask": prompt_mask, - "completion_ids": completion_ids, - "completion_mask": completion_mask, - "old_logps": old_logps, - "advantages": advantages, - } - if "pixel_values" in forward_kwargs: - output["pixel_values"] = forward_kwargs["pixel_values"] - if "image_grid_thw" in forward_kwargs: - output["image_grid_thw"] = forward_kwargs["image_grid_thw"] - if "pixel_attention_mask" in forward_kwargs: - output["pixel_attention_mask"] = forward_kwargs["pixel_attention_mask"] - if "image_sizes" in forward_kwargs: - output["image_sizes"] = forward_kwargs["image_sizes"] - if "token_type_ids" in forward_kwargs: - output["token_type_ids"] = forward_kwargs["token_type_ids"] - if images is not None: - output["num_images"] = num_images - return output - - @profiling_decorator - def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): - if return_outputs: - raise ValueError("The RLOOTrainer does not support returning outputs") - return self._compute_loss(model, inputs) - - def _compute_loss(self, model, inputs): - # Compute the per-token log probabilities for the model - prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] - completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] - input_ids = torch.cat([prompt_ids, completion_ids], dim=1) - attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) - logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens - - # Compute the per_token_logps and the entropy at each position in the completion - per_token_logps, entropies = self._get_per_token_logps_and_entropies( - model, - input_ids, - attention_mask, - logits_to_keep, - compute_entropy=True, - pixel_values=inputs.get("pixel_values"), - image_grid_thw=inputs.get("image_grid_thw"), - num_images=inputs.get("num_images"), - pixel_attention_mask=inputs.get("pixel_attention_mask"), - image_sizes=inputs.get("image_sizes"), - token_type_ids=inputs.get("token_type_ids"), - ) - - logps = (per_token_logps * completion_mask).sum(1) # mask out padding and tokens after EOS - old_logps = inputs["old_logps"] - log_ratio = logps - old_logps - - # Compute the loss - advantages = inputs["advantages"] - coef_1 = torch.exp(log_ratio) - coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) - per_sequence_loss1 = coef_1 * advantages - per_sequence_loss2 = coef_2 * advantages - per_sequence_loss = -torch.min(per_sequence_loss1, per_sequence_loss2) - loss = per_sequence_loss.mean() - - # Log the metrics - mode = "train" if self.model.training else "eval" - - # Entropy - mean_entropy = (entropies * completion_mask).sum() / completion_mask.sum().clamp(min=1.0) - self._metrics[mode]["entropy"].append(self.accelerator.gather(mean_entropy).nanmean().item()) - - # Compute the clipped probability ratios - is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages < 0) - is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages > 0) - is_region_clipped = is_low_clipped | is_high_clipped - gathered_low_clip = self.accelerator.gather(is_low_clipped.float().mean()) - self._metrics[mode]["clip_ratio/low_mean"].append(gathered_low_clip.nanmean().item()) - self._metrics[mode]["clip_ratio/low_min"].append(nanmin(gathered_low_clip).item()) - gathered_high_clip = self.accelerator.gather(is_high_clipped.float().mean()) - self._metrics[mode]["clip_ratio/high_mean"].append(gathered_high_clip.nanmean().item()) - self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item()) - gathered_clip_ratio = self.accelerator.gather(is_region_clipped.float().mean()) - self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item()) - return loss - - def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: Optional[list[str]] = None): - inputs = self._prepare_inputs(inputs) - with torch.no_grad(): - with self.compute_loss_context_manager(): - loss = self.compute_loss(model, inputs) - loss = loss.mean().detach() - return loss, None, None - - def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None: - mode = "train" if self.model.training else "eval" - metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics - - # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs` - # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format. - if mode == "eval": - metrics = {f"eval_{key}": val for key, val in metrics.items()} - - logs = {**logs, **metrics} - super().log(logs, start_time) - self._metrics[mode].clear() - - if self.accelerator.is_main_process and self.log_completions: - if is_rich_available(): - print_prompt_completions_sample( - self._logs["prompt"], - self._logs["completion"], - self._logs["rewards"], - self._logs["advantages"], - self.state.global_step, - self.num_completions_to_print, - ) - - if self.args.report_to and "wandb" in self.args.report_to and wandb.run is not None: - import pandas as pd - - table = { - "step": [str(self.state.global_step)] * len(self._logs["prompt"]), - "prompt": self._logs["prompt"], - "completion": self._logs["completion"], - **self._logs["rewards"], - "advantage": self._logs["advantages"], - } - - if self._logs["images"]: - table["images"] = [] - for image_list in self._logs["images"]: - # Convert images to wandb Image objects for proper visualization - table["images"].append([wandb.Image(image) for image in image_list]) - - df = pd.DataFrame(table) - if self.wandb_log_unique_prompts: - df = df.drop_duplicates(subset=["prompt"]) - wandb.log({"completions": wandb.Table(dataframe=df)}) - - # Ensure the model card is saved along with the checkpoint - def _save_checkpoint(self, model, trial): - if self.args.hub_model_id is None: - model_name = Path(self.args.output_dir).name - else: - model_name = self.args.hub_model_id.split("/")[-1] - self.create_model_card(model_name=model_name) - super()._save_checkpoint(model, trial) -class UnslothRLOOTrainer(_UnslothRLOOTrainer): - """ - - Trainer for the Reinforce Leave One Out (RLOO) method. This algorithm was initially proposed in the paper [Back to - Basics: Revisiting REINFORCE Style Optimization for Learning from Human Feedback in - LLMs](https://huggingface.co/papers/2402.14740). - - Example: - - ```python - from datasets import load_dataset - from trl import RLOOTrainer - - dataset = load_dataset("trl-lib/tldr", split="train") - def reward_func(completions, **kwargs): - # Dummy reward function that rewards completions with more unique letters. - return [float(len(set(completion))) for completion in completions] - trainer = RLOOTrainer( - model="Qwen/Qwen2-0.5B-Instruct", - reward_funcs=reward_func, - train_dataset=dataset, - ) - - trainer.train() - ``` - - Args: - model (`Union[str, PreTrainedModel]`): - Model to be trained. Can be either: - - - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a - path to a *directory* containing model weights saved using - [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded - using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keyword arguments in - `args.model_init_kwargs`. - - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported. - reward_funcs (`Union[RewardFunc, list[RewardFunc]]`): - Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward - functions with the prompts and completions and sum the rewards. Can be either: - - - A single reward function, such as: - - A string: The *model ID* of a pretrained model hosted inside a model repo on huggingface.co, or a - path to a *directory* containing model weights saved using - [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded - using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the - keyword arguments in `args.model_init_kwargs`. - - A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported. - - A custom reward function: The function is provided with the prompts and the generated completions, - plus any additional columns in the dataset. It should return a list of rewards. Custom reward - functions can also return `None` when the reward is not applicable to those samples. This is useful - for multi-task training where different reward functions apply to different types of samples. When a - reward function returns `None` for a sample, that reward function is excluded from the reward - calculation for that sample. For more details, see [Using a custom reward - function](#using-a-custom-reward-function). - - The trainer's state is also passed to the reward function. The trainer's state is an instance of - [`~transformers.TrainerState`] and can be accessed by accessing the `trainer_state` argument to the - reward function's signature. - - A list of reward functions, where each item can independently be any of the above types. Mixing different - types within the list (e.g., a string model ID and a custom reward function) is allowed. - args ([`RLOOConfig`], *optional*): - Configuration for this trainer. If `None`, a default configuration is used. - train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]): - Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset is - ignored. The format of the samples can be either: - - - [Standard](dataset_formats#standard): Each sample contains plain text. - - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role - and content). - eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`): - Dataset to use for evaluation. It must meet the same requirements as `train_dataset`. - processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.ProcessorMixin`], *optional*): - Processing class used to process the data. The padding side must be set to "left". If `None`, the - processing class is loaded from the model's name with [`~transformers.AutoProcessor.from_pretrained`]. A - padding token, `tokenizer.pad_token`, must be set. If the processing class has not set a padding token, - `tokenizer.eos_token` will be used as the default. - reward_processing_classes ([`~transformers.PreTrainedTokenizerBase`] or `list[PreTrainedTokenizerBase]`, *optional*): - Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either: - - - A single processing class: Used when `reward_funcs` contains only one reward function. - - A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`. - If set to `None`, or if an element of the list corresponding to a [`~transformers.PreTrainedModel`] is - `None`, the tokenizer for the model is automatically loaded using - [`~transformers.AutoTokenizer.from_pretrained`]. For elements in `reward_funcs` that are custom reward - functions (not [`~transformers.PreTrainedModel`]), the corresponding entries in `reward_processing_classes` - are ignored. - callbacks (list of [`~transformers.TrainerCallback`], *optional*): - List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed - in [here](https://huggingface.co/docs/transformers/main_classes/callback). - - If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`] - method. - optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`): - A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your - model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`. - peft_config ([`~peft.PeftConfig`], *optional*): - PEFT configuration used to wrap the model. If `None`, the model is not wrapped. - - config: - - - - This parameter is deprecated and will be removed in version 0.25.0. Use `args` instead. - - - - reward_model: - - - This parameter is deprecated and will be removed in version 0.25.0. Use `reward_funcs` instead. - - - - policy: - - - - This parameter is deprecated and will be removed in version 0.25.0. Use `model` instead. - - - - ref_policy: - - - - This parameter is deprecated and will be removed in version 0.25.0. To use the initial model as the - reference model, simply omit this parameter. The parameter is ignored. - - - - data_collator: - - - - This parameter is deprecated and will be removed in version 0.25.0. The RLOOTrainer does not use a data - collator, so this parameter is ignored. - - - - """ - def __init__( - self, - model = None, - reward_funcs = None, - args = None, - train_dataset = None, - eval_dataset = None, - processing_class = None, - reward_processing_classes = None, - callbacks = None, - peft_config = None, - config = None, - reward_model = None, - policy = None, - ref_policy = None, - data_collator = None, - **kwargs - ): - if args is None: args = UnslothRLOOConfig() - use_bf16 = getattr(args, 'bf16', False) - if type(use_bf16) is not bool: use_bf16 = False - use_fp16 = getattr(args, 'fp16', False) - if type(use_fp16) is not bool: use_fp16 = False - force_float32 = False - full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1' - if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'): - print('Unsloth: Switching to float32 training since model cannot work with float16') - force_float32 = True - mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') - dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None) - if dtype is None: dtype = model.get_input_embeddings().weight.dtype - from unsloth_zoo.utils import _get_dtype - dtype = _get_dtype(dtype) - float16 = dtype == torch.float16 - if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`') - if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`') - if force_float32: - # Forced float32 training - args.fp16 = False - args.bf16 = False - os.environ['ACCELERATE_MIXED_PRECISION'] = 'no' - if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no' - # args.mixed_precision is a new argument which needs to be set now - elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32': - # Mixed precision training - args.fp16 = float16 - args.bf16 = not float16 - os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16' - if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16' - # args.mixed_precision is a new argument which needs to be set now - elif mixed_precision_dtype == 'bfloat16': - # Both False since bfloat16 full finetuning doesn't do any autocasting. - args.fp16 = False - args.bf16 = False - os.environ['ACCELERATE_MIXED_PRECISION'] = 'no' - if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no' - # args.mixed_precision is a new argument which needs to be set now - - if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no': - args.eval_strategy = 'steps' - if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1 - ga_steps = getattr(args, 'gradient_accumulation_steps', None) - if ga_steps is not None and ga_steps > 1: - from transformers import __version__ as transformers_version - if Version(transformers_version) <= Version('4.45.2'): - print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n' - '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`') - if getattr(args, 'eval_strategy', 'no') != 'no': - eval_bsz = getattr(args, 'per_device_eval_batch_size', 8) - if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size - if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps - fp16_full_eval = getattr(args, 'fp16_full_eval', False) - if type(fp16_full_eval) is not bool: fp16_full_eval = False - bf16_full_eval = getattr(args, 'bf16_full_eval', False) - if type(bf16_full_eval) is not bool: bf16_full_eval = False - if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True - if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False - if force_float32: - args.bf16_full_eval = False - args.fp16_full_eval = False - elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16': - args.bf16_full_eval = True - args.fp16_full_eval = False - elif not bf16_full_eval and not fp16_full_eval: - args.bf16_full_eval = args.bf16 - args.fp16_full_eval = args.fp16 - _output_logits = False - if locals().get('compute_metrics', None) is not None: _output_logits = True - if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True - if _output_logits: - os.environ['UNSLOTH_RETURN_LOGITS'] = '1' - if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'): - pass - else: - model_max_seq_length = getattr(model, 'max_seq_length', None) - args_max_seq_length = getattr(args, 'max_seq_length', None) - if args_max_seq_length is None and model_max_seq_length is not None: - max_seq_length = model.max_seq_length - if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length - elif args_max_seq_length is not None and model_max_seq_length is not None: - if args_max_seq_length > model_max_seq_length: - print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but ' - 'the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.') - args.max_seq_length = model_max_seq_length - if model is not None and hasattr(model, 'for_training'): - model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True)) - if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right' - if 'processing_class' in locals(): - if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right' - if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right' - __tokenizer = processing_class if 'processing_class' in locals() else tokenizer - from unsloth_zoo.vision_utils import UnslothVisionDataCollator - if not isinstance(data_collator, UnslothVisionDataCollator): - if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names: - data_collator = TransformersDataCollatorForLanguageModeling( - __tokenizer, - mlm = False, - mlm_probability = 0.0, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names: - data_collator = DataCollatorForSeq2Seq( - __tokenizer, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - else: - if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False - if hasattr(args, 'dataset_text_field'): args.dataset_text_field = '' - if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True} - if not isinstance(data_collator, UnslothVisionDataCollator): - if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'): - if isinstance(data_collator, DataCollatorForSeq2Seq): - data_collator = DataCollatorForSeq2Seq( - __tokenizer.tokenizer, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - else: - data_collator = TransformersDataCollatorForLanguageModeling( - __tokenizer.tokenizer, - mlm = False, - mlm_probability = 0.0, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - other_metrics = [] - - from unsloth_zoo.logging_utils import PatchRLStatistics - PatchRLStatistics('rloo_trainer', other_metrics) - - # [TODO] Fix up DataParallel multiplying batch sizes - # [TODO] DDP works, but DP seems to not work? [TODO] - if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1: - if getattr(args, "_n_gpu", 1) != 1: - args._n_gpu = 1 - if "model" in locals() and hasattr(model, "for_training"): - model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True)) - super().__init__( - model = model, - reward_funcs = reward_funcs, - args = args, - train_dataset = train_dataset, - eval_dataset = eval_dataset, - processing_class = processing_class, - reward_processing_classes = reward_processing_classes, - callbacks = callbacks, - peft_config = peft_config, - config = config, - reward_model = reward_model, - policy = policy, - ref_policy = ref_policy, - data_collator = data_collator,**kwargs) - if "model" in locals() and hasattr(model, "for_inference"): - model.for_inference() - if hasattr(self, 'neftune_hook_handle'): - self.neftune_hook_handle.remove() - if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle - if getattr(args, 'neftune_noise_alpha', None) is not None: - model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha - pass - if hasattr(self, 'accelerator'): - scaler = self.accelerator.scaler - current_model = model - while hasattr(current_model, 'model'): - current_model.accelerator_scaler = scaler - current_model = current_model.model - current_model.accelerator_scaler = scaler - pass - if hasattr(self, 'train'): - self.train = MethodType(prepare_for_training_mode(self.__class__.train), self) - pass - if hasattr(self, 'llm') and self.llm is not None and hasattr(self.llm, 'get_tokenizer'): - _vllm_tok = self.llm.get_tokenizer() - _pc = getattr(self, 'processing_class', None) or getattr(self, 'tokenizer', None) - if _vllm_tok is not None and _pc is not None and getattr(_pc, 'chat_template', None) is not None and getattr(_vllm_tok, 'chat_template', None) is None: - _vllm_tok.chat_template = _pc.chat_template - pass - -pass - - -if hasattr(logger, "addFilter"): - import logging - class HideLoggingMessage(logging.Filter): - def __init__(self, text): self.text = text - def filter(self, x): return not (self.text in x.getMessage()) - pass - logger.addFilter(HideLoggingMessage("`use_cache=True`")) - diff --git a/unsloth_compiled_cache/UnslothRewardTrainer.py b/unsloth_compiled_cache/UnslothRewardTrainer.py deleted file mode 100644 index 5fd7b70..0000000 --- a/unsloth_compiled_cache/UnslothRewardTrainer.py +++ /dev/null @@ -1,1351 +0,0 @@ -""" -2026.2.1 -2026.2.1 -4.57.6 -0.24.0 -__UNSLOTH_VERSIONING__ -""" - -# Unsloth auto generated code -# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see . - -from torch import Tensor -import torch -import torch.nn as nn -from torch.nn import functional as F -from unsloth_zoo.temporary_patches.common import torch_compile -from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable -from trl.trainer.reward_trainer import (Any, AutoModelForSequenceClassification, AutoTokenizer, BaseTrainer, Callable, DataCollator, DataCollatorForPreference, Dataset, EvalPrediction, IterableDataset, Optional, PartialState, Path, PeftConfig, PreTrainedModel, PreTrainedTokenizerBase, RewardConfig, RewardTrainer, TrainerCallback, Union, clone_chat_template, contextlib, dataclass, defaultdict, disable_dropout_in_model, get_act_offloading_ctx_manager, is_conversational, logger, logging, nn, os, pad, re, remove_none_values, suppress_from_pretrained_warning, torch, transformers, Any, AutoModelForSequenceClassification, AutoTokenizer, Callable, DataCollator, DataCollatorForPreference, Dataset, EvalPrediction, IterableDataset, Optional, PeftConfig, PreTrainedModel, PreTrainedTokenizerBase, RewardConfig, TrainerCallback, Union, clone_chat_template, contextlib, defaultdict, disable_dropout_in_model, get_act_offloading_ctx_manager, logger, os, pad, re, suppress_from_pretrained_warning, torch, transformers, Optional, PreTrainedModel, logger, os, re, torch) - - -import os -from typing import * -from dataclasses import dataclass, field -from packaging.version import Version -import torch -import numpy as np -from contextlib import nullcontext -from torch.nn import functional as F -import inspect -from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling -from transformers.training_args import ParallelMode -from unsloth_zoo.device_type import DEVICE_TYPE, device_synchronize - -# Wrap trainer with padding to right and enable training mode -# Also patches W&B since multiple runs must use wandb.finish() -import functools -from types import MethodType -try: - from unsloth_zoo.gradient_checkpointing import reset_unsloth_gradient_checkpointing_buffers -except: - def reset_unsloth_gradient_checkpointing_buffers(): pass -def prepare_for_training_mode(f): - @functools.wraps(f) - def wrapper(self, *args, **kwargs): - # Enable training mode - _was_training = None - # Get gradient checkpointing setting from training arguments - use_gc = getattr(self.args, 'gradient_checkpointing', True) - if hasattr(self, 'model') and hasattr(self.model, "training"): - _was_training = self.model.training - if hasattr(self, 'model') and hasattr(self.model, "for_training"): - self.model.for_training(use_gradient_checkpointing=use_gc) - output = f(self, *args, **kwargs) - # Restore previous mode when possible - if hasattr(self, 'model') and hasattr(self.model, "for_inference"): - if _was_training is False: - self.model.for_inference() - elif _was_training is True and hasattr(self.model, "for_training"): - self.model.for_training(use_gradient_checkpointing=use_gc) - # Reset gradient checkpointing buffers to free memory while staying ready for next run - try: - reset_unsloth_gradient_checkpointing_buffers() - except: - pass - # Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run - try: - import wandb - wandb.finish() - except: - pass - return output - return wrapper -pass - -torch_compile_options = { - "epilogue_fusion" : True, - "max_autotune" : False, - "shape_padding" : True, - "trace.enabled" : False, - "triton.cudagraphs" : False, -} - -@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) -def chunked_hidden_states_selective_log_softmax( - hidden_states: torch.Tensor, - lm_head: torch.Tensor, - index: torch.Tensor, - chunks: int = 4, - logit_scale_multiply: float = 0.0, - logit_scale_divide: float = 0.0, - logit_softcapping: float = 0.0, - temperature: float = 1.0, -) -> torch.Tensor: - # All Unsloth Zoo code licensed under AGPL3 - flat_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1]) - flat_index = index.reshape(-1) - - chunked_hidden_states = torch.chunk(flat_hidden_states, chunks=chunks, dim=0) - chunked_index = torch.chunk(flat_index, chunks=chunks, dim=0) - - all_per_token_logps = [] - - for chunk_hidden_states, chunk_index in zip(chunked_hidden_states, chunked_index): - chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t() - - if logit_scale_multiply != 0.0: - chunk_logits = chunk_logits * logit_scale_multiply - if logit_scale_divide != 0.0: - chunk_logits = chunk_logits / logit_scale_divide - if logit_softcapping != 0.0: - chunk_logits = chunk_logits * torch.tanh(chunk_logits / logit_softcapping) - - chunk_logits = chunk_logits.to(torch.float32) - - if temperature != 1.0: - chunk_logits = chunk_logits / temperature - - selected_logits = torch.gather(chunk_logits, dim=-1, index=chunk_index.unsqueeze(-1)).squeeze(-1) - logsumexp_values = torch.logsumexp(chunk_logits, dim=-1) - per_token_logps = selected_logits - logsumexp_values - all_per_token_logps.append(per_token_logps) - - all_per_token_logps = torch.concat(all_per_token_logps) - - all_per_token_logps = all_per_token_logps.reshape((hidden_states.shape[0], hidden_states.shape[1])) - return all_per_token_logps - -@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) -def chunked_selective_log_softmax(logits, index): - # Split into 4 chunks only - chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0) - chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0) - all_per_token_logps = [] - # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index) - for chunk_logits, chunk_index in zip(chunked_logits, chunked_index): - chunk_logits = chunk_logits.to(torch.float32) - selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1) - logsumexp_values = torch.logsumexp(chunk_logits, dim = -1) - per_token_logps = selected_logits - logsumexp_values - all_per_token_logps.append(per_token_logps) - pass - all_per_token_logps = torch.concat(all_per_token_logps) - all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1])) - return all_per_token_logps - -def calculate_pad_tokens_in_prompt( - input_ids: torch.Tensor, - logits_to_keep: int, - pad_token_id: int -) -> torch.Tensor: - """ - Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens - """ - if logits_to_keep >= input_ids.shape[1]: - raise ValueError("logits_to_keep must be smaller than the sequence length.") - - prompt_section = input_ids[:, :-logits_to_keep] - - padding_mask = (prompt_section == pad_token_id) - - pad_token_counts = padding_mask.sum(dim=1) - - return pad_token_counts - -def create_completion_attention_mask( - completion_input_ids: torch.Tensor, - left_pad_tokens_per_prompt: torch.Tensor, - max_left_pad: int, - pad_token_id: int -) -> torch.Tensor: - """ - Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad] - - Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens - and pad are pad tokens, this function would make a completion mask that would 0 out the pad - and p tokens. so in this example [0,0,0,1,1,1,0,0,0] - """ - batch_size, completion_len = completion_input_ids.shape - device = completion_input_ids.device - - num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt - - indices = torch.arange(completion_len, device=device).unsqueeze(0) - shift_mask = indices >= num_tokens_to_mask.unsqueeze(1) - - non_padding_mask = (completion_input_ids != pad_token_id) - - final_mask = shift_mask & non_padding_mask - - return final_mask - -def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor: - """ - Moves all padding tokens in each sequence of a batch to the right. - """ - mask = (tensor != pad_id) - # Must do stable=True since binary mark is unordered - sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True) - packed_tensor = torch.gather(tensor, 1, sorted_indices) - return packed_tensor - -def align_logprobs_with_mask( - logprob_tensor: torch.Tensor, - attention_mask: torch.Tensor, - pad_value: float = 0.0 -) -> torch.Tensor: - """ - Aligns a log probability tensor with a given attention mask. - """ - - device = logprob_tensor.device - batch_size, logprob_seq_len = logprob_tensor.shape - mask_seq_len = attention_mask.shape[1] - - padded_logprobs = torch.full( - attention_mask.shape, - fill_value=pad_value, - dtype=logprob_tensor.dtype, - device=device - ) - - left_pad_counts = torch.argmax(attention_mask, dim=1) - - cols = torch.arange(logprob_seq_len, device=device) - dest_indices = left_pad_counts.unsqueeze(1) + cols - - # Create destination row indices - # Shape: [batch_size, logprob_seq_len] - row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices) - - # --- 4. Filter out-of-bounds indices and perform assignment --- - # Create a mask to identify only the indices that are within the bounds - # of the target tensor's sequence length. - valid_mask = dest_indices < mask_seq_len - - # Use this mask to select only the valid row indices, column indices, - # and the corresponding values from the logprob tensor. - # This flattens the selected elements into 1D tensors. - valid_rows = row_indices[valid_mask] - valid_cols = dest_indices[valid_mask] - valid_vals = logprob_tensor[valid_mask] - - # Place the valid values into their correct positions in the padded tensor - # using a single, efficient advanced indexing operation. - padded_logprobs[valid_rows, valid_cols] = valid_vals - - return padded_logprobs - -def autotune_batch_and_chunks( - total_input_rows, - seq_len, - hidden_size, - vocab_size, - dtype_bytes=16, - multiplier=None -): - if multiplier is None: - final_m = max(4, seq_len // 4096) - else: - final_m = multiplier - - if torch.cuda.is_available(): - free_bytes, _ = torch.cuda.mem_get_info() - limit_gb = (free_bytes / (1024**3))*.80 - elif hasattr(torch, "xpu") and torch.xpu.is_available(): - # For XPU: estimate free memory from total - reserved - total_mem = torch.xpu.get_device_properties(0).total_memory - reserved_mem = torch.xpu.memory_reserved() - free_bytes = total_mem - reserved_mem - limit_gb = (free_bytes / (1024**3)) * 0.80 - else: - # Fallback: assume 8GB available - limit_gb = 8.0 - - bytes_to_gb = 1024**3 - - b_vals = torch.arange(total_input_rows, 0, -1, device='cpu', dtype=torch.float32) - - hidden_gb = (b_vals * seq_len * hidden_size * dtype_bytes) / bytes_to_gb - - base_logits = ((b_vals/total_input_rows) * b_vals * seq_len * vocab_size * dtype_bytes) / bytes_to_gb - logits_gb = base_logits / final_m - - total_mem_gb = hidden_gb + logits_gb - - valid_mask = total_mem_gb <= limit_gb - valid_indices = torch.nonzero(valid_mask, as_tuple=False) - - if valid_indices.shape[0] == 0: - #This means your GPU will OOM - return 4, final_m - - best_idx = valid_indices[0].item() - final_b = int(b_vals[best_idx].item()) - - return final_b, final_m -@dataclass -class UnslothRewardConfig(RewardConfig): - """ - - Configuration class for the [`RewardTrainer`]. - - This class includes only the parameters that are specific to Reward training. For a full list of training - arguments, please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this - class may differ from those in [`~transformers.TrainingArguments`]. - - Using [`~transformers.HfArgumentParser`] we can turn this class into - [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the - command line. - - Parameters: - > Parameters that control the model - - model_init_kwargs (`dict[str, Any]`, *optional*): - Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model` - argument of the [`RewardTrainer`] is provided as a string. If you're training a MoE architecture and want - to include the load balancing/auxilliary loss as a part of the final loss, remember to set - `output_router_logits=True` in this dictionary. - chat_template_path (`str`, *optional*): - If specified, sets the model's chat template. This can either be the path to a tokenizer (local directory - or Hugging Face Hub model) or a direct path to a Jinja template file. When using a Jinja file, you must - ensure that any special tokens referenced in the template are added to the tokenizer and that the model's - embedding layer is resized accordingly. - disable_dropout (`bool`, *optional*, defaults to `True`): - Whether to disable dropout in the model. - - > Parameters that control the data preprocessing - - dataset_num_proc (`int`, *optional*): - Number of processes to use for processing the dataset. - eos_token (`str`, *optional*): - Token used to indicate the end of a turn or sequence. If `None`, it defaults to - `processing_class.eos_token`. - pad_token (`str`, *optional*): - Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that is also `None`, - it falls back to `processing_class.eos_token`. - max_length (`int` or `None`, *optional*, defaults to `1024`): - Maximum length of the tokenized sequence. Samples are filtered out if either chosen or rejected sequence - exceeds this value. If `None`, no filtering is applied. - pad_to_multiple_of (`int`, *optional*): - If set, the sequences will be padded to a multiple of this value. - - > Parameters that control the training - - center_rewards_coefficient (`float`, *optional*): - Coefficient to incentivize the reward model to output mean-zero rewards (proposed by - https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`. - activation_offloading (`bool`, *optional*, defaults to `False`): - Whether to offload the activations to the CPU. - - """ - vllm_sampling_params: Optional[Any] = field( - default = None, - metadata = {'help': 'vLLM SamplingParams'}, - ) - unsloth_num_chunks : Optional[int] = field( - default = -1, - metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'}, - ) - unsloth_logit_chunk_multiplier : Optional[int] = field( - default = None, - metadata = {'help': 'Multiplier for chunked logit computations.'}, - ) - unsloth_grpo_mini_batch : Optional[int] = field( - default = None, - metadata = {'help': 'Mini batch size for GRPO hidden state accumulation. Default is None unless user defines it.'}, - ) - max_seq_length : Optional[int] = field( - default = None, - metadata = {'help': 'Maximum sequence length to truncate to.'}, - ) - def __init__( - self, - output_dir = None, - overwrite_output_dir = None, - do_train = False, - do_eval = False, - do_predict = False, - eval_strategy = 'no', - prediction_loss_only = False, - per_device_train_batch_size = 4, - per_device_eval_batch_size = 4, - per_gpu_train_batch_size = None, - per_gpu_eval_batch_size = None, - gradient_accumulation_steps = 2, - eval_accumulation_steps = 2, - eval_delay = 0, - torch_empty_cache_steps = 250, - learning_rate = 5e-05, - weight_decay = 0.01, - adam_beta1 = 0.9, - adam_beta2 = 0.999, - adam_epsilon = 1e-08, - max_grad_norm = 1.0, - num_train_epochs = 3.0, - max_steps = -1, - lr_scheduler_type = 'linear', - lr_scheduler_kwargs = None, - warmup_ratio = 0.1, - warmup_steps = 0, - log_level = 'passive', - log_level_replica = 'warning', - log_on_each_node = True, - logging_dir = None, - logging_strategy = 'steps', - logging_first_step = False, - logging_steps = 1, - logging_nan_inf_filter = False, - save_strategy = 'steps', - save_steps = 500, - save_total_limit = None, - save_safetensors = True, - save_on_each_node = False, - save_only_model = False, - restore_callback_states_from_checkpoint = False, - no_cuda = False, - use_cpu = False, - use_mps_device = False, - seed = 3407, - data_seed = 3407, - jit_mode_eval = False, - bf16 = False, - fp16 = False, - fp16_opt_level = 'O1', - half_precision_backend = 'auto', - bf16_full_eval = False, - fp16_full_eval = False, - tf32 = None, - local_rank = -1, - ddp_backend = None, - tpu_num_cores = None, - tpu_metrics_debug = False, - debug = '', - dataloader_drop_last = False, - eval_steps = None, - dataloader_num_workers = 0, - dataloader_prefetch_factor = None, - past_index = -1, - run_name = None, - disable_tqdm = None, - remove_unused_columns = True, - label_names = None, - load_best_model_at_end = False, - metric_for_best_model = None, - greater_is_better = None, - ignore_data_skip = False, - fsdp = None, - fsdp_min_num_params = 0, - fsdp_config = None, - fsdp_transformer_layer_cls_to_wrap = None, - accelerator_config = None, - parallelism_config = None, - deepspeed = None, - label_smoothing_factor = 0.0, - optim = 'adamw_8bit', - optim_args = None, - adafactor = False, - group_by_length = False, - length_column_name = 'length', - report_to = 'none', - project = 'huggingface', - trackio_space_id = 'trackio', - ddp_find_unused_parameters = None, - ddp_bucket_cap_mb = None, - ddp_broadcast_buffers = None, - dataloader_pin_memory = True, - dataloader_persistent_workers = False, - skip_memory_metrics = True, - use_legacy_prediction_loop = False, - push_to_hub = False, - resume_from_checkpoint = None, - hub_model_id = None, - hub_strategy = 'every_save', - hub_token = None, - hub_private_repo = None, - hub_always_push = False, - hub_revision = None, - gradient_checkpointing = True, - gradient_checkpointing_kwargs = None, - include_inputs_for_metrics = False, - eval_do_concat_batches = True, - fp16_backend = 'auto', - push_to_hub_model_id = None, - push_to_hub_organization = None, - push_to_hub_token = None, - mp_parameters = '', - auto_find_batch_size = False, - full_determinism = False, - torchdynamo = None, - ray_scope = 'last', - ddp_timeout = 1800, - torch_compile = False, - torch_compile_backend = None, - torch_compile_mode = None, - include_tokens_per_second = False, - include_num_input_tokens_seen = False, - neftune_noise_alpha = None, - optim_target_modules = None, - batch_eval_metrics = False, - eval_on_start = False, - use_liger_kernel = False, - liger_kernel_config = None, - eval_use_gather_object = False, - average_tokens_across_devices = True, - model_init_kwargs = None, - chat_template_path = None, - disable_dropout = True, - dataset_num_proc = None, - eos_token = None, - pad_token = None, - max_length = 1024, - pad_to_multiple_of = None, - center_rewards_coefficient = None, - activation_offloading = False, - vllm_sampling_params = None, - unsloth_num_chunks = -1, - unsloth_logit_chunk_multiplier = None, - unsloth_grpo_mini_batch = None, - max_seq_length = None, - **kwargs, - ): - if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!') - if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!') - if num_train_epochs is None: - num_train_epochs = 3.0 # Default to 3 epochs if None, max_steps will override - if output_dir is None and save_strategy == 'steps' and save_steps == 500: - output_dir = 'unsloth_training_checkpoints' - save_strategy = 'no' - import multiprocessing as _mp - if _mp.get_start_method() != 'fork': - dataset_num_proc = None - elif dataset_num_proc is None: - import psutil - dataset_num_proc = min(max((psutil.cpu_count() or 1)+4, 2), 64) - memory_gb_left = psutil.virtual_memory().available / (1024**3) - if memory_gb_left <= 2: dataset_num_proc = 1 - else: dataset_num_proc = min(dataset_num_proc, int(memory_gb_left)) - if os.environ.get('UNSLOTH_ENABLE_FLEX_ATTENTION', '0') == '1': - from unsloth_zoo.flex_attention import HAS_FLEX_ATTENTION - if HAS_FLEX_ATTENTION and pad_to_multiple_of is None: - from unsloth_zoo.flex_attention import FLEX_ATTENTION_BLOCK_SIZE - pad_to_multiple_of = FLEX_ATTENTION_BLOCK_SIZE - - - super().__init__( - output_dir = output_dir, - overwrite_output_dir = overwrite_output_dir, - do_train = do_train, - do_eval = do_eval, - do_predict = do_predict, - eval_strategy = eval_strategy, - prediction_loss_only = prediction_loss_only, - per_device_train_batch_size = per_device_train_batch_size, - per_device_eval_batch_size = per_device_eval_batch_size, - per_gpu_train_batch_size = per_gpu_train_batch_size, - per_gpu_eval_batch_size = per_gpu_eval_batch_size, - gradient_accumulation_steps = gradient_accumulation_steps, - eval_accumulation_steps = eval_accumulation_steps, - eval_delay = eval_delay, - torch_empty_cache_steps = torch_empty_cache_steps, - learning_rate = learning_rate, - weight_decay = weight_decay, - adam_beta1 = adam_beta1, - adam_beta2 = adam_beta2, - adam_epsilon = adam_epsilon, - max_grad_norm = max_grad_norm, - num_train_epochs = num_train_epochs, - max_steps = max_steps, - lr_scheduler_type = lr_scheduler_type, - lr_scheduler_kwargs = lr_scheduler_kwargs, - warmup_ratio = warmup_ratio, - warmup_steps = warmup_steps, - log_level = log_level, - log_level_replica = log_level_replica, - log_on_each_node = log_on_each_node, - logging_dir = logging_dir, - logging_strategy = logging_strategy, - logging_first_step = logging_first_step, - logging_steps = logging_steps, - logging_nan_inf_filter = logging_nan_inf_filter, - save_strategy = save_strategy, - save_steps = save_steps, - save_total_limit = save_total_limit, - save_safetensors = save_safetensors, - save_on_each_node = save_on_each_node, - save_only_model = save_only_model, - restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint, - no_cuda = no_cuda, - use_cpu = use_cpu, - use_mps_device = use_mps_device, - seed = seed, - data_seed = data_seed, - jit_mode_eval = jit_mode_eval, - bf16 = bf16, - fp16 = fp16, - fp16_opt_level = fp16_opt_level, - half_precision_backend = half_precision_backend, - bf16_full_eval = bf16_full_eval, - fp16_full_eval = fp16_full_eval, - tf32 = tf32, - local_rank = local_rank, - ddp_backend = ddp_backend, - tpu_num_cores = tpu_num_cores, - tpu_metrics_debug = tpu_metrics_debug, - debug = debug, - dataloader_drop_last = dataloader_drop_last, - eval_steps = eval_steps, - dataloader_num_workers = dataloader_num_workers, - dataloader_prefetch_factor = dataloader_prefetch_factor, - past_index = past_index, - run_name = run_name, - disable_tqdm = disable_tqdm, - remove_unused_columns = remove_unused_columns, - label_names = label_names, - load_best_model_at_end = load_best_model_at_end, - metric_for_best_model = metric_for_best_model, - greater_is_better = greater_is_better, - ignore_data_skip = ignore_data_skip, - fsdp = fsdp, - fsdp_min_num_params = fsdp_min_num_params, - fsdp_config = fsdp_config, - fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap, - accelerator_config = accelerator_config, - parallelism_config = parallelism_config, - deepspeed = deepspeed, - label_smoothing_factor = label_smoothing_factor, - optim = optim, - optim_args = optim_args, - adafactor = adafactor, - group_by_length = group_by_length, - length_column_name = length_column_name, - report_to = report_to, - project = project, - trackio_space_id = trackio_space_id, - ddp_find_unused_parameters = ddp_find_unused_parameters, - ddp_bucket_cap_mb = ddp_bucket_cap_mb, - ddp_broadcast_buffers = ddp_broadcast_buffers, - dataloader_pin_memory = dataloader_pin_memory, - dataloader_persistent_workers = dataloader_persistent_workers, - skip_memory_metrics = skip_memory_metrics, - use_legacy_prediction_loop = use_legacy_prediction_loop, - push_to_hub = push_to_hub, - resume_from_checkpoint = resume_from_checkpoint, - hub_model_id = hub_model_id, - hub_strategy = hub_strategy, - hub_token = hub_token, - hub_private_repo = hub_private_repo, - hub_always_push = hub_always_push, - hub_revision = hub_revision, - gradient_checkpointing = gradient_checkpointing, - gradient_checkpointing_kwargs = gradient_checkpointing_kwargs, - include_inputs_for_metrics = include_inputs_for_metrics, - eval_do_concat_batches = eval_do_concat_batches, - fp16_backend = fp16_backend, - push_to_hub_model_id = push_to_hub_model_id, - push_to_hub_organization = push_to_hub_organization, - push_to_hub_token = push_to_hub_token, - mp_parameters = mp_parameters, - auto_find_batch_size = auto_find_batch_size, - full_determinism = full_determinism, - torchdynamo = torchdynamo, - ray_scope = ray_scope, - ddp_timeout = ddp_timeout, - torch_compile = torch_compile, - torch_compile_backend = torch_compile_backend, - torch_compile_mode = torch_compile_mode, - include_tokens_per_second = include_tokens_per_second, - include_num_input_tokens_seen = include_num_input_tokens_seen, - neftune_noise_alpha = neftune_noise_alpha, - optim_target_modules = optim_target_modules, - batch_eval_metrics = batch_eval_metrics, - eval_on_start = eval_on_start, - use_liger_kernel = use_liger_kernel, - liger_kernel_config = liger_kernel_config, - eval_use_gather_object = eval_use_gather_object, - average_tokens_across_devices = average_tokens_across_devices, - model_init_kwargs = model_init_kwargs, - chat_template_path = chat_template_path, - disable_dropout = disable_dropout, - dataset_num_proc = dataset_num_proc, - eos_token = eos_token, - pad_token = pad_token, - max_length = max_length, - pad_to_multiple_of = pad_to_multiple_of, - center_rewards_coefficient = center_rewards_coefficient, - activation_offloading = activation_offloading,**kwargs) - self.vllm_sampling_params = vllm_sampling_params - self.unsloth_num_chunks = unsloth_num_chunks - if unsloth_grpo_mini_batch is not None: - if self.generation_batch_size >= unsloth_grpo_mini_batch: - self.unsloth_grpo_mini_batch = unsloth_grpo_mini_batch - else: - raise ValueError( - f"Unsloth GRPO mini batch size needs to be less than or equal to the effective generation batch size, " - f"which is self.per_device_train_batch_size * gradient_accumulation_steps." - ) - self.unsloth_logit_chunk_multiplier = unsloth_logit_chunk_multiplier - self.max_seq_length = max_seq_length - -pass - -class _UnslothRewardTrainer(BaseTrainer): - """""" - - _tag_names = ["trl", "reward-trainer"] - _name = "Reward" - _template_file = "rm_model_card.md" - - def __init__( - self, - model: Union[str, PreTrainedModel], - args: Optional[RewardConfig] = None, - data_collator: Optional[DataCollator] = None, - train_dataset: Optional[Union[Dataset, IterableDataset]] = None, - eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None, - processing_class: Optional[PreTrainedTokenizerBase] = None, - compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None, - callbacks: Optional[list[TrainerCallback]] = None, - optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), - optimizer_cls_and_kwargs: Optional[tuple[type[torch.optim.Optimizer], dict[str, Any]]] = None, - preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, - peft_config: Optional["PeftConfig"] = None, - ): - # Args - if args is None: - model_name = model if isinstance(model, str) else model.config._name_or_path - model_name = model_name.split("/")[-1] - args = RewardConfig(f"{model_name}-Reward") - - # Model - model_init_kwargs = args.model_init_kwargs or {} - if isinstance(model, str): - model_id = model - dtype = model_init_kwargs.get("dtype") - if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None: - pass # dtype is already a torch.dtype or "auto" or None - elif isinstance(dtype, str) and dtype in ["bfloat16", "float16", "float32"]: - model_init_kwargs["dtype"] = getattr(torch, dtype) - else: - raise ValueError( - "Invalid `dtype` passed to `RewardConfig`. Expected either 'auto' or a string representing " - f"a valid `torch.dtype` (e.g., 'float32'), but got {dtype}." - ) - with suppress_from_pretrained_warning(transformers.modeling_utils.logger): - model = AutoModelForSequenceClassification.from_pretrained(model_id, num_labels=1, **model_init_kwargs) - else: - model_id = model.config._name_or_path - if args.model_init_kwargs is not None: - logger.warning( - "You passed `model_init_kwargs` to the `RewardConfig`, but your model is already instantiated. " - "The `model_init_kwargs` will be ignored." - ) - - # Processing class - if processing_class is None: - processing_class = AutoTokenizer.from_pretrained(model_id) - - # Handle pad token for processors or tokenizers - if args.eos_token is not None: - eos_token = args.eos_token - eos_token_id = processing_class.convert_tokens_to_ids(eos_token) - if eos_token_id is None: - raise ValueError( - f"The specified `eos_token` ('{eos_token}') is not found in the vocabulary of the given " - f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `eos_token` exists " - "in the vocabulary before using it as an EOS token." - ) - processing_class.eos_token_id = eos_token_id - - if args.chat_template_path is not None: - if os.path.isfile(args.chat_template_path) and args.chat_template_path.endswith((".jinja", ".j2")): - with open(args.chat_template_path, encoding="utf-8") as chat_template_file: - processing_class.chat_template = chat_template_file.read() - added_tokens = [] - else: - model, processing_class, added_tokens = clone_chat_template( - model, processing_class, args.chat_template_path - ) - else: - added_tokens = [] - - # PEFT configuration and model wrapping - if False: - if added_tokens: - # Ensure that the added tokens are trainable - if peft_config.trainable_token_indices is None: - peft_config.trainable_token_indices = {"embed_tokens": added_tokens} - elif "embed_tokens" not in peft_config.trainable_token_indices: - peft_config.trainable_token_indices["embed_tokens"] = added_tokens - else: - peft_config.trainable_token_indices["embed_tokens"].extend(added_tokens) - - # Ensure that the lm_head is trainable - if peft_config.modules_to_save is None or "lm_head" not in peft_config.modules_to_save: - logger.warning( - "Cloning chat template added new tokens to the tokenizer, but 'lm_head' is not in PEFT's " - "`modules_to_save`. As a result, the model may not learn to generate outputs with these new " - "tokens, leading to degraded generation quality. To fix this, add " - "`modules_to_save=['lm_head']` to your PEFT configuration." - ) - - if peft_config.modules_to_save is None: - peft_config.modules_to_save = ["lm_head"] - else: - peft_config.modules_to_save.append("lm_head") - - if False: - pass - - # Disable dropout in the model - if args.disable_dropout: - disable_dropout_in_model(model) - - # Pad token [needed for SequenceClassification models] - # If not provided, use the one from the processing class or the eos token if the processing class does not have - # a pad token. - pad_token = args.pad_token or processing_class.pad_token or processing_class.eos_token - pad_token_id = processing_class.convert_tokens_to_ids(pad_token) - if pad_token_id is None: - raise ValueError( - f"The specified `pad_token` ('{pad_token}') is not found in the vocabulary of the given " - f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `pad_token` exists " - "in the vocabulary before using it as a padding token." - ) - model.config.pad_token_id = pad_token_id - processing_class.pad_token_id = pad_token_id - - # Data collator - if data_collator is None: - data_collator = DataCollatorForPreference( - pad_token_id=pad_token_id, - pad_to_multiple_of=args.pad_to_multiple_of, - ) - - # Dataset - train_dataset = self._prepare_dataset(train_dataset, processing_class, args, "train") - if eval_dataset is not None: - if isinstance(eval_dataset, dict): - eval_dataset = { - key: self._prepare_dataset(dataset, processing_class, args, key) - for key, dataset in eval_dataset.items() - } - else: - eval_dataset = self._prepare_dataset(eval_dataset, processing_class, args, "eval") - - # Initialize the metrics - self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} - self._total_train_tokens = 0 - - # Initialize the Trainer. Parent class will handle: - # - DeepSpeed configuration [through create_accelerator_and_postprocess] - # - FSDP setup - # - Distributed training setup - # - Optimizer and scheduler creation - - super().__init__( - model=model, - args=args, - data_collator=data_collator, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - processing_class=processing_class, - compute_metrics=compute_metrics, - callbacks=callbacks, - optimizers=optimizers, - optimizer_cls_and_kwargs=optimizer_cls_and_kwargs, - preprocess_logits_for_metrics=preprocess_logits_for_metrics, - ) - - # During evaluation, Trainer calls compute_loss[] only if can_return_loss is True and label_names is empty. - self.can_return_loss = True - self.label_names = [] - - # Initialize activation offloading context - if self.args.activation_offloading: - self.maybe_activation_offload_context = get_act_offloading_ctx_manager(model=self.model) - else: - self.maybe_activation_offload_context = contextlib.nullcontext() - - # Add tags for models that have been loaded with the correct transformers version - if hasattr(self.model, "add_model_tags"): - self.model.add_model_tags(self._tag_names) - - self.aux_loss_enabled = getattr(model.config, "output_router_logits", False) - - def _prepare_dataset( - self, - dataset: Union[Dataset, IterableDataset], - processing_class: PreTrainedTokenizerBase, - args: RewardConfig, - dataset_name: str, - ) -> Union[Dataset, IterableDataset]: - # Tabular backends like Arrow/Parquet insert `None` for mismatched keys in nested structures. Clean them from - # sampled data. - if isinstance(dataset, Dataset): # IterableDataset does not support `with_transform` - dataset = dataset.with_transform(remove_none_values) - - # If the dataset is already preprocessed (tokenized), skip the processing steps. - column_names = list(next(iter(dataset)).keys()) - is_processed = "chosen_input_ids" in column_names and "rejected_input_ids" in column_names - - # Build the kwargs for the `map` function - map_kwargs = {} - if isinstance(dataset, Dataset): # IterableDataset does not support num_proc - map_kwargs["num_proc"] = args.dataset_num_proc - - with PartialState().main_process_first(): - if not is_processed: - # Add EOS token to the end of the sequences if needed - first_example = next(iter(dataset)) - if not is_conversational(first_example): - if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` - map_kwargs["desc"] = f"Adding EOS to {dataset_name} dataset" - - def add_eos(example, eos_token): - if not example["chosen"].endswith(eos_token): - example["chosen"] = example["chosen"] + eos_token - if "rejected" in example and not example["rejected"].endswith(eos_token): - example["rejected"] = example["rejected"] + eos_token - return example - - dataset = dataset.map( - add_eos, - fn_kwargs={"eos_token": processing_class.eos_token}, - **map_kwargs, - ) - - # Tokenize the dataset - if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` - map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset" - - def tokenize_fn(example, processing_class): - if "prompt" in example: # explicit prompt case - example["chosen"] = example["prompt"] + example["chosen"] - example["rejected"] = example["prompt"] + example["rejected"] - - if is_conversational(example): - chosen_input_ids = processing_class.apply_chat_template( - example["chosen"], - tools=example.get("tools"), - **example.get("chat_template_kwargs", {}), - ) - rejected_input_ids = processing_class.apply_chat_template( - example["rejected"], - tools=example.get("tools"), - **example.get("chat_template_kwargs", {}), - ) - output = {"chosen_input_ids": chosen_input_ids, "rejected_input_ids": rejected_input_ids} - else: - output = { - "chosen_input_ids": processing_class(text=example["chosen"])["input_ids"], - "rejected_input_ids": processing_class(text=example["rejected"])["input_ids"], - } - return output - - dataset = dataset.map(tokenize_fn, fn_kwargs={"processing_class": processing_class}, **map_kwargs) - - # Filter samples that are longer than `max_length` - if args.max_length is not None: - if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` - map_kwargs["desc"] = f"Filtering {dataset_name} >{args.max_length} tokens" - dataset = dataset.filter( - lambda example: len(example["chosen_input_ids"]) <= args.max_length - and len(example["rejected_input_ids"]) <= args.max_length, - **map_kwargs, - ) - - return dataset - - def _set_signature_columns_if_needed(self): - # If `self.args.remove_unused_columns` is True, non-signature columns are removed. - # By default, this method sets `self._signature_columns` to the model's expected inputs (usually, "input_ids" - # and "attention_mask"). - if self._signature_columns is None: - self._signature_columns = ["chosen_input_ids", "rejected_input_ids", "margin"] - - def compute_loss( - self, - model: nn.Module, - inputs: dict[str, Union[torch.Tensor, Any]], - return_outputs: bool = False, - num_items_in_batch: Optional[torch.Tensor] = None, - ): - """ - Compute training loss and additionally compute token accuracies - """ - mode = "train" if self.model.training else "eval" - - # If not set, defaults from model config and may warn since cache isn't compatible with gradient checkpointing - inputs["use_cache"] = False - outputs = model(**inputs) - - # Split the rewards into chosen and rejected - rewards_chosen, rewards_rejected = torch.chunk(outputs.logits.squeeze(-1), chunks=2) - - # Calculate loss, optionally modulate with margin - if "margin" in inputs: - loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs["margin"]).mean() - else: - loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean() - - if self.args.center_rewards_coefficient is not None: - loss += self.args.center_rewards_coefficient * torch.mean((rewards_chosen + rewards_rejected) ** 2) - - if mode == "train": - num_tokens_in_batch = self.accelerator.gather_for_metrics(inputs["attention_mask"].sum()).sum().item() - self._total_train_tokens += num_tokens_in_batch - self._metrics[mode]["num_tokens"] = [self._total_train_tokens] - - # Compute min, mean, max, accuracy and margin - with torch.no_grad(): - all_rewards = self.accelerator.gather(outputs.logits) - self._metrics[mode]["min_reward"].append(all_rewards.min().item()) - self._metrics[mode]["mean_reward"].append(all_rewards.mean().item()) - self._metrics[mode]["max_reward"].append(all_rewards.max().item()) - - mean_accuracy = (rewards_chosen > rewards_rejected).float().mean() - mean_accuracy = self.accelerator.gather_for_metrics(mean_accuracy).mean().item() - self._metrics[mode]["accuracy"].append(mean_accuracy) - - mean_margin = (rewards_chosen - rewards_rejected).mean() - mean_margin = self.accelerator.gather_for_metrics(mean_margin).mean() - self._metrics[mode]["margin"].append(mean_margin.item()) - - return (loss, outputs) if return_outputs else loss - - # Override training step to add activation offloading context. - def training_step(self, *args, **kwargs): - with self.maybe_activation_offload_context: - return super().training_step(*args, **kwargs) - - def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None: - mode = "train" if self.model.training else "eval" - metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics - - # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs` - # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format. - if mode == "eval": - metrics = {f"eval_{key}": val for key, val in metrics.items()} - - logs.update(metrics) - super().log(logs, start_time) - self._metrics[mode].clear() - - # Ensure the model card is saved along with the checkpoint - def _save_checkpoint(self, model, trial): - if self.args.hub_model_id is None: - model_name = Path(self.args.output_dir).name - else: - model_name = self.args.hub_model_id.split("/")[-1] - self.create_model_card(model_name=model_name) - super()._save_checkpoint(model, trial) -class UnslothRewardTrainer(_UnslothRewardTrainer): - """ - - Trainer for Outcome-supervised Reward Models (ORM). - - This class is a wrapper around the [`~transformers.Trainer`] class and inherits all of its attributes and methods. - - Example: - - ```python - from trl import RewardTrainer - from datasets import load_dataset - - dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train") - - trainer = RewardTrainer(model="Qwen/Qwen2.5-0.5B-Instruct", train_dataset=dataset) - trainer.train() - ``` - - Args: - model (`Union[str, PreTrainedModel]`): - Model to be trained. Can be either: - - - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a - path to a *directory* containing model weights saved using - [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded - using `AutoModelForSequenceClassification.from_pretrained` with the keyword arguments in - `args.model_init_kwargs`. - - A sequence classification [`~transformers.PreTrainedModel`] object. - args ([`RewardConfig`], *optional*): - Configuration for this trainer. If `None`, a default configuration is used. - data_collator ([`~transformers.DataCollator`], *optional*): - Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`. - Will default to [`~trainer.reward_trainer.DataCollatorForPreference`]. - train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]): - Dataset to use for training. This trainer supports [preference](#preference) type (both implicit and - explicit prompt). The format of the samples can be either: - - - [Standard](dataset_formats#standard): Each sample contains plain text. - - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role - and content). - - The trainer also supports processed datasets (tokenized) as long as they contain an `chosen_input_ids` and - `rejected_input_ids` fields. - eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`): - Dataset to use for evaluation. It must meet the same requirements as `train_dataset`. - processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*): - Tokenizer used to process the data. If `None`, the tokenizer is loaded from the model's name with - [`~transformers.AutoTokenizer.from_pretrained`]. A padding token, `processing_class.pad_token`, must be - set. If the processing class has not set a padding token, `processing_class.eos_token` will be used as the - default. - compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): - The function that will be used to compute metrics at evaluation. Must take a - [`~transformers.EvalPrediction`] and return a dictionary string to metric values. When passing - [`RewardConfig`] with `batch_eval_metrics` set to `True`, your `compute_metrics` function must take a - boolean `compute_result` argument. This will be triggered after the last eval batch to signal that the - function needs to calculate and return the global summary statistics rather than accumulating the - batch-level statistics. - callbacks (list of [`~transformers.TrainerCallback`], *optional*): - List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed - in [here](https://huggingface.co/docs/transformers/main_classes/callback). - - If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`] - method. - optimizers (`tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]]`, *optional*, defaults to `(None, None)`): - A tuple containing the optimizer and the scheduler to use. Will default to an instance of `AdamW` on your - model and a scheduler given by [`~transformers.get_linear_schedule_with_warmup`] controlled by `args`. - optimizer_cls_and_kwargs (`tuple[Type[torch.optim.Optimizer], Dict[str, Any]]`, *optional*): - A tuple containing the optimizer class and keyword arguments to use. Overrides `optim` and `optim_args` in - `args`. Incompatible with the `optimizers` argument. - - Unlike `optimizers`, this argument avoids the need to place model parameters on the correct devices before - initializing the Trainer. - preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*): - A function that preprocess the logits right before caching them at each evaluation step. Must take two - tensors, the logits and the labels, and return the logits once processed as desired. The modifications made - by this function will be reflected in the predictions received by `compute_metrics`. - - Note that the labels (second parameter) will be `None` if the dataset does not have them. - peft_config ([`~peft.PeftConfig`], *optional*): - PEFT configuration used to wrap the model. If `None`, the model is not wrapped. Note that if the loaded - model is a causal LM, it's highly recommended to set `modules_to_save=["score"]` in the PEFT configuration - to ensure that the reward head is properly trained. - - """ - def __init__( - self, - model, - args = None, - data_collator = None, - train_dataset = None, - eval_dataset = None, - processing_class = None, - compute_metrics = None, - callbacks = None, - optimizer_cls_and_kwargs = None, - preprocess_logits_for_metrics = None, - peft_config = None, - **kwargs - ): - if args is None: args = UnslothRewardConfig() - use_bf16 = getattr(args, 'bf16', False) - if type(use_bf16) is not bool: use_bf16 = False - use_fp16 = getattr(args, 'fp16', False) - if type(use_fp16) is not bool: use_fp16 = False - force_float32 = False - full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1' - if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'): - print('Unsloth: Switching to float32 training since model cannot work with float16') - force_float32 = True - mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') - dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None) - if dtype is None: dtype = model.get_input_embeddings().weight.dtype - from unsloth_zoo.utils import _get_dtype - dtype = _get_dtype(dtype) - float16 = dtype == torch.float16 - if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`') - if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`') - if force_float32: - # Forced float32 training - args.fp16 = False - args.bf16 = False - os.environ['ACCELERATE_MIXED_PRECISION'] = 'no' - if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no' - # args.mixed_precision is a new argument which needs to be set now - elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32': - # Mixed precision training - args.fp16 = float16 - args.bf16 = not float16 - os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16' - if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16' - # args.mixed_precision is a new argument which needs to be set now - elif mixed_precision_dtype == 'bfloat16': - # Both False since bfloat16 full finetuning doesn't do any autocasting. - args.fp16 = False - args.bf16 = False - os.environ['ACCELERATE_MIXED_PRECISION'] = 'no' - if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no' - # args.mixed_precision is a new argument which needs to be set now - - if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no': - args.eval_strategy = 'steps' - if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1 - ga_steps = getattr(args, 'gradient_accumulation_steps', None) - if ga_steps is not None and ga_steps > 1: - from transformers import __version__ as transformers_version - if Version(transformers_version) <= Version('4.45.2'): - print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n' - '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`') - if getattr(args, 'eval_strategy', 'no') != 'no': - eval_bsz = getattr(args, 'per_device_eval_batch_size', 8) - if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size - if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps - fp16_full_eval = getattr(args, 'fp16_full_eval', False) - if type(fp16_full_eval) is not bool: fp16_full_eval = False - bf16_full_eval = getattr(args, 'bf16_full_eval', False) - if type(bf16_full_eval) is not bool: bf16_full_eval = False - if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True - if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False - if force_float32: - args.bf16_full_eval = False - args.fp16_full_eval = False - elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16': - args.bf16_full_eval = True - args.fp16_full_eval = False - elif not bf16_full_eval and not fp16_full_eval: - args.bf16_full_eval = args.bf16 - args.fp16_full_eval = args.fp16 - _output_logits = False - if locals().get('compute_metrics', None) is not None: _output_logits = True - if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True - if _output_logits: - os.environ['UNSLOTH_RETURN_LOGITS'] = '1' - if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'): - pass - else: - model_max_seq_length = getattr(model, 'max_seq_length', None) - args_max_seq_length = getattr(args, 'max_seq_length', None) - if args_max_seq_length is None and model_max_seq_length is not None: - max_seq_length = model.max_seq_length - if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length - elif args_max_seq_length is not None and model_max_seq_length is not None: - if args_max_seq_length > model_max_seq_length: - print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but ' - 'the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.') - args.max_seq_length = model_max_seq_length - if model is not None and hasattr(model, 'for_training'): - model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True)) - if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right' - if 'processing_class' in locals(): - if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right' - if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right' - __tokenizer = processing_class if 'processing_class' in locals() else tokenizer - from unsloth_zoo.vision_utils import UnslothVisionDataCollator - if not isinstance(data_collator, UnslothVisionDataCollator): - if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names: - data_collator = TransformersDataCollatorForLanguageModeling( - __tokenizer, - mlm = False, - mlm_probability = 0.0, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names: - data_collator = DataCollatorForSeq2Seq( - __tokenizer, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - else: - if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False - if hasattr(args, 'dataset_text_field'): args.dataset_text_field = '' - if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True} - if not isinstance(data_collator, UnslothVisionDataCollator): - if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'): - if isinstance(data_collator, DataCollatorForSeq2Seq): - data_collator = DataCollatorForSeq2Seq( - __tokenizer.tokenizer, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - else: - data_collator = TransformersDataCollatorForLanguageModeling( - __tokenizer.tokenizer, - mlm = False, - mlm_probability = 0.0, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - other_metrics = [] - - from unsloth_zoo.logging_utils import PatchRLStatistics - PatchRLStatistics('reward_trainer', other_metrics) - - # [TODO] Fix up DataParallel multiplying batch sizes - # [TODO] DDP works, but DP seems to not work? [TODO] - if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1: - if getattr(args, "_n_gpu", 1) != 1: - args._n_gpu = 1 - if "model" in locals() and hasattr(model, "for_training"): - model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True)) - super().__init__( - model = model, - args = args, - data_collator = data_collator, - train_dataset = train_dataset, - eval_dataset = eval_dataset, - processing_class = processing_class, - compute_metrics = compute_metrics, - callbacks = callbacks, - optimizer_cls_and_kwargs = optimizer_cls_and_kwargs, - preprocess_logits_for_metrics = preprocess_logits_for_metrics, - peft_config = peft_config,**kwargs) - if "model" in locals() and hasattr(model, "for_inference"): - model.for_inference() - if hasattr(self, 'neftune_hook_handle'): - self.neftune_hook_handle.remove() - if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle - if getattr(args, 'neftune_noise_alpha', None) is not None: - model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha - pass - if hasattr(self, 'accelerator'): - scaler = self.accelerator.scaler - current_model = model - while hasattr(current_model, 'model'): - current_model.accelerator_scaler = scaler - current_model = current_model.model - current_model.accelerator_scaler = scaler - pass - if hasattr(self, 'train'): - self.train = MethodType(prepare_for_training_mode(self.__class__.train), self) - pass - if hasattr(self, 'llm') and self.llm is not None and hasattr(self.llm, 'get_tokenizer'): - _vllm_tok = self.llm.get_tokenizer() - _pc = getattr(self, 'processing_class', None) or getattr(self, 'tokenizer', None) - if _vllm_tok is not None and _pc is not None and getattr(_pc, 'chat_template', None) is not None and getattr(_vllm_tok, 'chat_template', None) is None: - _vllm_tok.chat_template = _pc.chat_template - pass - -pass - - -if hasattr(logger, "addFilter"): - import logging - class HideLoggingMessage(logging.Filter): - def __init__(self, text): self.text = text - def filter(self, x): return not (self.text in x.getMessage()) - pass - logger.addFilter(HideLoggingMessage("`use_cache=True`")) - diff --git a/unsloth_compiled_cache/UnslothSFTTrainer.py b/unsloth_compiled_cache/UnslothSFTTrainer.py deleted file mode 100644 index 5719e38..0000000 --- a/unsloth_compiled_cache/UnslothSFTTrainer.py +++ /dev/null @@ -1,1612 +0,0 @@ -""" -2026.2.1 -2026.2.1 -4.57.6 -0.24.0 -__UNSLOTH_VERSIONING__ -""" - -# Unsloth auto generated code -# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see . - -from torch import Tensor -import torch -import torch.nn as nn -from torch.nn import functional as F -from unsloth_zoo.temporary_patches.common import torch_compile -from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable -from trl.trainer.sft_trainer import (Any, AutoProcessor, BaseTrainer, Callable, DataCollator, DataCollatorForLanguageModeling, DataCollatorForVisionLanguageModeling, Dataset, EvalPrediction, FLASH_ATTENTION_VARIANTS, IterableDataset, Optional, Path, PeftConfig, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SFTConfig, SFTTrainer, TrainerCallback, TrainingArguments, Union, clone_chat_template, contextlib, create_model_from_path, dataclass, defaultdict, dft_loss, get_act_offloading_ctx_manager, is_conversational, logger, logging, nn, os, pack_dataset, pad, selective_log_softmax, torch, Any, AutoProcessor, Callable, DataCollator, DataCollatorForLanguageModeling, DataCollatorForVisionLanguageModeling, Dataset, EvalPrediction, FLASH_ATTENTION_VARIANTS, IterableDataset, Optional, PeftConfig, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SFTConfig, SFTTrainer, TrainerCallback, TrainingArguments, Union, clone_chat_template, contextlib, create_model_from_path, defaultdict, dft_loss, get_act_offloading_ctx_manager, is_conversational, logger, os, pad, torch, Callable, DataCollator, DataCollatorForLanguageModeling, Dataset, IterableDataset, Optional, Union, os, pack_dataset, pad, Optional, PreTrainedModel, logger, os, torch, os) - - -import os -from typing import * -from dataclasses import dataclass, field -from packaging.version import Version -import torch -import numpy as np -from contextlib import nullcontext -from torch.nn import functional as F -import inspect -from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling -from transformers.training_args import ParallelMode -from unsloth_zoo.device_type import DEVICE_TYPE, device_synchronize - -# Wrap trainer with padding to right and enable training mode -# Also patches W&B since multiple runs must use wandb.finish() -import functools -from types import MethodType -try: - from unsloth_zoo.gradient_checkpointing import reset_unsloth_gradient_checkpointing_buffers -except: - def reset_unsloth_gradient_checkpointing_buffers(): pass -def prepare_for_training_mode(f): - @functools.wraps(f) - def wrapper(self, *args, **kwargs): - # Enable training mode - _was_training = None - # Get gradient checkpointing setting from training arguments - use_gc = getattr(self.args, 'gradient_checkpointing', True) - if hasattr(self, 'model') and hasattr(self.model, "training"): - _was_training = self.model.training - if hasattr(self, 'model') and hasattr(self.model, "for_training"): - self.model.for_training(use_gradient_checkpointing=use_gc) - output = f(self, *args, **kwargs) - # Restore previous mode when possible - if hasattr(self, 'model') and hasattr(self.model, "for_inference"): - if _was_training is False: - self.model.for_inference() - elif _was_training is True and hasattr(self.model, "for_training"): - self.model.for_training(use_gradient_checkpointing=use_gc) - # Reset gradient checkpointing buffers to free memory while staying ready for next run - try: - reset_unsloth_gradient_checkpointing_buffers() - except: - pass - # Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run - try: - import wandb - wandb.finish() - except: - pass - return output - return wrapper -pass - -torch_compile_options = { - "epilogue_fusion" : True, - "max_autotune" : False, - "shape_padding" : True, - "trace.enabled" : False, - "triton.cudagraphs" : False, -} - -@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) -def chunked_hidden_states_selective_log_softmax( - hidden_states: torch.Tensor, - lm_head: torch.Tensor, - index: torch.Tensor, - chunks: int = 4, - logit_scale_multiply: float = 0.0, - logit_scale_divide: float = 0.0, - logit_softcapping: float = 0.0, - temperature: float = 1.0, -) -> torch.Tensor: - # All Unsloth Zoo code licensed under AGPL3 - flat_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1]) - flat_index = index.reshape(-1) - - chunked_hidden_states = torch.chunk(flat_hidden_states, chunks=chunks, dim=0) - chunked_index = torch.chunk(flat_index, chunks=chunks, dim=0) - - all_per_token_logps = [] - - for chunk_hidden_states, chunk_index in zip(chunked_hidden_states, chunked_index): - chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t() - - if logit_scale_multiply != 0.0: - chunk_logits = chunk_logits * logit_scale_multiply - if logit_scale_divide != 0.0: - chunk_logits = chunk_logits / logit_scale_divide - if logit_softcapping != 0.0: - chunk_logits = chunk_logits * torch.tanh(chunk_logits / logit_softcapping) - - chunk_logits = chunk_logits.to(torch.float32) - - if temperature != 1.0: - chunk_logits = chunk_logits / temperature - - selected_logits = torch.gather(chunk_logits, dim=-1, index=chunk_index.unsqueeze(-1)).squeeze(-1) - logsumexp_values = torch.logsumexp(chunk_logits, dim=-1) - per_token_logps = selected_logits - logsumexp_values - all_per_token_logps.append(per_token_logps) - - all_per_token_logps = torch.concat(all_per_token_logps) - - all_per_token_logps = all_per_token_logps.reshape((hidden_states.shape[0], hidden_states.shape[1])) - return all_per_token_logps - -@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) -def chunked_selective_log_softmax(logits, index): - # Split into 4 chunks only - chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0) - chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0) - all_per_token_logps = [] - # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index) - for chunk_logits, chunk_index in zip(chunked_logits, chunked_index): - chunk_logits = chunk_logits.to(torch.float32) - selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1) - logsumexp_values = torch.logsumexp(chunk_logits, dim = -1) - per_token_logps = selected_logits - logsumexp_values - all_per_token_logps.append(per_token_logps) - pass - all_per_token_logps = torch.concat(all_per_token_logps) - all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1])) - return all_per_token_logps - -def calculate_pad_tokens_in_prompt( - input_ids: torch.Tensor, - logits_to_keep: int, - pad_token_id: int -) -> torch.Tensor: - """ - Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens - """ - if logits_to_keep >= input_ids.shape[1]: - raise ValueError("logits_to_keep must be smaller than the sequence length.") - - prompt_section = input_ids[:, :-logits_to_keep] - - padding_mask = (prompt_section == pad_token_id) - - pad_token_counts = padding_mask.sum(dim=1) - - return pad_token_counts - -def create_completion_attention_mask( - completion_input_ids: torch.Tensor, - left_pad_tokens_per_prompt: torch.Tensor, - max_left_pad: int, - pad_token_id: int -) -> torch.Tensor: - """ - Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad] - - Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens - and pad are pad tokens, this function would make a completion mask that would 0 out the pad - and p tokens. so in this example [0,0,0,1,1,1,0,0,0] - """ - batch_size, completion_len = completion_input_ids.shape - device = completion_input_ids.device - - num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt - - indices = torch.arange(completion_len, device=device).unsqueeze(0) - shift_mask = indices >= num_tokens_to_mask.unsqueeze(1) - - non_padding_mask = (completion_input_ids != pad_token_id) - - final_mask = shift_mask & non_padding_mask - - return final_mask - -def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor: - """ - Moves all padding tokens in each sequence of a batch to the right. - """ - mask = (tensor != pad_id) - # Must do stable=True since binary mark is unordered - sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True) - packed_tensor = torch.gather(tensor, 1, sorted_indices) - return packed_tensor - -def align_logprobs_with_mask( - logprob_tensor: torch.Tensor, - attention_mask: torch.Tensor, - pad_value: float = 0.0 -) -> torch.Tensor: - """ - Aligns a log probability tensor with a given attention mask. - """ - - device = logprob_tensor.device - batch_size, logprob_seq_len = logprob_tensor.shape - mask_seq_len = attention_mask.shape[1] - - padded_logprobs = torch.full( - attention_mask.shape, - fill_value=pad_value, - dtype=logprob_tensor.dtype, - device=device - ) - - left_pad_counts = torch.argmax(attention_mask, dim=1) - - cols = torch.arange(logprob_seq_len, device=device) - dest_indices = left_pad_counts.unsqueeze(1) + cols - - # Create destination row indices - # Shape: [batch_size, logprob_seq_len] - row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices) - - # --- 4. Filter out-of-bounds indices and perform assignment --- - # Create a mask to identify only the indices that are within the bounds - # of the target tensor's sequence length. - valid_mask = dest_indices < mask_seq_len - - # Use this mask to select only the valid row indices, column indices, - # and the corresponding values from the logprob tensor. - # This flattens the selected elements into 1D tensors. - valid_rows = row_indices[valid_mask] - valid_cols = dest_indices[valid_mask] - valid_vals = logprob_tensor[valid_mask] - - # Place the valid values into their correct positions in the padded tensor - # using a single, efficient advanced indexing operation. - padded_logprobs[valid_rows, valid_cols] = valid_vals - - return padded_logprobs - -def autotune_batch_and_chunks( - total_input_rows, - seq_len, - hidden_size, - vocab_size, - dtype_bytes=16, - multiplier=None -): - if multiplier is None: - final_m = max(4, seq_len // 4096) - else: - final_m = multiplier - - if torch.cuda.is_available(): - free_bytes, _ = torch.cuda.mem_get_info() - limit_gb = (free_bytes / (1024**3))*.80 - elif hasattr(torch, "xpu") and torch.xpu.is_available(): - # For XPU: estimate free memory from total - reserved - total_mem = torch.xpu.get_device_properties(0).total_memory - reserved_mem = torch.xpu.memory_reserved() - free_bytes = total_mem - reserved_mem - limit_gb = (free_bytes / (1024**3)) * 0.80 - else: - # Fallback: assume 8GB available - limit_gb = 8.0 - - bytes_to_gb = 1024**3 - - b_vals = torch.arange(total_input_rows, 0, -1, device='cpu', dtype=torch.float32) - - hidden_gb = (b_vals * seq_len * hidden_size * dtype_bytes) / bytes_to_gb - - base_logits = ((b_vals/total_input_rows) * b_vals * seq_len * vocab_size * dtype_bytes) / bytes_to_gb - logits_gb = base_logits / final_m - - total_mem_gb = hidden_gb + logits_gb - - valid_mask = total_mem_gb <= limit_gb - valid_indices = torch.nonzero(valid_mask, as_tuple=False) - - if valid_indices.shape[0] == 0: - #This means your GPU will OOM - return 4, final_m - - best_idx = valid_indices[0].item() - final_b = int(b_vals[best_idx].item()) - - return final_b, final_m -@dataclass -class UnslothSFTConfig(SFTConfig): - """ - - Configuration class for the [`SFTTrainer`]. - - This class includes only the parameters that are specific to SFT training. For a full list of training arguments, - please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may - differ from those in [`~transformers.TrainingArguments`]. - - Using [`~transformers.HfArgumentParser`] we can turn this class into - [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the - command line. - - Parameters: - > Parameters that control the model - - model_init_kwargs (`dict[str, Any]`, *optional*): - Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model` - argument of the [`SFTTrainer`] is provided as a string. If you're training a MoE architecture and want to - include the load balancing/auxilliary loss as a part of the final loss, remember to set - `output_router_logits=True` in this dictionary. - chat_template_path (`str`, *optional*): - If specified, sets the model's chat template. This can either be the path to a tokenizer (local directory - or Hugging Face Hub model) or a direct path to a Jinja template file. When using a Jinja file, you must - ensure that any special tokens referenced in the template are added to the tokenizer and that the model's - embedding layer is resized accordingly. - - > Parameters that control the data preprocessing - - dataset_text_field (`str`, *optional*, defaults to `"text"`): - Name of the column that contains text data in the dataset. - dataset_kwargs (`dict[str, Any]`, *optional*): - Dictionary of optional keyword arguments for the dataset preparation. The only supported key is - `skip_prepare_dataset`. When the model is a VLM, `skip_prepare_dataset` is automatically treated as `True` - regardless of the provided value, since preprocessing is done on the fly. - dataset_num_proc (`int`, *optional*): - Number of processes to use for processing the dataset. - eos_token (`str`, *optional*): - Token used to indicate the end of a turn or sequence. If `None`, it defaults to - `processing_class.eos_token`. - pad_token (`str`, *optional*): - Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that is also `None`, - it falls back to `processing_class.eos_token`. - max_length (`int` or `None`, *optional*, defaults to `1024`): - Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from the right. - If `None`, no truncation is applied. When packing is enabled, this value sets the sequence length. - packing (`bool`, *optional*, defaults to `False`): - Whether to group multiple sequences into fixed-length blocks to improve computational efficiency and reduce - padding. Uses `max_length` to define sequence length. - packing_strategy (`str`, *optional*, defaults to `"bfd"`): - Strategy for packing sequences. Can be either `"bfd"` (best-fit decreasing, default), or `"wrapped"`. - padding_free (`bool`, *optional*, defaults to `False`): - Whether to perform forward passes without padding by flattening all sequences in the batch into a single - continuous sequence. This reduces memory usage by eliminating padding overhead. Currently, this is only - supported with the FlashAttention 2 or 3, which can efficiently handle the flattened batch structure. When - packing is enabled with strategy `"bfd"`, padding-free is enabled, regardless of the value of this - parameter. - pad_to_multiple_of (`int`, *optional*): - If set, the sequences will be padded to a multiple of this value. - eval_packing (`bool`, *optional*): - Whether to pack the eval dataset. If `None`, uses the same value as `packing`. - - > Parameters that control the training - - completion_only_loss (`bool`, *optional*): - Whether to compute loss only on the completion part of the sequence. If set to `True`, loss is computed - only on the completion, which is supported only for [prompt-completion](#prompt-completion) datasets. If - `False`, loss is computed on the entire sequence. If `None` (default), the behavior depends on the dataset: - loss is computed on the completion for [prompt-completion](#prompt-completion) datasets, and on the full - sequence for [language modeling](#language-modeling) datasets. - assistant_only_loss (`bool`, *optional*, defaults to `False`): - Whether to compute loss only on the assistant part of the sequence. If set to `True`, loss is computed only - on the assistant responses, which is supported only for [conversational](#conversational) datasets. If - `False`, loss is computed on the entire sequence. - loss_type (`str`, *optional*, defaults to `"nll"`): - Type of loss to use. Possible values are `"nll"` (negative log-likelihood, default) and `"dft"` (Dynamic - Fine-Tuning, as described in [this paper](https://huggingface.co/papers/2508.05629)). - activation_offloading (`bool`, *optional*, defaults to `False`): - Whether to offload the activations to the CPU. - - """ - vllm_sampling_params: Optional[Any] = field( - default = None, - metadata = {'help': 'vLLM SamplingParams'}, - ) - unsloth_num_chunks : Optional[int] = field( - default = -1, - metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'}, - ) - unsloth_logit_chunk_multiplier : Optional[int] = field( - default = None, - metadata = {'help': 'Multiplier for chunked logit computations.'}, - ) - unsloth_grpo_mini_batch : Optional[int] = field( - default = None, - metadata = {'help': 'Mini batch size for GRPO hidden state accumulation. Default is None unless user defines it.'}, - ) - max_seq_length : Optional[int] = field( - default = None, - metadata = {'help': 'Maximum sequence length to truncate to.'}, - ) - def __init__( - self, - output_dir = None, - overwrite_output_dir = None, - do_train = False, - do_eval = False, - do_predict = False, - eval_strategy = 'no', - prediction_loss_only = False, - per_device_train_batch_size = 4, - per_device_eval_batch_size = 4, - per_gpu_train_batch_size = None, - per_gpu_eval_batch_size = None, - gradient_accumulation_steps = 2, - eval_accumulation_steps = 2, - eval_delay = 0, - torch_empty_cache_steps = 250, - learning_rate = 5e-05, - weight_decay = 0.01, - adam_beta1 = 0.9, - adam_beta2 = 0.999, - adam_epsilon = 1e-08, - max_grad_norm = 1.0, - num_train_epochs = 3.0, - max_steps = -1, - lr_scheduler_type = 'linear', - lr_scheduler_kwargs = None, - warmup_ratio = 0.1, - warmup_steps = 0, - log_level = 'passive', - log_level_replica = 'warning', - log_on_each_node = True, - logging_dir = None, - logging_strategy = 'steps', - logging_first_step = False, - logging_steps = 1, - logging_nan_inf_filter = False, - save_strategy = 'steps', - save_steps = 500, - save_total_limit = None, - save_safetensors = True, - save_on_each_node = False, - save_only_model = False, - restore_callback_states_from_checkpoint = False, - no_cuda = False, - use_cpu = False, - use_mps_device = False, - seed = 3407, - data_seed = 3407, - jit_mode_eval = False, - bf16 = False, - fp16 = False, - fp16_opt_level = 'O1', - half_precision_backend = 'auto', - bf16_full_eval = False, - fp16_full_eval = False, - tf32 = None, - local_rank = -1, - ddp_backend = None, - tpu_num_cores = None, - tpu_metrics_debug = False, - debug = '', - dataloader_drop_last = False, - eval_steps = None, - dataloader_num_workers = 0, - dataloader_prefetch_factor = None, - past_index = -1, - run_name = None, - disable_tqdm = None, - remove_unused_columns = True, - label_names = None, - load_best_model_at_end = False, - metric_for_best_model = None, - greater_is_better = None, - ignore_data_skip = False, - fsdp = None, - fsdp_min_num_params = 0, - fsdp_config = None, - fsdp_transformer_layer_cls_to_wrap = None, - accelerator_config = None, - parallelism_config = None, - deepspeed = None, - label_smoothing_factor = 0.0, - optim = 'adamw_8bit', - optim_args = None, - adafactor = False, - group_by_length = False, - length_column_name = 'length', - report_to = 'none', - project = 'huggingface', - trackio_space_id = 'trackio', - ddp_find_unused_parameters = None, - ddp_bucket_cap_mb = None, - ddp_broadcast_buffers = None, - dataloader_pin_memory = True, - dataloader_persistent_workers = False, - skip_memory_metrics = True, - use_legacy_prediction_loop = False, - push_to_hub = False, - resume_from_checkpoint = None, - hub_model_id = None, - hub_strategy = 'every_save', - hub_token = None, - hub_private_repo = None, - hub_always_push = False, - hub_revision = None, - gradient_checkpointing = True, - gradient_checkpointing_kwargs = None, - include_inputs_for_metrics = False, - eval_do_concat_batches = True, - fp16_backend = 'auto', - push_to_hub_model_id = None, - push_to_hub_organization = None, - push_to_hub_token = None, - mp_parameters = '', - auto_find_batch_size = False, - full_determinism = False, - torchdynamo = None, - ray_scope = 'last', - ddp_timeout = 1800, - torch_compile = False, - torch_compile_backend = None, - torch_compile_mode = None, - include_tokens_per_second = False, - include_num_input_tokens_seen = False, - neftune_noise_alpha = None, - optim_target_modules = None, - batch_eval_metrics = False, - eval_on_start = False, - use_liger_kernel = False, - liger_kernel_config = None, - eval_use_gather_object = False, - average_tokens_across_devices = True, - model_init_kwargs = None, - chat_template_path = None, - dataset_text_field = 'text', - dataset_kwargs = None, - dataset_num_proc = None, - eos_token = None, - pad_token = None, - max_length = 1024, - packing = False, - packing_strategy = 'bfd', - padding_free = False, - pad_to_multiple_of = None, - eval_packing = None, - completion_only_loss = None, - assistant_only_loss = False, - loss_type = 'nll', - activation_offloading = False, - vllm_sampling_params = None, - unsloth_num_chunks = -1, - unsloth_logit_chunk_multiplier = None, - unsloth_grpo_mini_batch = None, - max_seq_length = None, - **kwargs, - ): - if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!') - if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!') - if num_train_epochs is None: - num_train_epochs = 3.0 # Default to 3 epochs if None, max_steps will override - if output_dir is None and save_strategy == 'steps' and save_steps == 500: - output_dir = 'unsloth_training_checkpoints' - save_strategy = 'no' - import multiprocessing as _mp - if _mp.get_start_method() != 'fork': - dataset_num_proc = None - elif dataset_num_proc is None: - import psutil - dataset_num_proc = min(max((psutil.cpu_count() or 1)+4, 2), 64) - memory_gb_left = psutil.virtual_memory().available / (1024**3) - if memory_gb_left <= 2: dataset_num_proc = 1 - else: dataset_num_proc = min(dataset_num_proc, int(memory_gb_left)) - if os.environ.get('UNSLOTH_ENABLE_FLEX_ATTENTION', '0') == '1': - from unsloth_zoo.flex_attention import HAS_FLEX_ATTENTION - if HAS_FLEX_ATTENTION and pad_to_multiple_of is None: - from unsloth_zoo.flex_attention import FLEX_ATTENTION_BLOCK_SIZE - pad_to_multiple_of = FLEX_ATTENTION_BLOCK_SIZE - - - super().__init__( - output_dir = output_dir, - overwrite_output_dir = overwrite_output_dir, - do_train = do_train, - do_eval = do_eval, - do_predict = do_predict, - eval_strategy = eval_strategy, - prediction_loss_only = prediction_loss_only, - per_device_train_batch_size = per_device_train_batch_size, - per_device_eval_batch_size = per_device_eval_batch_size, - per_gpu_train_batch_size = per_gpu_train_batch_size, - per_gpu_eval_batch_size = per_gpu_eval_batch_size, - gradient_accumulation_steps = gradient_accumulation_steps, - eval_accumulation_steps = eval_accumulation_steps, - eval_delay = eval_delay, - torch_empty_cache_steps = torch_empty_cache_steps, - learning_rate = learning_rate, - weight_decay = weight_decay, - adam_beta1 = adam_beta1, - adam_beta2 = adam_beta2, - adam_epsilon = adam_epsilon, - max_grad_norm = max_grad_norm, - num_train_epochs = num_train_epochs, - max_steps = max_steps, - lr_scheduler_type = lr_scheduler_type, - lr_scheduler_kwargs = lr_scheduler_kwargs, - warmup_ratio = warmup_ratio, - warmup_steps = warmup_steps, - log_level = log_level, - log_level_replica = log_level_replica, - log_on_each_node = log_on_each_node, - logging_dir = logging_dir, - logging_strategy = logging_strategy, - logging_first_step = logging_first_step, - logging_steps = logging_steps, - logging_nan_inf_filter = logging_nan_inf_filter, - save_strategy = save_strategy, - save_steps = save_steps, - save_total_limit = save_total_limit, - save_safetensors = save_safetensors, - save_on_each_node = save_on_each_node, - save_only_model = save_only_model, - restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint, - no_cuda = no_cuda, - use_cpu = use_cpu, - use_mps_device = use_mps_device, - seed = seed, - data_seed = data_seed, - jit_mode_eval = jit_mode_eval, - bf16 = bf16, - fp16 = fp16, - fp16_opt_level = fp16_opt_level, - half_precision_backend = half_precision_backend, - bf16_full_eval = bf16_full_eval, - fp16_full_eval = fp16_full_eval, - tf32 = tf32, - local_rank = local_rank, - ddp_backend = ddp_backend, - tpu_num_cores = tpu_num_cores, - tpu_metrics_debug = tpu_metrics_debug, - debug = debug, - dataloader_drop_last = dataloader_drop_last, - eval_steps = eval_steps, - dataloader_num_workers = dataloader_num_workers, - dataloader_prefetch_factor = dataloader_prefetch_factor, - past_index = past_index, - run_name = run_name, - disable_tqdm = disable_tqdm, - remove_unused_columns = remove_unused_columns, - label_names = label_names, - load_best_model_at_end = load_best_model_at_end, - metric_for_best_model = metric_for_best_model, - greater_is_better = greater_is_better, - ignore_data_skip = ignore_data_skip, - fsdp = fsdp, - fsdp_min_num_params = fsdp_min_num_params, - fsdp_config = fsdp_config, - fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap, - accelerator_config = accelerator_config, - parallelism_config = parallelism_config, - deepspeed = deepspeed, - label_smoothing_factor = label_smoothing_factor, - optim = optim, - optim_args = optim_args, - adafactor = adafactor, - group_by_length = group_by_length, - length_column_name = length_column_name, - report_to = report_to, - project = project, - trackio_space_id = trackio_space_id, - ddp_find_unused_parameters = ddp_find_unused_parameters, - ddp_bucket_cap_mb = ddp_bucket_cap_mb, - ddp_broadcast_buffers = ddp_broadcast_buffers, - dataloader_pin_memory = dataloader_pin_memory, - dataloader_persistent_workers = dataloader_persistent_workers, - skip_memory_metrics = skip_memory_metrics, - use_legacy_prediction_loop = use_legacy_prediction_loop, - push_to_hub = push_to_hub, - resume_from_checkpoint = resume_from_checkpoint, - hub_model_id = hub_model_id, - hub_strategy = hub_strategy, - hub_token = hub_token, - hub_private_repo = hub_private_repo, - hub_always_push = hub_always_push, - hub_revision = hub_revision, - gradient_checkpointing = gradient_checkpointing, - gradient_checkpointing_kwargs = gradient_checkpointing_kwargs, - include_inputs_for_metrics = include_inputs_for_metrics, - eval_do_concat_batches = eval_do_concat_batches, - fp16_backend = fp16_backend, - push_to_hub_model_id = push_to_hub_model_id, - push_to_hub_organization = push_to_hub_organization, - push_to_hub_token = push_to_hub_token, - mp_parameters = mp_parameters, - auto_find_batch_size = auto_find_batch_size, - full_determinism = full_determinism, - torchdynamo = torchdynamo, - ray_scope = ray_scope, - ddp_timeout = ddp_timeout, - torch_compile = torch_compile, - torch_compile_backend = torch_compile_backend, - torch_compile_mode = torch_compile_mode, - include_tokens_per_second = include_tokens_per_second, - include_num_input_tokens_seen = include_num_input_tokens_seen, - neftune_noise_alpha = neftune_noise_alpha, - optim_target_modules = optim_target_modules, - batch_eval_metrics = batch_eval_metrics, - eval_on_start = eval_on_start, - use_liger_kernel = use_liger_kernel, - liger_kernel_config = liger_kernel_config, - eval_use_gather_object = eval_use_gather_object, - average_tokens_across_devices = average_tokens_across_devices, - model_init_kwargs = model_init_kwargs, - chat_template_path = chat_template_path, - dataset_text_field = dataset_text_field, - dataset_kwargs = dataset_kwargs, - dataset_num_proc = dataset_num_proc, - eos_token = eos_token, - pad_token = pad_token, - max_length = max_length, - packing = packing, - packing_strategy = packing_strategy, - padding_free = padding_free, - pad_to_multiple_of = pad_to_multiple_of, - eval_packing = eval_packing, - completion_only_loss = completion_only_loss, - assistant_only_loss = assistant_only_loss, - loss_type = loss_type, - activation_offloading = activation_offloading,**kwargs) - self.vllm_sampling_params = vllm_sampling_params - self.unsloth_num_chunks = unsloth_num_chunks - if unsloth_grpo_mini_batch is not None: - if self.generation_batch_size >= unsloth_grpo_mini_batch: - self.unsloth_grpo_mini_batch = unsloth_grpo_mini_batch - else: - raise ValueError( - f"Unsloth GRPO mini batch size needs to be less than or equal to the effective generation batch size, " - f"which is self.per_device_train_batch_size * gradient_accumulation_steps." - ) - self.unsloth_logit_chunk_multiplier = unsloth_logit_chunk_multiplier - self.max_seq_length = max_seq_length - -pass - -class _UnslothSFTTrainer(BaseTrainer): - """""" - - _tag_names = ["trl", "sft"] - _name = "SFT" - - def __init__( - self, - model: Union[str, PreTrainedModel], - args: Optional[Union[SFTConfig, TrainingArguments]] = None, - data_collator: Optional[DataCollator] = None, - train_dataset: Optional[Union[Dataset, IterableDataset]] = None, - eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None, - processing_class: Optional[Union[PreTrainedTokenizerBase, ProcessorMixin]] = None, - compute_loss_func: Optional[Callable] = None, - compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None, - callbacks: Optional[list[TrainerCallback]] = None, - optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), - optimizer_cls_and_kwargs: Optional[tuple[type[torch.optim.Optimizer], dict[str, Any]]] = None, - preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, - peft_config: Optional["PeftConfig"] = None, - formatting_func: Optional[Callable[[dict], str]] = None, - ): - # Args - if args is None: - model_name = model if isinstance(model, str) else model.config._name_or_path - model_name = model_name.split("/")[-1] - args = SFTConfig(f"{model_name}-SFT") - elif isinstance(args, TrainingArguments) and not isinstance(args, SFTConfig): - dict_args = args.to_dict() - dict_args["hub_token"] = args.hub_token # to_dict hides the hub_token - dict_args.pop("push_to_hub_token", None) - args = SFTConfig(**dict_args) - - # Model - if isinstance(model, str): - model = create_model_from_path(model, **args.model_init_kwargs or {}) - else: - if args.model_init_kwargs is not None: - logger.warning( - "You passed `model_init_kwargs` to the `SFTConfig`, but your model is already instantiated. " - "The `model_init_kwargs` will be ignored." - ) - model_id = model.config._name_or_path - - # Processing class - if processing_class is None: - processing_class = AutoProcessor.from_pretrained(model_id) - - # Handle pad token for processors or tokenizers - if isinstance(processing_class, ProcessorMixin): - tokenizer = processing_class.tokenizer - self._is_vlm = True - elif isinstance(processing_class, PreTrainedTokenizerBase): - tokenizer = processing_class - self._is_vlm = False - else: - raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") - - if args.eos_token is not None: - eos_token = args.eos_token - eos_token_id = tokenizer.convert_tokens_to_ids(eos_token) - if eos_token_id is None: - raise ValueError( - f"The specified `eos_token` ('{eos_token}') is not found in the vocabulary of the given " - f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `eos_token` exists " - "in the vocabulary before using it as an EOS token." - ) - tokenizer.eos_token_id = eos_token_id - - if args.chat_template_path is not None: - if os.path.isfile(args.chat_template_path) and args.chat_template_path.endswith((".jinja", ".j2")): - with open(args.chat_template_path, encoding="utf-8") as chat_template_file: - processing_class.chat_template = chat_template_file.read() - added_tokens = [] - else: - model, processing_class, added_tokens = clone_chat_template( - model, processing_class, args.chat_template_path - ) - else: - added_tokens = [] - - # Catch some wrong configurations related to VLMs - if self._is_vlm and args.packing: - raise ValueError( - "Packing is not supported for vision-language models. Please set `packing=False` in the SFTConfig." - ) - if self._is_vlm and args.padding_free: - raise ValueError( - "Padding-free training is yet not supported for vision-language models. Please set " - "`padding_free=False` in the `SFTConfig`." - ) - if self._is_vlm and args.assistant_only_loss: - raise ValueError( - "Assistant-only loss is not yet supported for vision-language models. Please set " - "`assistant_only_loss=False` in the `SFTConfig`." - ) - - # PEFT configuration and model wrapping - if False: - if added_tokens: - # Ensure that the added tokens are trainable - if peft_config.trainable_token_indices is None: - peft_config.trainable_token_indices = {"embed_tokens": added_tokens} - elif "embed_tokens" not in peft_config.trainable_token_indices: - peft_config.trainable_token_indices["embed_tokens"] = added_tokens - else: - peft_config.trainable_token_indices["embed_tokens"].extend(added_tokens) - - # Ensure that the lm_head is trainable - if peft_config.modules_to_save is None or "lm_head" not in peft_config.modules_to_save: - logger.warning( - "Cloning chat template added new tokens to the tokenizer, but 'lm_head' is not in PEFT's " - "`modules_to_save`. As a result, the model may not learn to generate outputs with these new " - "tokens, leading to degraded generation quality. To fix this, add " - "`modules_to_save=['lm_head']` to your PEFT configuration." - ) - - if peft_config.modules_to_save is None: - peft_config.modules_to_save = ["lm_head"] - else: - peft_config.modules_to_save.append("lm_head") - - # In Prompt Tuning a small set of trainable virtual tokens [continuous prompt embeddings] is prepended to the - # input. We store the number of these tokens so we can account for them correctly when calculating accuracy. - self.num_virtual_tokens = 0 - - if False: - pass - if model.active_adapter in model.peft_config: - peft_model_config = model.peft_config[model.active_adapter] - self.num_virtual_tokens = getattr(peft_model_config, "num_virtual_tokens", 0) - - # Data collator - # BFD packing requires padding-free mode; otherwise, the collator outputs padded attention masks, causing - # FlashAttention to ignore position_ids and recompute them incorrectly from the padded attention mask. - self.padding_free = args.padding_free or (args.packing and args.packing_strategy == "bfd") - use_flash_attention = model.config._attn_implementation in FLASH_ATTENTION_VARIANTS - if self.padding_free: - if data_collator is not None: - raise ValueError("Passing a custom data collator is not supported when using padding-free.") - if args.packing and args.packing_strategy == "wrapped": - logger.warning( - "You are passing `padding_free=True` with the 'wrapped' packing strategy, which is not " - "recommended. Please refer to the documentation to understand why this is not recommended." - ) - if not use_flash_attention: - logger.warning( - "Padding-free training is enabled, but the attention implementation is not set to a supported " - "flash attention variant. Padding-free training flattens batches into a single sequence, and only " - "the following implementations are known to reliably support this: " - f"{', '.join(sorted(FLASH_ATTENTION_VARIANTS))}. Using other implementations may lead to " - "unexpected behavior. To ensure compatibility, set `attn_implementation` in the model " - "configuration to one of these supported options or verify that your attention mechanism can " - "handle flattened sequences." - ) - # Decide whether to use completion-only loss: if not specified, then it is set to True if the dataset format - # is prompt-completion, and False if the dataset format is language modeling. - dataset_sample = next(iter(train_dataset)) - if args.completion_only_loss is None: - self.completion_only_loss = "prompt" in dataset_sample and "completion" in dataset_sample - else: - self.completion_only_loss = args.completion_only_loss - - self._is_vision_dataset = "image" in dataset_sample or "images" in dataset_sample - # Unsloth: override _is_vlm for VLM models that pass a bare tokenizer - if not self._is_vlm and self._is_vision_dataset: - _m = model - if hasattr(_m, "model"): _m = _m.model - if hasattr(getattr(_m, "config", None), "vision_config") or\ - _m.__class__.__name__.endswith("ForConditionalGeneration"): - self._is_vlm = True - if self._is_vision_dataset and not self._is_vlm: - raise ValueError( - "The dataset appears to be vision-related (contains 'image' or 'images' keys), but the provided " - "model does not seem to be a vision-language model. Please check your model and dataset." - ) - - if data_collator is None and not self._is_vision_dataset: - # Get the pad token: if not provided, use the one from the processing class or the eos token - # if the processing class does not have a pad token. - pad_token = args.pad_token or tokenizer.pad_token or tokenizer.eos_token - pad_token_id = tokenizer.convert_tokens_to_ids(pad_token) - if pad_token_id is None: - raise ValueError( - f"The specified `pad_token` ('{pad_token}') is not found in the vocabulary of the given " - f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `pad_token` exists " - "in the vocabulary before using it as a padding token." - ) - data_collator = DataCollatorForLanguageModeling( - pad_token_id=pad_token_id, - completion_only_loss=self.completion_only_loss, - padding_free=self.padding_free, - pad_to_multiple_of=args.pad_to_multiple_of, - ) - elif data_collator is None and self._is_vision_dataset: - data_collator = DataCollatorForVisionLanguageModeling( - processor=processing_class, - max_length=args.max_length, - completion_only_loss=self.completion_only_loss, - pad_to_multiple_of=args.pad_to_multiple_of, - dataset_text_field=args.dataset_text_field, - ) - - if args.packing and args.packing_strategy == "bfd" and not use_flash_attention: - logger.warning( - "You are using packing, but the attention implementation is not set to a supported flash attention " - "variant. Packing gathers multiple samples into a single sequence, and only the following " - f"implementations are known to reliably support this: {', '.join(sorted(FLASH_ATTENTION_VARIANTS))}. " - "Using other implementations may lead to cross-contamination between samples. To avoid this, either " - "disable packing by setting `packing=False`, or set `attn_implementation` in the model configuration " - "to one of these supported options." - ) - if args.assistant_only_loss and not is_conversational(dataset_sample): - raise ValueError( - "You set `assistant_only_loss=True`, but the dataset is not conversational. This option is only " - "supported for conversational datasets." - ) - - # Dataset - # Skip dataset preparation if `skip_prepare_dataset=True` in `dataset_kwargs`, or if it's a VLM, where - # preprocessing [e.g., image-to-pixel conversion] is too costly and done on the fly instead. - skip_prepare_dataset = ( - args.dataset_kwargs is not None - and args.dataset_kwargs.get("skip_prepare_dataset", False) - or self._is_vision_dataset - ) - if not skip_prepare_dataset: - if self.completion_only_loss and formatting_func: - raise ValueError( - "A formatting function was provided while `completion_only_loss=True`, which is incompatible. " - "Using a formatter converts the dataset to a language modeling type, conflicting with " - "completion-only loss. To resolve this, apply your formatting function before passing the " - "dataset, or disable `completion_only_loss` in `SFTConfig`." - ) - self._unsloth_model_ref = model - train_dataset = self._prepare_dataset( - train_dataset, processing_class, args, args.packing, formatting_func, "train" - ) - if eval_dataset is not None: - packing = args.packing if args.eval_packing is None else args.eval_packing - if isinstance(eval_dataset, dict): - eval_dataset = { - key: self._prepare_dataset(dataset, processing_class, args, packing, formatting_func, key) - for key, dataset in eval_dataset.items() - } - else: - eval_dataset = self._prepare_dataset( - eval_dataset, processing_class, args, packing, formatting_func, "eval" - ) - - # Loss function - if args.loss_type == "nll": - pass # use the default loss - elif args.loss_type == "dft": - if compute_loss_func is not None: - raise ValueError( - "You passed a `compute_loss_func` together with `loss_type='dft'` to the `SFTTrainer`. " - "When using `loss_type='dft'`, the loss function is internally set to the DFT loss, so passing a " - "`compute_loss_func` is not allowed." - ) - compute_loss_func = dft_loss - else: - raise ValueError(f"Invalid `loss_type` {args.loss_type} passed. Supported values are 'nll' and 'dft'.") - - # Initialize the metrics - self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} - self._total_train_tokens = 0 - - # Initialize the Trainer. Parent class will handle: - # - DeepSpeed configuration [through create_accelerator_and_postprocess] - # - FSDP setup - # - Distributed training setup - # - Optimizer and scheduler creation - - super().__init__( - model=model, - args=args, - data_collator=data_collator, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - processing_class=processing_class, - compute_loss_func=compute_loss_func, - compute_metrics=compute_metrics, - callbacks=callbacks, - optimizers=optimizers, - optimizer_cls_and_kwargs=optimizer_cls_and_kwargs, - preprocess_logits_for_metrics=preprocess_logits_for_metrics, - ) - - # Initialize activation offloading context - if self.args.activation_offloading: - self.maybe_activation_offload_context = get_act_offloading_ctx_manager(model=self.model) - else: - self.maybe_activation_offload_context = contextlib.nullcontext() - - # Add tags for models that have been loaded with the correct transformers version - if hasattr(self.model, "add_model_tags"): - self.model.add_model_tags(self._tag_names) - - self.aux_loss_enabled = getattr(model.config, "output_router_logits", False) - - def _prepare_dataset( - self, - dataset: Union[Dataset, IterableDataset], - processing_class, - args, - packing: bool, - formatting_func: Optional[Callable[[dict], str]], - dataset_name: str, - ) -> Union[Dataset, IterableDataset]: - # All Unsloth Zoo code licensed under LGPLv3 - try: - if isinstance(dataset, ConstantLengthDataset): return dataset - except: - pass - - map_kwargs = {} - use_desc = isinstance(dataset, Dataset) - is_vlm = hasattr(processing_class, "tokenizer") - tokenizer = processing_class - if is_vlm: tokenizer = processing_class.tokenizer - - # Dynamic detection: check if model's module defines a function - # that requires token_type_ids when is_training=True - import sys as _sys - _needs_token_type_ids = False - # Split to avoid compiler substring match on masking_utils names - _ccm = 'create_' + 'causal_mask_mapping' - _model = getattr(self, '_unsloth_model_ref', None) or getattr(self, 'model', None) - if _model is not None: - for _m in (_model, getattr(_model, 'model', None)): - if _m is None: continue - _mod = _sys.modules.get(type(_m).__module__) - if _mod is not None and hasattr(_mod, _ccm): - _needs_token_type_ids = True - break - - if not _needs_token_type_ids: - # Fallback: model not yet available, check processor class MRO - for _base in type(processing_class).__mro__: - _base_mod = getattr(_base, '__module__', '') - if 'transformers.models.' in _base_mod: - _modeling_mod = _base_mod.replace('.processing_', '.modeling_') - _mod = _sys.modules.get(_modeling_mod) - if _mod is not None and hasattr(_mod, _ccm): - _needs_token_type_ids = True - break - if _needs_token_type_ids and hasattr(args, 'remove_unused_columns'): - args.remove_unused_columns = False - - # Get max length - max_seq_length = getattr(args, "max_length", 0) - if max_seq_length == 0: max_seq_length = getattr(args, "max_seq_length", 0) - if max_seq_length == 0: max_seq_length = getattr(self, "max_seq_length", 0) - if max_seq_length == 0: max_seq_length = getattr(self, "max_seq", 0) - if max_seq_length == 0: raise RuntimeError("Unsloth: max_seq_length is 0! Please specify one!") - dataset_text_field = getattr(args, "dataset_text_field", "text") - do_truncation = max_seq_length != 0 - do_formatting_func = False - do_tokenize = True - - # Get correct column names - column_names = set(next(iter(dataset)).keys()) - used_column_names = ["input_ids"] - if "attention_mask" in column_names: - used_column_names.append("attention_mask") - if _needs_token_type_ids: - used_column_names.append("token_type_ids") - - # Check if already tokenized so skip - from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling - if "labels" in column_names: - # Most likely forgot data collator! - if is_vlm and not hasattr(tokenizer, "pad"): - # Check if processing_class has a .pad, if not, use tokenizer.tokenizer - raise RuntimeError(f"Unsloth: {processing_class.__class__} does not have .pad!") - self.data_collator = DataCollatorForSeq2Seq(tokenizer) - used_column_names.append("labels") - do_tokenize = False - elif "input_ids" in column_names: - # Skip dataset prep, and set data collator - if is_vlm and not hasattr(tokenizer, "pad"): - # Check if processing_class has a .pad, if not, use tokenizer.tokenizer - raise RuntimeError(f"Unsloth: {processing_class.__class__} does not have .pad!") - self.data_collator = DataCollatorForLanguageModeling(tokenizer, mlm = False) - do_tokenize = False - elif dataset_text_field not in column_names: - do_formatting_func = True - if formatting_func is None: - raise RuntimeError("Unsloth: You must specify a `formatting_func`") - pass - - if do_tokenize: - # Check double BOS tokens - if do_formatting_func: - test_text = formatting_func(next(iter(dataset))) - if not isinstance(test_text, list): - raise ValueError( - "Unsloth: The `formatting_func` should return a list of processed strings." - ) - test_text = test_text[0] - else: - test_text = next(iter(dataset))[dataset_text_field][0] - - # Get chat template - chat_template = getattr(processing_class, 'chat_template', '') - if chat_template == '' and is_vlm: - chat_template = getattr(tokenizer, 'chat_template', '') - if chat_template is None: - chat_template = '' - - # Get bos_token - add_special_tokens = True - bos_token_1 = getattr(processing_class, 'bos_token', None) - bos_token_2 = getattr(tokenizer, 'bos_token', None) - bos_token = bos_token_1 or bos_token_2 - - if bos_token is not None: - if test_text.startswith(bos_token) or bos_token in chat_template: - add_special_tokens = False - print("Unsloth: We found double BOS tokens - we shall remove one automatically.") - pass - - # Create tokenize function - def _tokenize(example): - return tokenizer( - example[dataset_text_field] if not do_formatting_func else formatting_func(example), - truncation = do_truncation, - max_length = max_seq_length, - return_token_type_ids = _needs_token_type_ids, - add_special_tokens = add_special_tokens, - ) - pass - - if not isinstance(dataset, IterableDataset): - import multiprocessing as _mp - if _mp.get_start_method() != 'fork': - dataset_num_proc = None - else: - dataset_num_proc = getattr(args, "dataset_num_proc", None) - if dataset_num_proc is None: - import psutil - dataset_num_proc = min(max((psutil.cpu_count() or 1)+4, 2), 64) - memory_gb_left = psutil.virtual_memory().available / (1024**3) - if memory_gb_left <= 2: - dataset_num_proc = 1 - else: - dataset_num_proc = min(dataset_num_proc, int(memory_gb_left)) - map_kwargs["num_proc"] = dataset_num_proc - else: - map_kwargs["batch_size"] = dataset._ex_iterable.batch_size - - if use_desc: map_kwargs["desc"] = f'Unsloth: Tokenizing ["{dataset_text_field}"]' - import warnings as _w - with _w.catch_warnings(): - _w.filterwarnings("ignore", message=".*couldn't be hashed properly.*") - dataset = dataset.map(_tokenize, batched = True, remove_columns = list(column_names), **map_kwargs) - - # If VLM, switch data collator since .pad is needed! - if is_vlm and not hasattr(processing_class, "pad"): - data_collator = DataCollatorForLanguageModeling(tokenizer, mlm = False) - self.data_collator = data_collator - pass - pass - if packing: - # Try using new packing which works in TRL - try: - pack_dataset - except: - print("Unsloth: Hugging Face's packing is currently buggy - we're disabling it for now!") - return dataset - - if max_seq_length == 0: - raise ValueError("When packing is enabled, `max_seq_length` can't be `None`.") - - if use_desc: map_kwargs["desc"] = f"Unsloth: Packing {dataset_name} dataset" - dataset = pack_dataset( - dataset.select_columns(used_column_names), - max_seq_length, - getattr(args, "packing_strategy", "bfd"), - map_kwargs, - ) - pass - return dataset - - def _set_signature_columns_if_needed(self): - # If `self.args.remove_unused_columns` is True, non-signature columns are removed. - # By default, this method sets `self._signature_columns` to the model's expected inputs (usually, "input_ids" - # and "attention_mask"). When using `train_on_completion_only` we add a "completion_mask" column to the - # dataset. So we need to override the default signature columns to include "completion_mask" as well. - if self._signature_columns is None: - if self._is_vision_dataset: - self._signature_columns = ["messages", "prompt", "completion", "images", "input_ids", "labels", "attention_mask", "seq_lengths", "completion_mask", "assistant_masks"] - else: - self._signature_columns = ["input_ids", "labels", "seq_lengths", "completion_mask", "assistant_masks"] - - def compute_loss( - self, model, inputs, return_outputs = False, num_items_in_batch = None - ): - outputs = super().compute_loss( - model, - inputs, - return_outputs = return_outputs, - num_items_in_batch = num_items_in_batch, - ) - return outputs - - # Override training step to add activation offloading context. - def training_step(self, *args, **kwargs): - with self.maybe_activation_offload_context: - return super().training_step(*args, **kwargs) - - def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None: - mode = "train" if self.model.training else "eval" - metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics - - # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs` - # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format. - if mode == "eval": - metrics = {f"eval_{key}": val for key, val in metrics.items()} - - logs.update(metrics) - super().log(logs, start_time) - self._metrics[mode].clear() - - # Ensure the model card is saved along with the checkpoint - def _save_checkpoint(self, model, trial): - if self.args.hub_model_id is None: - model_name = Path(self.args.output_dir).name - else: - model_name = self.args.hub_model_id.split("/")[-1] - self.create_model_card(model_name=model_name) - super()._save_checkpoint(model, trial) -class UnslothSFTTrainer(_UnslothSFTTrainer): - """ - - Trainer for Supervised Fine-Tuning (SFT) method. - - This class is a wrapper around the [`~transformers.Trainer`] class and inherits all of its attributes and methods. - - Example: - - ```python - from datasets import load_dataset - from trl import SFTTrainer - - dataset = load_dataset("roneneldan/TinyStories", split="train[:1%]") - - trainer = SFTTrainer(model="Qwen/Qwen2-0.5B-Instruct", train_dataset=dataset) - trainer.train() - ``` - - Args: - model (`Union[str, PreTrainedModel]`): - Model to be trained. Can be either: - - - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a - path to a *directory* containing model weights saved using - [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded - using `.from_pretrained` (where `` is derived from the model - config) with the keyword arguments in `args.model_init_kwargs`. - - A [`~transformers.PreTrainedModel`] object. - If you're training a model with an MoE architecture and want to include the load balancing/auxilliary loss - as a part of the final loss, remember to set the `output_router_logits` config of the model to `True`. - args ([`SFTConfig`], *optional*): - Configuration for this trainer. If `None`, a default configuration is used. - data_collator ([`~transformers.DataCollator`], *optional*): - Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`. - Will default to [`~trainer.sft_trainer.DataCollatorForLanguageModeling`] if the model is a language model - and [`~trainer.sft_trainer.DataCollatorForVisionLanguageModeling`] if the model is a vision-language model. - train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]): - Dataset to use for training. SFT supports both [language modeling](#language-modeling) type and - [prompt-completion](#prompt-completion) type. The format of the samples can be either: - - - [Standard](dataset_formats#standard): Each sample contains plain text. - - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role - and content). - - The trainer also supports processed datasets (tokenized) as long as they contain an `input_ids` field. - eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`): - Dataset to use for evaluation. It must meet the same requirements as `train_dataset`. - processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.ProcessorMixin`], *optional*): - Processing class used to process the data. If `None`, the processing class is loaded from the model's name - with [`~transformers.AutoProcessor.from_pretrained`]. A padding token, `tokenizer.pad_token`, must be set. - If the processing class has not set a padding token, `tokenizer.eos_token` will be used as the default. - compute_loss_func (`Callable`, *optional*): - A function that accepts the raw model outputs, labels, and the number of items in the entire accumulated - batch (batch_size * gradient_accumulation_steps) and returns the loss. For example, see the default [loss - function](https://github.com/huggingface/transformers/blob/052e652d6d53c2b26ffde87e039b723949a53493/src/transformers/trainer.py#L3618) - used by [`Trainer`]. - compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): - The function that will be used to compute metrics at evaluation. Must take a - [`~transformers.EvalPrediction`] and return a dictionary string to metric values. When passing - [`SFTConfig`] with `batch_eval_metrics` set to `True`, your `compute_metrics` function must take a boolean - `compute_result` argument. This will be triggered after the last eval batch to signal that the function - needs to calculate and return the global summary statistics rather than accumulating the batch-level - statistics. - callbacks (list of [`~transformers.TrainerCallback`], *optional*): - List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed - in [here](https://huggingface.co/docs/transformers/main_classes/callback). - - If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`] - method. - optimizers (`tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]]`, *optional*, defaults to `(None, None)`): - A tuple containing the optimizer and the scheduler to use. Will default to an instance of `AdamW` on your - model and a scheduler given by [`~transformers.get_linear_schedule_with_warmup`] controlled by `args`. - optimizer_cls_and_kwargs (`tuple[Type[torch.optim.Optimizer], Dict[str, Any]]`, *optional*): - A tuple containing the optimizer class and keyword arguments to use. Overrides `optim` and `optim_args` in - `args`. Incompatible with the `optimizers` argument. - - Unlike `optimizers`, this argument avoids the need to place model parameters on the correct devices before - initializing the Trainer. - preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*): - A function that preprocess the logits right before caching them at each evaluation step. Must take two - tensors, the logits and the labels, and return the logits once processed as desired. The modifications made - by this function will be reflected in the predictions received by `compute_metrics`. - - Note that the labels (second parameter) will be `None` if the dataset does not have them. - peft_config ([`~peft.PeftConfig`], *optional*): - PEFT configuration used to wrap the model. If `None`, the model is not wrapped. - formatting_func (`Callable`, *optional*): - Formatting function applied to the dataset before tokenization. Applying the formatting function explicitly - converts the dataset into a [language modeling](#language-modeling) type. - - """ - def __init__( - self, - model, - args = None, - data_collator = None, - train_dataset = None, - eval_dataset = None, - processing_class = None, - compute_loss_func = None, - compute_metrics = None, - callbacks = None, - optimizer_cls_and_kwargs = None, - preprocess_logits_for_metrics = None, - peft_config = None, - formatting_func = None, - **kwargs - ): - if args is None: args = UnslothSFTConfig() - use_bf16 = getattr(args, 'bf16', False) - if type(use_bf16) is not bool: use_bf16 = False - use_fp16 = getattr(args, 'fp16', False) - if type(use_fp16) is not bool: use_fp16 = False - force_float32 = False - full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1' - if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'): - print('Unsloth: Switching to float32 training since model cannot work with float16') - force_float32 = True - mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') - dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None) - if dtype is None: dtype = model.get_input_embeddings().weight.dtype - from unsloth_zoo.utils import _get_dtype - dtype = _get_dtype(dtype) - float16 = dtype == torch.float16 - if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`') - if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`') - if force_float32: - # Forced float32 training - args.fp16 = False - args.bf16 = False - os.environ['ACCELERATE_MIXED_PRECISION'] = 'no' - if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no' - # args.mixed_precision is a new argument which needs to be set now - elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32': - # Mixed precision training - args.fp16 = float16 - args.bf16 = not float16 - os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16' - if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16' - # args.mixed_precision is a new argument which needs to be set now - elif mixed_precision_dtype == 'bfloat16': - # Both False since bfloat16 full finetuning doesn't do any autocasting. - args.fp16 = False - args.bf16 = False - os.environ['ACCELERATE_MIXED_PRECISION'] = 'no' - if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no' - # args.mixed_precision is a new argument which needs to be set now - - if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no': - args.eval_strategy = 'steps' - if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1 - ga_steps = getattr(args, 'gradient_accumulation_steps', None) - if ga_steps is not None and ga_steps > 1: - from transformers import __version__ as transformers_version - if Version(transformers_version) <= Version('4.45.2'): - print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n' - '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`') - if getattr(args, 'eval_strategy', 'no') != 'no': - eval_bsz = getattr(args, 'per_device_eval_batch_size', 8) - if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size - if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps - fp16_full_eval = getattr(args, 'fp16_full_eval', False) - if type(fp16_full_eval) is not bool: fp16_full_eval = False - bf16_full_eval = getattr(args, 'bf16_full_eval', False) - if type(bf16_full_eval) is not bool: bf16_full_eval = False - if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True - if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False - if force_float32: - args.bf16_full_eval = False - args.fp16_full_eval = False - elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16': - args.bf16_full_eval = True - args.fp16_full_eval = False - elif not bf16_full_eval and not fp16_full_eval: - args.bf16_full_eval = args.bf16 - args.fp16_full_eval = args.fp16 - _output_logits = False - if locals().get('compute_metrics', None) is not None: _output_logits = True - if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True - if _output_logits: - os.environ['UNSLOTH_RETURN_LOGITS'] = '1' - if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'): - pass - else: - model_max_seq_length = getattr(model, 'max_seq_length', None) - args_max_seq_length = getattr(args, 'max_seq_length', None) - if args_max_seq_length is None and model_max_seq_length is not None: - max_seq_length = model.max_seq_length - if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length - elif args_max_seq_length is not None and model_max_seq_length is not None: - if args_max_seq_length > model_max_seq_length: - print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but ' - 'the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.') - args.max_seq_length = model_max_seq_length - if 'max_length' not in locals() and not hasattr(args, 'max_length'): - pass - else: - if hasattr(args, 'max_seq_length') and args.max_seq_length is not None and args.max_seq_length > 0: - if hasattr(args, 'max_length'): - args.max_length = args.max_seq_length - max_length = args.max_length - else: - model_max_length = getattr(model, 'max_seq_length', None) - if model_max_length is None: model_max_length = getattr(model, 'max_length', None) - if model_max_length is not None: - args.max_length = model_max_length - max_length = args.max_length - elif hasattr(args, 'max_length') and args.max_length is not None: - max_length = args.max_length - # if we are here, then we are in a weird case where max_length is set but max_seq_length is not set - setattr(model, 'max_seq_length', max_length) - else: - print('Unsloth: We did not find `max_seq_length` or `max_length` in the model or args. We will set it to 1024.') - args.max_length = 1024 - if model is not None and hasattr(model, 'for_training'): - model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True)) - if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right' - if 'processing_class' in locals(): - if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right' - if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right' - __tokenizer = processing_class if 'processing_class' in locals() else tokenizer - from unsloth_zoo.vision_utils import UnslothVisionDataCollator - if not isinstance(data_collator, UnslothVisionDataCollator): - if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names: - data_collator = TransformersDataCollatorForLanguageModeling( - __tokenizer, - mlm = False, - mlm_probability = 0.0, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names: - data_collator = DataCollatorForSeq2Seq( - __tokenizer, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - else: - if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False - if hasattr(args, 'dataset_text_field'): args.dataset_text_field = '' - if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True} - if not isinstance(data_collator, UnslothVisionDataCollator): - if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'): - if isinstance(data_collator, DataCollatorForSeq2Seq): - data_collator = DataCollatorForSeq2Seq( - __tokenizer.tokenizer, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - else: - data_collator = TransformersDataCollatorForLanguageModeling( - __tokenizer.tokenizer, - mlm = False, - mlm_probability = 0.0, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - other_metrics = [] - - from unsloth_zoo.logging_utils import PatchRLStatistics - PatchRLStatistics('sft_trainer', other_metrics) - IGNORED_TOKENIZER_NAMES = os.environ.get('UNSLOTH_IGNORED_TOKENIZER_NAMES', '').split('\n') - from unsloth_zoo.tokenizer_utils import fix_untrained_tokens - from unsloth_zoo.training_utils import fix_zero_training_loss - if 'tokenizer' not in locals(): tokenizer = processing_class - fix_untrained_tokens(model, tokenizer, train_dataset, IGNORED_TOKENIZER_NAMES, eps = 1e-16) - fix_zero_training_loss(model, tokenizer, train_dataset) - - # [TODO] Fix up DataParallel multiplying batch sizes - # [TODO] DDP works, but DP seems to not work? [TODO] - if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1: - if getattr(args, "_n_gpu", 1) != 1: - args._n_gpu = 1 - if "model" in locals() and hasattr(model, "for_training"): - model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True)) - super().__init__( - model = model, - args = args, - data_collator = data_collator, - train_dataset = train_dataset, - eval_dataset = eval_dataset, - processing_class = processing_class, - compute_loss_func = compute_loss_func, - compute_metrics = compute_metrics, - callbacks = callbacks, - optimizer_cls_and_kwargs = optimizer_cls_and_kwargs, - preprocess_logits_for_metrics = preprocess_logits_for_metrics, - peft_config = peft_config, - formatting_func = formatting_func,**kwargs) - if "model" in locals() and hasattr(model, "for_inference"): - model.for_inference() - if hasattr(self, 'neftune_hook_handle'): - self.neftune_hook_handle.remove() - if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle - if getattr(args, 'neftune_noise_alpha', None) is not None: - model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha - pass - if hasattr(self, 'accelerator'): - scaler = self.accelerator.scaler - current_model = model - while hasattr(current_model, 'model'): - current_model.accelerator_scaler = scaler - current_model = current_model.model - current_model.accelerator_scaler = scaler - pass - if hasattr(self, 'train'): - self.train = MethodType(prepare_for_training_mode(self.__class__.train), self) - pass - if hasattr(self, 'llm') and self.llm is not None and hasattr(self.llm, 'get_tokenizer'): - _vllm_tok = self.llm.get_tokenizer() - _pc = getattr(self, 'processing_class', None) or getattr(self, 'tokenizer', None) - if _vllm_tok is not None and _pc is not None and getattr(_pc, 'chat_template', None) is not None and getattr(_vllm_tok, 'chat_template', None) is None: - _vllm_tok.chat_template = _pc.chat_template - pass - -pass - - -if hasattr(logger, "addFilter"): - import logging - class HideLoggingMessage(logging.Filter): - def __init__(self, text): self.text = text - def filter(self, x): return not (self.text in x.getMessage()) - pass - logger.addFilter(HideLoggingMessage("`use_cache=True`")) - diff --git a/unsloth_compiled_cache/UnslothXPOTrainer.py b/unsloth_compiled_cache/UnslothXPOTrainer.py deleted file mode 100644 index 4da51fc..0000000 --- a/unsloth_compiled_cache/UnslothXPOTrainer.py +++ /dev/null @@ -1,1409 +0,0 @@ -""" -2026.2.1 -2026.2.1 -4.57.6 -0.24.0 -__UNSLOTH_VERSIONING__ -""" - -# Unsloth auto generated code -# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see . - -from torch import Tensor -import torch -import torch.nn as nn -from torch.nn import functional as F -from unsloth_zoo.temporary_patches.common import torch_compile -from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable -from trl.trainer.xpo_trainer import (Any, BaseImageProcessor, BasePairwiseJudge, Callable, Dataset, EvalPrediction, F, FeatureExtractionMixin, IterableDataset, OnlineDPOTrainer, OptimizerNames, Optional, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SIMPLE_CHAT_TEMPLATE, TrainerCallback, Union, XPOConfig, XPOTrainer, empty_cache, get_reward, is_conversational, is_peft_available, jinja2, maybe_apply_chat_template, nn, selective_log_softmax, textwrap, torch, truncate_right, unwrap_model_for_generation) - - -import os -from typing import * -from dataclasses import dataclass, field -from packaging.version import Version -import torch -import numpy as np -from contextlib import nullcontext -from torch.nn import functional as F -import inspect -from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling -from transformers.training_args import ParallelMode -from unsloth_zoo.device_type import DEVICE_TYPE, device_synchronize - -# Wrap trainer with padding to right and enable training mode -# Also patches W&B since multiple runs must use wandb.finish() -import functools -from types import MethodType -try: - from unsloth_zoo.gradient_checkpointing import reset_unsloth_gradient_checkpointing_buffers -except: - def reset_unsloth_gradient_checkpointing_buffers(): pass -def prepare_for_training_mode(f): - @functools.wraps(f) - def wrapper(self, *args, **kwargs): - # Enable training mode - _was_training = None - # Get gradient checkpointing setting from training arguments - use_gc = getattr(self.args, 'gradient_checkpointing', True) - if hasattr(self, 'model') and hasattr(self.model, "training"): - _was_training = self.model.training - if hasattr(self, 'model') and hasattr(self.model, "for_training"): - self.model.for_training(use_gradient_checkpointing=use_gc) - output = f(self, *args, **kwargs) - # Restore previous mode when possible - if hasattr(self, 'model') and hasattr(self.model, "for_inference"): - if _was_training is False: - self.model.for_inference() - elif _was_training is True and hasattr(self.model, "for_training"): - self.model.for_training(use_gradient_checkpointing=use_gc) - # Reset gradient checkpointing buffers to free memory while staying ready for next run - try: - reset_unsloth_gradient_checkpointing_buffers() - except: - pass - # Patch W&B to enable logging on future runs, otherwise it'll overwrite the first run - try: - import wandb - wandb.finish() - except: - pass - return output - return wrapper -pass - -torch_compile_options = { - "epilogue_fusion" : True, - "max_autotune" : False, - "shape_padding" : True, - "trace.enabled" : False, - "triton.cudagraphs" : False, -} - -@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) -def chunked_hidden_states_selective_log_softmax( - hidden_states: torch.Tensor, - lm_head: torch.Tensor, - index: torch.Tensor, - chunks: int = 4, - logit_scale_multiply: float = 0.0, - logit_scale_divide: float = 0.0, - logit_softcapping: float = 0.0, - temperature: float = 1.0, -) -> torch.Tensor: - # All Unsloth Zoo code licensed under AGPL3 - flat_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1]) - flat_index = index.reshape(-1) - - chunked_hidden_states = torch.chunk(flat_hidden_states, chunks=chunks, dim=0) - chunked_index = torch.chunk(flat_index, chunks=chunks, dim=0) - - all_per_token_logps = [] - - for chunk_hidden_states, chunk_index in zip(chunked_hidden_states, chunked_index): - chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t() - - if logit_scale_multiply != 0.0: - chunk_logits = chunk_logits * logit_scale_multiply - if logit_scale_divide != 0.0: - chunk_logits = chunk_logits / logit_scale_divide - if logit_softcapping != 0.0: - chunk_logits = chunk_logits * torch.tanh(chunk_logits / logit_softcapping) - - chunk_logits = chunk_logits.to(torch.float32) - - if temperature != 1.0: - chunk_logits = chunk_logits / temperature - - selected_logits = torch.gather(chunk_logits, dim=-1, index=chunk_index.unsqueeze(-1)).squeeze(-1) - logsumexp_values = torch.logsumexp(chunk_logits, dim=-1) - per_token_logps = selected_logits - logsumexp_values - all_per_token_logps.append(per_token_logps) - - all_per_token_logps = torch.concat(all_per_token_logps) - - all_per_token_logps = all_per_token_logps.reshape((hidden_states.shape[0], hidden_states.shape[1])) - return all_per_token_logps - -@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) -def chunked_selective_log_softmax(logits, index): - # Split into 4 chunks only - chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0) - chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0) - all_per_token_logps = [] - # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index) - for chunk_logits, chunk_index in zip(chunked_logits, chunked_index): - chunk_logits = chunk_logits.to(torch.float32) - selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1) - logsumexp_values = torch.logsumexp(chunk_logits, dim = -1) - per_token_logps = selected_logits - logsumexp_values - all_per_token_logps.append(per_token_logps) - pass - all_per_token_logps = torch.concat(all_per_token_logps) - all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1])) - return all_per_token_logps - -def calculate_pad_tokens_in_prompt( - input_ids: torch.Tensor, - logits_to_keep: int, - pad_token_id: int -) -> torch.Tensor: - """ - Given prompt tensor, it returns all the left padded tokens in that sequence. so [pad, pad, pad, cat] = 3 tokens - """ - if logits_to_keep >= input_ids.shape[1]: - raise ValueError("logits_to_keep must be smaller than the sequence length.") - - prompt_section = input_ids[:, :-logits_to_keep] - - padding_mask = (prompt_section == pad_token_id) - - pad_token_counts = padding_mask.sum(dim=1) - - return pad_token_counts - -def create_completion_attention_mask( - completion_input_ids: torch.Tensor, - left_pad_tokens_per_prompt: torch.Tensor, - max_left_pad: int, - pad_token_id: int -) -> torch.Tensor: - """ - Given that we have a sequence, [p,p,p,c,c,c,pad,pad,pad] - - Where p are extra prompt tokens we got from slicing the torch tensor, c is completion tokens - and pad are pad tokens, this function would make a completion mask that would 0 out the pad - and p tokens. so in this example [0,0,0,1,1,1,0,0,0] - """ - batch_size, completion_len = completion_input_ids.shape - device = completion_input_ids.device - - num_tokens_to_mask = max_left_pad - left_pad_tokens_per_prompt - - indices = torch.arange(completion_len, device=device).unsqueeze(0) - shift_mask = indices >= num_tokens_to_mask.unsqueeze(1) - - non_padding_mask = (completion_input_ids != pad_token_id) - - final_mask = shift_mask & non_padding_mask - - return final_mask - -def left_pack_padding(tensor: torch.Tensor, pad_id: int) -> torch.Tensor: - """ - Moves all padding tokens in each sequence of a batch to the right. - """ - mask = (tensor != pad_id) - # Must do stable=True since binary mark is unordered - sorted_indices = torch.argsort(mask, dim=1, descending=True, stable=True) - packed_tensor = torch.gather(tensor, 1, sorted_indices) - return packed_tensor - -def align_logprobs_with_mask( - logprob_tensor: torch.Tensor, - attention_mask: torch.Tensor, - pad_value: float = 0.0 -) -> torch.Tensor: - """ - Aligns a log probability tensor with a given attention mask. - """ - - device = logprob_tensor.device - batch_size, logprob_seq_len = logprob_tensor.shape - mask_seq_len = attention_mask.shape[1] - - padded_logprobs = torch.full( - attention_mask.shape, - fill_value=pad_value, - dtype=logprob_tensor.dtype, - device=device - ) - - left_pad_counts = torch.argmax(attention_mask, dim=1) - - cols = torch.arange(logprob_seq_len, device=device) - dest_indices = left_pad_counts.unsqueeze(1) + cols - - # Create destination row indices - # Shape: [batch_size, logprob_seq_len] - row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(dest_indices) - - # --- 4. Filter out-of-bounds indices and perform assignment --- - # Create a mask to identify only the indices that are within the bounds - # of the target tensor's sequence length. - valid_mask = dest_indices < mask_seq_len - - # Use this mask to select only the valid row indices, column indices, - # and the corresponding values from the logprob tensor. - # This flattens the selected elements into 1D tensors. - valid_rows = row_indices[valid_mask] - valid_cols = dest_indices[valid_mask] - valid_vals = logprob_tensor[valid_mask] - - # Place the valid values into their correct positions in the padded tensor - # using a single, efficient advanced indexing operation. - padded_logprobs[valid_rows, valid_cols] = valid_vals - - return padded_logprobs - -def autotune_batch_and_chunks( - total_input_rows, - seq_len, - hidden_size, - vocab_size, - dtype_bytes=16, - multiplier=None -): - if multiplier is None: - final_m = max(4, seq_len // 4096) - else: - final_m = multiplier - - if torch.cuda.is_available(): - free_bytes, _ = torch.cuda.mem_get_info() - limit_gb = (free_bytes / (1024**3))*.80 - elif hasattr(torch, "xpu") and torch.xpu.is_available(): - # For XPU: estimate free memory from total - reserved - total_mem = torch.xpu.get_device_properties(0).total_memory - reserved_mem = torch.xpu.memory_reserved() - free_bytes = total_mem - reserved_mem - limit_gb = (free_bytes / (1024**3)) * 0.80 - else: - # Fallback: assume 8GB available - limit_gb = 8.0 - - bytes_to_gb = 1024**3 - - b_vals = torch.arange(total_input_rows, 0, -1, device='cpu', dtype=torch.float32) - - hidden_gb = (b_vals * seq_len * hidden_size * dtype_bytes) / bytes_to_gb - - base_logits = ((b_vals/total_input_rows) * b_vals * seq_len * vocab_size * dtype_bytes) / bytes_to_gb - logits_gb = base_logits / final_m - - total_mem_gb = hidden_gb + logits_gb - - valid_mask = total_mem_gb <= limit_gb - valid_indices = torch.nonzero(valid_mask, as_tuple=False) - - if valid_indices.shape[0] == 0: - #This means your GPU will OOM - return 4, final_m - - best_idx = valid_indices[0].item() - final_b = int(b_vals[best_idx].item()) - - return final_b, final_m -@dataclass -class UnslothXPOConfig(XPOConfig): - """ - - Configuration class for the [`XPOTrainer`]. - - Subclass of [`OnlineDPOConfig`] we can use all its arguments and add the following: - - Parameters: - alpha (`float` or `list[float]`, *optional*, defaults to `1e-5`): - Weight of the XPO loss term. If a list of floats is provided then the alpha is selected for each new epoch - and the last alpha is used for the rest of the epochs. - - """ - vllm_sampling_params: Optional[Any] = field( - default = None, - metadata = {'help': 'vLLM SamplingParams'}, - ) - unsloth_num_chunks : Optional[int] = field( - default = -1, - metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'}, - ) - unsloth_logit_chunk_multiplier : Optional[int] = field( - default = None, - metadata = {'help': 'Multiplier for chunked logit computations.'}, - ) - unsloth_grpo_mini_batch : Optional[int] = field( - default = None, - metadata = {'help': 'Mini batch size for GRPO hidden state accumulation. Default is None unless user defines it.'}, - ) - max_seq_length : Optional[int] = field( - default = None, - metadata = {'help': 'Maximum sequence length to truncate to.'}, - ) - def __init__( - self, - output_dir = None, - overwrite_output_dir = None, - do_train = False, - do_eval = False, - do_predict = False, - eval_strategy = 'no', - prediction_loss_only = False, - per_device_train_batch_size = 4, - per_device_eval_batch_size = 4, - per_gpu_train_batch_size = None, - per_gpu_eval_batch_size = None, - gradient_accumulation_steps = 2, - eval_accumulation_steps = 2, - eval_delay = 0, - torch_empty_cache_steps = 250, - learning_rate = 5e-05, - weight_decay = 0.01, - adam_beta1 = 0.9, - adam_beta2 = 0.999, - adam_epsilon = 1e-08, - max_grad_norm = 1.0, - num_train_epochs = 3.0, - max_steps = -1, - lr_scheduler_type = 'linear', - lr_scheduler_kwargs = None, - warmup_ratio = 0.1, - warmup_steps = 0, - log_level = 'passive', - log_level_replica = 'warning', - log_on_each_node = True, - logging_dir = None, - logging_strategy = 'steps', - logging_first_step = False, - logging_steps = 1, - logging_nan_inf_filter = False, - save_strategy = 'steps', - save_steps = 500, - save_total_limit = None, - save_safetensors = True, - save_on_each_node = False, - save_only_model = False, - restore_callback_states_from_checkpoint = False, - no_cuda = False, - use_cpu = False, - use_mps_device = False, - seed = 3407, - data_seed = 3407, - jit_mode_eval = False, - bf16 = False, - fp16 = False, - fp16_opt_level = 'O1', - half_precision_backend = 'auto', - bf16_full_eval = False, - fp16_full_eval = False, - tf32 = None, - local_rank = -1, - ddp_backend = None, - tpu_num_cores = None, - tpu_metrics_debug = False, - debug = '', - dataloader_drop_last = False, - eval_steps = None, - dataloader_num_workers = 0, - dataloader_prefetch_factor = None, - past_index = -1, - run_name = None, - disable_tqdm = None, - remove_unused_columns = True, - label_names = None, - load_best_model_at_end = False, - metric_for_best_model = None, - greater_is_better = None, - ignore_data_skip = False, - fsdp = None, - fsdp_min_num_params = 0, - fsdp_config = None, - fsdp_transformer_layer_cls_to_wrap = None, - accelerator_config = None, - parallelism_config = None, - deepspeed = None, - label_smoothing_factor = 0.0, - optim = 'adamw_8bit', - optim_args = None, - adafactor = False, - group_by_length = False, - length_column_name = 'length', - report_to = 'none', - project = 'huggingface', - trackio_space_id = 'trackio', - ddp_find_unused_parameters = None, - ddp_bucket_cap_mb = None, - ddp_broadcast_buffers = None, - dataloader_pin_memory = True, - dataloader_persistent_workers = False, - skip_memory_metrics = True, - use_legacy_prediction_loop = False, - push_to_hub = False, - resume_from_checkpoint = None, - hub_model_id = None, - hub_strategy = 'every_save', - hub_token = None, - hub_private_repo = None, - hub_always_push = False, - hub_revision = None, - gradient_checkpointing = True, - gradient_checkpointing_kwargs = None, - include_inputs_for_metrics = False, - eval_do_concat_batches = True, - fp16_backend = 'auto', - push_to_hub_model_id = None, - push_to_hub_organization = None, - push_to_hub_token = None, - mp_parameters = '', - auto_find_batch_size = False, - full_determinism = False, - torchdynamo = None, - ray_scope = 'last', - ddp_timeout = 1800, - torch_compile = False, - torch_compile_backend = None, - torch_compile_mode = None, - include_tokens_per_second = False, - include_num_input_tokens_seen = False, - neftune_noise_alpha = None, - optim_target_modules = None, - batch_eval_metrics = False, - eval_on_start = False, - use_liger_kernel = False, - liger_kernel_config = None, - eval_use_gather_object = False, - average_tokens_across_devices = True, - reward_model_path = None, - judge = None, - max_new_tokens = 64, - max_length = 512, - temperature = 0.9, - top_p = 1.0, - top_k = None, - min_p = None, - repetition_penalty = 1.0, - generation_kwargs = {}, - use_transformers_paged = False, - cache_implementation = None, - missing_eos_penalty = None, - loss_type = 'sigmoid', - disable_dropout = True, - use_vllm = False, - vllm_model_impl = 'vllm', - vllm_guided_decoding_regex = None, - vllm_gpu_memory_utilization = 0.55, - vllm_mode = 'colocate', - vllm_server_base_url = None, - vllm_server_host = '0.0.0.0', - vllm_server_port = 8000, - vllm_server_timeout = 240.0, - vllm_tensor_parallel_size = 1, - ds3_gather_for_generation = True, - model_init_kwargs = None, - reward_weights = None, - dataset_num_proc = None, - gpu_memory_utilization = None, - vllm_sampling_params = None, - unsloth_num_chunks = -1, - unsloth_logit_chunk_multiplier = None, - unsloth_grpo_mini_batch = None, - max_seq_length = None, - **kwargs, - ): - if learning_rate < 1e-7: print(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!') - if learning_rate > 1: print(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!') - if num_train_epochs is None: - num_train_epochs = 3.0 # Default to 3 epochs if None, max_steps will override - if output_dir is None and save_strategy == 'steps' and save_steps == 500: - output_dir = 'unsloth_training_checkpoints' - save_strategy = 'no' - import multiprocessing as _mp - if _mp.get_start_method() != 'fork': - dataset_num_proc = None - elif dataset_num_proc is None: - import psutil - dataset_num_proc = min(max((psutil.cpu_count() or 1)+4, 2), 64) - memory_gb_left = psutil.virtual_memory().available / (1024**3) - if memory_gb_left <= 2: dataset_num_proc = 1 - else: dataset_num_proc = min(dataset_num_proc, int(memory_gb_left)) - if temperature <= 0: - raise ValueError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.') - elif temperature >= 10: - raise ValueError('Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic.') - - - super().__init__( - output_dir = output_dir, - overwrite_output_dir = overwrite_output_dir, - do_train = do_train, - do_eval = do_eval, - do_predict = do_predict, - eval_strategy = eval_strategy, - prediction_loss_only = prediction_loss_only, - per_device_train_batch_size = per_device_train_batch_size, - per_device_eval_batch_size = per_device_eval_batch_size, - per_gpu_train_batch_size = per_gpu_train_batch_size, - per_gpu_eval_batch_size = per_gpu_eval_batch_size, - gradient_accumulation_steps = gradient_accumulation_steps, - eval_accumulation_steps = eval_accumulation_steps, - eval_delay = eval_delay, - torch_empty_cache_steps = torch_empty_cache_steps, - learning_rate = learning_rate, - weight_decay = weight_decay, - adam_beta1 = adam_beta1, - adam_beta2 = adam_beta2, - adam_epsilon = adam_epsilon, - max_grad_norm = max_grad_norm, - num_train_epochs = num_train_epochs, - max_steps = max_steps, - lr_scheduler_type = lr_scheduler_type, - lr_scheduler_kwargs = lr_scheduler_kwargs, - warmup_ratio = warmup_ratio, - warmup_steps = warmup_steps, - log_level = log_level, - log_level_replica = log_level_replica, - log_on_each_node = log_on_each_node, - logging_dir = logging_dir, - logging_strategy = logging_strategy, - logging_first_step = logging_first_step, - logging_steps = logging_steps, - logging_nan_inf_filter = logging_nan_inf_filter, - save_strategy = save_strategy, - save_steps = save_steps, - save_total_limit = save_total_limit, - save_safetensors = save_safetensors, - save_on_each_node = save_on_each_node, - save_only_model = save_only_model, - restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint, - no_cuda = no_cuda, - use_cpu = use_cpu, - use_mps_device = use_mps_device, - seed = seed, - data_seed = data_seed, - jit_mode_eval = jit_mode_eval, - bf16 = bf16, - fp16 = fp16, - fp16_opt_level = fp16_opt_level, - half_precision_backend = half_precision_backend, - bf16_full_eval = bf16_full_eval, - fp16_full_eval = fp16_full_eval, - tf32 = tf32, - local_rank = local_rank, - ddp_backend = ddp_backend, - tpu_num_cores = tpu_num_cores, - tpu_metrics_debug = tpu_metrics_debug, - debug = debug, - dataloader_drop_last = dataloader_drop_last, - eval_steps = eval_steps, - dataloader_num_workers = dataloader_num_workers, - dataloader_prefetch_factor = dataloader_prefetch_factor, - past_index = past_index, - run_name = run_name, - disable_tqdm = disable_tqdm, - remove_unused_columns = remove_unused_columns, - label_names = label_names, - load_best_model_at_end = load_best_model_at_end, - metric_for_best_model = metric_for_best_model, - greater_is_better = greater_is_better, - ignore_data_skip = ignore_data_skip, - fsdp = fsdp, - fsdp_min_num_params = fsdp_min_num_params, - fsdp_config = fsdp_config, - fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap, - accelerator_config = accelerator_config, - parallelism_config = parallelism_config, - deepspeed = deepspeed, - label_smoothing_factor = label_smoothing_factor, - optim = optim, - optim_args = optim_args, - adafactor = adafactor, - group_by_length = group_by_length, - length_column_name = length_column_name, - report_to = report_to, - project = project, - trackio_space_id = trackio_space_id, - ddp_find_unused_parameters = ddp_find_unused_parameters, - ddp_bucket_cap_mb = ddp_bucket_cap_mb, - ddp_broadcast_buffers = ddp_broadcast_buffers, - dataloader_pin_memory = dataloader_pin_memory, - dataloader_persistent_workers = dataloader_persistent_workers, - skip_memory_metrics = skip_memory_metrics, - use_legacy_prediction_loop = use_legacy_prediction_loop, - push_to_hub = push_to_hub, - resume_from_checkpoint = resume_from_checkpoint, - hub_model_id = hub_model_id, - hub_strategy = hub_strategy, - hub_token = hub_token, - hub_private_repo = hub_private_repo, - hub_always_push = hub_always_push, - hub_revision = hub_revision, - gradient_checkpointing = gradient_checkpointing, - gradient_checkpointing_kwargs = gradient_checkpointing_kwargs, - include_inputs_for_metrics = include_inputs_for_metrics, - eval_do_concat_batches = eval_do_concat_batches, - fp16_backend = fp16_backend, - push_to_hub_model_id = push_to_hub_model_id, - push_to_hub_organization = push_to_hub_organization, - push_to_hub_token = push_to_hub_token, - mp_parameters = mp_parameters, - auto_find_batch_size = auto_find_batch_size, - full_determinism = full_determinism, - torchdynamo = torchdynamo, - ray_scope = ray_scope, - ddp_timeout = ddp_timeout, - torch_compile = torch_compile, - torch_compile_backend = torch_compile_backend, - torch_compile_mode = torch_compile_mode, - include_tokens_per_second = include_tokens_per_second, - include_num_input_tokens_seen = include_num_input_tokens_seen, - neftune_noise_alpha = neftune_noise_alpha, - optim_target_modules = optim_target_modules, - batch_eval_metrics = batch_eval_metrics, - eval_on_start = eval_on_start, - use_liger_kernel = use_liger_kernel, - liger_kernel_config = liger_kernel_config, - eval_use_gather_object = eval_use_gather_object, - average_tokens_across_devices = average_tokens_across_devices, - reward_model_path = reward_model_path, - judge = judge, - max_new_tokens = max_new_tokens, - max_length = max_length, - temperature = temperature, - top_p = top_p, - top_k = top_k, - min_p = min_p, - repetition_penalty = repetition_penalty, - generation_kwargs = generation_kwargs, - use_transformers_paged = use_transformers_paged, - cache_implementation = cache_implementation, - missing_eos_penalty = missing_eos_penalty, - loss_type = loss_type, - disable_dropout = disable_dropout, - use_vllm = use_vllm, - vllm_model_impl = vllm_model_impl, - vllm_guided_decoding_regex = vllm_guided_decoding_regex, - vllm_gpu_memory_utilization = vllm_gpu_memory_utilization, - vllm_mode = vllm_mode, - vllm_server_base_url = vllm_server_base_url, - vllm_server_host = vllm_server_host, - vllm_server_port = vllm_server_port, - vllm_server_timeout = vllm_server_timeout, - vllm_tensor_parallel_size = vllm_tensor_parallel_size, - ds3_gather_for_generation = ds3_gather_for_generation, - model_init_kwargs = model_init_kwargs, - reward_weights = reward_weights, - dataset_num_proc = dataset_num_proc, - gpu_memory_utilization = gpu_memory_utilization,**kwargs) - self.vllm_sampling_params = vllm_sampling_params - self.unsloth_num_chunks = unsloth_num_chunks - if unsloth_grpo_mini_batch is not None: - if self.generation_batch_size >= unsloth_grpo_mini_batch: - self.unsloth_grpo_mini_batch = unsloth_grpo_mini_batch - else: - raise ValueError( - f"Unsloth GRPO mini batch size needs to be less than or equal to the effective generation batch size, " - f"which is self.per_device_train_batch_size * gradient_accumulation_steps." - ) - self.unsloth_logit_chunk_multiplier = unsloth_logit_chunk_multiplier - self.max_seq_length = max_seq_length - -pass - -class _UnslothXPOTrainer(OnlineDPOTrainer): - """""" - - _tag_names = ["trl", "xpo"] - _name = "XPO" - _paper = { - "title": "Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF", - "id": "2405.21046", - # docstyle-ignore - "citation": textwrap.dedent("""\ - @article{jung2024binary, - title = {{Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF}}, - author = {Tengyang Xie and Dylan J. Foster and Akshay Krishnamurthy and Corby Rosset and Ahmed Awadallah and Alexander Rakhlin}, - year = 2024, - eprint = {arXiv:2405.21046} - }"""), - } - - def __init__( - self, - model: Union[PreTrainedModel, nn.Module] = None, - ref_model: Union[PreTrainedModel, nn.Module] = None, - reward_funcs: Optional[nn.Module] = None, - judge: Optional[BasePairwiseJudge] = None, - args: Optional[XPOConfig] = None, - data_collator: Optional[Callable] = None, - train_dataset: Optional[Union[Dataset, IterableDataset]] = None, - eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None, - processing_class: Optional[ - Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] - ] = None, - reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None, - peft_config: Optional[dict] = None, - compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None, - callbacks: Optional[list[TrainerCallback]] = None, - optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), - preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, - # Deprecated parameters - reward_model: Optional[Union[PreTrainedModel, nn.Module]] = None, - ) -> None: - super().__init__( - model=model, - ref_model=ref_model, - judge=judge, - reward_funcs=reward_funcs, - reward_model=reward_model, - args=args, - data_collator=data_collator, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - processing_class=processing_class, - reward_processing_classes=reward_processing_classes, - peft_config=peft_config, - compute_metrics=compute_metrics, - callbacks=callbacks, - optimizers=optimizers, - preprocess_logits_for_metrics=preprocess_logits_for_metrics, - ) - - self._alpha = self.args.alpha - - # Overwrite the stats dictionary to include XPO specific statistics - self.stats = { - # Remove "non_score_reward", "rlhf_reward", "scores" - # Add "loss/dpo", "loss/xpo" - "loss/dpo": [], - "loss/xpo": [], - "objective/kl": [], - "objective/entropy": [], - "rewards/chosen": [], - "rewards/rejected": [], - "rewards/accuracies": [], - "rewards/margins": [], - "logps/chosen": [], - "logps/rejected": [], - # Replace "contain_eos_token" by "model_contain_eos_token" and "ref_contain_eos_token" - "val/model_contain_eos_token": [], - "val/ref_contain_eos_token": [], - "alpha": [], - "beta": [], - } - if self.reward_funcs is not None: - if len(self.reward_funcs) != 1: - raise ValueError("XPOTrainer only supports one reward function/model.") - self.reward_funcs = self.reward_funcs[0] - self.stats["objective/model_scores"] = [] - self.stats["objective/ref_scores"] = [] - self.stats["objective/scores_margin"] = [] - - @property - def alpha(self): - if isinstance(self._alpha, list): - epoch = self.state.epoch - return self._alpha[epoch] if epoch < len(self._alpha) else self._alpha[-1] - else: - return self._alpha - - def _generate_completions(self, prompts, model): - with unwrap_model_for_generation(model, self.accelerator) as unwrapped_policy_model_for_gen: - model_output = unwrapped_policy_model_for_gen.generate( - input_ids=prompts["input_ids"], - attention_mask=prompts["attention_mask"], - generation_config=self.generation_config, - ) - - actual_model_for_ref_generation: torch.nn.Module - if self.ref_model is None: - unwrapped_main_model_for_ref_logic = self.accelerator.unwrap_model(model) - - if is_peft_available() and isinstance(unwrapped_main_model_for_ref_logic, PeftModel): - actual_model_for_ref_generation = unwrapped_main_model_for_ref_logic.get_base_model() - else: - actual_model_for_ref_generation = unwrapped_main_model_for_ref_logic - else: - actual_model_for_ref_generation = self.accelerator.unwrap_model(self.ref_model) - - with unwrap_model_for_generation(actual_model_for_ref_generation, self.accelerator) as final_ref_model_for_gen: - ref_output = final_ref_model_for_gen.generate( - input_ids=prompts["input_ids"], - attention_mask=prompts["attention_mask"], - generation_config=self.generation_config, - ) - - return model_output, ref_output - - def _process_completions(self, model_output, ref_output, prompts): - context_length = prompts["input_ids"].shape[1] - - # Process model completions - model_completion_ids = model_output[:, context_length:] - model_completion_ids, model_completion_mask = truncate_right( - model_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id - ) - model_data = { - "input_ids": torch.cat((prompts["input_ids"], model_completion_ids), dim=1), - "attention_mask": torch.cat((prompts["attention_mask"], model_completion_mask), dim=1), - "raw": prompts["raw"], - } - - # Process reference model completions - ref_completion_ids = ref_output[:, context_length:] - ref_completion_ids, ref_completion_mask = truncate_right( - ref_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id - ) - ref_data = { - "input_ids": torch.cat((prompts["input_ids"], ref_completion_ids), dim=1), - "attention_mask": torch.cat((prompts["attention_mask"], ref_completion_mask), dim=1), - "raw": prompts["raw"], - } - - return model_data, ref_data - - def _compute_rewards(self, model_data, ref_data, context_length): - with torch.no_grad(): - _, model_scores, _ = get_reward( - self.reward_funcs, model_data["input_ids"], self.processing_class.pad_token_id, context_length - ) - _, ref_scores, _ = get_reward( - self.reward_funcs, ref_data["input_ids"], self.processing_class.pad_token_id, context_length - ) - - # Apply EOS penalty if needed - if self.args.missing_eos_penalty is not None: - model_contain_eos = torch.any(model_data["input_ids"] == self.processing_class.eos_token_id, dim=-1) - ref_contain_eos = torch.any(ref_data["input_ids"] == self.processing_class.eos_token_id, dim=-1) - model_scores[~model_contain_eos] -= self.args.missing_eos_penalty - ref_scores[~ref_contain_eos] -= self.args.missing_eos_penalty - - return model_scores, ref_scores - - def _compute_judge(self, model_data, ref_data, context_length): - prompts = model_data["raw"] - model_data_completions = self.processing_class.batch_decode( - model_data["input_ids"][:, context_length:], skip_special_tokens=True - ) - model_data_completions = [completion.strip() for completion in model_data_completions] - - ref_data_completions = self.processing_class.batch_decode( - ref_data["input_ids"][:, context_length:], skip_special_tokens=True - ) - ref_data_completions = [completion.strip() for completion in ref_data_completions] - - if is_conversational({"prompt": prompts[0]}): - model_data_completions = [ - [{"role": "assistant", "content": completion}] for completion in model_data_completions - ] - environment = jinja2.Environment() - template = environment.from_string(SIMPLE_CHAT_TEMPLATE) - prompts = [template.render(messages=message) for message in prompts] - model_data_completions = [template.render(messages=completion) for completion in model_data_completions] - - ref_data_completions = [ - [{"role": "assistant", "content": completion}] for completion in ref_data_completions - ] - ref_data_completions = [template.render(messages=completion) for completion in ref_data_completions] - - ranks_of_first_completion = self.judge.judge( - prompts, - list(zip(model_data_completions, ref_data_completions)), - ) - # convert ranks to a True/False mask: - # when rank == 0, it means the first completion is the best - # when rank == 1, it means the second completion is the best - return torch.tensor([rank == 0 for rank in ranks_of_first_completion], device=model_data["input_ids"].device) - - def _compute_logprobs(self, model, model_data, ref_data, context_length): - def compute_logprobs_for_data(m, data): - output = m(data["input_ids"], attention_mask=data["attention_mask"]) - logits = output.logits[:, context_length - 1 : -1] - token_logprobs = selective_log_softmax(logits, data["input_ids"][:, context_length:]) - return token_logprobs - - # Compute logprobs for model completions - model_logprobs_model_data = compute_logprobs_for_data(model, model_data) - # Compute logprobs for model on reference completions (for XPO loss) - model_logprobs_ref_data = compute_logprobs_for_data(model, ref_data) - - # Compute logprobs for reference model completions - with torch.no_grad(): - if self.ref_model is None: - with model.disable_adapter(): - ref_logprobs_model_data = compute_logprobs_for_data(model, model_data) - ref_logprobs_ref_data = compute_logprobs_for_data(model, ref_data) - else: - ref_logprobs_model_data = compute_logprobs_for_data(self.ref_model, model_data) - ref_logprobs_ref_data = compute_logprobs_for_data(self.ref_model, ref_data) - - # Mask padding tokens - model_padding_mask = model_data["attention_mask"][:, context_length:] == 0 - ref_padding_mask = ref_data["attention_mask"][:, context_length:] == 0 - model_logprobs_model_data = model_logprobs_model_data.masked_fill(model_padding_mask, 0.0) - model_logprobs_ref_data = model_logprobs_ref_data.masked_fill(ref_padding_mask, 0.0) - ref_logprobs_ref_data = ref_logprobs_ref_data.masked_fill(ref_padding_mask, 0.0) - ref_logprobs_model_data = ref_logprobs_model_data.masked_fill(model_padding_mask, 0.0) - - return model_logprobs_model_data, model_logprobs_ref_data, ref_logprobs_ref_data, ref_logprobs_model_data - - def _compute_losses( - self, - model_logprobs_model_data, - model_logprobs_ref_data, - ref_logprobs_ref_data, - ref_logprobs_model_data, - chosen_mask, - ): - # Compute log probs - model_logprobs_model_data_sum = model_logprobs_model_data.sum(1) - model_logprobs_ref_data_sum = model_logprobs_ref_data.sum(1) - ref_logprobs_ref_data_sum = ref_logprobs_ref_data.sum(1) - ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1) - - chosen_model_logprobs = torch.where(chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum) - chosen_ref_logprobs = torch.where(chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum) - chosen_log_ratios = chosen_model_logprobs - chosen_ref_logprobs - - rejected_model_logprobs = torch.where(~chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum) - rejected_ref_logprobs = torch.where(~chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum) - rejected_log_ratios = rejected_model_logprobs - rejected_ref_logprobs - - # Compute logits as the difference between chosen and rejected log ratios - logits = chosen_log_ratios - rejected_log_ratios - - if self.args.loss_type == "sigmoid": - dpo_losses = -F.logsigmoid(self.beta * logits) - elif self.args.loss_type == "ipo": - dpo_losses = (logits - 1 / (2 * self.beta)) ** 2 - else: - raise NotImplementedError(f"invalid loss type {self.args.loss_type}") - - # Compute XPO specific loss - xpo_losses = self.alpha * model_logprobs_ref_data_sum - - # Total loss - loss = (dpo_losses + xpo_losses).mean() - - return loss, dpo_losses, xpo_losses - - def _log_statistics( - self, - model_data, - ref_data, - model_logprobs_model_data, - model_logprobs_ref_data, - ref_logprobs_ref_data, - ref_logprobs_model_data, - chosen_mask, - dpo_losses, - xpo_losses, - context_length, - model_scores=None, - ref_scores=None, - ): - # Helper function to gather and compute mean - def gather_mean(tensor): - return self.accelerator.gather_for_metrics(tensor).mean().item() - - # Log losses - self.stats["loss/dpo"].append(gather_mean(dpo_losses)) - self.stats["loss/xpo"].append(gather_mean(xpo_losses)) - - # Log scores - if self.reward_funcs is not None: - self.stats["objective/model_scores"].append(gather_mean(model_scores)) - self.stats["objective/ref_scores"].append(gather_mean(ref_scores)) - self.stats["objective/scores_margin"].append(gather_mean(model_scores - ref_scores)) - - # Log logprobs - model_logprobs_model_data_sum = model_logprobs_model_data.sum(1) - model_logprobs_ref_data_sum = model_logprobs_ref_data.sum(1) - ref_logprobs_ref_data_sum = ref_logprobs_ref_data.sum(1) - ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1) - - chosen_model_logprobs = torch.where(chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum) - chosen_ref_logprobs = torch.where(chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum) - chosen_log_ratios = chosen_model_logprobs - chosen_ref_logprobs - - rejected_model_logprobs = torch.where(~chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum) - rejected_ref_logprobs = torch.where(~chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum) - rejected_log_ratios = rejected_model_logprobs - rejected_ref_logprobs - - self.stats["logps/chosen"].append(gather_mean(chosen_model_logprobs.mean() + chosen_ref_logprobs.mean())) - self.stats["logps/rejected"].append(gather_mean(rejected_model_logprobs.mean() + rejected_ref_logprobs.mean())) - - # Log rewards - # Compute various statistics - chosen_rewards = chosen_log_ratios * self.beta - rejected_rewards = rejected_log_ratios * self.beta - self.stats["rewards/chosen"].append(gather_mean(chosen_rewards.mean())) - self.stats["rewards/rejected"].append(gather_mean(rejected_rewards.mean())) - - # Calculate KL divergence for model and ref data - kl_model_data = model_logprobs_model_data - ref_logprobs_model_data - kl_ref_data = model_logprobs_ref_data - ref_logprobs_ref_data - mean_kl = (kl_model_data.sum(1) + kl_ref_data.sum(1)).mean() / 2 - self.stats["objective/kl"].append(gather_mean(mean_kl)) - - # Calculate entropy for model and ref data - entropy_model_data = -model_logprobs_model_data.sum(1) - entropy_ref_data = -model_logprobs_ref_data.sum(1) - mean_entropy = (entropy_model_data.mean() + entropy_ref_data.mean()) / 2 - self.stats["objective/entropy"].append(gather_mean(mean_entropy)) - - # Calculate margins - margin = chosen_rewards - rejected_rewards - self.stats["rewards/margins"].append(gather_mean(margin.mean())) - - # Calculate accuracy - accuracy = (margin > 0).float() - self.stats["rewards/accuracies"].append(gather_mean(accuracy.mean())) - - # Log EOS token statistics - model_eos = (model_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1) - ref_eos = (ref_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1) - self.stats["val/model_contain_eos_token"].append(gather_mean(model_eos.float())) - self.stats["val/ref_contain_eos_token"].append(gather_mean(ref_eos.float())) - - # Log alpha and beta - self.stats["alpha"].append(self.alpha) - self.stats["beta"].append(self.beta) - - def training_step( - self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None - ) -> torch.Tensor: - model.train() - - # Apply chat template and tokenize the input - batch_size = len(next(iter(inputs.values()))) - prompts = inputs["prompt"] - inputs = [{k: v[i] for k, v in inputs.items()} for i in range(batch_size)] - inputs = [maybe_apply_chat_template(x, self.processing_class) for x in inputs] - inputs = [self.tokenize_row(x, self.model.config.is_encoder_decoder, self.processing_class) for x in inputs] - inputs = self.data_collator(inputs) - - # need the prompt_ only - inputs = self._prepare_inputs(inputs) - context_length = inputs["prompt_input_ids"].shape[1] - prompts = { - "input_ids": inputs["prompt_input_ids"], - "attention_mask": inputs["prompt_attention_mask"], - "raw": prompts, - } - del inputs - - # Sample completions from both the model and the reference model - model_output, ref_output = self._generate_completions(prompts, model) - - # Process model completions - model_data, ref_data = self._process_completions(model_output, ref_output, prompts) - - # Compute rewards - if self.reward_funcs is not None: - model_scores, ref_scores = self._compute_rewards(model_data, ref_data, context_length) - chosen_mask = model_scores >= ref_scores - else: - model_scores, ref_scores = None, None - chosen_mask = self._compute_judge(model_data, ref_data, context_length) - - # Compute logprobs - model_logprobs_model_data, model_logprobs_ref_data, ref_logprobs_ref_data, ref_logprobs_model_data = ( - self._compute_logprobs(model, model_data, ref_data, context_length) - ) - - # Compute loss - loss, dpo_losses, xpo_losses = self._compute_losses( - model_logprobs_model_data, - model_logprobs_ref_data, - ref_logprobs_ref_data, - ref_logprobs_model_data, - chosen_mask, - ) - - # Log everything - self._log_statistics( - model_data, - ref_data, - model_logprobs_model_data.detach(), - model_logprobs_ref_data.detach(), - ref_logprobs_ref_data, - ref_logprobs_model_data, - chosen_mask, - dpo_losses.detach(), - xpo_losses.detach(), - context_length, - model_scores, - ref_scores, - ) - - if ( - self.args.torch_empty_cache_steps is not None - and self.state.global_step % self.args.torch_empty_cache_steps == 0 - ): - empty_cache() - - kwargs = {} - # For LOMO optimizers you need to explicitly use the learning rate - if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]: - kwargs["learning_rate"] = self._get_learning_rate() - - if self.args.n_gpu > 1: - loss = loss.mean() # mean() to average on multi-gpu parallel training - - self.accelerator.backward(loss, **kwargs) - - return loss.detach() / self.args.gradient_accumulation_steps -class UnslothXPOTrainer(_UnslothXPOTrainer): - """ - - Trainer for Exploratory Preference Optimization (XPO). - - It is implemented as a subclass of [`OnlineDPOTrainer`]. - - Args: - model ([`~transformers.PreTrainedModel`]): - The model to train, preferably an `AutoModelForCausalLM`. - ref_model ([`PreTrainedModelWrapper`]): - Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation - and loss. If no reference model is provided, the trainer will create a reference model with the same - architecture as the model to be optimized. - reward_funcs ([`~transformers.PreTrainedModel`]): - The reward model to score completions with, preferably an - [`~transformers.AutoModelForSequenceClassification`]. - judge ([`BasePairwiseJudge`]): - The judge to use for pairwise comparison of model completions. - args ([`XPOConfig`]): - The XPO config arguments to use for training. - data_collator ([`~transformers.DataCollator`]): - The data collator to use for training. If None is specified, the default data collator - ([`DPODataCollatorWithPadding`]) will be used which will pad the sequences to the maximum length of the - sequences in the batch, given a dataset of paired sequences. - train_dataset ([`~datasets.Dataset`]): - The dataset to use for training. - eval_dataset ([`~datasets.Dataset`]): - The dataset to use for evaluation. - processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*): - Processing class used to process the data. If provided, will be used to automatically process the inputs - for the model, and it will be saved along the model to make it easier to rerun an interrupted training or - reuse the fine-tuned model. - peft_config (`dict`): - The peft config to use for training. - compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): - The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to - metric values. - callbacks (`list[transformers.TrainerCallback]`): - The callbacks to use for training. - optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): - The optimizer and scheduler to use for training. - preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): - The function to use to preprocess the logits before computing the metrics. - - reward_model: - - - - This parameter is deprecated and will be removed in version 0.25.0. Use `reward_funcs` instead. - - - - """ - def __init__( - self, - model = None, - ref_model = None, - reward_funcs = None, - judge = None, - args = None, - data_collator = None, - train_dataset = None, - eval_dataset = None, - processing_class = None, - reward_processing_classes = None, - peft_config = None, - compute_metrics = None, - callbacks = None, - preprocess_logits_for_metrics = None, - reward_model = None, - **kwargs - ): - if args is None: args = UnslothXPOConfig() - use_bf16 = getattr(args, 'bf16', False) - if type(use_bf16) is not bool: use_bf16 = False - use_fp16 = getattr(args, 'fp16', False) - if type(use_fp16) is not bool: use_fp16 = False - force_float32 = False - full_finetuning = os.environ.get('UNSLOTH_ENABLE_FULL_FINETUNING', '0') == '1' - if not full_finetuning and (os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1'): - print('Unsloth: Switching to float32 training since model cannot work with float16') - force_float32 = True - mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') - dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None) - if dtype is None: dtype = model.get_input_embeddings().weight.dtype - from unsloth_zoo.utils import _get_dtype - dtype = _get_dtype(dtype) - float16 = dtype == torch.float16 - if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`') - if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`') - if force_float32: - # Forced float32 training - args.fp16 = False - args.bf16 = False - os.environ['ACCELERATE_MIXED_PRECISION'] = 'no' - if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no' - # args.mixed_precision is a new argument which needs to be set now - elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32': - # Mixed precision training - args.fp16 = float16 - args.bf16 = not float16 - os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16' - if hasattr(args, 'mixed_precision'): args.mixed_precision = 'fp16' if float16 else 'bf16' - # args.mixed_precision is a new argument which needs to be set now - elif mixed_precision_dtype == 'bfloat16': - # Both False since bfloat16 full finetuning doesn't do any autocasting. - args.fp16 = False - args.bf16 = False - os.environ['ACCELERATE_MIXED_PRECISION'] = 'no' - if hasattr(args, 'mixed_precision'): args.mixed_precision = 'no' - # args.mixed_precision is a new argument which needs to be set now - - if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no': - args.eval_strategy = 'steps' - if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1 - ga_steps = getattr(args, 'gradient_accumulation_steps', None) - if ga_steps is not None and ga_steps > 1: - from transformers import __version__ as transformers_version - if Version(transformers_version) <= Version('4.45.2'): - print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n' - '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`') - if getattr(args, 'eval_strategy', 'no') != 'no': - eval_bsz = getattr(args, 'per_device_eval_batch_size', 8) - if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size - if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps - fp16_full_eval = getattr(args, 'fp16_full_eval', False) - if type(fp16_full_eval) is not bool: fp16_full_eval = False - bf16_full_eval = getattr(args, 'bf16_full_eval', False) - if type(bf16_full_eval) is not bool: bf16_full_eval = False - if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True - if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False - if force_float32: - args.bf16_full_eval = False - args.fp16_full_eval = False - elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16': - args.bf16_full_eval = True - args.fp16_full_eval = False - elif not bf16_full_eval and not fp16_full_eval: - args.bf16_full_eval = args.bf16 - args.fp16_full_eval = args.fp16 - _output_logits = False - if locals().get('compute_metrics', None) is not None: _output_logits = True - if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True - if _output_logits: - os.environ['UNSLOTH_RETURN_LOGITS'] = '1' - if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'): - pass - else: - model_max_seq_length = getattr(model, 'max_seq_length', None) - args_max_seq_length = getattr(args, 'max_seq_length', None) - if args_max_seq_length is None and model_max_seq_length is not None: - max_seq_length = model.max_seq_length - if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length - elif args_max_seq_length is not None and model_max_seq_length is not None: - if args_max_seq_length > model_max_seq_length: - print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but ' - 'the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.') - args.max_seq_length = model_max_seq_length - if model is not None and hasattr(model, 'for_training'): - model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True)) - if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right' - if 'processing_class' in locals(): - if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right' - if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right' - __tokenizer = processing_class if 'processing_class' in locals() else tokenizer - from unsloth_zoo.vision_utils import UnslothVisionDataCollator - if not isinstance(data_collator, UnslothVisionDataCollator): - if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names: - data_collator = TransformersDataCollatorForLanguageModeling( - __tokenizer, - mlm = False, - mlm_probability = 0.0, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names: - data_collator = DataCollatorForSeq2Seq( - __tokenizer, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - else: - if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False - if hasattr(args, 'dataset_text_field'): args.dataset_text_field = '' - if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True} - if not isinstance(data_collator, UnslothVisionDataCollator): - if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'): - if isinstance(data_collator, DataCollatorForSeq2Seq): - data_collator = DataCollatorForSeq2Seq( - __tokenizer.tokenizer, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - else: - data_collator = TransformersDataCollatorForLanguageModeling( - __tokenizer.tokenizer, - mlm = False, - mlm_probability = 0.0, - pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None), - ) - other_metrics = [] - - from unsloth_zoo.logging_utils import PatchRLStatistics - PatchRLStatistics('xpo_trainer', other_metrics) - - # [TODO] Fix up DataParallel multiplying batch sizes - # [TODO] DDP works, but DP seems to not work? [TODO] - if getattr(args, "parallel_mode", None) == ParallelMode.NOT_DISTRIBUTED and args.n_gpu > 1: - if getattr(args, "_n_gpu", 1) != 1: - args._n_gpu = 1 - if "model" in locals() and hasattr(model, "for_training"): - model.for_training(use_gradient_checkpointing=getattr(args, 'gradient_checkpointing', True)) - super().__init__( - model = model, - ref_model = ref_model, - reward_funcs = reward_funcs, - judge = judge, - args = args, - data_collator = data_collator, - train_dataset = train_dataset, - eval_dataset = eval_dataset, - processing_class = processing_class, - reward_processing_classes = reward_processing_classes, - peft_config = peft_config, - compute_metrics = compute_metrics, - callbacks = callbacks, - preprocess_logits_for_metrics = preprocess_logits_for_metrics, - reward_model = reward_model,**kwargs) - if "model" in locals() and hasattr(model, "for_inference"): - model.for_inference() - if hasattr(self, 'neftune_hook_handle'): - self.neftune_hook_handle.remove() - if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle - if getattr(args, 'neftune_noise_alpha', None) is not None: - model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha - pass - if hasattr(self, 'accelerator'): - scaler = self.accelerator.scaler - current_model = model - while hasattr(current_model, 'model'): - current_model.accelerator_scaler = scaler - current_model = current_model.model - current_model.accelerator_scaler = scaler - pass - if hasattr(self, 'train'): - self.train = MethodType(prepare_for_training_mode(self.__class__.train), self) - pass - if hasattr(self, 'llm') and self.llm is not None and hasattr(self.llm, 'get_tokenizer'): - _vllm_tok = self.llm.get_tokenizer() - _pc = getattr(self, 'processing_class', None) or getattr(self, 'tokenizer', None) - if _vllm_tok is not None and _pc is not None and getattr(_pc, 'chat_template', None) is not None and getattr(_vllm_tok, 'chat_template', None) is None: - _vllm_tok.chat_template = _pc.chat_template - pass - -pass diff --git a/unsloth_compiled_cache/moe_utils.py b/unsloth_compiled_cache/moe_utils.py deleted file mode 100644 index b444c2f..0000000 --- a/unsloth_compiled_cache/moe_utils.py +++ /dev/null @@ -1,1251 +0,0 @@ -# Unsloth Zoo - Utilities for Unsloth -# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as published -# by the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Affero General Public License for more details. -# -# You should have received a copy of the GNU Affero General Public License -# along with this program. If not, see . -import torch -import torch.nn.functional as F -import os -import shutil -from typing import Optional, Tuple -from torch.autograd import Function -from .utils import logger - -# Get compile location -UNSLOTH_COMPILE_LOCATION = os.environ.get( - "UNSLOTH_COMPILE_LOCATION", "unsloth_compiled_cache" -) - - -def install_to_cache(source_path, destination_filename=None): - """ - Copies a file to the unsloth_compiled_cache directory - to ensure it is available for compiled modules. - """ - if not os.path.exists(UNSLOTH_COMPILE_LOCATION): - try: - os.makedirs(UNSLOTH_COMPILE_LOCATION) - except: - pass - - current_file = os.path.abspath(source_path) - if destination_filename is None: - destination_filename = os.path.basename(current_file) - - destination = os.path.abspath(os.path.join(UNSLOTH_COMPILE_LOCATION, destination_filename)) - - # If source and dest are different, copy. - if current_file != destination: - try: - shutil.copy(current_file, destination) - except Exception: - pass - - -install_to_cache(__file__, "moe_utils.py") - -# ============================================================================ -# Grouped MM wrapper -# ============================================================================ -# Simple wrapper around torch._grouped_mm that ensures contiguous inputs. -# Native backward works correctly - no custom autograd needed. -# ============================================================================ - - -def _grouped_mm_with_backward_fix( - inputs: torch.Tensor, weight: torch.Tensor, offsets: torch.Tensor -) -> torch.Tensor: - """ - Grouped matmul with working backward pass. - - Uses native torch._grouped_mm with contiguous inputs for correct gradients. - """ - return torch._grouped_mm(inputs, weight, offs=offsets) - - -# Global flag to check if grouped GEMM is available -_GROUPED_GEMM_AVAILABLE = None -_TORCH_GROUPED_MM_AVAILABLE = hasattr(torch, "_grouped_mm") - -# Check if GPU supports torch._grouped_mm (verified via runtime check) -_TORCH_GROUPED_MM_SUPPORTED = None - - -def _check_torch_grouped_mm_supported(): - """ - Check if torch._grouped_mm is actually supported on the current GPU. - We check for existence and verify with a dummy call. - A runtime probe is the only reliable check. - """ - global _TORCH_GROUPED_MM_SUPPORTED - if _TORCH_GROUPED_MM_SUPPORTED is not None: return _TORCH_GROUPED_MM_SUPPORTED - - if not _TORCH_GROUPED_MM_AVAILABLE: - _TORCH_GROUPED_MM_SUPPORTED = False - return False - - if not torch.cuda.is_available(): - _TORCH_GROUPED_MM_SUPPORTED = False - return False - - try: - # Attempt a dummy grouped_mm call to verify support. - # This handles cases where the symbol exists but hardware is unsupported (e.g. < H100). - # It also allows support on newer hardware or backports without code changes. - device = torch.cuda.current_device() - dtype = torch.float16 - - # Minimal dummy data: 1 expert, 1 token, dim 8 (safe alignment) - x = torch.ones((1, 8), device=device, dtype=dtype) - w = torch.ones((1, 8, 8), device=device, dtype=dtype) - offs = torch.tensor([1], device=device, dtype=torch.int32) - - torch._grouped_mm(x, w, offs=offs) - del x, w, offs - _TORCH_GROUPED_MM_SUPPORTED = True - except Exception: - _TORCH_GROUPED_MM_SUPPORTED = False - - return _TORCH_GROUPED_MM_SUPPORTED - - -_TRITON_ALLOCATOR_INITIALIZED = False -_PERSISTENT_BUFFER = None - - -def _init_triton_allocator(): - """ - Initialize a persistent Triton allocator to avoid memory allocation overhead per call. - This significantly reduces GPU utilization fluctuation. - """ - global _TRITON_ALLOCATOR_INITIALIZED, _PERSISTENT_BUFFER - if _TRITON_ALLOCATOR_INITIALIZED: return - - try: - import triton - - # Create a persistent buffer that grows as needed - # This avoids allocating new memory on every kernel call - - def persistent_alloc_fn(size: int, alignment: int, stream): - global _PERSISTENT_BUFFER - # Round up size to avoid frequent reallocations - # Round to nearest 128 bytes for alignment - rounded_size = ((size + 128 - 1) // 128) * 128 - - if ( - _PERSISTENT_BUFFER is None - or _PERSISTENT_BUFFER.numel() * _PERSISTENT_BUFFER.element_size() - < rounded_size - ): - # Allocate with small headroom (10%) to reduce reallocations - # Use ByteTensor (uint8) for raw byte storage - _PERSISTENT_BUFFER = torch.empty( - int(rounded_size * 1.1), device="cuda", dtype=torch.uint8 - ) - _PERSISTENT_BUFFER.__hibernate__ = {"type": "ignore"} - return _PERSISTENT_BUFFER - - triton.set_allocator(persistent_alloc_fn) - triton._unsloth_allocator_set = True - _TRITON_ALLOCATOR_INITIALIZED = True - except Exception: - pass - - -def _check_grouped_gemm_available(): - """Check if Unsloth grouped GEMM kernels are available.""" - if os.environ.get("UNSLOTH_DISABLE_MOE_TRITON", "0") == "1": return False - - global _GROUPED_GEMM_AVAILABLE - if _GROUPED_GEMM_AVAILABLE is not None: return _GROUPED_GEMM_AVAILABLE - - try: - from unsloth.kernels.moe.grouped_gemm.interface import grouped_gemm, supports_tma - _GROUPED_GEMM_AVAILABLE = True - _init_triton_allocator() - except (ImportError, ModuleNotFoundError): - _GROUPED_GEMM_AVAILABLE = False - return _GROUPED_GEMM_AVAILABLE - - -from functools import lru_cache - - -@lru_cache(maxsize=1) -def select_moe_backend(): - """ - Selects the MoE backend based on UNSLOTH_MOE_BACKEND environment variable and availability. - Choices: "grouped_mm", "unsloth_triton", "native_torch". - Default if unspecified: "grouped_mm". - """ - # This Unsloth Zoo code section is licensed under AGPL3 - - requested = os.environ.get("UNSLOTH_MOE_BACKEND") - if requested: - if requested == "grouped_mm" and _check_torch_grouped_mm_supported(): - return "grouped_mm" - if requested == "unsloth_triton" and _check_grouped_gemm_available(): - return "unsloth_triton" - if requested == "native_torch": - return "native_torch" - logger.info(f"Unsloth: '{requested}' backend requested but is not available. Falling back to next available.") - - if _check_torch_grouped_mm_supported(): - logger.info("Unsloth: Using MoE backend 'grouped_mm'") - return "grouped_mm" - if _check_grouped_gemm_available(): - logger.info("Unsloth: Using MoE backend 'unsloth_triton'") - return "unsloth_triton" - return "native_torch" - - -def forward_moe_backend( - self, - hidden_states: torch.Tensor, - top_k_index: torch.Tensor, - top_k_weights: torch.Tensor, -) -> torch.Tensor: - """ - Dispatch MoE forward to the selected backend. - Centralizes backend selection to keep model-specific patches minimal. - """ - # This Unsloth Zoo code section is licensed under AGPL3 - - backend = select_moe_backend() - if backend == "grouped_mm": - return forward_native_grouped_mm(self, hidden_states, top_k_index, top_k_weights) - if backend == "unsloth_triton": - return forward_triton_grouped_gemm(self, hidden_states, top_k_index, top_k_weights) - return forward_native_moe_loop(self, hidden_states, top_k_index, top_k_weights) - - -@torch.no_grad() -def _get_routing_indices(selected_experts, num_experts): - """ - Compute token→expert mapping for grouped GEMM. - Uses bincount instead of histc to avoid float conversion overhead. - - Returns: - token_counts_by_expert: (num_experts,) token counts per expert - gather_indices: (total_tokens,) indices for gathering tokens in expert order - """ - # This Unsloth Zoo code section is licensed under AGPL3 - - flat_experts = selected_experts.view(-1) - - # bincount is faster than histc since it doesn't require float conversion - token_counts_by_expert = torch.bincount(flat_experts, minlength=num_experts).to(torch.int32) - - # argsort with stable=True preserves order within each expert - gather_indices = flat_experts.argsort(stable=True) - - return token_counts_by_expert, gather_indices - - -def _silu_and_mul(x): - """Fused SiLU activation and element-wise multiply for gate/up projections.""" - gate, up = x.chunk(2, dim=-1) - return F.silu(gate) * up - - -# ============================================================================ -# Separated LoRA Helper Functions -# ============================================================================ - - -def _has_lora_adapters(param) -> bool: - """Check if parameter has active LoRA adapters (PEFT ParamWrapper).""" - # Check if this is a PEFT LoRA wrapper - if not hasattr(param, "lora_A") or not hasattr(param, "lora_B"): - return False - if hasattr(param, "disable_adapters") and param.disable_adapters: - return False - if hasattr(param, "merged") and param.merged: - return False - return len(param.lora_A) > 0 - - -def _extract_lora_from_wrapper( - wrapper, adapter_name: str = "default", experts_module=None -) -> Optional[Tuple[torch.Tensor, torch.Tensor, float, int]]: - """ - Extract LoRA weights from PEFT ParamWrapper for MoE separated computation. - - PEFT ParamWrapper for 3D parameters creates: - - lora_A: nn.Linear(in_dim, E*R) -> weight: (E*R, in_dim) - - lora_B: nn.Linear(E*R, out_dim) -> weight: (out_dim, E*R) - - For grouped_mm: X @ first_weight @ second_weight - - STANDARD FORMAT (Qwen3-MoE): weights stored as (E, out_dim, in_dim) for F.linear - gate_up_proj: (E, 2*I, H) - input X is (N, H), output is (N, 2*I) - down_proj: (E, H, I) - input X is (N, I), output is (N, H) - - For gate_up with (E, 2*I, H): - lora_A: (E*R, H), lora_B: (2*I, E*R) - Input X (N, H) needs: X @ (E, H, R) @ (E, R, 2*I) -> (N, 2*I) - first_weight from lora_A: (E*R, H) -> (E, H, R) after view/permute - second_weight from lora_B: (2*I, E*R) -> (E, R, 2*I) after view/permute - - TRANSPOSED FORMAT (Qwen3-VL-MoE): weights stored as (E, in_dim, out_dim) for grouped_mm - gate_up_proj: (E, H, 2*I) - input X is (N, H), output is (N, 2*I) - down_proj: (E, I, H) - input X is (N, I), output is (N, H) - - For gate_up with (E, H, 2*I): - lora_A: (E*R, H), lora_B: (2*I, E*R) - Input X (N, H) needs: X @ (E, H, R) @ (E, R, 2*I) -> (N, 2*I) - first_weight from lora_A: (E*R, H) -> (E, H, R) - second_weight from lora_B: (2*I, E*R) -> (E, R, 2*I) - - Returns: - (first_weight, second_weight, scaling, num_experts) or None - """ - # This Unsloth Zoo code section is licensed under AGPL3 - - try: - if not hasattr(wrapper, "lora_A") or not hasattr(wrapper, "lora_B"): - return None - - if hasattr(wrapper, "disable_adapters") and wrapper.disable_adapters: - return None - if hasattr(wrapper, "merged") and wrapper.merged: - return None - - if not wrapper.lora_A: - return None - - if adapter_name not in wrapper.lora_A: - adapter_name = list(wrapper.lora_A.keys())[0] - - lora_A_module = wrapper.lora_A[adapter_name] - lora_B_module = wrapper.lora_B[adapter_name] - - weight_A = lora_A_module.weight # (E*R, dim1) - weight_B = lora_B_module.weight # (dim2, E*R) - scaling = wrapper.scaling[adapter_name] - num_experts = getattr(wrapper, "num_experts", 1) - - # GET EXPERTS MODULE TO CHECK FOR REGISTERED EXTRACTOR - if experts_module is None: - experts_module = wrapper.get_base_layer() if hasattr(wrapper, "get_base_layer") else None - - # Check for model-specific LoRA extractor attached to the experts module - extractor_fn = getattr(experts_module, "_unsloth_lora_extractor_fn", None) - - if extractor_fn is not None: - return extractor_fn(wrapper, weight_A, weight_B, scaling, num_experts) - - # DEFAULT BEHAVIOR (Standard Format / Non-MoE) - if num_experts > 1: - total_rank = weight_A.shape[0] - rank_per_expert = total_rank // num_experts - dim1 = weight_A.shape[1] - dim2 = weight_B.shape[0] - - # STANDARD FORMAT (Qwen3-MoE / GLM4): - # Base weights are (E, out_dim, in_dim) for F.linear. - # LoRA weights follow PEFT: weight_A is (E*R, in_dim), weight_B is (out_dim, E*R). - # We need X @ (E, in_dim, R) @ (E, R, out_dim). - - # first_weight: (E, in_dim, R) - from lora_A - # second_weight: (E, R, out_dim) - from lora_B - first_weight = weight_A.view(num_experts, rank_per_expert, dim1) - first_weight = first_weight.permute(0, 2, 1).contiguous() # (E, dim1, R) - - # second_weight (B): (E, R, out_dim) - second_weight = weight_B.view(dim2, num_experts, rank_per_expert) - second_weight = second_weight.permute(1, 2, 0).contiguous() # (E, R, dim2) - else: - # Non-MoE case: return weights for X @ A.T @ B.T - first_weight = weight_A.T # (dim1, R) - second_weight = weight_B.T # (R, dim2) - - return first_weight, second_weight, scaling, num_experts - except Exception: - return None - - -def _extract_lora_weights( - param, adapter_name: str = "default", num_experts: int = None, experts_module=None -) -> Optional[Tuple[torch.Tensor, torch.Tensor, float]]: - """ - Extract LoRA A and B weights from PEFT ParamWrapper. - - This is a compatibility wrapper around _extract_lora_from_wrapper. - Use _extract_lora_from_wrapper directly for new code. - - Returns: - (first_weight, second_weight, scaling) for (X @ first) @ second - """ - # This Unsloth Zoo code section is licensed under AGPL3 - - # Set num_experts on param if provided, so _extract_lora_from_wrapper can use it - if num_experts is not None and not hasattr(param, "num_experts"): - param.num_experts = num_experts - - result = _extract_lora_from_wrapper(param, adapter_name, experts_module=experts_module) - if result is None: - return None - # Return first 3 elements (first_weight, second_weight, scaling) without num_experts - return result[0], result[1], result[2] - - -def _get_base_weight(param): - """Get base weight from potentially wrapped parameter or module.""" - # This Unsloth Zoo code section is licensed under AGPL3 - - # Recursively unwrap PEFT layers - while hasattr(param, "base_layer"): - param = param.base_layer - - if hasattr(param, "get_param"): - return param.get_param() - - # Handle Modules (Linear, etc.) - if hasattr(param, "weight"): - return param.weight - - return param - - -def _get_lora_wrapper_for_param(experts_module, param_name): - """ - Get the PEFT ParamWrapper for a specific parameter (gate_up_proj or down_proj). - Uses the explicit key stored in __dict__ if available. - Does NOT lazily setup wrappers as that requires traversing logic not present here. - """ - # This Unsloth Zoo code section is licensed under AGPL3 - - if hasattr(experts_module, f"{param_name}_lora_wrapper"): - return getattr(experts_module, f"{param_name}_lora_wrapper") - - # Check simple attributes if it's directly wrapped - if hasattr(experts_module, param_name): - attr = getattr(experts_module, param_name) - if hasattr(attr, "lora_A"): # Is a ParamWrapper - return attr - - return None - - -def native_moe_grouped_mm( - inputs: torch.Tensor, weight: torch.Tensor, offsets: torch.Tensor -) -> torch.Tensor: - """ - Native implementation using grouped_mm with backward fix. - - Uses custom autograd function to avoid PyTorch's grouped_mm backward stride bug. - """ - return _grouped_mm_with_backward_fix(inputs, weight, offsets) - - -def _apply_lora_grouped_mm( - inputs: torch.Tensor, - lora_B: torch.Tensor, - lora_A: torch.Tensor, - offsets: torch.Tensor, - scaling: float, - grouped_mm_func=native_moe_grouped_mm, -) -> torch.Tensor: - """ - Apply LoRA using grouped GEMM: result = ((X @ B) @ A) * scaling - - Args: - inputs: (total_tokens, in_dim) - lora_B: (num_experts, in_dim, rank) - First projection - lora_A: (num_experts, rank, out_dim) - Second projection - offsets: Grouped GEMM offsets - scaling: LoRA scaling factor - grouped_mm_func: Function to use for grouped GEMM (default: native_moe_grouped_mm) - """ - # This Unsloth Zoo code section is licensed under AGPL3 - - # 1. First Matmul (X @ B) - # lora_B is (E, in_dim, R) - # Native needs (E, in_dim, R) -> No Transpose - lora_intermediate = grouped_mm_func(inputs, lora_B.contiguous(), offsets) - - # 2. Second Matmul (result @ A) - # lora_A is (E, R, out_dim) - # Native needs (E, R, out_dim) -> No Transpose - lora_delta = grouped_mm_func(lora_intermediate, lora_A.contiguous(), offsets) - - return lora_delta * scaling - - -def _should_use_separated_lora() -> bool: - """ - Check if separated LoRA approach should be used (default: True). - Set UNSLOTH_MOE_LORA_MERGED=1 to use merged approach instead. - """ - return os.environ.get("UNSLOTH_MOE_LORA_MERGED", "0") != "1" - - -# ============================================================================ -# Model-specific Weight Preprocessing Hooks -# ============================================================================ -# Each model can register its own preprocessing function for weight transposition. -# This allows the generic backend to work with different model weight layouts. - -_WEIGHT_PREPROCESSORS = {} - - -def register_weight_preprocessor(model_type: str, preprocessor_fn): - """ - Register a weight preprocessor for a specific model type. - - Args: - model_type: Model identifier (e.g., "qwen3_moe", "qwen3_vl_moe") - preprocessor_fn: Function(weight, proj_type, hidden_dim) -> processed_weight - proj_type is "gate_up" or "down" - """ - _WEIGHT_PREPROCESSORS[model_type] = preprocessor_fn - - -def get_weight_preprocessor(model_type: str): - """Get registered weight preprocessor for model type.""" - return _WEIGHT_PREPROCESSORS.get(model_type) - - -def preprocess_weight( - weight: torch.Tensor, proj_type: str, hidden_dim: int, model_type=None -): - """ - Preprocess weight tensor for grouped_mm compatibility. - - Uses model-specific preprocessor if registered, otherwise uses default logic. - - Args: - weight: Weight tensor (E, dim1, dim2) or similar - proj_type: "gate_up" or "down" - hidden_dim: Hidden dimension for shape inference - model_type: Optional model type to use specific preprocessor - - Returns: - Weight tensor in (E, in_dim, out_dim) format for grouped_mm - """ - # This Unsloth Zoo code section is licensed under AGPL3 - - if model_type and model_type in _WEIGHT_PREPROCESSORS: - return _WEIGHT_PREPROCESSORS[model_type](weight, proj_type, hidden_dim) - - # Default preprocessing: check if transposition is needed - if proj_type == "gate_up": - # For gate_up, we need (E, hidden_dim, 2*intermediate) - if weight.shape[1] == hidden_dim: - return weight - else: - return weight.transpose(-2, -1) - else: # down - # For down, we need (E, intermediate, hidden_dim) - if weight.shape[2] == hidden_dim: - return weight - else: - return weight.transpose(-2, -1) - - -# ============================================================================ -# Generic MoE Detection and ParamWrapper Patching -# ============================================================================ - - -def _is_moe_experts_module(module) -> bool: - """ - Check if module is an MoE experts layer (generic, not model-specific). - - Detects modules with stacked expert weights as 3D nn.Parameter: - - gate_up_proj/down_proj pattern (Qwen3-MoE, Qwen3-VL-MoE, etc.) - - w1/w2/w3 pattern (older MoE models) - """ - # This Unsloth Zoo code section is licensed under AGPL3 - - import torch.nn as nn - - # Check for gate_up_proj pattern - if hasattr(module, "gate_up_proj"): - param = module.gate_up_proj - if isinstance(param, nn.Parameter) and param.ndim == 3: - return True - - # Check for w1/w2 pattern (separate gate/up projections) - if hasattr(module, "w1") and hasattr(module, "w2"): - w1 = module.w1 - if isinstance(w1, nn.Parameter) and w1.ndim == 3: - return True - - return False - - -# Aliases for compatibility with gpt_oss.py -_get_moe_lora_weights = _extract_lora_from_wrapper - - -# Store original ParamWrapper.forward for fallback -_original_param_wrapper_forward = None - - -def _patched_param_wrapper_forward( - self, x: torch.Tensor, *args, **kwargs -) -> torch.Tensor: - """ - Patched ParamWrapper.forward for MoE separated LoRA. - - For MoE expert modules: - - Bypasses PEFTs _activate_lora parametrization context - - Stores LoRA data by parameter_name for forward_native_grouped_mm to use - - For non-MoE modules: - - Falls back to original PEFT forward - """ - # This Unsloth Zoo code section is licensed under AGPL3 - - # CRITICAL: Use self.base_layer for forward call (immediate parent) - # NOT self.get_base_layer() which recursively traverses to deepest layer! - # The wrapper chain must be preserved: down_proj -> gate_up_proj -> Qwen3MoeExperts - immediate_base_layer = self.base_layer - - # For storing LoRA data, we DO need the actual experts module - # Use get_base_layer() to find it (recursive traversal is correct here) - experts_module = self.get_base_layer() - - use_separated = _should_use_separated_lora() - param_name = getattr(self, "parameter_name", None) - - # Check if this is an MoE experts module that should use separated LoRA - if ( - use_separated - and param_name in ("gate_up_proj", "down_proj") - and _is_moe_experts_module(experts_module) - ): - # MoE experts: bypass PEFT's _activate_lora, use separated computation - - # Check adapter state - if self.disable_adapters: - if self.merged: - self.unmerge() - return immediate_base_layer(x, *args, **kwargs) - - if self.merged: - return immediate_base_layer(x, *args, **kwargs) - - # Ensure wrapper.num_experts is set for LoRA weight reshaping - if not hasattr(self, "num_experts"): - if hasattr(experts_module, "num_experts"): - self.num_experts = experts_module.num_experts - elif hasattr(experts_module, param_name): - p = getattr(experts_module, param_name) - if hasattr(p, "shape") and len(p.shape) >= 1: - self.num_experts = p.shape[0] - - # Extract LoRA for this specific parameter - lora_data = _extract_lora_from_wrapper(self) - - if lora_data is not None and param_name: - # Store LoRA data on the EXPERTS MODULE (not base_layer) - # e.g., _unsloth_lora_gate_up_proj or _unsloth_lora_down_proj - lora_attr = f"_unsloth_lora_{param_name}" - setattr(experts_module, lora_attr, lora_data) - - try: - # Call IMMEDIATE base_layer to preserve wrapper chain - # (down_proj wrapper calls gate_up_proj wrapper calls Qwen3MoeExperts) - result = immediate_base_layer(x, *args, **kwargs) - finally: - # Clean up - if param_name: - lora_attr = f"_unsloth_lora_{param_name}" - if hasattr(experts_module, lora_attr): - delattr(experts_module, lora_attr) - - return result - - # Non-MoE: use original PEFT forward with _activate_lora - return _original_param_wrapper_forward(self, x, *args, **kwargs) - - -def patch_param_wrapper_for_moe(): - """ - Patch PEFT's ParamWrapper.forward to use separated LoRA for MoE. - - This should be called after PEFT is imported. - """ - # This Unsloth Zoo code section is licensed under AGPL3 - - global _original_param_wrapper_forward - - try: - from peft.tuners.lora.layer import ParamWrapper - - # Store original forward - if _original_param_wrapper_forward is None: - _original_param_wrapper_forward = ParamWrapper.forward - - # Patch with our version - ParamWrapper.forward = _patched_param_wrapper_forward - - return True - except ImportError: - return False - - -def forward_native_grouped_mm( - self, - hidden_states: torch.Tensor, - top_k_index: torch.Tensor, - top_k_weights: torch.Tensor, -) -> torch.Tensor: - """ - Native Pytorch grouped GEMM MoE forward pass. - Uses torch._grouped_mm which is significantly faster than loop and works without Triton dependencies. - Requires torch._grouped_mm support (verified via runtime check). - """ - # This Unsloth Zoo code section is licensed under AGPL3 - - # Runtime safety check - defense in depth - if not _check_torch_grouped_mm_supported(): - major, minor = torch.cuda.get_device_capability(torch.cuda.current_device()) - raise RuntimeError( - f"torch._grouped_mm is not supported on this device (Compute Capability {major}.{minor}). " - f"Set UNSLOTH_MOE_BACKEND='unsloth_triton' or 'native_torch' to use a compatible backend." - ) - - is_2d_input = hidden_states.dim() == 2 - if is_2d_input: - sequence_length, hidden_dim = hidden_states.shape - batch_size = 1 - else: - batch_size, sequence_length, hidden_dim = hidden_states.shape - - hidden_states = hidden_states.view(-1, hidden_dim) - - # 1. Calculate routing - flat_top_k = top_k_index.view(-1) - num_tokens_per_expert = torch.bincount(flat_top_k, minlength=self.num_experts).int() - - # 2. Sort indices to group tokens by expert - sorted_indices = torch.argsort(flat_top_k, stable=True) - token_indices = sorted_indices // top_k_index.shape[-1] - - # 3. Permute Input - # We need to gather inputs. Since we may have expanded top_k, we use token_indices to map back to original input - permuted_input = hidden_states[token_indices] - - # 4. Prepare Grouped MM arguments - offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) - - # ======================================================================== - # Gate + Up projection with optional separated LoRA (DEFAULT) - # ======================================================================== - use_separated_lora = _should_use_separated_lora() - gate_up_lora = None - - # Check for injected LoRA data from patched ParamWrapper (preferred path) - if getattr(self, "_unsloth_lora_gate_up_proj", None) is not None: - gate_up_lora = self._unsloth_lora_gate_up_proj[ - :3 - ] # (first_weight, second_weight, scaling) - # Fallback: check parameter directly (for older wrapping patterns) - elif ( - use_separated_lora - and hasattr(self, "gate_up_proj") - and _has_lora_adapters(self.gate_up_proj) - ): - gate_up_lora = _extract_lora_weights( - self.gate_up_proj, num_experts=self.num_experts, experts_module=self - ) - - if hasattr(self, "gate_up_proj"): - # Get base weights (raw, without LoRA) - gate_up_base = _get_base_weight(self.gate_up_proj) - - # Get model type for preprocessing (if registered) - model_type = getattr(self, "_unsloth_model_type", None) - - # Handle different weight shapes using preprocessor - # torch._grouped_mm backward requires weights to be contiguous; preprocessing may return a transposed view. - w1 = preprocess_weight(gate_up_base, "gate_up", hidden_dim, model_type) - # Base forward: X @ W - mm1_out = _grouped_mm_with_backward_fix(permuted_input, w1, offsets) - - # Add separated LoRA contribution: + ((X @ first) @ second) * scaling - # _extract_lora_from_wrapper returns (first_weight, second_weight, scaling) - if gate_up_lora is not None: - first_weight, second_weight, scaling = gate_up_lora - - # Cast to input dtype (LoRA weights are float32, input may be bfloat16) - # Ensure contiguous for grouped_mm alignment requirements - first_weight = first_weight.to(permuted_input.dtype).contiguous() - second_weight = second_weight.to(permuted_input.dtype).contiguous() - - # Step 1: permuted_input @ first_weight - try: - lora_out = _grouped_mm_with_backward_fix(permuted_input, first_weight, offsets) - lora_out = lora_out.contiguous() - except RuntimeError as e: - raise e - - # Step 2: result @ second_weight - # Handle unaligned O dimension or other grouped_mm failures - try: - if second_weight.shape[-1] % 8 != 0: - pad_size = 8 - (second_weight.shape[-1] % 8) - second_weight_padded = F.pad( - second_weight, (0, pad_size) - ).contiguous() - lora_delta = _grouped_mm_with_backward_fix( - lora_out, second_weight_padded, offsets - ) - lora_delta = lora_delta[:, :-pad_size] - else: - lora_delta = _grouped_mm_with_backward_fix( - lora_out, second_weight, offsets - ) - except RuntimeError: - # Fallback to manual loop if grouped_mm fails (e.g. stride alignment) - lora_delta = torch.empty( - (lora_out.shape[0], second_weight.shape[-1]), - dtype=lora_out.dtype, - device=lora_out.device, - ) - cpu_offsets = offsets.cpu().tolist() - prev_offset = 0 - for i, end in enumerate(cpu_offsets): - if prev_offset < end: - lora_delta[prev_offset:end] = torch.matmul( - lora_out[prev_offset:end], second_weight[i] - ) - prev_offset = end - - # Add scaled LoRA contribution - mm1_out = mm1_out + lora_delta * scaling - - if hasattr(self, "gate_up_proj_bias") and self.gate_up_proj_bias is not None: - num_repeats = num_tokens_per_expert.to(self.gate_up_proj_bias.device) - bias_expanded = self.gate_up_proj_bias.repeat_interleave(num_repeats, dim=0) - mm1_out = mm1_out + bias_expanded.to(mm1_out.dtype) - - if "GptOssExperts" in self.__class__.__name__: - gate = mm1_out[..., ::2] - up = mm1_out[..., 1::2] - else: - gate, up = mm1_out.chunk(2, dim=-1) - - elif hasattr(self, "w1") and hasattr(self, "w3"): - # Separate w1/w3 weights (older models) - w1_base = _get_base_weight(self.w1) - w3_base = _get_base_weight(self.w3) - - w1 = w1_base.transpose(-2, -1) - w3 = w3_base.transpose(-2, -1) - - gate = _grouped_mm_with_backward_fix(permuted_input, w1, offsets) - up = _grouped_mm_with_backward_fix(permuted_input, w3, offsets) - - # Add LoRA for w1 and w3 separately if present - if use_separated_lora: - if _has_lora_adapters(self.w1): - w1_lora = _extract_lora_weights(self.w1, experts_module=self) - if w1_lora is not None: - lora_A, lora_B, scaling = w1_lora - lora_A_t = lora_A.transpose(-2, -1) - lora_A_out = _grouped_mm_with_backward_fix( - permuted_input, lora_A_t, offsets - ) - lora_B_t = lora_B.transpose(-2, -1) - lora_B_out = _grouped_mm_with_backward_fix(lora_A_out, lora_B_t, offsets) - gate = gate + lora_B_out * scaling - - if _has_lora_adapters(self.w3): - w3_lora = _extract_lora_weights(self.w3, experts_module=self) - if w3_lora is not None: - lora_A, lora_B, scaling = w3_lora - lora_A_t = lora_A.transpose(-2, -1) - lora_A_out = _grouped_mm_with_backward_fix( - permuted_input, lora_A_t, offsets - ) - lora_B_t = lora_B.transpose(-2, -1) - lora_B_out = _grouped_mm_with_backward_fix(lora_A_out, lora_B_t, offsets) - up = up + lora_B_out * scaling - else: - raise AttributeError("MoE layer must have 'gate_up_proj' or 'w1'/'w3'.") - - # Activation - if "GptOssExperts" in self.__class__.__name__: - # Custom activation from GptOss - limit = getattr(self, "limit", 7.0) - alpha = getattr(self, "alpha", 1.702) - - gate = gate.clamp(min=None, max=limit) - up = up.clamp(min=-limit, max=limit) - glu = gate * torch.sigmoid(gate * alpha) - inter = (up + 1.0) * glu - else: - inter = F.silu(gate) * up - - # ======================================================================== - # Down projection with optional separated LoRA (DEFAULT) - # ======================================================================== - down_lora = None - - # Check for injected LoRA data from patched ParamWrapper (preferred path) - if getattr(self, "_unsloth_lora_down_proj", None) is not None: - down_lora = self._unsloth_lora_down_proj[ - :3 - ] # (first_weight, second_weight, scaling) - # Fallback: check parameter directly (for older wrapping patterns) - elif ( - use_separated_lora - and hasattr(self, "down_proj") - and _has_lora_adapters(self.down_proj) - ): - down_lora = _extract_lora_weights(self.down_proj, num_experts=self.num_experts, experts_module=self) - - if hasattr(self, "down_proj"): - # Get base weights - down_base = _get_base_weight(self.down_proj) - - # Get model type for preprocessing (if registered) - model_type = getattr(self, "_unsloth_model_type", None) - - # Handle different weight shapes using preprocessor - w2 = preprocess_weight(down_base, "down", hidden_dim, model_type) - - # Base forward - mm2_out = _grouped_mm_with_backward_fix(inter, w2, offsets) - - # Add separated LoRA contribution if present - # _extract_lora_from_wrapper returns (first_weight, second_weight, scaling) - if down_lora is not None: - first_weight, second_weight, scaling = down_lora - - # Cast to input dtype (LoRA weights are float32, input may be bfloat16) - first_weight = first_weight.to(inter.dtype).contiguous() - second_weight = second_weight.to(inter.dtype).contiguous() - - # Step 1: inter @ first_weight - lora_out = _grouped_mm_with_backward_fix(inter, first_weight, offsets) - lora_out = lora_out.contiguous() - - # Step 2: result @ second_weight - try: - lora_delta = _grouped_mm_with_backward_fix(lora_out, second_weight, offsets) - except RuntimeError: - # Fallback to manual loop - lora_delta = torch.empty( - (lora_out.shape[0], second_weight.shape[-1]), - dtype=lora_out.dtype, - device=lora_out.device, - ) - cpu_offsets = offsets.cpu().tolist() - prev_offset = 0 - for i, end in enumerate(cpu_offsets): - if prev_offset < end: - lora_delta[prev_offset:end] = torch.matmul( - lora_out[prev_offset:end], second_weight[i] - ) - prev_offset = end - - # Add scaled LoRA contribution - mm2_out = mm2_out + lora_delta * scaling - - if hasattr(self, "down_proj_bias") and self.down_proj_bias is not None: - bias_expanded = self.down_proj_bias.repeat_interleave( - num_tokens_per_expert.to(self.down_proj_bias.device), dim=0 - ).to(mm2_out.device) - mm2_out = mm2_out + bias_expanded.to(mm2_out.dtype) - - elif hasattr(self, "w2"): - w2_base = _get_base_weight(self.w2) - w2 = w2_base.transpose(-2, -1) - - # Base forward - mm2_out = _grouped_mm_with_backward_fix(inter, w2, offsets) - - # Add LoRA if present - if use_separated_lora and _has_lora_adapters(self.w2): - w2_lora = _extract_lora_weights(self.w2, experts_module=self) - if w2_lora is not None: - lora_A, lora_B, scaling = w2_lora - lora_A_t = lora_A.transpose(-2, -1).contiguous() - lora_A_out = _grouped_mm_with_backward_fix(inter, lora_A_t, offsets) - lora_B_t = lora_B.transpose(-2, -1).contiguous() - lora_B_out = _grouped_mm_with_backward_fix(lora_A_out, lora_B_t, offsets) - mm2_out = mm2_out + lora_B_out * scaling - else: - raise AttributeError("MoE layer must have 'down_proj' or 'w2'.") - - # 5. Apply Routing Weights and Scatter Add (Reduce) - flat_weights = top_k_weights.view(-1) - permuted_weights = flat_weights[sorted_indices] - mm2_out = mm2_out * permuted_weights.unsqueeze(-1) - - final_hidden_states = torch.zeros( - (batch_size * sequence_length, hidden_dim), - dtype=hidden_states.dtype, - device=hidden_states.device, - ) - - final_hidden_states.index_add_(0, token_indices, mm2_out.to(hidden_states.dtype)) - - if is_2d_input: - return final_hidden_states - - return final_hidden_states.view(batch_size, sequence_length, hidden_dim) - - -def forward_triton_grouped_gemm( - self, - hidden_states: torch.Tensor, - top_k_index: torch.Tensor, - top_k_weights: torch.Tensor, -) -> torch.Tensor: - """ - Grouped GEMM MoE forward pass using Triton kernels. - Compatible with torch.compile (recommended mode="max-autotune" with cudagraph_mark_step_begin). - """ - # This Unsloth Zoo code section is licensed under AGPL3 - - # Import grouped GEMM interface - from unsloth.kernels.moe.grouped_gemm.interface import grouped_gemm - - # Import autotune cache - from unsloth.kernels.moe.autotune_cache import get_or_autotune_moe_kernels - - # Helper to check TMA support - assumes helper function or just check directly - # In original: it was a cached closure. Here we can use _supports_tma() directly - - # nonlocal _MODEL_DIMS_AND_CONFIGS # We need a way to store this! - # For now, let's attach it to self if possible, or use a global usage - # Attaching to self is cleaner: self._unsloth_moe_configs - - # Create expert mask and find which experts have tokens - - if not hasattr(self, "_unsloth_moe_configs"): - self._unsloth_moe_configs = None - - use_separated_lora = _should_use_separated_lora() - - - # Handle 3D inputs (batch_size, seq_len, hidden_dim) - is_3d = hidden_states.dim() == 3 - if is_3d: - batch_size, seq_len, hidden_dim = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_dim) - num_tokens = batch_size * seq_len - # Also flatten top_k inputs if they are 3D - if top_k_index.dim() == 3: - top_k_index = top_k_index.view(-1, top_k_index.shape[-1]) - if top_k_weights.dim() == 3: - top_k_weights = top_k_weights.view(-1, top_k_weights.shape[-1]) - else: - num_tokens, hidden_dim = hidden_states.shape - - top_k = top_k_index.shape[1] - - # Cache model dimensions and kernel configs on first call - if self._unsloth_moe_configs is None: - intermediate_dim = self.gate_up_proj.shape[1] // 2 - - # Autotune first GEMM - gemm1_configs = get_or_autotune_moe_kernels( - num_experts=self.num_experts, - hidden_dim=hidden_dim, - intermediate_dim=intermediate_dim * 2, - top_k=top_k, - dtype=hidden_states.dtype, - ) - - # Autotune second GEMM - gemm2_configs = get_or_autotune_moe_kernels( - num_experts=self.num_experts, - hidden_dim=intermediate_dim, - intermediate_dim=hidden_dim, # Output dim for 2nd GEMM is hidden_dim - top_k=top_k, - dtype=hidden_states.dtype, - ) - - self._unsloth_moe_configs = (intermediate_dim, gemm1_configs, gemm2_configs) - - # Clear autotuning memory overhead - torch.cuda.empty_cache() - - # Unpack cached configs - intermediate_dim, gemm1_configs, gemm2_configs = self._unsloth_moe_configs - - # Unpack specific kernel configs - fwd_config_1, bwd_dX_config_1, bwd_dW_config_1 = gemm1_configs - fwd_config_2, bwd_dX_config_2, bwd_dW_config_2 = gemm2_configs - - # Compute routing indices for grouped GEMM - token_counts_by_expert, gather_indices = _get_routing_indices( - top_k_index, self.num_experts - ) - offsets = torch.cumsum(token_counts_by_expert, dim=0, dtype=torch.int32) - - if self.gate_up_proj.shape[-1] == hidden_dim: - w1 = self.gate_up_proj - else: - w1 = self.gate_up_proj.transpose(-2, -1).contiguous() - - # First grouped GEMM: gate_up projection - first_gemm_output = grouped_gemm( - X=hidden_states, - W=w1, - m_sizes=token_counts_by_expert, - topk=top_k, - gather_indices=gather_indices, - permute_x=True, - permute_y=False, - autotune=False, # We use cached configs - kernel_config_fwd=fwd_config_1, - kernel_config_bwd_dX=bwd_dX_config_1, - kernel_config_bwd_dW=bwd_dW_config_1, - is_first_gemm=True, - ) - - # Apply SiLU activation and multiply gate with up - intermediate = _silu_and_mul(first_gemm_output) - - # Grouped GEMM 2: down projection - - # Grouped GEMM 2: down projection - # Prepare LoRA data - down_lora = None - if getattr(self, "_unsloth_lora_down_proj", None) is not None: - down_lora = self._unsloth_lora_down_proj[:3] - elif ( - use_separated_lora - and hasattr(self, "down_proj") - and _has_lora_adapters(self.down_proj) - ): - down_lora = _extract_lora_weights(self.down_proj, num_experts=self.num_experts) - - if self.down_proj.shape[-1] == intermediate.shape[-1]: - w2 = self.down_proj - else: - w2 = self.down_proj.transpose(-2, -1).contiguous() - - second_gemm_output = grouped_gemm( - X=intermediate, - W=w2, - m_sizes=token_counts_by_expert, - topk=top_k, - gather_indices=gather_indices, - permute_x=False, - permute_y=True, - autotune=False, # We use cached configs - kernel_config_fwd=fwd_config_2, - kernel_config_bwd_dX=bwd_dX_config_2, - kernel_config_bwd_dW=bwd_dW_config_2, - is_first_gemm=False, - ) - - # Add separated LoRA contribution for Down - if down_lora is not None: - first_weight, second_weight, scaling = down_lora - - # Intermediate is already permuted from step 1. - # Offsets are same. - - first_weight = first_weight.to(intermediate.dtype) - second_weight = second_weight.to(intermediate.dtype) - - lora_delta = _apply_lora_grouped_mm( - intermediate, - first_weight, - second_weight, - offsets, - scaling, - grouped_mm_func=native_moe_grouped_mm - ) - - second_gemm_output = second_gemm_output + lora_delta - - # Apply routing weights and sum across top_k experts - # Output shape: (num_tokens, top_k, hidden_dim) -> (num_tokens, hidden_dim) - # Ensure top_k_weights matches dtype (can be float32 from softmax) - top_k_weights_casted = top_k_weights.to(hidden_states.dtype) - final_hidden_states = ( - second_gemm_output.view(num_tokens, top_k, hidden_dim) - * top_k_weights_casted[..., None] - ) - final_hidden_states = final_hidden_states.sum(dim=1) - - if is_3d: - final_hidden_states = final_hidden_states.view(batch_size, seq_len, hidden_dim) - - return final_hidden_states - - -@torch.compiler.disable -def forward_native_moe_loop( - self, - hidden_states: torch.Tensor, - top_k_index: torch.Tensor, - top_k_weights: torch.Tensor, -) -> torch.Tensor: - """ - Loop-based MoE forward pass. Loops over experts that have tokens routed to them. - Explicitly disabled for torch.compile to prevent graph breaks/recompilation issues with dynamic control flow. - """ - # This Unsloth Zoo code section is licensed under AGPL3 - final_hidden_states = torch.zeros_like(hidden_states) - - # Create expert mask and find which experts have tokens - with torch.no_grad(): - expert_mask = F.one_hot(top_k_index, num_classes=self.num_experts) - expert_mask = expert_mask.permute(2, 1, 0) # (num_experts, top_k, n_tokens) - expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - - # Only loop over experts that actually have tokens routed to them - for expert_idx_t in expert_hit: - expert_idx = expert_idx_t.item() - - # Find which tokens are routed to this expert - top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) - - # Gather only the tokens for this expert - current_state = hidden_states[token_idx] - - # Compute gate_up projection for this expert only - # Handle 'gate_up_proj' or 'w1'/'w3' - if hasattr(self, "gate_up_proj"): - gate, up = F.linear(current_state, self.gate_up_proj[expert_idx]).chunk( - 2, dim=-1 - ) - else: - gate = F.linear(current_state, self.w1[expert_idx]) - up = F.linear(current_state, self.w3[expert_idx]) - - current_hidden_states = self.act_fn(gate) * up - - # Compute down projection for this expert only - if hasattr(self, "down_proj"): - current_hidden_states = F.linear( - current_hidden_states, self.down_proj[expert_idx] - ) - else: - current_hidden_states = F.linear(current_hidden_states, self.w2[expert_idx]) - - # Apply routing weights - current_hidden_states = ( - current_hidden_states * top_k_weights[token_idx, top_k_pos, None] - ) - - # Scatter back to final output - final_hidden_states.index_add_( - 0, token_idx, current_hidden_states.to(final_hidden_states.dtype) - ) - - return final_hidden_states