From 902010f6100564846f133ef5282f8a8e5aa7cf3b Mon Sep 17 00:00:00 2001
From: Wiktor Kwapisiewicz <wiktor@metacode.biz>
Date: Wed, 15 Jan 2025 17:17:02 +0100
Subject: [PATCH] Allow setting hash algorithm to use for signing requests of
 SSH agent (#449)

---
 russh/src/auth.rs              |  8 ++++++--
 russh/src/client/mod.rs        |  5 +++--
 russh/src/keys/agent/client.rs | 23 +++++++++++------------
 russh/src/keys/mod.rs          |  6 +++---
 4 files changed, 23 insertions(+), 19 deletions(-)

diff --git a/russh/src/auth.rs b/russh/src/auth.rs
index 1cdf2579..cf53070e 100644
--- a/russh/src/auth.rs
+++ b/russh/src/auth.rs
@@ -18,7 +18,7 @@ use std::str::FromStr;
 use std::sync::Arc;
 
 use async_trait::async_trait;
-use ssh_key::{Certificate, PrivateKey};
+use ssh_key::{Certificate, HashAlg, PrivateKey};
 use thiserror::Error;
 use tokio::io::{AsyncRead, AsyncWrite};
 
@@ -154,6 +154,7 @@ pub trait Signer: Sized {
     async fn auth_publickey_sign(
         &mut self,
         key: &ssh_key::PublicKey,
+        hash_alg: Option<HashAlg>,
         to_sign: CryptoVec,
     ) -> Result<CryptoVec, Self::Error>;
 }
@@ -175,9 +176,12 @@ impl<R: AsyncRead + AsyncWrite + Unpin + Send + 'static> Signer
     async fn auth_publickey_sign(
         &mut self,
         key: &ssh_key::PublicKey,
+        hash_alg: Option<HashAlg>,
         to_sign: CryptoVec,
     ) -> Result<CryptoVec, Self::Error> {
-        self.sign_request(key, to_sign).await.map_err(Into::into)
+        self.sign_request(key, hash_alg, to_sign)
+            .await
+            .map_err(Into::into)
     }
 }
 
diff --git a/russh/src/client/mod.rs b/russh/src/client/mod.rs
index f50830d8..a5b6d1ed 100644
--- a/russh/src/client/mod.rs
+++ b/russh/src/client/mod.rs
@@ -47,7 +47,7 @@ use kex::ClientKex;
 use log::{debug, error, trace};
 use russh_util::time::Instant;
 use ssh_encoding::Decode;
-use ssh_key::{Certificate, PrivateKey, PublicKey};
+use ssh_key::{Certificate, HashAlg, PrivateKey, PublicKey};
 use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
 use tokio::pin;
 use tokio::sync::mpsc::{
@@ -401,6 +401,7 @@ impl<H: Handler> Handle<H> {
         &mut self,
         user: U,
         key: ssh_key::PublicKey,
+        hash_alg: Option<HashAlg>,
         signer: &mut S,
     ) -> Result<AuthResult, S::Error> {
         let user = user.into();
@@ -423,7 +424,7 @@ impl<H: Handler> Handle<H> {
                     proceed_with_methods: remaining_methods,
                 }) => return Ok(AuthResult::Failure { remaining_methods }),
                 Some(Reply::SignRequest { key, data }) => {
-                    let data = signer.auth_publickey_sign(&key, data).await;
+                    let data = signer.auth_publickey_sign(&key, hash_alg, data).await;
                     let data = match data {
                         Ok(data) => data,
                         Err(e) => return Err(e),
diff --git a/russh/src/keys/agent/client.rs b/russh/src/keys/agent/client.rs
index 7969df8a..de9fedb7 100644
--- a/russh/src/keys/agent/client.rs
+++ b/russh/src/keys/agent/client.rs
@@ -4,7 +4,7 @@ use byteorder::{BigEndian, ByteOrder};
 use bytes::Bytes;
 use log::debug;
 use ssh_encoding::{Decode, Encode, Reader};
-use ssh_key::{Algorithm, HashAlg, PrivateKey, PublicKey, Signature};
+use ssh_key::{HashAlg, PrivateKey, PublicKey, Signature};
 use tokio;
 use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
 
@@ -284,10 +284,11 @@ impl<S: AgentStream + Unpin> AgentClient<S> {
     pub async fn sign_request(
         &mut self,
         public: &PublicKey,
+        hash_alg: Option<HashAlg>,
         mut data: CryptoVec,
     ) -> Result<CryptoVec, Error> {
         debug!("sign_request: {:?}", data);
-        let hash = self.prepare_sign_request(public, &data)?;
+        let hash = self.prepare_sign_request(public, hash_alg, &data)?;
 
         self.read_response().await?;
 
@@ -307,6 +308,7 @@ impl<S: AgentStream + Unpin> AgentClient<S> {
     fn prepare_sign_request(
         &mut self,
         public: &ssh_key::PublicKey,
+        hash_alg: Option<HashAlg>,
         data: &[u8],
     ) -> Result<u32, Error> {
         self.buf.clear();
@@ -315,14 +317,9 @@ impl<S: AgentStream + Unpin> AgentClient<S> {
         public.key_data().encoded()?.encode(&mut self.buf)?;
         data.encode(&mut self.buf)?;
         debug!("public = {:?}", public);
-        let hash = match public.algorithm() {
-            Algorithm::Rsa {
-                hash: Some(HashAlg::Sha256),
-            } => 2,
-            Algorithm::Rsa {
-                hash: Some(HashAlg::Sha512),
-            } => 4,
-            Algorithm::Rsa { hash: None } => 0,
+        let hash = match hash_alg {
+            Some(HashAlg::Sha256) => 2,
+            Some(HashAlg::Sha512) => 4,
             _ => 0,
         };
         hash.encode(&mut self.buf)?;
@@ -352,10 +349,11 @@ impl<S: AgentStream + Unpin> AgentClient<S> {
     pub fn sign_request_base64(
         mut self,
         public: &ssh_key::PublicKey,
+        hash_alg: Option<HashAlg>,
         data: &[u8],
     ) -> impl futures::Future<Output = (Self, Result<String, Error>)> {
         debug!("sign_request: {:?}", data);
-        let r = self.prepare_sign_request(public, data);
+        let r = self.prepare_sign_request(public, hash_alg, data);
         async move {
             if let Err(e) = r {
                 return (self, Err(e));
@@ -380,11 +378,12 @@ impl<S: AgentStream + Unpin> AgentClient<S> {
     pub async fn sign_request_signature(
         &mut self,
         public: &ssh_key::PublicKey,
+        hash_alg: Option<HashAlg>,
         data: &[u8],
     ) -> Result<Signature, Error> {
         debug!("sign_request: {:?}", data);
 
-        self.prepare_sign_request(public, data)?;
+        self.prepare_sign_request(public, hash_alg, data)?;
         self.read_response().await?;
 
         match self.buf.split_first() {
diff --git a/russh/src/keys/mod.rs b/russh/src/keys/mod.rs
index b92f9599..f5889e61 100644
--- a/russh/src/keys/mod.rs
+++ b/russh/src/keys/mod.rs
@@ -45,7 +45,7 @@
 //!        client.add_identity(&key, &[agent::Constraint::KeyLifetime { seconds: 60 }]).await?;
 //!        client.request_identities().await?;
 //!        let buf = b"signed message";
-//!        let sig = client.sign_request(&public, russh_cryptovec::CryptoVec::from_slice(&buf[..])).await.unwrap();
+//!        let sig = client.sign_request(&public, None, russh_cryptovec::CryptoVec::from_slice(&buf[..])).await.unwrap();
 //!        // Here, `sig` is encoded in a format usable internally by the SSH protocol.
 //!        Ok::<(), Error>(())
 //!    }).unwrap()
@@ -849,7 +849,7 @@ Cog3JMeTrb3LiPHgN6gU2P30MRp6L1j1J/MtlOAr5rux
         client.request_identities().await?;
         let buf = russh_cryptovec::CryptoVec::from_slice(b"blabla");
         let len = buf.len();
-        let buf = client.sign_request(public, buf).await.unwrap();
+        let buf = client.sign_request(public, None, buf).await.unwrap();
         let (a, b) = buf.split_at(len);
 
         match key.public_key().key_data() {
@@ -935,7 +935,7 @@ Cog3JMeTrb3LiPHgN6gU2P30MRp6L1j1J/MtlOAr5rux
             client.request_identities().await.unwrap();
             let buf = russh_cryptovec::CryptoVec::from_slice(b"blabla");
             let len = buf.len();
-            let buf = client.sign_request(public, buf).await.unwrap();
+            let buf = client.sign_request(public, None, buf).await.unwrap();
             let (a, b) = buf.split_at(len);
             if let ssh_key::public::KeyData::Ed25519 { .. } = public.key_data() {
                 let sig = &b[b.len() - 64..];