Skip to content

Commit

Permalink
progress on cuda-checkpoint, blocked on restoring with the same PID: N…
Browse files Browse the repository at this point in the history
  • Loading branch information
thundergolfer committed Nov 11, 2024
1 parent 94f5103 commit 26e9e70
Show file tree
Hide file tree
Showing 8 changed files with 201 additions and 57 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ clap = { version = "4.5.18", features = ["derive"] }
camino = "1.1.9"
tracing = "0.1.40"
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
memfd-exec = "0.2.1"

[dev-dependencies]
num_cpus = "1.12"
Expand Down
42 changes: 42 additions & 0 deletions build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
use std::env;
use std::fs;
use std::path::Path;
use std::process::Command;

fn main() {
// Define the URL of the file to download
let url = "https://github.com/NVIDIA/cuda-checkpoint/blob/main/bin/x86_64_Linux/cuda-checkpoint?raw=true";
let filename = "cuda-checkpoint";

// Determine the output directory for the binary
let out_dir = env::var("OUT_DIR").expect("OUT_DIR environment variable is not set");
let dest_path = Path::new(&out_dir).join(filename);

// Download the binary using curl
let status = Command::new("curl")
.arg("-L") // Follow redirects
.arg("-o")
.arg(&dest_path)
.arg(url)
.status()
.expect("Failed to execute curl");

if !status.success() {
panic!("Failed to download cuda-checkpoint");
}

// Make the binary executable
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let mut perms = fs::metadata(&dest_path)
.expect("Failed to retrieve metadata")
.permissions();
perms.set_mode(0o755);
fs::set_permissions(&dest_path, perms).expect("Failed to set permissions");
}

// Print cargo metadata to add the binary to the build process
println!("cargo:rerun-if-changed=build.rs");
println!("cargo:rerun-if-env-changed=OUT_DIR");
}
2 changes: 1 addition & 1 deletion examples/dumpme.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#
# I think this is because telefork isn't restoring any file descriptors.

for i in {1..10}
for i in {1..20}
do
echo "step $i"
sleep 1
Expand Down
47 changes: 47 additions & 0 deletions simpler_counter.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#include <stdio.h>
#include <unistd.h>
#include <cuda_runtime.h>

__device__ int counter = 100;

__global__ void increment()
{
counter++;
}

void checkCuda(cudaError_t result, const char *msg) {
if (result != cudaSuccess) {
fprintf(stderr, "CUDA Error: %s - %s\n", msg, cudaGetErrorString(result));
exit(1);
}
}

int main(void)
{
// Initialize CUDA
checkCuda(cudaFree(0), "Initializing CUDA");

// Initialize counter to 100 on the device
int initialCounter = 100;
checkCuda(cudaMemcpyToSymbol(counter, &initialCounter, sizeof(int)), "Initializing counter");

while (true) {
int hCounter = 0;

// Launch the increment kernel
increment<<<1, 1>>>();
checkCuda(cudaDeviceSynchronize(), "Kernel execution");

// Copy the counter from device to host
checkCuda(cudaMemcpyFromSymbol(&hCounter, counter, sizeof(counter)), "Copying counter to host");

// Print the current counter value
printf("%d\n", hCounter);

// Wait for 1 second
sleep(1);
}

return 0;
}

13 changes: 11 additions & 2 deletions src/cmd.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{teledump, telepad, wait_for_exit};
use crate::{cuda, teledump, telepad, wait_for_exit};
use std::fs::File;
use std::io::ErrorKind;
use std::path::Path;
Expand All @@ -9,19 +9,24 @@ pub fn dump(
pid: i32,
path: impl AsRef<Path>,
leave_running: bool,
cuda: bool,
) -> Result<(), Box<dyn std::error::Error>> {
let mut output = File::create(&path).map_err(|e| {
Box::new(std::io::Error::new(
ErrorKind::Other,
format!("Failed to create file: {}", e),
))
})?;
if cuda {
info!("toggling cuda state for pid {:?}", pid);
cuda::checkpoint(pid)?;
}
info!("dumping pid {:?}", pid);
teledump(pid, &mut output, leave_running)?;
Ok(())
}

pub fn restore(path: impl AsRef<Path>) -> Result<(), Box<dyn std::error::Error>> {
pub fn restore(path: impl AsRef<Path>, cuda: bool) -> Result<(), Box<dyn std::error::Error>> {
let mut input = File::open(&path).map_err(|e| {
Box::new(std::io::Error::new(
ErrorKind::Other,
Expand All @@ -30,6 +35,10 @@ pub fn restore(path: impl AsRef<Path>) -> Result<(), Box<dyn std::error::Error>>
})?;
info!("restoring from {:?}", path.as_ref());
let child = telepad(&mut input, 1)?;
if cuda {
info!("toggling cuda state for pid {:?}", child.as_raw());
cuda::checkpoint(child.as_raw())?;
}
let status = wait_for_exit(child).unwrap();
info!("child exited with status = {}", status);
Ok(())
Expand Down
30 changes: 30 additions & 0 deletions src/cuda.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
use memfd_exec::{MemFdExecutable, Stdio};

/// Returns the cuda-checkpoint executable as an array of bytes.
///
/// The cuda-checkpoint executable is used to checkpoint/restore CUDA state.
#[cfg(target_os = "linux")]
fn get_cuda_checkpoint_binary() -> &'static [u8] {
include_bytes!(concat!(env!("OUT_DIR"), "/cuda-checkpoint"))
}

/// Run cuda-checkpoint.
/// Ref: https://github.com/NVIDIA/cuda-checkpoint
pub fn checkpoint(pid: i32) -> Result<(), Box<dyn std::error::Error>> {
// The `MemFdExecutable` struct is at near feature-parity with `std::process::Command`,
// so you can use it in the same way. The only difference is that you must provide the
// executable contents as a `Vec<u8>` as well as telling it the argv[0] to use.
let c = MemFdExecutable::new("cuda-checkpoint", get_cuda_checkpoint_binary())
.arg("--toggle")
.args(["--pid", &pid.to_string().as_str()])
// We'll capture the stdout of the process, so we need to set up a pipe.
.stdout(Stdio::piped())
// Spawn the process as a forked child
.spawn()?;

// Get the output and status code of the process (this will block until the process
// exits)
let output = c.wait_with_output()?;
assert!(output.status.into_raw() == 0);
Ok(())
}
110 changes: 59 additions & 51 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use proc_maps;
// We use these to serialize our state over the wire
use bincode;
use serde::{Deserialize, Serialize};
use tracing::{info, warn};
use tracing::{debug, info, warn};

use std::collections::HashMap;
// Error handling
Expand All @@ -39,6 +39,7 @@ use std::net::{TcpStream, ToSocketAddrs};
use std::os::unix::io::FromRawFd;

pub mod cmd;
pub mod cuda;

type Result<T> = std::result::Result<T, Box<dyn Error>>;
const PAGE_SIZE: usize = 4096;
Expand Down Expand Up @@ -673,11 +674,11 @@ fn remote_lseek(child: Pid, syscall: SyscallLoc, fd: u32, offset: u64) -> Result
let SyscallLoc(loc) = syscall;
let regs = ptrace::getregs(child)?;
let syscall_regs = libc::user_regs_struct {
rip: loc as u64, // syscall instr (rip is the instruction pointer)
rax: 8, // lseek (rax holds the syscall number)
rdi: fd as u64, // (first argument to syscall goes in rdi)
rsi: offset as u64, // (second argument to syscall goes in rsi)
rdx: libc::SEEK_SET as u64, // (third argument to syscall goes in rdx)
rip: loc as u64, // syscall instr (rip is the instruction pointer)
rax: 8, // lseek (rax holds the syscall number)
rdi: fd as u64, // (first argument to syscall goes in rdi)
rsi: offset as u64, // (second argument to syscall goes in rsi)
rdx: libc::SEEK_SET as u64, // (third argument to syscall goes in rdx)
..regs
};
// == 2. Set the modified regs
Expand All @@ -696,7 +697,13 @@ fn remote_lseek(child: Pid, syscall: SyscallLoc, fd: u32, offset: u64) -> Result

/// TODO
fn restore_file_descriptors(child: Pid, syscall: SyscallLoc, cm: ConnectionMap) -> Result<()> {
fn restore_file(child: Pid, syscall: SyscallLoc, fd: u32, path: String, offset: u64) -> Result<()> {
fn restore_file(
child: Pid,
syscall: SyscallLoc,
fd: u32,
path: String,
offset: u64,
) -> Result<()> {
let open_fd = remote_open(child, syscall, &path, libc::O_RDONLY)?;
tracing::debug!("opened file descriptor {} for {}", open_fd, path);
remote_dup2(child, syscall, open_fd, fd)?;
Expand All @@ -713,7 +720,12 @@ fn restore_file_descriptors(child: Pid, syscall: SyscallLoc, cm: ConnectionMap)
warn!("skipping tcp file descriptor {}", fd);
}
Connection::File(FileConnection { path, offset }) => {
tracing::debug!("restoring file descriptor {} for {} at offset {}", fd, path, offset);
tracing::debug!(
"restoring file descriptor {} for {} at offset {}",
fd,
path,
offset
);
restore_file(child, syscall, fd, path, offset)?;
}
Connection::Stdio(_) => {
Expand Down Expand Up @@ -861,15 +873,6 @@ pub fn telepad(inp: &mut dyn Read, pass_to_child: i32) -> Result<Pid> {
// This lets the other process be stopped without triggering out waitpid,
// as well as to be debugged by a different ptrace-er

for i in 1..10000 {
if i == 1 {
tracing::debug!("step {}", i);
let regs = ptrace::getregs(child)?;
tracing::debug!("regs = {:?}", regs);
}
single_step(child)?;
}

tracing::debug!("detaching from child");
ptrace::detach(child, None)?;

Expand Down Expand Up @@ -1021,45 +1024,50 @@ fn scan_file_descriptors(pid: i32) -> Result<ConnectionMap> {
let entry = entry?;
let fd_path = entry.path();
let fd = fd_path.file_name().unwrap().to_string_lossy();
debug!("processing file descriptor {}", fd);
// Read the symbolic link to get the file descriptor target
let target = std::fs::read_link(&fd_path)?;
let metadata = std::fs::metadata(&target)?;
let file_type = metadata.file_type();
debug!("reading file descriptor {} target {} metadata", fd, target.display());
let metadata = std::fs::metadata(&target).ok();
info!("file descriptor {}: {:?}", fd, target);

if file_type.is_file() {
let fd = fd.parse::<u32>().unwrap();
let offset = get_fd_offset(pid, fd)?.unwrap_or(0);
cm.insert(
fd,
Connection::File(FileConnection {
path: target.to_string_lossy().to_string(),
offset,
}),
);
} else if file_type.is_dir() {
cm.insert(
fd.parse::<u32>().unwrap(),
Connection::File(FileConnection {
path: target.to_string_lossy().to_string(),
offset: 0,
}),
);
} else if file_type.is_socket() {
cm.insert(
fd.parse::<u32>().unwrap(),
Connection::Tcp(TcpConnection {
local_addr: target.to_string_lossy().to_string(),
remote_addr: target.to_string_lossy().to_string(),
}),
);
} else if file_type.is_char_device() {
let fd = fd.parse::<u32>().unwrap();
if matches!(fd, 0..=2) {
cm.insert(fd, Connection::Stdio(StdioConnection {}));
if let Some(file_type) = metadata.map(|m| m.file_type()) {
if file_type.is_file() {
let fd = fd.parse::<u32>().unwrap();
let offset = get_fd_offset(pid, fd)?.unwrap_or(0);
cm.insert(
fd,
Connection::File(FileConnection {
path: target.to_string_lossy().to_string(),
offset,
}),
);
} else if file_type.is_dir() {
cm.insert(
fd.parse::<u32>().unwrap(),
Connection::File(FileConnection {
path: target.to_string_lossy().to_string(),
offset: 0,
}),
);
} else if file_type.is_socket() {
cm.insert(
fd.parse::<u32>().unwrap(),
Connection::Tcp(TcpConnection {
local_addr: target.to_string_lossy().to_string(),
remote_addr: target.to_string_lossy().to_string(),
}),
);
} else if file_type.is_char_device() {
let fd = fd.parse::<u32>().unwrap();
if matches!(fd, 0..=2) {
cm.insert(fd, Connection::Stdio(StdioConnection {}));
} else {
warn!("saving unsupported file descriptor");
cm.insert(fd, Connection::Invalid);
}
} else {
warn!("saving unsupported file descriptor");
cm.insert(fd, Connection::Invalid);
cm.insert(fd.parse::<u32>().unwrap(), Connection::Invalid);
}
} else {
warn!("saving unsupported file descriptor");
Expand Down
13 changes: 10 additions & 3 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,17 @@ enum Command {
/// Restore the process running after dumping.
#[clap(long)]
leave_running: bool,
/// Enable CUDA support.
#[clap(long)]
cuda: bool,
},
/// Restore a process from a dumped file.
Restore {
/// The dumped file to restore from.
path: Utf8PathBuf,
/// Enable CUDA support.
#[clap(long)]
cuda: bool,
},
}

Expand All @@ -68,11 +74,12 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
process_id,
path,
leave_running,
cuda,
} => {
cmd::dump(process_id, path, leave_running)?;
cmd::dump(process_id, path, leave_running, cuda)?;
}
Command::Restore { path } => {
cmd::restore(path)?;
Command::Restore { path, cuda } => {
cmd::restore(path, cuda)?;
}
}
Ok(())
Expand Down

0 comments on commit 26e9e70

Please sign in to comment.