From f13ebf5e8c937843d4d033fdb9e6bbc202882a51 Mon Sep 17 00:00:00 2001 From: Matan Green Date: Wed, 22 Jan 2025 15:53:14 +0200 Subject: [PATCH 1/6] WIP: Added Go DI Exploration Testing --- .../testutil/exploration_e2e_test.go | 1625 +++++++++++++++++ .../patches/protobuf/integration_test.go | 586 ++++++ .../patches/protobuf/test.bash | 7 + .../protobuf/testing/prototest/message.go | 911 +++++++++ 4 files changed, 3129 insertions(+) create mode 100644 pkg/dynamicinstrumentation/testutil/exploration_e2e_test.go create mode 100644 pkg/dynamicinstrumentation/testutil/exploration_tests/patches/protobuf/integration_test.go create mode 100644 pkg/dynamicinstrumentation/testutil/exploration_tests/patches/protobuf/test.bash create mode 100644 pkg/dynamicinstrumentation/testutil/exploration_tests/patches/protobuf/testing/prototest/message.go diff --git a/pkg/dynamicinstrumentation/testutil/exploration_e2e_test.go b/pkg/dynamicinstrumentation/testutil/exploration_e2e_test.go new file mode 100644 index 00000000000000..d1b0b669a8f0a1 --- /dev/null +++ b/pkg/dynamicinstrumentation/testutil/exploration_e2e_test.go @@ -0,0 +1,1625 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2016-present Datadog, Inc. + +//go:build linux_bpf + +package testutil + +import ( + "bufio" + "bytes" + "debug/dwarf" + "debug/elf" + "encoding/json" + "fmt" + "html/template" + "io" + "os" + "os/exec" + "path/filepath" + "sort" + "strconv" + "strings" + "sync" + "syscall" + "testing" + "time" + + "github.com/cilium/ebpf" + "github.com/cilium/ebpf/features" + "github.com/cilium/ebpf/rlimit" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "golang.org/x/sys/unix" + + "github.com/DataDog/datadog-agent/pkg/dynamicinstrumentation" + "github.com/DataDog/datadog-agent/pkg/dynamicinstrumentation/diconfig" + "github.com/DataDog/datadog-agent/pkg/dynamicinstrumentation/ditypes" +) + +type ProcessState int + +const ( + StateNew ProcessState = iota + StateAnalyzing + StateRunning + StateExited +) + +func (s ProcessState) String() string { + switch s { + case StateNew: + return "NEW" + case StateAnalyzing: + return "ANALYZING" + case StateRunning: + return "RUNNING" + case StateExited: + return "EXITED" + default: + return "UNKNOWN" + } +} + +type ProcessInfo struct { + PID int + BinaryPath string + ParentPID int + State ProcessState + Children []*ProcessInfo + StartTime time.Time + Analyzed bool +} + +type ProcessTracker struct { + t *testing.T + mu sync.RWMutex + processes map[int]*ProcessInfo + mainPID int + stopChan chan struct{} + analyzedBinaries map[string]bool + analyzedPIDs map[int]bool + done chan struct{} +} + +type ProbeManager struct { + t *testing.T + installedProbes sync.Map // maps pid -> map[string]struct{} + dataReceived sync.Map // maps pid -> map[string]bool + mu sync.Mutex +} + +func NewProbeManager(t *testing.T) *ProbeManager { + return &ProbeManager{ + t: t, + } +} + +func (pm *ProbeManager) Install(pid int, function string) error { + pm.mu.Lock() + defer pm.mu.Unlock() + + // Get or create the map of installed probes for this PID + v, _ := pm.installedProbes.LoadOrStore(pid, make(map[string]struct{})) + probes := v.(map[string]struct{}) + + // Install the probe + probes[function] = struct{}{} + pm.t.Logf("πŸ”§ Installing probe: PID=%d Function=%s", pid, function) + + // Your actual probe installation logic here using GoDI + // Example: + // err := pm.godi.InstallProbe(pid, function) + return nil +} + +func (pm *ProbeManager) Remove(pid int, function string) error { + pm.mu.Lock() + defer pm.mu.Unlock() + + if v, ok := pm.installedProbes.Load(pid); ok { + probes := v.(map[string]struct{}) + delete(probes, function) + pm.t.Logf("πŸ”§ Removing probe: PID=%d Function=%s", pid, function) + + // Your actual probe removal logic here + } + return nil +} + +func (pm *ProbeManager) CollectData(pid int, function string) (bool, error) { + // Check if we've received data for this probe + // This is where you'd check your actual data collection mechanism + + // For testing, let's simulate data collection + // In reality, you'd check if your probe has published any data + if v, ok := pm.dataReceived.Load(pid); ok { + dataMap := v.(map[string]bool) + return dataMap[function], nil + } + return false, nil +} + +func NewProcessTracker(t *testing.T) *ProcessTracker { + return &ProcessTracker{ + t: t, + processes: make(map[int]*ProcessInfo), + stopChan: make(chan struct{}), + analyzedBinaries: make(map[string]bool), + analyzedPIDs: make(map[int]bool), + done: make(chan struct{}), + } +} + +func (pt *ProcessTracker) markAnalyzed(pid int, path string) { + pt.mu.Lock() + defer pt.mu.Unlock() + pt.analyzedPIDs[pid] = true + pt.analyzedBinaries[path] = true +} + +func getProcessArgs(pid int) ([]string, error) { + // Construct the path to the /proc//cmdline file + procFile := fmt.Sprintf("/proc/%d/cmdline", pid) + + // Read the file content + data, err := os.ReadFile(procFile) + if err != nil { + return nil, err + } + + // The arguments are null-byte separated, split them + args := strings.Split(string(data), "\x00") + // Remove any trailing empty string caused by the trailing null byte + if len(args) > 0 && args[len(args)-1] == "" { + args = args[:len(args)-1] + } + return args, nil +} + +func getProcessCwd(pid int) (string, error) { + // Construct the path to the /proc//cwd symlink + procFile := fmt.Sprintf("/proc/%d/cwd", pid) + + // Read the symlink to find the current working directory + cwd, err := os.Readlink(procFile) + if err != nil { + return "", err + } + return cwd, nil +} + +func getProcessEnv(pid int) ([]string, error) { + // Construct the path to the /proc//environ file + procFile := fmt.Sprintf("/proc/%d/environ", pid) + + // Open and read the file + data, err := os.ReadFile(procFile) + if err != nil { + return nil, err + } + + // The environment variables are null-byte separated, split them + env := strings.Split(string(data), "\x00") + // Remove any trailing empty string caused by the trailing null byte + if len(env) > 0 && env[len(env)-1] == "" { + env = env[:len(env)-1] + } + return env, nil +} + +func hasDWARFInfo(binaryPath string) (bool, error) { + f, err := elf.Open(binaryPath) + if err != nil { + return false, fmt.Errorf("failed to open binary: %w", err) + } + defer f.Close() + + // Try both approaches: section lookup and DWARF data reading + debugSections := false + for _, section := range f.Sections { + if strings.HasPrefix(section.Name, ".debug_") { + fmt.Printf("Found debug section: %s (size: %d)\n", section.Name, section.Size) + debugSections = true + } + } + + // Try to actually read DWARF data + dwarfData, err := f.DWARF() + if err != nil { + return debugSections, fmt.Errorf("DWARF read error: %w", err) + } + + // Verify we can read some DWARF data + reader := dwarfData.Reader() + entry, err := reader.Next() + if err != nil { + return debugSections, fmt.Errorf("DWARF entry read error: %w", err) + } + if entry != nil { + fmt.Printf("Found DWARF entry of type: %v\n", entry.Tag) + return true, nil + } + + return false, nil +} + +type BinaryInfo struct { + path string + hasDebug bool +} + +type FunctionInfo struct { + PackageName string + FunctionName string + FullName string + ProbeId string +} + +func NewFunctionInfo(packageName, functionName, fullName string) FunctionInfo { + return FunctionInfo{ + PackageName: packageName, + FunctionName: functionName, + FullName: fullName, + ProbeId: uuid.NewString(), + } +} + +func extractPackageAndFunction(fullName string) FunctionInfo { + // Handle empty input + if fullName == "" { + return FunctionInfo{} + } + + // First, find the last index of "." before any parentheses + parenIndex := strings.Index(fullName, "(") + lastDot := -1 + if parenIndex != -1 { + // If we have parentheses, look for the last dot before them + lastDot = strings.LastIndex(fullName[:parenIndex], ".") + } else { + // If no parentheses, just find the last dot + lastDot = strings.LastIndex(fullName, ".") + } + + if lastDot == -1 { + return FunctionInfo{} + } + + // Split into package and function parts + pkgPath := fullName[:lastDot] + funcPart := fullName[lastDot+1:] + + return NewFunctionInfo(pkgPath, funcPart, fullName) +} + +func listAllFunctions(filePath string) ([]FunctionInfo, error) { + var functions []FunctionInfo + var errors []string + + ef, err := elf.Open(filePath) + if err != nil { + return nil, fmt.Errorf("failed to open file: %v", err) + } + defer ef.Close() + + dwarfData, err := ef.DWARF() + if err != nil { + return nil, fmt.Errorf("failed to load DWARF data: %v", err) + } + + reader := dwarfData.Reader() + + for { + entry, err := reader.Next() + if err != nil { + return nil, fmt.Errorf("error reading DWARF entry: %v", err) + } + if entry == nil { + break + } + + if entry.Tag == dwarf.TagSubprogram { + funcName, ok := entry.Val(dwarf.AttrName).(string) + if !ok || funcName == "" { + continue + } + + info := extractPackageAndFunction(funcName) + if info.FunctionName == "" { + errors = append(errors, fmt.Sprintf("could not extract function name from %q", funcName)) + continue + } + + functions = append(functions, info) + } + } + + if len(functions) == 0 { + if len(errors) > 0 { + return nil, fmt.Errorf("failed to extract any functions. Errors: %s", strings.Join(errors, "; ")) + } + return nil, fmt.Errorf("no functions found in the binary") + } + + return functions, nil +} + +// func isStandardPackage(pkg string) bool { +// // List of common standard library packages that might be nested +// stdPkgs := map[string]bool{ +// "encoding/json": true, +// "compress/flate": true, +// "compress/gzip": true, +// "encoding/base64": true, +// // Add more as needed +// } +// return stdPkgs[pkg] +// } + +// func listAllFunctions(filePath string) ([]FunctionInfo, error) { +// var functions []FunctionInfo + +// // Open the ELF file +// ef, err := elf.Open(filePath) +// if err != nil { +// return nil, fmt.Errorf("failed to open file: %v", err) +// } +// defer ef.Close() + +// // Retrieve symbols from the ELF file +// symbols, err := ef.Symbols() +// if err != nil { +// return nil, fmt.Errorf("failed to read symbols: %v", err) +// } + +// // Iterate over symbols and filter function symbols +// for _, sym := range symbols { +// if elf.ST_TYPE(sym.Info) == elf.STT_FUNC { +// // Extract function name +// functionName := sym.Name + +// // Extract package name from section index (if applicable) +// // DWARF data or additional analysis can refine this +// packageName := "" + +// // Add to result +// functions = append(functions, FunctionInfo{ +// PackageName: packageName, +// FunctionName: functionName, +// }) +// } +// } +// return functions, nil +// } + +func shouldProfileFunction(name string) bool { + // First, immediately reject known system/internal functions + if strings.HasPrefix(name, "*ZN") || // Sanitizer/LLVM functions + strings.HasPrefix(name, "_") || // Internal functions + strings.Contains(name, "_sanitizer") || + strings.Contains(name, "runtime.") { + return false + } + + // Extract package from function name + parts := strings.Split(name, ".") + if len(parts) < 2 { + return false + } + + pkgPath := parts[0] + if len(parts) > 2 { + pkgPath = strings.Join(parts[:len(parts)-1], "/") + } + + // Check if it's in our repository packages + for repoPkg := range g_RepoInfo.Packages { + if strings.Contains(pkgPath, repoPkg) { + return true + } + } + + return false +} + +// func shouldProfileFunction(name string) bool { +// // Skip standard library packages +// stdlibPrefixes := []string{ +// "bufio.", +// "bytes.", +// "context.", +// "crypto.", +// "compress/", +// "database/", +// "debug/", +// "encoding/", +// "errors.", +// "flag.", +// "fmt.", +// "io.", +// "log.", +// "math.", +// "net.", +// "os.", +// "path.", +// "reflect.", +// "regexp.", +// "runtime.", +// "sort.", +// "strconv.", +// "strings.", +// "sync.", +// "syscall.", +// "time.", +// "unicode.", +// } + +// // Definitely skip these system internals +// skipPrefixes := []string{ +// "runtime.", +// "runtime/race", +// "*ZN", // LLVM/Clang internals +// "type..", // Go type metadata +// "gc.", // Garbage collector +// "gosb.", // Go sandbox +// "_rt.", // Runtime helpers +// "reflect.", // Reflection internals +// } + +// skipContains := []string{ +// "_sanitizer", +// "_tsan", +// ".constprop.", // Compiler generated constants +// ".isra.", // LLVM optimized functions +// ".part.", // Partial functions from compiler +// "__gcc_", // GCC internals +// "_cgo_", // CGO generated code +// "goexit", // Go runtime exit handlers +// "gcproc", // GC procedures +// ".loc.", // Location metadata +// "runtimeΒ·", // Runtime internals (different dot) +// } + +// // Quick reject for standard library and system functions +// for _, prefix := range append(stdlibPrefixes, skipPrefixes...) { +// if strings.HasPrefix(name, prefix) { +// return false +// } +// } + +// for _, substr := range skipContains { +// if strings.Contains(name, substr) { +// return false +// } +// } + +// // High priority user functions - definitely profile these +// priorityPrefixes := []string{ +// "main.", +// "cmd.", +// "github.com/", +// "golang.org/x/", +// "google.golang.org/", +// "k8s.io/", +// } + +// for _, prefix := range priorityPrefixes { +// if strings.HasPrefix(name, prefix) { +// return true +// } +// } + +// // Function looks like a normal Go function (CapitalizedName) +// if len(name) > 0 && unicode.IsUpper(rune(name[0])) { +// return true +// } + +// // If it contains a dot and doesn't look like a compiler-generated name +// if strings.Contains(name, ".") && +// !strings.Contains(name, "$") && +// !strings.Contains(name, "__") { +// return true +// } + +// // If we get here, it's probably a system function +// return false +// } + +var NUMBER_OF_PROBES int = 100 + +func filterFunctions(funcs []FunctionInfo) []FunctionInfo { + var validFuncs []FunctionInfo + + // First pass: collect only functions from our packages + for _, f := range funcs { + // Combine package and function name for filtering + fullName := fmt.Sprintf("%s.%s", f.PackageName, f.FunctionName) + if shouldProfileFunction(fullName) { + validFuncs = append(validFuncs, f) + } + } + + // If we have no valid functions, return empty list + if len(validFuncs) == 0 { + return nil + } + + // Sort valid functions for consistent ordering + sort.Slice(validFuncs, func(i, j int) bool { + // Sort alphabetically by full name (package + function) + fullNameI := fmt.Sprintf("%s.%s", validFuncs[i].PackageName, validFuncs[i].FunctionName) + fullNameJ := fmt.Sprintf("%s.%s", validFuncs[j].PackageName, validFuncs[j].FunctionName) + return fullNameI < fullNameJ + }) + + // Return all if we have 10 or fewer + if len(validFuncs) <= NUMBER_OF_PROBES { + return validFuncs + } + + // Only take first 10 if we have more + return validFuncs[:NUMBER_OF_PROBES] +} + +// func filterFunctions(funcs []string) []string { +// var validFuncs []string + +// // First pass: collect only functions from our packages +// for _, f := range funcs { +// if shouldProfileFunction(f) { +// validFuncs = append(validFuncs, f) +// } +// } + +// // If we have no valid functions, return empty list +// if len(validFuncs) == 0 { +// return nil +// } + +// // Sort for consistent ordering +// sort.Strings(validFuncs) + +// // Return all if we have 10 or fewer +// if len(validFuncs) <= NUMBER_OF_PROBES { +// return validFuncs +// } + +// // Only take first 10 if we have more +// return validFuncs[:NUMBER_OF_PROBES] +// } + +func ExtractFunctions(binaryPath string) ([]FunctionInfo, error) { + // Open the binary + file, err := elf.Open(binaryPath) + if err != nil { + return nil, fmt.Errorf("failed to open binary: %v", err) + } + defer file.Close() + + // Get DWARF data + dwarfData, err := file.DWARF() + if err != nil { + return nil, fmt.Errorf("failed to load DWARF data: %v", err) + } + + // Prepare result + var functions []FunctionInfo + + // Iterate over DWARF entries + reader := dwarfData.Reader() + for { + entry, err := reader.Next() + if err != nil { + return nil, fmt.Errorf("error reading DWARF: %v", err) + } + if entry == nil { + break // End of entries + } + + // Check for subprogram (function) entries + if entry.Tag == dwarf.TagSubprogram { + // Extract function name + funcName, _ := entry.Val(dwarf.AttrName).(string) + + // Extract package/module name (if available) + var packageName string + if compDir, ok := entry.Val(dwarf.AttrCompDir).(string); ok { + packageName = compDir + } + + // Add to the result + if funcName != "" { + functions = append(functions, FunctionInfo{ + PackageName: packageName, + FunctionName: funcName, + }) + } + } + } + + return functions, nil +} + +// hasDWARF checks if the given binary contains DWARF debug information. +func hasDWARF(binaryPath string) (bool, error) { + // Open the binary file + file, err := elf.Open(binaryPath) + if err != nil { + return false, fmt.Errorf("failed to open binary: %v", err) + } + defer file.Close() + + // Check if DWARF data exists + _, err = file.DWARF() + if err != nil { + // Check if the error indicates missing DWARF information + if err.Error() == "no DWARF data" { + return false, nil + } + // Otherwise, propagate the error + return false, fmt.Errorf("failed to check DWARF data: %v", err) + } + + // DWARF data exists + return true, nil +} + +var analyzedBinaries []BinaryInfo +var waitForAttach bool = true + +func InspectBinary(t *testing.T, binaryPath string, pid int) error { + // // check that we can analyse the binary without targeting a specific function + // err := diconfig.AnalyzeBinary(&ditypes.ProcessInfo{BinaryPath: binaryPath}) + // if err != nil { + // // log.Fatalln("Failed to analyze", binaryPath, "--", err) + // return nil + // } + + // targets, err := ExtractFunctions(binaryPath) + // if err != nil { + // // log.Fatalf("Error extracting functions: %v", err) + // return nil + // } + + // hasDwarf, err := hasDWARF(binaryPath) + // if err != nil || !hasDwarf { + // // log.Fatalf("Error checking for DWARF info: %v", err) + // return nil + // } + + allFuncs, err := listAllFunctions(binaryPath) + if err != nil { + analyzedBinaries = append(analyzedBinaries, BinaryInfo{ + path: binaryPath, + hasDebug: false, + }) + + return nil + } + + // targets := filterFunctions(allFuncs) + targets := allFuncs + + // Get process arguments + args, err := getProcessArgs(pid) + if err != nil { + return fmt.Errorf("Failed to process args: %v", err) + } + + // Get process current working directory + cwd, err := getProcessCwd(pid) + if err != nil { + return fmt.Errorf("Failed to get Cwd: %v", err) + } + + // // Get process environment variables + // env, err := getProcessEnv(pid) + // if err != nil { + // return fmt.Errorf("Failed to get Env: %v", err) + // } + + LogDebug(t, "\n=======================================") + LogDebug(t, "πŸ” ANALYZING BINARY: %s", binaryPath) + LogDebug(t, "πŸ” ARGS: %v", args) + LogDebug(t, "πŸ” CWD: %s", cwd) + LogDebug(t, "πŸ” Elected %d target functions:", len(targets)) + for _, f := range targets { + LogDebug(t, " β†’ Package: %s, Function: %s, FullName: %s", f.PackageName, f.FunctionName, f.FullName) + } + + // hasDWARF, dwarfErr := hasDWARFInfo(binaryPath) + // if dwarfErr != nil { + // log.Printf("Error checking DWARF info: %v", dwarfErr) + // } else { + // log.Printf("Binary has DWARF info: %v", hasDWARF) + // } + // LogDebug(t, "πŸ” ENV: %v", env) + LogDebug(t, "=======================================") + + // Check if the binary exists + if _, err := os.Stat(binaryPath); err != nil { + return fmt.Errorf("(1) binary inspection failed: %v", err) + } + + analyzedBinaries = append(analyzedBinaries, BinaryInfo{ + path: binaryPath, + hasDebug: len(targets) > 0, + }) + + // i := 0 + // // Re-check binary existence + // for { + // if _, err := os.Stat(binaryPath); err != nil { + // time.Sleep(10 * time.Hour) + // return fmt.Errorf("(2) binary inspection failed: %v", err) + // } + + // // if strings.HasSuffix(binaryPath, "generate-protos") { + // // break + // // } + + // if strings.HasSuffix(binaryPath, "conformance.test") { + // time.Sleep(10 * time.Second) + // break + // } + + // i++ + // if i > 11 { + // break + // } + + // // time.Sleep(100 * time.Millisecond) + // } + + LogDebug(t, "βœ… Analysis complete for: %s", binaryPath) + LogDebug(t, "=======================================\n") + + t.Logf("About to request instrumentations for binary: %s, pid: %d.", binaryPath, pid) + + cfgTemplate, err := template.New("config_template").Parse(explorationTestConfigTemplateText) + require.NoError(t, err) + + b := []byte{} + var buf *bytes.Buffer + + // if waitForAttach { + // pid := os.Getpid() + // t.Logf("(1) Waiting to attach for PID: %d", pid) + // time.Sleep(30 * time.Second) + // waitForAttach = false + // } + + requesterdFuncs := 0 + for _, f := range targets { + + // if !strings.Contains(f.FullName, "blabla_blabla") { + // continue + // } + + if !strings.Contains(f.FullName, "FullName") { + continue + } + + // if f.FullName != "regexp.(*bitState).shouldVisit" { + // continue + // } + + // if f.FullName != "google.golang.org/protobuf/encoding/protodelim_test.(*notBufioReader).UnreadRune" { + // continue + // } + + buf = bytes.NewBuffer(b) + err = cfgTemplate.Execute(buf, f) + if err != nil { + continue + } + + // LogDebug(t, "Requesting instrumentation for %v", f) + t.Logf("Requesting instrumentation for %v", f) + _, err := g_ConfigManager.ConfigWriter.Write(buf.Bytes()) + + if err != nil { + continue + } + + requesterdFuncs++ + } + + if !waitForAttach { + time.Sleep(100 * time.Second) + } + + if requesterdFuncs > 0 { + // if waitForAttach { + // pid := os.Getpid() + // t.Logf("(2) Waiting to attach for PID: %d", pid) + // time.Sleep(30 * time.Second) + // waitForAttach = false + // } + + // Wait for probes to be instrumented + time.Sleep(2 * time.Second) + + t.Logf("Requested to instrument %d functions for binary: %s, pid: %d.", requesterdFuncs, binaryPath, pid) + } + + return nil +} + +func (pt *ProcessTracker) addProcess(pid int, parentPID int) *ProcessInfo { + pt.mu.Lock() + defer pt.mu.Unlock() + + if proc, exists := pt.processes[pid]; exists { + return proc + } + + binaryPath := getBinaryPath(pid) + proc := &ProcessInfo{ + PID: pid, + ParentPID: parentPID, + BinaryPath: binaryPath, + State: StateNew, + StartTime: time.Now(), + Analyzed: false, + } + + pt.processes[pid] = proc + + // Add to parent's children if parent exists + if parent, exists := pt.processes[parentPID]; exists { + parent.Children = append(parent.Children, proc) + } + + pt.LogTrace("πŸ‘Ά New process: PID=%d, Parent=%d, Binary=%s", pid, parentPID, binaryPath) + return proc +} + +func getBinaryPath(pid int) string { + path, err := os.Readlink(fmt.Sprintf("/proc/%d/exe", pid)) + if err != nil { + return "" + } + + // Resolve any symlinks + realPath, err := filepath.EvalSymlinks(path) + if err == nil { + path = realPath + } + + return path +} + +func (pt *ProcessTracker) analyzeBinary(pid int, info *ProcessInfo) error { + if info == nil { + return fmt.Errorf("nil process info") + } + + pt.mu.Lock() + info.State = StateAnalyzing + pt.mu.Unlock() + + // pt.LogTrace("πŸ”Ž Analyzing binary PID=%d Path=%s", pid, info.BinaryPath) + + // Perform analysis + if err := InspectBinary(pt.t, info.BinaryPath, pid); err != nil { + pt.mu.Lock() + info.State = StateNew + pt.mu.Unlock() + return fmt.Errorf("binary analysis failed: %v", err) + } + + pt.mu.Lock() + info.State = StateRunning + pt.mu.Unlock() + + return nil +} + +func getParentPID(pid int) int { + ppidStr, err := os.ReadFile(fmt.Sprintf("/proc/%d/stat", pid)) + if err != nil { + return 0 + } + fields := strings.Fields(string(ppidStr)) + if len(fields) < 4 { + return 0 + } + ppid, _ := strconv.Atoi(fields[3]) + return ppid +} + +func (pt *ProcessTracker) scanProcessTree() error { + if err := syscall.Kill(-g_cmd.Process.Pid, syscall.SIGSTOP); err != nil { + if err != unix.ESRCH { + pt.LogTrace("⚠️ Failed to stop PID %d: %v", -g_cmd.Process.Pid, err) + } + return nil + } + + // pt.profiler.OnProcessesPaused() + + defer func() { + if err := syscall.Kill(-g_cmd.Process.Pid, syscall.SIGCONT); err != nil { + if err != unix.ESRCH { + pt.LogTrace("⚠️ Failed to resume PID %d: %v", -g_cmd.Process.Pid, err) + } + } else { + // pt.LogTrace("▢️ Resumed process: PID=%d", -g_cmd.Process.Pid) + } + + // pt.profiler.OnProcessesResumed() + + // if err := unix.Kill(pid, unix.SIGCONT); err != nil { + // if err != unix.ESRCH { + // pt.LogTrace("⚠️ Failed to resume PID %d: %v", pid, err) + // } + // } else { + // pt.LogTrace("▢️ Resumed process: PID=%d", pid) + // } + }() + + // Get all processes + allPids := make(map[int]bool) + if entries, err := os.ReadDir("/proc"); err == nil { + for _, entry := range entries { + if pid, err := strconv.Atoi(entry.Name()); err == nil { + allPids[pid] = true + } + } + } + + // Record our own process tree for exclusion + ourProcessTree := make(map[int]bool) + ourPid := os.Getpid() + findAncestors(ourPid, ourProcessTree) + + var toAnalyze []struct { + pid int + path string + ppid int + } + + // Check each PID + for pid := range allPids { + // Skip if already analyzed + pt.mu.RLock() + if pt.analyzedPIDs[pid] { + pt.mu.RUnlock() + continue + } + pt.mu.RUnlock() + + // Skip if in our process tree + if ourProcessTree[pid] { + continue + } + + // Get process path + binaryPath := getBinaryPath(pid) + if binaryPath == "" { + continue + } + + // Get parent PID + ppid := getParentPID(pid) + + // Skip if parent is in our tree + if ourProcessTree[ppid] { + continue + } + + // Always analyze: + // 1. Test binaries (.test) + // 2. Go build executables in /tmp + // 3. Children of test binaries + shouldAnalyze := false + + if strings.HasSuffix(binaryPath, ".test") { + shouldAnalyze = true + pt.LogTrace("Found test binary: %s (PID=%d)", binaryPath, pid) + } else if strings.Contains(binaryPath, "/go-build") && strings.Contains(binaryPath, "/exe/") { + shouldAnalyze = true + pt.LogTrace("Found build binary: %s (PID=%d)", binaryPath, pid) + } else { + // Check if parent is a test binary + parentPath := getBinaryPath(ppid) + if strings.HasSuffix(parentPath, ".test") { + shouldAnalyze = true + pt.LogTrace("Found child of test: %s (PID=%d, Parent=%d)", binaryPath, pid, ppid) + } + } + + if shouldAnalyze { + // Verify process still exists + if _, err := os.Stat(fmt.Sprintf("/proc/%d", pid)); err == nil { + toAnalyze = append(toAnalyze, struct { + pid int + path string + ppid int + }{pid, binaryPath, ppid}) + + // Add to process tree + if pt.processes[pid] == nil { + pt.addProcess(pid, ppid) + } + } + } + } + + if len(toAnalyze) > 0 { + pt.LogTrace("\nπŸ” Found %d processes to analyze:", len(toAnalyze)) + for _, p := range toAnalyze { + pt.LogTrace(" PID=%d PPID=%d Path=%s", p.pid, p.ppid, p.path) + } + } + + var activePids []int + for _, p := range toAnalyze { + activePids = append(activePids, p.pid) + } + + // if pt.profiler!= nil { + // pt.profiler.OnTick(activePids) + // } + + // Process in small batches + batchSize := 2 + for i := 0; i < len(toAnalyze); i += batchSize { + end := i + batchSize + if end > len(toAnalyze) { + end = len(toAnalyze) + } + + var wg sync.WaitGroup + for _, p := range toAnalyze[i:end] { + wg.Add(1) + go func(pid int, path string) { + defer wg.Done() + + // Verify process still exists + if _, err := os.Stat(fmt.Sprintf("/proc/%d", pid)); err != nil { + return + } + + pt.LogTrace("πŸ” Stopping process for analysis: PID=%d Path=%s", pid, path) + + // Get process info + pt.mu.RLock() + proc := pt.processes[pid] + pt.mu.RUnlock() + + if proc == nil { + return + } + + // Stop process + // if err := syscall.Kill(-g_cmd.Process.Pid, syscall.SIGSTOP); err != nil { + // if err != unix.ESRCH { + // pt.LogTrace("⚠️ Failed to stop PID %d: %v", pid, err) + // } + // return + // } + + // if err := unix.Kill(pid, unix.SIGSTOP); err != nil { + // if err != unix.ESRCH { + // pt.LogTrace("⚠️ Failed to stop PID %d: %v", pid, err) + // } + // return + // } + + // Ensure process gets resumed + // defer func() { + // if err := syscall.Kill(-g_cmd.Process.Pid, syscall.SIGCONT); err != nil { + // if err != unix.ESRCH { + // pt.LogTrace("⚠️ Failed to resume PID %d: %v", pid, err) + // } + // } else { + // pt.LogTrace("▢️ Resumed process: PID=%d", pid) + // } + + // // if err := unix.Kill(pid, unix.SIGCONT); err != nil { + // // if err != unix.ESRCH { + // // pt.LogTrace("⚠️ Failed to resume PID %d: %v", pid, err) + // // } + // // } else { + // // pt.LogTrace("▢️ Resumed process: PID=%d", pid) + // // } + // }() + + // Wait a bit after stopping + // time.Sleep(1 * time.Millisecond) + + // Analyze with timeout + if err := pt.analyzeBinary(pid, proc); err != nil { + pt.LogTrace("⚠️ Analysis failed: %v", err) + } else { + proc.Analyzed = true + pt.markAnalyzed(pid, path) + // pt.LogTrace("βœ… Analysis complete: PID=%d", pid) + } + + // go func() { + // if err := pt.analyzeBinary(pid, proc); err != nil { + // pt.LogTrace("⚠️ Analysis failed: %v", err) + // done <- false + // return + // } + + // proc.Analyzed = true + // pt.markAnalyzed(pid, path) + // pt.LogTrace("βœ… Analysis complete: PID=%d", pid) + // done <- true + // }() + }(p.pid, p.path) + } + wg.Wait() + + // Wait between batches + time.Sleep(10 * time.Microsecond) + } + + return nil +} + +func (pt *ProcessTracker) Cleanup() { +} + +// Helper to record process tree starting from a PID +func findAncestors(pid int, tree map[int]bool) { + for pid > 1 { + if tree[pid] { + return // Already visited + } + tree[pid] = true + + // Get parent + ppid := getParentPID(pid) + if ppid <= 1 { + return + } + pid = ppid + } +} + +var g_cmd *exec.Cmd + +func (pt *ProcessTracker) StartTracking(command string, args []string, dir string) error { + // ctx, cancel := context.WithCancel(context.Background()) + // defer cancel() + + // if err := pt.profiler.Start(ctx); err != nil { + // return fmt.Errorf("failed to start profiler: %w", err) + // } + // defer pt.profiler.Stop() + + cmd := exec.Command(command, args...) + g_cmd = cmd + + if dir != "" { + cmd.Dir = dir + } + cmd.Env = append( + os.Environ(), + "PWD="+dir, + "DD_DYNAMIC_INSTRUMENTATION_ENABLED=true", + "DD_SERVICE=go-di-exploration-test-service") + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} + + if err := cmd.Start(); err != nil { + return fmt.Errorf("failed to start command: %v", err) + } + + pt.mainPID = cmd.Process.Pid + pt.addProcess(pt.mainPID, os.Getpid()) + + // Start scanning with high frequency initially + go func() { + // Initial high-frequency scanning + initialTicker := time.NewTicker(1 * time.Millisecond) + defer initialTicker.Stop() + + // After initial period, reduce frequency slightly + // time.AfterFunc(5*time.Second, func() { + // initialTicker.Stop() + // }) + + // regularTicker := time.NewTicker(10 * time.Millisecond) + // defer regularTicker.Stop() + + logTicker := time.NewTicker(10 * time.Second) + defer logTicker.Stop() + + for { + select { + case <-pt.stopChan: + return + case <-initialTicker.C: + if err := pt.scanProcessTree(); err != nil { + pt.LogTrace("⚠️ Error scanning: %v", err) + } + case <-logTicker.C: + // pt.logProcessTree() + } + } + }() + + err := cmd.Wait() + close(pt.stopChan) + + pt.LogTrace("Analyzed %d binaries.", len(analyzedBinaries)) + + for _, binary := range analyzedBinaries { + pt.LogTrace("Analyzed %s (debug info: %v)", binary.path, binary.hasDebug) + } + + return err +} + +func (pt *ProcessTracker) logProcessTree() { + pt.mu.RLock() + defer pt.mu.RUnlock() + + pt.t.Log("\n🌳 Process Tree:") + var printNode func(proc *ProcessInfo, prefix string) + printNode = func(proc *ProcessInfo, prefix string) { + state := "➑️" + switch proc.State { + case StateAnalyzing: + state = "πŸ”" + case StateRunning: + state = "▢️" + case StateExited: + state = "⏹️" + } + + analyzed := "" + if proc.Analyzed { + analyzed = "βœ“" + } + + pt.LogTrace("%s%s [PID=%d] %s%s (Parent=%d)", + prefix, state, proc.PID, filepath.Base(proc.BinaryPath), analyzed, proc.ParentPID) + + for _, child := range proc.Children { + printNode(child, prefix+" ") + } + } + + if main, exists := pt.processes[pt.mainPID]; exists { + printNode(main, "") + } +} + +var DEBUG bool = false +var TRACE bool = false + +func (pt *ProcessTracker) LogTrace(format string, args ...any) { + if TRACE { + pt.t.Logf(format, args...) + } +} + +func LogDebug(t *testing.T, format string, args ...any) { + if DEBUG { + t.Logf(format, args...) + } +} + +var g_RepoInfo *RepoInfo +var g_ConfigManager *diconfig.ReaderConfigManager + +func TestExplorationGoDI(t *testing.T) { + require.NoError(t, rlimit.RemoveMemlock(), "Failed to remove memlock limit") + if features.HaveMapType(ebpf.RingBuf) != nil { + t.Skip("Ringbuffers not supported on this kernel") + } + + eventOutputWriter := &explorationEventOutputTestWriter{ + t: t, + } + + opts := &dynamicinstrumentation.DIOptions{ + RateLimitPerProbePerSecond: 0.0, + ReaderWriterOptions: dynamicinstrumentation.ReaderWriterOptions{ + CustomReaderWriters: true, + SnapshotWriter: eventOutputWriter, + DiagnosticWriter: os.Stderr, + }, + } + + var ( + GoDI *dynamicinstrumentation.GoDI + err error + ) + + GoDI, err = dynamicinstrumentation.RunDynamicInstrumentation(opts) + require.NoError(t, err) + t.Cleanup(GoDI.Close) + + cm, ok := GoDI.ConfigManager.(*diconfig.ReaderConfigManager) + if !ok { + t.Fatal("Config manager is of wrong type") + } + + g_ConfigManager = cm + + tempDir := initializeTempDir(t, "/tmp/protobuf-integration-1060272402") + modulePath := filepath.Join(tempDir, "src", "google.golang.org", "protobuf") + + t.Log("Setting up test environment...") + g_RepoInfo = cloneProtobufRepo(t, modulePath, "30f628eeb303f2c29be7a381bf78aa3e3aabd317") + copyPatches(t, "exploration_tests/patches/protobuf", modulePath) + + t.Log("Starting process tracking...") + tracker := NewProcessTracker(t) + err = tracker.StartTracking("./test.bash", nil, modulePath) + require.NoError(t, err) +} + +type explorationEventOutputTestWriter struct { + t *testing.T + expectedResult map[string]*ditypes.CapturedValue +} + +func (e *explorationEventOutputTestWriter) Write(p []byte) (n int, err error) { + var snapshot ditypes.SnapshotUpload + if err := json.Unmarshal(p, &snapshot); err != nil { + e.t.Error("failed to unmarshal snapshot", err) + } + + funcName := snapshot.Debugger.ProbeInSnapshot.Type + "." + snapshot.Debugger.ProbeInSnapshot.Method + e.t.Logf("Received snapshot for function: %s", funcName) + + return len(p), nil +} + +func initializeTempDir(t *testing.T, predefinedTempDir string) string { + if predefinedTempDir != "" { + return predefinedTempDir + } + tempDir, err := os.MkdirTemp("", "protobuf-integration-") + require.NoError(t, err) + require.NoError(t, os.Chmod(tempDir, 0755)) + t.Log("tempDir:", tempDir) + return tempDir +} + +// RepoInfo holds scanned repository package information +type RepoInfo struct { + Packages map[string]bool // Package names found in repo + RepoPath string // Path to the repo + CommitHash string // Current commit hash (optional) +} + +func ScanRepoPackages(repoPath string) (*RepoInfo, error) { + info := &RepoInfo{ + Packages: make(map[string]bool), + RepoPath: repoPath, + } + + // Get git hash if available + if _, err := os.Stat(filepath.Join(repoPath, ".git")); err == nil { + if hash, err := exec.Command("git", "-C", repoPath, "rev-parse", "HEAD").Output(); err == nil { + info.CommitHash = strings.TrimSpace(string(hash)) + } + } + + err := filepath.Walk(repoPath, func(path string, f os.FileInfo, err error) error { + if err != nil { + return nil + } + + // Skip certain directories + if f.IsDir() { + dirname := filepath.Base(path) + if dirname == ".git" || + dirname == ".cache" || + dirname == "vendor" || + dirname == "testdata" || + strings.HasPrefix(dirname, ".") || + strings.HasPrefix(dirname, "tmp") { + return filepath.SkipDir + } + return nil + } + + // Only process .go files + if !strings.HasSuffix(path, ".go") { + return nil + } + + // Skip test files and generated files + if strings.HasSuffix(path, "_test.go") || + strings.HasSuffix(path, ".pb.go") { + return nil + } + + // Ensure the file is within the repo (not in .cache etc) + relPath, err := filepath.Rel(repoPath, path) + if err != nil || strings.Contains(relPath, "..") { + return nil + } + + content, err := os.ReadFile(path) + if err != nil { + return nil + } + + scanner := bufio.NewScanner(bytes.NewReader(content)) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if strings.HasPrefix(line, "package ") { + pkgDir := filepath.Dir(relPath) + if pkgDir != "." { + info.Packages[pkgDir] = true + } + break + } + } + return nil + }) + + if len(info.Packages) == 0 { + return nil, fmt.Errorf("no packages found in repository at %s", repoPath) + } + + return info, err +} + +func cloneProtobufRepo(t *testing.T, modulePath string, commitHash string) *RepoInfo { + if _, err := os.Stat(modulePath); os.IsNotExist(err) { + cmd := exec.Command("git", "clone", "https://github.com/protocolbuffers/protobuf-go", modulePath) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + require.NoError(t, cmd.Run(), "Failed to clone repository") + } + + if commitHash != "" { + cmd := exec.Command("git", "checkout", commitHash) + cmd.Dir = modulePath + require.NoError(t, cmd.Run(), "Failed to checkout commit hash") + } + + // Scan packages after clone/checkout + info, err := ScanRepoPackages(modulePath) + require.NoError(t, err, "Failed to scan repo packages") + + // Log the organized package information + var pkgs []string + for pkg := range info.Packages { + if strings.Contains(pkg, "/tmp") { + continue + } + pkgs = append(pkgs, pkg) + } + sort.Strings(pkgs) + + t.Logf("πŸ“¦ Found %d packages in protobuf repo:", len(pkgs)) + + // Group packages by their top-level directory + groups := make(map[string][]string) + for _, pkg := range pkgs { + parts := strings.SplitN(pkg, "/", 2) + topLevel := parts[0] + groups[topLevel] = append(groups[topLevel], pkg) + } + + // Print grouped packages + var topLevels []string + for k := range groups { + topLevels = append(topLevels, k) + } + sort.Strings(topLevels) + + for _, topLevel := range topLevels { + t.Logf(" %s/", topLevel) + for _, pkg := range groups[topLevel] { + t.Logf(" β†’ %s", pkg) + } + } + + return info +} + +func copyPatches(t *testing.T, src, dst string) { + require.NoError(t, copyDir(src, dst), "Failed to copy patches") +} + +func copyDir(src, dst string) error { + entries, err := os.ReadDir(src) + if err != nil { + return err + } + if err := os.MkdirAll(dst, 0755); err != nil { + return err + } + + for _, entry := range entries { + srcPath := filepath.Join(src, entry.Name()) + dstPath := filepath.Join(dst, entry.Name()) + + info, err := entry.Info() + if err != nil { + return err + } + + if info.IsDir() { + if err = copyDir(srcPath, dstPath); err != nil { + return err + } + } else { + if err = copyFile(srcPath, dstPath); err != nil { + return err + } + } + } + return nil +} + +func copyFile(srcFile, dstFile string) error { + src, err := os.Open(srcFile) + if err != nil { + return err + } + defer src.Close() + + if err = os.MkdirAll(filepath.Dir(dstFile), 0755); err != nil { + return err + } + + dst, err := os.Create(dstFile) + if err != nil { + return err + } + defer dst.Close() + + _, err = io.Copy(dst, src) + return err +} + +var explorationTestConfigTemplateText = ` +{ + "go-di-exploration-test-service": { + "{{.ProbeId}}": { + "id": "{{.ProbeId}}", + "version": 0, + "type": "LOG_PROBE", + "language": "go", + "where": { + "typeName": "{{.PackageName}}", + "methodName": "{{.FunctionName}}" + }, + "tags": [], + "template": "Executed {{.PackageName}}.{{.FunctionName}}, it took {@duration}ms", + "segments": [ + { + "str": "Executed {{.PackageName}}.{{.FunctionName}}, it took " + }, + { + "dsl": "@duration", + "json": { + "ref": "@duration" + } + }, + { + "str": "ms" + } + ], + "captureSnapshot": false, + "capture": { + "maxReferenceDepth": 10 + }, + "sampling": { + "snapshotsPerSecond": 5000 + }, + "evaluateAt": "EXIT" + } + } +} +` diff --git a/pkg/dynamicinstrumentation/testutil/exploration_tests/patches/protobuf/integration_test.go b/pkg/dynamicinstrumentation/testutil/exploration_tests/patches/protobuf/integration_test.go new file mode 100644 index 00000000000000..ad9b259b8a85d4 --- /dev/null +++ b/pkg/dynamicinstrumentation/testutil/exploration_tests/patches/protobuf/integration_test.go @@ -0,0 +1,586 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "archive/tar" + "archive/zip" + "bytes" + "compress/gzip" + "crypto/sha256" + "flag" + "fmt" + "io" + "io/fs" + "net/http" + "os" + "os/exec" + "path/filepath" + "regexp" + "runtime" + "runtime/debug" + "strings" + "sync" + "testing" + "time" + + "google.golang.org/protobuf/internal/version" +) + +var ( + regenerate = flag.Bool("regenerate", false, "regenerate files") + buildRelease = flag.Bool("buildRelease", false, "build release binaries") + + protobufVersion = "27.0" + + golangVersions = func() []string { + // Version policy: oldest supported version of Go, plus the version before that. + // This matches the version policy of the Google Cloud Client Libraries: + // https://cloud.google.com/go/getting-started/supported-go-versions + return []string{ + "1.21.13", + "1.22.6", + "1.23.0", + } + }() + golangLatest = golangVersions[len(golangVersions)-1] + + staticcheckVersion = "2024.1.1" + staticcheckSHA256s = map[string]string{ + "darwin/amd64": "b67380b84b81d5765b478b7ad888dd7ce53b2c0861103bafa946ac84dc9244ce", + "darwin/arm64": "09cb10e4199f7c6356c2ed5dc45e877c3087ef775d84d39338b52e1a94866074", + "linux/386": "0225fd8b5cf6c762f9c0aedf1380ed4df576d1d54fb68691be895889e10faf0b", + "linux/amd64": "6e9398fcaff2b36e1d15e84a647a3a14733b7c2dd41187afa2c182a4c3b32180", + } + + // purgeTimeout determines the maximum age of unused sub-directories. + purgeTimeout = 30 * 24 * time.Hour // 1 month + + // Variables initialized by mustInitDeps. + modulePath string + protobufPath string +) + +func TestIntegration(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + if os.Getenv("GO_BUILDER_NAME") != "" { + // To start off, run on longtest builders, not longtest-race ones. + if race() { + t.Skip("skipping integration test in race mode on builders") + } + // When on a builder, run even if it's not explicitly requested + // provided our caller isn't already running it. + if os.Getenv("GO_PROTOBUF_INTEGRATION_TEST_RUNNING") == "1" { + t.Skip("protobuf integration test is already running, skipping nested invocation") + } + os.Setenv("GO_PROTOBUF_INTEGRATION_TEST_RUNNING", "1") + } else if flag.Lookup("test.run").Value.String() != "^TestIntegration$" { + t.Skip("not running integration test if not explicitly requested via test.bash") + } + + mustInitDeps(t) + mustHandleFlags(t) + + // Report dirt in the working tree quickly, rather than after + // going through all the presubmits. + // + // Fail the test late, so we can test uncommitted changes with -failfast. + // gitDiff := mustRunCommand(t, "git", "diff", "HEAD") + // if strings.TrimSpace(gitDiff) != "" { + // fmt.Printf("WARNING: working tree contains uncommitted changes:\n%v\n", gitDiff) + // } + // gitUntracked := mustRunCommand(t, "git", "ls-files", "--others", "--exclude-standard") + // if strings.TrimSpace(gitUntracked) != "" { + // fmt.Printf("WARNING: working tree contains untracked files:\n%v\n", gitUntracked) + // } + + // Do the relatively fast checks up-front. + t.Run("GeneratedGoFiles", func(t *testing.T) { + diff := mustRunCommand(t, "go", "run", "-tags", "protolegacy", "./internal/cmd/generate-types") + if strings.TrimSpace(diff) != "" { + t.Fatalf("stale generated files:\n%v", diff) + } + diff = mustRunCommand(t, "go", "run", "-tags", "protolegacy", "./internal/cmd/generate-protos") + if strings.TrimSpace(diff) != "" { + t.Fatalf("stale generated files:\n%v", diff) + } + }) + t.Run("FormattedGoFiles", func(t *testing.T) { + files := strings.Split(strings.TrimSpace(mustRunCommand(t, "git", "ls-files", "*.go")), "\n") + diff := mustRunCommand(t, append([]string{"gofmt", "-d"}, files...)...) + if strings.TrimSpace(diff) != "" { + t.Fatalf("unformatted source files:\n%v", diff) + } + }) + t.Run("CopyrightHeaders", func(t *testing.T) { + files := strings.Split(strings.TrimSpace(mustRunCommand(t, "git", "ls-files", "*.go", "*.proto")), "\n") + mustHaveCopyrightHeader(t, files) + }) + + var wg sync.WaitGroup + sema := make(chan bool, (runtime.NumCPU()+1)/2) + for i := range golangVersions { + goVersion := golangVersions[i] + goLabel := "Go" + goVersion + runGo := func(label string, cmd command, args ...string) { + wg.Add(1) + sema <- true + go func() { + defer wg.Done() + defer func() { <-sema }() + t.Run(goLabel+"/"+label, func(t *testing.T) { + args[0] += goVersion + cmd.mustRun(t, args...) + }) + }() + } + + runGo("Normal", command{}, "go", "test", "-race", "./...") + runGo("Reflect", command{}, "go", "test", "-race", "-tags", "protoreflect", "./...") + if goVersion == golangLatest { + runGo("ProtoLegacyRace", command{}, "go", "test", "-race", "-tags", "protolegacy", "./...") + runGo("ProtoLegacy", command{}, "go", "test", "-tags", "protolegacy", "./...") + runGo("ProtocGenGo", command{Dir: "cmd/protoc-gen-go/testdata"}, "go", "test") + runGo("Conformance", command{Dir: "internal/conformance"}, "go", "test", "-execute") + + // Only run the 32-bit compatibility tests for Linux; + // avoid Darwin since 10.15 dropped support i386 code execution. + // if runtime.GOOS == "linux" { + // runGo("Arch32Bit", command{Env: append(os.Environ(), "GOARCH=386")}, "go", "test", "./...") + // } + } + } + wg.Wait() + + t.Run("GoStaticCheck", func(t *testing.T) { + checks := []string{ + "all", // start with all checks enabled + "-SA1019", // disable deprecated usage check + "-S*", // disable code simplification checks + "-ST*", // disable coding style checks + "-U*", // disable unused declaration checks + } + out := mustRunCommand(t, "staticcheck", "-checks="+strings.Join(checks, ","), "-fail=none", "./...") + + // Filter out findings from certain paths. + var findings []string + for _, finding := range strings.Split(strings.TrimSpace(out), "\n") { + switch { + case strings.HasPrefix(finding, "internal/testprotos/legacy/"): + default: + findings = append(findings, finding) + } + } + if len(findings) > 0 { + t.Fatalf("staticcheck findings:\n%v", strings.Join(findings, "\n")) + } + }) + // t.Run("CommittedGitChanges", func(t *testing.T) { + // if strings.TrimSpace(gitDiff) != "" { + // t.Fatalf("uncommitted changes") + // } + // }) + // t.Run("TrackedGitFiles", func(t *testing.T) { + // if strings.TrimSpace(gitUntracked) != "" { + // t.Fatalf("untracked files") + // } + // }) +} + +func mustInitDeps(t *testing.T) { + check := func(err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } + } + + // Determine the directory to place the test directory. + repoRoot, err := os.Getwd() + check(err) + testDir := filepath.Join(repoRoot, ".cache") + check(os.MkdirAll(testDir, 0775)) + + // Delete the current directory if non-empty, + // which only occurs if a dependency failed to initialize properly. + var workingDir string + finishedDirs := map[string]bool{} + defer func() { + if workingDir != "" { + os.RemoveAll(workingDir) // best-effort + } + }() + startWork := func(name string) string { + workingDir = filepath.Join(testDir, name) + return workingDir + } + finishWork := func() { + finishedDirs[workingDir] = true + workingDir = "" + } + + // Delete other sub-directories that are no longer relevant. + defer func() { + now := time.Now() + fis, _ := os.ReadDir(testDir) + for _, fi := range fis { + dir := filepath.Join(testDir, fi.Name()) + if finishedDirs[dir] { + os.Chtimes(dir, now, now) // best-effort + continue + } + fii, err := fi.Info() + check(err) + if now.Sub(fii.ModTime()) < purgeTimeout { + continue + } + fmt.Printf("delete %v\n", fi.Name()) + os.RemoveAll(dir) // best-effort + } + }() + + // The bin directory contains symlinks to each tool by version. + // It is safe to delete this directory and run the test script from scratch. + binPath := startWork("bin") + check(os.RemoveAll(binPath)) + check(os.Mkdir(binPath, 0775)) + check(os.Setenv("PATH", binPath+":"+os.Getenv("PATH"))) + registerBinary := func(name, path string) { + check(os.Symlink(path, filepath.Join(binPath, name))) + } + finishWork() + + // Get the protobuf toolchain. + protobufPath = startWork("protobuf-" + protobufVersion) + if _, err := os.Stat(protobufPath); err != nil { + fmt.Printf("download %v\n", filepath.Base(protobufPath)) + checkoutVersion := protobufVersion + if isCommit := strings.Trim(protobufVersion, "0123456789abcdef") == ""; !isCommit { + // release tags have "v" prefix + checkoutVersion = "v" + protobufVersion + } + command{Dir: testDir}.mustRun(t, "git", "clone", "https://github.com/protocolbuffers/protobuf", "protobuf-"+protobufVersion) + command{Dir: protobufPath}.mustRun(t, "git", "checkout", checkoutVersion) + + if os.Getenv("GO_BUILDER_NAME") != "" { + // If this is running on the Go build infrastructure, + // use pre-built versions of these binaries that the + // builders are configured to provide in $PATH. + protocPath, err := exec.LookPath("protoc") + check(err) + confTestRunnerPath, err := exec.LookPath("conformance_test_runner") + check(err) + check(os.MkdirAll(filepath.Join(protobufPath, "bazel-bin", "conformance"), 0775)) + check(os.Symlink(protocPath, filepath.Join(protobufPath, "bazel-bin", "protoc"))) + check(os.Symlink(confTestRunnerPath, filepath.Join(protobufPath, "bazel-bin", "conformance", "conformance_test_runner"))) + } else { + // In other environments, download and build the protobuf toolchain. + // We avoid downloading the pre-compiled binaries since they do not contain + // the conformance test runner. + fmt.Printf("build %v\n", filepath.Base(protobufPath)) + env := os.Environ() + args := []string{ + "bazel", "build", + ":protoc", + "//conformance:conformance_test_runner", + } + if runtime.GOOS == "darwin" { + // Adding this environment variable appears to be necessary for macOS builds. + env = append(env, "CC=clang") + // And this flag. + args = append(args, + "--macos_minimum_os=13.0", + "--host_macos_minimum_os=13.0", + ) + } + command{ + Dir: protobufPath, + Env: env, + }.mustRun(t, args...) + } + } + check(os.Setenv("PROTOBUF_ROOT", protobufPath)) // for generate-protos + registerBinary("conform-test-runner", filepath.Join(protobufPath, "bazel-bin", "conformance", "conformance_test_runner")) + registerBinary("protoc", filepath.Join(protobufPath, "bazel-bin", "protoc")) + finishWork() + + // Download each Go toolchain version. + for _, v := range golangVersions { + goDir := startWork("go" + v) + if _, err := os.Stat(goDir); err != nil { + fmt.Printf("download %v\n", filepath.Base(goDir)) + url := fmt.Sprintf("https://dl.google.com/go/go%v.%v-%v.tar.gz", v, runtime.GOOS, runtime.GOARCH) + downloadArchive(check, goDir, url, "go", "") // skip SHA256 check as we fetch over https from a trusted domain + } + registerBinary("go"+v, filepath.Join(goDir, "bin", "go")) + finishWork() + } + registerBinary("go", filepath.Join(testDir, "go"+golangLatest, "bin", "go")) + registerBinary("gofmt", filepath.Join(testDir, "go"+golangLatest, "bin", "gofmt")) + + // Download the staticcheck tool. + checkDir := startWork("staticcheck-" + staticcheckVersion) + if _, err := os.Stat(checkDir); err != nil { + fmt.Printf("download %v\n", filepath.Base(checkDir)) + url := fmt.Sprintf("https://github.com/dominikh/go-tools/releases/download/%v/staticcheck_%v_%v.tar.gz", staticcheckVersion, runtime.GOOS, runtime.GOARCH) + downloadArchive(check, checkDir, url, "staticcheck", staticcheckSHA256s[runtime.GOOS+"/"+runtime.GOARCH]) + } + registerBinary("staticcheck", filepath.Join(checkDir, "staticcheck")) + finishWork() + + // GitHub actions sets GOROOT, which confuses invocations of the Go toolchain. + // Explicitly clear GOROOT, so each toolchain uses their default GOROOT. + check(os.Unsetenv("GOROOT")) + + // Set a cache directory outside the test directory. + check(os.Setenv("GOCACHE", filepath.Join(repoRoot, ".gocache"))) +} + +func downloadFile(check func(error), dstPath, srcURL string, perm fs.FileMode) { + resp, err := http.Get(srcURL) + check(err) + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 4<<10)) + check(fmt.Errorf("GET %q: non-200 OK status code: %v body: %q", srcURL, resp.Status, body)) + } + + check(os.MkdirAll(filepath.Dir(dstPath), 0775)) + f, err := os.OpenFile(dstPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, perm) + check(err) + + _, err = io.Copy(f, resp.Body) + check(err) + + check(f.Close()) +} + +func downloadArchive(check func(error), dstPath, srcURL, skipPrefix, wantSHA256 string) { + check(os.RemoveAll(dstPath)) + + resp, err := http.Get(srcURL) + check(err) + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 4<<10)) + check(fmt.Errorf("GET %q: non-200 OK status code: %v body: %q", srcURL, resp.Status, body)) + } + + var r io.Reader = resp.Body + if wantSHA256 != "" { + b, err := io.ReadAll(resp.Body) + check(err) + r = bytes.NewReader(b) + + if gotSHA256 := fmt.Sprintf("%x", sha256.Sum256(b)); gotSHA256 != wantSHA256 { + check(fmt.Errorf("checksum validation error:\ngot %v\nwant %v", gotSHA256, wantSHA256)) + } + } + + zr, err := gzip.NewReader(r) + check(err) + + tr := tar.NewReader(zr) + for { + h, err := tr.Next() + if err == io.EOF { + return + } + check(err) + + // Skip directories or files outside the prefix directory. + if len(skipPrefix) > 0 { + if !strings.HasPrefix(h.Name, skipPrefix) { + continue + } + if len(h.Name) > len(skipPrefix) && h.Name[len(skipPrefix)] != '/' { + continue + } + } + + path := strings.TrimPrefix(strings.TrimPrefix(h.Name, skipPrefix), "/") + path = filepath.Join(dstPath, filepath.FromSlash(path)) + mode := os.FileMode(h.Mode & 0777) + switch h.Typeflag { + case tar.TypeReg: + b, err := io.ReadAll(tr) + check(err) + check(os.WriteFile(path, b, mode)) + case tar.TypeDir: + check(os.Mkdir(path, mode)) + } + } +} + +func mustHandleFlags(t *testing.T) { + if *regenerate { + t.Run("Generate", func(t *testing.T) { + fmt.Print(mustRunCommand(t, "go", "generate", "./internal/cmd/generate-types")) + fmt.Print(mustRunCommand(t, "go", "generate", "./internal/cmd/generate-protos")) + files := strings.Split(strings.TrimSpace(mustRunCommand(t, "git", "ls-files", "*.go")), "\n") + mustRunCommand(t, append([]string{"gofmt", "-w"}, files...)...) + }) + } + if *buildRelease { + t.Run("BuildRelease", func(t *testing.T) { + v := version.String() + for _, goos := range []string{"linux", "darwin", "windows"} { + for _, goarch := range []string{"386", "amd64", "arm64"} { + // Avoid Darwin since 10.15 dropped support for i386. + if goos == "darwin" && goarch == "386" { + continue + } + + binPath := filepath.Join("bin", fmt.Sprintf("protoc-gen-go.%v.%v.%v", v, goos, goarch)) + + // Build the binary. + cmd := command{Env: append(os.Environ(), "GOOS="+goos, "GOARCH="+goarch)} + cmd.mustRun(t, "go", "build", "-trimpath", "-ldflags", "-s -w -buildid=", "-o", binPath, "./cmd/protoc-gen-go") + + // Archive and compress the binary. + in, err := os.ReadFile(binPath) + if err != nil { + t.Fatal(err) + } + out := new(bytes.Buffer) + suffix := "" + comment := fmt.Sprintf("protoc-gen-go VERSION=%v GOOS=%v GOARCH=%v", v, goos, goarch) + switch goos { + case "windows": + suffix = ".zip" + zw := zip.NewWriter(out) + zw.SetComment(comment) + fw, _ := zw.Create("protoc-gen-go.exe") + fw.Write(in) + zw.Close() + default: + suffix = ".tar.gz" + gz, _ := gzip.NewWriterLevel(out, gzip.BestCompression) + gz.Comment = comment + tw := tar.NewWriter(gz) + tw.WriteHeader(&tar.Header{ + Name: "protoc-gen-go", + Mode: int64(0775), + Size: int64(len(in)), + }) + tw.Write(in) + tw.Close() + gz.Close() + } + if err := os.WriteFile(binPath+suffix, out.Bytes(), 0664); err != nil { + t.Fatal(err) + } + } + } + }) + } + if *regenerate || *buildRelease { + t.SkipNow() + } +} + +var copyrightRegex = []*regexp.Regexp{ + regexp.MustCompile(`^// Copyright \d\d\d\d The Go Authors\. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file\. +`), + // Generated .pb.go files from main protobuf repo. + regexp.MustCompile(`^// Protocol Buffers - Google's data interchange format +// Copyright \d\d\d\d Google Inc\. All rights reserved\. +`), +} + +func mustHaveCopyrightHeader(t *testing.T, files []string) { + var bad []string +File: + for _, file := range files { + if strings.HasSuffix(file, "internal/testprotos/conformance/editions/test_messages_edition2023.pb.go") { + // TODO(lassefolger) the underlying proto file is checked into + // the protobuf repo without a copyright header. Fix is pending but + // might require a release. + continue + } + b, err := os.ReadFile(file) + if err != nil { + t.Fatal(err) + } + for _, re := range copyrightRegex { + if loc := re.FindIndex(b); loc != nil && loc[0] == 0 { + continue File + } + } + bad = append(bad, file) + } + if len(bad) > 0 { + t.Fatalf("files with missing/bad copyright headers:\n %v", strings.Join(bad, "\n ")) + } +} + +// Add in command struct: +type command struct { + Dir string + Env []string +} + +func (c command) mustRun(t *testing.T, args ...string) string { + t.Helper() + stdout := new(bytes.Buffer) + stderr := new(bytes.Buffer) + + var cmdArgs []string + if len(args) > 1 && strings.HasPrefix(args[0], "go") && args[1] == "test" { + for i, arg := range args { + cmdArgs = append(cmdArgs, arg) + if i == 1 { // right after "test" + cmdArgs = append(cmdArgs, "-ldflags=-w=false -s=false", "-count=1", "-timeout=30m") + } + } + } else { + cmdArgs = args + } + + cmd := exec.Command(cmdArgs[0], cmdArgs[1:]...) + cmd.Dir = "." + if c.Dir != "" { + cmd.Dir = c.Dir + } + cmd.Env = os.Environ() + if c.Env != nil { + cmd.Env = c.Env + } + cmd.Env = append(cmd.Env, "PWD="+cmd.Dir) + cmd.Stdout = stdout + cmd.Stderr = stderr + + if err := cmd.Run(); err != nil { + t.Fatalf("executing (%v): %v\n%s%s", strings.Join(args, " "), err, stdout.String(), stderr.String()) + } + + return stdout.String() +} + +func mustRunCommand(t *testing.T, args ...string) string { + t.Helper() + return command{}.mustRun(t, args...) +} + +// race is an approximation of whether the race detector is on. +// It's used to skip the integration test on builders, without +// preventing the integration test from running under the race +// detector as a '//go:build !race' build constraint would. +func race() bool { + bi, ok := debug.ReadBuildInfo() + if !ok { + return false + } + for _, setting := range bi.Settings { + if setting.Key == "-race" { + return setting.Value == "true" + } + } + return false +} diff --git a/pkg/dynamicinstrumentation/testutil/exploration_tests/patches/protobuf/test.bash b/pkg/dynamicinstrumentation/testutil/exploration_tests/patches/protobuf/test.bash new file mode 100644 index 00000000000000..aec89522a116a5 --- /dev/null +++ b/pkg/dynamicinstrumentation/testutil/exploration_tests/patches/protobuf/test.bash @@ -0,0 +1,7 @@ +#!/bin/bash +# Copyright 2018 The Go Authors. All rights reserved. +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +go test google.golang.org/protobuf -run='^TestIntegration$' -v -timeout=60m -count=1 -failfast "$@" +exit $? \ No newline at end of file diff --git a/pkg/dynamicinstrumentation/testutil/exploration_tests/patches/protobuf/testing/prototest/message.go b/pkg/dynamicinstrumentation/testutil/exploration_tests/patches/protobuf/testing/prototest/message.go new file mode 100644 index 00000000000000..8f9af17e604744 --- /dev/null +++ b/pkg/dynamicinstrumentation/testutil/exploration_tests/patches/protobuf/testing/prototest/message.go @@ -0,0 +1,911 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package prototest exercises protobuf reflection. +package prototest + +import ( + "bytes" + "fmt" + "math" + "reflect" + "sort" + "strings" + "testing" + + "google.golang.org/protobuf/encoding/prototext" + "google.golang.org/protobuf/encoding/protowire" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/reflect/protoregistry" +) + +// TODO: Test invalid field descriptors or oneof descriptors. +// TODO: This should test the functionality that can be provided by fast-paths. + +// Message tests a message implementation. +type Message struct { + // Resolver is used to determine the list of extension fields to test with. + // If nil, this defaults to using protoregistry.GlobalTypes. + Resolver interface { + FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) + FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) + RangeExtensionsByMessage(message protoreflect.FullName, f func(protoreflect.ExtensionType) bool) + } + + // UnmarshalOptions are respected for every Unmarshal call this package + // does. The Resolver and AllowPartial fields are overridden. + UnmarshalOptions proto.UnmarshalOptions +} + +//nolint:all +//go:noinline +func blabla_blabla(x bool) {} + +// Test performs tests on a [protoreflect.MessageType] implementation. +func (test Message) Test(t testing.TB, mt protoreflect.MessageType) { + testType(t, mt) + + // for { + // blabla_blabla(true) + // time.Sleep(1 * time.Second) + // } + + md := mt.Descriptor() + m1 := mt.New() + for i := 0; i < md.Fields().Len(); i++ { + fd := md.Fields().Get(i) + testField(t, m1, fd) + } + if test.Resolver == nil { + test.Resolver = protoregistry.GlobalTypes + } + var extTypes []protoreflect.ExtensionType + test.Resolver.RangeExtensionsByMessage(md.FullName(), func(e protoreflect.ExtensionType) bool { + extTypes = append(extTypes, e) + return true + }) + for _, xt := range extTypes { + testField(t, m1, xt.TypeDescriptor()) + } + for i := 0; i < md.Oneofs().Len(); i++ { + testOneof(t, m1, md.Oneofs().Get(i)) + } + testUnknown(t, m1) + + // Test round-trip marshal/unmarshal. + m2 := mt.New().Interface() + populateMessage(m2.ProtoReflect(), 1, nil) + for _, xt := range extTypes { + m2.ProtoReflect().Set(xt.TypeDescriptor(), newValue(m2.ProtoReflect(), xt.TypeDescriptor(), 1, nil)) + } + b, err := proto.MarshalOptions{ + AllowPartial: true, + }.Marshal(m2) + if err != nil { + t.Errorf("Marshal() = %v, want nil\n%v", err, prototext.Format(m2)) + } + m3 := mt.New().Interface() + unmarshalOpts := test.UnmarshalOptions + unmarshalOpts.AllowPartial = true + unmarshalOpts.Resolver = test.Resolver + if err := unmarshalOpts.Unmarshal(b, m3); err != nil { + t.Errorf("Unmarshal() = %v, want nil\n%v", err, prototext.Format(m2)) + } + if !proto.Equal(m2, m3) { + t.Errorf("round-trip marshal/unmarshal did not preserve message\nOriginal:\n%v\nNew:\n%v", prototext.Format(m2), prototext.Format(m3)) + } +} + +func testType(t testing.TB, mt protoreflect.MessageType) { + m := mt.New().Interface() + want := reflect.TypeOf(m) + if got := reflect.TypeOf(m.ProtoReflect().Interface()); got != want { + t.Errorf("type mismatch: reflect.TypeOf(m) != reflect.TypeOf(m.ProtoReflect().Interface()): %v != %v", got, want) + } + if got := reflect.TypeOf(m.ProtoReflect().New().Interface()); got != want { + t.Errorf("type mismatch: reflect.TypeOf(m) != reflect.TypeOf(m.ProtoReflect().New().Interface()): %v != %v", got, want) + } + if got := reflect.TypeOf(m.ProtoReflect().Type().Zero().Interface()); got != want { + t.Errorf("type mismatch: reflect.TypeOf(m) != reflect.TypeOf(m.ProtoReflect().Type().Zero().Interface()): %v != %v", got, want) + } + if mt, ok := mt.(protoreflect.MessageFieldTypes); ok { + testFieldTypes(t, mt) + } +} + +func testFieldTypes(t testing.TB, mt protoreflect.MessageFieldTypes) { + descName := func(d protoreflect.Descriptor) protoreflect.FullName { + if d == nil { + return "" + } + return d.FullName() + } + typeName := func(mt protoreflect.MessageType) protoreflect.FullName { + if mt == nil { + return "" + } + return mt.Descriptor().FullName() + } + adjustExpr := func(idx int, expr string) string { + expr = strings.Replace(expr, "fd.", "md.Fields().Get(i).", -1) + expr = strings.Replace(expr, "(fd)", "(md.Fields().Get(i))", -1) + expr = strings.Replace(expr, "mti.", "mt.Message(i).", -1) + expr = strings.Replace(expr, "(i)", fmt.Sprintf("(%d)", idx), -1) + return expr + } + checkEnumDesc := func(idx int, gotExpr, wantExpr string, got, want protoreflect.EnumDescriptor) { + if got != want { + t.Errorf("descriptor mismatch: %v != %v: %v != %v", adjustExpr(idx, gotExpr), adjustExpr(idx, wantExpr), descName(got), descName(want)) + } + } + checkMessageDesc := func(idx int, gotExpr, wantExpr string, got, want protoreflect.MessageDescriptor) { + if got != want { + t.Errorf("descriptor mismatch: %v != %v: %v != %v", adjustExpr(idx, gotExpr), adjustExpr(idx, wantExpr), descName(got), descName(want)) + } + } + checkMessageType := func(idx int, gotExpr, wantExpr string, got, want protoreflect.MessageType) { + if got != want { + t.Errorf("type mismatch: %v != %v: %v != %v", adjustExpr(idx, gotExpr), adjustExpr(idx, wantExpr), typeName(got), typeName(want)) + } + } + + fds := mt.Descriptor().Fields() + m := mt.New() + for i := 0; i < fds.Len(); i++ { + fd := fds.Get(i) + switch { + case fd.IsList(): + if fd.Enum() != nil { + checkEnumDesc(i, + "mt.Enum(i).Descriptor()", "fd.Enum()", + mt.Enum(i).Descriptor(), fd.Enum()) + } + if fd.Message() != nil { + checkMessageDesc(i, + "mt.Message(i).Descriptor()", "fd.Message()", + mt.Message(i).Descriptor(), fd.Message()) + checkMessageType(i, + "mt.Message(i)", "m.NewField(fd).List().NewElement().Message().Type()", + mt.Message(i), m.NewField(fd).List().NewElement().Message().Type()) + } + case fd.IsMap(): + mti := mt.Message(i) + if m := mti.New(); m != nil { + checkMessageDesc(i, + "m.Descriptor()", "fd.Message()", + m.Descriptor(), fd.Message()) + } + if m := mti.Zero(); m != nil { + checkMessageDesc(i, + "m.Descriptor()", "fd.Message()", + m.Descriptor(), fd.Message()) + } + checkMessageDesc(i, + "mti.Descriptor()", "fd.Message()", + mti.Descriptor(), fd.Message()) + if mti := mti.(protoreflect.MessageFieldTypes); mti != nil { + if fd.MapValue().Enum() != nil { + checkEnumDesc(i, + "mti.Enum(fd.MapValue().Index()).Descriptor()", "fd.MapValue().Enum()", + mti.Enum(fd.MapValue().Index()).Descriptor(), fd.MapValue().Enum()) + } + if fd.MapValue().Message() != nil { + checkMessageDesc(i, + "mti.Message(fd.MapValue().Index()).Descriptor()", "fd.MapValue().Message()", + mti.Message(fd.MapValue().Index()).Descriptor(), fd.MapValue().Message()) + checkMessageType(i, + "mti.Message(fd.MapValue().Index())", "m.NewField(fd).Map().NewValue().Message().Type()", + mti.Message(fd.MapValue().Index()), m.NewField(fd).Map().NewValue().Message().Type()) + } + } + default: + if fd.Enum() != nil { + checkEnumDesc(i, + "mt.Enum(i).Descriptor()", "fd.Enum()", + mt.Enum(i).Descriptor(), fd.Enum()) + } + if fd.Message() != nil { + checkMessageDesc(i, + "mt.Message(i).Descriptor()", "fd.Message()", + mt.Message(i).Descriptor(), fd.Message()) + checkMessageType(i, + "mt.Message(i)", "m.NewField(fd).Message().Type()", + mt.Message(i), m.NewField(fd).Message().Type()) + } + } + } +} + +// testField exercises set/get/has/clear of a field. +func testField(t testing.TB, m protoreflect.Message, fd protoreflect.FieldDescriptor) { + name := fd.FullName() + num := fd.Number() + + switch { + case fd.IsList(): + testFieldList(t, m, fd) + case fd.IsMap(): + testFieldMap(t, m, fd) + case fd.Message() != nil: + default: + if got, want := m.NewField(fd), fd.Default(); !valueEqual(got, want) { + t.Errorf("Message.NewField(%v) = %v, want default value %v", name, formatValue(got), formatValue(want)) + } + if fd.Kind() == protoreflect.FloatKind || fd.Kind() == protoreflect.DoubleKind { + testFieldFloat(t, m, fd) + } + } + + // Set to a non-zero value, the zero value, different non-zero values. + for _, n := range []seed{1, 0, minVal, maxVal} { + v := newValue(m, fd, n, nil) + m.Set(fd, v) + wantHas := true + if n == 0 { + if !fd.HasPresence() { + wantHas = false + } + if fd.IsExtension() { + wantHas = true + } + if fd.Cardinality() == protoreflect.Repeated { + wantHas = false + } + if fd.ContainingOneof() != nil { + wantHas = true + } + } + if !fd.HasPresence() && fd.Cardinality() != protoreflect.Repeated && fd.ContainingOneof() == nil && fd.Kind() == protoreflect.EnumKind && v.Enum() == 0 { + wantHas = false + } + if got, want := m.Has(fd), wantHas; got != want { + t.Errorf("after setting %q to %v:\nMessage.Has(%v) = %v, want %v", name, formatValue(v), num, got, want) + } + if got, want := m.Get(fd), v; !valueEqual(got, want) { + t.Errorf("after setting %q:\nMessage.Get(%v) = %v, want %v", name, num, formatValue(got), formatValue(want)) + } + found := false + m.Range(func(d protoreflect.FieldDescriptor, got protoreflect.Value) bool { + if fd != d { + return true + } + found = true + if want := v; !valueEqual(got, want) { + t.Errorf("after setting %q:\nMessage.Range got value %v, want %v", name, formatValue(got), formatValue(want)) + } + return true + }) + if got, want := wantHas, found; got != want { + t.Errorf("after setting %q:\nMessageRange saw field: %v, want %v", name, got, want) + } + } + + m.Clear(fd) + if got, want := m.Has(fd), false; got != want { + t.Errorf("after clearing %q:\nMessage.Has(%v) = %v, want %v", name, num, got, want) + } + switch { + case fd.IsList(): + if got := m.Get(fd); got.List().Len() != 0 { + t.Errorf("after clearing %q:\nMessage.Get(%v) = %v, want empty list", name, num, formatValue(got)) + } + case fd.IsMap(): + if got := m.Get(fd); got.Map().Len() != 0 { + t.Errorf("after clearing %q:\nMessage.Get(%v) = %v, want empty map", name, num, formatValue(got)) + } + case fd.Message() == nil: + if got, want := m.Get(fd), fd.Default(); !valueEqual(got, want) { + t.Errorf("after clearing %q:\nMessage.Get(%v) = %v, want default %v", name, num, formatValue(got), formatValue(want)) + } + } + + // Set to the default value. + switch { + case fd.IsList() || fd.IsMap(): + m.Set(fd, m.Mutable(fd)) + if got, want := m.Has(fd), (fd.IsExtension() && fd.Cardinality() != protoreflect.Repeated) || fd.ContainingOneof() != nil; got != want { + t.Errorf("after setting %q to default:\nMessage.Has(%v) = %v, want %v", name, num, got, want) + } + case fd.Message() == nil: + m.Set(fd, m.Get(fd)) + if got, want := m.Get(fd), fd.Default(); !valueEqual(got, want) { + t.Errorf("after setting %q to default:\nMessage.Get(%v) = %v, want default %v", name, num, formatValue(got), formatValue(want)) + } + } + m.Clear(fd) + + // Set to the wrong type. + v := protoreflect.ValueOfString("") + if fd.Kind() == protoreflect.StringKind { + v = protoreflect.ValueOfInt32(0) + } + if !panics(func() { + m.Set(fd, v) + }) { + t.Errorf("setting %v to %T succeeds, want panic", name, v.Interface()) + } +} + +// testFieldMap tests set/get/has/clear of entries in a map field. +func testFieldMap(t testing.TB, m protoreflect.Message, fd protoreflect.FieldDescriptor) { + name := fd.FullName() + num := fd.Number() + + // New values. + m.Clear(fd) // start with an empty map + mapv := m.Get(fd).Map() + if mapv.IsValid() { + t.Errorf("after clearing field: message.Get(%v).IsValid() = true, want false", name) + } + if got, want := mapv.NewValue(), newMapValue(fd, mapv, 0, nil); !valueEqual(got, want) { + t.Errorf("message.Get(%v).NewValue() = %v, want %v", name, formatValue(got), formatValue(want)) + } + if !panics(func() { + m.Set(fd, protoreflect.ValueOfMap(mapv)) + }) { + t.Errorf("message.Set(%v, ) does not panic", name) + } + if !panics(func() { + mapv.Set(newMapKey(fd, 0), newMapValue(fd, mapv, 0, nil)) + }) { + t.Errorf("message.Get(%v).Set(...) of invalid map does not panic", name) + } + mapv = m.Mutable(fd).Map() // mutable map + if !mapv.IsValid() { + t.Errorf("message.Mutable(%v).IsValid() = false, want true", name) + } + if got, want := mapv.NewValue(), newMapValue(fd, mapv, 0, nil); !valueEqual(got, want) { + t.Errorf("message.Mutable(%v).NewValue() = %v, want %v", name, formatValue(got), formatValue(want)) + } + + // Add values. + want := make(testMap) + for i, n := range []seed{1, 0, minVal, maxVal} { + if got, want := m.Has(fd), i > 0; got != want { + t.Errorf("after inserting %d elements to %q:\nMessage.Has(%v) = %v, want %v", i, name, num, got, want) + } + + k := newMapKey(fd, n) + v := newMapValue(fd, mapv, n, nil) + mapv.Set(k, v) + want.Set(k, v) + if got, want := m.Get(fd), protoreflect.ValueOfMap(want); !valueEqual(got, want) { + t.Errorf("after inserting %d elements to %q:\nMessage.Get(%v) = %v, want %v", i, name, num, formatValue(got), formatValue(want)) + } + } + + // Set values. + want.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool { + nv := newMapValue(fd, mapv, 10, nil) + mapv.Set(k, nv) + want.Set(k, nv) + if got, want := m.Get(fd), protoreflect.ValueOfMap(want); !valueEqual(got, want) { + t.Errorf("after setting element %v of %q:\nMessage.Get(%v) = %v, want %v", formatValue(k.Value()), name, num, formatValue(got), formatValue(want)) + } + return true + }) + + // Clear values. + want.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool { + mapv.Clear(k) + want.Clear(k) + if got, want := m.Has(fd), want.Len() > 0; got != want { + t.Errorf("after clearing elements of %q:\nMessage.Has(%v) = %v, want %v", name, num, got, want) + } + if got, want := m.Get(fd), protoreflect.ValueOfMap(want); !valueEqual(got, want) { + t.Errorf("after clearing elements of %q:\nMessage.Get(%v) = %v, want %v", name, num, formatValue(got), formatValue(want)) + } + return true + }) + if mapv := m.Get(fd).Map(); mapv.IsValid() { + t.Errorf("after clearing all elements: message.Get(%v).IsValid() = true, want false %v", name, formatValue(protoreflect.ValueOfMap(mapv))) + } + + // Non-existent map keys. + missingKey := newMapKey(fd, 1) + if got, want := mapv.Has(missingKey), false; got != want { + t.Errorf("non-existent map key in %q: Map.Has(%v) = %v, want %v", name, formatValue(missingKey.Value()), got, want) + } + if got, want := mapv.Get(missingKey).IsValid(), false; got != want { + t.Errorf("non-existent map key in %q: Map.Get(%v).IsValid() = %v, want %v", name, formatValue(missingKey.Value()), got, want) + } + mapv.Clear(missingKey) // noop + + // Mutable. + if fd.MapValue().Message() == nil { + if !panics(func() { + mapv.Mutable(newMapKey(fd, 1)) + }) { + t.Errorf("Mutable on %q succeeds, want panic", name) + } + } else { + k := newMapKey(fd, 1) + v := mapv.Mutable(k) + if got, want := mapv.Len(), 1; got != want { + t.Errorf("after Mutable on %q, Map.Len() = %v, want %v", name, got, want) + } + populateMessage(v.Message(), 1, nil) + if !valueEqual(mapv.Get(k), v) { + t.Errorf("after Mutable on %q, changing new mutable value does not change map entry", name) + } + mapv.Clear(k) + } +} + +type testMap map[any]protoreflect.Value + +func (m testMap) Get(k protoreflect.MapKey) protoreflect.Value { return m[k.Interface()] } +func (m testMap) Set(k protoreflect.MapKey, v protoreflect.Value) { m[k.Interface()] = v } +func (m testMap) Has(k protoreflect.MapKey) bool { return m.Get(k).IsValid() } +func (m testMap) Clear(k protoreflect.MapKey) { delete(m, k.Interface()) } +func (m testMap) Mutable(k protoreflect.MapKey) protoreflect.Value { panic("unimplemented") } +func (m testMap) Len() int { return len(m) } +func (m testMap) NewValue() protoreflect.Value { panic("unimplemented") } +func (m testMap) Range(f func(protoreflect.MapKey, protoreflect.Value) bool) { + for k, v := range m { + if !f(protoreflect.ValueOf(k).MapKey(), v) { + return + } + } +} +func (m testMap) IsValid() bool { return true } + +// testFieldList exercises set/get/append/truncate of values in a list. +func testFieldList(t testing.TB, m protoreflect.Message, fd protoreflect.FieldDescriptor) { + name := fd.FullName() + num := fd.Number() + + m.Clear(fd) // start with an empty list + list := m.Get(fd).List() + if list.IsValid() { + t.Errorf("message.Get(%v).IsValid() = true, want false", name) + } + if !panics(func() { + m.Set(fd, protoreflect.ValueOfList(list)) + }) { + t.Errorf("message.Set(%v, ) does not panic", name) + } + if !panics(func() { + list.Append(newListElement(fd, list, 0, nil)) + }) { + t.Errorf("message.Get(%v).Append(...) of invalid list does not panic", name) + } + if got, want := list.NewElement(), newListElement(fd, list, 0, nil); !valueEqual(got, want) { + t.Errorf("message.Get(%v).NewElement() = %v, want %v", name, formatValue(got), formatValue(want)) + } + list = m.Mutable(fd).List() // mutable list + if !list.IsValid() { + t.Errorf("message.Get(%v).IsValid() = false, want true", name) + } + if got, want := list.NewElement(), newListElement(fd, list, 0, nil); !valueEqual(got, want) { + t.Errorf("message.Mutable(%v).NewElement() = %v, want %v", name, formatValue(got), formatValue(want)) + } + + // Append values. + var want protoreflect.List = &testList{} + for i, n := range []seed{1, 0, minVal, maxVal} { + if got, want := m.Has(fd), i > 0; got != want { + t.Errorf("after appending %d elements to %q:\nMessage.Has(%v) = %v, want %v", i, name, num, got, want) + } + v := newListElement(fd, list, n, nil) + want.Append(v) + list.Append(v) + + if got, want := m.Get(fd), protoreflect.ValueOfList(want); !valueEqual(got, want) { + t.Errorf("after appending %d elements to %q:\nMessage.Get(%v) = %v, want %v", i+1, name, num, formatValue(got), formatValue(want)) + } + } + + // Set values. + for i := 0; i < want.Len(); i++ { + v := newListElement(fd, list, seed(i+10), nil) + want.Set(i, v) + list.Set(i, v) + if got, want := m.Get(fd), protoreflect.ValueOfList(want); !valueEqual(got, want) { + t.Errorf("after setting element %d of %q:\nMessage.Get(%v) = %v, want %v", i, name, num, formatValue(got), formatValue(want)) + } + } + + // Truncate. + for want.Len() > 0 { + n := want.Len() - 1 + want.Truncate(n) + list.Truncate(n) + if got, want := m.Has(fd), want.Len() > 0; got != want { + t.Errorf("after truncating %q to %d:\nMessage.Has(%v) = %v, want %v", name, n, num, got, want) + } + if got, want := m.Get(fd), protoreflect.ValueOfList(want); !valueEqual(got, want) { + t.Errorf("after truncating %q to %d:\nMessage.Get(%v) = %v, want %v", name, n, num, formatValue(got), formatValue(want)) + } + } + + // AppendMutable. + if fd.Message() == nil { + if !panics(func() { + list.AppendMutable() + }) { + t.Errorf("AppendMutable on %q succeeds, want panic", name) + } + } else { + v := list.AppendMutable() + if got, want := list.Len(), 1; got != want { + t.Errorf("after AppendMutable on %q, list.Len() = %v, want %v", name, got, want) + } + populateMessage(v.Message(), 1, nil) + if !valueEqual(list.Get(0), v) { + t.Errorf("after AppendMutable on %q, changing new mutable value does not change list item 0", name) + } + want.Truncate(0) + } +} + +type testList struct { + a []protoreflect.Value +} + +func (l *testList) Append(v protoreflect.Value) { l.a = append(l.a, v) } +func (l *testList) AppendMutable() protoreflect.Value { panic("unimplemented") } +func (l *testList) Get(n int) protoreflect.Value { return l.a[n] } +func (l *testList) Len() int { return len(l.a) } +func (l *testList) Set(n int, v protoreflect.Value) { l.a[n] = v } +func (l *testList) Truncate(n int) { l.a = l.a[:n] } +func (l *testList) NewElement() protoreflect.Value { panic("unimplemented") } +func (l *testList) IsValid() bool { return true } + +// testFieldFloat exercises some interesting floating-point scalar field values. +func testFieldFloat(t testing.TB, m protoreflect.Message, fd protoreflect.FieldDescriptor) { + name := fd.FullName() + num := fd.Number() + + for _, v := range []float64{math.Inf(-1), math.Inf(1), math.NaN(), math.Copysign(0, -1)} { + var val protoreflect.Value + if fd.Kind() == protoreflect.FloatKind { + val = protoreflect.ValueOfFloat32(float32(v)) + } else { + val = protoreflect.ValueOfFloat64(float64(v)) + } + m.Set(fd, val) + // Note that Has is true for -0. + if got, want := m.Has(fd), true; got != want { + t.Errorf("after setting %v to %v: Message.Has(%v) = %v, want %v", name, v, num, got, want) + } + if got, want := m.Get(fd), val; !valueEqual(got, want) { + t.Errorf("after setting %v: Message.Get(%v) = %v, want %v", name, num, formatValue(got), formatValue(want)) + } + } +} + +// testOneof tests the behavior of fields in a oneof. +func testOneof(t testing.TB, m protoreflect.Message, od protoreflect.OneofDescriptor) { + for _, mutable := range []bool{false, true} { + for i := 0; i < od.Fields().Len(); i++ { + fda := od.Fields().Get(i) + if mutable { + // Set fields by requesting a mutable reference. + if !fda.IsMap() && !fda.IsList() && fda.Message() == nil { + continue + } + _ = m.Mutable(fda) + } else { + // Set fields explicitly. + m.Set(fda, newValue(m, fda, 1, nil)) + } + if !od.IsSynthetic() { + // Synthetic oneofs are used to represent optional fields in + // proto3. While they show up in protoreflect, WhichOneof does + // not work on these (only on non-synthetic, explicit oneofs). + if got, want := m.WhichOneof(od), fda; got != want { + t.Errorf("after setting oneof field %q:\nWhichOneof(%q) = %v, want %v", fda.FullName(), fda.Name(), got, want) + } + } + for j := 0; j < od.Fields().Len(); j++ { + fdb := od.Fields().Get(j) + if got, want := m.Has(fdb), i == j; got != want { + t.Errorf("after setting oneof field %q:\nGet(%q) = %v, want %v", fda.FullName(), fdb.FullName(), got, want) + } + } + } + } +} + +// testUnknown tests the behavior of unknown fields. +func testUnknown(t testing.TB, m protoreflect.Message) { + var b []byte + b = protowire.AppendTag(b, 1000, protowire.VarintType) + b = protowire.AppendVarint(b, 1001) + m.SetUnknown(protoreflect.RawFields(b)) + if got, want := []byte(m.GetUnknown()), b; !bytes.Equal(got, want) { + t.Errorf("after setting unknown fields:\nGetUnknown() = %v, want %v", got, want) + } +} + +func formatValue(v protoreflect.Value) string { + switch v := v.Interface().(type) { + case protoreflect.List: + var buf bytes.Buffer + buf.WriteString("list[") + for i := 0; i < v.Len(); i++ { + if i > 0 { + buf.WriteString(" ") + } + buf.WriteString(formatValue(v.Get(i))) + } + buf.WriteString("]") + return buf.String() + case protoreflect.Map: + var buf bytes.Buffer + buf.WriteString("map[") + var keys []protoreflect.MapKey + v.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool { + keys = append(keys, k) + return true + }) + sort.Slice(keys, func(i, j int) bool { + return keys[i].String() < keys[j].String() + }) + for i, k := range keys { + if i > 0 { + buf.WriteString(" ") + } + buf.WriteString(formatValue(k.Value())) + buf.WriteString(":") + buf.WriteString(formatValue(v.Get(k))) + } + buf.WriteString("]") + return buf.String() + case protoreflect.Message: + b, err := prototext.Marshal(v.Interface()) + if err != nil { + return fmt.Sprintf("<%v>", err) + } + return fmt.Sprintf("%v{%s}", v.Descriptor().FullName(), b) + case string: + return fmt.Sprintf("%q", v) + default: + return fmt.Sprint(v) + } +} + +func valueEqual(a, b protoreflect.Value) bool { + ai, bi := a.Interface(), b.Interface() + switch ai.(type) { + case protoreflect.Message: + return proto.Equal( + a.Message().Interface(), + b.Message().Interface(), + ) + case protoreflect.List: + lista, listb := a.List(), b.List() + if lista.Len() != listb.Len() { + return false + } + for i := 0; i < lista.Len(); i++ { + if !valueEqual(lista.Get(i), listb.Get(i)) { + return false + } + } + return true + case protoreflect.Map: + mapa, mapb := a.Map(), b.Map() + if mapa.Len() != mapb.Len() { + return false + } + equal := true + mapa.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool { + if !valueEqual(v, mapb.Get(k)) { + equal = false + return false + } + return true + }) + return equal + case []byte: + return bytes.Equal(a.Bytes(), b.Bytes()) + case float32: + // NaNs are equal, but must be the same NaN. + return math.Float32bits(ai.(float32)) == math.Float32bits(bi.(float32)) + case float64: + // NaNs are equal, but must be the same NaN. + return math.Float64bits(ai.(float64)) == math.Float64bits(bi.(float64)) + default: + return ai == bi + } +} + +// A seed is used to vary the content of a value. +// +// A seed of 0 is the zero value. Messages do not have a zero-value; a 0-seeded messages +// is unpopulated. +// +// A seed of minVal or maxVal is the least or greatest value of the value type. +type seed int + +const ( + minVal seed = -1 + maxVal seed = -2 +) + +// newSeed creates new seed values from a base, for example to create seeds for the +// elements in a list. If the input seed is minVal or maxVal, so is the output. +func newSeed(n seed, adjust ...int) seed { + switch n { + case minVal, maxVal: + return n + } + for _, a := range adjust { + n = 10*n + seed(a) + } + return n +} + +// newValue returns a new value assignable to a field. +// +// The stack parameter is used to avoid infinite recursion when populating circular +// data structures. +func newValue(m protoreflect.Message, fd protoreflect.FieldDescriptor, n seed, stack []protoreflect.MessageDescriptor) protoreflect.Value { + switch { + case fd.IsList(): + if n == 0 { + return m.New().Mutable(fd) + } + list := m.NewField(fd).List() + list.Append(newListElement(fd, list, 0, stack)) + list.Append(newListElement(fd, list, minVal, stack)) + list.Append(newListElement(fd, list, maxVal, stack)) + list.Append(newListElement(fd, list, n, stack)) + return protoreflect.ValueOfList(list) + case fd.IsMap(): + if n == 0 { + return m.New().Mutable(fd) + } + mapv := m.NewField(fd).Map() + mapv.Set(newMapKey(fd, 0), newMapValue(fd, mapv, 0, stack)) + mapv.Set(newMapKey(fd, minVal), newMapValue(fd, mapv, minVal, stack)) + mapv.Set(newMapKey(fd, maxVal), newMapValue(fd, mapv, maxVal, stack)) + mapv.Set(newMapKey(fd, n), newMapValue(fd, mapv, newSeed(n, 0), stack)) + return protoreflect.ValueOfMap(mapv) + case fd.Message() != nil: + return populateMessage(m.NewField(fd).Message(), n, stack) + default: + return newScalarValue(fd, n) + } +} + +func newListElement(fd protoreflect.FieldDescriptor, list protoreflect.List, n seed, stack []protoreflect.MessageDescriptor) protoreflect.Value { + if fd.Message() == nil { + return newScalarValue(fd, n) + } + return populateMessage(list.NewElement().Message(), n, stack) +} + +func newMapKey(fd protoreflect.FieldDescriptor, n seed) protoreflect.MapKey { + kd := fd.MapKey() + return newScalarValue(kd, n).MapKey() +} + +func newMapValue(fd protoreflect.FieldDescriptor, mapv protoreflect.Map, n seed, stack []protoreflect.MessageDescriptor) protoreflect.Value { + vd := fd.MapValue() + if vd.Message() == nil { + return newScalarValue(vd, n) + } + return populateMessage(mapv.NewValue().Message(), n, stack) +} + +func newScalarValue(fd protoreflect.FieldDescriptor, n seed) protoreflect.Value { + switch fd.Kind() { + case protoreflect.BoolKind: + return protoreflect.ValueOfBool(n != 0) + case protoreflect.EnumKind: + vals := fd.Enum().Values() + var i int + switch n { + case minVal: + i = 0 + case maxVal: + i = vals.Len() - 1 + default: + i = int(n) % vals.Len() + } + return protoreflect.ValueOfEnum(vals.Get(i).Number()) + case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind: + switch n { + case minVal: + return protoreflect.ValueOfInt32(math.MinInt32) + case maxVal: + return protoreflect.ValueOfInt32(math.MaxInt32) + default: + return protoreflect.ValueOfInt32(int32(n)) + } + case protoreflect.Uint32Kind, protoreflect.Fixed32Kind: + switch n { + case minVal: + // Only use 0 for the zero value. + return protoreflect.ValueOfUint32(1) + case maxVal: + return protoreflect.ValueOfUint32(math.MaxInt32) + default: + return protoreflect.ValueOfUint32(uint32(n)) + } + case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind: + switch n { + case minVal: + return protoreflect.ValueOfInt64(math.MinInt64) + case maxVal: + return protoreflect.ValueOfInt64(math.MaxInt64) + default: + return protoreflect.ValueOfInt64(int64(n)) + } + case protoreflect.Uint64Kind, protoreflect.Fixed64Kind: + switch n { + case minVal: + // Only use 0 for the zero value. + return protoreflect.ValueOfUint64(1) + case maxVal: + return protoreflect.ValueOfUint64(math.MaxInt64) + default: + return protoreflect.ValueOfUint64(uint64(n)) + } + case protoreflect.FloatKind: + switch n { + case minVal: + return protoreflect.ValueOfFloat32(math.SmallestNonzeroFloat32) + case maxVal: + return protoreflect.ValueOfFloat32(math.MaxFloat32) + default: + return protoreflect.ValueOfFloat32(1.5 * float32(n)) + } + case protoreflect.DoubleKind: + switch n { + case minVal: + return protoreflect.ValueOfFloat64(math.SmallestNonzeroFloat64) + case maxVal: + return protoreflect.ValueOfFloat64(math.MaxFloat64) + default: + return protoreflect.ValueOfFloat64(1.5 * float64(n)) + } + case protoreflect.StringKind: + if n == 0 { + return protoreflect.ValueOfString("") + } + return protoreflect.ValueOfString(fmt.Sprintf("%d", n)) + case protoreflect.BytesKind: + if n == 0 { + return protoreflect.ValueOfBytes(nil) + } + return protoreflect.ValueOfBytes([]byte{byte(n >> 24), byte(n >> 16), byte(n >> 8), byte(n)}) + } + panic("unhandled kind") +} + +func populateMessage(m protoreflect.Message, n seed, stack []protoreflect.MessageDescriptor) protoreflect.Value { + if n == 0 { + return protoreflect.ValueOfMessage(m) + } + md := m.Descriptor() + for _, x := range stack { + if md == x { + return protoreflect.ValueOfMessage(m) + } + } + stack = append(stack, md) + for i := 0; i < md.Fields().Len(); i++ { + fd := md.Fields().Get(i) + if fd.IsWeak() { + continue + } + m.Set(fd, newValue(m, fd, newSeed(n, i), stack)) + } + return protoreflect.ValueOfMessage(m) +} + +func panics(f func()) (didPanic bool) { + defer func() { + if err := recover(); err != nil { + didPanic = true + } + }() + f() + return false +} From 7ef9f702156c6e1b7580775e7b686e61d8f82892 Mon Sep 17 00:00:00 2001 From: Matan Green Date: Wed, 19 Feb 2025 11:35:25 +0200 Subject: [PATCH 2/6] Fixed DEBUG-3205, DEBUG-3230, DEBUG-3454, DEBUG-3211 + Improved Exploration Testing --- .../diconfig/binary_inspection.go | 12 +- .../diconfig/config_manager.go | 13 +- .../diconfig/location_expression.go | 10 + .../diconfig/mem_config_manager.go | 37 +- pkg/dynamicinstrumentation/ditypes/config.go | 6 + .../proctracker/proctracker.go | 6 + .../testutil/exploration_e2e_test.go | 383 ++++++++++++++---- 7 files changed, 365 insertions(+), 102 deletions(-) diff --git a/pkg/dynamicinstrumentation/diconfig/binary_inspection.go b/pkg/dynamicinstrumentation/diconfig/binary_inspection.go index 9589bfb2ecf241..c62c9975d75fe3 100644 --- a/pkg/dynamicinstrumentation/diconfig/binary_inspection.go +++ b/pkg/dynamicinstrumentation/diconfig/binary_inspection.go @@ -21,15 +21,17 @@ import ( // inspectGoBinaries goes through each service and populates information about the binary // and the relevant parameters, and their types // configEvent maps service names to info about the service and their configurations -func inspectGoBinaries(configEvent ditypes.DIProcs) error { - var err error +func inspectGoBinaries(configEvent ditypes.DIProcs) bool { + var inspectedAtLeastOneBinary bool for i := range configEvent { - err = AnalyzeBinary(configEvent[i]) + err := AnalyzeBinary(configEvent[i]) if err != nil { - return fmt.Errorf("inspection of PID %d (path=%s) failed: %w", configEvent[i].PID, configEvent[i].BinaryPath, err) + log.Info("inspection of PID %d (path=%s) failed: %w", configEvent[i].PID, configEvent[i].BinaryPath, err) + } else { + inspectedAtLeastOneBinary = true } } - return nil + return inspectedAtLeastOneBinary } // AnalyzeBinary reads the binary associated with the specified process and parses diff --git a/pkg/dynamicinstrumentation/diconfig/config_manager.go b/pkg/dynamicinstrumentation/diconfig/config_manager.go index ef64e768dfd0ed..a078f45d852c1c 100644 --- a/pkg/dynamicinstrumentation/diconfig/config_manager.go +++ b/pkg/dynamicinstrumentation/diconfig/config_manager.go @@ -244,14 +244,17 @@ func (cm *RCConfigManager) readConfigs(r *ringbuf.Reader, procInfo *ditypes.Proc func applyConfigUpdate(procInfo *ditypes.ProcessInfo, probe *ditypes.Probe) { log.Tracef("Applying config update: %v\n", probe) - err := AnalyzeBinary(procInfo) - if err != nil { - log.Errorf("couldn't inspect binary: %v\n", err) - return + + if procInfo.TypeMap == nil { + err := AnalyzeBinary(procInfo) + if err != nil { + log.Errorf("couldn't inspect binary: %v\n", err) + return + } } generateCompileAttach: - err = codegen.GenerateBPFParamsCode(procInfo, probe) + err := codegen.GenerateBPFParamsCode(procInfo, probe) if err != nil { log.Info("Couldn't generate BPF programs", err) if !probe.InstrumentationInfo.AttemptedRebuild { diff --git a/pkg/dynamicinstrumentation/diconfig/location_expression.go b/pkg/dynamicinstrumentation/diconfig/location_expression.go index 862166378a495f..5c576f050f4b7e 100644 --- a/pkg/dynamicinstrumentation/diconfig/location_expression.go +++ b/pkg/dynamicinstrumentation/diconfig/location_expression.go @@ -166,6 +166,11 @@ func GenerateLocationExpression(limitsInfo *ditypes.InstrumentationInfo, param * } slicePointer := elementParam.ParameterPieces[0] sliceLength := elementParam.ParameterPieces[1] + + if slicePointer == nil || sliceLength == nil { + continue + } + sliceLength.LocationExpressions = append(sliceLength.LocationExpressions, ditypes.PrintStatement("%s", "Reading the length of slice"), ) @@ -188,6 +193,11 @@ func GenerateLocationExpression(limitsInfo *ditypes.InstrumentationInfo, param * // Generate and collect the location expressions for collecting an individual // element of this slice sliceElementType := slicePointer.ParameterPieces[0] + + if sliceElementType == nil { + continue + } + sliceIdentifier := randomLabel() labelName := randomLabel() diff --git a/pkg/dynamicinstrumentation/diconfig/mem_config_manager.go b/pkg/dynamicinstrumentation/diconfig/mem_config_manager.go index 381080eab517f3..a604920c32d83b 100644 --- a/pkg/dynamicinstrumentation/diconfig/mem_config_manager.go +++ b/pkg/dynamicinstrumentation/diconfig/mem_config_manager.go @@ -24,7 +24,7 @@ import ( type ReaderConfigManager struct { sync.Mutex ConfigWriter *ConfigWriter - procTracker *proctracker.ProcessTracker + ProcTracker *proctracker.ProcessTracker callback configUpdateCallback configs configsByService @@ -40,8 +40,8 @@ func NewReaderConfigManager() (*ReaderConfigManager, error) { state: ditypes.NewDIProcs(), } - cm.procTracker = proctracker.NewProcessTracker(cm.updateProcessInfo) - err := cm.procTracker.Start() + cm.ProcTracker = proctracker.NewProcessTracker(cm.updateProcessInfo) + err := cm.ProcTracker.Start() if err != nil { return nil, err } @@ -63,7 +63,7 @@ func (cm *ReaderConfigManager) GetProcInfos() ditypes.DIProcs { // Stop causes the ReaderConfigManager to stop processing data func (cm *ReaderConfigManager) Stop() { cm.ConfigWriter.Stop() - cm.procTracker.Stop() + cm.ProcTracker.Stop() } func (cm *ReaderConfigManager) update() error { @@ -80,9 +80,9 @@ func (cm *ReaderConfigManager) update() error { } if !reflect.DeepEqual(cm.state, updatedState) { - err := inspectGoBinaries(updatedState) - if err != nil { - return err + atLeastOneBinaryAnalyzed := inspectGoBinaries(updatedState) + if !atLeastOneBinaryAnalyzed { + return fmt.Errorf("failed to inspect all tracked go binaries.") } for pid, procInfo := range cm.state { @@ -159,6 +159,10 @@ func (r *ConfigWriter) Write(p []byte) (n int, e error) { return 0, nil } +func (r *ConfigWriter) WriteSync(p []byte) error { + return r.parseRawConfigBytesAndTriggerCallback(p) +} + // Start initiates the ConfigWriter to start processing data func (r *ConfigWriter) Start() error { go func() { @@ -166,13 +170,7 @@ func (r *ConfigWriter) Start() error { for { select { case rawConfigBytes := <-r.updateChannel: - conf := map[string]map[string]rcConfig{} - err := json.Unmarshal(rawConfigBytes, &conf) - if err != nil { - log.Errorf("invalid config read from reader: %v", err) - continue - } - r.configCallback(conf) + r.parseRawConfigBytesAndTriggerCallback(rawConfigBytes) case <-r.stopChannel: break configUpdateLoop } @@ -181,6 +179,17 @@ func (r *ConfigWriter) Start() error { return nil } +func (r *ConfigWriter) parseRawConfigBytesAndTriggerCallback(rawConfigBytes []byte) error { + conf := map[string]map[string]rcConfig{} + err := json.Unmarshal(rawConfigBytes, &conf) + if err != nil { + log.Errorf("invalid config read from reader: %v", err) + return fmt.Errorf("invalid config read from reader: %v", err) + } + r.configCallback(conf) + return nil +} + // Stop causes the ConfigWriter to stop processing data func (r *ConfigWriter) Stop() { r.stopChannel <- true diff --git a/pkg/dynamicinstrumentation/ditypes/config.go b/pkg/dynamicinstrumentation/ditypes/config.go index 72c98e6aa809e2..9799d137ee3318 100644 --- a/pkg/dynamicinstrumentation/ditypes/config.go +++ b/pkg/dynamicinstrumentation/ditypes/config.go @@ -14,6 +14,7 @@ import ( "io" "strconv" "strings" + "sync" "github.com/DataDog/datadog-agent/pkg/util/log" @@ -138,6 +139,7 @@ type ProcessInfo struct { ProbesByID ProbesByID InstrumentationUprobes map[ProbeID]*link.Link InstrumentationObjects map[ProbeID]*ebpf.Collection + mu sync.RWMutex } // SetupConfigUprobe sets the configuration probe for the process @@ -172,12 +174,16 @@ func (pi *ProcessInfo) CloseConfigUprobe() error { // SetUprobeLink associates the uprobe link with the specified probe // in the tracked process func (pi *ProcessInfo) SetUprobeLink(probeID ProbeID, l *link.Link) { + pi.mu.Lock() + defer pi.mu.Unlock() pi.InstrumentationUprobes[probeID] = l } // CloseUprobeLink closes the probe and deletes the link for the probe // in the tracked process func (pi *ProcessInfo) CloseUprobeLink(probeID ProbeID) error { + pi.mu.Lock() + defer pi.mu.Unlock() if l, ok := pi.InstrumentationUprobes[probeID]; ok { err := (*l).Close() delete(pi.InstrumentationUprobes, probeID) diff --git a/pkg/dynamicinstrumentation/proctracker/proctracker.go b/pkg/dynamicinstrumentation/proctracker/proctracker.go index 663dc5608497b7..790d8bc9e01ac4 100644 --- a/pkg/dynamicinstrumentation/proctracker/proctracker.go +++ b/pkg/dynamicinstrumentation/proctracker/proctracker.go @@ -84,6 +84,12 @@ func (pt *ProcessTracker) Stop() { } } +func (pt *ProcessTracker) Test_HandleProcessStart(pid uint32) { + exePath := filepath.Join(pt.procRoot, strconv.FormatUint(uint64(pid), 10), "exe") + + pt.inspectBinary(exePath, pid) +} + func (pt *ProcessTracker) handleProcessStart(pid uint32) { exePath := filepath.Join(pt.procRoot, strconv.FormatUint(uint64(pid), 10), "exe") diff --git a/pkg/dynamicinstrumentation/testutil/exploration_e2e_test.go b/pkg/dynamicinstrumentation/testutil/exploration_e2e_test.go index d1b0b669a8f0a1..d41ce5dff9d569 100644 --- a/pkg/dynamicinstrumentation/testutil/exploration_e2e_test.go +++ b/pkg/dynamicinstrumentation/testutil/exploration_e2e_test.go @@ -10,8 +10,10 @@ package testutil import ( "bufio" "bytes" + "crypto/sha256" "debug/dwarf" "debug/elf" + "encoding/hex" "encoding/json" "fmt" "html/template" @@ -210,6 +212,15 @@ func getProcessEnv(pid int) ([]string, error) { return env, nil } +func extractDDService(env []string) (string, error) { + for _, entry := range env { + if strings.HasPrefix(entry, "DD_SERVICE=") { + return strings.TrimPrefix(entry, "DD_SERVICE="), nil + } + } + return "", fmt.Errorf("DD_SERVICE not found") +} + func hasDWARFInfo(binaryPath string) (bool, error) { f, err := elf.Open(binaryPath) if err != nil { @@ -528,7 +539,7 @@ func shouldProfileFunction(name string) bool { // return false // } -var NUMBER_OF_PROBES int = 100 +var NUMBER_OF_PROBES int = 10 func filterFunctions(funcs []FunctionInfo) []FunctionInfo { var validFuncs []FunctionInfo @@ -669,6 +680,117 @@ func hasDWARF(binaryPath string) (bool, error) { var analyzedBinaries []BinaryInfo var waitForAttach bool = true +var bufferPool = sync.Pool{ + New: func() interface{} { + return new(bytes.Buffer) + }, +} + +var g_configsAccumulator *ConfigAccumulator + +type rcConfig struct { + ID string + Version int + ProbeType string `json:"type"` + Language string + Where struct { + TypeName string `json:"typeName"` + MethodName string `json:"methodName"` + SourceFile string + Lines []string + } + Tags []string + Template string + CaptureSnapshot bool + EvaluatedAt string + Capture struct { + MaxReferenceDepth int `json:"maxReferenceDepth"` + MaxFieldCount int `json:"maxFieldCount"` + } +} + +type ConfigAccumulator struct { + configs map[string]map[string]rcConfig + tmpl *template.Template + mu sync.RWMutex +} + +func NewConfigAccumulator() (*ConfigAccumulator, error) { + tmpl, err := template.New("config_template").Parse(explorationTestConfigTemplateText) + if err != nil { + return nil, fmt.Errorf("failed to parse template: %w", err) + } + + return &ConfigAccumulator{ + configs: make(map[string]map[string]rcConfig), + tmpl: tmpl, + }, nil +} + +// fingerprintGoBinary opens an ELF binary at binaryPath, +// iterates over its sections (in a sorted order by name), +// skips known non-deterministic sections (like .note.go.buildid), +// and computes a SHA256 hash over the remaining content. +func fingerprintGoBinary(binaryPath string) (string, error) { + // Open the ELF file. + f, err := elf.Open(binaryPath) + if err != nil { + return "", err + } + defer f.Close() + + // Make a copy of the sections and sort them by name. + sections := make([]*elf.Section, len(f.Sections)) + copy(sections, f.Sections) + sort.Slice(sections, func(i, j int) bool { + return sections[i].Name < sections[j].Name + }) + + // Create a hash to accumulate the fingerprint. + hash := sha256.New() + for _, sec := range sections { + // Skip sections with no bytes in the file. + if sec.Type == elf.SHT_NOBITS { + continue + } + + // Skip the Go build ID section. + if sec.Name == ".note.go.buildid" { + continue + } + + // Write the section name to the hash. + if _, err := io.WriteString(hash, sec.Name); err != nil { + return "", err + } + + // Read the section data. + data, err := sec.Data() + if err != nil { + return "", err + } + if _, err := hash.Write(data); err != nil { + return "", err + } + } + + return hex.EncodeToString(hash.Sum(nil)), nil +} + +// HaveISeenItBefore uses a simple in-memory map to record fingerprints. +var seenBinaries = make(map[string]struct{}) + +func isAlreadyProcessed(binaryPath string) (bool, error) { + fingerprint, err := fingerprintGoBinary(binaryPath) + if err != nil { + return false, err + } + if _, exists := seenBinaries[fingerprint]; exists { + return true, nil + } + seenBinaries[fingerprint] = struct{}{} + return false, nil +} func InspectBinary(t *testing.T, binaryPath string, pid int) error { // // check that we can analyse the binary without targeting a specific function @@ -690,6 +812,18 @@ func InspectBinary(t *testing.T, binaryPath string, pid int) error { // return nil // } + //processed, err := isAlreadyProcessed(binaryPath) + // + //if err != nil { + // LogDebug(t, "Failed to determine if `binaryPath` is already processed args: %v, binaryPath: %s", err, binaryPath) + // // Don't fail the entire processing + //} + // + //if processed { + // LogDebug(t, "Already processed %s, skipping.", binaryPath) + // return nil + //} + allFuncs, err := listAllFunctions(binaryPath) if err != nil { analyzedBinaries = append(analyzedBinaries, BinaryInfo{ @@ -700,8 +834,8 @@ func InspectBinary(t *testing.T, binaryPath string, pid int) error { return nil } - // targets := filterFunctions(allFuncs) - targets := allFuncs + targets := filterFunctions(allFuncs) + //targets := allFuncs // Get process arguments args, err := getProcessArgs(pid) @@ -716,13 +850,20 @@ func InspectBinary(t *testing.T, binaryPath string, pid int) error { } // // Get process environment variables - // env, err := getProcessEnv(pid) - // if err != nil { - // return fmt.Errorf("Failed to get Env: %v", err) - // } + env, err := getProcessEnv(pid) + if err != nil { + return fmt.Errorf("Failed to get Env: %v", err) + } + + serviceName, err := extractDDService(env) + if err != nil { + return fmt.Errorf("Failed to get Env: %v, binaryPath: %s", err, binaryPath) + } LogDebug(t, "\n=======================================") + LogDebug(t, "πŸ” SERVICE NAME: %s", serviceName) LogDebug(t, "πŸ” ANALYZING BINARY: %s", binaryPath) + LogDebug(t, "πŸ” ENV: %v", env) LogDebug(t, "πŸ” ARGS: %v", args) LogDebug(t, "πŸ” CWD: %s", cwd) LogDebug(t, "πŸ” Elected %d target functions:", len(targets)) @@ -777,62 +918,98 @@ func InspectBinary(t *testing.T, binaryPath string, pid int) error { LogDebug(t, "βœ… Analysis complete for: %s", binaryPath) LogDebug(t, "=======================================\n") + // Notify the ConfigManager that a new process has arrived + g_ConfigManager.ProcTracker.Test_HandleProcessStart(uint32(pid)) + t.Logf("About to request instrumentations for binary: %s, pid: %d.", binaryPath, pid) - cfgTemplate, err := template.New("config_template").Parse(explorationTestConfigTemplateText) - require.NoError(t, err) + if err := g_configsAccumulator.AddTargets(targets, serviceName); err != nil { + t.Logf("Error adding target: %v, binaryPath: %s", err, binaryPath) + return fmt.Errorf("add targets failed: %v, binary: %s", err, binaryPath) + } - b := []byte{} - var buf *bytes.Buffer + if err = g_configsAccumulator.WriteConfigs(); err != nil { + t.Logf("Error writing configs: %v, binaryPath: %s", err, binaryPath) + return fmt.Errorf("error adding configs: %v, binary: %s", err, binaryPath) + } - // if waitForAttach { - // pid := os.Getpid() - // t.Logf("(1) Waiting to attach for PID: %d", pid) - // time.Sleep(30 * time.Second) - // waitForAttach = false - // } + //cfgTemplate, err := template.New("config_template").Parse(explorationTestConfigTemplateText) + //require.NoError(t, err) + // + //buf := bufferPool.Get().(*bytes.Buffer) + //buf.Reset() + //defer bufferPool.Put(buf) + // + //if err = cfgTemplate.Execute(buf, targets); err != nil { + // return fmt.Errorf("template execution failed: %w", err) + //} + // + //_, err = g_ConfigManager.ConfigWriter.Write(buf.Bytes()) + // + //if err != nil { + // return fmt.Errorf("config writing failed: %v, binary: %s", err, binaryPΖ’ath) + //} + + time.Sleep(2 * time.Second) + + t.Logf("Requested to instrument %d functions for binary: %s, pid: %d.", len(targets), binaryPath, pid) - requesterdFuncs := 0 for _, f := range targets { + t.Logf(" -> requested instrumentation for %v", f) + } - // if !strings.Contains(f.FullName, "blabla_blabla") { - // continue - // } + //b := []byte{} + //var buf *bytes.Buffer - if !strings.Contains(f.FullName, "FullName") { - continue - } + if waitForAttach && os.Getenv("DEBUG") == "true" { + pid := os.Getpid() + t.Logf("(1) Waiting to attach for PID: %d", pid) + time.Sleep(30 * time.Second) + waitForAttach = false + } - // if f.FullName != "regexp.(*bitState).shouldVisit" { - // continue - // } + /* + requesterdFuncs := 0 + for _, f := range targets { - // if f.FullName != "google.golang.org/protobuf/encoding/protodelim_test.(*notBufioReader).UnreadRune" { - // continue - // } + // if !strings.Contains(f.FullName, "blabla_blabla") { + // continue + // } - buf = bytes.NewBuffer(b) - err = cfgTemplate.Execute(buf, f) - if err != nil { - continue - } + // if !strings.Contains(f.FullName, "FullName") { + // continue + // } - // LogDebug(t, "Requesting instrumentation for %v", f) - t.Logf("Requesting instrumentation for %v", f) - _, err := g_ConfigManager.ConfigWriter.Write(buf.Bytes()) + // if f.FullName != "regexp.(*bitState).shouldVisit" { + // continue + // } - if err != nil { - continue - } + // if f.FullName != "google.golang.org/protobuf/encoding/protodelim_test.(*notBufioReader).UnreadRune" { + // continue + // } - requesterdFuncs++ - } + buf = bytes.NewBuffer(b) + err = cfgTemplate.Execute(buf, f) + if err != nil { + continue + } + + // LogDebug(t, "Requesting instrumentation for %v", f) + t.Logf("Requesting instrumentation for %v", f) + _, err := g_ConfigManager.ConfigWriter.Write(buf.Bytes()) + + if err != nil { + continue + } - if !waitForAttach { + requesterdFuncs++ + } + */ + /*if !waitForAttach { time.Sleep(100 * time.Second) - } + }*/ - if requesterdFuncs > 0 { + /*if requesterdFuncs > 0 { // if waitForAttach { // pid := os.Getpid() // t.Logf("(2) Waiting to attach for PID: %d", pid) @@ -844,11 +1021,57 @@ func InspectBinary(t *testing.T, binaryPath string, pid int) error { time.Sleep(2 * time.Second) t.Logf("Requested to instrument %d functions for binary: %s, pid: %d.", requesterdFuncs, binaryPath, pid) + }*/ + + return nil +} + +func (ca *ConfigAccumulator) AddTargets(targets []FunctionInfo, serviceName string) error { + ca.mu.Lock() + defer ca.mu.Unlock() + + buf := bufferPool.Get().(*bytes.Buffer) + buf.Reset() + defer bufferPool.Put(buf) + + buf.WriteString("{") + if err := ca.tmpl.Execute(buf, targets); err != nil { + return fmt.Errorf("failed to execute template: %w", err) + } + buf.WriteString("}") + + var newConfigs map[string]rcConfig + if err := json.NewDecoder(buf).Decode(&newConfigs); err != nil { + return fmt.Errorf("failed to decode generated configs: %w", err) + } + + if ca.configs[serviceName] == nil { + ca.configs[serviceName] = make(map[string]rcConfig) + } + + for probeID, config := range newConfigs { + ca.configs[serviceName][probeID] = config } return nil } +func (ca *ConfigAccumulator) WriteConfigs() error { + ca.mu.RLock() + defer ca.mu.RUnlock() + + buf := bufferPool.Get().(*bytes.Buffer) + buf.Reset() + defer bufferPool.Put(buf) + + // Marshal the full config structure (service name -> probe configs) + if err := json.NewEncoder(buf).Encode(ca.configs); err != nil { + return fmt.Errorf("failed to marshal configs: %w", err) + } + + return g_ConfigManager.ConfigWriter.WriteSync(buf.Bytes()) +} + func (pt *ProcessTracker) addProcess(pid int, parentPID int) *ProcessInfo { pt.mu.Lock() defer pt.mu.Unlock() @@ -1296,7 +1519,7 @@ func (pt *ProcessTracker) logProcessTree() { } } -var DEBUG bool = false +var DEBUG bool = true var TRACE bool = false func (pt *ProcessTracker) LogTrace(format string, args ...any) { @@ -1348,6 +1571,11 @@ func TestExplorationGoDI(t *testing.T) { } g_ConfigManager = cm + g_configsAccumulator, err = NewConfigAccumulator() + + if err != nil { + t.Fatal("Failed to create ConfigAccumulator") + } tempDir := initializeTempDir(t, "/tmp/protobuf-integration-1060272402") modulePath := filepath.Join(tempDir, "src", "google.golang.org", "protobuf") @@ -1584,42 +1812,41 @@ func copyFile(srcFile, dstFile string) error { } var explorationTestConfigTemplateText = ` -{ - "go-di-exploration-test-service": { - "{{.ProbeId}}": { - "id": "{{.ProbeId}}", - "version": 0, - "type": "LOG_PROBE", - "language": "go", - "where": { - "typeName": "{{.PackageName}}", - "methodName": "{{.FunctionName}}" + {{- range $index, $target := .}} + {{- if $index}},{{end}} + "{{$target.ProbeId}}": { + "id": "{{$target.ProbeId}}", + "version": 0, + "type": "LOG_PROBE", + "language": "go", + "where": { + "typeName": "{{$target.PackageName}}", + "methodName": "{{$target.FunctionName}}" + }, + "tags": [], + "template": "Executed {{$target.PackageName}}.{{$target.FunctionName}}, it took {@duration}ms", + "segments": [ + { + "str": "Executed {{$target.PackageName}}.{{$target.FunctionName}}, it took " }, - "tags": [], - "template": "Executed {{.PackageName}}.{{.FunctionName}}, it took {@duration}ms", - "segments": [ - { - "str": "Executed {{.PackageName}}.{{.FunctionName}}, it took " - }, - { + { "dsl": "@duration", "json": { "ref": "@duration" } - }, - { - "str": "ms" - } - ], - "captureSnapshot": false, - "capture": { - "maxReferenceDepth": 10 }, - "sampling": { - "snapshotsPerSecond": 5000 - }, - "evaluateAt": "EXIT" - } + { + "str": "ms" + } + ], + "captureSnapshot": false, + "capture": { + "maxReferenceDepth": 10 + }, + "sampling": { + "snapshotsPerSecond": 5000 + }, + "evaluateAt": "EXIT" } -} + {{- end}} ` From 56f74d1f721b90994315c87dcf4d209f1058c5c5 Mon Sep 17 00:00:00 2001 From: Matan Green Date: Wed, 19 Feb 2025 18:25:42 +0200 Subject: [PATCH 3/6] Exploration Testing --- .../diconfig/binary_inspection.go | 12 +- .../diconfig/mem_config_manager.go | 6 +- .../proctracker/proctracker.go | 2 +- .../testutil/exploration_e2e_test.go | 1404 +++++------------ .../patches/protobuf/integration_test.go | 13 +- 5 files changed, 403 insertions(+), 1034 deletions(-) diff --git a/pkg/dynamicinstrumentation/diconfig/binary_inspection.go b/pkg/dynamicinstrumentation/diconfig/binary_inspection.go index c62c9975d75fe3..3608124eda7761 100644 --- a/pkg/dynamicinstrumentation/diconfig/binary_inspection.go +++ b/pkg/dynamicinstrumentation/diconfig/binary_inspection.go @@ -21,17 +21,23 @@ import ( // inspectGoBinaries goes through each service and populates information about the binary // and the relevant parameters, and their types // configEvent maps service names to info about the service and their configurations -func inspectGoBinaries(configEvent ditypes.DIProcs) bool { +func inspectGoBinaries(configEvent ditypes.DIProcs) error { + var err error var inspectedAtLeastOneBinary bool for i := range configEvent { - err := AnalyzeBinary(configEvent[i]) + err = AnalyzeBinary(configEvent[i]) if err != nil { log.Info("inspection of PID %d (path=%s) failed: %w", configEvent[i].PID, configEvent[i].BinaryPath, err) } else { inspectedAtLeastOneBinary = true } } - return inspectedAtLeastOneBinary + + if !inspectedAtLeastOneBinary { + return fmt.Errorf("failed to inspect all tracked go binaries") + } + + return nil } // AnalyzeBinary reads the binary associated with the specified process and parses diff --git a/pkg/dynamicinstrumentation/diconfig/mem_config_manager.go b/pkg/dynamicinstrumentation/diconfig/mem_config_manager.go index a604920c32d83b..07ad2d0df262ee 100644 --- a/pkg/dynamicinstrumentation/diconfig/mem_config_manager.go +++ b/pkg/dynamicinstrumentation/diconfig/mem_config_manager.go @@ -80,9 +80,9 @@ func (cm *ReaderConfigManager) update() error { } if !reflect.DeepEqual(cm.state, updatedState) { - atLeastOneBinaryAnalyzed := inspectGoBinaries(updatedState) - if !atLeastOneBinaryAnalyzed { - return fmt.Errorf("failed to inspect all tracked go binaries.") + err := inspectGoBinaries(updatedState) + if err != nil { + return err } for pid, procInfo := range cm.state { diff --git a/pkg/dynamicinstrumentation/proctracker/proctracker.go b/pkg/dynamicinstrumentation/proctracker/proctracker.go index 790d8bc9e01ac4..85ef944cd74a93 100644 --- a/pkg/dynamicinstrumentation/proctracker/proctracker.go +++ b/pkg/dynamicinstrumentation/proctracker/proctracker.go @@ -84,7 +84,7 @@ func (pt *ProcessTracker) Stop() { } } -func (pt *ProcessTracker) Test_HandleProcessStart(pid uint32) { +func (pt *ProcessTracker) HandleProcessStartSync(pid uint32) { exePath := filepath.Join(pt.procRoot, strconv.FormatUint(uint64(pid), 10), "exe") pt.inspectBinary(exePath, pid) diff --git a/pkg/dynamicinstrumentation/testutil/exploration_e2e_test.go b/pkg/dynamicinstrumentation/testutil/exploration_e2e_test.go index d41ce5dff9d569..9f93fa23402af6 100644 --- a/pkg/dynamicinstrumentation/testutil/exploration_e2e_test.go +++ b/pkg/dynamicinstrumentation/testutil/exploration_e2e_test.go @@ -93,88 +93,135 @@ type ProbeManager struct { mu sync.Mutex } -func NewProbeManager(t *testing.T) *ProbeManager { - return &ProbeManager{ - t: t, - } +type BinaryInfo struct { + path string + hasDebug bool } -func (pm *ProbeManager) Install(pid int, function string) error { - pm.mu.Lock() - defer pm.mu.Unlock() - - // Get or create the map of installed probes for this PID - v, _ := pm.installedProbes.LoadOrStore(pid, make(map[string]struct{})) - probes := v.(map[string]struct{}) - - // Install the probe - probes[function] = struct{}{} - pm.t.Logf("πŸ”§ Installing probe: PID=%d Function=%s", pid, function) - - // Your actual probe installation logic here using GoDI - // Example: - // err := pm.godi.InstallProbe(pid, function) - return nil +type FunctionInfo struct { + PackageName string + FunctionName string + FullName string + ProbeId string } -func (pm *ProbeManager) Remove(pid int, function string) error { - pm.mu.Lock() - defer pm.mu.Unlock() - - if v, ok := pm.installedProbes.Load(pid); ok { - probes := v.(map[string]struct{}) - delete(probes, function) - pm.t.Logf("πŸ”§ Removing probe: PID=%d Function=%s", pid, function) +func NewFunctionInfo(packageName, functionName, fullName string) FunctionInfo { + return FunctionInfo{ + PackageName: packageName, + FunctionName: functionName, + FullName: fullName, + ProbeId: uuid.NewString(), + } +} - // Your actual probe removal logic here +type rcConfig struct { + ID string + Version int + ProbeType string `json:"type"` + Language string + Where struct { + TypeName string `json:"typeName"` + MethodName string `json:"methodName"` + SourceFile string + Lines []string + } + Tags []string + Template string + CaptureSnapshot bool + EvaluatedAt string + Capture struct { + MaxReferenceDepth int `json:"maxReferenceDepth"` + MaxFieldCount int `json:"maxFieldCount"` } - return nil } -func (pm *ProbeManager) CollectData(pid int, function string) (bool, error) { - // Check if we've received data for this probe - // This is where you'd check your actual data collection mechanism +type ConfigAccumulator struct { + configs map[string]map[string]rcConfig + tmpl *template.Template + mu sync.RWMutex +} - // For testing, let's simulate data collection - // In reality, you'd check if your probe has published any data - if v, ok := pm.dataReceived.Load(pid); ok { - dataMap := v.(map[string]bool) - return dataMap[function], nil - } - return false, nil +type RepoInfo struct { + Packages map[string]bool // Package names found in repo + RepoPath string // Path to the repo + CommitHash string // Current commit hash (optional) } -func NewProcessTracker(t *testing.T) *ProcessTracker { - return &ProcessTracker{ - t: t, - processes: make(map[int]*ProcessInfo), - stopChan: make(chan struct{}), - analyzedBinaries: make(map[string]bool), - analyzedPIDs: make(map[int]bool), - done: make(chan struct{}), - } +type explorationEventOutputTestWriter struct { + t *testing.T + expectedResult map[string]*ditypes.CapturedValue } -func (pt *ProcessTracker) markAnalyzed(pid int, path string) { - pt.mu.Lock() - defer pt.mu.Unlock() - pt.analyzedPIDs[pid] = true - pt.analyzedBinaries[path] = true +func (e *explorationEventOutputTestWriter) Write(p []byte) (n int, err error) { + var snapshot ditypes.SnapshotUpload + if err := json.Unmarshal(p, &snapshot); err != nil { + e.t.Error("failed to unmarshal snapshot", err) + } + funcName := snapshot.Debugger.ProbeInSnapshot.Type + "." + snapshot.Debugger.ProbeInSnapshot.Method + e.t.Logf("Received snapshot for function: %s", funcName) + return len(p), nil } -func getProcessArgs(pid int) ([]string, error) { - // Construct the path to the /proc//cmdline file - procFile := fmt.Sprintf("/proc/%d/cmdline", pid) +var ( + analyzedBinaries []BinaryInfo + waitForAttach bool = true + bufferPool = sync.Pool{New: func() interface{} { return new(bytes.Buffer) }} + seenBinaries = make(map[string]struct{}) + g_configsAccumulator *ConfigAccumulator + g_RepoInfo *RepoInfo + g_ConfigManager *diconfig.ReaderConfigManager + g_cmd *exec.Cmd + DEBUG bool = true + TRACE bool = false + NUMBER_OF_PROBES int = 10 + + explorationTestConfigTemplateText = ` + {{- range $index, $target := .}} + {{- if $index}},{{end}} + "{{$target.ProbeId}}": { + "id": "{{$target.ProbeId}}", + "version": 0, + "type": "LOG_PROBE", + "language": "go", + "where": { + "typeName": "{{$target.PackageName}}", + "methodName": "{{$target.FunctionName}}" + }, + "tags": [], + "template": "Executed {{$target.PackageName}}.{{$target.FunctionName}}, it took {@duration}ms", + "segments": [ + { + "str": "Executed {{$target.PackageName}}.{{$target.FunctionName}}, it took " + }, + { + "dsl": "@duration", + "json": { + "ref": "@duration" + } + }, + { + "str": "ms" + } + ], + "captureSnapshot": false, + "capture": { + "maxReferenceDepth": 10 + }, + "sampling": { + "snapshotsPerSecond": 5000 + }, + "evaluateAt": "EXIT" + } + {{- end}} +` +) - // Read the file content - data, err := os.ReadFile(procFile) +func getProcessArgs(pid int) ([]string, error) { + data, err := os.ReadFile(fmt.Sprintf("/proc/%d/cmdline", pid)) if err != nil { return nil, err } - - // The arguments are null-byte separated, split them args := strings.Split(string(data), "\x00") - // Remove any trailing empty string caused by the trailing null byte if len(args) > 0 && args[len(args)-1] == "" { args = args[:len(args)-1] } @@ -182,11 +229,7 @@ func getProcessArgs(pid int) ([]string, error) { } func getProcessCwd(pid int) (string, error) { - // Construct the path to the /proc//cwd symlink - procFile := fmt.Sprintf("/proc/%d/cwd", pid) - - // Read the symlink to find the current working directory - cwd, err := os.Readlink(procFile) + cwd, err := os.Readlink(fmt.Sprintf("/proc/%d/cwd", pid)) if err != nil { return "", err } @@ -194,18 +237,11 @@ func getProcessCwd(pid int) (string, error) { } func getProcessEnv(pid int) ([]string, error) { - // Construct the path to the /proc//environ file - procFile := fmt.Sprintf("/proc/%d/environ", pid) - - // Open and read the file - data, err := os.ReadFile(procFile) + data, err := os.ReadFile(fmt.Sprintf("/proc/%d/environ", pid)) if err != nil { return nil, err } - - // The environment variables are null-byte separated, split them env := strings.Split(string(data), "\x00") - // Remove any trailing empty string caused by the trailing null byte if len(env) > 0 && env[len(env)-1] == "" { env = env[:len(env)-1] } @@ -228,7 +264,6 @@ func hasDWARFInfo(binaryPath string) (bool, error) { } defer f.Close() - // Try both approaches: section lookup and DWARF data reading debugSections := false for _, section := range f.Sections { if strings.HasPrefix(section.Name, ".debug_") { @@ -237,13 +272,11 @@ func hasDWARFInfo(binaryPath string) (bool, error) { } } - // Try to actually read DWARF data dwarfData, err := f.DWARF() if err != nil { return debugSections, fmt.Errorf("DWARF read error: %w", err) } - // Verify we can read some DWARF data reader := dwarfData.Reader() entry, err := reader.Next() if err != nil { @@ -253,59 +286,9 @@ func hasDWARFInfo(binaryPath string) (bool, error) { fmt.Printf("Found DWARF entry of type: %v\n", entry.Tag) return true, nil } - return false, nil } -type BinaryInfo struct { - path string - hasDebug bool -} - -type FunctionInfo struct { - PackageName string - FunctionName string - FullName string - ProbeId string -} - -func NewFunctionInfo(packageName, functionName, fullName string) FunctionInfo { - return FunctionInfo{ - PackageName: packageName, - FunctionName: functionName, - FullName: fullName, - ProbeId: uuid.NewString(), - } -} - -func extractPackageAndFunction(fullName string) FunctionInfo { - // Handle empty input - if fullName == "" { - return FunctionInfo{} - } - - // First, find the last index of "." before any parentheses - parenIndex := strings.Index(fullName, "(") - lastDot := -1 - if parenIndex != -1 { - // If we have parentheses, look for the last dot before them - lastDot = strings.LastIndex(fullName[:parenIndex], ".") - } else { - // If no parentheses, just find the last dot - lastDot = strings.LastIndex(fullName, ".") - } - - if lastDot == -1 { - return FunctionInfo{} - } - - // Split into package and function parts - pkgPath := fullName[:lastDot] - funcPart := fullName[lastDot+1:] - - return NewFunctionInfo(pkgPath, funcPart, fullName) -} - func listAllFunctions(filePath string) ([]FunctionInfo, error) { var functions []FunctionInfo var errors []string @@ -322,7 +305,6 @@ func listAllFunctions(filePath string) ([]FunctionInfo, error) { } reader := dwarfData.Reader() - for { entry, err := reader.Next() if err != nil { @@ -331,19 +313,16 @@ func listAllFunctions(filePath string) ([]FunctionInfo, error) { if entry == nil { break } - if entry.Tag == dwarf.TagSubprogram { funcName, ok := entry.Val(dwarf.AttrName).(string) if !ok || funcName == "" { continue } - info := extractPackageAndFunction(funcName) if info.FunctionName == "" { errors = append(errors, fmt.Sprintf("could not extract function name from %q", funcName)) continue } - functions = append(functions, info) } } @@ -354,272 +333,86 @@ func listAllFunctions(filePath string) ([]FunctionInfo, error) { } return nil, fmt.Errorf("no functions found in the binary") } - return functions, nil } -// func isStandardPackage(pkg string) bool { -// // List of common standard library packages that might be nested -// stdPkgs := map[string]bool{ -// "encoding/json": true, -// "compress/flate": true, -// "compress/gzip": true, -// "encoding/base64": true, -// // Add more as needed -// } -// return stdPkgs[pkg] -// } - -// func listAllFunctions(filePath string) ([]FunctionInfo, error) { -// var functions []FunctionInfo - -// // Open the ELF file -// ef, err := elf.Open(filePath) -// if err != nil { -// return nil, fmt.Errorf("failed to open file: %v", err) -// } -// defer ef.Close() - -// // Retrieve symbols from the ELF file -// symbols, err := ef.Symbols() -// if err != nil { -// return nil, fmt.Errorf("failed to read symbols: %v", err) -// } - -// // Iterate over symbols and filter function symbols -// for _, sym := range symbols { -// if elf.ST_TYPE(sym.Info) == elf.STT_FUNC { -// // Extract function name -// functionName := sym.Name - -// // Extract package name from section index (if applicable) -// // DWARF data or additional analysis can refine this -// packageName := "" - -// // Add to result -// functions = append(functions, FunctionInfo{ -// PackageName: packageName, -// FunctionName: functionName, -// }) -// } -// } -// return functions, nil -// } +func extractPackageAndFunction(fullName string) FunctionInfo { + if fullName == "" { + return FunctionInfo{} + } + parenIndex := strings.Index(fullName, "(") + lastDot := -1 + if parenIndex != -1 { + lastDot = strings.LastIndex(fullName[:parenIndex], ".") + } else { + lastDot = strings.LastIndex(fullName, ".") + } + if lastDot == -1 { + return FunctionInfo{} + } + pkgPath := fullName[:lastDot] + funcPart := fullName[lastDot+1:] + return NewFunctionInfo(pkgPath, funcPart, fullName) +} func shouldProfileFunction(name string) bool { - // First, immediately reject known system/internal functions - if strings.HasPrefix(name, "*ZN") || // Sanitizer/LLVM functions - strings.HasPrefix(name, "_") || // Internal functions + if strings.HasPrefix(name, "*ZN") || + strings.HasPrefix(name, "_") || strings.Contains(name, "_sanitizer") || strings.Contains(name, "runtime.") { return false } - - // Extract package from function name parts := strings.Split(name, ".") if len(parts) < 2 { return false } - pkgPath := parts[0] if len(parts) > 2 { pkgPath = strings.Join(parts[:len(parts)-1], "/") } - - // Check if it's in our repository packages for repoPkg := range g_RepoInfo.Packages { if strings.Contains(pkgPath, repoPkg) { return true } } - return false } -// func shouldProfileFunction(name string) bool { -// // Skip standard library packages -// stdlibPrefixes := []string{ -// "bufio.", -// "bytes.", -// "context.", -// "crypto.", -// "compress/", -// "database/", -// "debug/", -// "encoding/", -// "errors.", -// "flag.", -// "fmt.", -// "io.", -// "log.", -// "math.", -// "net.", -// "os.", -// "path.", -// "reflect.", -// "regexp.", -// "runtime.", -// "sort.", -// "strconv.", -// "strings.", -// "sync.", -// "syscall.", -// "time.", -// "unicode.", -// } - -// // Definitely skip these system internals -// skipPrefixes := []string{ -// "runtime.", -// "runtime/race", -// "*ZN", // LLVM/Clang internals -// "type..", // Go type metadata -// "gc.", // Garbage collector -// "gosb.", // Go sandbox -// "_rt.", // Runtime helpers -// "reflect.", // Reflection internals -// } - -// skipContains := []string{ -// "_sanitizer", -// "_tsan", -// ".constprop.", // Compiler generated constants -// ".isra.", // LLVM optimized functions -// ".part.", // Partial functions from compiler -// "__gcc_", // GCC internals -// "_cgo_", // CGO generated code -// "goexit", // Go runtime exit handlers -// "gcproc", // GC procedures -// ".loc.", // Location metadata -// "runtimeΒ·", // Runtime internals (different dot) -// } - -// // Quick reject for standard library and system functions -// for _, prefix := range append(stdlibPrefixes, skipPrefixes...) { -// if strings.HasPrefix(name, prefix) { -// return false -// } -// } - -// for _, substr := range skipContains { -// if strings.Contains(name, substr) { -// return false -// } -// } - -// // High priority user functions - definitely profile these -// priorityPrefixes := []string{ -// "main.", -// "cmd.", -// "github.com/", -// "golang.org/x/", -// "google.golang.org/", -// "k8s.io/", -// } - -// for _, prefix := range priorityPrefixes { -// if strings.HasPrefix(name, prefix) { -// return true -// } -// } - -// // Function looks like a normal Go function (CapitalizedName) -// if len(name) > 0 && unicode.IsUpper(rune(name[0])) { -// return true -// } - -// // If it contains a dot and doesn't look like a compiler-generated name -// if strings.Contains(name, ".") && -// !strings.Contains(name, "$") && -// !strings.Contains(name, "__") { -// return true -// } - -// // If we get here, it's probably a system function -// return false -// } - -var NUMBER_OF_PROBES int = 10 - func filterFunctions(funcs []FunctionInfo) []FunctionInfo { var validFuncs []FunctionInfo - - // First pass: collect only functions from our packages for _, f := range funcs { - // Combine package and function name for filtering fullName := fmt.Sprintf("%s.%s", f.PackageName, f.FunctionName) if shouldProfileFunction(fullName) { validFuncs = append(validFuncs, f) } } - - // If we have no valid functions, return empty list if len(validFuncs) == 0 { return nil } - - // Sort valid functions for consistent ordering sort.Slice(validFuncs, func(i, j int) bool { - // Sort alphabetically by full name (package + function) fullNameI := fmt.Sprintf("%s.%s", validFuncs[i].PackageName, validFuncs[i].FunctionName) fullNameJ := fmt.Sprintf("%s.%s", validFuncs[j].PackageName, validFuncs[j].FunctionName) return fullNameI < fullNameJ }) - - // Return all if we have 10 or fewer if len(validFuncs) <= NUMBER_OF_PROBES { return validFuncs } - - // Only take first 10 if we have more return validFuncs[:NUMBER_OF_PROBES] } -// func filterFunctions(funcs []string) []string { -// var validFuncs []string - -// // First pass: collect only functions from our packages -// for _, f := range funcs { -// if shouldProfileFunction(f) { -// validFuncs = append(validFuncs, f) -// } -// } - -// // If we have no valid functions, return empty list -// if len(validFuncs) == 0 { -// return nil -// } - -// // Sort for consistent ordering -// sort.Strings(validFuncs) - -// // Return all if we have 10 or fewer -// if len(validFuncs) <= NUMBER_OF_PROBES { -// return validFuncs -// } - -// // Only take first 10 if we have more -// return validFuncs[:NUMBER_OF_PROBES] -// } - func ExtractFunctions(binaryPath string) ([]FunctionInfo, error) { - // Open the binary file, err := elf.Open(binaryPath) if err != nil { return nil, fmt.Errorf("failed to open binary: %v", err) } defer file.Close() - // Get DWARF data dwarfData, err := file.DWARF() if err != nil { return nil, fmt.Errorf("failed to load DWARF data: %v", err) } - // Prepare result var functions []FunctionInfo - - // Iterate over DWARF entries reader := dwarfData.Reader() for { entry, err := reader.Next() @@ -627,21 +420,14 @@ func ExtractFunctions(binaryPath string) ([]FunctionInfo, error) { return nil, fmt.Errorf("error reading DWARF: %v", err) } if entry == nil { - break // End of entries + break } - - // Check for subprogram (function) entries if entry.Tag == dwarf.TagSubprogram { - // Extract function name funcName, _ := entry.Val(dwarf.AttrName).(string) - - // Extract package/module name (if available) var packageName string if compDir, ok := entry.Val(dwarf.AttrCompDir).(string); ok { packageName = compDir } - - // Add to the result if funcName != "" { functions = append(functions, FunctionInfo{ PackageName: packageName, @@ -650,121 +436,47 @@ func ExtractFunctions(binaryPath string) ([]FunctionInfo, error) { } } } - return functions, nil } -// hasDWARF checks if the given binary contains DWARF debug information. func hasDWARF(binaryPath string) (bool, error) { - // Open the binary file file, err := elf.Open(binaryPath) if err != nil { return false, fmt.Errorf("failed to open binary: %v", err) } defer file.Close() - // Check if DWARF data exists _, err = file.DWARF() if err != nil { - // Check if the error indicates missing DWARF information if err.Error() == "no DWARF data" { return false, nil } - // Otherwise, propagate the error return false, fmt.Errorf("failed to check DWARF data: %v", err) } - - // DWARF data exists return true, nil } -var analyzedBinaries []BinaryInfo -var waitForAttach bool = true -var bufferPool = sync.Pool{ - New: func() interface{} { - return new(bytes.Buffer) - }, -} +func fingerprintGoBinary(binaryPath string) (string, error) { + f, err := elf.Open(binaryPath) + if err != nil { + return "", err + } + defer f.Close() -var g_configsAccumulator *ConfigAccumulator + sections := make([]*elf.Section, len(f.Sections)) + copy(sections, f.Sections) + sort.Slice(sections, func(i, j int) bool { + return sections[i].Name < sections[j].Name + }) -type rcConfig struct { - ID string - Version int - ProbeType string `json:"type"` - Language string - Where struct { - TypeName string `json:"typeName"` - MethodName string `json:"methodName"` - SourceFile string - Lines []string - } - Tags []string - Template string - CaptureSnapshot bool - EvaluatedAt string - Capture struct { - MaxReferenceDepth int `json:"maxReferenceDepth"` - MaxFieldCount int `json:"maxFieldCount"` - } -} - -type ConfigAccumulator struct { - configs map[string]map[string]rcConfig - tmpl *template.Template - mu sync.RWMutex -} - -func NewConfigAccumulator() (*ConfigAccumulator, error) { - tmpl, err := template.New("config_template").Parse(explorationTestConfigTemplateText) - if err != nil { - return nil, fmt.Errorf("failed to parse template: %w", err) - } - - return &ConfigAccumulator{ - configs: make(map[string]map[string]rcConfig), - tmpl: tmpl, - }, nil -} - -// fingerprintGoBinary opens an ELF binary at binaryPath, -// iterates over its sections (in a sorted order by name), -// skips known non-deterministic sections (like .note.go.buildid), -// and computes a SHA256 hash over the remaining content. -func fingerprintGoBinary(binaryPath string) (string, error) { - // Open the ELF file. - f, err := elf.Open(binaryPath) - if err != nil { - return "", err - } - defer f.Close() - - // Make a copy of the sections and sort them by name. - sections := make([]*elf.Section, len(f.Sections)) - copy(sections, f.Sections) - sort.Slice(sections, func(i, j int) bool { - return sections[i].Name < sections[j].Name - }) - - // Create a hash to accumulate the fingerprint. hash := sha256.New() for _, sec := range sections { - // Skip sections with no bytes in the file. - if sec.Type == elf.SHT_NOBITS { + if sec.Type == elf.SHT_NOBITS || sec.Name == ".note.go.buildid" { continue } - - // Skip the Go build ID section. - if sec.Name == ".note.go.buildid" { - continue - } - - // Write the section name to the hash. if _, err := io.WriteString(hash, sec.Name); err != nil { return "", err } - - // Read the section data. data, err := sec.Data() if err != nil { return "", err @@ -773,13 +485,9 @@ func fingerprintGoBinary(binaryPath string) (string, error) { return "", err } } - return hex.EncodeToString(hash.Sum(nil)), nil } -// HaveISeenItBefore uses a simple in-memory map to record fingerprints. -var seenBinaries = make(map[string]struct{}) - func isAlreadyProcessed(binaryPath string) (bool, error) { fingerprint, err := fingerprintGoBinary(binaryPath) if err != nil { @@ -792,284 +500,102 @@ func isAlreadyProcessed(binaryPath string) (bool, error) { return false, nil } -func InspectBinary(t *testing.T, binaryPath string, pid int) error { - // // check that we can analyse the binary without targeting a specific function - // err := diconfig.AnalyzeBinary(&ditypes.ProcessInfo{BinaryPath: binaryPath}) - // if err != nil { - // // log.Fatalln("Failed to analyze", binaryPath, "--", err) - // return nil - // } - - // targets, err := ExtractFunctions(binaryPath) - // if err != nil { - // // log.Fatalf("Error extracting functions: %v", err) - // return nil - // } - - // hasDwarf, err := hasDWARF(binaryPath) - // if err != nil || !hasDwarf { - // // log.Fatalf("Error checking for DWARF info: %v", err) - // return nil - // } - - //processed, err := isAlreadyProcessed(binaryPath) - // - //if err != nil { - // LogDebug(t, "Failed to determine if `binaryPath` is already processed args: %v, binaryPath: %s", err, binaryPath) - // // Don't fail the entire processing - //} - // - //if processed { - // LogDebug(t, "Already processed %s, skipping.", binaryPath) - // return nil - //} - - allFuncs, err := listAllFunctions(binaryPath) - if err != nil { - analyzedBinaries = append(analyzedBinaries, BinaryInfo{ - path: binaryPath, - hasDebug: false, - }) - - return nil - } - - targets := filterFunctions(allFuncs) - //targets := allFuncs - - // Get process arguments - args, err := getProcessArgs(pid) - if err != nil { - return fmt.Errorf("Failed to process args: %v", err) - } - - // Get process current working directory - cwd, err := getProcessCwd(pid) +func getBinaryPath(pid int) string { + path, err := os.Readlink(fmt.Sprintf("/proc/%d/exe", pid)) if err != nil { - return fmt.Errorf("Failed to get Cwd: %v", err) + return "" } - - // // Get process environment variables - env, err := getProcessEnv(pid) - if err != nil { - return fmt.Errorf("Failed to get Env: %v", err) + realPath, err := filepath.EvalSymlinks(path) + if err == nil { + path = realPath } + return path +} - serviceName, err := extractDDService(env) +func getParentPID(pid int) int { + ppidStr, err := os.ReadFile(fmt.Sprintf("/proc/%d/stat", pid)) if err != nil { - return fmt.Errorf("Failed to get Env: %v, binaryPath: %s", err, binaryPath) - } - - LogDebug(t, "\n=======================================") - LogDebug(t, "πŸ” SERVICE NAME: %s", serviceName) - LogDebug(t, "πŸ” ANALYZING BINARY: %s", binaryPath) - LogDebug(t, "πŸ” ENV: %v", env) - LogDebug(t, "πŸ” ARGS: %v", args) - LogDebug(t, "πŸ” CWD: %s", cwd) - LogDebug(t, "πŸ” Elected %d target functions:", len(targets)) - for _, f := range targets { - LogDebug(t, " β†’ Package: %s, Function: %s, FullName: %s", f.PackageName, f.FunctionName, f.FullName) - } - - // hasDWARF, dwarfErr := hasDWARFInfo(binaryPath) - // if dwarfErr != nil { - // log.Printf("Error checking DWARF info: %v", dwarfErr) - // } else { - // log.Printf("Binary has DWARF info: %v", hasDWARF) - // } - // LogDebug(t, "πŸ” ENV: %v", env) - LogDebug(t, "=======================================") - - // Check if the binary exists - if _, err := os.Stat(binaryPath); err != nil { - return fmt.Errorf("(1) binary inspection failed: %v", err) - } - - analyzedBinaries = append(analyzedBinaries, BinaryInfo{ - path: binaryPath, - hasDebug: len(targets) > 0, - }) - - // i := 0 - // // Re-check binary existence - // for { - // if _, err := os.Stat(binaryPath); err != nil { - // time.Sleep(10 * time.Hour) - // return fmt.Errorf("(2) binary inspection failed: %v", err) - // } - - // // if strings.HasSuffix(binaryPath, "generate-protos") { - // // break - // // } - - // if strings.HasSuffix(binaryPath, "conformance.test") { - // time.Sleep(10 * time.Second) - // break - // } - - // i++ - // if i > 11 { - // break - // } - - // // time.Sleep(100 * time.Millisecond) - // } - - LogDebug(t, "βœ… Analysis complete for: %s", binaryPath) - LogDebug(t, "=======================================\n") - - // Notify the ConfigManager that a new process has arrived - g_ConfigManager.ProcTracker.Test_HandleProcessStart(uint32(pid)) - - t.Logf("About to request instrumentations for binary: %s, pid: %d.", binaryPath, pid) - - if err := g_configsAccumulator.AddTargets(targets, serviceName); err != nil { - t.Logf("Error adding target: %v, binaryPath: %s", err, binaryPath) - return fmt.Errorf("add targets failed: %v, binary: %s", err, binaryPath) + return 0 } - - if err = g_configsAccumulator.WriteConfigs(); err != nil { - t.Logf("Error writing configs: %v, binaryPath: %s", err, binaryPath) - return fmt.Errorf("error adding configs: %v, binary: %s", err, binaryPath) + fields := strings.Fields(string(ppidStr)) + if len(fields) < 4 { + return 0 } + ppid, _ := strconv.Atoi(fields[3]) + return ppid +} - //cfgTemplate, err := template.New("config_template").Parse(explorationTestConfigTemplateText) - //require.NoError(t, err) - // - //buf := bufferPool.Get().(*bytes.Buffer) - //buf.Reset() - //defer bufferPool.Put(buf) - // - //if err = cfgTemplate.Execute(buf, targets); err != nil { - // return fmt.Errorf("template execution failed: %w", err) - //} - // - //_, err = g_ConfigManager.ConfigWriter.Write(buf.Bytes()) - // - //if err != nil { - // return fmt.Errorf("config writing failed: %v, binary: %s", err, binaryPΖ’ath) - //} - - time.Sleep(2 * time.Second) - - t.Logf("Requested to instrument %d functions for binary: %s, pid: %d.", len(targets), binaryPath, pid) - - for _, f := range targets { - t.Logf(" -> requested instrumentation for %v", f) +func findAncestors(pid int, tree map[int]bool) { + for pid > 1 { + if tree[pid] { + return + } + tree[pid] = true + ppid := getParentPID(pid) + if ppid <= 1 { + return + } + pid = ppid } +} - //b := []byte{} - //var buf *bytes.Buffer - - if waitForAttach && os.Getenv("DEBUG") == "true" { - pid := os.Getpid() - t.Logf("(1) Waiting to attach for PID: %d", pid) - time.Sleep(30 * time.Second) - waitForAttach = false +func LogDebug(t *testing.T, format string, args ...any) { + if DEBUG { + t.Logf(format, args...) } +} - /* - requesterdFuncs := 0 - for _, f := range targets { - - // if !strings.Contains(f.FullName, "blabla_blabla") { - // continue - // } - - // if !strings.Contains(f.FullName, "FullName") { - // continue - // } - - // if f.FullName != "regexp.(*bitState).shouldVisit" { - // continue - // } - - // if f.FullName != "google.golang.org/protobuf/encoding/protodelim_test.(*notBufioReader).UnreadRune" { - // continue - // } - - buf = bytes.NewBuffer(b) - err = cfgTemplate.Execute(buf, f) - if err != nil { - continue - } - - // LogDebug(t, "Requesting instrumentation for %v", f) - t.Logf("Requesting instrumentation for %v", f) - _, err := g_ConfigManager.ConfigWriter.Write(buf.Bytes()) - - if err != nil { - continue - } - - requesterdFuncs++ - } - */ - /*if !waitForAttach { - time.Sleep(100 * time.Second) - }*/ - - /*if requesterdFuncs > 0 { - // if waitForAttach { - // pid := os.Getpid() - // t.Logf("(2) Waiting to attach for PID: %d", pid) - // time.Sleep(30 * time.Second) - // waitForAttach = false - // } - - // Wait for probes to be instrumented - time.Sleep(2 * time.Second) +func NewProbeManager(t *testing.T) *ProbeManager { + return &ProbeManager{t: t} +} - t.Logf("Requested to instrument %d functions for binary: %s, pid: %d.", requesterdFuncs, binaryPath, pid) - }*/ +func (pm *ProbeManager) Install(pid int, function string) error { + pm.mu.Lock() + defer pm.mu.Unlock() + v, _ := pm.installedProbes.LoadOrStore(pid, make(map[string]struct{})) + probes := v.(map[string]struct{}) + probes[function] = struct{}{} + pm.t.Logf("Installing probe: PID=%d Function=%s", pid, function) return nil } -func (ca *ConfigAccumulator) AddTargets(targets []FunctionInfo, serviceName string) error { - ca.mu.Lock() - defer ca.mu.Unlock() - - buf := bufferPool.Get().(*bytes.Buffer) - buf.Reset() - defer bufferPool.Put(buf) - - buf.WriteString("{") - if err := ca.tmpl.Execute(buf, targets); err != nil { - return fmt.Errorf("failed to execute template: %w", err) - } - buf.WriteString("}") - - var newConfigs map[string]rcConfig - if err := json.NewDecoder(buf).Decode(&newConfigs); err != nil { - return fmt.Errorf("failed to decode generated configs: %w", err) - } - - if ca.configs[serviceName] == nil { - ca.configs[serviceName] = make(map[string]rcConfig) - } +func (pm *ProbeManager) Remove(pid int, function string) error { + pm.mu.Lock() + defer pm.mu.Unlock() - for probeID, config := range newConfigs { - ca.configs[serviceName][probeID] = config + if v, ok := pm.installedProbes.Load(pid); ok { + probes := v.(map[string]struct{}) + delete(probes, function) + pm.t.Logf("Removing probe: PID=%d Function=%s", pid, function) } - return nil } -func (ca *ConfigAccumulator) WriteConfigs() error { - ca.mu.RLock() - defer ca.mu.RUnlock() - - buf := bufferPool.Get().(*bytes.Buffer) - buf.Reset() - defer bufferPool.Put(buf) +func (pm *ProbeManager) CollectData(pid int, function string) (bool, error) { + if v, ok := pm.dataReceived.Load(pid); ok { + dataMap := v.(map[string]bool) + return dataMap[function], nil + } + return false, nil +} - // Marshal the full config structure (service name -> probe configs) - if err := json.NewEncoder(buf).Encode(ca.configs); err != nil { - return fmt.Errorf("failed to marshal configs: %w", err) +func NewProcessTracker(t *testing.T) *ProcessTracker { + return &ProcessTracker{ + t: t, + processes: make(map[int]*ProcessInfo), + stopChan: make(chan struct{}), + analyzedBinaries: make(map[string]bool), + analyzedPIDs: make(map[int]bool), + done: make(chan struct{}), } +} - return g_ConfigManager.ConfigWriter.WriteSync(buf.Bytes()) +func (pt *ProcessTracker) markAnalyzed(pid int, path string) { + pt.mu.Lock() + defer pt.mu.Unlock() + pt.analyzedPIDs[pid] = true + pt.analyzedBinaries[path] = true } func (pt *ProcessTracker) addProcess(pid int, parentPID int) *ProcessInfo { @@ -1089,72 +615,34 @@ func (pt *ProcessTracker) addProcess(pid int, parentPID int) *ProcessInfo { StartTime: time.Now(), Analyzed: false, } - pt.processes[pid] = proc - - // Add to parent's children if parent exists if parent, exists := pt.processes[parentPID]; exists { parent.Children = append(parent.Children, proc) } - - pt.LogTrace("πŸ‘Ά New process: PID=%d, Parent=%d, Binary=%s", pid, parentPID, binaryPath) + pt.LogTrace("New process: PID=%d, Parent=%d, Binary=%s", pid, parentPID, binaryPath) return proc } -func getBinaryPath(pid int) string { - path, err := os.Readlink(fmt.Sprintf("/proc/%d/exe", pid)) - if err != nil { - return "" - } - - // Resolve any symlinks - realPath, err := filepath.EvalSymlinks(path) - if err == nil { - path = realPath - } - - return path -} - func (pt *ProcessTracker) analyzeBinary(pid int, info *ProcessInfo) error { if info == nil { return fmt.Errorf("nil process info") } - pt.mu.Lock() info.State = StateAnalyzing pt.mu.Unlock() - // pt.LogTrace("πŸ”Ž Analyzing binary PID=%d Path=%s", pid, info.BinaryPath) - - // Perform analysis if err := InspectBinary(pt.t, info.BinaryPath, pid); err != nil { pt.mu.Lock() info.State = StateNew pt.mu.Unlock() return fmt.Errorf("binary analysis failed: %v", err) } - pt.mu.Lock() info.State = StateRunning pt.mu.Unlock() - return nil } -func getParentPID(pid int) int { - ppidStr, err := os.ReadFile(fmt.Sprintf("/proc/%d/stat", pid)) - if err != nil { - return 0 - } - fields := strings.Fields(string(ppidStr)) - if len(fields) < 4 { - return 0 - } - ppid, _ := strconv.Atoi(fields[3]) - return ppid -} - func (pt *ProcessTracker) scanProcessTree() error { if err := syscall.Kill(-g_cmd.Process.Pid, syscall.SIGSTOP); err != nil { if err != unix.ESRCH { @@ -1162,30 +650,13 @@ func (pt *ProcessTracker) scanProcessTree() error { } return nil } - - // pt.profiler.OnProcessesPaused() - defer func() { if err := syscall.Kill(-g_cmd.Process.Pid, syscall.SIGCONT); err != nil { if err != unix.ESRCH { pt.LogTrace("⚠️ Failed to resume PID %d: %v", -g_cmd.Process.Pid, err) } - } else { - // pt.LogTrace("▢️ Resumed process: PID=%d", -g_cmd.Process.Pid) } - - // pt.profiler.OnProcessesResumed() - - // if err := unix.Kill(pid, unix.SIGCONT); err != nil { - // if err != unix.ESRCH { - // pt.LogTrace("⚠️ Failed to resume PID %d: %v", pid, err) - // } - // } else { - // pt.LogTrace("▢️ Resumed process: PID=%d", pid) - // } }() - - // Get all processes allPids := make(map[int]bool) if entries, err := os.ReadDir("/proc"); err == nil { for _, entry := range entries { @@ -1194,8 +665,6 @@ func (pt *ProcessTracker) scanProcessTree() error { } } } - - // Record our own process tree for exclusion ourProcessTree := make(map[int]bool) ourPid := os.Getpid() findAncestors(ourPid, ourProcessTree) @@ -1205,42 +674,25 @@ func (pt *ProcessTracker) scanProcessTree() error { path string ppid int } - - // Check each PID for pid := range allPids { - // Skip if already analyzed pt.mu.RLock() if pt.analyzedPIDs[pid] { pt.mu.RUnlock() continue } pt.mu.RUnlock() - - // Skip if in our process tree if ourProcessTree[pid] { continue } - - // Get process path binaryPath := getBinaryPath(pid) if binaryPath == "" { continue } - - // Get parent PID ppid := getParentPID(pid) - - // Skip if parent is in our tree if ourProcessTree[ppid] { continue } - - // Always analyze: - // 1. Test binaries (.test) - // 2. Go build executables in /tmp - // 3. Children of test binaries shouldAnalyze := false - if strings.HasSuffix(binaryPath, ".test") { shouldAnalyze = true pt.LogTrace("Found test binary: %s (PID=%d)", binaryPath, pid) @@ -1248,185 +700,107 @@ func (pt *ProcessTracker) scanProcessTree() error { shouldAnalyze = true pt.LogTrace("Found build binary: %s (PID=%d)", binaryPath, pid) } else { - // Check if parent is a test binary parentPath := getBinaryPath(ppid) if strings.HasSuffix(parentPath, ".test") { shouldAnalyze = true pt.LogTrace("Found child of test: %s (PID=%d, Parent=%d)", binaryPath, pid, ppid) } } - if shouldAnalyze { - // Verify process still exists if _, err := os.Stat(fmt.Sprintf("/proc/%d", pid)); err == nil { toAnalyze = append(toAnalyze, struct { pid int path string ppid int }{pid, binaryPath, ppid}) - - // Add to process tree if pt.processes[pid] == nil { pt.addProcess(pid, ppid) } } } } - if len(toAnalyze) > 0 { - pt.LogTrace("\nπŸ” Found %d processes to analyze:", len(toAnalyze)) + pt.LogTrace("πŸ” Found %d processes to analyze:", len(toAnalyze)) for _, p := range toAnalyze { pt.LogTrace(" PID=%d PPID=%d Path=%s", p.pid, p.ppid, p.path) } } - var activePids []int for _, p := range toAnalyze { activePids = append(activePids, p.pid) } - - // if pt.profiler!= nil { - // pt.profiler.OnTick(activePids) - // } - - // Process in small batches batchSize := 2 for i := 0; i < len(toAnalyze); i += batchSize { end := i + batchSize if end > len(toAnalyze) { end = len(toAnalyze) } - var wg sync.WaitGroup for _, p := range toAnalyze[i:end] { wg.Add(1) go func(pid int, path string) { defer wg.Done() - - // Verify process still exists if _, err := os.Stat(fmt.Sprintf("/proc/%d", pid)); err != nil { return } - - pt.LogTrace("πŸ” Stopping process for analysis: PID=%d Path=%s", pid, path) - - // Get process info + pt.LogTrace("Stopping process for analysis: PID=%d Path=%s", pid, path) pt.mu.RLock() proc := pt.processes[pid] pt.mu.RUnlock() - if proc == nil { return } - - // Stop process - // if err := syscall.Kill(-g_cmd.Process.Pid, syscall.SIGSTOP); err != nil { - // if err != unix.ESRCH { - // pt.LogTrace("⚠️ Failed to stop PID %d: %v", pid, err) - // } - // return - // } - - // if err := unix.Kill(pid, unix.SIGSTOP); err != nil { - // if err != unix.ESRCH { - // pt.LogTrace("⚠️ Failed to stop PID %d: %v", pid, err) - // } - // return - // } - - // Ensure process gets resumed - // defer func() { - // if err := syscall.Kill(-g_cmd.Process.Pid, syscall.SIGCONT); err != nil { - // if err != unix.ESRCH { - // pt.LogTrace("⚠️ Failed to resume PID %d: %v", pid, err) - // } - // } else { - // pt.LogTrace("▢️ Resumed process: PID=%d", pid) - // } - - // // if err := unix.Kill(pid, unix.SIGCONT); err != nil { - // // if err != unix.ESRCH { - // // pt.LogTrace("⚠️ Failed to resume PID %d: %v", pid, err) - // // } - // // } else { - // // pt.LogTrace("▢️ Resumed process: PID=%d", pid) - // // } - // }() - - // Wait a bit after stopping - // time.Sleep(1 * time.Millisecond) - - // Analyze with timeout if err := pt.analyzeBinary(pid, proc); err != nil { - pt.LogTrace("⚠️ Analysis failed: %v", err) + pt.LogTrace("Analysis failed: %v", err) } else { proc.Analyzed = true pt.markAnalyzed(pid, path) - // pt.LogTrace("βœ… Analysis complete: PID=%d", pid) } - - // go func() { - // if err := pt.analyzeBinary(pid, proc); err != nil { - // pt.LogTrace("⚠️ Analysis failed: %v", err) - // done <- false - // return - // } - - // proc.Analyzed = true - // pt.markAnalyzed(pid, path) - // pt.LogTrace("βœ… Analysis complete: PID=%d", pid) - // done <- true - // }() }(p.pid, p.path) } wg.Wait() - - // Wait between batches time.Sleep(10 * time.Microsecond) } - return nil } -func (pt *ProcessTracker) Cleanup() { -} - -// Helper to record process tree starting from a PID -func findAncestors(pid int, tree map[int]bool) { - for pid > 1 { - if tree[pid] { - return // Already visited +func (pt *ProcessTracker) logProcessTree() { + pt.mu.RLock() + defer pt.mu.RUnlock() + pt.t.Log("\n🌳 Process Tree:") + var printNode func(proc *ProcessInfo, prefix string) + printNode = func(proc *ProcessInfo, prefix string) { + state := "➑️" + switch proc.State { + case StateAnalyzing: + state = "πŸ”" + case StateRunning: + state = "▢️" + case StateExited: + state = "⏹️" } - tree[pid] = true - - // Get parent - ppid := getParentPID(pid) - if ppid <= 1 { - return + analyzed := "" + if proc.Analyzed { + analyzed = "βœ“" } - pid = ppid + pt.LogTrace("%s%s [PID=%d] %s%s (Parent=%d)", + prefix, state, proc.PID, filepath.Base(proc.BinaryPath), analyzed, proc.ParentPID) + for _, child := range proc.Children { + printNode(child, prefix+" ") + } + } + if main, exists := pt.processes[pt.mainPID]; exists { + printNode(main, "") } } -var g_cmd *exec.Cmd - func (pt *ProcessTracker) StartTracking(command string, args []string, dir string) error { - // ctx, cancel := context.WithCancel(context.Background()) - // defer cancel() - - // if err := pt.profiler.Start(ctx); err != nil { - // return fmt.Errorf("failed to start profiler: %w", err) - // } - // defer pt.profiler.Stop() - cmd := exec.Command(command, args...) g_cmd = cmd - if dir != "" { cmd.Dir = dir } - cmd.Env = append( - os.Environ(), + cmd.Env = append(os.Environ(), "PWD="+dir, "DD_DYNAMIC_INSTRUMENTATION_ENABLED=true", "DD_SERVICE=go-di-exploration-test-service") @@ -1441,23 +815,11 @@ func (pt *ProcessTracker) StartTracking(command string, args []string, dir strin pt.mainPID = cmd.Process.Pid pt.addProcess(pt.mainPID, os.Getpid()) - // Start scanning with high frequency initially go func() { - // Initial high-frequency scanning initialTicker := time.NewTicker(1 * time.Millisecond) defer initialTicker.Stop() - - // After initial period, reduce frequency slightly - // time.AfterFunc(5*time.Second, func() { - // initialTicker.Stop() - // }) - - // regularTicker := time.NewTicker(10 * time.Millisecond) - // defer regularTicker.Stop() - logTicker := time.NewTicker(10 * time.Second) defer logTicker.Stop() - for { select { case <-pt.stopChan: @@ -1467,86 +829,174 @@ func (pt *ProcessTracker) StartTracking(command string, args []string, dir strin pt.LogTrace("⚠️ Error scanning: %v", err) } case <-logTicker.C: - // pt.logProcessTree() } } }() err := cmd.Wait() close(pt.stopChan) - pt.LogTrace("Analyzed %d binaries.", len(analyzedBinaries)) - for _, binary := range analyzedBinaries { pt.LogTrace("Analyzed %s (debug info: %v)", binary.path, binary.hasDebug) } - return err } -func (pt *ProcessTracker) logProcessTree() { - pt.mu.RLock() - defer pt.mu.RUnlock() +func (pt *ProcessTracker) Cleanup() { + // Cleanup logic if needed. +} - pt.t.Log("\n🌳 Process Tree:") - var printNode func(proc *ProcessInfo, prefix string) - printNode = func(proc *ProcessInfo, prefix string) { - state := "➑️" - switch proc.State { - case StateAnalyzing: - state = "πŸ”" - case StateRunning: - state = "▢️" - case StateExited: - state = "⏹️" - } +func (pt *ProcessTracker) LogTrace(format string, args ...any) { + if TRACE { + pt.t.Logf(format, args...) + } +} - analyzed := "" - if proc.Analyzed { - analyzed = "βœ“" - } +func NewConfigAccumulator() (*ConfigAccumulator, error) { + tmpl, err := template.New("config_template").Parse(explorationTestConfigTemplateText) + if err != nil { + return nil, fmt.Errorf("failed to parse template: %w", err) + } + return &ConfigAccumulator{ + configs: make(map[string]map[string]rcConfig), + tmpl: tmpl, + }, nil +} - pt.LogTrace("%s%s [PID=%d] %s%s (Parent=%d)", - prefix, state, proc.PID, filepath.Base(proc.BinaryPath), analyzed, proc.ParentPID) +func (ca *ConfigAccumulator) AddTargets(targets []FunctionInfo, serviceName string) error { + ca.mu.Lock() + defer ca.mu.Unlock() - for _, child := range proc.Children { - printNode(child, prefix+" ") - } + buf := bufferPool.Get().(*bytes.Buffer) + buf.Reset() + defer bufferPool.Put(buf) + + buf.WriteString("{") + if err := ca.tmpl.Execute(buf, targets); err != nil { + return fmt.Errorf("failed to execute template: %w", err) } + buf.WriteString("}") - if main, exists := pt.processes[pt.mainPID]; exists { - printNode(main, "") + var newConfigs map[string]rcConfig + if err := json.NewDecoder(buf).Decode(&newConfigs); err != nil { + return fmt.Errorf("failed to decode generated configs: %w", err) + } + if ca.configs[serviceName] == nil { + ca.configs[serviceName] = make(map[string]rcConfig) } + for probeID, config := range newConfigs { + ca.configs[serviceName][probeID] = config + } + return nil } -var DEBUG bool = true -var TRACE bool = false +func (ca *ConfigAccumulator) WriteConfigs() error { + ca.mu.RLock() + defer ca.mu.RUnlock() -func (pt *ProcessTracker) LogTrace(format string, args ...any) { - if TRACE { - pt.t.Logf(format, args...) + buf := bufferPool.Get().(*bytes.Buffer) + buf.Reset() + defer bufferPool.Put(buf) + + if err := json.NewEncoder(buf).Encode(ca.configs); err != nil { + return fmt.Errorf("failed to marshal configs: %w", err) } + return g_ConfigManager.ConfigWriter.WriteSync(buf.Bytes()) } -func LogDebug(t *testing.T, format string, args ...any) { - if DEBUG { - t.Logf(format, args...) +func InspectBinary(t *testing.T, binaryPath string, pid int) error { + allFuncs, err := listAllFunctions(binaryPath) + if err != nil { + analyzedBinaries = append(analyzedBinaries, BinaryInfo{ + path: binaryPath, + hasDebug: false, + }) + return nil } -} -var g_RepoInfo *RepoInfo -var g_ConfigManager *diconfig.ReaderConfigManager + targets := filterFunctions(allFuncs) + args, err := getProcessArgs(pid) + if err != nil { + return fmt.Errorf("failed to process args: %v", err) + } + cwd, err := getProcessCwd(pid) + if err != nil { + return fmt.Errorf("failed to get Cwd: %v", err) + } + env, err := getProcessEnv(pid) + if err != nil { + return fmt.Errorf("failed to get Env: %v", err) + } + serviceName, err := extractDDService(env) + if err != nil { + return fmt.Errorf("failed to get Env: %v, binaryPath: %s", err, binaryPath) + } + + LogDebug(t, "\n=======================================") + LogDebug(t, "πŸ” SERVICE NAME: %s", serviceName) + LogDebug(t, "πŸ” ANALYZING BINARY: %s", binaryPath) + LogDebug(t, "πŸ” ENV: %v", env) + LogDebug(t, "πŸ” ARGS: %v", args) + LogDebug(t, "πŸ” CWD: %s", cwd) + LogDebug(t, "πŸ” Elected %d target functions:", len(targets)) + for _, f := range targets { + LogDebug(t, " β†’ Package: %s, Function: %s, FullName: %s", f.PackageName, f.FunctionName, f.FullName) + } + LogDebug(t, "=======================================") + + if _, err := os.Stat(binaryPath); err != nil { + return fmt.Errorf("(1) binary inspection failed: %v", err) + } + + analyzedBinaries = append(analyzedBinaries, BinaryInfo{ + path: binaryPath, + hasDebug: len(targets) > 0, + }) + LogDebug(t, "βœ… Analysis complete for: %s", binaryPath) + LogDebug(t, "=======================================\n") + + g_ConfigManager.ProcTracker.HandleProcessStartSync(uint32(pid)) + t.Logf("About to request instrumentations for binary: %s, pid: %d.", binaryPath, pid) + + if err := g_configsAccumulator.AddTargets(targets, serviceName); err != nil { + t.Logf("Error adding target: %v, binaryPath: %s", err, binaryPath) + return fmt.Errorf("add targets failed: %v, binary: %s", err, binaryPath) + } + if err = g_configsAccumulator.WriteConfigs(); err != nil { + t.Logf("Error writing configs: %v, binaryPath: %s", err, binaryPath) + return fmt.Errorf("error adding configs: %v, binary: %s", err, binaryPath) + } + time.Sleep(2 * time.Second) + t.Logf("Requested to instrument %d functions for binary: %s, pid: %d.", len(targets), binaryPath, pid) + for _, f := range targets { + t.Logf(" -> requested instrumentation for %v", f) + } + if waitForAttach && os.Getenv("DEBUG") == "true" { + pid := os.Getpid() + t.Logf("Waiting to attach for PID: %d", pid) + time.Sleep(30 * time.Second) + waitForAttach = false + } + return nil +} +// TestExplorationGoDI is the entrypoint of the integration test of Go DI. The idea is to +// test Go DI systematically and in exploratory manner. In high level, here are the steps this test takes: +// 1. Clones protobuf and applies patches. +// 2. Figuring out the 1st party packages involved with the cloned project (to avoid 3rd party/std libs) +// 3. Compiles the test +// 4. Runs the test in a supervised environment, spawning processes as a group. +// 5. Periodically pauses and resumes the process group to analyze each binary unique. +// 6. Invoke Go DI to put probes in top X functions defined by `NUMBER_OF_RROBES` const. +// +// The goal is to exercise as many code paths as possible of the Go DI system. func TestExplorationGoDI(t *testing.T) { require.NoError(t, rlimit.RemoveMemlock(), "Failed to remove memlock limit") if features.HaveMapType(ebpf.RingBuf) != nil { t.Skip("Ringbuffers not supported on this kernel") } - eventOutputWriter := &explorationEventOutputTestWriter{ - t: t, - } - + eventOutputWriter := &explorationEventOutputTestWriter{t: t} opts := &dynamicinstrumentation.DIOptions{ RateLimitPerProbePerSecond: 0.0, ReaderWriterOptions: dynamicinstrumentation.ReaderWriterOptions{ @@ -1560,7 +1010,6 @@ func TestExplorationGoDI(t *testing.T) { GoDI *dynamicinstrumentation.GoDI err error ) - GoDI, err = dynamicinstrumentation.RunDynamicInstrumentation(opts) require.NoError(t, err) t.Cleanup(GoDI.Close) @@ -1569,10 +1018,9 @@ func TestExplorationGoDI(t *testing.T) { if !ok { t.Fatal("Config manager is of wrong type") } - g_ConfigManager = cm - g_configsAccumulator, err = NewConfigAccumulator() + g_configsAccumulator, err = NewConfigAccumulator() if err != nil { t.Fatal("Failed to create ConfigAccumulator") } @@ -1590,23 +1038,6 @@ func TestExplorationGoDI(t *testing.T) { require.NoError(t, err) } -type explorationEventOutputTestWriter struct { - t *testing.T - expectedResult map[string]*ditypes.CapturedValue -} - -func (e *explorationEventOutputTestWriter) Write(p []byte) (n int, err error) { - var snapshot ditypes.SnapshotUpload - if err := json.Unmarshal(p, &snapshot); err != nil { - e.t.Error("failed to unmarshal snapshot", err) - } - - funcName := snapshot.Debugger.ProbeInSnapshot.Type + "." + snapshot.Debugger.ProbeInSnapshot.Method - e.t.Logf("Received snapshot for function: %s", funcName) - - return len(p), nil -} - func initializeTempDir(t *testing.T, predefinedTempDir string) string { if predefinedTempDir != "" { return predefinedTempDir @@ -1618,32 +1049,20 @@ func initializeTempDir(t *testing.T, predefinedTempDir string) string { return tempDir } -// RepoInfo holds scanned repository package information -type RepoInfo struct { - Packages map[string]bool // Package names found in repo - RepoPath string // Path to the repo - CommitHash string // Current commit hash (optional) -} - func ScanRepoPackages(repoPath string) (*RepoInfo, error) { info := &RepoInfo{ Packages: make(map[string]bool), RepoPath: repoPath, } - - // Get git hash if available if _, err := os.Stat(filepath.Join(repoPath, ".git")); err == nil { if hash, err := exec.Command("git", "-C", repoPath, "rev-parse", "HEAD").Output(); err == nil { info.CommitHash = strings.TrimSpace(string(hash)) } } - err := filepath.Walk(repoPath, func(path string, f os.FileInfo, err error) error { if err != nil { return nil } - - // Skip certain directories if f.IsDir() { dirname := filepath.Base(path) if dirname == ".git" || @@ -1656,29 +1075,21 @@ func ScanRepoPackages(repoPath string) (*RepoInfo, error) { } return nil } - - // Only process .go files if !strings.HasSuffix(path, ".go") { return nil } - - // Skip test files and generated files if strings.HasSuffix(path, "_test.go") || strings.HasSuffix(path, ".pb.go") { return nil } - - // Ensure the file is within the repo (not in .cache etc) relPath, err := filepath.Rel(repoPath, path) if err != nil || strings.Contains(relPath, "..") { return nil } - content, err := os.ReadFile(path) if err != nil { return nil } - scanner := bufio.NewScanner(bytes.NewReader(content)) for scanner.Scan() { line := strings.TrimSpace(scanner.Text()) @@ -1692,11 +1103,9 @@ func ScanRepoPackages(repoPath string) (*RepoInfo, error) { } return nil }) - if len(info.Packages) == 0 { return nil, fmt.Errorf("no packages found in repository at %s", repoPath) } - return info, err } @@ -1707,18 +1116,14 @@ func cloneProtobufRepo(t *testing.T, modulePath string, commitHash string) *Repo cmd.Stderr = os.Stderr require.NoError(t, cmd.Run(), "Failed to clone repository") } - if commitHash != "" { cmd := exec.Command("git", "checkout", commitHash) cmd.Dir = modulePath require.NoError(t, cmd.Run(), "Failed to checkout commit hash") } - - // Scan packages after clone/checkout info, err := ScanRepoPackages(modulePath) require.NoError(t, err, "Failed to scan repo packages") - // Log the organized package information var pkgs []string for pkg := range info.Packages { if strings.Contains(pkg, "/tmp") { @@ -1727,10 +1132,8 @@ func cloneProtobufRepo(t *testing.T, modulePath string, commitHash string) *Repo pkgs = append(pkgs, pkg) } sort.Strings(pkgs) - t.Logf("πŸ“¦ Found %d packages in protobuf repo:", len(pkgs)) - // Group packages by their top-level directory groups := make(map[string][]string) for _, pkg := range pkgs { parts := strings.SplitN(pkg, "/", 2) @@ -1738,20 +1141,17 @@ func cloneProtobufRepo(t *testing.T, modulePath string, commitHash string) *Repo groups[topLevel] = append(groups[topLevel], pkg) } - // Print grouped packages var topLevels []string for k := range groups { topLevels = append(topLevels, k) } sort.Strings(topLevels) - for _, topLevel := range topLevels { t.Logf(" %s/", topLevel) for _, pkg := range groups[topLevel] { t.Logf(" β†’ %s", pkg) } } - return info } @@ -1767,16 +1167,13 @@ func copyDir(src, dst string) error { if err := os.MkdirAll(dst, 0755); err != nil { return err } - for _, entry := range entries { srcPath := filepath.Join(src, entry.Name()) dstPath := filepath.Join(dst, entry.Name()) - info, err := entry.Info() if err != nil { return err } - if info.IsDir() { if err = copyDir(srcPath, dstPath); err != nil { return err @@ -1796,57 +1193,14 @@ func copyFile(srcFile, dstFile string) error { return err } defer src.Close() - if err = os.MkdirAll(filepath.Dir(dstFile), 0755); err != nil { return err } - dst, err := os.Create(dstFile) if err != nil { return err } defer dst.Close() - _, err = io.Copy(dst, src) return err } - -var explorationTestConfigTemplateText = ` - {{- range $index, $target := .}} - {{- if $index}},{{end}} - "{{$target.ProbeId}}": { - "id": "{{$target.ProbeId}}", - "version": 0, - "type": "LOG_PROBE", - "language": "go", - "where": { - "typeName": "{{$target.PackageName}}", - "methodName": "{{$target.FunctionName}}" - }, - "tags": [], - "template": "Executed {{$target.PackageName}}.{{$target.FunctionName}}, it took {@duration}ms", - "segments": [ - { - "str": "Executed {{$target.PackageName}}.{{$target.FunctionName}}, it took " - }, - { - "dsl": "@duration", - "json": { - "ref": "@duration" - } - }, - { - "str": "ms" - } - ], - "captureSnapshot": false, - "capture": { - "maxReferenceDepth": 10 - }, - "sampling": { - "snapshotsPerSecond": 5000 - }, - "evaluateAt": "EXIT" - } - {{- end}} -` diff --git a/pkg/dynamicinstrumentation/testutil/exploration_tests/patches/protobuf/integration_test.go b/pkg/dynamicinstrumentation/testutil/exploration_tests/patches/protobuf/integration_test.go index ad9b259b8a85d4..cb8ef847f40b34 100644 --- a/pkg/dynamicinstrumentation/testutil/exploration_tests/patches/protobuf/integration_test.go +++ b/pkg/dynamicinstrumentation/testutil/exploration_tests/patches/protobuf/integration_test.go @@ -14,6 +14,7 @@ import ( "fmt" "io" "io/fs" + "math/rand" "net/http" "os" "os/exec" @@ -536,7 +537,12 @@ func (c command) mustRun(t *testing.T, args ...string) string { for i, arg := range args { cmdArgs = append(cmdArgs, arg) if i == 1 { // right after "test" - cmdArgs = append(cmdArgs, "-ldflags=-w=false -s=false", "-count=1", "-timeout=30m") + cmdArgs = append(cmdArgs, + "-ldflags=-w=false -s=false", + "-gcflags=all=-l", + "-count=1", + "-timeout=30m", + ) } } } else { @@ -552,7 +558,10 @@ func (c command) mustRun(t *testing.T, args ...string) string { if c.Env != nil { cmd.Env = c.Env } - cmd.Env = append(cmd.Env, "PWD="+cmd.Dir) + cmd.Env = append(cmd.Env, + fmt.Sprintf("PWD=%s", cmd.Dir), + fmt.Sprintf("DD_SERVICE=go-di-exploration-test-%d", rand.Int()), + ) cmd.Stdout = stdout cmd.Stderr = stderr From 5f602f206a9303337c371efabdcae09147496db2 Mon Sep 17 00:00:00 2001 From: Matan Green Date: Tue, 25 Feb 2025 16:21:25 +0200 Subject: [PATCH 4/6] Deleted exploration tests - to be added in a separate PR --- .../diconfig/mem_config_manager.go | 31 +- .../proctracker/proctracker.go | 6 - .../testutil/exploration_e2e_test.go | 1206 ----------------- .../patches/protobuf/integration_test.go | 595 -------- .../patches/protobuf/test.bash | 7 - .../protobuf/testing/prototest/message.go | 911 ------------- 6 files changed, 11 insertions(+), 2745 deletions(-) delete mode 100644 pkg/dynamicinstrumentation/testutil/exploration_e2e_test.go delete mode 100644 pkg/dynamicinstrumentation/testutil/exploration_tests/patches/protobuf/integration_test.go delete mode 100644 pkg/dynamicinstrumentation/testutil/exploration_tests/patches/protobuf/test.bash delete mode 100644 pkg/dynamicinstrumentation/testutil/exploration_tests/patches/protobuf/testing/prototest/message.go diff --git a/pkg/dynamicinstrumentation/diconfig/mem_config_manager.go b/pkg/dynamicinstrumentation/diconfig/mem_config_manager.go index 07ad2d0df262ee..381080eab517f3 100644 --- a/pkg/dynamicinstrumentation/diconfig/mem_config_manager.go +++ b/pkg/dynamicinstrumentation/diconfig/mem_config_manager.go @@ -24,7 +24,7 @@ import ( type ReaderConfigManager struct { sync.Mutex ConfigWriter *ConfigWriter - ProcTracker *proctracker.ProcessTracker + procTracker *proctracker.ProcessTracker callback configUpdateCallback configs configsByService @@ -40,8 +40,8 @@ func NewReaderConfigManager() (*ReaderConfigManager, error) { state: ditypes.NewDIProcs(), } - cm.ProcTracker = proctracker.NewProcessTracker(cm.updateProcessInfo) - err := cm.ProcTracker.Start() + cm.procTracker = proctracker.NewProcessTracker(cm.updateProcessInfo) + err := cm.procTracker.Start() if err != nil { return nil, err } @@ -63,7 +63,7 @@ func (cm *ReaderConfigManager) GetProcInfos() ditypes.DIProcs { // Stop causes the ReaderConfigManager to stop processing data func (cm *ReaderConfigManager) Stop() { cm.ConfigWriter.Stop() - cm.ProcTracker.Stop() + cm.procTracker.Stop() } func (cm *ReaderConfigManager) update() error { @@ -159,10 +159,6 @@ func (r *ConfigWriter) Write(p []byte) (n int, e error) { return 0, nil } -func (r *ConfigWriter) WriteSync(p []byte) error { - return r.parseRawConfigBytesAndTriggerCallback(p) -} - // Start initiates the ConfigWriter to start processing data func (r *ConfigWriter) Start() error { go func() { @@ -170,7 +166,13 @@ func (r *ConfigWriter) Start() error { for { select { case rawConfigBytes := <-r.updateChannel: - r.parseRawConfigBytesAndTriggerCallback(rawConfigBytes) + conf := map[string]map[string]rcConfig{} + err := json.Unmarshal(rawConfigBytes, &conf) + if err != nil { + log.Errorf("invalid config read from reader: %v", err) + continue + } + r.configCallback(conf) case <-r.stopChannel: break configUpdateLoop } @@ -179,17 +181,6 @@ func (r *ConfigWriter) Start() error { return nil } -func (r *ConfigWriter) parseRawConfigBytesAndTriggerCallback(rawConfigBytes []byte) error { - conf := map[string]map[string]rcConfig{} - err := json.Unmarshal(rawConfigBytes, &conf) - if err != nil { - log.Errorf("invalid config read from reader: %v", err) - return fmt.Errorf("invalid config read from reader: %v", err) - } - r.configCallback(conf) - return nil -} - // Stop causes the ConfigWriter to stop processing data func (r *ConfigWriter) Stop() { r.stopChannel <- true diff --git a/pkg/dynamicinstrumentation/proctracker/proctracker.go b/pkg/dynamicinstrumentation/proctracker/proctracker.go index 85ef944cd74a93..663dc5608497b7 100644 --- a/pkg/dynamicinstrumentation/proctracker/proctracker.go +++ b/pkg/dynamicinstrumentation/proctracker/proctracker.go @@ -84,12 +84,6 @@ func (pt *ProcessTracker) Stop() { } } -func (pt *ProcessTracker) HandleProcessStartSync(pid uint32) { - exePath := filepath.Join(pt.procRoot, strconv.FormatUint(uint64(pid), 10), "exe") - - pt.inspectBinary(exePath, pid) -} - func (pt *ProcessTracker) handleProcessStart(pid uint32) { exePath := filepath.Join(pt.procRoot, strconv.FormatUint(uint64(pid), 10), "exe") diff --git a/pkg/dynamicinstrumentation/testutil/exploration_e2e_test.go b/pkg/dynamicinstrumentation/testutil/exploration_e2e_test.go deleted file mode 100644 index 9f93fa23402af6..00000000000000 --- a/pkg/dynamicinstrumentation/testutil/exploration_e2e_test.go +++ /dev/null @@ -1,1206 +0,0 @@ -// Unless explicitly stated otherwise all files in this repository are licensed -// under the Apache License Version 2.0. -// This product includes software developed at Datadog (https://www.datadoghq.com/). -// Copyright 2016-present Datadog, Inc. - -//go:build linux_bpf - -package testutil - -import ( - "bufio" - "bytes" - "crypto/sha256" - "debug/dwarf" - "debug/elf" - "encoding/hex" - "encoding/json" - "fmt" - "html/template" - "io" - "os" - "os/exec" - "path/filepath" - "sort" - "strconv" - "strings" - "sync" - "syscall" - "testing" - "time" - - "github.com/cilium/ebpf" - "github.com/cilium/ebpf/features" - "github.com/cilium/ebpf/rlimit" - "github.com/google/uuid" - "github.com/stretchr/testify/require" - "golang.org/x/sys/unix" - - "github.com/DataDog/datadog-agent/pkg/dynamicinstrumentation" - "github.com/DataDog/datadog-agent/pkg/dynamicinstrumentation/diconfig" - "github.com/DataDog/datadog-agent/pkg/dynamicinstrumentation/ditypes" -) - -type ProcessState int - -const ( - StateNew ProcessState = iota - StateAnalyzing - StateRunning - StateExited -) - -func (s ProcessState) String() string { - switch s { - case StateNew: - return "NEW" - case StateAnalyzing: - return "ANALYZING" - case StateRunning: - return "RUNNING" - case StateExited: - return "EXITED" - default: - return "UNKNOWN" - } -} - -type ProcessInfo struct { - PID int - BinaryPath string - ParentPID int - State ProcessState - Children []*ProcessInfo - StartTime time.Time - Analyzed bool -} - -type ProcessTracker struct { - t *testing.T - mu sync.RWMutex - processes map[int]*ProcessInfo - mainPID int - stopChan chan struct{} - analyzedBinaries map[string]bool - analyzedPIDs map[int]bool - done chan struct{} -} - -type ProbeManager struct { - t *testing.T - installedProbes sync.Map // maps pid -> map[string]struct{} - dataReceived sync.Map // maps pid -> map[string]bool - mu sync.Mutex -} - -type BinaryInfo struct { - path string - hasDebug bool -} - -type FunctionInfo struct { - PackageName string - FunctionName string - FullName string - ProbeId string -} - -func NewFunctionInfo(packageName, functionName, fullName string) FunctionInfo { - return FunctionInfo{ - PackageName: packageName, - FunctionName: functionName, - FullName: fullName, - ProbeId: uuid.NewString(), - } -} - -type rcConfig struct { - ID string - Version int - ProbeType string `json:"type"` - Language string - Where struct { - TypeName string `json:"typeName"` - MethodName string `json:"methodName"` - SourceFile string - Lines []string - } - Tags []string - Template string - CaptureSnapshot bool - EvaluatedAt string - Capture struct { - MaxReferenceDepth int `json:"maxReferenceDepth"` - MaxFieldCount int `json:"maxFieldCount"` - } -} - -type ConfigAccumulator struct { - configs map[string]map[string]rcConfig - tmpl *template.Template - mu sync.RWMutex -} - -type RepoInfo struct { - Packages map[string]bool // Package names found in repo - RepoPath string // Path to the repo - CommitHash string // Current commit hash (optional) -} - -type explorationEventOutputTestWriter struct { - t *testing.T - expectedResult map[string]*ditypes.CapturedValue -} - -func (e *explorationEventOutputTestWriter) Write(p []byte) (n int, err error) { - var snapshot ditypes.SnapshotUpload - if err := json.Unmarshal(p, &snapshot); err != nil { - e.t.Error("failed to unmarshal snapshot", err) - } - funcName := snapshot.Debugger.ProbeInSnapshot.Type + "." + snapshot.Debugger.ProbeInSnapshot.Method - e.t.Logf("Received snapshot for function: %s", funcName) - return len(p), nil -} - -var ( - analyzedBinaries []BinaryInfo - waitForAttach bool = true - bufferPool = sync.Pool{New: func() interface{} { return new(bytes.Buffer) }} - seenBinaries = make(map[string]struct{}) - g_configsAccumulator *ConfigAccumulator - g_RepoInfo *RepoInfo - g_ConfigManager *diconfig.ReaderConfigManager - g_cmd *exec.Cmd - DEBUG bool = true - TRACE bool = false - NUMBER_OF_PROBES int = 10 - - explorationTestConfigTemplateText = ` - {{- range $index, $target := .}} - {{- if $index}},{{end}} - "{{$target.ProbeId}}": { - "id": "{{$target.ProbeId}}", - "version": 0, - "type": "LOG_PROBE", - "language": "go", - "where": { - "typeName": "{{$target.PackageName}}", - "methodName": "{{$target.FunctionName}}" - }, - "tags": [], - "template": "Executed {{$target.PackageName}}.{{$target.FunctionName}}, it took {@duration}ms", - "segments": [ - { - "str": "Executed {{$target.PackageName}}.{{$target.FunctionName}}, it took " - }, - { - "dsl": "@duration", - "json": { - "ref": "@duration" - } - }, - { - "str": "ms" - } - ], - "captureSnapshot": false, - "capture": { - "maxReferenceDepth": 10 - }, - "sampling": { - "snapshotsPerSecond": 5000 - }, - "evaluateAt": "EXIT" - } - {{- end}} -` -) - -func getProcessArgs(pid int) ([]string, error) { - data, err := os.ReadFile(fmt.Sprintf("/proc/%d/cmdline", pid)) - if err != nil { - return nil, err - } - args := strings.Split(string(data), "\x00") - if len(args) > 0 && args[len(args)-1] == "" { - args = args[:len(args)-1] - } - return args, nil -} - -func getProcessCwd(pid int) (string, error) { - cwd, err := os.Readlink(fmt.Sprintf("/proc/%d/cwd", pid)) - if err != nil { - return "", err - } - return cwd, nil -} - -func getProcessEnv(pid int) ([]string, error) { - data, err := os.ReadFile(fmt.Sprintf("/proc/%d/environ", pid)) - if err != nil { - return nil, err - } - env := strings.Split(string(data), "\x00") - if len(env) > 0 && env[len(env)-1] == "" { - env = env[:len(env)-1] - } - return env, nil -} - -func extractDDService(env []string) (string, error) { - for _, entry := range env { - if strings.HasPrefix(entry, "DD_SERVICE=") { - return strings.TrimPrefix(entry, "DD_SERVICE="), nil - } - } - return "", fmt.Errorf("DD_SERVICE not found") -} - -func hasDWARFInfo(binaryPath string) (bool, error) { - f, err := elf.Open(binaryPath) - if err != nil { - return false, fmt.Errorf("failed to open binary: %w", err) - } - defer f.Close() - - debugSections := false - for _, section := range f.Sections { - if strings.HasPrefix(section.Name, ".debug_") { - fmt.Printf("Found debug section: %s (size: %d)\n", section.Name, section.Size) - debugSections = true - } - } - - dwarfData, err := f.DWARF() - if err != nil { - return debugSections, fmt.Errorf("DWARF read error: %w", err) - } - - reader := dwarfData.Reader() - entry, err := reader.Next() - if err != nil { - return debugSections, fmt.Errorf("DWARF entry read error: %w", err) - } - if entry != nil { - fmt.Printf("Found DWARF entry of type: %v\n", entry.Tag) - return true, nil - } - return false, nil -} - -func listAllFunctions(filePath string) ([]FunctionInfo, error) { - var functions []FunctionInfo - var errors []string - - ef, err := elf.Open(filePath) - if err != nil { - return nil, fmt.Errorf("failed to open file: %v", err) - } - defer ef.Close() - - dwarfData, err := ef.DWARF() - if err != nil { - return nil, fmt.Errorf("failed to load DWARF data: %v", err) - } - - reader := dwarfData.Reader() - for { - entry, err := reader.Next() - if err != nil { - return nil, fmt.Errorf("error reading DWARF entry: %v", err) - } - if entry == nil { - break - } - if entry.Tag == dwarf.TagSubprogram { - funcName, ok := entry.Val(dwarf.AttrName).(string) - if !ok || funcName == "" { - continue - } - info := extractPackageAndFunction(funcName) - if info.FunctionName == "" { - errors = append(errors, fmt.Sprintf("could not extract function name from %q", funcName)) - continue - } - functions = append(functions, info) - } - } - - if len(functions) == 0 { - if len(errors) > 0 { - return nil, fmt.Errorf("failed to extract any functions. Errors: %s", strings.Join(errors, "; ")) - } - return nil, fmt.Errorf("no functions found in the binary") - } - return functions, nil -} - -func extractPackageAndFunction(fullName string) FunctionInfo { - if fullName == "" { - return FunctionInfo{} - } - parenIndex := strings.Index(fullName, "(") - lastDot := -1 - if parenIndex != -1 { - lastDot = strings.LastIndex(fullName[:parenIndex], ".") - } else { - lastDot = strings.LastIndex(fullName, ".") - } - if lastDot == -1 { - return FunctionInfo{} - } - pkgPath := fullName[:lastDot] - funcPart := fullName[lastDot+1:] - return NewFunctionInfo(pkgPath, funcPart, fullName) -} - -func shouldProfileFunction(name string) bool { - if strings.HasPrefix(name, "*ZN") || - strings.HasPrefix(name, "_") || - strings.Contains(name, "_sanitizer") || - strings.Contains(name, "runtime.") { - return false - } - parts := strings.Split(name, ".") - if len(parts) < 2 { - return false - } - pkgPath := parts[0] - if len(parts) > 2 { - pkgPath = strings.Join(parts[:len(parts)-1], "/") - } - for repoPkg := range g_RepoInfo.Packages { - if strings.Contains(pkgPath, repoPkg) { - return true - } - } - return false -} - -func filterFunctions(funcs []FunctionInfo) []FunctionInfo { - var validFuncs []FunctionInfo - for _, f := range funcs { - fullName := fmt.Sprintf("%s.%s", f.PackageName, f.FunctionName) - if shouldProfileFunction(fullName) { - validFuncs = append(validFuncs, f) - } - } - if len(validFuncs) == 0 { - return nil - } - sort.Slice(validFuncs, func(i, j int) bool { - fullNameI := fmt.Sprintf("%s.%s", validFuncs[i].PackageName, validFuncs[i].FunctionName) - fullNameJ := fmt.Sprintf("%s.%s", validFuncs[j].PackageName, validFuncs[j].FunctionName) - return fullNameI < fullNameJ - }) - if len(validFuncs) <= NUMBER_OF_PROBES { - return validFuncs - } - return validFuncs[:NUMBER_OF_PROBES] -} - -func ExtractFunctions(binaryPath string) ([]FunctionInfo, error) { - file, err := elf.Open(binaryPath) - if err != nil { - return nil, fmt.Errorf("failed to open binary: %v", err) - } - defer file.Close() - - dwarfData, err := file.DWARF() - if err != nil { - return nil, fmt.Errorf("failed to load DWARF data: %v", err) - } - - var functions []FunctionInfo - reader := dwarfData.Reader() - for { - entry, err := reader.Next() - if err != nil { - return nil, fmt.Errorf("error reading DWARF: %v", err) - } - if entry == nil { - break - } - if entry.Tag == dwarf.TagSubprogram { - funcName, _ := entry.Val(dwarf.AttrName).(string) - var packageName string - if compDir, ok := entry.Val(dwarf.AttrCompDir).(string); ok { - packageName = compDir - } - if funcName != "" { - functions = append(functions, FunctionInfo{ - PackageName: packageName, - FunctionName: funcName, - }) - } - } - } - return functions, nil -} - -func hasDWARF(binaryPath string) (bool, error) { - file, err := elf.Open(binaryPath) - if err != nil { - return false, fmt.Errorf("failed to open binary: %v", err) - } - defer file.Close() - - _, err = file.DWARF() - if err != nil { - if err.Error() == "no DWARF data" { - return false, nil - } - return false, fmt.Errorf("failed to check DWARF data: %v", err) - } - return true, nil -} - -func fingerprintGoBinary(binaryPath string) (string, error) { - f, err := elf.Open(binaryPath) - if err != nil { - return "", err - } - defer f.Close() - - sections := make([]*elf.Section, len(f.Sections)) - copy(sections, f.Sections) - sort.Slice(sections, func(i, j int) bool { - return sections[i].Name < sections[j].Name - }) - - hash := sha256.New() - for _, sec := range sections { - if sec.Type == elf.SHT_NOBITS || sec.Name == ".note.go.buildid" { - continue - } - if _, err := io.WriteString(hash, sec.Name); err != nil { - return "", err - } - data, err := sec.Data() - if err != nil { - return "", err - } - if _, err := hash.Write(data); err != nil { - return "", err - } - } - return hex.EncodeToString(hash.Sum(nil)), nil -} - -func isAlreadyProcessed(binaryPath string) (bool, error) { - fingerprint, err := fingerprintGoBinary(binaryPath) - if err != nil { - return false, err - } - if _, exists := seenBinaries[fingerprint]; exists { - return true, nil - } - seenBinaries[fingerprint] = struct{}{} - return false, nil -} - -func getBinaryPath(pid int) string { - path, err := os.Readlink(fmt.Sprintf("/proc/%d/exe", pid)) - if err != nil { - return "" - } - realPath, err := filepath.EvalSymlinks(path) - if err == nil { - path = realPath - } - return path -} - -func getParentPID(pid int) int { - ppidStr, err := os.ReadFile(fmt.Sprintf("/proc/%d/stat", pid)) - if err != nil { - return 0 - } - fields := strings.Fields(string(ppidStr)) - if len(fields) < 4 { - return 0 - } - ppid, _ := strconv.Atoi(fields[3]) - return ppid -} - -func findAncestors(pid int, tree map[int]bool) { - for pid > 1 { - if tree[pid] { - return - } - tree[pid] = true - ppid := getParentPID(pid) - if ppid <= 1 { - return - } - pid = ppid - } -} - -func LogDebug(t *testing.T, format string, args ...any) { - if DEBUG { - t.Logf(format, args...) - } -} - -func NewProbeManager(t *testing.T) *ProbeManager { - return &ProbeManager{t: t} -} - -func (pm *ProbeManager) Install(pid int, function string) error { - pm.mu.Lock() - defer pm.mu.Unlock() - - v, _ := pm.installedProbes.LoadOrStore(pid, make(map[string]struct{})) - probes := v.(map[string]struct{}) - probes[function] = struct{}{} - pm.t.Logf("Installing probe: PID=%d Function=%s", pid, function) - return nil -} - -func (pm *ProbeManager) Remove(pid int, function string) error { - pm.mu.Lock() - defer pm.mu.Unlock() - - if v, ok := pm.installedProbes.Load(pid); ok { - probes := v.(map[string]struct{}) - delete(probes, function) - pm.t.Logf("Removing probe: PID=%d Function=%s", pid, function) - } - return nil -} - -func (pm *ProbeManager) CollectData(pid int, function string) (bool, error) { - if v, ok := pm.dataReceived.Load(pid); ok { - dataMap := v.(map[string]bool) - return dataMap[function], nil - } - return false, nil -} - -func NewProcessTracker(t *testing.T) *ProcessTracker { - return &ProcessTracker{ - t: t, - processes: make(map[int]*ProcessInfo), - stopChan: make(chan struct{}), - analyzedBinaries: make(map[string]bool), - analyzedPIDs: make(map[int]bool), - done: make(chan struct{}), - } -} - -func (pt *ProcessTracker) markAnalyzed(pid int, path string) { - pt.mu.Lock() - defer pt.mu.Unlock() - pt.analyzedPIDs[pid] = true - pt.analyzedBinaries[path] = true -} - -func (pt *ProcessTracker) addProcess(pid int, parentPID int) *ProcessInfo { - pt.mu.Lock() - defer pt.mu.Unlock() - - if proc, exists := pt.processes[pid]; exists { - return proc - } - - binaryPath := getBinaryPath(pid) - proc := &ProcessInfo{ - PID: pid, - ParentPID: parentPID, - BinaryPath: binaryPath, - State: StateNew, - StartTime: time.Now(), - Analyzed: false, - } - pt.processes[pid] = proc - if parent, exists := pt.processes[parentPID]; exists { - parent.Children = append(parent.Children, proc) - } - pt.LogTrace("New process: PID=%d, Parent=%d, Binary=%s", pid, parentPID, binaryPath) - return proc -} - -func (pt *ProcessTracker) analyzeBinary(pid int, info *ProcessInfo) error { - if info == nil { - return fmt.Errorf("nil process info") - } - pt.mu.Lock() - info.State = StateAnalyzing - pt.mu.Unlock() - - if err := InspectBinary(pt.t, info.BinaryPath, pid); err != nil { - pt.mu.Lock() - info.State = StateNew - pt.mu.Unlock() - return fmt.Errorf("binary analysis failed: %v", err) - } - pt.mu.Lock() - info.State = StateRunning - pt.mu.Unlock() - return nil -} - -func (pt *ProcessTracker) scanProcessTree() error { - if err := syscall.Kill(-g_cmd.Process.Pid, syscall.SIGSTOP); err != nil { - if err != unix.ESRCH { - pt.LogTrace("⚠️ Failed to stop PID %d: %v", -g_cmd.Process.Pid, err) - } - return nil - } - defer func() { - if err := syscall.Kill(-g_cmd.Process.Pid, syscall.SIGCONT); err != nil { - if err != unix.ESRCH { - pt.LogTrace("⚠️ Failed to resume PID %d: %v", -g_cmd.Process.Pid, err) - } - } - }() - allPids := make(map[int]bool) - if entries, err := os.ReadDir("/proc"); err == nil { - for _, entry := range entries { - if pid, err := strconv.Atoi(entry.Name()); err == nil { - allPids[pid] = true - } - } - } - ourProcessTree := make(map[int]bool) - ourPid := os.Getpid() - findAncestors(ourPid, ourProcessTree) - - var toAnalyze []struct { - pid int - path string - ppid int - } - for pid := range allPids { - pt.mu.RLock() - if pt.analyzedPIDs[pid] { - pt.mu.RUnlock() - continue - } - pt.mu.RUnlock() - if ourProcessTree[pid] { - continue - } - binaryPath := getBinaryPath(pid) - if binaryPath == "" { - continue - } - ppid := getParentPID(pid) - if ourProcessTree[ppid] { - continue - } - shouldAnalyze := false - if strings.HasSuffix(binaryPath, ".test") { - shouldAnalyze = true - pt.LogTrace("Found test binary: %s (PID=%d)", binaryPath, pid) - } else if strings.Contains(binaryPath, "/go-build") && strings.Contains(binaryPath, "/exe/") { - shouldAnalyze = true - pt.LogTrace("Found build binary: %s (PID=%d)", binaryPath, pid) - } else { - parentPath := getBinaryPath(ppid) - if strings.HasSuffix(parentPath, ".test") { - shouldAnalyze = true - pt.LogTrace("Found child of test: %s (PID=%d, Parent=%d)", binaryPath, pid, ppid) - } - } - if shouldAnalyze { - if _, err := os.Stat(fmt.Sprintf("/proc/%d", pid)); err == nil { - toAnalyze = append(toAnalyze, struct { - pid int - path string - ppid int - }{pid, binaryPath, ppid}) - if pt.processes[pid] == nil { - pt.addProcess(pid, ppid) - } - } - } - } - if len(toAnalyze) > 0 { - pt.LogTrace("πŸ” Found %d processes to analyze:", len(toAnalyze)) - for _, p := range toAnalyze { - pt.LogTrace(" PID=%d PPID=%d Path=%s", p.pid, p.ppid, p.path) - } - } - var activePids []int - for _, p := range toAnalyze { - activePids = append(activePids, p.pid) - } - batchSize := 2 - for i := 0; i < len(toAnalyze); i += batchSize { - end := i + batchSize - if end > len(toAnalyze) { - end = len(toAnalyze) - } - var wg sync.WaitGroup - for _, p := range toAnalyze[i:end] { - wg.Add(1) - go func(pid int, path string) { - defer wg.Done() - if _, err := os.Stat(fmt.Sprintf("/proc/%d", pid)); err != nil { - return - } - pt.LogTrace("Stopping process for analysis: PID=%d Path=%s", pid, path) - pt.mu.RLock() - proc := pt.processes[pid] - pt.mu.RUnlock() - if proc == nil { - return - } - if err := pt.analyzeBinary(pid, proc); err != nil { - pt.LogTrace("Analysis failed: %v", err) - } else { - proc.Analyzed = true - pt.markAnalyzed(pid, path) - } - }(p.pid, p.path) - } - wg.Wait() - time.Sleep(10 * time.Microsecond) - } - return nil -} - -func (pt *ProcessTracker) logProcessTree() { - pt.mu.RLock() - defer pt.mu.RUnlock() - pt.t.Log("\n🌳 Process Tree:") - var printNode func(proc *ProcessInfo, prefix string) - printNode = func(proc *ProcessInfo, prefix string) { - state := "➑️" - switch proc.State { - case StateAnalyzing: - state = "πŸ”" - case StateRunning: - state = "▢️" - case StateExited: - state = "⏹️" - } - analyzed := "" - if proc.Analyzed { - analyzed = "βœ“" - } - pt.LogTrace("%s%s [PID=%d] %s%s (Parent=%d)", - prefix, state, proc.PID, filepath.Base(proc.BinaryPath), analyzed, proc.ParentPID) - for _, child := range proc.Children { - printNode(child, prefix+" ") - } - } - if main, exists := pt.processes[pt.mainPID]; exists { - printNode(main, "") - } -} - -func (pt *ProcessTracker) StartTracking(command string, args []string, dir string) error { - cmd := exec.Command(command, args...) - g_cmd = cmd - if dir != "" { - cmd.Dir = dir - } - cmd.Env = append(os.Environ(), - "PWD="+dir, - "DD_DYNAMIC_INSTRUMENTATION_ENABLED=true", - "DD_SERVICE=go-di-exploration-test-service") - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} - - if err := cmd.Start(); err != nil { - return fmt.Errorf("failed to start command: %v", err) - } - - pt.mainPID = cmd.Process.Pid - pt.addProcess(pt.mainPID, os.Getpid()) - - go func() { - initialTicker := time.NewTicker(1 * time.Millisecond) - defer initialTicker.Stop() - logTicker := time.NewTicker(10 * time.Second) - defer logTicker.Stop() - for { - select { - case <-pt.stopChan: - return - case <-initialTicker.C: - if err := pt.scanProcessTree(); err != nil { - pt.LogTrace("⚠️ Error scanning: %v", err) - } - case <-logTicker.C: - } - } - }() - - err := cmd.Wait() - close(pt.stopChan) - pt.LogTrace("Analyzed %d binaries.", len(analyzedBinaries)) - for _, binary := range analyzedBinaries { - pt.LogTrace("Analyzed %s (debug info: %v)", binary.path, binary.hasDebug) - } - return err -} - -func (pt *ProcessTracker) Cleanup() { - // Cleanup logic if needed. -} - -func (pt *ProcessTracker) LogTrace(format string, args ...any) { - if TRACE { - pt.t.Logf(format, args...) - } -} - -func NewConfigAccumulator() (*ConfigAccumulator, error) { - tmpl, err := template.New("config_template").Parse(explorationTestConfigTemplateText) - if err != nil { - return nil, fmt.Errorf("failed to parse template: %w", err) - } - return &ConfigAccumulator{ - configs: make(map[string]map[string]rcConfig), - tmpl: tmpl, - }, nil -} - -func (ca *ConfigAccumulator) AddTargets(targets []FunctionInfo, serviceName string) error { - ca.mu.Lock() - defer ca.mu.Unlock() - - buf := bufferPool.Get().(*bytes.Buffer) - buf.Reset() - defer bufferPool.Put(buf) - - buf.WriteString("{") - if err := ca.tmpl.Execute(buf, targets); err != nil { - return fmt.Errorf("failed to execute template: %w", err) - } - buf.WriteString("}") - - var newConfigs map[string]rcConfig - if err := json.NewDecoder(buf).Decode(&newConfigs); err != nil { - return fmt.Errorf("failed to decode generated configs: %w", err) - } - if ca.configs[serviceName] == nil { - ca.configs[serviceName] = make(map[string]rcConfig) - } - for probeID, config := range newConfigs { - ca.configs[serviceName][probeID] = config - } - return nil -} - -func (ca *ConfigAccumulator) WriteConfigs() error { - ca.mu.RLock() - defer ca.mu.RUnlock() - - buf := bufferPool.Get().(*bytes.Buffer) - buf.Reset() - defer bufferPool.Put(buf) - - if err := json.NewEncoder(buf).Encode(ca.configs); err != nil { - return fmt.Errorf("failed to marshal configs: %w", err) - } - return g_ConfigManager.ConfigWriter.WriteSync(buf.Bytes()) -} - -func InspectBinary(t *testing.T, binaryPath string, pid int) error { - allFuncs, err := listAllFunctions(binaryPath) - if err != nil { - analyzedBinaries = append(analyzedBinaries, BinaryInfo{ - path: binaryPath, - hasDebug: false, - }) - return nil - } - - targets := filterFunctions(allFuncs) - args, err := getProcessArgs(pid) - if err != nil { - return fmt.Errorf("failed to process args: %v", err) - } - cwd, err := getProcessCwd(pid) - if err != nil { - return fmt.Errorf("failed to get Cwd: %v", err) - } - env, err := getProcessEnv(pid) - if err != nil { - return fmt.Errorf("failed to get Env: %v", err) - } - serviceName, err := extractDDService(env) - if err != nil { - return fmt.Errorf("failed to get Env: %v, binaryPath: %s", err, binaryPath) - } - - LogDebug(t, "\n=======================================") - LogDebug(t, "πŸ” SERVICE NAME: %s", serviceName) - LogDebug(t, "πŸ” ANALYZING BINARY: %s", binaryPath) - LogDebug(t, "πŸ” ENV: %v", env) - LogDebug(t, "πŸ” ARGS: %v", args) - LogDebug(t, "πŸ” CWD: %s", cwd) - LogDebug(t, "πŸ” Elected %d target functions:", len(targets)) - for _, f := range targets { - LogDebug(t, " β†’ Package: %s, Function: %s, FullName: %s", f.PackageName, f.FunctionName, f.FullName) - } - LogDebug(t, "=======================================") - - if _, err := os.Stat(binaryPath); err != nil { - return fmt.Errorf("(1) binary inspection failed: %v", err) - } - - analyzedBinaries = append(analyzedBinaries, BinaryInfo{ - path: binaryPath, - hasDebug: len(targets) > 0, - }) - LogDebug(t, "βœ… Analysis complete for: %s", binaryPath) - LogDebug(t, "=======================================\n") - - g_ConfigManager.ProcTracker.HandleProcessStartSync(uint32(pid)) - t.Logf("About to request instrumentations for binary: %s, pid: %d.", binaryPath, pid) - - if err := g_configsAccumulator.AddTargets(targets, serviceName); err != nil { - t.Logf("Error adding target: %v, binaryPath: %s", err, binaryPath) - return fmt.Errorf("add targets failed: %v, binary: %s", err, binaryPath) - } - if err = g_configsAccumulator.WriteConfigs(); err != nil { - t.Logf("Error writing configs: %v, binaryPath: %s", err, binaryPath) - return fmt.Errorf("error adding configs: %v, binary: %s", err, binaryPath) - } - time.Sleep(2 * time.Second) - t.Logf("Requested to instrument %d functions for binary: %s, pid: %d.", len(targets), binaryPath, pid) - for _, f := range targets { - t.Logf(" -> requested instrumentation for %v", f) - } - if waitForAttach && os.Getenv("DEBUG") == "true" { - pid := os.Getpid() - t.Logf("Waiting to attach for PID: %d", pid) - time.Sleep(30 * time.Second) - waitForAttach = false - } - return nil -} - -// TestExplorationGoDI is the entrypoint of the integration test of Go DI. The idea is to -// test Go DI systematically and in exploratory manner. In high level, here are the steps this test takes: -// 1. Clones protobuf and applies patches. -// 2. Figuring out the 1st party packages involved with the cloned project (to avoid 3rd party/std libs) -// 3. Compiles the test -// 4. Runs the test in a supervised environment, spawning processes as a group. -// 5. Periodically pauses and resumes the process group to analyze each binary unique. -// 6. Invoke Go DI to put probes in top X functions defined by `NUMBER_OF_RROBES` const. -// -// The goal is to exercise as many code paths as possible of the Go DI system. -func TestExplorationGoDI(t *testing.T) { - require.NoError(t, rlimit.RemoveMemlock(), "Failed to remove memlock limit") - if features.HaveMapType(ebpf.RingBuf) != nil { - t.Skip("Ringbuffers not supported on this kernel") - } - - eventOutputWriter := &explorationEventOutputTestWriter{t: t} - opts := &dynamicinstrumentation.DIOptions{ - RateLimitPerProbePerSecond: 0.0, - ReaderWriterOptions: dynamicinstrumentation.ReaderWriterOptions{ - CustomReaderWriters: true, - SnapshotWriter: eventOutputWriter, - DiagnosticWriter: os.Stderr, - }, - } - - var ( - GoDI *dynamicinstrumentation.GoDI - err error - ) - GoDI, err = dynamicinstrumentation.RunDynamicInstrumentation(opts) - require.NoError(t, err) - t.Cleanup(GoDI.Close) - - cm, ok := GoDI.ConfigManager.(*diconfig.ReaderConfigManager) - if !ok { - t.Fatal("Config manager is of wrong type") - } - g_ConfigManager = cm - - g_configsAccumulator, err = NewConfigAccumulator() - if err != nil { - t.Fatal("Failed to create ConfigAccumulator") - } - - tempDir := initializeTempDir(t, "/tmp/protobuf-integration-1060272402") - modulePath := filepath.Join(tempDir, "src", "google.golang.org", "protobuf") - - t.Log("Setting up test environment...") - g_RepoInfo = cloneProtobufRepo(t, modulePath, "30f628eeb303f2c29be7a381bf78aa3e3aabd317") - copyPatches(t, "exploration_tests/patches/protobuf", modulePath) - - t.Log("Starting process tracking...") - tracker := NewProcessTracker(t) - err = tracker.StartTracking("./test.bash", nil, modulePath) - require.NoError(t, err) -} - -func initializeTempDir(t *testing.T, predefinedTempDir string) string { - if predefinedTempDir != "" { - return predefinedTempDir - } - tempDir, err := os.MkdirTemp("", "protobuf-integration-") - require.NoError(t, err) - require.NoError(t, os.Chmod(tempDir, 0755)) - t.Log("tempDir:", tempDir) - return tempDir -} - -func ScanRepoPackages(repoPath string) (*RepoInfo, error) { - info := &RepoInfo{ - Packages: make(map[string]bool), - RepoPath: repoPath, - } - if _, err := os.Stat(filepath.Join(repoPath, ".git")); err == nil { - if hash, err := exec.Command("git", "-C", repoPath, "rev-parse", "HEAD").Output(); err == nil { - info.CommitHash = strings.TrimSpace(string(hash)) - } - } - err := filepath.Walk(repoPath, func(path string, f os.FileInfo, err error) error { - if err != nil { - return nil - } - if f.IsDir() { - dirname := filepath.Base(path) - if dirname == ".git" || - dirname == ".cache" || - dirname == "vendor" || - dirname == "testdata" || - strings.HasPrefix(dirname, ".") || - strings.HasPrefix(dirname, "tmp") { - return filepath.SkipDir - } - return nil - } - if !strings.HasSuffix(path, ".go") { - return nil - } - if strings.HasSuffix(path, "_test.go") || - strings.HasSuffix(path, ".pb.go") { - return nil - } - relPath, err := filepath.Rel(repoPath, path) - if err != nil || strings.Contains(relPath, "..") { - return nil - } - content, err := os.ReadFile(path) - if err != nil { - return nil - } - scanner := bufio.NewScanner(bytes.NewReader(content)) - for scanner.Scan() { - line := strings.TrimSpace(scanner.Text()) - if strings.HasPrefix(line, "package ") { - pkgDir := filepath.Dir(relPath) - if pkgDir != "." { - info.Packages[pkgDir] = true - } - break - } - } - return nil - }) - if len(info.Packages) == 0 { - return nil, fmt.Errorf("no packages found in repository at %s", repoPath) - } - return info, err -} - -func cloneProtobufRepo(t *testing.T, modulePath string, commitHash string) *RepoInfo { - if _, err := os.Stat(modulePath); os.IsNotExist(err) { - cmd := exec.Command("git", "clone", "https://github.com/protocolbuffers/protobuf-go", modulePath) - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - require.NoError(t, cmd.Run(), "Failed to clone repository") - } - if commitHash != "" { - cmd := exec.Command("git", "checkout", commitHash) - cmd.Dir = modulePath - require.NoError(t, cmd.Run(), "Failed to checkout commit hash") - } - info, err := ScanRepoPackages(modulePath) - require.NoError(t, err, "Failed to scan repo packages") - - var pkgs []string - for pkg := range info.Packages { - if strings.Contains(pkg, "/tmp") { - continue - } - pkgs = append(pkgs, pkg) - } - sort.Strings(pkgs) - t.Logf("πŸ“¦ Found %d packages in protobuf repo:", len(pkgs)) - - groups := make(map[string][]string) - for _, pkg := range pkgs { - parts := strings.SplitN(pkg, "/", 2) - topLevel := parts[0] - groups[topLevel] = append(groups[topLevel], pkg) - } - - var topLevels []string - for k := range groups { - topLevels = append(topLevels, k) - } - sort.Strings(topLevels) - for _, topLevel := range topLevels { - t.Logf(" %s/", topLevel) - for _, pkg := range groups[topLevel] { - t.Logf(" β†’ %s", pkg) - } - } - return info -} - -func copyPatches(t *testing.T, src, dst string) { - require.NoError(t, copyDir(src, dst), "Failed to copy patches") -} - -func copyDir(src, dst string) error { - entries, err := os.ReadDir(src) - if err != nil { - return err - } - if err := os.MkdirAll(dst, 0755); err != nil { - return err - } - for _, entry := range entries { - srcPath := filepath.Join(src, entry.Name()) - dstPath := filepath.Join(dst, entry.Name()) - info, err := entry.Info() - if err != nil { - return err - } - if info.IsDir() { - if err = copyDir(srcPath, dstPath); err != nil { - return err - } - } else { - if err = copyFile(srcPath, dstPath); err != nil { - return err - } - } - } - return nil -} - -func copyFile(srcFile, dstFile string) error { - src, err := os.Open(srcFile) - if err != nil { - return err - } - defer src.Close() - if err = os.MkdirAll(filepath.Dir(dstFile), 0755); err != nil { - return err - } - dst, err := os.Create(dstFile) - if err != nil { - return err - } - defer dst.Close() - _, err = io.Copy(dst, src) - return err -} diff --git a/pkg/dynamicinstrumentation/testutil/exploration_tests/patches/protobuf/integration_test.go b/pkg/dynamicinstrumentation/testutil/exploration_tests/patches/protobuf/integration_test.go deleted file mode 100644 index cb8ef847f40b34..00000000000000 --- a/pkg/dynamicinstrumentation/testutil/exploration_tests/patches/protobuf/integration_test.go +++ /dev/null @@ -1,595 +0,0 @@ -// Copyright 2019 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package main - -import ( - "archive/tar" - "archive/zip" - "bytes" - "compress/gzip" - "crypto/sha256" - "flag" - "fmt" - "io" - "io/fs" - "math/rand" - "net/http" - "os" - "os/exec" - "path/filepath" - "regexp" - "runtime" - "runtime/debug" - "strings" - "sync" - "testing" - "time" - - "google.golang.org/protobuf/internal/version" -) - -var ( - regenerate = flag.Bool("regenerate", false, "regenerate files") - buildRelease = flag.Bool("buildRelease", false, "build release binaries") - - protobufVersion = "27.0" - - golangVersions = func() []string { - // Version policy: oldest supported version of Go, plus the version before that. - // This matches the version policy of the Google Cloud Client Libraries: - // https://cloud.google.com/go/getting-started/supported-go-versions - return []string{ - "1.21.13", - "1.22.6", - "1.23.0", - } - }() - golangLatest = golangVersions[len(golangVersions)-1] - - staticcheckVersion = "2024.1.1" - staticcheckSHA256s = map[string]string{ - "darwin/amd64": "b67380b84b81d5765b478b7ad888dd7ce53b2c0861103bafa946ac84dc9244ce", - "darwin/arm64": "09cb10e4199f7c6356c2ed5dc45e877c3087ef775d84d39338b52e1a94866074", - "linux/386": "0225fd8b5cf6c762f9c0aedf1380ed4df576d1d54fb68691be895889e10faf0b", - "linux/amd64": "6e9398fcaff2b36e1d15e84a647a3a14733b7c2dd41187afa2c182a4c3b32180", - } - - // purgeTimeout determines the maximum age of unused sub-directories. - purgeTimeout = 30 * 24 * time.Hour // 1 month - - // Variables initialized by mustInitDeps. - modulePath string - protobufPath string -) - -func TestIntegration(t *testing.T) { - if testing.Short() { - t.Skip("skipping integration test in short mode") - } - if os.Getenv("GO_BUILDER_NAME") != "" { - // To start off, run on longtest builders, not longtest-race ones. - if race() { - t.Skip("skipping integration test in race mode on builders") - } - // When on a builder, run even if it's not explicitly requested - // provided our caller isn't already running it. - if os.Getenv("GO_PROTOBUF_INTEGRATION_TEST_RUNNING") == "1" { - t.Skip("protobuf integration test is already running, skipping nested invocation") - } - os.Setenv("GO_PROTOBUF_INTEGRATION_TEST_RUNNING", "1") - } else if flag.Lookup("test.run").Value.String() != "^TestIntegration$" { - t.Skip("not running integration test if not explicitly requested via test.bash") - } - - mustInitDeps(t) - mustHandleFlags(t) - - // Report dirt in the working tree quickly, rather than after - // going through all the presubmits. - // - // Fail the test late, so we can test uncommitted changes with -failfast. - // gitDiff := mustRunCommand(t, "git", "diff", "HEAD") - // if strings.TrimSpace(gitDiff) != "" { - // fmt.Printf("WARNING: working tree contains uncommitted changes:\n%v\n", gitDiff) - // } - // gitUntracked := mustRunCommand(t, "git", "ls-files", "--others", "--exclude-standard") - // if strings.TrimSpace(gitUntracked) != "" { - // fmt.Printf("WARNING: working tree contains untracked files:\n%v\n", gitUntracked) - // } - - // Do the relatively fast checks up-front. - t.Run("GeneratedGoFiles", func(t *testing.T) { - diff := mustRunCommand(t, "go", "run", "-tags", "protolegacy", "./internal/cmd/generate-types") - if strings.TrimSpace(diff) != "" { - t.Fatalf("stale generated files:\n%v", diff) - } - diff = mustRunCommand(t, "go", "run", "-tags", "protolegacy", "./internal/cmd/generate-protos") - if strings.TrimSpace(diff) != "" { - t.Fatalf("stale generated files:\n%v", diff) - } - }) - t.Run("FormattedGoFiles", func(t *testing.T) { - files := strings.Split(strings.TrimSpace(mustRunCommand(t, "git", "ls-files", "*.go")), "\n") - diff := mustRunCommand(t, append([]string{"gofmt", "-d"}, files...)...) - if strings.TrimSpace(diff) != "" { - t.Fatalf("unformatted source files:\n%v", diff) - } - }) - t.Run("CopyrightHeaders", func(t *testing.T) { - files := strings.Split(strings.TrimSpace(mustRunCommand(t, "git", "ls-files", "*.go", "*.proto")), "\n") - mustHaveCopyrightHeader(t, files) - }) - - var wg sync.WaitGroup - sema := make(chan bool, (runtime.NumCPU()+1)/2) - for i := range golangVersions { - goVersion := golangVersions[i] - goLabel := "Go" + goVersion - runGo := func(label string, cmd command, args ...string) { - wg.Add(1) - sema <- true - go func() { - defer wg.Done() - defer func() { <-sema }() - t.Run(goLabel+"/"+label, func(t *testing.T) { - args[0] += goVersion - cmd.mustRun(t, args...) - }) - }() - } - - runGo("Normal", command{}, "go", "test", "-race", "./...") - runGo("Reflect", command{}, "go", "test", "-race", "-tags", "protoreflect", "./...") - if goVersion == golangLatest { - runGo("ProtoLegacyRace", command{}, "go", "test", "-race", "-tags", "protolegacy", "./...") - runGo("ProtoLegacy", command{}, "go", "test", "-tags", "protolegacy", "./...") - runGo("ProtocGenGo", command{Dir: "cmd/protoc-gen-go/testdata"}, "go", "test") - runGo("Conformance", command{Dir: "internal/conformance"}, "go", "test", "-execute") - - // Only run the 32-bit compatibility tests for Linux; - // avoid Darwin since 10.15 dropped support i386 code execution. - // if runtime.GOOS == "linux" { - // runGo("Arch32Bit", command{Env: append(os.Environ(), "GOARCH=386")}, "go", "test", "./...") - // } - } - } - wg.Wait() - - t.Run("GoStaticCheck", func(t *testing.T) { - checks := []string{ - "all", // start with all checks enabled - "-SA1019", // disable deprecated usage check - "-S*", // disable code simplification checks - "-ST*", // disable coding style checks - "-U*", // disable unused declaration checks - } - out := mustRunCommand(t, "staticcheck", "-checks="+strings.Join(checks, ","), "-fail=none", "./...") - - // Filter out findings from certain paths. - var findings []string - for _, finding := range strings.Split(strings.TrimSpace(out), "\n") { - switch { - case strings.HasPrefix(finding, "internal/testprotos/legacy/"): - default: - findings = append(findings, finding) - } - } - if len(findings) > 0 { - t.Fatalf("staticcheck findings:\n%v", strings.Join(findings, "\n")) - } - }) - // t.Run("CommittedGitChanges", func(t *testing.T) { - // if strings.TrimSpace(gitDiff) != "" { - // t.Fatalf("uncommitted changes") - // } - // }) - // t.Run("TrackedGitFiles", func(t *testing.T) { - // if strings.TrimSpace(gitUntracked) != "" { - // t.Fatalf("untracked files") - // } - // }) -} - -func mustInitDeps(t *testing.T) { - check := func(err error) { - t.Helper() - if err != nil { - t.Fatal(err) - } - } - - // Determine the directory to place the test directory. - repoRoot, err := os.Getwd() - check(err) - testDir := filepath.Join(repoRoot, ".cache") - check(os.MkdirAll(testDir, 0775)) - - // Delete the current directory if non-empty, - // which only occurs if a dependency failed to initialize properly. - var workingDir string - finishedDirs := map[string]bool{} - defer func() { - if workingDir != "" { - os.RemoveAll(workingDir) // best-effort - } - }() - startWork := func(name string) string { - workingDir = filepath.Join(testDir, name) - return workingDir - } - finishWork := func() { - finishedDirs[workingDir] = true - workingDir = "" - } - - // Delete other sub-directories that are no longer relevant. - defer func() { - now := time.Now() - fis, _ := os.ReadDir(testDir) - for _, fi := range fis { - dir := filepath.Join(testDir, fi.Name()) - if finishedDirs[dir] { - os.Chtimes(dir, now, now) // best-effort - continue - } - fii, err := fi.Info() - check(err) - if now.Sub(fii.ModTime()) < purgeTimeout { - continue - } - fmt.Printf("delete %v\n", fi.Name()) - os.RemoveAll(dir) // best-effort - } - }() - - // The bin directory contains symlinks to each tool by version. - // It is safe to delete this directory and run the test script from scratch. - binPath := startWork("bin") - check(os.RemoveAll(binPath)) - check(os.Mkdir(binPath, 0775)) - check(os.Setenv("PATH", binPath+":"+os.Getenv("PATH"))) - registerBinary := func(name, path string) { - check(os.Symlink(path, filepath.Join(binPath, name))) - } - finishWork() - - // Get the protobuf toolchain. - protobufPath = startWork("protobuf-" + protobufVersion) - if _, err := os.Stat(protobufPath); err != nil { - fmt.Printf("download %v\n", filepath.Base(protobufPath)) - checkoutVersion := protobufVersion - if isCommit := strings.Trim(protobufVersion, "0123456789abcdef") == ""; !isCommit { - // release tags have "v" prefix - checkoutVersion = "v" + protobufVersion - } - command{Dir: testDir}.mustRun(t, "git", "clone", "https://github.com/protocolbuffers/protobuf", "protobuf-"+protobufVersion) - command{Dir: protobufPath}.mustRun(t, "git", "checkout", checkoutVersion) - - if os.Getenv("GO_BUILDER_NAME") != "" { - // If this is running on the Go build infrastructure, - // use pre-built versions of these binaries that the - // builders are configured to provide in $PATH. - protocPath, err := exec.LookPath("protoc") - check(err) - confTestRunnerPath, err := exec.LookPath("conformance_test_runner") - check(err) - check(os.MkdirAll(filepath.Join(protobufPath, "bazel-bin", "conformance"), 0775)) - check(os.Symlink(protocPath, filepath.Join(protobufPath, "bazel-bin", "protoc"))) - check(os.Symlink(confTestRunnerPath, filepath.Join(protobufPath, "bazel-bin", "conformance", "conformance_test_runner"))) - } else { - // In other environments, download and build the protobuf toolchain. - // We avoid downloading the pre-compiled binaries since they do not contain - // the conformance test runner. - fmt.Printf("build %v\n", filepath.Base(protobufPath)) - env := os.Environ() - args := []string{ - "bazel", "build", - ":protoc", - "//conformance:conformance_test_runner", - } - if runtime.GOOS == "darwin" { - // Adding this environment variable appears to be necessary for macOS builds. - env = append(env, "CC=clang") - // And this flag. - args = append(args, - "--macos_minimum_os=13.0", - "--host_macos_minimum_os=13.0", - ) - } - command{ - Dir: protobufPath, - Env: env, - }.mustRun(t, args...) - } - } - check(os.Setenv("PROTOBUF_ROOT", protobufPath)) // for generate-protos - registerBinary("conform-test-runner", filepath.Join(protobufPath, "bazel-bin", "conformance", "conformance_test_runner")) - registerBinary("protoc", filepath.Join(protobufPath, "bazel-bin", "protoc")) - finishWork() - - // Download each Go toolchain version. - for _, v := range golangVersions { - goDir := startWork("go" + v) - if _, err := os.Stat(goDir); err != nil { - fmt.Printf("download %v\n", filepath.Base(goDir)) - url := fmt.Sprintf("https://dl.google.com/go/go%v.%v-%v.tar.gz", v, runtime.GOOS, runtime.GOARCH) - downloadArchive(check, goDir, url, "go", "") // skip SHA256 check as we fetch over https from a trusted domain - } - registerBinary("go"+v, filepath.Join(goDir, "bin", "go")) - finishWork() - } - registerBinary("go", filepath.Join(testDir, "go"+golangLatest, "bin", "go")) - registerBinary("gofmt", filepath.Join(testDir, "go"+golangLatest, "bin", "gofmt")) - - // Download the staticcheck tool. - checkDir := startWork("staticcheck-" + staticcheckVersion) - if _, err := os.Stat(checkDir); err != nil { - fmt.Printf("download %v\n", filepath.Base(checkDir)) - url := fmt.Sprintf("https://github.com/dominikh/go-tools/releases/download/%v/staticcheck_%v_%v.tar.gz", staticcheckVersion, runtime.GOOS, runtime.GOARCH) - downloadArchive(check, checkDir, url, "staticcheck", staticcheckSHA256s[runtime.GOOS+"/"+runtime.GOARCH]) - } - registerBinary("staticcheck", filepath.Join(checkDir, "staticcheck")) - finishWork() - - // GitHub actions sets GOROOT, which confuses invocations of the Go toolchain. - // Explicitly clear GOROOT, so each toolchain uses their default GOROOT. - check(os.Unsetenv("GOROOT")) - - // Set a cache directory outside the test directory. - check(os.Setenv("GOCACHE", filepath.Join(repoRoot, ".gocache"))) -} - -func downloadFile(check func(error), dstPath, srcURL string, perm fs.FileMode) { - resp, err := http.Get(srcURL) - check(err) - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(io.LimitReader(resp.Body, 4<<10)) - check(fmt.Errorf("GET %q: non-200 OK status code: %v body: %q", srcURL, resp.Status, body)) - } - - check(os.MkdirAll(filepath.Dir(dstPath), 0775)) - f, err := os.OpenFile(dstPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, perm) - check(err) - - _, err = io.Copy(f, resp.Body) - check(err) - - check(f.Close()) -} - -func downloadArchive(check func(error), dstPath, srcURL, skipPrefix, wantSHA256 string) { - check(os.RemoveAll(dstPath)) - - resp, err := http.Get(srcURL) - check(err) - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(io.LimitReader(resp.Body, 4<<10)) - check(fmt.Errorf("GET %q: non-200 OK status code: %v body: %q", srcURL, resp.Status, body)) - } - - var r io.Reader = resp.Body - if wantSHA256 != "" { - b, err := io.ReadAll(resp.Body) - check(err) - r = bytes.NewReader(b) - - if gotSHA256 := fmt.Sprintf("%x", sha256.Sum256(b)); gotSHA256 != wantSHA256 { - check(fmt.Errorf("checksum validation error:\ngot %v\nwant %v", gotSHA256, wantSHA256)) - } - } - - zr, err := gzip.NewReader(r) - check(err) - - tr := tar.NewReader(zr) - for { - h, err := tr.Next() - if err == io.EOF { - return - } - check(err) - - // Skip directories or files outside the prefix directory. - if len(skipPrefix) > 0 { - if !strings.HasPrefix(h.Name, skipPrefix) { - continue - } - if len(h.Name) > len(skipPrefix) && h.Name[len(skipPrefix)] != '/' { - continue - } - } - - path := strings.TrimPrefix(strings.TrimPrefix(h.Name, skipPrefix), "/") - path = filepath.Join(dstPath, filepath.FromSlash(path)) - mode := os.FileMode(h.Mode & 0777) - switch h.Typeflag { - case tar.TypeReg: - b, err := io.ReadAll(tr) - check(err) - check(os.WriteFile(path, b, mode)) - case tar.TypeDir: - check(os.Mkdir(path, mode)) - } - } -} - -func mustHandleFlags(t *testing.T) { - if *regenerate { - t.Run("Generate", func(t *testing.T) { - fmt.Print(mustRunCommand(t, "go", "generate", "./internal/cmd/generate-types")) - fmt.Print(mustRunCommand(t, "go", "generate", "./internal/cmd/generate-protos")) - files := strings.Split(strings.TrimSpace(mustRunCommand(t, "git", "ls-files", "*.go")), "\n") - mustRunCommand(t, append([]string{"gofmt", "-w"}, files...)...) - }) - } - if *buildRelease { - t.Run("BuildRelease", func(t *testing.T) { - v := version.String() - for _, goos := range []string{"linux", "darwin", "windows"} { - for _, goarch := range []string{"386", "amd64", "arm64"} { - // Avoid Darwin since 10.15 dropped support for i386. - if goos == "darwin" && goarch == "386" { - continue - } - - binPath := filepath.Join("bin", fmt.Sprintf("protoc-gen-go.%v.%v.%v", v, goos, goarch)) - - // Build the binary. - cmd := command{Env: append(os.Environ(), "GOOS="+goos, "GOARCH="+goarch)} - cmd.mustRun(t, "go", "build", "-trimpath", "-ldflags", "-s -w -buildid=", "-o", binPath, "./cmd/protoc-gen-go") - - // Archive and compress the binary. - in, err := os.ReadFile(binPath) - if err != nil { - t.Fatal(err) - } - out := new(bytes.Buffer) - suffix := "" - comment := fmt.Sprintf("protoc-gen-go VERSION=%v GOOS=%v GOARCH=%v", v, goos, goarch) - switch goos { - case "windows": - suffix = ".zip" - zw := zip.NewWriter(out) - zw.SetComment(comment) - fw, _ := zw.Create("protoc-gen-go.exe") - fw.Write(in) - zw.Close() - default: - suffix = ".tar.gz" - gz, _ := gzip.NewWriterLevel(out, gzip.BestCompression) - gz.Comment = comment - tw := tar.NewWriter(gz) - tw.WriteHeader(&tar.Header{ - Name: "protoc-gen-go", - Mode: int64(0775), - Size: int64(len(in)), - }) - tw.Write(in) - tw.Close() - gz.Close() - } - if err := os.WriteFile(binPath+suffix, out.Bytes(), 0664); err != nil { - t.Fatal(err) - } - } - } - }) - } - if *regenerate || *buildRelease { - t.SkipNow() - } -} - -var copyrightRegex = []*regexp.Regexp{ - regexp.MustCompile(`^// Copyright \d\d\d\d The Go Authors\. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file\. -`), - // Generated .pb.go files from main protobuf repo. - regexp.MustCompile(`^// Protocol Buffers - Google's data interchange format -// Copyright \d\d\d\d Google Inc\. All rights reserved\. -`), -} - -func mustHaveCopyrightHeader(t *testing.T, files []string) { - var bad []string -File: - for _, file := range files { - if strings.HasSuffix(file, "internal/testprotos/conformance/editions/test_messages_edition2023.pb.go") { - // TODO(lassefolger) the underlying proto file is checked into - // the protobuf repo without a copyright header. Fix is pending but - // might require a release. - continue - } - b, err := os.ReadFile(file) - if err != nil { - t.Fatal(err) - } - for _, re := range copyrightRegex { - if loc := re.FindIndex(b); loc != nil && loc[0] == 0 { - continue File - } - } - bad = append(bad, file) - } - if len(bad) > 0 { - t.Fatalf("files with missing/bad copyright headers:\n %v", strings.Join(bad, "\n ")) - } -} - -// Add in command struct: -type command struct { - Dir string - Env []string -} - -func (c command) mustRun(t *testing.T, args ...string) string { - t.Helper() - stdout := new(bytes.Buffer) - stderr := new(bytes.Buffer) - - var cmdArgs []string - if len(args) > 1 && strings.HasPrefix(args[0], "go") && args[1] == "test" { - for i, arg := range args { - cmdArgs = append(cmdArgs, arg) - if i == 1 { // right after "test" - cmdArgs = append(cmdArgs, - "-ldflags=-w=false -s=false", - "-gcflags=all=-l", - "-count=1", - "-timeout=30m", - ) - } - } - } else { - cmdArgs = args - } - - cmd := exec.Command(cmdArgs[0], cmdArgs[1:]...) - cmd.Dir = "." - if c.Dir != "" { - cmd.Dir = c.Dir - } - cmd.Env = os.Environ() - if c.Env != nil { - cmd.Env = c.Env - } - cmd.Env = append(cmd.Env, - fmt.Sprintf("PWD=%s", cmd.Dir), - fmt.Sprintf("DD_SERVICE=go-di-exploration-test-%d", rand.Int()), - ) - cmd.Stdout = stdout - cmd.Stderr = stderr - - if err := cmd.Run(); err != nil { - t.Fatalf("executing (%v): %v\n%s%s", strings.Join(args, " "), err, stdout.String(), stderr.String()) - } - - return stdout.String() -} - -func mustRunCommand(t *testing.T, args ...string) string { - t.Helper() - return command{}.mustRun(t, args...) -} - -// race is an approximation of whether the race detector is on. -// It's used to skip the integration test on builders, without -// preventing the integration test from running under the race -// detector as a '//go:build !race' build constraint would. -func race() bool { - bi, ok := debug.ReadBuildInfo() - if !ok { - return false - } - for _, setting := range bi.Settings { - if setting.Key == "-race" { - return setting.Value == "true" - } - } - return false -} diff --git a/pkg/dynamicinstrumentation/testutil/exploration_tests/patches/protobuf/test.bash b/pkg/dynamicinstrumentation/testutil/exploration_tests/patches/protobuf/test.bash deleted file mode 100644 index aec89522a116a5..00000000000000 --- a/pkg/dynamicinstrumentation/testutil/exploration_tests/patches/protobuf/test.bash +++ /dev/null @@ -1,7 +0,0 @@ -#!/bin/bash -# Copyright 2018 The Go Authors. All rights reserved. -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file. - -go test google.golang.org/protobuf -run='^TestIntegration$' -v -timeout=60m -count=1 -failfast "$@" -exit $? \ No newline at end of file diff --git a/pkg/dynamicinstrumentation/testutil/exploration_tests/patches/protobuf/testing/prototest/message.go b/pkg/dynamicinstrumentation/testutil/exploration_tests/patches/protobuf/testing/prototest/message.go deleted file mode 100644 index 8f9af17e604744..00000000000000 --- a/pkg/dynamicinstrumentation/testutil/exploration_tests/patches/protobuf/testing/prototest/message.go +++ /dev/null @@ -1,911 +0,0 @@ -// Copyright 2019 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Package prototest exercises protobuf reflection. -package prototest - -import ( - "bytes" - "fmt" - "math" - "reflect" - "sort" - "strings" - "testing" - - "google.golang.org/protobuf/encoding/prototext" - "google.golang.org/protobuf/encoding/protowire" - "google.golang.org/protobuf/proto" - "google.golang.org/protobuf/reflect/protoreflect" - "google.golang.org/protobuf/reflect/protoregistry" -) - -// TODO: Test invalid field descriptors or oneof descriptors. -// TODO: This should test the functionality that can be provided by fast-paths. - -// Message tests a message implementation. -type Message struct { - // Resolver is used to determine the list of extension fields to test with. - // If nil, this defaults to using protoregistry.GlobalTypes. - Resolver interface { - FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) - FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) - RangeExtensionsByMessage(message protoreflect.FullName, f func(protoreflect.ExtensionType) bool) - } - - // UnmarshalOptions are respected for every Unmarshal call this package - // does. The Resolver and AllowPartial fields are overridden. - UnmarshalOptions proto.UnmarshalOptions -} - -//nolint:all -//go:noinline -func blabla_blabla(x bool) {} - -// Test performs tests on a [protoreflect.MessageType] implementation. -func (test Message) Test(t testing.TB, mt protoreflect.MessageType) { - testType(t, mt) - - // for { - // blabla_blabla(true) - // time.Sleep(1 * time.Second) - // } - - md := mt.Descriptor() - m1 := mt.New() - for i := 0; i < md.Fields().Len(); i++ { - fd := md.Fields().Get(i) - testField(t, m1, fd) - } - if test.Resolver == nil { - test.Resolver = protoregistry.GlobalTypes - } - var extTypes []protoreflect.ExtensionType - test.Resolver.RangeExtensionsByMessage(md.FullName(), func(e protoreflect.ExtensionType) bool { - extTypes = append(extTypes, e) - return true - }) - for _, xt := range extTypes { - testField(t, m1, xt.TypeDescriptor()) - } - for i := 0; i < md.Oneofs().Len(); i++ { - testOneof(t, m1, md.Oneofs().Get(i)) - } - testUnknown(t, m1) - - // Test round-trip marshal/unmarshal. - m2 := mt.New().Interface() - populateMessage(m2.ProtoReflect(), 1, nil) - for _, xt := range extTypes { - m2.ProtoReflect().Set(xt.TypeDescriptor(), newValue(m2.ProtoReflect(), xt.TypeDescriptor(), 1, nil)) - } - b, err := proto.MarshalOptions{ - AllowPartial: true, - }.Marshal(m2) - if err != nil { - t.Errorf("Marshal() = %v, want nil\n%v", err, prototext.Format(m2)) - } - m3 := mt.New().Interface() - unmarshalOpts := test.UnmarshalOptions - unmarshalOpts.AllowPartial = true - unmarshalOpts.Resolver = test.Resolver - if err := unmarshalOpts.Unmarshal(b, m3); err != nil { - t.Errorf("Unmarshal() = %v, want nil\n%v", err, prototext.Format(m2)) - } - if !proto.Equal(m2, m3) { - t.Errorf("round-trip marshal/unmarshal did not preserve message\nOriginal:\n%v\nNew:\n%v", prototext.Format(m2), prototext.Format(m3)) - } -} - -func testType(t testing.TB, mt protoreflect.MessageType) { - m := mt.New().Interface() - want := reflect.TypeOf(m) - if got := reflect.TypeOf(m.ProtoReflect().Interface()); got != want { - t.Errorf("type mismatch: reflect.TypeOf(m) != reflect.TypeOf(m.ProtoReflect().Interface()): %v != %v", got, want) - } - if got := reflect.TypeOf(m.ProtoReflect().New().Interface()); got != want { - t.Errorf("type mismatch: reflect.TypeOf(m) != reflect.TypeOf(m.ProtoReflect().New().Interface()): %v != %v", got, want) - } - if got := reflect.TypeOf(m.ProtoReflect().Type().Zero().Interface()); got != want { - t.Errorf("type mismatch: reflect.TypeOf(m) != reflect.TypeOf(m.ProtoReflect().Type().Zero().Interface()): %v != %v", got, want) - } - if mt, ok := mt.(protoreflect.MessageFieldTypes); ok { - testFieldTypes(t, mt) - } -} - -func testFieldTypes(t testing.TB, mt protoreflect.MessageFieldTypes) { - descName := func(d protoreflect.Descriptor) protoreflect.FullName { - if d == nil { - return "" - } - return d.FullName() - } - typeName := func(mt protoreflect.MessageType) protoreflect.FullName { - if mt == nil { - return "" - } - return mt.Descriptor().FullName() - } - adjustExpr := func(idx int, expr string) string { - expr = strings.Replace(expr, "fd.", "md.Fields().Get(i).", -1) - expr = strings.Replace(expr, "(fd)", "(md.Fields().Get(i))", -1) - expr = strings.Replace(expr, "mti.", "mt.Message(i).", -1) - expr = strings.Replace(expr, "(i)", fmt.Sprintf("(%d)", idx), -1) - return expr - } - checkEnumDesc := func(idx int, gotExpr, wantExpr string, got, want protoreflect.EnumDescriptor) { - if got != want { - t.Errorf("descriptor mismatch: %v != %v: %v != %v", adjustExpr(idx, gotExpr), adjustExpr(idx, wantExpr), descName(got), descName(want)) - } - } - checkMessageDesc := func(idx int, gotExpr, wantExpr string, got, want protoreflect.MessageDescriptor) { - if got != want { - t.Errorf("descriptor mismatch: %v != %v: %v != %v", adjustExpr(idx, gotExpr), adjustExpr(idx, wantExpr), descName(got), descName(want)) - } - } - checkMessageType := func(idx int, gotExpr, wantExpr string, got, want protoreflect.MessageType) { - if got != want { - t.Errorf("type mismatch: %v != %v: %v != %v", adjustExpr(idx, gotExpr), adjustExpr(idx, wantExpr), typeName(got), typeName(want)) - } - } - - fds := mt.Descriptor().Fields() - m := mt.New() - for i := 0; i < fds.Len(); i++ { - fd := fds.Get(i) - switch { - case fd.IsList(): - if fd.Enum() != nil { - checkEnumDesc(i, - "mt.Enum(i).Descriptor()", "fd.Enum()", - mt.Enum(i).Descriptor(), fd.Enum()) - } - if fd.Message() != nil { - checkMessageDesc(i, - "mt.Message(i).Descriptor()", "fd.Message()", - mt.Message(i).Descriptor(), fd.Message()) - checkMessageType(i, - "mt.Message(i)", "m.NewField(fd).List().NewElement().Message().Type()", - mt.Message(i), m.NewField(fd).List().NewElement().Message().Type()) - } - case fd.IsMap(): - mti := mt.Message(i) - if m := mti.New(); m != nil { - checkMessageDesc(i, - "m.Descriptor()", "fd.Message()", - m.Descriptor(), fd.Message()) - } - if m := mti.Zero(); m != nil { - checkMessageDesc(i, - "m.Descriptor()", "fd.Message()", - m.Descriptor(), fd.Message()) - } - checkMessageDesc(i, - "mti.Descriptor()", "fd.Message()", - mti.Descriptor(), fd.Message()) - if mti := mti.(protoreflect.MessageFieldTypes); mti != nil { - if fd.MapValue().Enum() != nil { - checkEnumDesc(i, - "mti.Enum(fd.MapValue().Index()).Descriptor()", "fd.MapValue().Enum()", - mti.Enum(fd.MapValue().Index()).Descriptor(), fd.MapValue().Enum()) - } - if fd.MapValue().Message() != nil { - checkMessageDesc(i, - "mti.Message(fd.MapValue().Index()).Descriptor()", "fd.MapValue().Message()", - mti.Message(fd.MapValue().Index()).Descriptor(), fd.MapValue().Message()) - checkMessageType(i, - "mti.Message(fd.MapValue().Index())", "m.NewField(fd).Map().NewValue().Message().Type()", - mti.Message(fd.MapValue().Index()), m.NewField(fd).Map().NewValue().Message().Type()) - } - } - default: - if fd.Enum() != nil { - checkEnumDesc(i, - "mt.Enum(i).Descriptor()", "fd.Enum()", - mt.Enum(i).Descriptor(), fd.Enum()) - } - if fd.Message() != nil { - checkMessageDesc(i, - "mt.Message(i).Descriptor()", "fd.Message()", - mt.Message(i).Descriptor(), fd.Message()) - checkMessageType(i, - "mt.Message(i)", "m.NewField(fd).Message().Type()", - mt.Message(i), m.NewField(fd).Message().Type()) - } - } - } -} - -// testField exercises set/get/has/clear of a field. -func testField(t testing.TB, m protoreflect.Message, fd protoreflect.FieldDescriptor) { - name := fd.FullName() - num := fd.Number() - - switch { - case fd.IsList(): - testFieldList(t, m, fd) - case fd.IsMap(): - testFieldMap(t, m, fd) - case fd.Message() != nil: - default: - if got, want := m.NewField(fd), fd.Default(); !valueEqual(got, want) { - t.Errorf("Message.NewField(%v) = %v, want default value %v", name, formatValue(got), formatValue(want)) - } - if fd.Kind() == protoreflect.FloatKind || fd.Kind() == protoreflect.DoubleKind { - testFieldFloat(t, m, fd) - } - } - - // Set to a non-zero value, the zero value, different non-zero values. - for _, n := range []seed{1, 0, minVal, maxVal} { - v := newValue(m, fd, n, nil) - m.Set(fd, v) - wantHas := true - if n == 0 { - if !fd.HasPresence() { - wantHas = false - } - if fd.IsExtension() { - wantHas = true - } - if fd.Cardinality() == protoreflect.Repeated { - wantHas = false - } - if fd.ContainingOneof() != nil { - wantHas = true - } - } - if !fd.HasPresence() && fd.Cardinality() != protoreflect.Repeated && fd.ContainingOneof() == nil && fd.Kind() == protoreflect.EnumKind && v.Enum() == 0 { - wantHas = false - } - if got, want := m.Has(fd), wantHas; got != want { - t.Errorf("after setting %q to %v:\nMessage.Has(%v) = %v, want %v", name, formatValue(v), num, got, want) - } - if got, want := m.Get(fd), v; !valueEqual(got, want) { - t.Errorf("after setting %q:\nMessage.Get(%v) = %v, want %v", name, num, formatValue(got), formatValue(want)) - } - found := false - m.Range(func(d protoreflect.FieldDescriptor, got protoreflect.Value) bool { - if fd != d { - return true - } - found = true - if want := v; !valueEqual(got, want) { - t.Errorf("after setting %q:\nMessage.Range got value %v, want %v", name, formatValue(got), formatValue(want)) - } - return true - }) - if got, want := wantHas, found; got != want { - t.Errorf("after setting %q:\nMessageRange saw field: %v, want %v", name, got, want) - } - } - - m.Clear(fd) - if got, want := m.Has(fd), false; got != want { - t.Errorf("after clearing %q:\nMessage.Has(%v) = %v, want %v", name, num, got, want) - } - switch { - case fd.IsList(): - if got := m.Get(fd); got.List().Len() != 0 { - t.Errorf("after clearing %q:\nMessage.Get(%v) = %v, want empty list", name, num, formatValue(got)) - } - case fd.IsMap(): - if got := m.Get(fd); got.Map().Len() != 0 { - t.Errorf("after clearing %q:\nMessage.Get(%v) = %v, want empty map", name, num, formatValue(got)) - } - case fd.Message() == nil: - if got, want := m.Get(fd), fd.Default(); !valueEqual(got, want) { - t.Errorf("after clearing %q:\nMessage.Get(%v) = %v, want default %v", name, num, formatValue(got), formatValue(want)) - } - } - - // Set to the default value. - switch { - case fd.IsList() || fd.IsMap(): - m.Set(fd, m.Mutable(fd)) - if got, want := m.Has(fd), (fd.IsExtension() && fd.Cardinality() != protoreflect.Repeated) || fd.ContainingOneof() != nil; got != want { - t.Errorf("after setting %q to default:\nMessage.Has(%v) = %v, want %v", name, num, got, want) - } - case fd.Message() == nil: - m.Set(fd, m.Get(fd)) - if got, want := m.Get(fd), fd.Default(); !valueEqual(got, want) { - t.Errorf("after setting %q to default:\nMessage.Get(%v) = %v, want default %v", name, num, formatValue(got), formatValue(want)) - } - } - m.Clear(fd) - - // Set to the wrong type. - v := protoreflect.ValueOfString("") - if fd.Kind() == protoreflect.StringKind { - v = protoreflect.ValueOfInt32(0) - } - if !panics(func() { - m.Set(fd, v) - }) { - t.Errorf("setting %v to %T succeeds, want panic", name, v.Interface()) - } -} - -// testFieldMap tests set/get/has/clear of entries in a map field. -func testFieldMap(t testing.TB, m protoreflect.Message, fd protoreflect.FieldDescriptor) { - name := fd.FullName() - num := fd.Number() - - // New values. - m.Clear(fd) // start with an empty map - mapv := m.Get(fd).Map() - if mapv.IsValid() { - t.Errorf("after clearing field: message.Get(%v).IsValid() = true, want false", name) - } - if got, want := mapv.NewValue(), newMapValue(fd, mapv, 0, nil); !valueEqual(got, want) { - t.Errorf("message.Get(%v).NewValue() = %v, want %v", name, formatValue(got), formatValue(want)) - } - if !panics(func() { - m.Set(fd, protoreflect.ValueOfMap(mapv)) - }) { - t.Errorf("message.Set(%v, ) does not panic", name) - } - if !panics(func() { - mapv.Set(newMapKey(fd, 0), newMapValue(fd, mapv, 0, nil)) - }) { - t.Errorf("message.Get(%v).Set(...) of invalid map does not panic", name) - } - mapv = m.Mutable(fd).Map() // mutable map - if !mapv.IsValid() { - t.Errorf("message.Mutable(%v).IsValid() = false, want true", name) - } - if got, want := mapv.NewValue(), newMapValue(fd, mapv, 0, nil); !valueEqual(got, want) { - t.Errorf("message.Mutable(%v).NewValue() = %v, want %v", name, formatValue(got), formatValue(want)) - } - - // Add values. - want := make(testMap) - for i, n := range []seed{1, 0, minVal, maxVal} { - if got, want := m.Has(fd), i > 0; got != want { - t.Errorf("after inserting %d elements to %q:\nMessage.Has(%v) = %v, want %v", i, name, num, got, want) - } - - k := newMapKey(fd, n) - v := newMapValue(fd, mapv, n, nil) - mapv.Set(k, v) - want.Set(k, v) - if got, want := m.Get(fd), protoreflect.ValueOfMap(want); !valueEqual(got, want) { - t.Errorf("after inserting %d elements to %q:\nMessage.Get(%v) = %v, want %v", i, name, num, formatValue(got), formatValue(want)) - } - } - - // Set values. - want.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool { - nv := newMapValue(fd, mapv, 10, nil) - mapv.Set(k, nv) - want.Set(k, nv) - if got, want := m.Get(fd), protoreflect.ValueOfMap(want); !valueEqual(got, want) { - t.Errorf("after setting element %v of %q:\nMessage.Get(%v) = %v, want %v", formatValue(k.Value()), name, num, formatValue(got), formatValue(want)) - } - return true - }) - - // Clear values. - want.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool { - mapv.Clear(k) - want.Clear(k) - if got, want := m.Has(fd), want.Len() > 0; got != want { - t.Errorf("after clearing elements of %q:\nMessage.Has(%v) = %v, want %v", name, num, got, want) - } - if got, want := m.Get(fd), protoreflect.ValueOfMap(want); !valueEqual(got, want) { - t.Errorf("after clearing elements of %q:\nMessage.Get(%v) = %v, want %v", name, num, formatValue(got), formatValue(want)) - } - return true - }) - if mapv := m.Get(fd).Map(); mapv.IsValid() { - t.Errorf("after clearing all elements: message.Get(%v).IsValid() = true, want false %v", name, formatValue(protoreflect.ValueOfMap(mapv))) - } - - // Non-existent map keys. - missingKey := newMapKey(fd, 1) - if got, want := mapv.Has(missingKey), false; got != want { - t.Errorf("non-existent map key in %q: Map.Has(%v) = %v, want %v", name, formatValue(missingKey.Value()), got, want) - } - if got, want := mapv.Get(missingKey).IsValid(), false; got != want { - t.Errorf("non-existent map key in %q: Map.Get(%v).IsValid() = %v, want %v", name, formatValue(missingKey.Value()), got, want) - } - mapv.Clear(missingKey) // noop - - // Mutable. - if fd.MapValue().Message() == nil { - if !panics(func() { - mapv.Mutable(newMapKey(fd, 1)) - }) { - t.Errorf("Mutable on %q succeeds, want panic", name) - } - } else { - k := newMapKey(fd, 1) - v := mapv.Mutable(k) - if got, want := mapv.Len(), 1; got != want { - t.Errorf("after Mutable on %q, Map.Len() = %v, want %v", name, got, want) - } - populateMessage(v.Message(), 1, nil) - if !valueEqual(mapv.Get(k), v) { - t.Errorf("after Mutable on %q, changing new mutable value does not change map entry", name) - } - mapv.Clear(k) - } -} - -type testMap map[any]protoreflect.Value - -func (m testMap) Get(k protoreflect.MapKey) protoreflect.Value { return m[k.Interface()] } -func (m testMap) Set(k protoreflect.MapKey, v protoreflect.Value) { m[k.Interface()] = v } -func (m testMap) Has(k protoreflect.MapKey) bool { return m.Get(k).IsValid() } -func (m testMap) Clear(k protoreflect.MapKey) { delete(m, k.Interface()) } -func (m testMap) Mutable(k protoreflect.MapKey) protoreflect.Value { panic("unimplemented") } -func (m testMap) Len() int { return len(m) } -func (m testMap) NewValue() protoreflect.Value { panic("unimplemented") } -func (m testMap) Range(f func(protoreflect.MapKey, protoreflect.Value) bool) { - for k, v := range m { - if !f(protoreflect.ValueOf(k).MapKey(), v) { - return - } - } -} -func (m testMap) IsValid() bool { return true } - -// testFieldList exercises set/get/append/truncate of values in a list. -func testFieldList(t testing.TB, m protoreflect.Message, fd protoreflect.FieldDescriptor) { - name := fd.FullName() - num := fd.Number() - - m.Clear(fd) // start with an empty list - list := m.Get(fd).List() - if list.IsValid() { - t.Errorf("message.Get(%v).IsValid() = true, want false", name) - } - if !panics(func() { - m.Set(fd, protoreflect.ValueOfList(list)) - }) { - t.Errorf("message.Set(%v, ) does not panic", name) - } - if !panics(func() { - list.Append(newListElement(fd, list, 0, nil)) - }) { - t.Errorf("message.Get(%v).Append(...) of invalid list does not panic", name) - } - if got, want := list.NewElement(), newListElement(fd, list, 0, nil); !valueEqual(got, want) { - t.Errorf("message.Get(%v).NewElement() = %v, want %v", name, formatValue(got), formatValue(want)) - } - list = m.Mutable(fd).List() // mutable list - if !list.IsValid() { - t.Errorf("message.Get(%v).IsValid() = false, want true", name) - } - if got, want := list.NewElement(), newListElement(fd, list, 0, nil); !valueEqual(got, want) { - t.Errorf("message.Mutable(%v).NewElement() = %v, want %v", name, formatValue(got), formatValue(want)) - } - - // Append values. - var want protoreflect.List = &testList{} - for i, n := range []seed{1, 0, minVal, maxVal} { - if got, want := m.Has(fd), i > 0; got != want { - t.Errorf("after appending %d elements to %q:\nMessage.Has(%v) = %v, want %v", i, name, num, got, want) - } - v := newListElement(fd, list, n, nil) - want.Append(v) - list.Append(v) - - if got, want := m.Get(fd), protoreflect.ValueOfList(want); !valueEqual(got, want) { - t.Errorf("after appending %d elements to %q:\nMessage.Get(%v) = %v, want %v", i+1, name, num, formatValue(got), formatValue(want)) - } - } - - // Set values. - for i := 0; i < want.Len(); i++ { - v := newListElement(fd, list, seed(i+10), nil) - want.Set(i, v) - list.Set(i, v) - if got, want := m.Get(fd), protoreflect.ValueOfList(want); !valueEqual(got, want) { - t.Errorf("after setting element %d of %q:\nMessage.Get(%v) = %v, want %v", i, name, num, formatValue(got), formatValue(want)) - } - } - - // Truncate. - for want.Len() > 0 { - n := want.Len() - 1 - want.Truncate(n) - list.Truncate(n) - if got, want := m.Has(fd), want.Len() > 0; got != want { - t.Errorf("after truncating %q to %d:\nMessage.Has(%v) = %v, want %v", name, n, num, got, want) - } - if got, want := m.Get(fd), protoreflect.ValueOfList(want); !valueEqual(got, want) { - t.Errorf("after truncating %q to %d:\nMessage.Get(%v) = %v, want %v", name, n, num, formatValue(got), formatValue(want)) - } - } - - // AppendMutable. - if fd.Message() == nil { - if !panics(func() { - list.AppendMutable() - }) { - t.Errorf("AppendMutable on %q succeeds, want panic", name) - } - } else { - v := list.AppendMutable() - if got, want := list.Len(), 1; got != want { - t.Errorf("after AppendMutable on %q, list.Len() = %v, want %v", name, got, want) - } - populateMessage(v.Message(), 1, nil) - if !valueEqual(list.Get(0), v) { - t.Errorf("after AppendMutable on %q, changing new mutable value does not change list item 0", name) - } - want.Truncate(0) - } -} - -type testList struct { - a []protoreflect.Value -} - -func (l *testList) Append(v protoreflect.Value) { l.a = append(l.a, v) } -func (l *testList) AppendMutable() protoreflect.Value { panic("unimplemented") } -func (l *testList) Get(n int) protoreflect.Value { return l.a[n] } -func (l *testList) Len() int { return len(l.a) } -func (l *testList) Set(n int, v protoreflect.Value) { l.a[n] = v } -func (l *testList) Truncate(n int) { l.a = l.a[:n] } -func (l *testList) NewElement() protoreflect.Value { panic("unimplemented") } -func (l *testList) IsValid() bool { return true } - -// testFieldFloat exercises some interesting floating-point scalar field values. -func testFieldFloat(t testing.TB, m protoreflect.Message, fd protoreflect.FieldDescriptor) { - name := fd.FullName() - num := fd.Number() - - for _, v := range []float64{math.Inf(-1), math.Inf(1), math.NaN(), math.Copysign(0, -1)} { - var val protoreflect.Value - if fd.Kind() == protoreflect.FloatKind { - val = protoreflect.ValueOfFloat32(float32(v)) - } else { - val = protoreflect.ValueOfFloat64(float64(v)) - } - m.Set(fd, val) - // Note that Has is true for -0. - if got, want := m.Has(fd), true; got != want { - t.Errorf("after setting %v to %v: Message.Has(%v) = %v, want %v", name, v, num, got, want) - } - if got, want := m.Get(fd), val; !valueEqual(got, want) { - t.Errorf("after setting %v: Message.Get(%v) = %v, want %v", name, num, formatValue(got), formatValue(want)) - } - } -} - -// testOneof tests the behavior of fields in a oneof. -func testOneof(t testing.TB, m protoreflect.Message, od protoreflect.OneofDescriptor) { - for _, mutable := range []bool{false, true} { - for i := 0; i < od.Fields().Len(); i++ { - fda := od.Fields().Get(i) - if mutable { - // Set fields by requesting a mutable reference. - if !fda.IsMap() && !fda.IsList() && fda.Message() == nil { - continue - } - _ = m.Mutable(fda) - } else { - // Set fields explicitly. - m.Set(fda, newValue(m, fda, 1, nil)) - } - if !od.IsSynthetic() { - // Synthetic oneofs are used to represent optional fields in - // proto3. While they show up in protoreflect, WhichOneof does - // not work on these (only on non-synthetic, explicit oneofs). - if got, want := m.WhichOneof(od), fda; got != want { - t.Errorf("after setting oneof field %q:\nWhichOneof(%q) = %v, want %v", fda.FullName(), fda.Name(), got, want) - } - } - for j := 0; j < od.Fields().Len(); j++ { - fdb := od.Fields().Get(j) - if got, want := m.Has(fdb), i == j; got != want { - t.Errorf("after setting oneof field %q:\nGet(%q) = %v, want %v", fda.FullName(), fdb.FullName(), got, want) - } - } - } - } -} - -// testUnknown tests the behavior of unknown fields. -func testUnknown(t testing.TB, m protoreflect.Message) { - var b []byte - b = protowire.AppendTag(b, 1000, protowire.VarintType) - b = protowire.AppendVarint(b, 1001) - m.SetUnknown(protoreflect.RawFields(b)) - if got, want := []byte(m.GetUnknown()), b; !bytes.Equal(got, want) { - t.Errorf("after setting unknown fields:\nGetUnknown() = %v, want %v", got, want) - } -} - -func formatValue(v protoreflect.Value) string { - switch v := v.Interface().(type) { - case protoreflect.List: - var buf bytes.Buffer - buf.WriteString("list[") - for i := 0; i < v.Len(); i++ { - if i > 0 { - buf.WriteString(" ") - } - buf.WriteString(formatValue(v.Get(i))) - } - buf.WriteString("]") - return buf.String() - case protoreflect.Map: - var buf bytes.Buffer - buf.WriteString("map[") - var keys []protoreflect.MapKey - v.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool { - keys = append(keys, k) - return true - }) - sort.Slice(keys, func(i, j int) bool { - return keys[i].String() < keys[j].String() - }) - for i, k := range keys { - if i > 0 { - buf.WriteString(" ") - } - buf.WriteString(formatValue(k.Value())) - buf.WriteString(":") - buf.WriteString(formatValue(v.Get(k))) - } - buf.WriteString("]") - return buf.String() - case protoreflect.Message: - b, err := prototext.Marshal(v.Interface()) - if err != nil { - return fmt.Sprintf("<%v>", err) - } - return fmt.Sprintf("%v{%s}", v.Descriptor().FullName(), b) - case string: - return fmt.Sprintf("%q", v) - default: - return fmt.Sprint(v) - } -} - -func valueEqual(a, b protoreflect.Value) bool { - ai, bi := a.Interface(), b.Interface() - switch ai.(type) { - case protoreflect.Message: - return proto.Equal( - a.Message().Interface(), - b.Message().Interface(), - ) - case protoreflect.List: - lista, listb := a.List(), b.List() - if lista.Len() != listb.Len() { - return false - } - for i := 0; i < lista.Len(); i++ { - if !valueEqual(lista.Get(i), listb.Get(i)) { - return false - } - } - return true - case protoreflect.Map: - mapa, mapb := a.Map(), b.Map() - if mapa.Len() != mapb.Len() { - return false - } - equal := true - mapa.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool { - if !valueEqual(v, mapb.Get(k)) { - equal = false - return false - } - return true - }) - return equal - case []byte: - return bytes.Equal(a.Bytes(), b.Bytes()) - case float32: - // NaNs are equal, but must be the same NaN. - return math.Float32bits(ai.(float32)) == math.Float32bits(bi.(float32)) - case float64: - // NaNs are equal, but must be the same NaN. - return math.Float64bits(ai.(float64)) == math.Float64bits(bi.(float64)) - default: - return ai == bi - } -} - -// A seed is used to vary the content of a value. -// -// A seed of 0 is the zero value. Messages do not have a zero-value; a 0-seeded messages -// is unpopulated. -// -// A seed of minVal or maxVal is the least or greatest value of the value type. -type seed int - -const ( - minVal seed = -1 - maxVal seed = -2 -) - -// newSeed creates new seed values from a base, for example to create seeds for the -// elements in a list. If the input seed is minVal or maxVal, so is the output. -func newSeed(n seed, adjust ...int) seed { - switch n { - case minVal, maxVal: - return n - } - for _, a := range adjust { - n = 10*n + seed(a) - } - return n -} - -// newValue returns a new value assignable to a field. -// -// The stack parameter is used to avoid infinite recursion when populating circular -// data structures. -func newValue(m protoreflect.Message, fd protoreflect.FieldDescriptor, n seed, stack []protoreflect.MessageDescriptor) protoreflect.Value { - switch { - case fd.IsList(): - if n == 0 { - return m.New().Mutable(fd) - } - list := m.NewField(fd).List() - list.Append(newListElement(fd, list, 0, stack)) - list.Append(newListElement(fd, list, minVal, stack)) - list.Append(newListElement(fd, list, maxVal, stack)) - list.Append(newListElement(fd, list, n, stack)) - return protoreflect.ValueOfList(list) - case fd.IsMap(): - if n == 0 { - return m.New().Mutable(fd) - } - mapv := m.NewField(fd).Map() - mapv.Set(newMapKey(fd, 0), newMapValue(fd, mapv, 0, stack)) - mapv.Set(newMapKey(fd, minVal), newMapValue(fd, mapv, minVal, stack)) - mapv.Set(newMapKey(fd, maxVal), newMapValue(fd, mapv, maxVal, stack)) - mapv.Set(newMapKey(fd, n), newMapValue(fd, mapv, newSeed(n, 0), stack)) - return protoreflect.ValueOfMap(mapv) - case fd.Message() != nil: - return populateMessage(m.NewField(fd).Message(), n, stack) - default: - return newScalarValue(fd, n) - } -} - -func newListElement(fd protoreflect.FieldDescriptor, list protoreflect.List, n seed, stack []protoreflect.MessageDescriptor) protoreflect.Value { - if fd.Message() == nil { - return newScalarValue(fd, n) - } - return populateMessage(list.NewElement().Message(), n, stack) -} - -func newMapKey(fd protoreflect.FieldDescriptor, n seed) protoreflect.MapKey { - kd := fd.MapKey() - return newScalarValue(kd, n).MapKey() -} - -func newMapValue(fd protoreflect.FieldDescriptor, mapv protoreflect.Map, n seed, stack []protoreflect.MessageDescriptor) protoreflect.Value { - vd := fd.MapValue() - if vd.Message() == nil { - return newScalarValue(vd, n) - } - return populateMessage(mapv.NewValue().Message(), n, stack) -} - -func newScalarValue(fd protoreflect.FieldDescriptor, n seed) protoreflect.Value { - switch fd.Kind() { - case protoreflect.BoolKind: - return protoreflect.ValueOfBool(n != 0) - case protoreflect.EnumKind: - vals := fd.Enum().Values() - var i int - switch n { - case minVal: - i = 0 - case maxVal: - i = vals.Len() - 1 - default: - i = int(n) % vals.Len() - } - return protoreflect.ValueOfEnum(vals.Get(i).Number()) - case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind: - switch n { - case minVal: - return protoreflect.ValueOfInt32(math.MinInt32) - case maxVal: - return protoreflect.ValueOfInt32(math.MaxInt32) - default: - return protoreflect.ValueOfInt32(int32(n)) - } - case protoreflect.Uint32Kind, protoreflect.Fixed32Kind: - switch n { - case minVal: - // Only use 0 for the zero value. - return protoreflect.ValueOfUint32(1) - case maxVal: - return protoreflect.ValueOfUint32(math.MaxInt32) - default: - return protoreflect.ValueOfUint32(uint32(n)) - } - case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind: - switch n { - case minVal: - return protoreflect.ValueOfInt64(math.MinInt64) - case maxVal: - return protoreflect.ValueOfInt64(math.MaxInt64) - default: - return protoreflect.ValueOfInt64(int64(n)) - } - case protoreflect.Uint64Kind, protoreflect.Fixed64Kind: - switch n { - case minVal: - // Only use 0 for the zero value. - return protoreflect.ValueOfUint64(1) - case maxVal: - return protoreflect.ValueOfUint64(math.MaxInt64) - default: - return protoreflect.ValueOfUint64(uint64(n)) - } - case protoreflect.FloatKind: - switch n { - case minVal: - return protoreflect.ValueOfFloat32(math.SmallestNonzeroFloat32) - case maxVal: - return protoreflect.ValueOfFloat32(math.MaxFloat32) - default: - return protoreflect.ValueOfFloat32(1.5 * float32(n)) - } - case protoreflect.DoubleKind: - switch n { - case minVal: - return protoreflect.ValueOfFloat64(math.SmallestNonzeroFloat64) - case maxVal: - return protoreflect.ValueOfFloat64(math.MaxFloat64) - default: - return protoreflect.ValueOfFloat64(1.5 * float64(n)) - } - case protoreflect.StringKind: - if n == 0 { - return protoreflect.ValueOfString("") - } - return protoreflect.ValueOfString(fmt.Sprintf("%d", n)) - case protoreflect.BytesKind: - if n == 0 { - return protoreflect.ValueOfBytes(nil) - } - return protoreflect.ValueOfBytes([]byte{byte(n >> 24), byte(n >> 16), byte(n >> 8), byte(n)}) - } - panic("unhandled kind") -} - -func populateMessage(m protoreflect.Message, n seed, stack []protoreflect.MessageDescriptor) protoreflect.Value { - if n == 0 { - return protoreflect.ValueOfMessage(m) - } - md := m.Descriptor() - for _, x := range stack { - if md == x { - return protoreflect.ValueOfMessage(m) - } - } - stack = append(stack, md) - for i := 0; i < md.Fields().Len(); i++ { - fd := md.Fields().Get(i) - if fd.IsWeak() { - continue - } - m.Set(fd, newValue(m, fd, newSeed(n, i), stack)) - } - return protoreflect.ValueOfMessage(m) -} - -func panics(f func()) (didPanic bool) { - defer func() { - if err := recover(); err != nil { - didPanic = true - } - }() - f() - return false -} From b7e112d207c88c7c62abf44d7489cfc6526351e7 Mon Sep 17 00:00:00 2001 From: Matan Green Date: Mon, 3 Mar 2025 22:34:33 +0200 Subject: [PATCH 5/6] Addressed PR comment: moved a call to AnalyzeBinary closer to it's callsite for clarity --- .../diconfig/config_manager.go | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/pkg/dynamicinstrumentation/diconfig/config_manager.go b/pkg/dynamicinstrumentation/diconfig/config_manager.go index a078f45d852c1c..c4dbf6e056ac33 100644 --- a/pkg/dynamicinstrumentation/diconfig/config_manager.go +++ b/pkg/dynamicinstrumentation/diconfig/config_manager.go @@ -236,6 +236,12 @@ func (cm *RCConfigManager) readConfigs(r *ringbuf.Reader, procInfo *ditypes.Proc // Check hash to see if the configuration changed if configPath.Hash != probe.InstrumentationInfo.ConfigurationHash { + err := AnalyzeBinary(procInfo) + if err != nil { + log.Errorf("couldn't inspect binary: %v\n", err) + continue + } + probe.InstrumentationInfo.ConfigurationHash = configPath.Hash applyConfigUpdate(procInfo, probe) } @@ -245,14 +251,6 @@ func (cm *RCConfigManager) readConfigs(r *ringbuf.Reader, procInfo *ditypes.Proc func applyConfigUpdate(procInfo *ditypes.ProcessInfo, probe *ditypes.Probe) { log.Tracef("Applying config update: %v\n", probe) - if procInfo.TypeMap == nil { - err := AnalyzeBinary(procInfo) - if err != nil { - log.Errorf("couldn't inspect binary: %v\n", err) - return - } - } - generateCompileAttach: err := codegen.GenerateBPFParamsCode(procInfo, probe) if err != nil { From d49985519cedebafe57761d8a154e6c9986737fa Mon Sep 17 00:00:00 2001 From: grantseltzer Date: Wed, 5 Mar 2025 14:55:29 -0600 Subject: [PATCH 6/6] Fix mutex copy Signed-off-by: grantseltzer --- .../diconfig/mem_config_manager.go | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/pkg/dynamicinstrumentation/diconfig/mem_config_manager.go b/pkg/dynamicinstrumentation/diconfig/mem_config_manager.go index 381080eab517f3..9e163dab193b1a 100644 --- a/pkg/dynamicinstrumentation/diconfig/mem_config_manager.go +++ b/pkg/dynamicinstrumentation/diconfig/mem_config_manager.go @@ -72,9 +72,17 @@ func (cm *ReaderConfigManager) update() error { for pid, proc := range cm.ConfigWriter.Processes { // If a config exists relevant to this proc if proc.ServiceName == serviceName { - procCopy := *proc - updatedState[pid] = &procCopy - updatedState[pid].ProbesByID = convert(serviceName, configsByID) + updatedState[pid] = &ditypes.ProcessInfo{ + PID: proc.PID, + ServiceName: proc.ServiceName, + RuntimeID: proc.RuntimeID, + BinaryPath: proc.BinaryPath, + TypeMap: proc.TypeMap, + ConfigurationUprobe: proc.ConfigurationUprobe, + InstrumentationUprobes: proc.InstrumentationUprobes, + InstrumentationObjects: proc.InstrumentationObjects, + ProbesByID: convert(serviceName, configsByID), + } } } }