Skip to content

Commit

Permalink
feat: improve USB performances by storing endpoints (#79)
Browse files Browse the repository at this point in the history
  • Loading branch information
cli-s1n authored Dec 18, 2024
1 parent 3feda38 commit c54942f
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 47 deletions.
5 changes: 2 additions & 3 deletions adb_client/src/device/adb_usb_device.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use rusb::constants::LIBUSB_CLASS_VENDOR_SPEC;
use rusb::Device;
use rusb::DeviceDescriptor;
use rusb::UsbContext;
Expand Down Expand Up @@ -57,8 +58,6 @@ fn search_adb_devices() -> Result<Option<(u16, u16)>> {
}

fn is_adb_device<T: UsbContext>(device: &Device<T>, des: &DeviceDescriptor) -> bool {
const ADB_CLASS: u8 = 0xff;

const ADB_SUBCLASS: u8 = 0x42;
const ADB_PROTOCOL: u8 = 0x1;

Expand All @@ -77,7 +76,7 @@ fn is_adb_device<T: UsbContext>(device: &Device<T>, des: &DeviceDescriptor) -> b
let class = interface_des.class_code();
let subcl = interface_des.sub_class_code();
if proto == ADB_PROTOCOL
&& ((class == ADB_CLASS && subcl == ADB_SUBCLASS)
&& ((class == LIBUSB_CLASS_VENDOR_SPEC && subcl == ADB_SUBCLASS)
|| (class == BULK_CLASS && subcl == BULK_ADB_SUBCLASS))
{
return true;
Expand Down
2 changes: 1 addition & 1 deletion adb_client/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ pub enum RustADBError {
#[error("Cannot get home directory")]
NoHomeDirectory,
/// Generic USB error
#[error(transparent)]
#[error("USB Error: {0}")]
UsbError(#[from] rusb::Error),
/// USB device not found
#[error("USB Device not found: {0} {1}")]
Expand Down
103 changes: 60 additions & 43 deletions adb_client/src/transports/usb_transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::{
Result, RustADBError,
};

#[derive(Debug)]
#[derive(Clone, Debug)]
struct Endpoint {
iface: u8,
address: u8,
Expand All @@ -22,6 +22,8 @@ struct Endpoint {
pub struct USBTransport {
device: Device<GlobalContext>,
handle: Option<Arc<DeviceHandle<GlobalContext>>>,
read_endpoint: Option<Endpoint>,
write_endpoint: Option<Endpoint>,
}

impl USBTransport {
Expand Down Expand Up @@ -49,6 +51,8 @@ impl USBTransport {
Self {
device: rusb_device,
handle: None,
read_endpoint: None,
write_endpoint: None,
}
}

Expand All @@ -62,43 +66,34 @@ impl USBTransport {
.cloned()
}

fn get_read_endpoint(&self) -> Result<Endpoint> {
self.read_endpoint
.as_ref()
.ok_or(RustADBError::IOError(std::io::Error::new(
std::io::ErrorKind::NotConnected,
"no read endpoint setup",
)))
.cloned()
}

fn get_write_endpoint(&self) -> Result<&Endpoint> {
self.write_endpoint
.as_ref()
.ok_or(RustADBError::IOError(std::io::Error::new(
std::io::ErrorKind::NotConnected,
"no write endpoint setup",
)))
}

fn configure_endpoint(handle: &DeviceHandle<GlobalContext>, endpoint: &Endpoint) -> Result<()> {
handle.claim_interface(endpoint.iface)?;
Ok(())
}

fn find_readable_endpoint(&self) -> Result<Endpoint> {
let handle = self.get_raw_connection()?;
for n in 0..handle.device().device_descriptor()?.num_configurations() {
let config_desc = match handle.device().config_descriptor(n) {
Ok(c) => c,
Err(_) => continue,
};
fn find_endpoints(&self, handle: &DeviceHandle<GlobalContext>) -> Result<(Endpoint, Endpoint)> {
let mut read_endpoint: Option<Endpoint> = None;
let mut write_endpoint: Option<Endpoint> = None;

for interface in config_desc.interfaces() {
for interface_desc in interface.descriptors() {
for endpoint_desc in interface_desc.endpoint_descriptors() {
if endpoint_desc.direction() == Direction::In
&& endpoint_desc.transfer_type() == TransferType::Bulk
&& interface_desc.class_code() == LIBUSB_CLASS_VENDOR_SPEC
&& interface_desc.sub_class_code() == 0x42
&& interface_desc.protocol_code() == 0x01
{
return Ok(Endpoint {
iface: interface_desc.interface_number(),
address: endpoint_desc.address(),
});
}
}
}
}
}

Err(RustADBError::USBNoDescriptorFound)
}

fn find_writable_endpoint(&self) -> Result<Endpoint> {
let handle = self.get_raw_connection()?;
for n in 0..handle.device().device_descriptor()?.num_configurations() {
let config_desc = match handle.device().config_descriptor(n) {
Ok(c) => c,
Expand All @@ -108,16 +103,31 @@ impl USBTransport {
for interface in config_desc.interfaces() {
for interface_desc in interface.descriptors() {
for endpoint_desc in interface_desc.endpoint_descriptors() {
if endpoint_desc.direction() == Direction::Out
&& endpoint_desc.transfer_type() == TransferType::Bulk
if endpoint_desc.transfer_type() == TransferType::Bulk
&& interface_desc.class_code() == LIBUSB_CLASS_VENDOR_SPEC
&& interface_desc.sub_class_code() == 0x42
&& interface_desc.protocol_code() == 0x01
{
return Ok(Endpoint {
let endpoint = Endpoint {
iface: interface_desc.interface_number(),
address: endpoint_desc.address(),
});
};
match endpoint_desc.direction() {
Direction::In => {
if let Some(write_endpoint) = write_endpoint {
return Ok((endpoint, write_endpoint));
} else {
read_endpoint = Some(endpoint);
}
}
Direction::Out => {
if let Some(read_endpoint) = read_endpoint {
return Ok((read_endpoint, endpoint));
} else {
write_endpoint = Some(endpoint);
}
}
}
}
}
}
Expand All @@ -130,7 +140,18 @@ impl USBTransport {

impl ADBTransport for USBTransport {
fn connect(&mut self) -> crate::Result<()> {
self.handle = Some(Arc::new(self.device.open()?));
let device = self.device.open()?;

let (read_endpoint, write_endpoint) = self.find_endpoints(&device)?;

Self::configure_endpoint(&device, &read_endpoint)?;
self.read_endpoint = Some(read_endpoint);

Self::configure_endpoint(&device, &write_endpoint)?;
self.write_endpoint = Some(write_endpoint);

self.handle = Some(Arc::new(device));

Ok(())
}

Expand All @@ -146,11 +167,9 @@ impl ADBMessageTransport for USBTransport {
message: ADBTransportMessage,
timeout: Duration,
) -> Result<()> {
let endpoint = self.find_writable_endpoint()?;
let endpoint = self.get_write_endpoint()?;
let handle = self.get_raw_connection()?;

Self::configure_endpoint(&handle, &endpoint)?;

let message_bytes = message.header().as_bytes()?;
let mut total_written = 0;
loop {
Expand All @@ -177,11 +196,9 @@ impl ADBMessageTransport for USBTransport {
}

fn read_message_with_timeout(&mut self, timeout: Duration) -> Result<ADBTransportMessage> {
let endpoint = self.find_readable_endpoint()?;
let endpoint = self.get_read_endpoint()?;
let handle = self.get_raw_connection()?;

Self::configure_endpoint(&handle, &endpoint)?;

let mut data = [0; 24];
let mut total_read = 0;
loop {
Expand Down

0 comments on commit c54942f

Please sign in to comment.