From e5cc566be82bbebb78f2c27ee13f3a5fafa4a0c8 Mon Sep 17 00:00:00 2001 From: DanConwayDev Date: Fri, 25 Jul 2025 22:26:59 +0100 Subject: refactor: move push helpers to lib to enable forthcoming ngit sync cmd --- src/lib/push.rs | 310 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 310 insertions(+) create mode 100644 src/lib/push.rs (limited to 'src/lib/push.rs') diff --git a/src/lib/push.rs b/src/lib/push.rs new file mode 100644 index 0000000..0d0ec93 --- /dev/null +++ b/src/lib/push.rs @@ -0,0 +1,310 @@ +use std::{ + sync::{Arc, Mutex}, + time::Instant, +}; + +use anyhow::{Result, anyhow}; +use auth_git2::GitAuthenticator; +use console::Term; + +use crate::{ + cli_interactor::count_lines_per_msg_vec, + git::{ + Repo, + nostr_url::{CloneUrl, NostrUrlDecoded}, + oid_to_shorthand_string, + }, + utils::{ + Direction, get_short_git_server_name, get_write_protocols_to_try, join_with_and, + set_protocol_preference, + }, +}; + +pub fn push_to_remote( + git_repo: &Repo, + git_server_url: &str, + decoded_nostr_url: &NostrUrlDecoded, + remote_refspecs: &[String], + term: &Term, + is_grasp_server: bool, +) -> Result<()> { + let server_url = git_server_url.parse::()?; + let protocols_to_attempt = + get_write_protocols_to_try(git_repo, &server_url, decoded_nostr_url, is_grasp_server); + + let mut failed_protocols = vec![]; + let mut success = false; + + for protocol in &protocols_to_attempt { + term.write_line(format!("push: {} over {protocol}...", server_url.short_name(),).as_str())?; + + let formatted_url = server_url.format_as(protocol, &decoded_nostr_url.user)?; + + if let Err(error) = push_to_remote_url(git_repo, &formatted_url, remote_refspecs, term) { + term.write_line( + format!("push: {formatted_url} failed over {protocol}: {error}").as_str(), + )?; + failed_protocols.push(protocol); + } else { + success = true; + if !failed_protocols.is_empty() { + term.write_line(format!("push: succeeded over {protocol}").as_str())?; + let _ = set_protocol_preference(git_repo, protocol, &server_url, &Direction::Push); + } + break; + } + } + if success { + Ok(()) + } else { + let error = anyhow!( + "{} failed over {}{}", + server_url.short_name(), + join_with_and(&failed_protocols), + if decoded_nostr_url.protocol.is_some() { + " and nostr url contains protocol override so no other protocols were attempted" + } else { + "" + }, + ); + term.write_line(format!("push: {error}").as_str())?; + Err(error) + } +} + +pub fn push_to_remote_url( + git_repo: &Repo, + git_server_url: &str, + remote_refspecs: &[String], + term: &Term, +) -> Result<()> { + let git_config = git_repo.git_repo.config()?; + let mut git_server_remote = git_repo.git_repo.remote_anonymous(git_server_url)?; + let auth = GitAuthenticator::default(); + let mut push_options = git2::PushOptions::new(); + let mut remote_callbacks = git2::RemoteCallbacks::new(); + let push_reporter = Arc::new(Mutex::new(PushReporter::new(term))); + + remote_callbacks.credentials(auth.credentials(&git_config)); + + remote_callbacks.push_update_reference({ + let push_reporter = Arc::clone(&push_reporter); + move |name, error| { + let mut reporter = push_reporter.lock().unwrap(); + if let Some(error) = error { + let existing_lines = reporter.count_all_existing_lines(); + reporter.update_reference_errors.push(format!( + "WARNING: {} failed to push {name} error: {error}", + get_short_git_server_name(git_repo, git_server_url), + )); + reporter.write_all(existing_lines); + } + Ok(()) + } + }); + + remote_callbacks.push_negotiation({ + let push_reporter = Arc::clone(&push_reporter); + move |updates| { + let mut reporter = push_reporter.lock().unwrap(); + let existing_lines = reporter.count_all_existing_lines(); + + for update in updates { + let dst_refname = update + .dst_refname() + .unwrap_or("") + .replace("refs/heads/", "") + .replace("refs/tags/", "tags/"); + let msg = if update.dst().is_zero() { + format!("push: - [delete] {dst_refname}") + } else if update.src().is_zero() { + if update.dst_refname().unwrap_or("").contains("refs/tags") { + format!("push: * [new tag] {dst_refname}") + } else { + format!("push: * [new branch] {dst_refname}") + } + } else { + let force = remote_refspecs + .iter() + .any(|r| r.contains(&dst_refname) && r.contains('+')); + format!( + "push: {} {}..{} {} -> {dst_refname}", + if force { "+" } else { " " }, + oid_to_shorthand_string(update.src()).unwrap(), + oid_to_shorthand_string(update.dst()).unwrap(), + update + .src_refname() + .unwrap_or("") + .replace("refs/heads/", "") + .replace("refs/tags/", "tags/"), + ) + }; + // other possibilities will result in push to fail but better reporting is + // needed: + // deleting a non-existant branch: + // ! [remote rejected] -> (not found) + // adding a branch that already exists: + // ! [remote rejected] -> (already exists) + // pushing without non-fast-forward without force: + // ! [rejected] -> (non-fast-forward) + reporter.negotiation.push(msg); + } + reporter.write_all(existing_lines); + Ok(()) + } + }); + + remote_callbacks.push_transfer_progress({ + let push_reporter = Arc::clone(&push_reporter); + #[allow(clippy::cast_precision_loss)] + move |current, total, bytes| { + let mut reporter = push_reporter.lock().unwrap(); + reporter.process_transfer_progress_update(current, total, bytes); + } + }); + + remote_callbacks.sideband_progress({ + let push_reporter = Arc::clone(&push_reporter); + move |data| { + let mut reporter = push_reporter.lock().unwrap(); + reporter.process_remote_msg(data); + true + } + }); + push_options.remote_callbacks(remote_callbacks); + git_server_remote.push(remote_refspecs, Some(&mut push_options))?; + let _ = git_server_remote.disconnect(); + Ok(()) +} + +#[allow(clippy::cast_precision_loss)] +#[allow(clippy::float_cmp)] +#[allow(clippy::needless_pass_by_value)] +fn report_on_transfer_progress( + current: usize, + total: usize, + bytes: usize, + start_time: &Instant, + end_time: Option<&Instant>, +) -> Option { + if total == 0 { + return None; + } + let percentage = ((current as f64 / total as f64) * 100.0) + // always round down because 100% complete is misleading when its not complete + .floor(); + let (size, unit) = if bytes as f64 >= (1024.0 * 1024.0) { + (bytes as f64 / (1024.0 * 1024.0), "MiB") + } else { + (bytes as f64 / 1024.0, "KiB") + }; + let speed = { + let duration = if let Some(end_time) = end_time { + (*end_time - *start_time).as_millis() as f64 + } else { + start_time.elapsed().as_millis() as f64 + }; + + if duration > 0.0 { + (bytes as f64 / (1024.0 * 1024.0)) / (duration / 1000.0) // Convert bytes to MiB and milliseconds to seconds + } else { + 0.0 + } + }; + + Some(format!( + "push: Writing objects: {percentage}% ({current}/{total}) {size:.2} {unit} | {speed:.2} MiB/s{}", + if current == total { ", done." } else { "" }, + )) +} + +pub struct PushReporter<'a> { + remote_msgs: Vec, + negotiation: Vec, + transfer_progress_msgs: Vec, + update_reference_errors: Vec, + term: &'a console::Term, + start_time: Option, + end_time: Option, +} +impl<'a> PushReporter<'a> { + fn new(term: &'a console::Term) -> Self { + Self { + remote_msgs: vec![], + negotiation: vec![], + transfer_progress_msgs: vec![], + update_reference_errors: vec![], + term, + start_time: None, + end_time: None, + } + } + fn write_all(&self, lines_to_clear: usize) { + let _ = self.term.clear_last_lines(lines_to_clear); + for msg in &self.remote_msgs { + let _ = self.term.write_line(format!("remote: {msg}").as_str()); + } + for msg in &self.negotiation { + let _ = self.term.write_line(msg); + } + for msg in &self.transfer_progress_msgs { + let _ = self.term.write_line(msg); + } + for msg in &self.update_reference_errors { + let _ = self.term.write_line(msg); + } + } + + fn count_all_existing_lines(&self) -> usize { + let width = self.term.size().1; + count_lines_per_msg_vec(width, &self.remote_msgs, "remote: ".len()) + + count_lines_per_msg_vec(width, &self.negotiation, 0) + + count_lines_per_msg_vec(width, &self.transfer_progress_msgs, 0) + + count_lines_per_msg_vec(width, &self.update_reference_errors, 0) + } + fn process_remote_msg(&mut self, data: &[u8]) { + if let Ok(data) = str::from_utf8(data) { + let data = data + .split(['\n', '\r']) + .map(str::trim) + .filter(|line| !line.trim().is_empty()) + .collect::>(); + for data in data { + let existing_lines = self.count_all_existing_lines(); + let msg = data.to_string(); + if let Some(last) = self.remote_msgs.last() { + if (last.contains('%') && !last.contains("100%")) + || last == &msg.replace(", done.", "") + { + self.remote_msgs.pop(); + self.remote_msgs.push(msg); + } else { + self.remote_msgs.push(msg); + } + } else { + self.remote_msgs.push(msg); + } + self.write_all(existing_lines); + } + } + } + fn process_transfer_progress_update(&mut self, current: usize, total: usize, bytes: usize) { + if self.start_time.is_none() { + self.start_time = Some(Instant::now()); + } + if let Some(report) = report_on_transfer_progress( + current, + total, + bytes, + &self.start_time.unwrap(), + self.end_time.as_ref(), + ) { + let existing_lines = self.count_all_existing_lines(); + if report.contains("100%") { + self.end_time = Some(Instant::now()); + } + self.transfer_progress_msgs = vec![report]; + self.write_all(existing_lines); + } + } +} -- cgit v1.2.3