Files
medical_ai_scribe/medical_diarizer.py
2026-03-05 01:20:13 +01:00

87 lines
2.9 KiB
Python

import os
import sys
import time
import torch
import gc
from pyannote.audio import Pipeline
from faster_whisper import WhisperModel
import librosa
def run_diarization_and_transcription(audio_file, hf_token):
"""
Version 3.0 : Feedback ultra-précis pour barre de progression.
"""
duration = librosa.get_duration(path=audio_file)
print(f"[STATUS] PROGRESS:1")
# 1. DIARISATION
print("[PHASE 1/3] Diarisation (Analyse des voix)...")
try:
pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1", token=hf_token)
pipeline.to(torch.device("cpu"))
raw_result = pipeline(audio_file)
if hasattr(raw_result, 'annotation'):
diarization = raw_result.annotation
else:
diarization = raw_result
del pipeline
gc.collect()
except Exception as e:
print(f"[ERROR] Diarisation : {e}")
diarization = None
print(f"[STATUS] PROGRESS:30")
# 2. TRANSCRIPTION
print("[PHASE 2/3] Transcription (Modèle Large-v3-Turbo)...")
device = "cuda" if torch.cuda.is_available() else "cpu"
compute_type = "int8_float16" if device == "cuda" else "int8"
try:
model = WhisperModel("large-v3-turbo", device=device, compute_type=compute_type, cpu_threads=16)
segments, _ = model.transcribe(audio_file, beam_size=5, language="fr", word_timestamps=True)
whisper_segments = []
for segment in segments:
# Progression dynamique : 30% à 95%
pct = 30 + (65 * (segment.end / duration))
print(f"[STATUS] PROGRESS:{int(pct)}")
speaker = "VOIX"
if diarization:
max_overlap = 0
for turn, _, speaker_id in diarization.itertracks(yield_label=True):
overlap = min(segment.end, turn.end) - max(segment.start, turn.start)
if overlap > max_overlap:
max_overlap = overlap
speaker = speaker_id
timestamp = f"[{time.strftime('%H:%M:%S', time.gmtime(segment.start))}]"
line = f"{timestamp} {speaker}: {segment.text.strip()}"
print(line)
whisper_segments.append(line)
del model
gc.collect()
if torch.cuda.is_available(): torch.cuda.empty_cache()
except Exception as e:
print(f"[ERROR] Transcription : {e}")
return None
print(f"[STATUS] PROGRESS:100")
return "\n".join(whisper_segments)
if __name__ == "__main__":
hf_token = os.getenv("HF_TOKEN")
if len(sys.argv) > 1:
audio_file = sys.argv[1]
result = run_diarization_and_transcription(audio_file, hf_token)
if result:
output_file = audio_file.rsplit('.', 1)[0] + "_diarized.txt"
with open(output_file, "w", encoding="utf-8") as f:
f.write(result)
print(f"\n[OK] Fini : {output_file}")