got cpu based backend working; trying python/gpu solution bc faster probs
This commit is contained in:
201
src-tauri/src/transcription.rs
Normal file
201
src-tauri/src/transcription.rs
Normal file
@ -0,0 +1,201 @@
|
||||
use std::fs;
|
||||
use std::process::Command;
|
||||
use whisper_rs::{WhisperContext, WhisperContextParameters, FullParams, SamplingStrategy};
|
||||
|
||||
#[derive(serde::Serialize, serde::Deserialize, Clone, Debug)]
|
||||
pub struct TranscriptionResult {
|
||||
pub words: Vec<Word>,
|
||||
pub segments: Vec<Segment>,
|
||||
pub language: String,
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize, serde::Deserialize, Clone, Debug)]
|
||||
pub struct Word {
|
||||
pub word: String,
|
||||
pub start: f64,
|
||||
pub end: f64,
|
||||
pub confidence: f64,
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize, serde::Deserialize, Clone, Debug)]
|
||||
pub struct Segment {
|
||||
pub id: usize,
|
||||
pub start: f64,
|
||||
pub end: f64,
|
||||
pub text: String,
|
||||
pub words: Vec<Word>,
|
||||
}
|
||||
|
||||
/// Extract audio from a video/audio file to a 16kHz mono WAV using ffmpeg
|
||||
fn extract_to_wav(input_path: &str, output_path: &str) -> Result<(), String> {
|
||||
let status = Command::new("ffmpeg")
|
||||
.args(["-y", "-i", input_path, "-vn", "-ar", "16000", "-ac", "1", "-f", "wav", output_path])
|
||||
.status()
|
||||
.map_err(|e| format!("Failed to run ffmpeg: {}", e))?;
|
||||
|
||||
if !status.success() {
|
||||
return Err(format!("ffmpeg exited with code: {:?}", status.code()));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Transcribe audio file using whisper-rs (real Whisper.cpp inference)
|
||||
pub fn transcribe_audio(
|
||||
file_path: &str,
|
||||
model_name: &str,
|
||||
language: Option<&str>,
|
||||
) -> Result<TranscriptionResult, String> {
|
||||
// Ensure model is downloaded
|
||||
let model_path = ensure_model_downloaded(model_name)?;
|
||||
|
||||
// Extract audio to temp 16kHz mono WAV
|
||||
let tmp_wav = tempfile::Builder::new()
|
||||
.suffix(".wav")
|
||||
.tempfile()
|
||||
.map_err(|e| format!("Failed to create temp file: {}", e))?;
|
||||
let wav_path = tmp_wav.path().to_string_lossy().to_string();
|
||||
|
||||
extract_to_wav(file_path, &wav_path)?;
|
||||
|
||||
// Read WAV as f32 samples
|
||||
let mut reader = hound::WavReader::open(&wav_path)
|
||||
.map_err(|e| format!("Failed to read WAV: {}", e))?;
|
||||
let spec = reader.spec();
|
||||
let samples: Vec<f32> = match spec.sample_format {
|
||||
hound::SampleFormat::Int => reader
|
||||
.samples::<i16>()
|
||||
.map(|s| s.map(|v| v as f32 / 32768.0).map_err(|e| format!("{}", e)))
|
||||
.collect::<Result<Vec<f32>, _>>()?,
|
||||
hound::SampleFormat::Float => reader
|
||||
.samples::<f32>()
|
||||
.map(|s| s.map_err(|e| format!("{}", e)))
|
||||
.collect::<Result<Vec<f32>, _>>()?,
|
||||
};
|
||||
|
||||
// Load Whisper model and transcribe
|
||||
let ctx_params = WhisperContextParameters::default();
|
||||
let ctx = WhisperContext::new_with_params(&model_path, ctx_params)
|
||||
.map_err(|e| format!("Failed to load model: {:?}", e))?;
|
||||
let mut state = ctx.create_state()
|
||||
.map_err(|e| format!("Failed to create state: {:?}", e))?;
|
||||
|
||||
let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
|
||||
params.set_print_special(false);
|
||||
params.set_print_progress(false);
|
||||
params.set_print_realtime(false);
|
||||
params.set_print_timestamps(false);
|
||||
params.set_token_timestamps(true);
|
||||
params.set_single_segment(false);
|
||||
if let Some(lang) = language {
|
||||
params.set_language(Some(lang));
|
||||
}
|
||||
|
||||
state.full(params, &samples)
|
||||
.map_err(|e| format!("Transcription failed: {:?}", e))?;
|
||||
|
||||
// Extract word-level results using the 0.16.0 iterator API
|
||||
let mut all_words: Vec<Word> = Vec::new();
|
||||
let mut segments: Vec<Segment> = Vec::new();
|
||||
let detected_language = language.unwrap_or("en").to_string();
|
||||
|
||||
for (seg_idx, segment) in state.as_iter().enumerate() {
|
||||
let seg_text = segment.to_str_lossy()
|
||||
.map_err(|e| format!("Segment text error: {:?}", e))?;
|
||||
let seg_t0 = segment.start_timestamp() as f64 / 100.0;
|
||||
let seg_t1 = segment.end_timestamp() as f64 / 100.0;
|
||||
|
||||
let mut seg_words: Vec<Word> = Vec::new();
|
||||
|
||||
for tok_i in 0..segment.n_tokens() {
|
||||
if let Some(token) = segment.get_token(tok_i) {
|
||||
let token_text = match token.to_str_lossy() {
|
||||
Ok(t) => t.into_owned(),
|
||||
Err(_) => continue,
|
||||
};
|
||||
let token_data = token.token_data();
|
||||
|
||||
// Skip special tokens
|
||||
let trimmed = token_text.trim();
|
||||
if trimmed.is_empty() || trimmed.starts_with('[') || trimmed.starts_with('<') {
|
||||
continue;
|
||||
}
|
||||
|
||||
let word = Word {
|
||||
word: trimmed.to_string(),
|
||||
start: token_data.t0 as f64 / 100.0,
|
||||
end: token_data.t1 as f64 / 100.0,
|
||||
confidence: token_data.p as f64,
|
||||
};
|
||||
all_words.push(word.clone());
|
||||
seg_words.push(word);
|
||||
}
|
||||
}
|
||||
|
||||
segments.push(Segment {
|
||||
id: seg_idx,
|
||||
start: seg_t0,
|
||||
end: seg_t1,
|
||||
text: seg_text.trim().to_string(),
|
||||
words: seg_words,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(TranscriptionResult {
|
||||
words: all_words,
|
||||
segments,
|
||||
language: detected_language,
|
||||
})
|
||||
}
|
||||
|
||||
/// Download and cache Whisper model
|
||||
pub fn ensure_model_downloaded(model_name: &str) -> Result<String, String> {
|
||||
// Get app data directory for storing models
|
||||
let app_data_dir = dirs::data_dir()
|
||||
.ok_or("Could not find app data directory")?
|
||||
.join("TalkEdit")
|
||||
.join("models");
|
||||
|
||||
// Create directory if it doesn't exist
|
||||
fs::create_dir_all(&app_data_dir)
|
||||
.map_err(|e| format!("Failed to create models directory: {}", e))?;
|
||||
|
||||
let model_path = app_data_dir.join(format!("ggml-{}.bin", model_name));
|
||||
|
||||
// Check if model already exists
|
||||
if model_path.exists() {
|
||||
return Ok(model_path.to_string_lossy().to_string());
|
||||
}
|
||||
|
||||
// Only download smaller models automatically
|
||||
let allowed_models = ["tiny", "base", "small"];
|
||||
if !allowed_models.contains(&model_name) {
|
||||
return Err(format!("Model '{}' is not available for automatic download. Only tiny, base, and small models are supported.", model_name));
|
||||
}
|
||||
|
||||
println!("Downloading Whisper model: {}...", model_name);
|
||||
|
||||
// Download the model from ggerganov's whisper.cpp repo
|
||||
let url = format!("https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-{}.bin", model_name);
|
||||
let response = ureq::get(&url)
|
||||
.call()
|
||||
.map_err(|e| format!("Failed to download model: {}", e))?;
|
||||
|
||||
let len = response
|
||||
.header("content-length")
|
||||
.and_then(|s| s.parse::<usize>().ok())
|
||||
.unwrap_or(0);
|
||||
|
||||
println!("Model size: {} bytes", len);
|
||||
|
||||
let mut reader = response.into_reader();
|
||||
let mut file = fs::File::create(&model_path)
|
||||
.map_err(|e| format!("Failed to create model file: {}", e))?;
|
||||
|
||||
std::io::copy(&mut reader, &mut file)
|
||||
.map_err(|e| format!("Failed to write model file: {}", e))?;
|
||||
|
||||
println!("Model downloaded successfully: {}", model_path.display());
|
||||
|
||||
Ok(model_path.to_string_lossy().to_string())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user