diff --git a/.github/workflows/semver.yml b/.github/workflows/semver.yml index 661a1019..90a8ca4f 100644 --- a/.github/workflows/semver.yml +++ b/.github/workflows/semver.yml @@ -1,4 +1,4 @@ -name: Rust +name: Semver check on: push: @@ -21,5 +21,6 @@ jobs: - name: Check semver compatibility (russh) uses: obi1kenobi/cargo-semver-checks-action@v2 + continue-on-error: true with: package: russh diff --git a/Cargo.toml b/Cargo.toml index 83190951..aee683da 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,7 +19,9 @@ russh-config = { path = "russh-config" } aes = "0.8" async-trait = "0.1" byteorder = "1.4" +bytes = "1.7" digest = "0.10" +delegate = "0.13" futures = "0.3" hmac = "0.12" log = "0.4" @@ -27,7 +29,7 @@ rand = "0.8" sha1 = { version = "0.10", features = ["oid"] } sha2 = { version = "0.10", features = ["oid"] } signature = "2.2" -ssh-encoding = "0.2" +ssh-encoding = { version = "0.2", features = ["bytes"] } ssh-key = { version = "0.6", features = [ "ed25519", "rsa", diff --git a/cryptovec/Cargo.toml b/cryptovec/Cargo.toml index fafe1c8d..9e8a528d 100644 --- a/cryptovec/Cargo.toml +++ b/cryptovec/Cargo.toml @@ -7,14 +7,18 @@ include = ["Cargo.toml", "src/lib.rs"] license = "Apache-2.0" name = "russh-cryptovec" repository = "https://github.com/warp-tech/russh" -version = "0.7.3" +version = "0.8.0-beta.1" rust-version = "1.60" [dependencies] libc = "0.2" +ssh-encoding = { workspace = true, optional = true } [target.'cfg(target_os = "windows")'.dependencies] winapi = {version = "0.3", features = ["basetsd", "minwindef", "memoryapi"]} [dev-dependencies] -wasm-bindgen-test = "0.3" \ No newline at end of file +wasm-bindgen-test = "0.3" + +[features] +ssh-encoding = ["dep:ssh-encoding"] diff --git a/cryptovec/src/cryptovec.rs b/cryptovec/src/cryptovec.rs index 1821ddab..8988cc29 100644 --- a/cryptovec/src/cryptovec.rs +++ b/cryptovec/src/cryptovec.rs @@ -1,6 +1,6 @@ use std::ops::{Deref, DerefMut, Index, IndexMut, Range, RangeFrom, RangeFull, RangeTo}; -use crate::platform::{self, memcpy, memset, mlock, munlock}; +use crate::platform::{self, memset, mlock, munlock}; /// A buffer which zeroes its memory on `.clear()`, `.resize()`, and /// reallocations, to avoid copying secrets around. @@ -246,38 +246,6 @@ impl CryptoVec { unsafe { *self.p.add(size) = s } } - /// Append a new u32, big endian-encoded, at the end of this CryptoVec. - /// - /// ``` - /// let mut v = russh_cryptovec::CryptoVec::new(); - /// let n = 43554; - /// v.push_u32_be(n); - /// assert_eq!(n, v.read_u32_be(0)) - /// ``` - pub fn push_u32_be(&mut self, s: u32) { - let s = s.to_be(); - let x: [u8; 4] = s.to_ne_bytes(); - self.extend(&x) - } - - /// Read a big endian-encoded u32 from this CryptoVec, with the - /// first byte at position `i`. - /// - /// ``` - /// let mut v = russh_cryptovec::CryptoVec::new(); - /// let n = 99485710; - /// v.push_u32_be(n); - /// assert_eq!(n, v.read_u32_be(0)) - /// ``` - pub fn read_u32_be(&self, i: usize) -> u32 { - assert!(i + 4 <= self.size); - let mut x: u32 = 0; - unsafe { - memcpy((&mut x) as *mut u32, self.p.add(i), 4); - } - u32::from_be(x) - } - /// Read `n_bytes` from `r`, and append them at the end of this /// `CryptoVec`. Returns the number of bytes read (and appended). pub fn read( diff --git a/cryptovec/src/lib.rs b/cryptovec/src/lib.rs index e2b3d54d..c1f4f778 100644 --- a/cryptovec/src/lib.rs +++ b/cryptovec/src/lib.rs @@ -26,3 +26,6 @@ pub use cryptovec::CryptoVec; // Platform-specific modules mod platform; + +#[cfg(feature = "ssh-encoding")] +mod ssh; diff --git a/cryptovec/src/platform/mod.rs b/cryptovec/src/platform/mod.rs index 78fbbffe..03ca1099 100644 --- a/cryptovec/src/platform/mod.rs +++ b/cryptovec/src/platform/mod.rs @@ -11,11 +11,11 @@ mod wasm; // Re-export functions based on the platform #[cfg(not(windows))] #[cfg(not(target_arch = "wasm32"))] -pub use unix::{memcpy, memset, mlock, munlock}; +pub use unix::{memset, mlock, munlock}; #[cfg(target_arch = "wasm32")] -pub use wasm::{memcpy, memset, mlock, munlock}; +pub use wasm::{memset, mlock, munlock}; #[cfg(windows)] -pub use windows::{memcpy, memset, mlock, munlock}; +pub use windows::{memset, mlock, munlock}; #[cfg(test)] mod tests { diff --git a/cryptovec/src/platform/unix.rs b/cryptovec/src/platform/unix.rs index 0f7ed9e5..611a8bd6 100644 --- a/cryptovec/src/platform/unix.rs +++ b/cryptovec/src/platform/unix.rs @@ -24,9 +24,3 @@ pub fn memset(ptr: *mut u8, value: i32, size: usize) { libc::memset(ptr as *mut c_void, value, size); } } - -pub fn memcpy(dest: *mut u32, src: *const u8, size: usize) { - unsafe { - libc::memcpy(dest as *mut c_void, src as *const c_void, size); - } -} diff --git a/cryptovec/src/ssh.rs b/cryptovec/src/ssh.rs new file mode 100644 index 00000000..846dd793 --- /dev/null +++ b/cryptovec/src/ssh.rs @@ -0,0 +1,20 @@ +use ssh_encoding::{Reader, Result, Writer}; + +use crate::CryptoVec; + +impl Reader for CryptoVec { + fn read<'o>(&mut self, out: &'o mut [u8]) -> Result<&'o [u8]> { + (&self[..]).read(out) + } + + fn remaining_len(&self) -> usize { + self.len() + } +} + +impl Writer for CryptoVec { + fn write(&mut self, bytes: &[u8]) -> Result<()> { + self.extend(bytes); + Ok(()) + } +} diff --git a/pageant/Cargo.toml b/pageant/Cargo.toml index 47e21c9b..fdeb007f 100644 --- a/pageant/Cargo.toml +++ b/pageant/Cargo.toml @@ -14,8 +14,8 @@ futures = { workspace = true } thiserror = { workspace = true } rand = { workspace = true } tokio = { workspace = true, features = ["io-util", "rt"] } -bytes = "1.7" -delegate = "0.13" +bytes = { workspace = true } +delegate.workspace = true [target.'cfg(windows)'.dependencies] windows = { version = "0.58", features = [ diff --git a/russh-keys/Cargo.toml b/russh-keys/Cargo.toml index 5b832494..badd2ac4 100644 --- a/russh-keys/Cargo.toml +++ b/russh-keys/Cargo.toml @@ -15,6 +15,7 @@ rust-version = "1.65" aes = { workspace = true } async-trait = { workspace = true } bcrypt-pbkdf = "0.10" +bytes = { workspace = true } cbc = "0.1" ctr = "0.9" block-padding = { version = "0.3", features = ["std"] } @@ -41,7 +42,9 @@ pkcs8 = { version = "0.10", features = ["pkcs5", "encryption"] } rand = { workspace = true } rand_core = { version = "0.6.4", features = ["std"] } rsa = "0.9" -russh-cryptovec = { version = "0.7.0", path = "../cryptovec" } +russh-cryptovec = { version = "0.8.0-beta.1", path = "../cryptovec", features = [ + "ssh-encoding", +] } russh-util = { version = "0.46.0", path = "../russh-util" } sec1 = { version = "0.7", features = ["pkcs8"] } serde = { version = "1.0", features = ["derive"] } @@ -53,13 +56,13 @@ ssh-encoding = { workspace = true } ssh-key = { workspace = true } thiserror = { workspace = true } typenum = "1.17" -yasna = { version = "0.5.0", features = ["bit-vec", "num-bigint"], optional = true } +yasna = { version = "0.5.0", features = [ + "bit-vec", + "num-bigint", +], optional = true } zeroize = "1.7" getrandom = { version = "0.2.15", features = ["js"] } -tokio = { workspace = true, features = [ - "io-util", - "time", -] } +tokio = { workspace = true, features = ["io-util", "time"] } [features] legacy-ed25519-pkcs8-parser = ["yasna"] diff --git a/russh-keys/src/agent/client.rs b/russh-keys/src/agent/client.rs index c707a356..e638b753 100644 --- a/russh-keys/src/agent/client.rs +++ b/russh-keys/src/agent/client.rs @@ -1,14 +1,15 @@ use core::str; use byteorder::{BigEndian, ByteOrder}; +use bytes::Bytes; use log::debug; use russh_cryptovec::CryptoVec; +use ssh_encoding::{Decode, Encode, Reader}; use ssh_key::{Algorithm, HashAlg, PrivateKey, PublicKey, Signature}; use tokio; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use super::{msg, Constraint}; -use crate::encoding::{Encoding, Reader}; use crate::helpers::EncodedExt; use crate::{key, Error}; @@ -156,24 +157,24 @@ impl AgentClient { self.buf.push(msg::ADD_ID_CONSTRAINED) } - self.buf.extend(key.key_data().encoded()?.as_slice()); - self.buf.extend_ssh_string(&[]); // comment field + key.key_data().encode(&mut self.buf)?; + "".encode(&mut self.buf)?; // comment field if !constraints.is_empty() { for cons in constraints { match *cons { Constraint::KeyLifetime { seconds } => { - self.buf.push(msg::CONSTRAIN_LIFETIME); - self.buf.push_u32_be(seconds); + msg::CONSTRAIN_LIFETIME.encode(&mut self.buf)?; + seconds.encode(&mut self.buf)?; } Constraint::Confirm => self.buf.push(msg::CONSTRAIN_CONFIRM), Constraint::Extensions { ref name, ref details, } => { - self.buf.push(msg::CONSTRAIN_EXTENSION); - self.buf.extend_ssh_string(name); - self.buf.extend_ssh_string(details); + msg::CONSTRAIN_EXTENSION.encode(&mut self.buf)?; + name.encode(&mut self.buf)?; + details.encode(&mut self.buf)?; } } } @@ -200,24 +201,24 @@ impl AgentClient { } else { self.buf.push(msg::ADD_SMARTCARD_KEY_CONSTRAINED) } - self.buf.extend_ssh_string(id.as_bytes()); - self.buf.extend_ssh_string(pin); + id.encode(&mut self.buf)?; + pin.encode(&mut self.buf)?; if !constraints.is_empty() { - self.buf.push_u32_be(constraints.len() as u32); + (constraints.len() as u32).encode(&mut self.buf)?; for cons in constraints { match *cons { Constraint::KeyLifetime { seconds } => { - self.buf.push(msg::CONSTRAIN_LIFETIME); - self.buf.push_u32_be(seconds) + msg::CONSTRAIN_LIFETIME.encode(&mut self.buf)?; + seconds.encode(&mut self.buf)?; } Constraint::Confirm => self.buf.push(msg::CONSTRAIN_CONFIRM), Constraint::Extensions { ref name, ref details, } => { - self.buf.push(msg::CONSTRAIN_EXTENSION); - self.buf.extend_ssh_string(name); - self.buf.extend_ssh_string(details); + msg::CONSTRAIN_EXTENSION.encode(&mut self.buf)?; + name.encode(&mut self.buf)?; + details.encode(&mut self.buf)?; } } } @@ -233,7 +234,7 @@ impl AgentClient { self.buf.clear(); self.buf.resize(4); self.buf.push(msg::LOCK); - self.buf.extend_ssh_string(passphrase); + passphrase.encode(&mut self.buf)?; let len = self.buf.len() - 4; BigEndian::write_u32(&mut self.buf[..], len as u32); self.read_response().await?; @@ -244,8 +245,8 @@ impl AgentClient { pub async fn unlock(&mut self, passphrase: &[u8]) -> Result<(), Error> { self.buf.clear(); self.buf.resize(4); - self.buf.push(msg::UNLOCK); - self.buf.extend_ssh_string(passphrase); + msg::UNLOCK.encode(&mut self.buf)?; + passphrase.encode(&mut self.buf)?; let len = self.buf.len() - 4; #[allow(clippy::indexing_slicing)] // static length BigEndian::write_u32(&mut self.buf[..], len as u32); @@ -258,7 +259,7 @@ impl AgentClient { pub async fn request_identities(&mut self) -> Result, Error> { self.buf.clear(); self.buf.resize(4); - self.buf.push(msg::REQUEST_IDENTITIES); + msg::REQUEST_IDENTITIES.encode(&mut self.buf)?; let len = self.buf.len() - 4; BigEndian::write_u32(&mut self.buf[..], len as u32); @@ -267,13 +268,12 @@ impl AgentClient { let mut keys = Vec::new(); #[allow(clippy::indexing_slicing)] // static length - if self.buf[0] == msg::IDENTITIES_ANSWER { - let mut r = self.buf.reader(1); - let n = r.read_u32()?; + if let Some((&msg::IDENTITIES_ANSWER, mut r)) = self.buf.split_first() { + let n = u32::decode(&mut r)?; for _ in 0..n { - let key_blob = r.read_string()?; - let _comment = r.read_string()?; - keys.push(key::parse_public_key(key_blob)?); + let key_blob = Bytes::decode(&mut r)?; + let _comment = String::decode(&mut r)?; + keys.push(key::parse_public_key(&key_blob)?); } } @@ -291,14 +291,16 @@ impl AgentClient { self.read_response().await?; - if self.buf.first() == Some(&msg::SIGN_RESPONSE) { - self.write_signature(hash, &mut data)?; - Ok(data) - } else if self.buf.first() == Some(&msg::FAILURE) { - Err(Error::AgentFailure) - } else { - debug!("self.buf = {:?}", &self.buf[..]); - Ok(data) + match self.buf.split_first() { + Some((&msg::SIGN_RESPONSE, mut r)) => { + self.write_signature(&mut r, hash, &mut data)?; + Ok(data) + } + Some((&msg::FAILURE, _)) => Err(Error::AgentFailure), + _ => { + debug!("self.buf = {:?}", &self.buf[..]); + Err(Error::AgentProtocolError) + } } } @@ -309,9 +311,9 @@ impl AgentClient { ) -> Result { self.buf.clear(); self.buf.resize(4); - self.buf.push(msg::SIGN_REQUEST); - key_blob(public, &mut self.buf)?; - self.buf.extend_ssh_string(data); + msg::SIGN_REQUEST.encode(&mut self.buf)?; + public.key_data().encoded()?.encode(&mut self.buf)?; + data.encode(&mut self.buf)?; debug!("public = {:?}", public); let hash = match public.algorithm() { Algorithm::Rsa { @@ -323,21 +325,25 @@ impl AgentClient { Algorithm::Rsa { hash: None } => 0, _ => 0, }; - self.buf.push_u32_be(hash); + hash.encode(&mut self.buf)?; let len = self.buf.len() - 4; BigEndian::write_u32(&mut self.buf[..], len as u32); Ok(hash) } - fn write_signature(&self, hash: u32, data: &mut CryptoVec) -> Result<(), Error> { - let mut r = self.buf.reader(1); - let mut resp = r.read_string()?.reader(0); - let t = resp.read_string()?; - if (hash == 2 && t == b"rsa-sha2-256") || (hash == 4 && t == b"rsa-sha2-512") || hash == 0 { - let sig = resp.read_string()?; - data.push_u32_be((t.len() + sig.len() + 8) as u32); - data.extend_ssh_string(t); - data.extend_ssh_string(sig); + fn write_signature( + &self, + r: &mut R, + hash: u32, + data: &mut CryptoVec, + ) -> Result<(), Error> { + let mut resp = &Bytes::decode(r)?[..]; + let t = String::decode(&mut resp)?; + if (hash == 2 && t == "rsa-sha2-256") || (hash == 4 && t == "rsa-sha2-512") || hash == 0 { + let sig = Bytes::decode(&mut resp)?; + (t.len() + sig.len() + 8).encode(data)?; + t.encode(data)?; + sig.encode(data)?; } Ok(()) } @@ -381,17 +387,13 @@ impl AgentClient { self.prepare_sign_request(public, data)?; self.read_response().await?; - #[allow(clippy::indexing_slicing)] // length is checked - if !self.buf.is_empty() && self.buf[0] == msg::SIGN_RESPONSE { - let mut r = self.buf.reader(1); - let mut resp = r.read_string()?.reader(0); - let typ = String::from_utf8(resp.read_string()?.into())?; - let sig = resp.read_string()?; - let algo = Algorithm::new(&typ)?; - let sig = Signature::new(algo, sig.to_vec())?; - Ok(sig) - } else { - Err(Error::AgentProtocolError) + match self.buf.split_first() { + Some((&msg::SIGN_RESPONSE, mut r)) => { + let mut resp = &Bytes::decode(&mut r)?[..]; + let sig = Signature::decode(&mut resp)?; + Ok(sig) + } + _ => Err(Error::AgentProtocolError), } } @@ -400,7 +402,7 @@ impl AgentClient { self.buf.clear(); self.buf.resize(4); self.buf.push(msg::REMOVE_IDENTITY); - key_blob(public, &mut self.buf)?; + public.key_data().encoded()?.encode(&mut self.buf)?; let len = self.buf.len() - 4; BigEndian::write_u32(&mut self.buf[..], len as u32); self.read_response().await?; @@ -411,9 +413,9 @@ impl AgentClient { pub async fn remove_smartcard_key(&mut self, id: &str, pin: &[u8]) -> Result<(), Error> { self.buf.clear(); self.buf.resize(4); - self.buf.push(msg::REMOVE_SMARTCARD_KEY); - self.buf.extend_ssh_string(id.as_bytes()); - self.buf.extend_ssh_string(pin); + msg::REMOVE_SMARTCARD_KEY.encode(&mut self.buf)?; + id.encode(&mut self.buf)?; + pin.encode(&mut self.buf)?; let len = self.buf.len() - 4; BigEndian::write_u32(&mut self.buf[..], len as u32); self.read_response().await?; @@ -424,8 +426,8 @@ impl AgentClient { pub async fn remove_all_identities(&mut self) -> Result<(), Error> { self.buf.clear(); self.buf.resize(4); - self.buf.push(msg::REMOVE_ALL_IDENTITIES); - BigEndian::write_u32(&mut self.buf[..], 1); + msg::REMOVE_ALL_IDENTITIES.encode(&mut self.buf)?; + 1u32.encode(&mut self.buf)?; self.read_success().await?; Ok(()) } @@ -434,11 +436,11 @@ impl AgentClient { pub async fn extension(&mut self, typ: &[u8], ext: &[u8]) -> Result<(), Error> { self.buf.clear(); self.buf.resize(4); - self.buf.push(msg::EXTENSION); - self.buf.extend_ssh_string(typ); - self.buf.extend_ssh_string(ext); + msg::EXTENSION.encode(&mut self.buf)?; + typ.encode(&mut self.buf)?; + ext.encode(&mut self.buf)?; let len = self.buf.len() - 4; - BigEndian::write_u32(&mut self.buf[..], len as u32); + (len as u32).encode(&mut self.buf)?; self.read_response().await?; Ok(()) } @@ -447,21 +449,18 @@ impl AgentClient { pub async fn query_extension(&mut self, typ: &[u8], mut ext: CryptoVec) -> Result { self.buf.clear(); self.buf.resize(4); - self.buf.push(msg::EXTENSION); - self.buf.extend_ssh_string(typ); + msg::EXTENSION.encode(&mut self.buf)?; + typ.encode(&mut self.buf)?; let len = self.buf.len() - 4; - BigEndian::write_u32(&mut self.buf[..], len as u32); + (len as u32).encode(&mut self.buf)?; self.read_response().await?; - let mut r = self.buf.reader(1); - ext.extend(r.read_string()?); - - #[allow(clippy::indexing_slicing)] // length is checked - Ok(!self.buf.is_empty() && self.buf[0] == msg::SUCCESS) + match self.buf.split_first() { + Some((&msg::SUCCESS, mut r)) => { + ext.extend(&Bytes::decode(&mut r)?); + Ok(true) + } + _ => Ok(false), + } } } - -fn key_blob(public: &ssh_key::PublicKey, buf: &mut CryptoVec) -> Result<(), Error> { - buf.extend_ssh_string(public.key_data().encoded()?.as_slice()); - Ok(()) -} diff --git a/russh-keys/src/agent/server.rs b/russh-keys/src/agent/server.rs index 0dc3caa3..984310a1 100644 --- a/russh-keys/src/agent/server.rs +++ b/russh-keys/src/agent/server.rs @@ -5,18 +5,19 @@ use std::time::{Duration, SystemTime}; use async_trait::async_trait; use byteorder::{BigEndian, ByteOrder}; +use bytes::Bytes; use futures::future::Future; use futures::stream::{Stream, StreamExt}; use russh_cryptovec::CryptoVec; +use ssh_encoding::{Decode, Encode, Reader}; use ssh_key::PrivateKey; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::time::sleep; use {std, tokio}; use super::{msg, Constraint}; -use crate::encoding::{Encoding, Position, Reader}; use crate::helpers::EncodedExt; -use crate::{add_signature, Error}; +use crate::Error; #[derive(Clone)] #[allow(clippy::type_complexity)] @@ -127,26 +128,30 @@ impl { + + match self.buf.split_first() { + Some((&11, _)) + if !is_locked && agentref.confirm_request(MessageType::RequestKeys).await => + { // request identities if let Ok(keys) = self.keys.0.read() { - writebuf.push(msg::IDENTITIES_ANSWER); - writebuf.push_u32_be(keys.len() as u32); + msg::IDENTITIES_ANSWER.encode(writebuf)?; + (keys.len() as u32).encode(writebuf)?; for (k, _) in keys.iter() { - writebuf.extend_ssh_string(k); - writebuf.extend_ssh_string(b""); + k.encode(writebuf)?; + "".encode(writebuf)?; } } else { - writebuf.push(msg::FAILURE) + msg::FAILURE.encode(writebuf)? } } - Ok(13) if !is_locked && agentref.confirm_request(MessageType::Sign).await => { + Some((&13, mut r)) + if !is_locked && agentref.confirm_request(MessageType::Sign).await => + { // sign request let agent = self.agent.take().ok_or(Error::AgentFailure)?; - let (agent, signed) = self.try_sign(agent, r, writebuf).await?; + let (agent, signed) = self.try_sign(agent, &mut r, writebuf).await?; self.agent = Some(agent); if signed { return Ok(()); @@ -155,22 +160,28 @@ impl { + Some((&17, mut r)) + if !is_locked && agentref.confirm_request(MessageType::AddKeys).await => + { // add identity - if let Ok(true) = self.add_key(r, false, writebuf).await { + if let Ok(true) = self.add_key(&mut r, false, writebuf).await { } else { writebuf.push(msg::FAILURE) } } - Ok(18) if !is_locked && agentref.confirm_request(MessageType::RemoveKeys).await => { + Some((&18, mut r)) + if !is_locked && agentref.confirm_request(MessageType::RemoveKeys).await => + { // remove identity - if let Ok(true) = self.remove_identity(r) { + if let Ok(true) = self.remove_identity(&mut r) { writebuf.push(msg::SUCCESS) } else { writebuf.push(msg::FAILURE) } } - Ok(19) if !is_locked && agentref.confirm_request(MessageType::RemoveAllKeys).await => { + Some((&19, _)) + if !is_locked && agentref.confirm_request(MessageType::RemoveAllKeys).await => + { // remove all identities if let Ok(mut keys) = self.keys.0.write() { keys.clear(); @@ -179,25 +190,31 @@ impl { + Some((&22, mut r)) + if !is_locked && agentref.confirm_request(MessageType::Lock).await => + { // lock - if let Ok(()) = self.lock(r) { + if let Ok(()) = self.lock(&mut r) { writebuf.push(msg::SUCCESS) } else { writebuf.push(msg::FAILURE) } } - Ok(23) if is_locked && agentref.confirm_request(MessageType::Unlock).await => { + Some((&23, mut r)) + if is_locked && agentref.confirm_request(MessageType::Unlock).await => + { // unlock - if let Ok(true) = self.unlock(r) { + if let Ok(true) = self.unlock(&mut r) { writebuf.push(msg::SUCCESS) } else { writebuf.push(msg::FAILURE) } } - Ok(25) if !is_locked && agentref.confirm_request(MessageType::AddKeys).await => { + Some((&25, mut r)) + if !is_locked && agentref.confirm_request(MessageType::AddKeys).await => + { // add identity constrained - if let Ok(true) = self.add_key(r, true, writebuf).await { + if let Ok(true) = self.add_key(&mut r, true, writebuf).await { } else { writebuf.push(msg::FAILURE) } @@ -212,17 +229,17 @@ impl Result<(), Error> { - let password = r.read_string()?; + fn lock(&self, r: &mut R) -> Result<(), Error> { + let password = Bytes::decode(r)?; let mut lock = self.lock.0.write().or(Err(Error::AgentFailure))?; - lock.extend(password); + lock.extend(&password); Ok(()) } - fn unlock(&self, mut r: Position) -> Result { - let password = r.read_string()?; + fn unlock(&self, r: &mut R) -> Result { + let password = Bytes::decode(r)?; let mut lock = self.lock.0.write().or(Err(Error::AgentFailure))?; - if &lock[..] == password { + if lock[..] == password { lock.clear(); Ok(true) } else { @@ -230,9 +247,9 @@ impl Result { + fn remove_identity(&self, r: &mut R) -> Result { if let Ok(mut keys) = self.keys.0.write() { - if keys.remove(r.read_string()?).is_some() { + if keys.remove(&Bytes::decode(r)?.to_vec()).is_some() { Ok(true) } else { Ok(false) @@ -242,20 +259,18 @@ impl( &self, - mut r: Position<'_>, + r: &mut R, constrained: bool, writebuf: &mut CryptoVec, ) -> Result { let (blob, key_pair) = { use ssh_encoding::Decode; - let private_key = ssh_key::private::PrivateKey::new( - ssh_key::private::KeypairData::decode(&mut r)?, - "", - )?; - let _comment = r.read_string()?; + let private_key = + ssh_key::private::PrivateKey::new(ssh_key::private::KeypairData::decode(r)?, "")?; + let _comment = String::decode(r)?; (private_key.public_key().key_data().encoded()?, private_key) }; @@ -264,9 +279,9 @@ impl( &self, agent: A, - mut r: Position<'_>, + r: &mut R, writebuf: &mut CryptoVec, ) -> Result<(A, bool), Error> { let mut needs_confirm = false; let key = { - let blob = r.read_string()?; + let blob = Bytes::decode(r)?; let k = self.keys.0.read().or(Err(Error::AgentFailure))?; - if let Some((key, _, constraints)) = k.get(blob) { + if let Some((key, _, constraints)) = k.get(&blob.to_vec()) { if constraints.iter().any(|c| *c == Constraint::Confirm) { needs_confirm = true; } @@ -325,9 +340,10 @@ impl &[u8]; -} - -impl> Bytes for A { - fn bytes(&self) -> &[u8] { - self.as_ref().as_bytes() - } -} - -/// Encode in the SSH format. -pub trait Encoding { - /// Push an SSH-encoded string to `self`. - fn extend_ssh_string(&mut self, s: &[u8]); - /// Push an SSH-encoded blank string of length `s` to `self`. - fn extend_ssh_string_blank(&mut self, s: usize) -> &mut [u8]; - /// Push an SSH-encoded multiple-precision integer. - fn extend_ssh_mpint(&mut self, s: &[u8]); - /// Push an SSH-encoded list. - fn extend_list>(&mut self, list: I); - /// Push an SSH-encoded empty list. - fn write_empty_list(&mut self); - /// Push an SSH-encoded value. - fn extend_ssh(&mut self, v: &T) { - v.write_ssh(self) - } - /// Push a nested SSH-encoded value. - fn extend_wrapped(&mut self, write: F) - where - F: FnOnce(&mut Self); -} - -/// Trait for writing value in SSH-encoded format. -pub trait SshWrite { - /// Write the value. - fn write_ssh(&self, encoder: &mut E); -} - -/// Encoding length of the given mpint. -#[allow(clippy::indexing_slicing)] -pub fn mpint_len(s: &[u8]) -> usize { - let mut i = 0; - while i < s.len() && s[i] == 0 { - i += 1 - } - (if s[i] & 0x80 != 0 { 5 } else { 4 }) + s.len() - i -} - -impl Encoding for Vec { - #[allow(clippy::unwrap_used)] // writing into Vec<> can't panic - fn extend_ssh_string(&mut self, s: &[u8]) { - self.write_u32::(s.len() as u32).unwrap(); - self.extend(s); - } - - #[allow(clippy::unwrap_used)] // writing into Vec<> can't panic - fn extend_ssh_string_blank(&mut self, len: usize) -> &mut [u8] { - self.write_u32::(len as u32).unwrap(); - let current = self.len(); - self.resize(current + len, 0u8); - #[allow(clippy::indexing_slicing)] // length is known - &mut self[current..] - } - - #[allow(clippy::unwrap_used)] // writing into Vec<> can't panic - #[allow(clippy::indexing_slicing)] // length is known - fn extend_ssh_mpint(&mut self, s: &[u8]) { - // Skip initial 0s. - let mut i = 0; - while i < s.len() && s[i] == 0 { - i += 1 - } - // If the first non-zero is >= 128, write its length (u32, BE), followed by 0. - if s[i] & 0x80 != 0 { - self.write_u32::((s.len() - i + 1) as u32) - .unwrap(); - self.push(0) - } else { - self.write_u32::((s.len() - i) as u32).unwrap(); - } - self.extend(&s[i..]); - } - - #[allow(clippy::indexing_slicing)] // length is known - fn extend_list>(&mut self, list: I) { - let len0 = self.len(); - self.extend([0, 0, 0, 0]); - let mut first = true; - for i in list { - if !first { - self.push(b',') - } else { - first = false; - } - self.extend(i.bytes()) - } - let len = (self.len() - len0 - 4) as u32; - - BigEndian::write_u32(&mut self[len0..], len); - } - - fn write_empty_list(&mut self) { - self.extend([0, 0, 0, 0]); - } - - fn extend_wrapped(&mut self, write: F) - where - F: FnOnce(&mut Self), - { - let len_offset = self.len(); - #[allow(clippy::unwrap_used)] // writing into Vec<> can't panic - self.write_u32::(0).unwrap(); - let data_offset = self.len(); - write(self); - let data_len = self.len() - data_offset; - #[allow(clippy::indexing_slicing)] // length is known - BigEndian::write_u32(&mut self[len_offset..], data_len as u32); - } -} - -impl Encoding for CryptoVec { - fn extend_ssh_string(&mut self, s: &[u8]) { - self.push_u32_be(s.len() as u32); - self.extend(s); - } - - #[allow(clippy::indexing_slicing)] // length is known - fn extend_ssh_string_blank(&mut self, len: usize) -> &mut [u8] { - self.push_u32_be(len as u32); - let current = self.len(); - self.resize(current + len); - &mut self[current..] - } - - #[allow(clippy::indexing_slicing)] // length is known - fn extend_ssh_mpint(&mut self, s: &[u8]) { - // Skip initial 0s. - let mut i = 0; - while i < s.len() && s[i] == 0 { - i += 1 - } - // If the first non-zero is >= 128, write its length (u32, BE), followed by 0. - if s[i] & 0x80 != 0 { - self.push_u32_be((s.len() - i + 1) as u32); - self.push(0) - } else { - self.push_u32_be((s.len() - i) as u32); - } - self.extend(&s[i..]); - } - - fn extend_list>(&mut self, list: I) { - let len0 = self.len(); - self.extend(&[0, 0, 0, 0]); - let mut first = true; - for i in list { - if !first { - self.push(b',') - } else { - first = false; - } - self.extend(i.bytes()) - } - let len = (self.len() - len0 - 4) as u32; - - #[allow(clippy::indexing_slicing)] // length is known - BigEndian::write_u32(&mut self[len0..], len); - } - - fn write_empty_list(&mut self) { - self.extend(&[0, 0, 0, 0]); - } - - fn extend_wrapped(&mut self, write: F) - where - F: FnOnce(&mut Self), - { - let len_offset = self.len(); - self.push_u32_be(0); - let data_offset = self.len(); - write(self); - let data_len = self.len() - data_offset; - #[allow(clippy::indexing_slicing)] // length is known - BigEndian::write_u32(&mut self[len_offset..], data_len as u32); - } -} - -/// A cursor-like trait to read SSH-encoded things. -pub trait Reader { - /// Create an SSH reader for `self`. - fn reader(&self, starting_at: usize) -> Position; -} - -impl Reader for CryptoVec { - fn reader(&self, starting_at: usize) -> Position { - Position { - s: self, - position: starting_at, - } - } -} - -impl Reader for [u8] { - fn reader(&self, starting_at: usize) -> Position { - Position { - s: self, - position: starting_at, - } - } -} - -/// A cursor-like type to read SSH-encoded values. -#[derive(Debug)] -pub struct Position<'a> { - s: &'a [u8], - #[doc(hidden)] - pub position: usize, -} -impl<'a> Position<'a> { - /// Read one string from this reader. - pub fn read_string(&mut self) -> Result<&'a [u8], Error> { - let len = self.read_u32()? as usize; - if self.position + len <= self.s.len() { - #[allow(clippy::indexing_slicing)] // length is known - let result = &self.s[self.position..(self.position + len)]; - self.position += len; - Ok(result) - } else { - Err(Error::IndexOutOfBounds) - } - } - /// Read a `u32` from this reader. - pub fn read_u32(&mut self) -> Result { - if self.position + 4 <= self.s.len() { - #[allow(clippy::indexing_slicing)] // length is known - let u = BigEndian::read_u32(&self.s[self.position..]); - self.position += 4; - Ok(u) - } else { - Err(Error::IndexOutOfBounds) - } - } - /// Read a `u64` from this reader by combining two `u32` values. - pub fn read_u64(&mut self) -> Result { - let high = self.read_u32()? as u64; - let low = self.read_u32()? as u64; - Ok((high << 32) | low) - } - /// Read one byte from this reader. - pub fn read_byte(&mut self) -> Result { - if self.position < self.s.len() { - #[allow(clippy::indexing_slicing)] // length is known - let u = self.s[self.position]; - self.position += 1; - Ok(u) - } else { - Err(Error::IndexOutOfBounds) - } - } - - /// Read one byte from this reader. - pub fn read_mpint(&mut self) -> Result<&'a [u8], Error> { - let len = self.read_u32()? as usize; - if self.position + len <= self.s.len() { - #[allow(clippy::indexing_slicing)] // length was checked - let result = &self.s[self.position..(self.position + len)]; - self.position += len; - Ok(result) - } else { - Err(Error::IndexOutOfBounds) - } - } - - pub fn read_ssh>(&mut self) -> Result { - T::read_ssh(self) - } -} - -/// Trait for reading value in SSH-encoded format. -pub trait SshRead<'a>: Sized + 'a { - /// Read the value from a position. - fn read_ssh(pos: &mut Position<'a>) -> Result; -} - -impl<'a> ssh_encoding::Reader for Position<'a> { - fn read<'o>(&mut self, out: &'o mut [u8]) -> ssh_encoding::Result<&'o [u8]> { - out.copy_from_slice( - self.s - .get(self.position..(self.position + out.len())) - .ok_or(ssh_encoding::Error::Length)?, - ); - self.position += out.len(); - Ok(out) - } - - fn remaining_len(&self) -> usize { - self.s.len() - self.position - } -} diff --git a/russh-keys/src/helpers.rs b/russh-keys/src/helpers.rs index 3f4befba..5be441c2 100644 --- a/russh-keys/src/helpers.rs +++ b/russh-keys/src/helpers.rs @@ -1,6 +1,7 @@ use ssh_encoding::Encode; -pub(crate) trait EncodedExt { +#[doc(hidden)] +pub trait EncodedExt { fn encoded(&self) -> ssh_key::Result>; } @@ -11,3 +12,32 @@ impl EncodedExt for E { Ok(buf) } } + +pub struct NameList(pub Vec); + +impl NameList { + pub fn as_encoded_string(&self) -> String { + self.0.join(",") + } +} + +impl Encode for NameList { + fn encoded_len(&self) -> Result { + self.as_encoded_string().encoded_len() + } + + fn encode(&self, writer: &mut impl ssh_encoding::Writer) -> Result<(), ssh_encoding::Error> { + self.as_encoded_string().encode(writer) + } +} + +#[macro_export] +#[doc(hidden)] +#[allow(clippy::crate_in_macro_def)] +macro_rules! map_err { + ($result:expr) => { + $result.map_err(|e| crate::Error::from(e)) + }; +} + +pub use map_err; diff --git a/russh-keys/src/key.rs b/russh-keys/src/key.rs index e584f399..efc0c88a 100644 --- a/russh-keys/src/key.rs +++ b/russh-keys/src/key.rs @@ -16,7 +16,6 @@ use ssh_encoding::Decode; use ssh_key::public::KeyData; use ssh_key::{Algorithm, EcdsaCurve, HashAlg, PublicKey}; -use crate::encoding::Reader; use crate::Error; pub trait PublicKeyExt { @@ -24,8 +23,8 @@ pub trait PublicKeyExt { } impl PublicKeyExt for PublicKey { - fn decode(bytes: &[u8]) -> Result { - let key = KeyData::decode(&mut bytes.reader(0))?; + fn decode(mut bytes: &[u8]) -> Result { + let key = KeyData::decode(&mut bytes)?; Ok(PublicKey::new(key, "")) } } @@ -37,9 +36,9 @@ pub trait Verify { } /// Parse a public key from a byte slice. -pub fn parse_public_key(p: &[u8]) -> Result { +pub fn parse_public_key(mut p: &[u8]) -> Result { use ssh_encoding::Decode; - Ok(ssh_key::public::KeyData::decode(&mut p.reader(0))?.into()) + Ok(ssh_key::public::KeyData::decode(&mut p)?.into()) } /// Obtain a cryptographic-safe random number generator. diff --git a/russh-keys/src/lib.rs b/russh-keys/src/lib.rs index a006053c..63ae5c14 100644 --- a/russh-keys/src/lib.rs +++ b/russh-keys/src/lib.rs @@ -71,18 +71,14 @@ use std::string::FromUtf8Error; use aes::cipher::block_padding::UnpadError; use aes::cipher::inout::PadError; use data_encoding::BASE64_MIME; -use encoding::Encoding; use helpers::EncodedExt; -use russh_cryptovec::CryptoVec; -use signature::Signer; -use ssh_key::Signature; use thiserror::Error; -pub mod encoding; pub mod key; mod format; -mod helpers; +#[doc(hidden)] +pub mod helpers; pub use format::*; pub use ssh_key::{self, Algorithm, Certificate, EcdsaCurve, HashAlg, PrivateKey, PublicKey}; @@ -281,27 +277,6 @@ fn is_base64_char(c: char) -> bool { || c == '=' } -#[doc(hidden)] -pub fn add_signature>( - signer: &S, - to_sign: &[u8], - output: &mut CryptoVec, -) -> Result<(), ssh_key::Error> { - let sig = signer.sign(to_sign); - output.extend_ssh_string(sig.encoded()?.as_slice()); - Ok(()) -} - -#[doc(hidden)] -pub fn add_self_signature>( - signer: &S, - buffer: &mut CryptoVec, -) -> Result<(), ssh_key::Error> { - let sig = signer.sign(buffer); - buffer.extend_ssh_string(sig.encoded()?.as_slice()); - Ok(()) -} - #[cfg(test)] mod test { diff --git a/russh/Cargo.toml b/russh/Cargo.toml index cfa55211..7f7f6de7 100644 --- a/russh/Cargo.toml +++ b/russh/Cargo.toml @@ -23,9 +23,11 @@ cbc = { version = "0.1" } async-trait = { workspace = true } bitflags = "2.0" byteorder = { workspace = true } +bytes = { workspace = true } chacha20 = "0.9" ctr = "0.9" curve25519-dalek = "4.1.3" +delegate.workspace = true digest = { workspace = true } elliptic-curve = { version = "0.13", features = ["ecdh"] } flate2 = { version = "1.0", optional = true } @@ -42,7 +44,7 @@ p521 = { version = "0.13", features = ["ecdh"] } poly1305 = "0.8" rand = { workspace = true } rand_core = { version = "0.6.4", features = ["getrandom"] } -russh-cryptovec = { version = "0.7.0", path = "../cryptovec" } +russh-cryptovec = { version = "0.8.0-beta.1", path = "../cryptovec" } russh-keys = { version = "0.46.0", path = "../russh-keys" } sha1 = { workspace = true } sha2 = { workspace = true } diff --git a/russh/examples/echoserver.rs b/russh/examples/echoserver.rs index 4506b49d..18d96718 100644 --- a/russh/examples/echoserver.rs +++ b/russh/examples/echoserver.rs @@ -86,9 +86,7 @@ impl server::Handler for Server { _: &str, _key: &ssh_key::PublicKey, ) -> Result { - Ok(server::Auth::Reject { - proceed_with_methods: None, - }) + Ok(server::Auth::Accept) } async fn auth_openssh_certificate( @@ -113,7 +111,7 @@ impl server::Handler for Server { let data = CryptoVec::from(format!("Got data: {}\r\n", String::from_utf8_lossy(data))); self.post(data.clone()).await; - session.data(channel, data); + session.data(channel, data)?; Ok(()) } diff --git a/russh/examples/ratatui_app.rs b/russh/examples/ratatui_app.rs index e223f20c..d5531fd7 100644 --- a/russh/examples/ratatui_app.rs +++ b/russh/examples/ratatui_app.rs @@ -167,7 +167,7 @@ impl Handler for AppServer { // Pressing 'q' closes the connection. b"q" => { self.clients.lock().await.remove(&self.id); - session.close(channel); + session.close(channel)?; } // Pressing 'c' resets the counter for the app. // Only the client with the id sees the counter reset. diff --git a/russh/examples/ratatui_shared_app.rs b/russh/examples/ratatui_shared_app.rs index b38cb412..eedaee5e 100644 --- a/russh/examples/ratatui_shared_app.rs +++ b/russh/examples/ratatui_shared_app.rs @@ -168,7 +168,7 @@ impl Handler for AppServer { // Pressing 'q' closes the connection. b"q" => { self.clients.lock().await.remove(&self.id); - session.close(channel); + session.close(channel)?; } // Pressing 'c' resets the counter for the app. // Every client sees the counter reset. diff --git a/russh/examples/sftp_server.rs b/russh/examples/sftp_server.rs index f4a00d4d..8b4ec473 100644 --- a/russh/examples/sftp_server.rs +++ b/russh/examples/sftp_server.rs @@ -78,7 +78,7 @@ impl russh::server::Handler for SshSession { ) -> Result<(), Self::Error> { // After a client has sent an EOF, indicating that they don't want // to send more data in this session, the channel can be closed. - session.close(channel); + session.close(channel)?; Ok(()) } @@ -93,10 +93,10 @@ impl russh::server::Handler for SshSession { if name == "sftp" { let channel = self.get_channel(channel_id).await; let sftp = SftpSession::default(); - session.channel_success(channel_id); + session.channel_success(channel_id)?; russh_sftp::server::run(channel.into_stream(), sftp).await; } else { - session.channel_failure(channel_id); + session.channel_failure(channel_id)?; } Ok(()) diff --git a/russh/examples/test.rs b/russh/examples/test.rs index 6089766e..08a037c1 100644 --- a/russh/examples/test.rs +++ b/russh/examples/test.rs @@ -92,7 +92,7 @@ impl server::Handler for Server { { let mut clients = self.clients.lock().unwrap(); for ((_, _channel_id), ref mut channel) in clients.iter_mut() { - session.data(channel.id(), CryptoVec::from(data.to_vec())); + session.data(channel.id(), CryptoVec::from(data.to_vec()))?; } } Ok(()) diff --git a/russh/src/auth.rs b/russh/src/auth.rs index 0c9a9f8b..b405f828 100644 --- a/russh/src/auth.rs +++ b/russh/src/auth.rs @@ -17,11 +17,11 @@ use std::sync::Arc; use async_trait::async_trait; use bitflags::bitflags; +use russh_keys::helpers::NameList; use ssh_key::{Certificate, PrivateKey}; use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; -use crate::keys::encoding; use crate::CryptoVec; bitflags! { @@ -101,27 +101,39 @@ pub enum Method { // Hostbased, } -impl encoding::Bytes for MethodSet { - fn bytes(&self) -> &'static [u8] { - match *self { - MethodSet::NONE => b"none", - MethodSet::PASSWORD => b"password", - MethodSet::PUBLICKEY => b"publickey", - MethodSet::HOSTBASED => b"hostbased", - MethodSet::KEYBOARD_INTERACTIVE => b"keyboard-interactive", - _ => b"", +impl From for &'static str { + fn from(value: MethodSet) -> Self { + match value { + MethodSet::NONE => "none", + MethodSet::PASSWORD => "password", + MethodSet::PUBLICKEY => "publickey", + MethodSet::HOSTBASED => "hostbased", + MethodSet::KEYBOARD_INTERACTIVE => "keyboard-interactive", + _ => "", } } } +impl From for String { + fn from(value: MethodSet) -> Self { + <&str>::from(value).to_string() + } +} + +impl From for NameList { + fn from(value: MethodSet) -> Self { + Self(value.into_iter().map(|x| x.into()).collect()) + } +} + impl MethodSet { - pub(crate) fn from_bytes(b: &[u8]) -> Option { + pub(crate) fn from_str(b: &str) -> Option { match b { - b"none" => Some(MethodSet::NONE), - b"password" => Some(MethodSet::PASSWORD), - b"publickey" => Some(MethodSet::PUBLICKEY), - b"hostbased" => Some(MethodSet::HOSTBASED), - b"keyboard-interactive" => Some(MethodSet::KEYBOARD_INTERACTIVE), + "none" => Some(MethodSet::NONE), + "password" => Some(MethodSet::PASSWORD), + "publickey" => Some(MethodSet::PUBLICKEY), + "hostbased" => Some(MethodSet::HOSTBASED), + "keyboard-interactive" => Some(MethodSet::KEYBOARD_INTERACTIVE), _ => None, } } diff --git a/russh/src/cert.rs b/russh/src/cert.rs index 79f49ed1..1c7f0bd2 100644 --- a/russh/src/cert.rs +++ b/russh/src/cert.rs @@ -11,9 +11,9 @@ pub(crate) enum PublicKeyOrCertificate { } impl PublicKeyOrCertificate { - pub fn decode(pubkey_algo: &[u8], buf: &[u8]) -> Result { + pub fn decode(pubkey_algo: &str, buf: &[u8]) -> Result { let mut reader = buf; - match Algorithm::new_certificate_ext(str::from_utf8(pubkey_algo)?) { + match Algorithm::new_certificate_ext(pubkey_algo) { Ok(Algorithm::Other(_)) | Err(ssh_key::Error::Encoding(_)) => { // Did not match a known cert algorithm Ok(PublicKeyOrCertificate::PublicKey( diff --git a/russh/src/cipher/mod.rs b/russh/src/cipher/mod.rs index 76804e84..5dc08a47 100644 --- a/russh/src/cipher/mod.rs +++ b/russh/src/cipher/mod.rs @@ -25,9 +25,11 @@ use aes::{Aes128, Aes192, Aes256}; use byteorder::{BigEndian, ByteOrder}; use cbc::CbcWrapper; use ctr::Ctr128BE; +use delegate::delegate; use des::TdesEde3; use log::debug; use once_cell::sync::Lazy; +use ssh_encoding::Encode; use tokio::io::{AsyncRead, AsyncReadExt}; use crate::mac::MacAlgorithm; @@ -143,6 +145,13 @@ impl AsRef for Name { } } +impl Encode for Name { + delegate! { to self.as_ref() { + fn encoded_len(&self) -> Result; + fn encode(&self, writer: &mut impl ssh_encoding::Writer) -> Result<(), ssh_encoding::Error>; + }} +} + impl Borrow for &Name { fn borrow(&self) -> &str { self.0 @@ -209,7 +218,8 @@ pub(crate) trait SealingKey { // Maximum packet length: // https://tools.ietf.org/html/rfc4253#section-6.1 assert!(packet_length <= u32::MAX as usize); - buffer.buffer.push_u32_be(packet_length as u32); + #[allow(clippy::unwrap_used)] // length checked + (packet_length as u32).encode(&mut buffer.buffer).unwrap(); assert!(padding_length <= u8::MAX as usize); buffer.buffer.push(padding_length as u8); diff --git a/russh/src/client/encrypted.rs b/russh/src/client/encrypted.rs index a95bdef2..489fc524 100644 --- a/russh/src/client/encrypted.rs +++ b/russh/src/client/encrypted.rs @@ -16,12 +16,13 @@ use std::cell::RefCell; use std::convert::TryInto; use std::num::Wrapping; +use bytes::Bytes; use log::{debug, error, info, trace, warn}; -use russh_keys::add_self_signature; +use russh_keys::helpers::{map_err, EncodedExt}; +use ssh_encoding::{Decode, Encode}; use crate::cert::PublicKeyOrCertificate; use crate::client::{Handler, Msg, Prompt, Reply, Session}; -use crate::keys::encoding::{Encoding, Reader}; use crate::keys::key::parse_public_key; use crate::negotiation::{Named, Select}; use crate::parsing::{ChannelOpenConfirmation, ChannelType, OpenChannelMessage}; @@ -103,7 +104,12 @@ impl Session { Ok(()) } else if buf.first() == Some(&msg::KEX_ECDH_REPLY) { // We've sent ECDH_INIT, waiting for ECDH_REPLY - let kex = kexdhdone.server_key_check(true, client, buf).await?; + + #[allow(clippy::indexing_slicing)] // length checked + let kex = kexdhdone + .server_key_check(true, client, &mut &buf[1..]) + .await?; + enc.rekey = Some(Kex::Keys(kex)); self.common .cipher @@ -125,7 +131,7 @@ impl Session { enc.last_rekey = russh_util::time::Instant::now(); // Ok, NEWKEYS received, now encrypted. - enc.flush_all_pending(); + enc.flush_all_pending()?; let mut pending = std::mem::take(&mut self.pending_reads); for p in pending.drain(..) { self.process_packet(client, &p).await?; @@ -175,206 +181,206 @@ impl Session { buf.first(), msg::SERVICE_ACCEPT ); - if buf.first() == Some(&msg::SERVICE_ACCEPT) { - let mut r = buf.reader(1); - if r.read_string().map_err(crate::Error::from)? == b"ssh-userauth" { - *accepted = true; - if let Some(ref meth) = self.common.auth_method { - let auth_request = match meth { - crate::auth::Method::KeyboardInteractive { submethods } => { - auth::AuthRequest { + match buf.split_first() { + Some((&msg::SERVICE_ACCEPT, mut r)) => { + if map_err!(Bytes::decode(&mut r))?.as_ref() == b"ssh-userauth" { + *accepted = true; + if let Some(ref meth) = self.common.auth_method { + let auth_request = match meth { + crate::auth::Method::KeyboardInteractive { submethods } => { + auth::AuthRequest { + methods: auth::MethodSet::all(), + partial_success: false, + current: Some( + auth::CurrentRequest::KeyboardInteractive { + submethods: submethods.to_string(), + }, + ), + rejection_count: 0, + } + } + _ => auth::AuthRequest { methods: auth::MethodSet::all(), partial_success: false, - current: Some( - auth::CurrentRequest::KeyboardInteractive { - submethods: submethods.to_string(), - }, - ), + current: None, rejection_count: 0, - } + }, + }; + let len = enc.write.len(); + #[allow(clippy::indexing_slicing)] // length checked + if enc.write_auth_request(&self.common.auth_user, meth)? { + debug!("enc: {:?}", &enc.write[len..]); + enc.state = EncryptedState::WaitingAuthRequest(auth_request) } - _ => auth::AuthRequest { - methods: auth::MethodSet::all(), - partial_success: false, - current: None, - rejection_count: 0, - }, - }; - let len = enc.write.len(); - #[allow(clippy::indexing_slicing)] // length checked - if enc.write_auth_request(&self.common.auth_user, meth)? { - debug!("enc: {:?}", &enc.write[len..]); - enc.state = EncryptedState::WaitingAuthRequest(auth_request) + } else { + debug!("no auth method") } - } else { - debug!("no auth method") } } - } else if buf.first() == Some(&msg::EXT_INFO) { - return self.handle_ext_info(client, buf); - } else { - debug!("unknown message: {:?}", buf); - return Err(crate::Error::Inconsistent.into()); + Some((&msg::EXT_INFO, r)) => { + return self.handle_ext_info(client, r); + } + other => { + debug!("unknown message: {other:?}"); + return Err(crate::Error::Inconsistent.into()); + } } } EncryptedState::WaitingAuthRequest(ref mut auth_request) => { - if buf.first() == Some(&msg::USERAUTH_SUCCESS) { - debug!("userauth_success"); - self.sender - .send(Reply::AuthSuccess) - .map_err(|_| crate::Error::SendError)?; - enc.state = EncryptedState::InitCompression; - enc.server_compression.init_decompress(&mut enc.decompress); - return Ok(()); - } else if buf.first() == Some(&msg::USERAUTH_BANNER) { - let mut r = buf.reader(1); - let banner = r.read_string().map_err(crate::Error::from)?; - return if let Ok(banner) = std::str::from_utf8(banner) { - client.auth_banner(banner, self).await - } else { - Ok(()) - }; - } else if buf.first() == Some(&msg::USERAUTH_FAILURE) { - debug!("userauth_failure"); - - let mut r = buf.reader(1); - let remaining_methods = r.read_string().map_err(crate::Error::from)?; - debug!( - "remaining methods {:?}", - std::str::from_utf8(remaining_methods) - ); - auth_request.methods = auth::MethodSet::empty(); - for method in remaining_methods.split(|&c| c == b',') { - if let Some(m) = auth::MethodSet::from_bytes(method) { - auth_request.methods |= m - } + match buf.split_first() { + Some((&msg::USERAUTH_SUCCESS, _)) => { + debug!("userauth_success"); + self.sender + .send(Reply::AuthSuccess) + .map_err(|_| crate::Error::SendError)?; + enc.state = EncryptedState::InitCompression; + enc.server_compression.init_decompress(&mut enc.decompress); + return Ok(()); } - let no_more_methods = auth_request.methods.is_empty(); - self.common.auth_method = None; - self.sender - .send(Reply::AuthFailure) - .map_err(|_| crate::Error::SendError)?; - - // If no other authentication method is allowed by the server, give up. - if no_more_methods { - return Err(crate::Error::NoAuthMethod.into()); + Some((&msg::USERAUTH_BANNER, mut r)) => { + let banner = map_err!(String::decode(&mut r))?; + client.auth_banner(&banner, self).await?; + return Ok(()); } - } else if buf.first() == Some(&msg::USERAUTH_INFO_REQUEST_OR_USERAUTH_PK_OK) { - if let Some(auth::CurrentRequest::PublicKey { - ref mut sent_pk_ok, .. - }) = auth_request.current - { - debug!("userauth_pk_ok"); - *sent_pk_ok = true; - } else if let Some(auth::CurrentRequest::KeyboardInteractive { .. }) = - auth_request.current - { - debug!("keyboard_interactive"); - let mut r = buf.reader(1); + Some((&msg::USERAUTH_FAILURE, mut r)) => { + debug!("userauth_failure"); - // read fields - let name = String::from_utf8_lossy( - r.read_string().map_err(crate::Error::from)?, - ) - .to_string(); + let remaining_methods = map_err!(String::decode(&mut r))?; + debug!("remaining methods {remaining_methods:?}",); + auth_request.methods = auth::MethodSet::empty(); + for method in remaining_methods.split(',') { + if let Some(m) = auth::MethodSet::from_str(method) { + auth_request.methods |= m + } + } + let no_more_methods = auth_request.methods.is_empty(); + self.common.auth_method = None; + self.sender + .send(Reply::AuthFailure) + .map_err(|_| crate::Error::SendError)?; - let instructions = String::from_utf8_lossy( - r.read_string().map_err(crate::Error::from)?, - ) - .to_string(); + // If no other authentication method is allowed by the server, give up. + if no_more_methods { + return Err(crate::Error::NoAuthMethod.into()); + } + } + Some((&msg::USERAUTH_INFO_REQUEST_OR_USERAUTH_PK_OK, mut r)) => { + if let Some(auth::CurrentRequest::PublicKey { + ref mut sent_pk_ok, + .. + }) = auth_request.current + { + debug!("userauth_pk_ok"); + *sent_pk_ok = true; + } else if let Some(auth::CurrentRequest::KeyboardInteractive { + .. + }) = auth_request.current + { + debug!("keyboard_interactive"); - let _lang = r.read_string().map_err(crate::Error::from)?; - let n_prompts = r.read_u32().map_err(crate::Error::from)?; + // read fields + let name = map_err!(String::decode(&mut r))?; - // read prompts - let mut prompts = Vec::with_capacity(n_prompts.try_into().unwrap_or(0)); - for _i in 0..n_prompts { - let prompt = String::from_utf8_lossy( - r.read_string().map_err(crate::Error::from)?, - ); + let instructions = map_err!(String::decode(&mut r))?; - let echo = r.read_byte().map_err(crate::Error::from)? != 0; - prompts.push(Prompt { - prompt: prompt.to_string(), - echo, - }); - } + let _lang = map_err!(String::decode(&mut r))?; + let n_prompts = map_err!(u32::decode(&mut r))?; - // send challenges to caller - self.sender - .send(Reply::AuthInfoRequest { - name, - instructions, - prompts, - }) - .map_err(|_| crate::Error::SendError)?; + // read prompts + let mut prompts = + Vec::with_capacity(n_prompts.try_into().unwrap_or(0)); + for _i in 0..n_prompts { + let prompt = map_err!(String::decode(&mut r))?; - // wait for response from handler - let responses = loop { - match self.receiver.recv().await { - Some(Msg::AuthInfoResponse { responses }) => break responses, - _ => {} + let echo = map_err!(u8::decode(&mut r))? != 0; + prompts.push(Prompt { + prompt: prompt.to_string(), + echo, + }); } - }; - // write responses - enc.client_send_auth_response(&responses)?; - return Ok(()); - } - - // continue with userauth_pk_ok - match self.common.auth_method.take() { - Some(auth_method @ auth::Method::PublicKey { .. }) => { - self.common.buffer.clear(); - enc.client_send_signature( - &self.common.auth_user, - &auth_method, - &mut self.common.buffer, - )? - } - Some(auth_method @ auth::Method::OpenSshCertificate { .. }) => { - self.common.buffer.clear(); - enc.client_send_signature( - &self.common.auth_user, - &auth_method, - &mut self.common.buffer, - )? - } - Some(auth::Method::FuturePublicKey { key }) => { - debug!("public key"); - self.common.buffer.clear(); - let i = enc.client_make_to_sign( - &self.common.auth_user, - &PublicKeyOrCertificate::PublicKey(key.clone()), - &mut self.common.buffer, - )?; - let len = self.common.buffer.len(); - let buf = - std::mem::replace(&mut self.common.buffer, CryptoVec::new()); + // send challenges to caller self.sender - .send(Reply::SignRequest { key, data: buf }) + .send(Reply::AuthInfoRequest { + name, + instructions, + prompts, + }) .map_err(|_| crate::Error::SendError)?; - self.common.buffer = loop { + + // wait for response from handler + let responses = loop { match self.receiver.recv().await { - Some(Msg::Signed { data }) => break data, + Some(Msg::AuthInfoResponse { responses }) => { + break responses + } _ => {} } }; - if self.common.buffer.len() != len { - // The buffer was modified. - push_packet!(enc.write, { - #[allow(clippy::indexing_slicing)] // length checked - enc.write.extend(&self.common.buffer[i..]); - }) + // write responses + enc.client_send_auth_response(&responses)?; + return Ok(()); + } + + // continue with userauth_pk_ok + match self.common.auth_method.take() { + Some(auth_method @ auth::Method::PublicKey { .. }) => { + self.common.buffer.clear(); + enc.client_send_signature( + &self.common.auth_user, + &auth_method, + &mut self.common.buffer, + )? + } + Some(auth_method @ auth::Method::OpenSshCertificate { .. }) => { + self.common.buffer.clear(); + enc.client_send_signature( + &self.common.auth_user, + &auth_method, + &mut self.common.buffer, + )? } + Some(auth::Method::FuturePublicKey { key }) => { + debug!("public key"); + self.common.buffer.clear(); + let i = enc.client_make_to_sign( + &self.common.auth_user, + &PublicKeyOrCertificate::PublicKey(key.clone()), + &mut self.common.buffer, + )?; + let len = self.common.buffer.len(); + let buf = std::mem::replace( + &mut self.common.buffer, + CryptoVec::new(), + ); + + self.sender + .send(Reply::SignRequest { key, data: buf }) + .map_err(|_| crate::Error::SendError)?; + self.common.buffer = loop { + match self.receiver.recv().await { + Some(Msg::Signed { data }) => break data, + _ => {} + } + }; + if self.common.buffer.len() != len { + // The buffer was modified. + push_packet!(enc.write, { + #[allow(clippy::indexing_slicing)] // length checked + enc.write.extend(&self.common.buffer[i..]); + }) + } + } + _ => {} } - _ => {} } - } else if buf.first() == Some(&msg::EXT_INFO) { - return self.handle_ext_info(client, buf); - } else { - debug!("unknown message: {:?}", buf); - return Err(crate::Error::Inconsistent.into()); + Some((&msg::EXT_INFO, r)) => { + return self.handle_ext_info(client, r); + } + other => { + debug!("unknown message: {other:?}"); + return Err(crate::Error::Inconsistent.into()); + } } } EncryptedState::InitCompression => unreachable!(), @@ -388,8 +394,8 @@ impl Session { } } - fn handle_ext_info(&mut self, _client: &mut H, buf: &[u8]) -> Result<(), H::Error> { - debug!("Received EXT_INFO: {:?}", buf); + fn handle_ext_info(&mut self, _client: &mut H, r: &[u8]) -> Result<(), H::Error> { + debug!("Received EXT_INFO: {:?}", r); Ok(()) } @@ -398,11 +404,10 @@ impl Session { client: &mut H, buf: &[u8], ) -> Result<(), H::Error> { - match buf.first() { - Some(&msg::CHANNEL_OPEN_CONFIRMATION) => { + match buf.split_first() { + Some((&msg::CHANNEL_OPEN_CONFIRMATION, mut reader)) => { debug!("channel_open_confirmation"); - let mut reader = buf.reader(1); - let msg = ChannelOpenConfirmation::parse(&mut reader)?; + let msg = map_err!(ChannelOpenConfirmation::decode(&mut reader))?; let local_id = ChannelId(msg.recipient_channel); if let Some(ref mut enc) = self.common.encrypted { @@ -437,38 +442,32 @@ impl Session { ) .await } - Some(&msg::CHANNEL_CLOSE) => { + Some((&msg::CHANNEL_CLOSE, mut r)) => { debug!("channel_close"); - let mut r = buf.reader(1); - let channel_num = ChannelId(r.read_u32().map_err(crate::Error::from)?); + let channel_num = map_err!(ChannelId::decode(&mut r))?; if let Some(ref mut enc) = self.common.encrypted { // The CHANNEL_CLOSE message must be sent to the server at this point or the session // will not be released. - enc.close(channel_num); + enc.close(channel_num)?; } self.channels.remove(&channel_num); client.channel_close(channel_num, self).await } - Some(&msg::CHANNEL_EOF) => { + Some((&msg::CHANNEL_EOF, mut r)) => { debug!("channel_eof"); - let mut r = buf.reader(1); - let channel_num = ChannelId(r.read_u32().map_err(crate::Error::from)?); + let channel_num = map_err!(ChannelId::decode(&mut r))?; if let Some(chan) = self.channels.get(&channel_num) { let _ = chan.send(ChannelMsg::Eof); } client.channel_eof(channel_num, self).await } - Some(&msg::CHANNEL_OPEN_FAILURE) => { + Some((&msg::CHANNEL_OPEN_FAILURE, mut r)) => { debug!("channel_open_failure"); - let mut r = buf.reader(1); - let channel_num = ChannelId(r.read_u32().map_err(crate::Error::from)?); - let reason_code = - ChannelOpenFailure::from_u32(r.read_u32().map_err(crate::Error::from)?) - .unwrap_or(ChannelOpenFailure::Unknown); - let descr = std::str::from_utf8(r.read_string().map_err(crate::Error::from)?) - .map_err(crate::Error::from)?; - let language = std::str::from_utf8(r.read_string().map_err(crate::Error::from)?) - .map_err(crate::Error::from)?; + let channel_num = map_err!(ChannelId::decode(&mut r))?; + let reason_code = ChannelOpenFailure::from_u32(map_err!(u32::decode(&mut r))?) + .unwrap_or(ChannelOpenFailure::Unknown); + let descr = map_err!(String::decode(&mut r))?; + let language = map_err!(String::decode(&mut r))?; if let Some(ref mut enc) = self.common.encrypted { enc.channels.remove(&channel_num); } @@ -480,17 +479,16 @@ impl Session { let _ = self.sender.send(Reply::ChannelOpenFailure); client - .channel_open_failure(channel_num, reason_code, descr, language, self) + .channel_open_failure(channel_num, reason_code, &descr, &language, self) .await } - Some(&msg::CHANNEL_DATA) => { + Some((&msg::CHANNEL_DATA, mut r)) => { trace!("channel_data"); - let mut r = buf.reader(1); - let channel_num = ChannelId(r.read_u32().map_err(crate::Error::from)?); - let data = r.read_string().map_err(crate::Error::from)?; + let channel_num = map_err!(ChannelId::decode(&mut r))?; + let data = map_err!(Bytes::decode(&mut r))?; let target = self.common.config.window_size; if let Some(ref mut enc) = self.common.encrypted { - if enc.adjust_window_size(channel_num, data, target) { + if enc.adjust_window_size(channel_num, &data, target)? { let next_window = client.adjust_window(channel_num, self.target_window_size); if next_window > 0 { @@ -501,21 +499,20 @@ impl Session { if let Some(chan) = self.channels.get(&channel_num) { let _ = chan.send(ChannelMsg::Data { - data: CryptoVec::from_slice(data), + data: CryptoVec::from_slice(&data), }); } - client.data(channel_num, data, self).await + client.data(channel_num, &data, self).await } - Some(&msg::CHANNEL_EXTENDED_DATA) => { + Some((&msg::CHANNEL_EXTENDED_DATA, mut r)) => { debug!("channel_extended_data"); - let mut r = buf.reader(1); - let channel_num = ChannelId(r.read_u32().map_err(crate::Error::from)?); - let extended_code = r.read_u32().map_err(crate::Error::from)?; - let data = r.read_string().map_err(crate::Error::from)?; + let channel_num = map_err!(ChannelId::decode(&mut r))?; + let extended_code = map_err!(u32::decode(&mut r))?; + let data = map_err!(Bytes::decode(&mut r))?; let target = self.common.config.window_size; if let Some(ref mut enc) = self.common.encrypted { - if enc.adjust_window_size(channel_num, data, target) { + if enc.adjust_window_size(channel_num, &data, target)? { let next_window = client.adjust_window(channel_num, self.target_window_size); if next_window > 0 { @@ -527,51 +524,42 @@ impl Session { if let Some(chan) = self.channels.get(&channel_num) { let _ = chan.send(ChannelMsg::ExtendedData { ext: extended_code, - data: CryptoVec::from_slice(data), + data: CryptoVec::from_slice(&data), }); } client - .extended_data(channel_num, extended_code, data, self) + .extended_data(channel_num, extended_code, &data, self) .await } - Some(&msg::CHANNEL_REQUEST) => { - let mut r = buf.reader(1); - let channel_num = ChannelId(r.read_u32().map_err(crate::Error::from)?); - let req = r.read_string().map_err(crate::Error::from)?; - debug!( - "channel_request: {:?} {:?}", - channel_num, - std::str::from_utf8(req) - ); - match req { - b"xon-xoff" => { - r.read_byte().map_err(crate::Error::from)?; // should be 0. - let client_can_do = r.read_byte().map_err(crate::Error::from)? != 0; + Some((&msg::CHANNEL_REQUEST, mut r)) => { + let channel_num = map_err!(ChannelId::decode(&mut r))?; + let req = map_err!(String::decode(&mut r))?; + debug!("channel_request: {channel_num:?} {req:?}",); + match req.as_str() { + "xon-xoff" => { + map_err!(u8::decode(&mut r))?; // should be 0. + let client_can_do = map_err!(u8::decode(&mut r))? != 0; if let Some(chan) = self.channels.get(&channel_num) { let _ = chan.send(ChannelMsg::XonXoff { client_can_do }); } client.xon_xoff(channel_num, client_can_do, self).await } - b"exit-status" => { - r.read_byte().map_err(crate::Error::from)?; // should be 0. - let exit_status = r.read_u32().map_err(crate::Error::from)?; + "exit-status" => { + map_err!(u8::decode(&mut r))?; // should be 0. + let exit_status = map_err!(u32::decode(&mut r))?; if let Some(chan) = self.channels.get(&channel_num) { let _ = chan.send(ChannelMsg::ExitStatus { exit_status }); } client.exit_status(channel_num, exit_status, self).await } - b"exit-signal" => { - r.read_byte().map_err(crate::Error::from)?; // should be 0. + "exit-signal" => { + map_err!(u8::decode(&mut r))?; // should be 0. let signal_name = - Sig::from_name(r.read_string().map_err(crate::Error::from)?)?; - let core_dumped = r.read_byte().map_err(crate::Error::from)? != 0; - let error_message = - std::str::from_utf8(r.read_string().map_err(crate::Error::from)?) - .map_err(crate::Error::from)?; - let lang_tag = - std::str::from_utf8(r.read_string().map_err(crate::Error::from)?) - .map_err(crate::Error::from)?; + Sig::from_name(map_err!(String::decode(&mut r))?.as_str()); + let core_dumped = map_err!(u8::decode(&mut r))? != 0; + let error_message = map_err!(String::decode(&mut r))?; + let lang_tag = map_err!(String::decode(&mut r))?; if let Some(chan) = self.channels.get(&channel_num) { let _ = chan.send(ChannelMsg::ExitSignal { signal_name: signal_name.clone(), @@ -585,24 +573,21 @@ impl Session { channel_num, signal_name, core_dumped, - error_message, - lang_tag, + &error_message, + &lang_tag, self, ) .await } - b"keepalive@openssh.com" => { - let wants_reply = r.read_byte().map_err(crate::Error::from)?; + "keepalive@openssh.com" => { + let wants_reply = map_err!(u8::decode(&mut r))?; if wants_reply == 1 { if let Some(ref mut enc) = self.common.encrypted { - trace!( - "Received channel keep alive message: {:?}", - std::str::from_utf8(req), - ); + trace!("Received channel keep alive message: {req:?}",); self.common.wants_reply = false; push_packet!(enc.write, { - enc.write.push(msg::CHANNEL_SUCCESS); - enc.write.push_u32_be(channel_num.0) + map_err!(msg::CHANNEL_SUCCESS.encode(&mut enc.write))?; + map_err!(channel_num.encode(&mut enc.write))?; }); } } else { @@ -611,30 +596,25 @@ impl Session { Ok(()) } _ => { - let wants_reply = r.read_byte().map_err(crate::Error::from)?; + let wants_reply = map_err!(u8::decode(&mut r))?; if wants_reply == 1 { if let Some(ref mut enc) = self.common.encrypted { self.common.wants_reply = false; push_packet!(enc.write, { - enc.write.push(msg::CHANNEL_FAILURE); - enc.write.push_u32_be(channel_num.0) + map_err!(msg::CHANNEL_FAILURE.encode(&mut enc.write))?; + map_err!(channel_num.encode(&mut enc.write))?; }) } } - info!( - "Unknown channel request {:?} {:?}", - std::str::from_utf8(req), - wants_reply - ); + info!("Unknown channel request {req:?} {wants_reply:?}",); Ok(()) } } } - Some(&msg::CHANNEL_WINDOW_ADJUST) => { + Some((&msg::CHANNEL_WINDOW_ADJUST, mut r)) => { debug!("channel_window_adjust"); - let mut r = buf.reader(1); - let channel_num = ChannelId(r.read_u32().map_err(crate::Error::from)?); - let amount = r.read_u32().map_err(crate::Error::from)?; + let channel_num = map_err!(ChannelId::decode(&mut r))?; + let amount = map_err!(u32::decode(&mut r))?; let mut new_size = 0; debug!("amount: {:?}", amount); if let Some(ref mut enc) = self.common.encrypted { @@ -647,7 +627,7 @@ impl Session { } if let Some(ref mut enc) = self.common.encrypted { - new_size -= enc.flush_pending(channel_num) as u32; + new_size -= enc.flush_pending(channel_num)? as u32; } if let Some(chan) = self.channels.get(&channel_num) { *chan.window_size().lock().await = new_size; @@ -656,52 +636,42 @@ impl Session { } client.window_adjusted(channel_num, new_size, self).await } - Some(&msg::GLOBAL_REQUEST) => { - let mut r = buf.reader(1); - let req = r.read_string().map_err(crate::Error::from)?; - let wants_reply = r.read_byte().map_err(crate::Error::from)?; + Some((&msg::GLOBAL_REQUEST, mut r)) => { + let req = map_err!(String::decode(&mut r))?; + let wants_reply = map_err!(u8::decode(&mut r))?; if let Some(ref mut enc) = self.common.encrypted { - if req.starts_with(b"keepalive") { + if req.starts_with("keepalive") { if wants_reply == 1 { - trace!( - "Received keep alive message: {:?}", - std::str::from_utf8(req), - ); + trace!("Received keep alive message: {req:?}",); self.common.wants_reply = false; push_packet!(enc.write, enc.write.push(msg::REQUEST_SUCCESS)); } else { warn!("Received keepalive without reply request!"); } - } else if req == b"hostkeys-00@openssh.com" { + } else if req == "hostkeys-00@openssh.com" { let mut keys = vec![]; loop { - match r.read_string() { + match Bytes::decode(&mut r) { Ok(key) => { - let key2 = <&[u8]>::clone(&key); - let key = parse_public_key(key).map_err(crate::Error::from); + let key = map_err!(parse_public_key(&key)); match key { Ok(key) => keys.push(key), - Err(err) => { + Err(ref err) => { debug!( - "failed to parse announced host key {:?}: {:?}", - key2, err + "failed to parse announced host key {key:?}: {err:?}", ) } } } - Err(russh_keys::Error::IndexOutOfBounds) => break, + Err(ssh_encoding::Error::Length) => break, x => { - x.map_err(crate::Error::from)?; + map_err!(x)?; } } } return client.openssh_ext_host_keys_announced(keys, self).await; } else { - warn!( - "Unhandled global request: {:?} {:?}", - std::str::from_utf8(req), - wants_reply - ); + warn!("Unhandled global request: {req:?} {wants_reply:?}",); self.common.wants_reply = false; push_packet!(enc.write, enc.write.push(msg::REQUEST_FAILURE)) } @@ -709,24 +679,21 @@ impl Session { self.common.received_data = false; Ok(()) } - Some(&msg::CHANNEL_SUCCESS) => { - let mut r = buf.reader(1); - let channel_num = ChannelId(r.read_u32().map_err(crate::Error::from)?); + Some((&msg::CHANNEL_SUCCESS, mut r)) => { + let channel_num = map_err!(ChannelId::decode(&mut r))?; if let Some(chan) = self.channels.get(&channel_num) { let _ = chan.send(ChannelMsg::Success); } client.channel_success(channel_num, self).await } - Some(&msg::CHANNEL_FAILURE) => { - let mut r = buf.reader(1); - let channel_num = ChannelId(r.read_u32().map_err(crate::Error::from)?); + Some((&msg::CHANNEL_FAILURE, mut r)) => { + let channel_num = map_err!(ChannelId::decode(&mut r))?; if let Some(chan) = self.channels.get(&channel_num) { let _ = chan.send(ChannelMsg::Failure); } client.channel_failure(channel_num, self).await } - Some(&msg::CHANNEL_OPEN) => { - let mut r = buf.reader(1); + Some((&msg::CHANNEL_OPEN, mut r)) => { let msg = OpenChannelMessage::parse(&mut r)?; if let Some(ref mut enc) = self.common.encrypted { @@ -747,23 +714,24 @@ impl Session { let confirm = || { debug!("confirming channel: {:?}", msg); - msg.confirm( + map_err!(msg.confirm( &mut enc.write, id.0, channel.sender_window_size, channel.sender_maximum_packet_size, - ); + ))?; enc.channels.insert(id, channel); + Ok(()) }; match &msg.typ { ChannelType::Session => { - confirm(); + confirm()?; let channel = self.accept_server_initiated_channel(id, &msg); client.server_channel_open_session(channel, self).await? } ChannelType::DirectTcpip(d) => { - confirm(); + confirm()?; let channel = self.accept_server_initiated_channel(id, &msg); client .server_channel_open_direct_tcpip( @@ -780,7 +748,7 @@ impl Session { originator_address, originator_port, } => { - confirm(); + confirm()?; let channel = self.accept_server_initiated_channel(id, &msg); client .server_channel_open_x11( @@ -792,7 +760,7 @@ impl Session { .await? } ChannelType::ForwardedTcpIp(d) => { - confirm(); + confirm()?; let channel = self.accept_server_initiated_channel(id, &msg); client .server_channel_open_forwarded_tcpip( @@ -806,7 +774,7 @@ impl Session { .await? } ChannelType::ForwardedStreamLocal(d) => { - confirm(); + confirm()?; let channel = self.accept_server_initiated_channel(id, &msg); client .server_channel_open_forwarded_streamlocal( @@ -817,7 +785,7 @@ impl Session { .await?; } ChannelType::AgentForward => { - confirm(); + confirm()?; let channel = self.accept_server_initiated_channel(id, &msg); client .server_channel_open_agent_forward(channel, self) @@ -825,12 +793,12 @@ impl Session { } ChannelType::Unknown { typ } => { if client.should_accept_unknown_server_channel(id, typ).await { - confirm(); + confirm()?; let channel = self.accept_server_initiated_channel(id, &msg); client.server_channel_open_unknown(channel, self).await?; } else { - debug!("unknown channel type: {}", String::from_utf8_lossy(typ)); - msg.unknown_type(&mut enc.write); + debug!("unknown channel type: {typ}"); + msg.unknown_type(&mut enc.write)?; } } }; @@ -839,19 +807,18 @@ impl Session { Err(crate::Error::Inconsistent.into()) } } - Some(&msg::REQUEST_SUCCESS) => { + Some((&msg::REQUEST_SUCCESS, mut r)) => { trace!("Global Request Success"); match self.open_global_requests.pop_front() { Some(GlobalRequestResponse::Keepalive) => { // ignore keepalives } Some(GlobalRequestResponse::TcpIpForward(return_channel)) => { - let result = if buf.len() == 1 { + let result = if r.is_empty() { // If a specific port was requested, the reply has no data Some(0) } else { - let mut r = buf.reader(1); - match r.read_u32() { + match u32::decode(&mut r) { Ok(port) => Some(port), Err(e) => { error!("Error parsing port for TcpIpForward request: {e:?}"); @@ -876,7 +843,7 @@ impl Session { } Ok(()) } - Some(&msg::REQUEST_FAILURE) => { + Some((&msg::REQUEST_FAILURE, _)) => { trace!("global request failure"); match self.open_global_requests.pop_front() { Some(GlobalRequestResponse::Keepalive) => { @@ -977,63 +944,61 @@ impl Encrypted { match *auth_method { auth::Method::None => { - self.write.extend_ssh_string(user.as_bytes()); - self.write.extend_ssh_string(b"ssh-connection"); - self.write.extend_ssh_string(b"none"); + user.encode(&mut self.write)?; + "ssh-connection".encode(&mut self.write)?; + "none".encode(&mut self.write)?; true } auth::Method::Password { ref password } => { - self.write.extend_ssh_string(user.as_bytes()); - self.write.extend_ssh_string(b"ssh-connection"); - self.write.extend_ssh_string(b"password"); - self.write.push(0); - self.write.extend_ssh_string(password.as_bytes()); + user.encode(&mut self.write)?; + "ssh-connection".encode(&mut self.write)?; + "password".encode(&mut self.write)?; + 0u8.encode(&mut self.write)?; + password.encode(&mut self.write)?; true } auth::Method::PublicKey { ref key } => { - self.write.extend_ssh_string(user.as_bytes()); - self.write.extend_ssh_string(b"ssh-connection"); - self.write.extend_ssh_string(b"publickey"); + user.encode(&mut self.write)?; + "ssh-connection".encode(&mut self.write)?; + "publickey".encode(&mut self.write)?; self.write.push(0); // This is a probe debug!("write_auth_request: key - {:?}", key.algorithm()); - self.write - .extend_ssh_string(key.algorithm().as_str().as_bytes()); - self.write - .extend_ssh_string(key.public_key().to_bytes()?.as_slice()); + key.algorithm().as_str().encode(&mut self.write)?; + key.public_key().to_bytes()?.encode(&mut self.write)?; true } auth::Method::OpenSshCertificate { ref cert, .. } => { - self.write.extend_ssh_string(user.as_bytes()); - self.write.extend_ssh_string(b"ssh-connection"); - self.write.extend_ssh_string(b"publickey"); + user.as_bytes().encode(&mut self.write)?; + "ssh-connection".encode(&mut self.write)?; + "publickey".encode(&mut self.write)?; self.write.push(0); // This is a probe debug!("write_auth_request: cert - {:?}", cert.algorithm()); - self.write - .extend_ssh_string(cert.algorithm().to_certificate_type().as_bytes()); - self.write.extend_ssh_string(cert.to_bytes()?.as_slice()); + cert.algorithm() + .to_certificate_type() + .encode(&mut self.write)?; + cert.to_bytes()?.as_slice().encode(&mut self.write)?; true } auth::Method::FuturePublicKey { ref key, .. } => { - self.write.extend_ssh_string(user.as_bytes()); - self.write.extend_ssh_string(b"ssh-connection"); - self.write.extend_ssh_string(b"publickey"); + user.as_bytes().encode(&mut self.write)?; + "ssh-connection".encode(&mut self.write)?; + "publickey".encode(&mut self.write)?; self.write.push(0); // This is a probe - self.write - .extend_ssh_string(key.algorithm().as_str().as_bytes()); + key.algorithm().as_str().encode(&mut self.write)?; - self.write.extend_ssh_string(key.to_bytes()?.as_slice()); + key.to_bytes()?.as_slice().encode(&mut self.write)?; true } auth::Method::KeyboardInteractive { ref submethods } => { debug!("Keyboard Iinteractive"); - self.write.extend_ssh_string(user.as_bytes()); - self.write.extend_ssh_string(b"ssh-connection"); - self.write.extend_ssh_string(b"keyboard-interactive"); - self.write.extend_ssh_string(b""); // lang tag is deprecated. Should be empty - self.write.extend_ssh_string(submethods.as_bytes()); + user.as_bytes().encode(&mut self.write)?; + "ssh-connection".encode(&mut self.write)?; + "keyboard-interactive".encode(&mut self.write)?; + "".encode(&mut self.write)?; // lang tag is deprecated. Should be empty + submethods.as_bytes().encode(&mut self.write)?; true } } @@ -1047,23 +1012,23 @@ impl Encrypted { buffer: &mut CryptoVec, ) -> Result { buffer.clear(); - buffer.extend_ssh_string(self.session_id.as_ref()); + self.session_id.as_ref().encode(buffer)?; let i0 = buffer.len(); buffer.push(msg::USERAUTH_REQUEST); - buffer.extend_ssh_string(user.as_bytes()); - buffer.extend_ssh_string(b"ssh-connection"); - buffer.extend_ssh_string(b"publickey"); - buffer.push(1); + user.encode(buffer)?; + "ssh-connection".encode(buffer)?; + "publickey".encode(buffer)?; + 1u8.encode(buffer)?; match key { PublicKeyOrCertificate::Certificate(cert) => { - buffer.extend_ssh_string(cert.name().as_ref().as_bytes()); - buffer.extend_ssh_string(cert.to_bytes()?.as_slice()); + cert.name().as_ref().encode(buffer)?; + cert.to_bytes()?.encode(buffer)?; } PublicKeyOrCertificate::PublicKey(key) => { - buffer.extend_ssh_string(key.name().as_ref().as_bytes()); - buffer.extend_ssh_string(key.to_bytes()?.as_slice()); + key.name().as_ref().encode(buffer)?; + key.to_bytes()?.encode(buffer)?; } } Ok(i0) @@ -1082,9 +1047,10 @@ impl Encrypted { &PublicKeyOrCertificate::PublicKey(key.public_key().clone()), buffer, )?; - // Extend with self-signature. - add_self_signature(&**key, buffer)?; + // Extend with self-signature. + let signature = signature::Signer::try_sign(&**key, buffer)?; + signature.encoded()?.encode(&mut *buffer)?; push_packet!(self.write, { #[allow(clippy::indexing_slicing)] // length checked @@ -1097,8 +1063,10 @@ impl Encrypted { &PublicKeyOrCertificate::Certificate(cert.clone()), buffer, )?; + // Extend with self-signature. - add_self_signature(&**key, buffer)?; + let signature = signature::Signer::try_sign(&**key, buffer)?; + signature.encoded()?.encode(&mut *buffer)?; push_packet!(self.write, { #[allow(clippy::indexing_slicing)] // length checked @@ -1112,12 +1080,11 @@ impl Encrypted { fn client_send_auth_response(&mut self, responses: &[String]) -> Result<(), crate::Error> { push_packet!(self.write, { - self.write.push(msg::USERAUTH_INFO_RESPONSE); - self.write - .push_u32_be(responses.len().try_into().unwrap_or(0)); // number of responses + msg::USERAUTH_INFO_RESPONSE.encode(&mut self.write)?; + (responses.len().try_into().unwrap_or(0) as u32).encode(&mut self.write)?; // number of responses for r in responses { - self.write.extend_ssh_string(r.as_bytes()); // write the reponses + r.encode(&mut self.write)?; // write the reponses } }); Ok(()) diff --git a/russh/src/client/mod.rs b/russh/src/client/mod.rs index 8988ae72..07b942e9 100644 --- a/russh/src/client/mod.rs +++ b/russh/src/client/mod.rs @@ -42,11 +42,13 @@ use std::pin::Pin; use std::sync::Arc; use async_trait::async_trait; +use bytes::Bytes; use futures::task::{Context, Poll}; use futures::Future; use log::{debug, error, info, trace}; -use russh_keys::encoding::Encoding; +use russh_keys::map_err; use signature::Verifier; +use ssh_encoding::{Decode, Encode, Reader}; use ssh_key::{Certificate, PrivateKey, PublicKey, Signature}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf}; use tokio::pin; @@ -57,7 +59,6 @@ use tokio::sync::{oneshot, Mutex}; use crate::channels::{Channel, ChannelMsg, ChannelRef}; use crate::cipher::{self, clear, CipherPair, OpeningKey}; -use crate::keys::encoding::Reader; use crate::keys::key::parse_public_key; use crate::session::{ CommonSession, EncryptedState, Exchange, GlobalRequestResponse, Kex, KexDhDone, KexInit, @@ -713,9 +714,9 @@ pub async fn connect( addrs: A, handler: H, ) -> Result, H::Error> { - let socket = tokio::net::TcpStream::connect(addrs) - .await - .map_err(crate::Error::from)?; + use russh_keys::map_err; + + let socket = map_err!(tokio::net::TcpStream::connect(addrs).await)?; connect_stream(config, socket, handler).await } @@ -735,10 +736,7 @@ where // Writing SSH id. let mut write_buffer = SSHBuffer::new(); write_buffer.send_ssh_id(&config.as_ref().client_id); - stream - .write_all(&write_buffer.buffer) - .await - .map_err(crate::Error::from)?; + map_err!(stream.write_all(&write_buffer.buffer).await)?; // Reading SSH id and allocating a session if correct. let mut stream = SshRead::new(stream); @@ -844,7 +842,7 @@ impl Session { trace!("disconnected"); self.receiver.close(); self.inbound_channel_receiver.close(); - stream_write.shutdown().await.map_err(crate::Error::from)?; + map_err!(stream_write.shutdown().await)?; match result { Ok(v) => { handler @@ -881,11 +879,12 @@ impl Session { self.flush()?; if !self.common.write_buffer.buffer.is_empty() { debug!("writing {:?} bytes", self.common.write_buffer.buffer.len()); - stream_write - .write_all(&self.common.write_buffer.buffer) - .await - .map_err(crate::Error::from)?; - stream_write.flush().await.map_err(crate::Error::from)?; + map_err!( + stream_write + .write_all(&self.common.write_buffer.buffer) + .await + )?; + map_err!(stream_write.flush().await)?; } self.common.write_buffer.buffer.clear(); let mut decomp = CryptoVec::new(); @@ -941,7 +940,7 @@ impl Session { if !buf.is_empty() { #[allow(clippy::indexing_slicing)] // length checked if buf[0] == crate::msg::DISCONNECT { - result = self.process_disconnect(buf); + result = self.process_disconnect(&buf[1..]); } else { self.common.received_data = true; reply(self, handler, kex_done_signal, &mut buffer.seqn, buf).await?; @@ -957,7 +956,7 @@ impl Session { return Err(crate::Error::KeepaliveTimeout.into()); } self.common.alive_timeouts = self.common.alive_timeouts.saturating_add(1); - self.send_keepalive(true); + self.send_keepalive(true)?; sent_keepalive = true; } () = &mut inactivity_timer => { @@ -1003,11 +1002,12 @@ impl Session { "writing to stream: {:?} bytes", self.common.write_buffer.buffer.len() ); - stream_write - .write_all(&self.common.write_buffer.buffer) - .await - .map_err(crate::Error::from)?; - stream_write.flush().await.map_err(crate::Error::from)?; + map_err!( + stream_write + .write_all(&self.common.write_buffer.buffer) + .await + )?; + map_err!(stream_write.flush().await)?; } self.common.write_buffer.buffer.clear(); if let Some(ref mut enc) = self.common.encrypted { @@ -1048,18 +1048,13 @@ impl Session { fn process_disconnect + Send>( &mut self, - buf: &[u8], + mut r: &[u8], ) -> Result { self.common.disconnected = true; - let mut reader = buf.reader(1); - let reason_code = reader.read_u32().map_err(crate::Error::from)?.try_into()?; - let message = std::str::from_utf8(reader.read_string().map_err(crate::Error::from)?) - .map_err(crate::Error::from)? - .to_owned(); - let lang_tag = std::str::from_utf8(reader.read_string().map_err(crate::Error::from)?) - .map_err(crate::Error::from)? - .to_owned(); + let reason_code = map_err!(u32::decode(&mut r))?.try_into()?; + let message = map_err!(String::decode(&mut r))?; + let lang_tag = map_err!(String::decode(&mut r))?; Ok(RemoteDisconnectInfo { reason_code, @@ -1113,31 +1108,31 @@ impl Session { reply_channel, address, port, - } => self.tcpip_forward(reply_channel, &address, port), + } => self.tcpip_forward(reply_channel, &address, port)?, Msg::CancelTcpIpForward { reply_channel, address, port, - } => self.cancel_tcpip_forward(reply_channel, &address, port), + } => self.cancel_tcpip_forward(reply_channel, &address, port)?, Msg::StreamLocalForward { reply_channel, socket_path, - } => self.streamlocal_forward(reply_channel, &socket_path), + } => self.streamlocal_forward(reply_channel, &socket_path)?, Msg::CancelStreamLocalForward { reply_channel, socket_path, - } => self.cancel_streamlocal_forward(reply_channel, &socket_path), + } => self.cancel_streamlocal_forward(reply_channel, &socket_path)?, Msg::Disconnect { reason, description, language_tag, - } => self.disconnect(reason, &description, &language_tag), - Msg::Channel(id, ChannelMsg::Data { data }) => self.data(id, data), + } => self.disconnect(reason, &description, &language_tag)?, + Msg::Channel(id, ChannelMsg::Data { data }) => self.data(id, data)?, Msg::Channel(id, ChannelMsg::Eof) => { - self.eof(id); + self.eof(id)?; } Msg::Channel(id, ChannelMsg::ExtendedData { data, ext }) => { - self.extended_data(id, ext, data); + self.extended_data(id, ext, data)?; } Msg::Channel( id, @@ -1159,7 +1154,7 @@ impl Session { pix_width, pix_height, &terminal_modes, - ), + )?, Msg::Channel( id, ChannelMsg::WindowChange { @@ -1168,7 +1163,7 @@ impl Session { pix_width, pix_height, }, - ) => self.window_change(id, col_width, row_height, pix_width, pix_height), + ) => self.window_change(id, col_width, row_height, pix_width, pix_height)?, Msg::Channel( id, ChannelMsg::RequestX11 { @@ -1185,7 +1180,7 @@ impl Session { &x11_authentication_protocol, &x11_authentication_cookie, x11_screen_number, - ), + )?, Msg::Channel( id, ChannelMsg::SetEnv { @@ -1193,9 +1188,9 @@ impl Session { variable_name, variable_value, }, - ) => self.set_env(id, want_reply, &variable_name, &variable_value), + ) => self.set_env(id, want_reply, &variable_name, &variable_value)?, Msg::Channel(id, ChannelMsg::RequestShell { want_reply }) => { - self.request_shell(want_reply, id) + self.request_shell(want_reply, id)? } Msg::Channel( id, @@ -1203,15 +1198,15 @@ impl Session { want_reply, command, }, - ) => self.exec(id, want_reply, &command), - Msg::Channel(id, ChannelMsg::Signal { signal }) => self.signal(id, signal), + ) => self.exec(id, want_reply, &command)?, + Msg::Channel(id, ChannelMsg::Signal { signal }) => self.signal(id, signal)?, Msg::Channel(id, ChannelMsg::RequestSubsystem { want_reply, name }) => { - self.request_subsystem(want_reply, id, &name) + self.request_subsystem(want_reply, id, &name)? } Msg::Channel(id, ChannelMsg::AgentForward { want_reply }) => { - self.agent_forward(id, want_reply) + self.agent_forward(id, want_reply)? } - Msg::Channel(id, ChannelMsg::Close) => self.close(id), + Msg::Channel(id, ChannelMsg::Close) => self.close(id)?, msg => { // should be unreachable, since the receiver only gets // messages from methods implemented within russh @@ -1295,15 +1290,14 @@ thread_local! { } impl KexDhDone { - async fn server_key_check( + async fn server_key_check( mut self, rekey: bool, handler: &mut H, - buf: &[u8], + r: &mut R, ) -> Result { - let mut reader = buf.reader(1); - let pubkey = reader.read_string().map_err(crate::Error::from)?; // server public key. - let pubkey = parse_public_key(pubkey).map_err(crate::Error::from)?; + let pubkey = map_err!(Bytes::decode(r))?; // server public key. + let pubkey = map_err!(parse_public_key(&pubkey))?; debug!("server_public_Key: {:?}", pubkey); if !rekey { let check = handler.check_server_key(&pubkey).await?; @@ -1315,16 +1309,16 @@ impl KexDhDone { let mut buffer = buffer.borrow_mut(); buffer.clear(); let hash = { - let server_ephemeral = reader.read_string().map_err(crate::Error::from)?; - self.exchange.server_ephemeral.extend(server_ephemeral); - let signature = reader.read_string().map_err(crate::Error::from)?; + let server_ephemeral = map_err!(Bytes::decode(r))?; + self.exchange.server_ephemeral.extend(&server_ephemeral); + let signature = map_err!(Bytes::decode(r))?; self.kex .compute_shared_secret(&self.exchange.server_ephemeral)?; debug!("kexdhdone.exchange = {:?}", self.exchange); let mut pubkey_vec = CryptoVec::new(); - pubkey_vec.extend_ssh_string(&pubkey.to_bytes().map_err(crate::Error::from)?); + map_err!(map_err!(pubkey.to_bytes())?.encode(&mut pubkey_vec))?; let hash = self.kex @@ -1332,11 +1326,12 @@ impl KexDhDone { debug!("exchange hash: {:?}", hash); let signature = { - let mut sig_reader = signature.reader(0); - let sig_type = sig_reader.read_string().map_err(crate::Error::from)?; + let mut r = &signature[..]; + let sig_type = map_err!(String::decode(&mut r))?; debug!("sig_type: {:?}", sig_type); - sig_reader.read_string().map_err(crate::Error::from)? + map_err!(Bytes::decode(&mut r))? }; + debug!("signature: {:?}", signature); let signature = Signature::new(pubkey.algorithm(), signature).map_err(|e| { debug!("signature ctor failed: {e:?}"); @@ -1418,7 +1413,12 @@ async fn reply( Ok(()) } else if buf.first() == Some(&msg::KEX_ECDH_REPLY) { // We've sent ECDH_INIT, waiting for ECDH_REPLY - let kex = kexdhdone.server_key_check(false, handler, buf).await?; + + #[allow(clippy::indexing_slicing)] // length checked + let kex = kexdhdone + .server_key_check(false, handler, &mut &buf[1..]) + .await?; + session.common.strict_kex = session.common.strict_kex || kex.names.strict_kex; session.common.kex = Some(Kex::Keys(kex)); session @@ -1656,7 +1656,7 @@ pub trait Handler: Sized + Send { async fn should_accept_unknown_server_channel( &mut self, id: ChannelId, - channel_type: &[u8], + channel_type: &str, ) -> bool { false } diff --git a/russh/src/client/session.rs b/russh/src/client/session.rs index 26f8a761..ccea4479 100644 --- a/russh/src/client/session.rs +++ b/russh/src/client/session.rs @@ -1,8 +1,9 @@ use log::error; +use russh_keys::map_err; +use ssh_encoding::Encode; use tokio::sync::oneshot; use crate::client::Session; -use crate::keys::encoding::Encoding; use crate::session::EncryptedState; use crate::{msg, ChannelId, CryptoVec, Disconnect, Pty, Sig}; @@ -13,7 +14,7 @@ impl Session { write_suffix: F, ) -> Result where - F: FnOnce(&mut CryptoVec), + F: FnOnce(&mut CryptoVec) -> Result<(), crate::Error>, { let result = if let Some(ref mut enc) = self.common.encrypted { match enc.state { @@ -23,21 +24,27 @@ impl Session { self.common.config.maximum_packet_size, ); push_packet!(enc.write, { - enc.write.push(msg::CHANNEL_OPEN); - enc.write.extend_ssh_string(kind); + msg::CHANNEL_OPEN.encode(&mut enc.write)?; + kind.encode(&mut enc.write)?; // sender channel id. - enc.write.push_u32_be(sender_channel.0); + sender_channel.encode(&mut enc.write)?; // window. - enc.write - .push_u32_be(self.common.config.as_ref().window_size); + self.common + .config + .as_ref() + .window_size + .encode(&mut enc.write)?; // max packet size. - enc.write - .push_u32_be(self.common.config.as_ref().maximum_packet_size); + self.common + .config + .as_ref() + .maximum_packet_size + .encode(&mut enc.write)?; - write_suffix(&mut enc.write); + write_suffix(&mut enc.write)?; }); sender_channel } @@ -50,7 +57,7 @@ impl Session { } pub fn channel_open_session(&mut self) -> Result { - self.channel_open_generic(b"session", |_| ()) + self.channel_open_generic(b"session", |_| Ok(())) } pub fn channel_open_x11( @@ -59,8 +66,9 @@ impl Session { originator_port: u32, ) -> Result { self.channel_open_generic(b"x11", |write| { - write.extend_ssh_string(originator_address.as_bytes()); - write.push_u32_be(originator_port); // sender channel id. + map_err!(originator_address.encode(write))?; + map_err!(originator_port.encode(write))?; // sender channel id. + Ok(()) }) } @@ -72,10 +80,11 @@ impl Session { originator_port: u32, ) -> Result { self.channel_open_generic(b"direct-tcpip", |write| { - write.extend_ssh_string(host_to_connect.as_bytes()); - write.push_u32_be(port_to_connect); // sender channel id. - write.extend_ssh_string(originator_address.as_bytes()); - write.push_u32_be(originator_port); // sender channel id. + host_to_connect.encode(write)?; + port_to_connect.encode(write)?; // sender channel id. + originator_address.encode(write)?; + originator_port.encode(write)?; // sender channel id. + Ok(()) }) } @@ -84,9 +93,10 @@ impl Session { socket_path: &str, ) -> Result { self.channel_open_generic(b"direct-streamlocal@openssh.com", |write| { - write.extend_ssh_string(socket_path.as_bytes()); - write.extend_ssh_string("".as_bytes()); // reserved - write.push_u32_be(0); // reserved + socket_path.encode(write)?; + "".encode(write)?; // reserved + 0u32.encode(write)?; // reserved + Ok(()) }) } @@ -101,32 +111,33 @@ impl Session { pix_width: u32, pix_height: u32, terminal_modes: &[(Pty, u32)], - ) { + ) -> Result<(), crate::Error> { if let Some(ref mut enc) = self.common.encrypted { if let Some(channel) = enc.channels.get(&channel) { push_packet!(enc.write, { - enc.write.push(msg::CHANNEL_REQUEST); + map_err!(msg::CHANNEL_REQUEST.encode(&mut enc.write))?; - enc.write.push_u32_be(channel.recipient_channel); - enc.write.extend_ssh_string(b"pty-req"); - enc.write.push(want_reply as u8); + channel.recipient_channel.encode(&mut enc.write)?; + "pty-req".encode(&mut enc.write)?; + (want_reply as u8).encode(&mut enc.write)?; - enc.write.extend_ssh_string(term.as_bytes()); - enc.write.push_u32_be(col_width); - enc.write.push_u32_be(row_height); - enc.write.push_u32_be(pix_width); - enc.write.push_u32_be(pix_height); + term.encode(&mut enc.write)?; + col_width.encode(&mut enc.write)?; + row_height.encode(&mut enc.write)?; + pix_width.encode(&mut enc.write)?; + pix_height.encode(&mut enc.write)?; - enc.write.push_u32_be((1 + 5 * terminal_modes.len()) as u32); + ((1 + 5 * terminal_modes.len()) as u32).encode(&mut enc.write)?; for &(code, value) in terminal_modes { - enc.write.push(code as u8); - enc.write.push_u32_be(value) + (code as u8).encode(&mut enc.write)?; + value.encode(&mut enc.write)?; } // 0 code (to terminate the list) - enc.write.push(0); + 0u8.encode(&mut enc.write)?; }); } } + Ok(()) } pub fn request_x11( @@ -137,24 +148,23 @@ impl Session { x11_authentication_protocol: &str, x11_authentication_cookie: &str, x11_screen_number: u32, - ) { + ) -> Result<(), crate::Error> { if let Some(ref mut enc) = self.common.encrypted { if let Some(channel) = enc.channels.get(&channel) { push_packet!(enc.write, { - enc.write.push(msg::CHANNEL_REQUEST); + msg::CHANNEL_REQUEST.encode(&mut enc.write)?; - enc.write.push_u32_be(channel.recipient_channel); - enc.write.extend_ssh_string(b"x11-req"); + channel.recipient_channel.encode(&mut enc.write)?; + "x11-req".encode(&mut enc.write)?; enc.write.push(want_reply as u8); enc.write.push(single_connection as u8); - enc.write - .extend_ssh_string(x11_authentication_protocol.as_bytes()); - enc.write - .extend_ssh_string(x11_authentication_cookie.as_bytes()); - enc.write.push_u32_be(x11_screen_number); + x11_authentication_protocol.encode(&mut enc.write)?; + x11_authentication_cookie.encode(&mut enc.write)?; + x11_screen_number.encode(&mut enc.write)?; }); } } + Ok(()) } pub fn set_env( @@ -163,80 +173,99 @@ impl Session { want_reply: bool, variable_name: &str, variable_value: &str, - ) { + ) -> Result<(), crate::Error> { if let Some(ref mut enc) = self.common.encrypted { if let Some(channel) = enc.channels.get(&channel) { push_packet!(enc.write, { - enc.write.push(msg::CHANNEL_REQUEST); + msg::CHANNEL_REQUEST.encode(&mut enc.write)?; - enc.write.push_u32_be(channel.recipient_channel); - enc.write.extend_ssh_string(b"env"); - enc.write.push(want_reply as u8); - enc.write.extend_ssh_string(variable_name.as_bytes()); - enc.write.extend_ssh_string(variable_value.as_bytes()); + channel.recipient_channel.encode(&mut enc.write)?; + "env".encode(&mut enc.write)?; + (want_reply as u8).encode(&mut enc.write)?; + variable_name.encode(&mut enc.write)?; + variable_value.encode(&mut enc.write)?; }); } } + Ok(()) } - pub fn request_shell(&mut self, want_reply: bool, channel: ChannelId) { + pub fn request_shell( + &mut self, + want_reply: bool, + channel: ChannelId, + ) -> Result<(), crate::Error> { if let Some(ref mut enc) = self.common.encrypted { if let Some(channel) = enc.channels.get(&channel) { push_packet!(enc.write, { - enc.write.push(msg::CHANNEL_REQUEST); + msg::CHANNEL_REQUEST.encode(&mut enc.write)?; - enc.write.push_u32_be(channel.recipient_channel); - enc.write.extend_ssh_string(b"shell"); - enc.write.push(want_reply as u8); + channel.recipient_channel.encode(&mut enc.write)?; + "shell".encode(&mut enc.write)?; + (want_reply as u8).encode(&mut enc.write)?; }); } } + Ok(()) } - pub fn exec(&mut self, channel: ChannelId, want_reply: bool, command: &[u8]) { + pub fn exec( + &mut self, + channel: ChannelId, + want_reply: bool, + command: &[u8], + ) -> Result<(), crate::Error> { if let Some(ref mut enc) = self.common.encrypted { if let Some(channel) = enc.channels.get(&channel) { push_packet!(enc.write, { - enc.write.push(msg::CHANNEL_REQUEST); + msg::CHANNEL_REQUEST.encode(&mut enc.write)?; - enc.write.push_u32_be(channel.recipient_channel); - enc.write.extend_ssh_string(b"exec"); - enc.write.push(want_reply as u8); - enc.write.extend_ssh_string(command); + channel.recipient_channel.encode(&mut enc.write)?; + "exec".encode(&mut enc.write)?; + (want_reply as u8).encode(&mut enc.write)?; + command.encode(&mut enc.write)?; }); - return; + return Ok(()); } } error!("exec"); + Ok(()) } - pub fn signal(&mut self, channel: ChannelId, signal: Sig) { + pub fn signal(&mut self, channel: ChannelId, signal: Sig) -> Result<(), crate::Error> { if let Some(ref mut enc) = self.common.encrypted { if let Some(channel) = enc.channels.get(&channel) { push_packet!(enc.write, { - enc.write.push(msg::CHANNEL_REQUEST); - enc.write.push_u32_be(channel.recipient_channel); - enc.write.extend_ssh_string(b"signal"); - enc.write.push(0); - enc.write.extend_ssh_string(signal.name().as_bytes()); + msg::CHANNEL_REQUEST.encode(&mut enc.write)?; + channel.recipient_channel.encode(&mut enc.write)?; + "signal".encode(&mut enc.write)?; + 0u8.encode(&mut enc.write)?; + signal.name().encode(&mut enc.write)?; }); } } + Ok(()) } - pub fn request_subsystem(&mut self, want_reply: bool, channel: ChannelId, name: &str) { + pub fn request_subsystem( + &mut self, + want_reply: bool, + channel: ChannelId, + name: &str, + ) -> Result<(), crate::Error> { if let Some(ref mut enc) = self.common.encrypted { if let Some(channel) = enc.channels.get(&channel) { push_packet!(enc.write, { - enc.write.push(msg::CHANNEL_REQUEST); + msg::CHANNEL_REQUEST.encode(&mut enc.write)?; - enc.write.push_u32_be(channel.recipient_channel); - enc.write.extend_ssh_string(b"subsystem"); - enc.write.push(want_reply as u8); - enc.write.extend_ssh_string(name.as_bytes()); + channel.recipient_channel.encode(&mut enc.write)?; + "subsystem".encode(&mut enc.write)?; + (want_reply as u8).encode(&mut enc.write)?; + name.encode(&mut enc.write)?; }); } } + Ok(()) } pub fn window_change( @@ -246,22 +275,23 @@ impl Session { row_height: u32, pix_width: u32, pix_height: u32, - ) { + ) -> Result<(), crate::Error> { if let Some(ref mut enc) = self.common.encrypted { if let Some(channel) = enc.channels.get(&channel) { push_packet!(enc.write, { - enc.write.push(msg::CHANNEL_REQUEST); - - enc.write.push_u32_be(channel.recipient_channel); - enc.write.extend_ssh_string(b"window-change"); - enc.write.push(0); // this packet never wants reply - enc.write.push_u32_be(col_width); - enc.write.push_u32_be(row_height); - enc.write.push_u32_be(pix_width); - enc.write.push_u32_be(pix_height); + msg::CHANNEL_REQUEST.encode(&mut enc.write)?; + + channel.recipient_channel.encode(&mut enc.write)?; + "window-change".encode(&mut enc.write)?; + 0u8.encode(&mut enc.write)?; + col_width.encode(&mut enc.write)?; + row_height.encode(&mut enc.write)?; + pix_width.encode(&mut enc.write)?; + pix_height.encode(&mut enc.write)?; }); } } + Ok(()) } /// Requests a TCP/IP forwarding from the server @@ -273,7 +303,7 @@ impl Session { reply_channel: Option>>, address: &str, port: u32, - ) { + ) -> Result<(), crate::Error> { if let Some(ref mut enc) = self.common.encrypted { let want_reply = reply_channel.is_some(); if let Some(reply_channel) = reply_channel { @@ -282,13 +312,14 @@ impl Session { ); } push_packet!(enc.write, { - enc.write.push(msg::GLOBAL_REQUEST); - enc.write.extend_ssh_string(b"tcpip-forward"); - enc.write.push(want_reply as u8); - enc.write.extend_ssh_string(address.as_bytes()); - enc.write.push_u32_be(port); + msg::GLOBAL_REQUEST.encode(&mut enc.write)?; + "tcpip-forward".encode(&mut enc.write)?; + (want_reply as u8).encode(&mut enc.write)?; + address.encode(&mut enc.write)?; + port.encode(&mut enc.write)?; }); } + Ok(()) } /// Requests cancellation of TCP/IP forwarding from the server @@ -300,7 +331,7 @@ impl Session { reply_channel: Option>, address: &str, port: u32, - ) { + ) -> Result<(), crate::Error> { if let Some(ref mut enc) = self.common.encrypted { let want_reply = reply_channel.is_some(); if let Some(reply_channel) = reply_channel { @@ -309,13 +340,14 @@ impl Session { ); } push_packet!(enc.write, { - enc.write.push(msg::GLOBAL_REQUEST); - enc.write.extend_ssh_string(b"cancel-tcpip-forward"); - enc.write.push(want_reply as u8); - enc.write.extend_ssh_string(address.as_bytes()); - enc.write.push_u32_be(port); + msg::GLOBAL_REQUEST.encode(&mut enc.write)?; + "cancel-tcpip-forward".encode(&mut enc.write)?; + (want_reply as u8).encode(&mut enc.write)?; + address.encode(&mut enc.write)?; + port.encode(&mut enc.write)?; }); } + Ok(()) } /// Requests a UDS forwarding from the server, `socket path` being the server side socket path. @@ -326,7 +358,7 @@ impl Session { &mut self, reply_channel: Option>, socket_path: &str, - ) { + ) -> Result<(), crate::Error> { if let Some(ref mut enc) = self.common.encrypted { let want_reply = reply_channel.is_some(); if let Some(reply_channel) = reply_channel { @@ -335,13 +367,13 @@ impl Session { ); } push_packet!(enc.write, { - enc.write.push(msg::GLOBAL_REQUEST); - enc.write - .extend_ssh_string(b"streamlocal-forward@openssh.com"); - enc.write.push(want_reply as u8); - enc.write.extend_ssh_string(socket_path.as_bytes()); + msg::GLOBAL_REQUEST.encode(&mut enc.write)?; + "streamlocal-forward@openssh.com".encode(&mut enc.write)?; + (want_reply as u8).encode(&mut enc.write)?; + socket_path.encode(&mut enc.write)?; }); } + Ok(()) } /// Requests cancellation of UDS forwarding from the server @@ -352,7 +384,7 @@ impl Session { &mut self, reply_channel: Option>, socket_path: &str, - ) { + ) -> Result<(), crate::Error> { if let Some(ref mut enc) = self.common.encrypted { let want_reply = reply_channel.is_some(); if let Some(reply_channel) = reply_channel { @@ -361,28 +393,29 @@ impl Session { ); } push_packet!(enc.write, { - enc.write.push(msg::GLOBAL_REQUEST); - enc.write - .extend_ssh_string(b"cancel-streamlocal-forward@openssh.com"); - enc.write.push(want_reply as u8); - enc.write.extend_ssh_string(socket_path.as_bytes()); + msg::GLOBAL_REQUEST.encode(&mut enc.write)?; + "cancel-streamlocal-forward@openssh.com".encode(&mut enc.write)?; + (want_reply as u8).encode(&mut enc.write)?; + socket_path.encode(&mut enc.write)?; }); } + Ok(()) } - pub fn send_keepalive(&mut self, want_reply: bool) { + pub fn send_keepalive(&mut self, want_reply: bool) -> Result<(), crate::Error> { self.open_global_requests .push_back(crate::session::GlobalRequestResponse::Keepalive); if let Some(ref mut enc) = self.common.encrypted { push_packet!(enc.write, { - enc.write.push(msg::GLOBAL_REQUEST); - enc.write.extend_ssh_string(b"keepalive@openssh.com"); - enc.write.push(want_reply as u8); + msg::GLOBAL_REQUEST.encode(&mut enc.write)?; + "keepalive@openssh.com".encode(&mut enc.write)?; + (want_reply as u8).encode(&mut enc.write)?; }); } + Ok(()) } - pub fn data(&mut self, channel: ChannelId, data: CryptoVec) { + pub fn data(&mut self, channel: ChannelId, data: CryptoVec) -> Result<(), crate::Error> { if let Some(ref mut enc) = self.common.encrypted { enc.data(channel, data) } else { @@ -390,7 +423,7 @@ impl Session { } } - pub fn eof(&mut self, channel: ChannelId) { + pub fn eof(&mut self, channel: ChannelId) -> Result<(), crate::Error> { if let Some(ref mut enc) = self.common.encrypted { enc.eof(channel) } else { @@ -398,7 +431,7 @@ impl Session { } } - pub fn close(&mut self, channel: ChannelId) { + pub fn close(&mut self, channel: ChannelId) -> Result<(), crate::Error> { if let Some(ref mut enc) = self.common.encrypted { enc.close(channel) } else { @@ -406,7 +439,12 @@ impl Session { } } - pub fn extended_data(&mut self, channel: ChannelId, ext: u32, data: CryptoVec) { + pub fn extended_data( + &mut self, + channel: ChannelId, + ext: u32, + data: CryptoVec, + ) -> Result<(), crate::Error> { if let Some(ref mut enc) = self.common.encrypted { enc.extended_data(channel, ext, data) } else { @@ -414,21 +452,31 @@ impl Session { } } - pub fn agent_forward(&mut self, channel: ChannelId, want_reply: bool) { + pub fn agent_forward( + &mut self, + channel: ChannelId, + want_reply: bool, + ) -> Result<(), crate::Error> { if let Some(ref mut enc) = self.common.encrypted { if let Some(channel) = enc.channels.get(&channel) { push_packet!(enc.write, { - enc.write.push(msg::CHANNEL_REQUEST); - enc.write.push_u32_be(channel.recipient_channel); - enc.write.extend_ssh_string(b"auth-agent-req@openssh.com"); - enc.write.push(want_reply as u8); + msg::CHANNEL_REQUEST.encode(&mut enc.write)?; + channel.recipient_channel.encode(&mut enc.write)?; + "auth-agent-req@openssh.com".encode(&mut enc.write)?; + (want_reply as u8).encode(&mut enc.write)?; }); } } + Ok(()) } - pub fn disconnect(&mut self, reason: Disconnect, description: &str, language_tag: &str) { - self.common.disconnect(reason, description, language_tag); + pub fn disconnect( + &mut self, + reason: Disconnect, + description: &str, + language_tag: &str, + ) -> Result<(), crate::Error> { + self.common.disconnect(reason, description, language_tag) } pub fn has_pending_data(&self, channel: ChannelId) -> bool { diff --git a/russh/src/compression.rs b/russh/src/compression.rs index 6d739bf6..d6eec087 100644 --- a/russh/src/compression.rs +++ b/russh/src/compression.rs @@ -1,5 +1,8 @@ use std::convert::TryFrom; +use delegate::delegate; +use ssh_encoding::Encode; + #[derive(Debug, Clone)] pub enum Compression { None, @@ -29,6 +32,13 @@ impl AsRef for Name { } } +impl Encode for Name { + delegate! { to self.as_ref() { + fn encoded_len(&self) -> Result; + fn encode(&self, writer: &mut impl ssh_encoding::Writer) -> Result<(), ssh_encoding::Error>; + }} +} + impl TryFrom<&str> for Name { type Error = (); fn try_from(s: &str) -> Result { diff --git a/russh/src/kex/curve25519.rs b/russh/src/kex/curve25519.rs index c26267b4..26c25b4b 100644 --- a/russh/src/kex/curve25519.rs +++ b/russh/src/kex/curve25519.rs @@ -3,9 +3,10 @@ use curve25519_dalek::constants::ED25519_BASEPOINT_TABLE; use curve25519_dalek::montgomery::MontgomeryPoint; use curve25519_dalek::scalar::Scalar; use log::debug; +use ssh_encoding::Encode; use super::{compute_keys, KexAlgorithm, KexType}; -use crate::keys::encoding::Encoding; +use crate::kex::encode_mpint; use crate::mac::{self}; use crate::session::Exchange; use crate::{cipher, msg, CryptoVec}; @@ -94,8 +95,8 @@ impl KexAlgorithm for Curve25519Kex { client_ephemeral.clear(); client_ephemeral.extend(&client_pubkey.0); - buf.push(msg::KEX_ECDH_INIT); - buf.extend_ssh_string(&client_pubkey.0); + msg::KEX_ECDH_INIT.encode(buf)?; + client_pubkey.0.encode(buf)?; self.local_secret = Some(client_secret); Ok(()) @@ -118,17 +119,17 @@ impl KexAlgorithm for Curve25519Kex { ) -> Result { // Computing the exchange hash, see page 7 of RFC 5656. buffer.clear(); - buffer.extend_ssh_string(&exchange.client_id); - buffer.extend_ssh_string(&exchange.server_id); - buffer.extend_ssh_string(&exchange.client_kex_init); - buffer.extend_ssh_string(&exchange.server_kex_init); + exchange.client_id.encode(buffer)?; + exchange.server_id.encode(buffer)?; + exchange.client_kex_init.encode(buffer)?; + exchange.server_kex_init.encode(buffer)?; buffer.extend(key); - buffer.extend_ssh_string(&exchange.client_ephemeral); - buffer.extend_ssh_string(&exchange.server_ephemeral); + exchange.client_ephemeral.encode(buffer)?; + exchange.server_ephemeral.encode(buffer)?; if let Some(ref shared) = self.shared_secret { - buffer.extend_ssh_mpint(&shared.0); + encode_mpint(&shared.0, buffer)?; } use sha2::Digest; diff --git a/russh/src/kex/dh/mod.rs b/russh/src/kex/dh/mod.rs index e409348d..67ab200f 100644 --- a/russh/src/kex/dh/mod.rs +++ b/russh/src/kex/dh/mod.rs @@ -8,10 +8,10 @@ use log::debug; use num_bigint::BigUint; use sha1::Sha1; use sha2::{Sha256, Sha512}; +use ssh_encoding::Encode; use self::groups::{DhGroup, DH_GROUP1, DH_GROUP14, DH_GROUP16}; use super::{compute_keys, KexAlgorithm, KexType}; -use crate::keys::encoding::Encoding; use crate::session::Exchange; use crate::{cipher, mac, msg, CryptoVec}; @@ -155,8 +155,8 @@ impl KexAlgorithm for DhGroupKex { client_ephemeral.clear(); client_ephemeral.extend(&encoded_pubkey); - buf.push(msg::KEX_ECDH_INIT); - buf.extend_ssh_string(&encoded_pubkey); + msg::KEX_ECDH_INIT.encode(buf)?; + encoded_pubkey.encode(buf)?; Ok(()) } @@ -184,17 +184,17 @@ impl KexAlgorithm for DhGroupKex { ) -> Result { // Computing the exchange hash, see page 7 of RFC 5656. buffer.clear(); - buffer.extend_ssh_string(&exchange.client_id); - buffer.extend_ssh_string(&exchange.server_id); - buffer.extend_ssh_string(&exchange.client_kex_init); - buffer.extend_ssh_string(&exchange.server_kex_init); + exchange.client_id.encode(buffer)?; + exchange.server_id.encode(buffer)?; + exchange.client_kex_init.encode(buffer)?; + exchange.server_kex_init.encode(buffer)?; buffer.extend(key); - buffer.extend_ssh_string(&exchange.client_ephemeral); - buffer.extend_ssh_string(&exchange.server_ephemeral); + exchange.client_ephemeral.encode(buffer)?; + exchange.server_ephemeral.encode(buffer)?; if let Some(ref shared) = self.shared_secret { - buffer.extend_ssh_mpint(shared); + shared.encode(buffer)?; } let mut hasher = D::new(); diff --git a/russh/src/kex/ecdh_nistp.rs b/russh/src/kex/ecdh_nistp.rs index 58ad899c..c7420a17 100644 --- a/russh/src/kex/ecdh_nistp.rs +++ b/russh/src/kex/ecdh_nistp.rs @@ -1,4 +1,5 @@ use std::marker::PhantomData; +use std::ops::Deref; use byteorder::{BigEndian, ByteOrder}; use elliptic_curve::ecdh::{EphemeralSecret, SharedSecret}; @@ -10,9 +11,10 @@ use p256::NistP256; use p384::NistP384; use p521::NistP521; use sha2::{Digest, Sha256, Sha384, Sha512}; +use ssh_encoding::Encode; +use super::encode_mpint; use crate::kex::{compute_keys, KexAlgorithm, KexType}; -use crate::keys::encoding::Encoding; use crate::mac::{self}; use crate::session::Exchange; use crate::{cipher, msg, CryptoVec}; @@ -129,7 +131,7 @@ where client_ephemeral.extend(&client_pubkey.to_sec1_bytes()); buf.push(msg::KEX_ECDH_INIT); - buf.extend_ssh_string(&client_pubkey.to_sec1_bytes()); + client_pubkey.to_sec1_bytes().encode(buf)?; self.local_secret = Some(client_secret); Ok(()) @@ -151,17 +153,17 @@ where ) -> Result { // Computing the exchange hash, see page 7 of RFC 5656. buffer.clear(); - buffer.extend_ssh_string(&exchange.client_id); - buffer.extend_ssh_string(&exchange.server_id); - buffer.extend_ssh_string(&exchange.client_kex_init); - buffer.extend_ssh_string(&exchange.server_kex_init); + exchange.client_id.deref().encode(buffer)?; + exchange.server_id.deref().encode(buffer)?; + exchange.client_kex_init.deref().encode(buffer)?; + exchange.server_kex_init.deref().encode(buffer)?; buffer.extend(key); - buffer.extend_ssh_string(&exchange.client_ephemeral); - buffer.extend_ssh_string(&exchange.server_ephemeral); + exchange.client_ephemeral.deref().encode(buffer)?; + exchange.server_ephemeral.deref().encode(buffer)?; if let Some(ref shared) = self.shared_secret { - buffer.extend_ssh_mpint(shared.raw_secret_bytes()); + encode_mpint(shared.raw_secret_bytes(), buffer)?; } let mut hasher = D::new(); diff --git a/russh/src/kex/mod.rs b/russh/src/kex/mod.rs index 59a58633..823ebae4 100644 --- a/russh/src/kex/mod.rs +++ b/russh/src/kex/mod.rs @@ -23,17 +23,19 @@ use std::cell::RefCell; use std::collections::HashMap; use std::convert::TryFrom; use std::fmt::Debug; +use std::ops::DerefMut; use curve25519::Curve25519KexType; +use delegate::delegate; use dh::{ DhGroup14Sha1KexType, DhGroup14Sha256KexType, DhGroup16Sha512KexType, DhGroup1Sha1KexType, }; use digest::Digest; use ecdh_nistp::{EcdhNistP256KexType, EcdhNistP384KexType, EcdhNistP521KexType}; use once_cell::sync::Lazy; +use ssh_encoding::{Encode, Writer}; use crate::cipher::CIPHERS; -use crate::keys::encoding::Encoding; use crate::mac::{self, MACS}; use crate::session::Exchange; use crate::{cipher, CryptoVec}; @@ -88,6 +90,13 @@ impl AsRef for Name { } } +impl Encode for Name { + delegate! { to self.as_ref() { + fn encoded_len(&self) -> Result; + fn encode(&self, writer: &mut impl ssh_encoding::Writer) -> Result<(), ssh_encoding::Error>; + }} +} + impl TryFrom<&str> for Name { type Error = (); fn try_from(s: &str) -> Result { @@ -199,7 +208,7 @@ pub(crate) fn compute_keys( key.clear(); if let Some(shared) = shared_secret { - buffer.extend_ssh_mpint(shared); + encode_mpint(shared, buffer.deref_mut())?; } buffer.extend(exchange_hash.as_ref()); @@ -216,7 +225,7 @@ pub(crate) fn compute_keys( // extend. buffer.clear(); if let Some(shared) = shared_secret { - buffer.extend_ssh_mpint(shared); + encode_mpint(shared, buffer.deref_mut())?; } buffer.extend(exchange_hash.as_ref()); buffer.extend(key); @@ -284,3 +293,23 @@ pub(crate) fn compute_keys( }) }) } + +// NOTE: using MpInt::from_bytes().encode() will randomly fail, +// I'm assuming it's due to specific byte values / padding but no time to investigate +#[allow(clippy::indexing_slicing)] // length is known +pub(crate) fn encode_mpint(s: &[u8], w: &mut W) -> Result<(), crate::Error> { + // Skip initial 0s. + let mut i = 0; + while i < s.len() && s[i] == 0 { + i += 1 + } + // If the first non-zero is >= 128, write its length (u32, BE), followed by 0. + if s[i] & 0x80 != 0 { + ((s.len() - i + 1) as u32).encode(w)?; + 0u8.encode(w)?; + } else { + ((s.len() - i) as u32).encode(w)?; + } + w.write(&s[i..])?; + Ok(()) +} diff --git a/russh/src/lib.rs b/russh/src/lib.rs index 0360eaf4..cf078e2e 100644 --- a/russh/src/lib.rs +++ b/russh/src/lib.rs @@ -95,6 +95,7 @@ use std::fmt::{Debug, Display, Formatter}; use log::debug; use parsing::ChannelOpenConfirmation; pub use russh_cryptovec::CryptoVec; +use ssh_encoding::{Decode, Encode}; use thiserror::Error; #[cfg(test)] @@ -304,6 +305,9 @@ pub enum Error { sequence_number: usize, }, + #[error("Signature: {0}")] + Signature(#[from] signature::Error), + #[error("SshKey: {0}")] SshKey(#[from] ssh_key::Error), @@ -451,21 +455,21 @@ impl Sig { Sig::Custom(ref c) => c, } } - fn from_name(name: &[u8]) -> Result { + fn from_name(name: &str) -> Sig { match name { - b"ABRT" => Ok(Sig::ABRT), - b"ALRM" => Ok(Sig::ALRM), - b"FPE" => Ok(Sig::FPE), - b"HUP" => Ok(Sig::HUP), - b"ILL" => Ok(Sig::ILL), - b"INT" => Ok(Sig::INT), - b"KILL" => Ok(Sig::KILL), - b"PIPE" => Ok(Sig::PIPE), - b"QUIT" => Ok(Sig::QUIT), - b"SEGV" => Ok(Sig::SEGV), - b"TERM" => Ok(Sig::TERM), - b"USR1" => Ok(Sig::USR1), - x => Ok(Sig::Custom(std::str::from_utf8(x)?.to_string())), + "ABRT" => Sig::ABRT, + "ALRM" => Sig::ALRM, + "FPE" => Sig::FPE, + "HUP" => Sig::HUP, + "ILL" => Sig::ILL, + "INT" => Sig::INT, + "KILL" => Sig::KILL, + "PIPE" => Sig::PIPE, + "QUIT" => Sig::QUIT, + "SEGV" => Sig::SEGV, + "TERM" => Sig::TERM, + "USR1" => Sig::USR1, + x => Sig::Custom(x.to_string()), } } } @@ -497,6 +501,24 @@ impl ChannelOpenFailure { /// The identifier of a channel. pub struct ChannelId(u32); +impl Decode for ChannelId { + type Error = ssh_encoding::Error; + + fn decode(reader: &mut impl ssh_encoding::Reader) -> Result { + Ok(Self(u32::decode(reader)?)) + } +} + +impl Encode for ChannelId { + fn encoded_len(&self) -> Result { + self.0.encoded_len() + } + + fn encode(&self, writer: &mut impl ssh_encoding::Writer) -> Result<(), ssh_encoding::Error> { + self.0.encode(writer) + } +} + impl From for u32 { fn from(c: ChannelId) -> u32 { c.0 diff --git a/russh/src/mac/mod.rs b/russh/src/mac/mod.rs index 088f50be..8b705f54 100644 --- a/russh/src/mac/mod.rs +++ b/russh/src/mac/mod.rs @@ -17,11 +17,13 @@ use std::collections::HashMap; use std::convert::TryFrom; use std::marker::PhantomData; +use delegate::delegate; use digest::typenum::{U20, U32, U64}; use hmac::Hmac; use once_cell::sync::Lazy; use sha1::Sha1; use sha2::{Sha256, Sha512}; +use ssh_encoding::Encode; use self::crypto::CryptoMacAlgorithm; use self::crypto_etm::CryptoEtmMacAlgorithm; @@ -53,6 +55,13 @@ impl AsRef for Name { } } +impl Encode for Name { + delegate! { to self.as_ref() { + fn encoded_len(&self) -> Result; + fn encode(&self, writer: &mut impl ssh_encoding::Writer) -> Result<(), ssh_encoding::Error>; + }} +} + impl TryFrom<&str> for Name { type Error = (); fn try_from(s: &str) -> Result { diff --git a/russh/src/negotiation.rs b/russh/src/negotiation.rs index b08aa324..313548b1 100644 --- a/russh/src/negotiation.rs +++ b/russh/src/negotiation.rs @@ -13,15 +13,15 @@ // limitations under the License. // use std::borrow::Cow; -use std::str::from_utf8; use log::debug; use rand::RngCore; +use russh_keys::helpers::NameList; +use ssh_encoding::{Decode, Encode}; use ssh_key::{Algorithm, Certificate, EcdsaCurve, HashAlg, PrivateKey, PublicKey}; use crate::cipher::CIPHERS; use crate::kex::{EXTENSION_OPENSSH_STRICT_KEX_AS_CLIENT, EXTENSION_OPENSSH_STRICT_KEX_AS_SERVER}; -use crate::keys::encoding::{Encoding, Reader}; #[cfg(not(target_arch = "wasm32"))] use crate::server::Config; use crate::{cipher, compression, kex, mac, msg, AlgorithmKind, CryptoVec, Error}; @@ -181,10 +181,8 @@ impl<'a> Named<'a> for Certificate { } } -pub(crate) fn parse_kex_algo_list(list: &[u8]) -> Vec<&str> { - list.split(|&x| x == b',') - .map(|x| from_utf8(x).unwrap_or_default()) - .collect() +pub(crate) fn parse_kex_algo_list(list: &str) -> Vec<&str> { + list.split(',').collect() } pub(crate) trait Select { @@ -202,14 +200,16 @@ pub(crate) trait Select { pref: &Preferred, available_host_keys: Option<&[PrivateKey]>, ) -> Result { - let mut r = buffer.reader(17); + let Some(mut r) = &buffer.get(17..) else { + return Err(Error::Inconsistent); + }; // Key exchange - let kex_string = r.read_string()?; + let kex_string = String::decode(&mut r)?; let (kex_both_first, kex_algorithm) = Self::select( &pref.kex, - &parse_kex_algo_list(kex_string), + &parse_kex_algo_list(&kex_string), AlgorithmKind::Kex, )?; @@ -226,7 +226,7 @@ pub(crate) trait Select { } else { EXTENSION_OPENSSH_STRICT_KEX_AS_SERVER }], - &parse_kex_algo_list(kex_string), + &parse_kex_algo_list(&kex_string), AlgorithmKind::Kex, ) .is_ok(); @@ -236,7 +236,7 @@ pub(crate) trait Select { // Host key - let key_string: &[u8] = r.read_string()?; + let key_string = String::decode(&mut r)?; let possible_host_key_algos = match available_host_keys { Some(available_host_keys) => pref.possible_host_key_algos_for_keys(available_host_keys), None => pref.key.iter().map(ToOwned::to_owned).collect::>(), @@ -244,19 +244,19 @@ pub(crate) trait Select { let (key_both_first, key_algorithm) = Self::select( &possible_host_key_algos[..], - &parse_kex_algo_list(key_string), + &parse_kex_algo_list(&key_string), AlgorithmKind::Key, )?; // Cipher - let cipher_string = r.read_string()?; + let cipher_string = String::decode(&mut r)?; let (_cipher_both_first, cipher) = Self::select( &pref.cipher, - &parse_kex_algo_list(cipher_string), + &parse_kex_algo_list(&cipher_string), AlgorithmKind::Cipher, )?; - r.read_string()?; // cipher server-to-client. + String::decode(&mut r)?; // cipher server-to-client. debug!("kex {}", line!()); // MAC @@ -265,7 +265,7 @@ pub(crate) trait Select { let client_mac = match Self::select( &pref.mac, - &parse_kex_algo_list(r.read_string()?), + &parse_kex_algo_list(&String::decode(&mut r)?), AlgorithmKind::Mac, ) { Ok((_, m)) => m, @@ -279,7 +279,7 @@ pub(crate) trait Select { }; let server_mac = match Self::select( &pref.mac, - &parse_kex_algo_list(r.read_string()?), + &parse_kex_algo_list(&String::decode(&mut r)?), AlgorithmKind::Mac, ) { Ok((_, m)) => m, @@ -299,7 +299,7 @@ pub(crate) trait Select { let client_compression = compression::Compression::new( &Self::select( &pref.compression, - &parse_kex_algo_list(r.read_string()?), + &parse_kex_algo_list(&String::decode(&mut r)?), AlgorithmKind::Compression, )? .1, @@ -310,16 +310,16 @@ pub(crate) trait Select { let server_compression = compression::Compression::new( &Self::select( &pref.compression, - &parse_kex_algo_list(r.read_string()?), + &parse_kex_algo_list(&String::decode(&mut r)?), AlgorithmKind::Compression, )? .1, ); debug!("client_compression = {:?}", client_compression); - r.read_string()?; // languages client-to-server - r.read_string()?; // languages server-to-client + String::decode(&mut r)?; // languages client-to-server + String::decode(&mut r)?; // languages server-to-client - let follows = r.read_byte()? != 0; + let follows = u8::decode(&mut r)? != 0; Ok(Names { kex: kex_algorithm, key: key_algorithm, @@ -404,43 +404,92 @@ pub fn write_kex( rand::thread_rng().fill_bytes(&mut cookie); buf.extend(&cookie); // cookie - buf.extend_list(prefs.kex.iter().filter(|k| { - !(if server_config.is_some() { - [ - crate::kex::EXTENSION_SUPPORT_AS_CLIENT, - crate::kex::EXTENSION_OPENSSH_STRICT_KEX_AS_CLIENT, - ] - } else { - [ - crate::kex::EXTENSION_SUPPORT_AS_SERVER, - crate::kex::EXTENSION_OPENSSH_STRICT_KEX_AS_SERVER, - ] - }) - .contains(*k) - })); // kex algo + NameList( + prefs + .kex + .iter() + .filter(|k| { + !(if server_config.is_some() { + [ + crate::kex::EXTENSION_SUPPORT_AS_CLIENT, + crate::kex::EXTENSION_OPENSSH_STRICT_KEX_AS_CLIENT, + ] + } else { + [ + crate::kex::EXTENSION_SUPPORT_AS_SERVER, + crate::kex::EXTENSION_OPENSSH_STRICT_KEX_AS_SERVER, + ] + }) + .contains(*k) + }) + .map(|x| x.as_ref().to_owned()) + .collect(), + ) + .encode(buf)?; // kex algo if let Some(server_config) = server_config { // Only advertise host key algorithms that we have keys for. - buf.extend_list( + NameList( prefs .key .iter() - .filter(|algo| server_config.keys.iter().any(|k| k.algorithm() == **algo)), - ); + .filter(|algo| server_config.keys.iter().any(|k| k.algorithm() == **algo)) + .map(|x| x.to_string()) + .collect(), + ) + .encode(buf)?; } else { - buf.extend_list(prefs.key.iter()); + NameList(prefs.key.iter().map(ToString::to_string).collect()).encode(buf)?; } - buf.extend_list(prefs.cipher.iter()); // cipher client to server - buf.extend_list(prefs.cipher.iter()); // cipher server to client + // cipher client to server + NameList( + prefs + .cipher + .iter() + .map(|x| x.as_ref().to_string()) + .collect(), + ) + .encode(buf)?; + + // cipher server to client + NameList( + prefs + .cipher + .iter() + .map(|x| x.as_ref().to_string()) + .collect(), + ) + .encode(buf)?; + + // mac client to server + NameList(prefs.mac.iter().map(|x| x.as_ref().to_string()).collect()).encode(buf)?; - buf.extend_list(prefs.mac.iter()); // mac client to server - buf.extend_list(prefs.mac.iter()); // mac server to client - buf.extend_list(prefs.compression.iter()); // compress client to server - buf.extend_list(prefs.compression.iter()); // compress server to client + // mac server to client + NameList(prefs.mac.iter().map(|x| x.as_ref().to_string()).collect()).encode(buf)?; + + // compress client to server + NameList( + prefs + .compression + .iter() + .map(|x| x.as_ref().to_string()) + .collect(), + ) + .encode(buf)?; + + // compress server to client + NameList( + prefs + .compression + .iter() + .map(|x| x.as_ref().to_string()) + .collect(), + ) + .encode(buf)?; - buf.write_empty_list(); // languages client to server - buf.write_empty_list(); // languagesserver to client + Vec::::new().encode(buf)?; // languages client to server + Vec::::new().encode(buf)?; // languages server to client buf.push(0); // doesn't follow buf.extend(&[0, 0, 0, 0]); // reserved diff --git a/russh/src/parsing.rs b/russh/src/parsing.rs index fe80c974..6323f242 100644 --- a/russh/src/parsing.rs +++ b/russh/src/parsing.rs @@ -1,4 +1,6 @@ -use crate::keys::encoding::{Encoding, Position}; +use russh_keys::helpers::map_err; +use ssh_encoding::{Decode, Encode, Reader}; + use crate::{msg, CryptoVec}; #[derive(Debug)] @@ -10,33 +12,30 @@ pub struct OpenChannelMessage { } impl OpenChannelMessage { - pub fn parse(r: &mut Position) -> Result { + pub fn parse(r: &mut R) -> Result { // https://tools.ietf.org/html/rfc4254#section-5.1 - let typ = r.read_string().map_err(crate::Error::from)?; - let sender = r.read_u32().map_err(crate::Error::from)?; - let window = r.read_u32().map_err(crate::Error::from)?; - let maxpacket = r.read_u32().map_err(crate::Error::from)?; - - let typ = match typ { - b"session" => ChannelType::Session, - b"x11" => { - let originator_address = - std::str::from_utf8(r.read_string().map_err(crate::Error::from)?) - .map_err(crate::Error::from)? - .to_owned(); - let originator_port = r.read_u32().map_err(crate::Error::from)?; + let typ = map_err!(String::decode(r))?; + let sender = map_err!(u32::decode(r))?; + let window = map_err!(u32::decode(r))?; + let maxpacket = map_err!(u32::decode(r))?; + + let typ = match typ.as_str() { + "session" => ChannelType::Session, + "x11" => { + let originator_address = map_err!(String::decode(r))?; + let originator_port = map_err!(u32::decode(r))?; ChannelType::X11 { originator_address, originator_port, } } - b"direct-tcpip" => ChannelType::DirectTcpip(TcpChannelInfo::new(r)?), - b"forwarded-tcpip" => ChannelType::ForwardedTcpIp(TcpChannelInfo::new(r)?), - b"forwarded-streamlocal@openssh.com" => { - ChannelType::ForwardedStreamLocal(StreamLocalChannelInfo::new(r)?) + "direct-tcpip" => ChannelType::DirectTcpip(TcpChannelInfo::decode(r)?), + "forwarded-tcpip" => ChannelType::ForwardedTcpIp(TcpChannelInfo::decode(r)?), + "forwarded-streamlocal@openssh.com" => { + ChannelType::ForwardedStreamLocal(StreamLocalChannelInfo::decode(r)?) } - b"auth-agent@openssh.com" => ChannelType::AgentForward, - t => ChannelType::Unknown { typ: t.to_vec() }, + "auth-agent@openssh.com" => ChannelType::AgentForward, + _ => ChannelType::Unknown { typ }, }; Ok(Self { @@ -54,34 +53,41 @@ impl OpenChannelMessage { sender_channel: u32, window_size: u32, packet_size: u32, - ) { + ) -> Result<(), crate::Error> { push_packet!(buffer, { - buffer.push(msg::CHANNEL_OPEN_CONFIRMATION); - buffer.push_u32_be(self.recipient_channel); // remote channel number. - buffer.push_u32_be(sender_channel); // our channel number. - buffer.push_u32_be(window_size); - buffer.push_u32_be(packet_size); + msg::CHANNEL_OPEN_CONFIRMATION.encode(buffer)?; + self.recipient_channel.encode(buffer)?; // remote channel number. + sender_channel.encode(buffer)?; // our channel number. + window_size.encode(buffer)?; + packet_size.encode(buffer)?; }); + Ok(()) } /// Pushes a failure message to the vec. - pub fn fail(&self, buffer: &mut CryptoVec, reason: u8, message: &[u8]) { + pub fn fail( + &self, + buffer: &mut CryptoVec, + reason: u8, + message: &[u8], + ) -> Result<(), crate::Error> { push_packet!(buffer, { - buffer.push(msg::CHANNEL_OPEN_FAILURE); - buffer.push_u32_be(self.recipient_channel); - buffer.push_u32_be(reason as u32); - buffer.extend_ssh_string(message); - buffer.extend_ssh_string(b"en"); + msg::CHANNEL_OPEN_FAILURE.encode(buffer)?; + self.recipient_channel.encode(buffer)?; + (reason as u32).encode(buffer)?; + message.encode(buffer)?; + "en".encode(buffer)?; }); + Ok(()) } /// Pushes an unknown type error to the vec. - pub fn unknown_type(&self, buffer: &mut CryptoVec) { + pub fn unknown_type(&self, buffer: &mut CryptoVec) -> Result<(), crate::Error> { self.fail( buffer, msg::SSH_OPEN_UNKNOWN_CHANNEL_TYPE, b"Unknown channel type", - ); + ) } } @@ -97,7 +103,7 @@ pub enum ChannelType { ForwardedStreamLocal(StreamLocalChannelInfo), AgentForward, Unknown { - typ: Vec, + typ: String, }, } @@ -114,26 +120,23 @@ pub struct StreamLocalChannelInfo { pub socket_path: String, } -impl StreamLocalChannelInfo { - fn new(r: &mut Position) -> Result { - let socket_path = std::str::from_utf8(r.read_string().map_err(crate::Error::from)?) - .map_err(crate::Error::from)? - .to_owned(); +impl Decode for StreamLocalChannelInfo { + type Error = ssh_encoding::Error; + fn decode(r: &mut impl Reader) -> Result { + let socket_path = String::decode(r)?.to_owned(); Ok(Self { socket_path }) } } -impl TcpChannelInfo { - fn new(r: &mut Position) -> Result { - let host_to_connect = std::str::from_utf8(r.read_string().map_err(crate::Error::from)?) - .map_err(crate::Error::from)? - .to_owned(); - let port_to_connect = r.read_u32().map_err(crate::Error::from)?; - let originator_address = std::str::from_utf8(r.read_string().map_err(crate::Error::from)?) - .map_err(crate::Error::from)? - .to_owned(); - let originator_port = r.read_u32().map_err(crate::Error::from)?; +impl Decode for TcpChannelInfo { + type Error = ssh_encoding::Error; + + fn decode(r: &mut impl Reader) -> Result { + let host_to_connect = String::decode(r)?; + let port_to_connect = u32::decode(r)?; + let originator_address = String::decode(r)?; + let originator_port = u32::decode(r)?; Ok(Self { host_to_connect, @@ -152,12 +155,14 @@ pub(crate) struct ChannelOpenConfirmation { pub maximum_packet_size: u32, } -impl ChannelOpenConfirmation { - pub fn parse(r: &mut Position) -> Result { - let recipient_channel = r.read_u32().map_err(crate::Error::from)?; - let sender_channel = r.read_u32().map_err(crate::Error::from)?; - let initial_window_size = r.read_u32().map_err(crate::Error::from)?; - let maximum_packet_size = r.read_u32().map_err(crate::Error::from)?; +impl Decode for ChannelOpenConfirmation { + type Error = ssh_encoding::Error; + + fn decode(r: &mut impl Reader) -> Result { + let recipient_channel = u32::decode(r)?; + let sender_channel = u32::decode(r)?; + let initial_window_size = u32::decode(r)?; + let maximum_packet_size = u32::decode(r)?; Ok(Self { recipient_channel, diff --git a/russh/src/server/encrypted.rs b/russh/src/server/encrypted.rs index b7d2ce22..3ecfa471 100644 --- a/russh/src/server/encrypted.rs +++ b/russh/src/server/encrypted.rs @@ -18,17 +18,20 @@ use std::time::SystemTime; use auth::*; use byteorder::{BigEndian, ByteOrder}; +use bytes::Bytes; use cert::PublicKeyOrCertificate; use log::{debug, error, info, trace, warn}; use negotiation::Select; +use russh_keys::helpers::NameList; +use russh_keys::map_err; use signature::Verifier; -use ssh_key::{Algorithm, PublicKey, Signature}; +use ssh_encoding::{Decode, Encode, Reader}; +use ssh_key::{PublicKey, Signature}; use tokio::time::Instant; use {msg, negotiation}; use super::super::*; use super::*; -use crate::keys::encoding::{Encoding, Position, Reader}; use crate::msg::SSH_OPEN_ADMINISTRATIVELY_PROHIBITED; use crate::parsing::{ChannelOpenConfirmation, ChannelType, OpenChannelMessage}; @@ -108,7 +111,7 @@ impl Session { enc.last_rekey = std::time::Instant::now(); // Ok, NEWKEYS received, now encrypted. - enc.flush_all_pending(); + enc.flush_all_pending()?; let mut pending = std::mem::take(&mut self.pending_reads); for p in pending.drain(..) { self.process_packet(handler, &p).await?; @@ -167,32 +170,33 @@ impl Session { #[allow(clippy::unwrap_used)] let enc = self.common.encrypted.as_mut().unwrap(); // If we've successfully read a packet. - match enc.state { - EncryptedState::WaitingAuthServiceRequest { - ref mut accepted, .. - } if buf.first() == Some(&msg::SERVICE_REQUEST) => { - let mut r = buf.reader(1); - let request = r.read_string().map_err(crate::Error::from)?; - debug!("request: {:?}", std::str::from_utf8(request)); - if request == b"ssh-userauth" { + match (&mut enc.state, buf.split_first()) { + ( + EncryptedState::WaitingAuthServiceRequest { + ref mut accepted, .. + }, + Some((&msg::SERVICE_REQUEST, mut r)), + ) => { + let request = map_err!(String::decode(&mut r))?; + debug!("request: {:?}", request); + if request == "ssh-userauth" { let auth_request = server_accept_service( self.common.config.as_ref().auth_banner, self.common.config.as_ref().methods, &mut enc.write, - ); + )?; *accepted = true; enc.state = EncryptedState::WaitingAuthRequest(auth_request); } Ok(()) } - EncryptedState::WaitingAuthRequest(_) - if buf.first() == Some(&msg::USERAUTH_REQUEST) => - { + (EncryptedState::WaitingAuthRequest(_), Some((&msg::USERAUTH_REQUEST, mut r))) => { enc.server_read_auth_request( rejection_wait_until, initial_none_rejection_wait_until, handler, buf, + &mut r, &mut self.common.auth_user, ) .await?; @@ -203,16 +207,17 @@ impl Session { } Ok(()) } - EncryptedState::WaitingAuthRequest(ref mut auth) - if buf.first() == Some(&msg::USERAUTH_INFO_RESPONSE) => - { + ( + EncryptedState::WaitingAuthRequest(ref mut auth), + Some((&msg::USERAUTH_INFO_RESPONSE, mut r)), + ) => { let resp = read_userauth_info_response( rejection_wait_until, handler, &mut enc.write, auth, &self.common.auth_user, - buf, + &mut r, ) .await?; if resp { @@ -223,12 +228,14 @@ impl Session { Ok(()) } } - EncryptedState::InitCompression => { + (EncryptedState::InitCompression, Some((msg, mut r))) => { enc.server_compression.init_compress(&mut enc.compress); enc.state = EncryptedState::Authenticated; - self.server_read_authenticated(handler, buf).await + self.server_read_authenticated(handler, *msg, &mut r).await + } + (EncryptedState::Authenticated, Some((msg, mut r))) => { + self.server_read_authenticated(handler, *msg, &mut r).await } - EncryptedState::Authenticated => self.server_read_authenticated(handler, buf).await, _ => Ok(()), } } @@ -238,26 +245,26 @@ fn server_accept_service( banner: Option<&str>, methods: MethodSet, buffer: &mut CryptoVec, -) -> AuthRequest { +) -> Result { push_packet!(buffer, { buffer.push(msg::SERVICE_ACCEPT); - buffer.extend_ssh_string(b"ssh-userauth"); + "ssh-userauth".encode(buffer)?; }); if let Some(banner) = banner { push_packet!(buffer, { buffer.push(msg::USERAUTH_BANNER); - buffer.extend_ssh_string(banner.as_bytes()); - buffer.extend_ssh_string(b""); + banner.encode(buffer)?; + "".encode(buffer)?; }) } - AuthRequest { + Ok(AuthRequest { methods, partial_success: false, // not used immediately anway. current: None, rejection_count: 0, - } + }) } impl Encrypted { @@ -267,24 +274,18 @@ impl Encrypted { mut until: Instant, initial_auth_until: Instant, handler: &mut H, - buf: &[u8], + original_packet: &[u8], + r: &mut &[u8], auth_user: &mut String, ) -> Result<(), H::Error> { // https://tools.ietf.org/html/rfc4252#section-5 - let mut r = buf.reader(1); - let user = r.read_string().map_err(crate::Error::from)?; - let user = std::str::from_utf8(user).map_err(crate::Error::from)?; - let service_name = r.read_string().map_err(crate::Error::from)?; - let method = r.read_string().map_err(crate::Error::from)?; - debug!( - "name: {:?} {:?} {:?}", - user, - std::str::from_utf8(service_name), - std::str::from_utf8(method) - ); + let user = map_err!(String::decode(r))?; + let service_name = map_err!(String::decode(r))?; + let method = map_err!(String::decode(r))?; + debug!("name: {user:?} {service_name:?} {method:?}",); - if service_name == b"ssh-connection" { - if method == b"password" { + if service_name == "ssh-connection" { + if method == "password" { let auth_request = if let EncryptedState::WaitingAuthRequest(ref mut a) = self.state { a @@ -292,11 +293,10 @@ impl Encrypted { unreachable!() }; auth_user.clear(); - auth_user.push_str(user); - r.read_byte().map_err(crate::Error::from)?; - let password = r.read_string().map_err(crate::Error::from)?; - let password = std::str::from_utf8(password).map_err(crate::Error::from)?; - let auth = handler.auth_password(user, password).await?; + auth_user.push_str(&user); + map_err!(u8::decode(r))?; + let password = map_err!(String::decode(r))?; + let auth = handler.auth_password(&user, &password).await?; if let Auth::Accept = auth { server_auth_request_success(&mut self.write); self.state = EncryptedState::InitCompression; @@ -311,13 +311,20 @@ impl Encrypted { auth_request.methods -= MethodSet::PASSWORD; } auth_request.partial_success = false; - reject_auth_request(until, &mut self.write, auth_request).await; + reject_auth_request(until, &mut self.write, auth_request).await?; } Ok(()) - } else if method == b"publickey" { - self.server_read_auth_request_pk(until, handler, buf, auth_user, user, r) - .await - } else if method == b"none" { + } else if method == "publickey" { + self.server_read_auth_request_pk( + until, + handler, + original_packet, + auth_user, + &user, + r, + ) + .await + } else if method == "none" { let auth_request = if let EncryptedState::WaitingAuthRequest(ref mut a) = self.state { a @@ -325,11 +332,9 @@ impl Encrypted { unreachable!() }; - if method == b"none" { - until = initial_auth_until - } + until = initial_auth_until; - let auth = handler.auth_none(user).await?; + let auth = handler.auth_none(&user).await?; if let Auth::Accept = auth { server_auth_request_success(&mut self.write); self.state = EncryptedState::InitCompression; @@ -344,10 +349,10 @@ impl Encrypted { auth_request.methods -= MethodSet::NONE; } auth_request.partial_success = false; - reject_auth_request(until, &mut self.write, auth_request).await; + reject_auth_request(until, &mut self.write, auth_request).await?; } Ok(()) - } else if method == b"keyboard-interactive" { + } else if method == "keyboard-interactive" { let auth_request = if let EncryptedState::WaitingAuthRequest(ref mut a) = self.state { a @@ -355,16 +360,15 @@ impl Encrypted { unreachable!() }; auth_user.clear(); - auth_user.push_str(user); - let _ = r.read_string().map_err(crate::Error::from)?; // language_tag, deprecated. - let submethods = std::str::from_utf8(r.read_string().map_err(crate::Error::from)?) - .map_err(crate::Error::from)?; + auth_user.push_str(&user); + let _ = map_err!(String::decode(r))?; // language_tag, deprecated. + let submethods = map_err!(String::decode(r))?; debug!("{:?}", submethods); auth_request.current = Some(CurrentRequest::KeyboardInteractive { submethods: submethods.to_string(), }); let auth = handler - .auth_keyboard_interactive(user, submethods, None) + .auth_keyboard_interactive(&user, &submethods, None) .await?; if reply_userauth_info_response(until, auth_request, &mut self.write, auth).await? { self.state = EncryptedState::InitCompression @@ -378,7 +382,7 @@ impl Encrypted { } else { unreachable!() }; - reject_auth_request(until, &mut self.write, auth_request).await; + reject_auth_request(until, &mut self.write, auth_request).await?; Ok(()) } } else { @@ -397,10 +401,10 @@ impl Encrypted { &mut self, until: Instant, handler: &mut H, - buf: &[u8], + original_packet: &[u8], auth_user: &mut String, user: &str, - mut r: Position<'_>, + r: &mut &[u8], ) -> Result<(), H::Error> { let auth_request = if let EncryptedState::WaitingAuthRequest(ref mut a) = self.state { a @@ -408,11 +412,11 @@ impl Encrypted { unreachable!() }; - let is_real = r.read_byte().map_err(crate::Error::from)?; - let pubkey_algo = r.read_string().map_err(crate::Error::from)?; - let pubkey_key = r.read_string().map_err(crate::Error::from)?; + let is_real = map_err!(u8::decode(r))?; - let key_or_cert = PublicKeyOrCertificate::decode(pubkey_algo, pubkey_key); + let pubkey_algo = map_err!(String::decode(r))?; + let pubkey_key = map_err!(Bytes::decode(r))?; + let key_or_cert = PublicKeyOrCertificate::decode(&pubkey_algo, &pubkey_key); // Parse the public key or certificate match key_or_cert { @@ -427,14 +431,14 @@ impl Encrypted { let now = SystemTime::now(); if now < cert.valid_after_time() || now > cert.valid_before_time() { warn!("Certificate is expired or not yet valid"); - reject_auth_request(until, &mut self.write, auth_request).await; + reject_auth_request(until, &mut self.write, auth_request).await?; return Ok(()); } // Verify the certificate’s signature if cert.verify_signature().is_err() { warn!("Certificate signature is invalid"); - reject_auth_request(until, &mut self.write, auth_request).await; + reject_auth_request(until, &mut self.write, auth_request).await?; return Ok(()); } @@ -444,7 +448,8 @@ impl Encrypted { }; if is_real != 0 { - let pos0 = r.position; + let pos0 = r.as_ptr(); + let sent_pk_ok = if let Some(CurrentRequest::PublicKey { sent_pk_ok, .. }) = auth_request.current { @@ -453,29 +458,18 @@ impl Encrypted { false }; - let signature = r.read_string().map_err(crate::Error::from)?; - let mut s = signature.reader(0); - let algo = s.read_string().map_err(crate::Error::from)?; - - let sig = s.read_string().map_err(crate::Error::from)?; - - let mut sig_buf = sig.to_vec(); - let algo = Algorithm::new(str::from_utf8(algo).map_err(crate::Error::from)?) - .map_err(crate::Error::from)?; + let encoded_signature = map_err!(Vec::::decode(r))?; - if algo == Algorithm::SkEcdsaSha2NistP256 || algo == Algorithm::SkEd25519 { - // https://github.com/RustCrypto/SSH/issues/312 - let flags = s.read_byte().map_err(crate::Error::from)?; - sig_buf.push(flags); - let counter = s.read_u32().map_err(crate::Error::from)?; - sig_buf.extend_from_slice(&counter.to_be_bytes()); - } - - #[allow(clippy::indexing_slicing)] - let sig = Signature::new(algo, sig_buf).map_err(crate::Error::from)?; + let sig = map_err!(Signature::decode(&mut encoded_signature.as_slice()))?; - #[allow(clippy::indexing_slicing)] // length checked - let init = &buf[0..pos0]; + // SAFETY: both original_packet and pos0 are coming + // from the same allocation (pos0 is derived from + // a slice of the original_packet) + let init = { + let init_len = unsafe { pos0.offset_from(original_packet.as_ptr()) }; + #[allow(clippy::indexing_slicing)] // length checked + &original_packet[0..init_len as usize] + }; let is_valid = if sent_pk_ok && user == auth_user { true @@ -494,11 +488,11 @@ impl Encrypted { if SIGNATURE_BUFFER.with(|buf| { let mut buf = buf.borrow_mut(); buf.clear(); - buf.extend_ssh_string(session_id); + map_err!(session_id.encode(&mut *buf))?; buf.extend(init); - Verifier::verify(&pubkey, &buf, &sig).is_ok() - }) { + Ok(Verifier::verify(&pubkey, &buf, &sig).is_ok()) + })? { debug!("signature verified"); let auth = match pk_or_cert { PublicKeyOrCertificate::PublicKey(ref pk) => { @@ -521,14 +515,14 @@ impl Encrypted { } auth_request.partial_success = false; auth_user.clear(); - reject_auth_request(until, &mut self.write, auth_request).await; + reject_auth_request(until, &mut self.write, auth_request).await?; } } else { debug!("signature wrong"); - reject_auth_request(until, &mut self.write, auth_request).await; + reject_auth_request(until, &mut self.write, auth_request).await?; } } else { - reject_auth_request(until, &mut self.write, auth_request).await; + reject_auth_request(until, &mut self.write, auth_request).await?; } Ok(()) } else { @@ -538,15 +532,15 @@ impl Encrypted { match auth { Auth::Accept => { let mut public_key = CryptoVec::new(); - public_key.extend(pubkey_key); + public_key.extend(&pubkey_key); let mut algo = CryptoVec::new(); - algo.extend(pubkey_algo); + algo.extend(pubkey_algo.as_bytes()); debug!("pubkey_key: {:?}", pubkey_key); push_packet!(self.write, { self.write.push(msg::USERAUTH_PK_OK); - self.write.extend_ssh_string(pubkey_algo); - self.write.extend_ssh_string(pubkey_key); + map_err!(pubkey_algo.encode(&mut self.write))?; + map_err!(pubkey_key.encode(&mut self.write))?; }); auth_request.current = Some(CurrentRequest::PublicKey { @@ -564,7 +558,7 @@ impl Encrypted { } auth_request.partial_success = false; auth_user.clear(); - reject_auth_request(until, &mut self.write, auth_request).await; + reject_auth_request(until, &mut self.write, auth_request).await?; } } Ok(()) @@ -575,7 +569,7 @@ impl Encrypted { | ssh_key::Error::AlgorithmUnsupported { .. } | ssh_key::Error::CertificateValidation { .. } => { debug!("public key error: {e}"); - reject_auth_request(until, &mut self.write, auth_request).await; + reject_auth_request(until, &mut self.write, auth_request).await?; Ok(()) } e => Err(crate::Error::from(e).into()), @@ -588,17 +582,18 @@ async fn reject_auth_request( until: Instant, write: &mut CryptoVec, auth_request: &mut AuthRequest, -) { +) -> Result<(), Error> { debug!("rejecting {:?}", auth_request); push_packet!(write, { write.push(msg::USERAUTH_FAILURE); - write.extend_list(auth_request.methods.into_iter()); + NameList::from(auth_request.methods).encode(write)?; write.push(auth_request.partial_success as u8); }); auth_request.current = None; auth_request.rejection_count += 1; debug!("packet pushed"); - tokio::time::sleep_until(until).await + tokio::time::sleep_until(until).await; + Ok(()) } fn server_auth_request_success(buffer: &mut CryptoVec) { @@ -607,27 +602,31 @@ fn server_auth_request_success(buffer: &mut CryptoVec) { }) } -async fn read_userauth_info_response( +async fn read_userauth_info_response( until: Instant, handler: &mut H, write: &mut CryptoVec, auth_request: &mut AuthRequest, user: &str, - b: &[u8], + r: &mut R, ) -> Result { if let Some(CurrentRequest::KeyboardInteractive { ref submethods }) = auth_request.current { - let mut r = b.reader(1); - let n = r.read_u32().map_err(crate::Error::from)?; - let response = Response { pos: r, n }; + let n = map_err!(u32::decode(r))?; + + let mut responses = Vec::with_capacity(n as usize); + for _ in 0..n { + responses.push(Bytes::decode(r).ok()) + } + let auth = handler - .auth_keyboard_interactive(user, submethods, Some(response)) + .auth_keyboard_interactive(user, submethods, Some(Response(&mut responses.into_iter()))) .await?; let resp = reply_userauth_info_response(until, auth_request, write, auth) .await .map_err(H::Error::from)?; Ok(resp) } else { - reject_auth_request(until, write, auth_request).await; + reject_auth_request(until, write, auth_request).await?; Ok(false) } } @@ -650,7 +649,7 @@ async fn reply_userauth_info_response( auth_request.methods = proceed_with_methods; } auth_request.partial_success = false; - reject_auth_request(until, write, auth_request).await; + reject_auth_request(until, write, auth_request).await?; Ok(false) } Auth::Partial { @@ -659,16 +658,17 @@ async fn reply_userauth_info_response( prompts, } => { push_packet!(write, { - write.push(msg::USERAUTH_INFO_REQUEST); - write.extend_ssh_string(name.as_bytes()); - write.extend_ssh_string(instructions.as_bytes()); - write.extend_ssh_string(b""); // lang, should be empty - write.push_u32_be(prompts.len() as u32); + msg::USERAUTH_INFO_REQUEST.encode(write)?; + name.as_ref().encode(write)?; + instructions.as_ref().encode(write)?; + "".encode(write)?; // lang, should be empty + prompts.len().encode(write)?; for &(ref a, b) in prompts.iter() { - write.extend_ssh_string(a.as_bytes()); - write.push(b as u8); + a.as_ref().encode(write)?; + (b as u8).encode(write)?; } - }); + Ok::<(), crate::Error>(()) + })?; Ok(false) } Auth::UnsupportedMethod => unreachable!(), @@ -676,26 +676,19 @@ async fn reply_userauth_info_response( } impl Session { - async fn server_read_authenticated( + async fn server_read_authenticated( &mut self, handler: &mut H, - buf: &[u8], + msg: u8, + r: &mut R, ) -> Result<(), H::Error> { - #[allow(clippy::indexing_slicing)] // length checked - { - trace!( - "authenticated buf = {:?}", - &buf[..std::cmp::min(buf.len(), 100)] - ); - } - match buf.first() { - Some(&msg::CHANNEL_OPEN) => self - .server_handle_channel_open(handler, buf) + match msg { + msg::CHANNEL_OPEN => self + .server_handle_channel_open(handler, r) .await .map(|_| ()), - Some(&msg::CHANNEL_CLOSE) => { - let mut r = buf.reader(1); - let channel_num = ChannelId(r.read_u32().map_err(crate::Error::from)?); + msg::CHANNEL_CLOSE => { + let channel_num = map_err!(ChannelId::decode(r))?; if let Some(ref mut enc) = self.common.encrypted { enc.channels.remove(&channel_num); } @@ -703,30 +696,28 @@ impl Session { debug!("handler.channel_close {:?}", channel_num); handler.channel_close(channel_num, self).await } - Some(&msg::CHANNEL_EOF) => { - let mut r = buf.reader(1); - let channel_num = ChannelId(r.read_u32().map_err(crate::Error::from)?); + msg::CHANNEL_EOF => { + let channel_num = map_err!(ChannelId::decode(r))?; if let Some(chan) = self.channels.get(&channel_num) { chan.send(ChannelMsg::Eof).unwrap_or(()) } debug!("handler.channel_eof {:?}", channel_num); handler.channel_eof(channel_num, self).await } - Some(&msg::CHANNEL_EXTENDED_DATA) | Some(&msg::CHANNEL_DATA) => { - let mut r = buf.reader(1); - let channel_num = ChannelId(r.read_u32().map_err(crate::Error::from)?); + msg::CHANNEL_EXTENDED_DATA | msg::CHANNEL_DATA => { + let channel_num = map_err!(ChannelId::decode(r))?; - let ext = if buf.first() == Some(&msg::CHANNEL_DATA) { + let ext = if msg == msg::CHANNEL_DATA { None } else { - Some(r.read_u32().map_err(crate::Error::from)?) + Some(map_err!(u32::decode(r))?) }; trace!("handler.data {:?} {:?}", ext, channel_num); - let data = r.read_string().map_err(crate::Error::from)?; + let data = map_err!(Bytes::decode(r))?; let target = self.target_window_size; if let Some(ref mut enc) = self.common.encrypted { - if enc.adjust_window_size(channel_num, data, target) { + if enc.adjust_window_size(channel_num, &data, target)? { let window = handler.adjust_window(channel_num, self.target_window_size); if window > 0 { self.target_window_size = window @@ -738,26 +729,25 @@ impl Session { if let Some(chan) = self.channels.get(&channel_num) { chan.send(ChannelMsg::ExtendedData { ext, - data: CryptoVec::from_slice(data), + data: CryptoVec::from_slice(&data), }) .unwrap_or(()) } - handler.extended_data(channel_num, ext, data, self).await + handler.extended_data(channel_num, ext, &data, self).await } else { if let Some(chan) = self.channels.get(&channel_num) { chan.send(ChannelMsg::Data { - data: CryptoVec::from_slice(data), + data: CryptoVec::from_slice(&data), }) .unwrap_or(()) } - handler.data(channel_num, data, self).await + handler.data(channel_num, &data, self).await } } - Some(&msg::CHANNEL_WINDOW_ADJUST) => { - let mut r = buf.reader(1); - let channel_num = ChannelId(r.read_u32().map_err(crate::Error::from)?); - let amount = r.read_u32().map_err(crate::Error::from)?; + msg::CHANNEL_WINDOW_ADJUST => { + let channel_num = map_err!(ChannelId::decode(r))?; + let amount = map_err!(u32::decode(r))?; let mut new_size = 0; if let Some(ref mut enc) = self.common.encrypted { if let Some(channel) = enc.channels.get_mut(&channel_num) { @@ -768,7 +758,7 @@ impl Session { } } if let Some(ref mut enc) = self.common.encrypted { - enc.flush_pending(channel_num); + enc.flush_pending(channel_num)?; } if let Some(chan) = self.channels.get(&channel_num) { *chan.window_size().lock().await = new_size; @@ -780,10 +770,9 @@ impl Session { handler.window_adjusted(channel_num, new_size, self).await } - Some(&msg::CHANNEL_OPEN_CONFIRMATION) => { + msg::CHANNEL_OPEN_CONFIRMATION => { debug!("channel_open_confirmation"); - let mut reader = buf.reader(1); - let msg = ChannelOpenConfirmation::parse(&mut reader)?; + let msg = map_err!(ChannelOpenConfirmation::decode(r))?; let local_id = ChannelId(msg.recipient_channel); if let Some(ref mut enc) = self.common.encrypted { @@ -818,29 +807,26 @@ impl Session { .await } - Some(&msg::CHANNEL_REQUEST) => { - let mut r = buf.reader(1); - let channel_num = ChannelId(r.read_u32().map_err(crate::Error::from)?); - let req_type = r.read_string().map_err(crate::Error::from)?; - let wants_reply = r.read_byte().map_err(crate::Error::from)?; + msg::CHANNEL_REQUEST => { + let channel_num = map_err!(ChannelId::decode(r))?; + let req_type = map_err!(String::decode(r))?; + let wants_reply = map_err!(u8::decode(r))?; if let Some(ref mut enc) = self.common.encrypted { if let Some(channel) = enc.channels.get_mut(&channel_num) { channel.wants_reply = wants_reply != 0; } } - match req_type { - b"pty-req" => { - let term = - std::str::from_utf8(r.read_string().map_err(crate::Error::from)?) - .map_err(crate::Error::from)?; - let col_width = r.read_u32().map_err(crate::Error::from)?; - let row_height = r.read_u32().map_err(crate::Error::from)?; - let pix_width = r.read_u32().map_err(crate::Error::from)?; - let pix_height = r.read_u32().map_err(crate::Error::from)?; + match req_type.as_str() { + "pty-req" => { + let term = map_err!(String::decode(r))?; + let col_width = map_err!(u32::decode(r))?; + let row_height = map_err!(u32::decode(r))?; + let pix_width = map_err!(u32::decode(r))?; + let pix_height = map_err!(u32::decode(r))?; let mut modes = [(Pty::TTY_OP_END, 0); 130]; let mut i = 0; { - let mode_string = r.read_string().map_err(crate::Error::from)?; + let mode_string = map_err!(Bytes::decode(r))?; while 5 * i < mode_string.len() { #[allow(clippy::indexing_slicing)] // length checked let code = mode_string[5 * i]; @@ -867,7 +853,7 @@ impl Session { if let Some(chan) = self.channels.get(&channel_num) { let _ = chan.send(ChannelMsg::RequestPty { want_reply: true, - term: term.into(), + term: term.clone(), col_width, row_height, pix_width, @@ -881,7 +867,7 @@ impl Session { handler .pty_request( channel_num, - term, + &term, col_width, row_height, pix_width, @@ -891,22 +877,18 @@ impl Session { ) .await } - b"x11-req" => { - let single_connection = r.read_byte().map_err(crate::Error::from)? != 0; - let x11_auth_protocol = - std::str::from_utf8(r.read_string().map_err(crate::Error::from)?) - .map_err(crate::Error::from)?; - let x11_auth_cookie = - std::str::from_utf8(r.read_string().map_err(crate::Error::from)?) - .map_err(crate::Error::from)?; - let x11_screen_number = r.read_u32().map_err(crate::Error::from)?; + "x11-req" => { + let single_connection = map_err!(u8::decode(r))? != 0; + let x11_auth_protocol = map_err!(String::decode(r))?; + let x11_auth_cookie = map_err!(String::decode(r))?; + let x11_screen_number = map_err!(u32::decode(r))?; if let Some(chan) = self.channels.get(&channel_num) { let _ = chan.send(ChannelMsg::RequestX11 { want_reply: true, single_connection, - x11_authentication_cookie: x11_auth_cookie.into(), - x11_authentication_protocol: x11_auth_protocol.into(), + x11_authentication_cookie: x11_auth_cookie.clone(), + x11_authentication_protocol: x11_auth_protocol.clone(), x11_screen_number, }); } @@ -915,42 +897,38 @@ impl Session { .x11_request( channel_num, single_connection, - x11_auth_protocol, - x11_auth_cookie, + &x11_auth_protocol, + &x11_auth_cookie, x11_screen_number, self, ) .await } - b"env" => { - let env_variable = - std::str::from_utf8(r.read_string().map_err(crate::Error::from)?) - .map_err(crate::Error::from)?; - let env_value = - std::str::from_utf8(r.read_string().map_err(crate::Error::from)?) - .map_err(crate::Error::from)?; + "env" => { + let env_variable = map_err!(String::decode(r))?; + let env_value = map_err!(String::decode(r))?; if let Some(chan) = self.channels.get(&channel_num) { let _ = chan.send(ChannelMsg::SetEnv { want_reply: true, - variable_name: env_variable.into(), - variable_value: env_value.into(), + variable_name: env_variable.clone(), + variable_value: env_value.clone(), }); } debug!("handler.env_request {:?}", channel_num); handler - .env_request(channel_num, env_variable, env_value, self) + .env_request(channel_num, &env_variable, &env_value, self) .await } - b"shell" => { + "shell" => { if let Some(chan) = self.channels.get(&channel_num) { let _ = chan.send(ChannelMsg::RequestShell { want_reply: true }); } debug!("handler.shell_request {:?}", channel_num); handler.shell_request(channel_num, self).await } - b"auth-agent-req@openssh.com" => { + "auth-agent-req@openssh.com" => { if let Some(chan) = self.channels.get(&channel_num) { let _ = chan.send(ChannelMsg::AgentForward { want_reply: true }); } @@ -964,36 +942,34 @@ impl Session { } Ok(()) } - b"exec" => { - let req = r.read_string().map_err(crate::Error::from)?; + "exec" => { + let req = map_err!(Bytes::decode(r))?; if let Some(chan) = self.channels.get(&channel_num) { let _ = chan.send(ChannelMsg::Exec { want_reply: true, - command: req.into(), + command: req.to_vec(), }); } debug!("handler.exec_request {:?}", channel_num); - handler.exec_request(channel_num, req, self).await + handler.exec_request(channel_num, &req, self).await } - b"subsystem" => { - let name = - std::str::from_utf8(r.read_string().map_err(crate::Error::from)?) - .map_err(crate::Error::from)?; + "subsystem" => { + let name = map_err!(String::decode(r))?; if let Some(chan) = self.channels.get(&channel_num) { let _ = chan.send(ChannelMsg::RequestSubsystem { want_reply: true, - name: name.into(), + name: name.clone(), }); } debug!("handler.subsystem_request {:?}", channel_num); - handler.subsystem_request(channel_num, name, self).await + handler.subsystem_request(channel_num, &name, self).await } - b"window-change" => { - let col_width = r.read_u32().map_err(crate::Error::from)?; - let row_height = r.read_u32().map_err(crate::Error::from)?; - let pix_width = r.read_u32().map_err(crate::Error::from)?; - let pix_height = r.read_u32().map_err(crate::Error::from)?; + "window-change" => { + let col_width = map_err!(u32::decode(r))?; + let row_height = map_err!(u32::decode(r))?; + let pix_width = map_err!(u32::decode(r))?; + let pix_height = map_err!(u32::decode(r))?; if let Some(chan) = self.channels.get(&channel_num) { let _ = chan.send(ChannelMsg::WindowChange { @@ -1016,8 +992,8 @@ impl Session { ) .await } - b"signal" => { - let signal = Sig::from_name(r.read_string().map_err(crate::Error::from)?)?; + "signal" => { + let signal = Sig::from_name(&map_err!(String::decode(r))?); if let Some(chan) = self.channels.get(&channel_num) { chan.send(ChannelMsg::Signal { signal: signal.clone(), @@ -1028,33 +1004,30 @@ impl Session { handler.signal(channel_num, signal, self).await } x => { - warn!("unknown channel request {}", String::from_utf8_lossy(x)); - self.channel_failure(channel_num); + warn!("unknown channel request {x}"); + self.channel_failure(channel_num)?; Ok(()) } } } - Some(&msg::GLOBAL_REQUEST) => { - let mut r = buf.reader(1); - let req_type = r.read_string().map_err(crate::Error::from)?; - self.common.wants_reply = r.read_byte().map_err(crate::Error::from)? != 0; - match req_type { - b"tcpip-forward" => { - let address = - std::str::from_utf8(r.read_string().map_err(crate::Error::from)?) - .map_err(crate::Error::from)?; - let port = r.read_u32().map_err(crate::Error::from)?; + msg::GLOBAL_REQUEST => { + let req_type = map_err!(String::decode(r))?; + self.common.wants_reply = map_err!(u8::decode(r))? != 0; + match req_type.as_str() { + "tcpip-forward" => { + let address = map_err!(String::decode(r))?; + let port = map_err!(u32::decode(r))?; debug!("handler.tcpip_forward {:?} {:?}", address, port); let mut returned_port = port; let result = handler - .tcpip_forward(address, &mut returned_port, self) + .tcpip_forward(&address, &mut returned_port, self) .await?; if let Some(ref mut enc) = self.common.encrypted { if result { push_packet!(enc.write, { enc.write.push(msg::REQUEST_SUCCESS); if self.common.wants_reply && port == 0 && returned_port != 0 { - enc.write.push_u32_be(returned_port); + map_err!(returned_port.encode(&mut enc.write))?; } }) } else { @@ -1063,13 +1036,11 @@ impl Session { } Ok(()) } - b"cancel-tcpip-forward" => { - let address = - std::str::from_utf8(r.read_string().map_err(crate::Error::from)?) - .map_err(crate::Error::from)?; - let port = r.read_u32().map_err(crate::Error::from)?; + "cancel-tcpip-forward" => { + let address = map_err!(String::decode(r))?; + let port = map_err!(u32::decode(r))?; debug!("handler.cancel_tcpip_forward {:?} {:?}", address, port); - let result = handler.cancel_tcpip_forward(address, port, self).await?; + let result = handler.cancel_tcpip_forward(&address, port, self).await?; if let Some(ref mut enc) = self.common.encrypted { if result { push_packet!(enc.write, enc.write.push(msg::REQUEST_SUCCESS)) @@ -1079,13 +1050,11 @@ impl Session { } Ok(()) } - b"streamlocal-forward@openssh.com" => { - let server_socket_path = - std::str::from_utf8(r.read_string().map_err(crate::Error::from)?) - .map_err(crate::Error::from)?; + "streamlocal-forward@openssh.com" => { + let server_socket_path = map_err!(String::decode(r))?; debug!("handler.streamlocal_forward {:?}", server_socket_path); let result = handler - .streamlocal_forward(server_socket_path, self) + .streamlocal_forward(&server_socket_path, self) .await?; if let Some(ref mut enc) = self.common.encrypted { if result { @@ -1096,13 +1065,11 @@ impl Session { } Ok(()) } - b"cancel-streamlocal-forward@openssh.com" => { - let socket_path = - std::str::from_utf8(r.read_string().map_err(crate::Error::from)?) - .map_err(crate::Error::from)?; + "cancel-streamlocal-forward@openssh.com" => { + let socket_path = map_err!(String::decode(r))?; debug!("handler.cancel_streamlocal_forward {:?}", socket_path); let result = handler - .cancel_streamlocal_forward(socket_path, self) + .cancel_streamlocal_forward(&socket_path, self) .await?; if let Some(ref mut enc) = self.common.encrypted { if result { @@ -1123,19 +1090,13 @@ impl Session { } } } - Some(&msg::CHANNEL_OPEN_FAILURE) => { + msg::CHANNEL_OPEN_FAILURE => { debug!("channel_open_failure"); - let mut buf_pos = buf.reader(1); - let channel_num = ChannelId(buf_pos.read_u32().map_err(crate::Error::from)?); - let reason = - ChannelOpenFailure::from_u32(buf_pos.read_u32().map_err(crate::Error::from)?) - .unwrap_or(ChannelOpenFailure::Unknown); - let description = - std::str::from_utf8(buf_pos.read_string().map_err(crate::Error::from)?) - .map_err(crate::Error::from)?; - let language_tag = - std::str::from_utf8(buf_pos.read_string().map_err(crate::Error::from)?) - .map_err(crate::Error::from)?; + let channel_num = map_err!(ChannelId::decode(r))?; + let reason = ChannelOpenFailure::from_u32(map_err!(u32::decode(r))?) + .unwrap_or(ChannelOpenFailure::Unknown); + let description = map_err!(String::decode(r))?; + let language_tag = map_err!(String::decode(r))?; trace!("Channel open failure description: {description}"); trace!("Channel open failure language tag: {language_tag}"); @@ -1152,19 +1113,18 @@ impl Session { Ok(()) } - Some(&msg::REQUEST_SUCCESS) => { + msg::REQUEST_SUCCESS => { trace!("Global Request Success"); match self.open_global_requests.pop_front() { Some(GlobalRequestResponse::Keepalive) => { // ignore keepalives } Some(GlobalRequestResponse::TcpIpForward(return_channel)) => { - let result = if buf.len() == 1 { + let result = if r.is_finished() { // If a specific port was requested, the reply has no data Some(0) } else { - let mut r = buf.reader(1); - match r.read_u32() { + match u32::decode(r) { Ok(port) => Some(port), Err(e) => { error!("Error parsing port for TcpIpForward request: {e:?}"); @@ -1183,7 +1143,7 @@ impl Session { } Ok(()) } - Some(&msg::REQUEST_FAILURE) => { + msg::REQUEST_FAILURE => { trace!("global request failure"); match self.open_global_requests.pop_front() { Some(GlobalRequestResponse::Keepalive) => { @@ -1208,13 +1168,12 @@ impl Session { } } - async fn server_handle_channel_open( + async fn server_handle_channel_open( &mut self, handler: &mut H, - buf: &[u8], + r: &mut R, ) -> Result { - let mut r = buf.reader(1); - let msg = OpenChannelMessage::parse(&mut r)?; + let msg = OpenChannelMessage::parse(r)?; let sender_channel = if let Some(ref mut enc) = self.common.encrypted { enc.new_channel_id() @@ -1250,7 +1209,7 @@ impl Session { let mut result = handler.channel_open_session(channel, self).await; if let Ok(allowed) = &mut result { self.channels.insert(sender_channel, reference); - self.finalize_channel_open(&msg, channel_params, *allowed); + self.finalize_channel_open(&msg, channel_params, *allowed)?; } result } @@ -1263,7 +1222,7 @@ impl Session { .await; if let Ok(allowed) = &mut result { self.channels.insert(sender_channel, reference); - self.finalize_channel_open(&msg, channel_params, *allowed); + self.finalize_channel_open(&msg, channel_params, *allowed)?; } result } @@ -1280,7 +1239,7 @@ impl Session { .await; if let Ok(allowed) = &mut result { self.channels.insert(sender_channel, reference); - self.finalize_channel_open(&msg, channel_params, *allowed); + self.finalize_channel_open(&msg, channel_params, *allowed)?; } result } @@ -1297,7 +1256,7 @@ impl Session { .await; if let Ok(allowed) = &mut result { self.channels.insert(sender_channel, reference); - self.finalize_channel_open(&msg, channel_params, *allowed); + self.finalize_channel_open(&msg, channel_params, *allowed)?; } result } @@ -1307,7 +1266,7 @@ impl Session { &mut enc.write, msg::SSH_OPEN_ADMINISTRATIVELY_PROHIBITED, b"Unsupported channel type", - ); + )?; } Ok(false) } @@ -1317,14 +1276,14 @@ impl Session { &mut enc.write, msg::SSH_OPEN_ADMINISTRATIVELY_PROHIBITED, b"Unsupported channel type", - ); + )?; } Ok(false) } ChannelType::Unknown { typ } => { - debug!("unknown channel type: {}", String::from_utf8_lossy(typ)); + debug!("unknown channel type: {typ}"); if let Some(ref mut enc) = self.common.encrypted { - msg.unknown_type(&mut enc.write); + msg.unknown_type(&mut enc.write)?; } Ok(false) } @@ -1336,7 +1295,7 @@ impl Session { open: &OpenChannelMessage, channel: ChannelParams, allowed: bool, - ) { + ) -> Result<(), Error> { if let Some(ref mut enc) = self.common.encrypted { if allowed { open.confirm( @@ -1344,15 +1303,16 @@ impl Session { channel.sender_channel.0, channel.sender_window_size, channel.sender_maximum_packet_size, - ); + )?; enc.channels.insert(channel.sender_channel, channel); } else { open.fail( &mut enc.write, SSH_OPEN_ADMINISTRATIVELY_PROHIBITED, b"Rejected", - ); + )?; } } + Ok(()) } } diff --git a/russh/src/server/kex.rs b/russh/src/server/kex.rs index a278b2e4..85a4a5d0 100644 --- a/russh/src/server/kex.rs +++ b/russh/src/server/kex.rs @@ -1,12 +1,13 @@ use std::cell::RefCell; +use std::ops::DerefMut; use log::debug; -use russh_keys::add_signature; +use russh_keys::helpers::EncodedExt; +use ssh_encoding::Encode; use super::*; use crate::cipher::SealingKey; use crate::kex::KEXES; -use crate::keys::encoding::{Encoding, Reader}; use crate::negotiation::Select; use crate::{msg, negotiation}; @@ -87,9 +88,13 @@ impl KexDh { Ok(Kex::Dh(self)) } else { // Else, process it. - assert!(buf.first() == Some(&msg::KEX_ECDH_INIT)); - let mut r = buf.reader(1); - self.exchange.client_ephemeral.extend(r.read_string()?); + let Some((&msg::KEX_ECDH_INIT, mut r)) = buf.split_first() else { + return Err(Error::Inconsistent); + }; + + self.exchange + .client_ephemeral + .extend(&Bytes::decode(&mut r)?); let mut kex = KEXES.get(&self.names.kex).ok_or(Error::UnknownAlgo)?.make(); @@ -111,12 +116,10 @@ impl KexDh { debug!("server kexdhdone.exchange = {:?}", kexdhdone.exchange); let mut pubkey_vec = CryptoVec::new(); - pubkey_vec.extend_ssh_string( - config.keys[kexdhdone.key] - .public_key() - .to_bytes()? - .as_slice(), - ); + config.keys[kexdhdone.key] + .public_key() + .to_bytes()? + .encode(&mut pubkey_vec)?; let hash = kexdhdone.kex.compute_exchange_hash( &pubkey_vec, @@ -126,20 +129,24 @@ impl KexDh { debug!("exchange hash: {:?}", hash); buffer.clear(); buffer.push(msg::KEX_ECDH_REPLY); - buffer.extend_ssh_string( - config.keys[kexdhdone.key] - .public_key() - .to_bytes()? - .as_slice(), - ); + config.keys[kexdhdone.key] + .public_key() + .to_bytes()? + .encode(buffer.deref_mut())?; + // Server ephemeral - buffer.extend_ssh_string(&kexdhdone.exchange.server_ephemeral); + kexdhdone + .exchange + .server_ephemeral + .encode(buffer.deref_mut())?; + // Hash signature debug!("signing with key {:?}", kexdhdone.key); debug!("hash: {:?}", hash); debug!("key: {:?}", config.keys[kexdhdone.key]); - add_signature(&config.keys[kexdhdone.key], &hash, &mut buffer)?; + let signature = signature::Signer::try_sign(&config.keys[kexdhdone.key], &hash)?; + signature.encoded()?.encode(&mut *buffer)?; cipher.write(&buffer, write_buffer); cipher.write(&[msg::NEWKEYS], write_buffer); diff --git a/russh/src/server/mod.rs b/russh/src/server/mod.rs index fb8d333d..be7e17b2 100644 --- a/russh/src/server/mod.rs +++ b/russh/src/server/mod.rs @@ -36,8 +36,10 @@ use std::sync::Arc; use std::task::{Context, Poll}; use async_trait::async_trait; +use bytes::Bytes; use futures::future::Future; use log::{debug, error}; +use russh_keys::map_err; use russh_util::runtime::JoinHandle; use ssh_key::{Certificate, PrivateKey}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; @@ -121,21 +123,12 @@ impl Default for Config { /// A client's response in a challenge-response authentication. /// /// You should iterate it to get `&[u8]` response slices. -#[derive(Debug)] -pub struct Response<'a> { - pos: russh_keys::encoding::Position<'a>, - n: u32, -} +pub struct Response<'a>(&'a mut (dyn Iterator> + Send)); -impl<'a> Iterator for Response<'a> { - type Item = &'a [u8]; +impl Iterator for Response<'_> { + type Item = Bytes; fn next(&mut self) -> Option { - if self.n == 0 { - None - } else { - self.n -= 1; - self.pos.read_string().ok() - } + self.0.next().flatten() } } @@ -695,10 +688,7 @@ where // Writing SSH id. let mut write_buffer = SSHBuffer::new(); write_buffer.send_ssh_id(&config.as_ref().server_id); - stream - .write_all(&write_buffer.buffer[..]) - .await - .map_err(crate::Error::from)?; + map_err!(stream.write_all(&write_buffer.buffer[..]).await)?; // Reading SSH id and allocating a session. let mut stream = SshRead::new(stream); @@ -844,7 +834,7 @@ async fn reply( }, newkeys, ); - session.maybe_send_ext_info(); + session.maybe_send_ext_info()?; if session.common.strict_kex { *seqn = Wrapping(0); } diff --git a/russh/src/server/session.rs b/russh/src/server/session.rs index 26d8da64..763ac180 100644 --- a/russh/src/server/session.rs +++ b/russh/src/server/session.rs @@ -3,6 +3,8 @@ use std::sync::Arc; use log::debug; use negotiation::parse_kex_algo_list; +use russh_keys::helpers::NameList; +use russh_keys::map_err; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::sync::mpsc::{unbounded_channel, Receiver, Sender, UnboundedReceiver}; use tokio::sync::{oneshot, Mutex}; @@ -10,7 +12,6 @@ use tokio::sync::{oneshot, Mutex}; use super::*; use crate::channels::{Channel, ChannelMsg, ChannelRef}; use crate::kex::EXTENSION_SUPPORT_AS_CLIENT; -use crate::keys::encoding::{Encoding, Reader}; use crate::msg; /// A connected server session. This type is unique to a client. @@ -435,10 +436,7 @@ impl Session { R: AsyncRead + AsyncWrite + Unpin + Send + 'static, { self.flush()?; - stream - .write_all(&self.common.write_buffer.buffer) - .await - .map_err(crate::Error::from)?; + map_err!(stream.write_all(&self.common.write_buffer.buffer).await)?; self.common.write_buffer.buffer.clear(); let (stream_read, mut stream_write) = stream.split(); @@ -517,7 +515,7 @@ impl Session { } self.common.alive_timeouts = self.common.alive_timeouts.saturating_add(1); sent_keepalive = true; - self.keepalive_request(); + self.keepalive_request()?; } () = &mut inactivity_timer => { debug!("timeout"); @@ -526,31 +524,31 @@ impl Session { msg = self.receiver.recv(), if !self.is_rekeying() => { match msg { Some(Msg::Channel(id, ChannelMsg::Data { data })) => { - self.data(id, data); + self.data(id, data)?; } Some(Msg::Channel(id, ChannelMsg::ExtendedData { ext, data })) => { - self.extended_data(id, ext, data); + self.extended_data(id, ext, data)?; } Some(Msg::Channel(id, ChannelMsg::Eof)) => { - self.eof(id); + self.eof(id)?; } Some(Msg::Channel(id, ChannelMsg::Close)) => { - self.close(id); + self.close(id)?; } Some(Msg::Channel(id, ChannelMsg::Success)) => { - self.channel_success(id); + self.channel_success(id)?; } Some(Msg::Channel(id, ChannelMsg::Failure)) => { - self.channel_failure(id); + self.channel_failure(id)?; } Some(Msg::Channel(id, ChannelMsg::XonXoff { client_can_do })) => { - self.xon_xoff_request(id, client_can_do); + self.xon_xoff_request(id, client_can_do)?; } Some(Msg::Channel(id, ChannelMsg::ExitStatus { exit_status })) => { - self.exit_status_request(id, exit_status); + self.exit_status_request(id, exit_status)?; } Some(Msg::Channel(id, ChannelMsg::ExitSignal { signal_name, core_dumped, error_message, lang_tag })) => { - self.exit_signal_request(id, signal_name, core_dumped, &error_message, &lang_tag); + self.exit_signal_request(id, signal_name, core_dumped, &error_message, &lang_tag)?; } Some(Msg::Channel(id, ChannelMsg::WindowAdjusted { new_size })) => { debug!("window adjusted to {:?} for channel {:?}", new_size, id); @@ -580,13 +578,13 @@ impl Session { self.channels.insert(id, channel_ref); } Some(Msg::TcpIpForward { address, port, reply_channel }) => { - self.tcpip_forward(&address, port, reply_channel); + self.tcpip_forward(&address, port, reply_channel)?; } Some(Msg::CancelTcpIpForward { address, port, reply_channel }) => { - self.cancel_tcpip_forward(&address, port, reply_channel); + self.cancel_tcpip_forward(&address, port, reply_channel)?; } Some(Msg::Disconnect {reason, description, language_tag}) => { - self.common.disconnect(reason, &description, &language_tag); + self.common.disconnect(reason, &description, &language_tag)?; } Some(_) => { // should be unreachable, since the receiver only gets @@ -600,10 +598,11 @@ impl Session { } } self.flush()?; - stream_write - .write_all(&self.common.write_buffer.buffer) - .await - .map_err(crate::Error::from)?; + map_err!( + stream_write + .write_all(&self.common.write_buffer.buffer) + .await + )?; self.common.write_buffer.buffer.clear(); if self.common.received_data { @@ -633,7 +632,7 @@ impl Session { } debug!("disconnected"); // Shutdown - stream_write.shutdown().await.map_err(crate::Error::from)?; + map_err!(stream_write.shutdown().await)?; loop { if let Some((stream_read, buffer, opening_cipher)) = is_reading.take() { reading.set(start_reading(stream_read, buffer, opening_cipher)); @@ -706,11 +705,11 @@ impl Session { Ok(()) } - pub fn flush_pending(&mut self, channel: ChannelId) -> usize { + pub fn flush_pending(&mut self, channel: ChannelId) -> Result { if let Some(ref mut enc) = self.common.encrypted { enc.flush_pending(channel) } else { - 0 + Ok(0) } } @@ -736,8 +735,13 @@ impl Session { } /// Sends a disconnect message. - pub fn disconnect(&mut self, reason: Disconnect, description: &str, language_tag: &str) { - self.common.disconnect(reason, description, language_tag); + pub fn disconnect( + &mut self, + reason: Disconnect, + description: &str, + language_tag: &str, + ) -> Result<(), Error> { + self.common.disconnect(reason, description, language_tag) } /// Send a "success" reply to a /global/ request (requests without @@ -764,7 +768,7 @@ impl Session { /// Send a "success" reply to a channel request. Always call this /// function if the request was successful (it checks whether the /// client expects an answer). - pub fn channel_success(&mut self, channel: ChannelId) { + pub fn channel_success(&mut self, channel: ChannelId) -> Result<(), crate::Error> { if let Some(ref mut enc) = self.common.encrypted { if let Some(channel) = enc.channels.get_mut(&channel) { assert!(channel.confirmed); @@ -772,16 +776,17 @@ impl Session { channel.wants_reply = false; debug!("channel_success {:?}", channel); push_packet!(enc.write, { - enc.write.push(msg::CHANNEL_SUCCESS); - enc.write.push_u32_be(channel.recipient_channel); + msg::CHANNEL_SUCCESS.encode(&mut enc.write)?; + channel.recipient_channel.encode(&mut enc.write)?; }) } } } + Ok(()) } /// Send a "failure" reply to a global request. - pub fn channel_failure(&mut self, channel: ChannelId) { + pub fn channel_failure(&mut self, channel: ChannelId) -> Result<(), crate::Error> { if let Some(ref mut enc) = self.common.encrypted { if let Some(channel) = enc.channels.get_mut(&channel) { assert!(channel.confirmed); @@ -789,11 +794,12 @@ impl Session { channel.wants_reply = false; push_packet!(enc.write, { enc.write.push(msg::CHANNEL_FAILURE); - enc.write.push_u32_be(channel.recipient_channel); + channel.recipient_channel.encode(&mut enc.write)?; }) } } } + Ok(()) } /// Send a "failure" reply to a request to open a channel open. @@ -803,26 +809,27 @@ impl Session { reason: ChannelOpenFailure, description: &str, language: &str, - ) { + ) -> Result<(), crate::Error> { if let Some(ref mut enc) = self.common.encrypted { push_packet!(enc.write, { enc.write.push(msg::CHANNEL_OPEN_FAILURE); - enc.write.push_u32_be(channel.0); - enc.write.push_u32_be(reason as u32); - enc.write.extend_ssh_string(description.as_bytes()); - enc.write.extend_ssh_string(language.as_bytes()); + channel.encode(&mut enc.write)?; + (reason as u32).encode(&mut enc.write)?; + description.encode(&mut enc.write)?; + language.encode(&mut enc.write)?; }) } + Ok(()) } /// Close a channel. - pub fn close(&mut self, channel: ChannelId) { - self.common.byte(channel, msg::CHANNEL_CLOSE); + pub fn close(&mut self, channel: ChannelId) -> Result<(), Error> { + self.common.byte(channel, msg::CHANNEL_CLOSE) } /// Send EOF to a channel - pub fn eof(&mut self, channel: ChannelId) { - self.common.byte(channel, msg::CHANNEL_EOF); + pub fn eof(&mut self, channel: ChannelId) -> Result<(), Error> { + self.common.byte(channel, msg::CHANNEL_EOF) } /// Send data to a channel. On session channels, `extended` can be @@ -831,7 +838,7 @@ impl Session { /// /// The number of bytes added to the "sending pipeline" (to be /// processed by the event loop) is returned. - pub fn data(&mut self, channel: ChannelId, data: CryptoVec) { + pub fn data(&mut self, channel: ChannelId, data: CryptoVec) -> Result<(), Error> { if let Some(ref mut enc) = self.common.encrypted { enc.data(channel, data) } else { @@ -845,7 +852,12 @@ impl Session { /// /// The number of bytes added to the "sending pipeline" (to be /// processed by the event loop) is returned. - pub fn extended_data(&mut self, channel: ChannelId, extended: u32, data: CryptoVec) { + pub fn extended_data( + &mut self, + channel: ChannelId, + extended: u32, + data: CryptoVec, + ) -> Result<(), Error> { if let Some(ref mut enc) = self.common.encrypted { enc.extended_data(channel, extended, data) } else { @@ -856,51 +868,62 @@ impl Session { /// Inform the client of whether they may perform /// control-S/control-Q flow control. See /// [RFC4254](https://tools.ietf.org/html/rfc4254#section-6.8). - pub fn xon_xoff_request(&mut self, channel: ChannelId, client_can_do: bool) { + pub fn xon_xoff_request( + &mut self, + channel: ChannelId, + client_can_do: bool, + ) -> Result<(), Error> { if let Some(ref mut enc) = self.common.encrypted { if let Some(channel) = enc.channels.get(&channel) { assert!(channel.confirmed); push_packet!(enc.write, { - enc.write.push(msg::CHANNEL_REQUEST); + msg::CHANNEL_REQUEST.encode(&mut enc.write)?; - enc.write.push_u32_be(channel.recipient_channel); - enc.write.extend_ssh_string(b"xon-xoff"); - enc.write.push(0); - enc.write.push(client_can_do as u8); + channel.recipient_channel.encode(&mut enc.write)?; + "xon-xoff".encode(&mut enc.write)?; + 0u8.encode(&mut enc.write)?; + (client_can_do as u8).encode(&mut enc.write)?; }) } } + Ok(()) } /// Ping the client to verify there is still connectivity. - pub fn keepalive_request(&mut self) { + pub fn keepalive_request(&mut self) -> Result<(), Error> { let want_reply = u8::from(true); if let Some(ref mut enc) = self.common.encrypted { self.open_global_requests .push_back(GlobalRequestResponse::Keepalive); push_packet!(enc.write, { - enc.write.push(msg::GLOBAL_REQUEST); - enc.write.extend_ssh_string(b"keepalive@openssh.com"); - enc.write.push(want_reply); + msg::GLOBAL_REQUEST.encode(&mut enc.write)?; + "keepalive@openssh.com".encode(&mut enc.write)?; + want_reply.encode(&mut enc.write)?; }) } + Ok(()) } /// Send the exit status of a program. - pub fn exit_status_request(&mut self, channel: ChannelId, exit_status: u32) { + pub fn exit_status_request( + &mut self, + channel: ChannelId, + exit_status: u32, + ) -> Result<(), Error> { if let Some(ref mut enc) = self.common.encrypted { if let Some(channel) = enc.channels.get(&channel) { assert!(channel.confirmed); push_packet!(enc.write, { - enc.write.push(msg::CHANNEL_REQUEST); + msg::CHANNEL_REQUEST.encode(&mut enc.write)?; - enc.write.push_u32_be(channel.recipient_channel); - enc.write.extend_ssh_string(b"exit-status"); - enc.write.push(0); - enc.write.push_u32_be(exit_status) + channel.recipient_channel.encode(&mut enc.write)?; + "exit-status".encode(&mut enc.write)?; + 0u8.encode(&mut enc.write)?; + exit_status.encode(&mut enc.write)?; }) } } + Ok(()) } /// If the program was killed by a signal, send the details about the signal to the client. @@ -911,28 +934,29 @@ impl Session { core_dumped: bool, error_message: &str, language_tag: &str, - ) { + ) -> Result<(), Error> { if let Some(ref mut enc) = self.common.encrypted { if let Some(channel) = enc.channels.get(&channel) { assert!(channel.confirmed); push_packet!(enc.write, { - enc.write.push(msg::CHANNEL_REQUEST); - - enc.write.push_u32_be(channel.recipient_channel); - enc.write.extend_ssh_string(b"exit-signal"); - enc.write.push(0); - enc.write.extend_ssh_string(signal.name().as_bytes()); - enc.write.push(core_dumped as u8); - enc.write.extend_ssh_string(error_message.as_bytes()); - enc.write.extend_ssh_string(language_tag.as_bytes()); + msg::CHANNEL_REQUEST.encode(&mut enc.write)?; + + channel.recipient_channel.encode(&mut enc.write)?; + "exit-signal".encode(&mut enc.write)?; + 0u8.encode(&mut enc.write)?; + signal.name().encode(&mut enc.write)?; + (core_dumped as u8).encode(&mut enc.write)?; + error_message.encode(&mut enc.write)?; + language_tag.encode(&mut enc.write)?; }) } } + Ok(()) } /// Opens a new session channel on the client. pub fn channel_open_session(&mut self) -> Result { - self.channel_open_generic(b"session", |_| ()) + self.channel_open_generic(b"session", |_| Ok(())) } /// Opens a direct TCP/IP channel on the client. @@ -944,10 +968,11 @@ impl Session { originator_port: u32, ) -> Result { self.channel_open_generic(b"direct-tcpip", |write| { - write.extend_ssh_string(host_to_connect.as_bytes()); - write.push_u32_be(port_to_connect); // sender channel id. - write.extend_ssh_string(originator_address.as_bytes()); - write.push_u32_be(originator_port); // sender channel id. + host_to_connect.encode(write)?; + port_to_connect.encode(write)?; // sender channel id. + originator_address.encode(write)?; + originator_port.encode(write)?; // sender channel id. + Ok(()) }) } @@ -964,10 +989,11 @@ impl Session { originator_port: u32, ) -> Result { self.channel_open_generic(b"forwarded-tcpip", |write| { - write.extend_ssh_string(connected_address.as_bytes()); - write.push_u32_be(connected_port); // sender channel id. - write.extend_ssh_string(originator_address.as_bytes()); - write.push_u32_be(originator_port); // sender channel id. + connected_address.encode(write)?; + connected_port.encode(write)?; // sender channel id. + originator_address.encode(write)?; + originator_port.encode(write)?; // sender channel id. + Ok(()) }) } @@ -976,8 +1002,9 @@ impl Session { socket_path: &str, ) -> Result { self.channel_open_generic(b"forwarded-streamlocal@openssh.com", |write| { - write.extend_ssh_string(socket_path.as_bytes()); - write.extend_ssh_string(b""); + socket_path.encode(write)?; + "".encode(write)?; + Ok(()) }) } @@ -990,19 +1017,20 @@ impl Session { originator_port: u32, ) -> Result { self.channel_open_generic(b"x11", |write| { - write.extend_ssh_string(originator_address.as_bytes()); - write.push_u32_be(originator_port); + originator_address.encode(write)?; + originator_port.encode(write)?; + Ok(()) }) } /// Opens a new agent channel on the client. pub fn channel_open_agent(&mut self) -> Result { - self.channel_open_generic(b"auth-agent@openssh.com", |_| ()) + self.channel_open_generic(b"auth-agent@openssh.com", |_| Ok(())) } fn channel_open_generic(&mut self, kind: &[u8], write_suffix: F) -> Result where - F: FnOnce(&mut CryptoVec), + F: FnOnce(&mut CryptoVec) -> Result<(), Error>, { let result = if let Some(ref mut enc) = self.common.encrypted { if !matches!( @@ -1018,20 +1046,26 @@ impl Session { ); push_packet!(enc.write, { enc.write.push(msg::CHANNEL_OPEN); - enc.write.extend_ssh_string(kind); + kind.encode(&mut enc.write)?; // sender channel id. - enc.write.push_u32_be(sender_channel.0); + sender_channel.encode(&mut enc.write)?; // window. - enc.write - .push_u32_be(self.common.config.as_ref().window_size); + self.common + .config + .as_ref() + .window_size + .encode(&mut enc.write)?; // max packet size. - enc.write - .push_u32_be(self.common.config.as_ref().maximum_packet_size); + self.common + .config + .as_ref() + .maximum_packet_size + .encode(&mut enc.write)?; - write_suffix(&mut enc.write); + write_suffix(&mut enc.write)?; }); sender_channel } else { @@ -1048,7 +1082,7 @@ impl Session { address: &str, port: u32, reply_channel: Option>>, - ) { + ) -> Result<(), Error> { if let Some(ref mut enc) = self.common.encrypted { let want_reply = reply_channel.is_some(); if let Some(reply_channel) = reply_channel { @@ -1058,12 +1092,13 @@ impl Session { } push_packet!(enc.write, { enc.write.push(msg::GLOBAL_REQUEST); - enc.write.extend_ssh_string(b"tcpip-forward"); - enc.write.push(want_reply as u8); - enc.write.extend_ssh_string(address.as_bytes()); - enc.write.push_u32_be(port); + "tcpip-forward".encode(&mut enc.write)?; + (want_reply as u8).encode(&mut enc.write)?; + address.encode(&mut enc.write)?; + port.encode(&mut enc.write)?; }); } + Ok(()) } /// Cancels a previously tcpip_forward request. @@ -1072,7 +1107,7 @@ impl Session { address: &str, port: u32, reply_channel: Option>, - ) { + ) -> Result<(), Error> { if let Some(ref mut enc) = self.common.encrypted { let want_reply = reply_channel.is_some(); if let Some(reply_channel) = reply_channel { @@ -1081,13 +1116,14 @@ impl Session { ); } push_packet!(enc.write, { - enc.write.push(msg::GLOBAL_REQUEST); - enc.write.extend_ssh_string(b"cancel-tcpip-forward"); - enc.write.push(want_reply as u8); - enc.write.extend_ssh_string(address.as_bytes()); - enc.write.push_u32_be(port); + msg::GLOBAL_REQUEST.encode(&mut enc.write)?; + "cancel-tcpip-forward".encode(&mut enc.write)?; + (want_reply as u8).encode(&mut enc.write)?; + address.encode(&mut enc.write)?; + port.encode(&mut enc.write)?; }); } + Ok(()) } /// Returns the SSH ID (Protocol Version + Software Version) the client sent when connecting @@ -1103,17 +1139,19 @@ impl Session { &self.common.remote_sshid } - pub(crate) fn maybe_send_ext_info(&mut self) { + pub(crate) fn maybe_send_ext_info(&mut self) -> Result<(), Error> { if let Some(ref mut enc) = self.common.encrypted { // If client sent a ext-info-c message in the kex list, it supports RFC 8308 extension negotiation. let mut key_extension_client = false; if let Some(e) = &enc.exchange { - let mut r = e.client_kex_init.as_ref().reader(17); - if let Ok(kex_string) = r.read_string() { + let Some(mut r) = &e.client_kex_init.as_ref().get(17..) else { + return Ok(()); + }; + if let Ok(kex_string) = String::decode(&mut r) { use super::negotiation::Select; key_extension_client = super::negotiation::Server::select( &[EXTENSION_SUPPORT_AS_CLIENT], - &parse_kex_algo_list(kex_string), + &parse_kex_algo_list(&kex_string), AlgorithmKind::Kex, ) .is_ok(); @@ -1122,16 +1160,26 @@ impl Session { if !key_extension_client { debug!("RFC 8308 Extension Negotiation not supported by client"); - return; + return Ok(()); } push_packet!(enc.write, { - enc.write.push(msg::EXT_INFO); - enc.write.push_u32_be(1); - enc.write.extend_ssh_string(b"server-sig-algs"); - enc.write - .extend_list(self.common.config.preferred.key.iter()); + msg::EXT_INFO.encode(&mut enc.write)?; + 1u32.encode(&mut enc.write)?; + "server-sig-algs".encode(&mut enc.write)?; + + NameList( + self.common + .config + .preferred + .key + .iter() + .map(|x| x.to_string()) + .collect(), + ) + .encode(&mut enc.write)?; }); } + Ok(()) } } diff --git a/russh/src/session.rs b/russh/src/session.rs index f901aef4..fbeea036 100644 --- a/russh/src/session.rs +++ b/russh/src/session.rs @@ -19,11 +19,11 @@ use std::num::Wrapping; use byteorder::{BigEndian, ByteOrder}; use log::{debug, trace}; +use ssh_encoding::Encode; use tokio::sync::oneshot; use crate::cipher::SealingKey; use crate::kex::KexAlgorithm; -use crate::keys::encoding::Encoding; use crate::sshbuffer::SSHBuffer; use crate::{ auth, cipher, mac, msg, negotiation, ChannelId, ChannelParams, CryptoVec, Disconnect, Limits, @@ -138,31 +138,39 @@ impl CommonSession { } /// Send a disconnect message. - pub fn disconnect(&mut self, reason: Disconnect, description: &str, language_tag: &str) { + pub fn disconnect( + &mut self, + reason: Disconnect, + description: &str, + language_tag: &str, + ) -> Result<(), crate::Error> { let disconnect = |buf: &mut CryptoVec| { push_packet!(buf, { - buf.push(msg::DISCONNECT); - buf.push_u32_be(reason as u32); - buf.extend_ssh_string(description.as_bytes()); - buf.extend_ssh_string(language_tag.as_bytes()); + msg::DISCONNECT.encode(buf)?; + (reason as u32).encode(buf)?; + description.encode(buf)?; + language_tag.encode(buf)?; }); + Ok(()) }; if !self.disconnected { self.disconnected = true; - if let Some(ref mut enc) = self.encrypted { + return if let Some(ref mut enc) = self.encrypted { disconnect(&mut enc.write) } else { disconnect(&mut self.write_buffer.buffer) - } + }; } + Ok(()) } /// Send a single byte message onto the channel. #[cfg(not(target_arch = "wasm32"))] - pub fn byte(&mut self, channel: ChannelId, msg: u8) { + pub fn byte(&mut self, channel: ChannelId, msg: u8) -> Result<(), crate::Error> { if let Some(ref mut enc) = self.encrypted { - enc.byte(channel, msg) + enc.byte(channel, msg)? } + Ok(()) } pub(crate) fn maybe_reset_seqn(&mut self) { @@ -173,13 +181,14 @@ impl CommonSession { } impl Encrypted { - pub fn byte(&mut self, channel: ChannelId, msg: u8) { + pub fn byte(&mut self, channel: ChannelId, msg: u8) -> Result<(), crate::Error> { if let Some(channel) = self.channels.get(&channel) { push_packet!(self.write, { self.write.push(msg); - self.write.push_u32_be(channel.recipient_channel); + channel.recipient_channel.encode(&mut self.write)?; }); } + Ok(()) } /* @@ -189,21 +198,23 @@ impl Encrypted { } */ - pub fn eof(&mut self, channel: ChannelId) { + pub fn eof(&mut self, channel: ChannelId) -> Result<(), crate::Error> { if let Some(channel) = self.has_pending_data_mut(channel) { channel.pending_eof = true; } else { - self.byte(channel, msg::CHANNEL_EOF); + self.byte(channel, msg::CHANNEL_EOF)?; } + Ok(()) } - pub fn close(&mut self, channel: ChannelId) { + pub fn close(&mut self, channel: ChannelId) -> Result<(), crate::Error> { if let Some(channel) = self.has_pending_data_mut(channel) { channel.pending_close = true; } else { - self.byte(channel, msg::CHANNEL_CLOSE); + self.byte(channel, msg::CHANNEL_CLOSE)?; self.channels.remove(&channel); } + Ok(()) } pub fn sender_window_size(&self, channel: ChannelId) -> usize { @@ -214,7 +225,12 @@ impl Encrypted { } } - pub fn adjust_window_size(&mut self, channel: ChannelId, data: &[u8], target: u32) -> bool { + pub fn adjust_window_size( + &mut self, + channel: ChannelId, + data: &[u8], + target: u32, + ) -> Result { if let Some(channel) = self.channels.get_mut(&channel) { trace!( "adjust_window_size, channel = {}, size = {},", @@ -233,32 +249,39 @@ impl Encrypted { ); push_packet!(self.write, { self.write.push(msg::CHANNEL_WINDOW_ADJUST); - self.write.push_u32_be(channel.recipient_channel); - self.write.push_u32_be(target - channel.sender_window_size); + channel.recipient_channel.encode(&mut self.write)?; + (target - channel.sender_window_size).encode(&mut self.write)?; }); channel.sender_window_size = target; - return true; + return Ok(true); } } - false + Ok(false) } - fn flush_channel(write: &mut CryptoVec, channel: &mut ChannelParams) -> ChannelFlushResult { + fn flush_channel( + write: &mut CryptoVec, + channel: &mut ChannelParams, + ) -> Result { let mut pending_size = 0; while let Some((buf, a, from)) = channel.pending_data.pop_front() { - let size = Self::data_noqueue(write, channel, &buf, a, from); + let size = Self::data_noqueue(write, channel, &buf, a, from)?; pending_size += size; if from + size < buf.len() { channel.pending_data.push_front((buf, a, from + size)); - return ChannelFlushResult::Incomplete { + return Ok(ChannelFlushResult::Incomplete { wrote: pending_size, - }; + }); } } - ChannelFlushResult::complete(pending_size, channel) + Ok(ChannelFlushResult::complete(pending_size, channel)) } - fn handle_flushed_channel(&mut self, channel: ChannelId, flush_result: ChannelFlushResult) { + fn handle_flushed_channel( + &mut self, + channel: ChannelId, + flush_result: ChannelFlushResult, + ) -> Result<(), crate::Error> { if let ChannelFlushResult::Complete { wrote: _, pending_eof, @@ -266,33 +289,35 @@ impl Encrypted { } = flush_result { if pending_eof { - self.eof(channel); + self.eof(channel)?; } if pending_close { - self.close(channel); + self.close(channel)?; } } + Ok(()) } - pub fn flush_pending(&mut self, channel: ChannelId) -> usize { + pub fn flush_pending(&mut self, channel: ChannelId) -> Result { let mut pending_size = 0; let mut maybe_flush_result = Option::::None; if let Some(channel) = self.channels.get_mut(&channel) { - let flush_result = Self::flush_channel(&mut self.write, channel); + let flush_result = Self::flush_channel(&mut self.write, channel)?; pending_size += flush_result.wrote(); maybe_flush_result = Some(flush_result); } if let Some(flush_result) = maybe_flush_result { - self.handle_flushed_channel(channel, flush_result) + self.handle_flushed_channel(channel, flush_result)? } - pending_size + Ok(pending_size) } - pub fn flush_all_pending(&mut self) { + pub fn flush_all_pending(&mut self) -> Result<(), crate::Error> { for channel in self.channels.values_mut() { - Self::flush_channel(&mut self.write, channel); + Self::flush_channel(&mut self.write, channel)?; } + Ok(()) } fn has_pending_data_mut(&mut self, channel: ChannelId) -> Option<&mut ChannelParams> { @@ -318,9 +343,9 @@ impl Encrypted { buf0: &[u8], a: Option, from: usize, - ) -> usize { + ) -> Result { if from >= buf0.len() { - return 0; + return Ok(0); } let mut buf = if buf0.len() as u32 > from as u32 + channel.recipient_window_size { #[allow(clippy::indexing_slicing)] // length checked @@ -337,16 +362,16 @@ impl Encrypted { match a { None => push_packet!(write, { write.push(msg::CHANNEL_DATA); - write.push_u32_be(channel.recipient_channel); + channel.recipient_channel.encode(write)?; #[allow(clippy::indexing_slicing)] // length checked - write.extend_ssh_string(&buf[..off]); + buf[..off].encode(write)?; }), Some(ext) => push_packet!(write, { write.push(msg::CHANNEL_EXTENDED_DATA); - write.push_u32_be(channel.recipient_channel); - write.push_u32_be(ext); + channel.recipient_channel.encode(write)?; + ext.encode(write)?; #[allow(clippy::indexing_slicing)] // length checked - write.extend_ssh_string(&buf[..off]); + buf[..off].encode(write)?; }), } trace!( @@ -361,37 +386,44 @@ impl Encrypted { } } trace!("buf.len() = {:?}, buf_len = {:?}", buf.len(), buf_len); - buf_len + Ok(buf_len) } - pub fn data(&mut self, channel: ChannelId, buf0: CryptoVec) { + pub fn data(&mut self, channel: ChannelId, buf0: CryptoVec) -> Result<(), crate::Error> { if let Some(channel) = self.channels.get_mut(&channel) { assert!(channel.confirmed); if !channel.pending_data.is_empty() || self.rekey.is_some() { channel.pending_data.push_back((buf0, None, 0)); - return; + return Ok(()); } - let buf_len = Self::data_noqueue(&mut self.write, channel, &buf0, None, 0); + let buf_len = Self::data_noqueue(&mut self.write, channel, &buf0, None, 0)?; if buf_len < buf0.len() { channel.pending_data.push_back((buf0, None, buf_len)) } } else { debug!("{:?} not saved for this session", channel); } + Ok(()) } - pub fn extended_data(&mut self, channel: ChannelId, ext: u32, buf0: CryptoVec) { + pub fn extended_data( + &mut self, + channel: ChannelId, + ext: u32, + buf0: CryptoVec, + ) -> Result<(), crate::Error> { if let Some(channel) = self.channels.get_mut(&channel) { assert!(channel.confirmed); if !channel.pending_data.is_empty() { channel.pending_data.push_back((buf0, Some(ext), 0)); - return; + return Ok(()); } - let buf_len = Self::data_noqueue(&mut self.write, channel, &buf0, Some(ext), 0); + let buf_len = Self::data_noqueue(&mut self.write, channel, &buf0, Some(ext), 0)?; if buf_len < buf0.len() { channel.pending_data.push_back((buf0, Some(ext), buf_len)) } } + Ok(()) } pub fn flush( diff --git a/russh/src/tests.rs b/russh/src/tests.rs index fc22c127..f1507bd1 100644 --- a/russh/src/tests.rs +++ b/russh/src/tests.rs @@ -116,7 +116,7 @@ mod compress { session: &mut Session, ) -> Result<(), Self::Error> { debug!("server data = {:?}", std::str::from_utf8(data)); - session.data(channel, CryptoVec::from_slice(data)); + session.data(channel, CryptoVec::from_slice(data))?; Ok(()) } } @@ -235,7 +235,7 @@ mod channels { session: &mut client::Session, ) -> Result<(), Self::Error> { assert_eq!(data, &b"hello world!"[..]); - session.data(channel, CryptoVec::from_slice(&b"hey there!"[..])); + session.data(channel, CryptoVec::from_slice(&b"hey there!"[..]))?; Ok(()) } }