Skip to content

Commit

Permalink
bug: fix peak detection
Browse files Browse the repository at this point in the history
  • Loading branch information
shelta-zhao committed Jan 23, 2025
1 parent f29c985 commit 6fa0ad9
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 32 deletions.
9 changes: 6 additions & 3 deletions module/doa_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,16 @@ def DOA_beamformingFFT(self, bin_val):
else:
# Azimuth and elevation angle estimation
for ind in peakLoc_azi:
spec_ele = torch.abs(doa_fft_result[ind, :])
spec_ele = torch.abs(doa_fft_result[ind - 1, :])
_, peakLoc_ele = peak_detect(spec_ele, self.gamma, self.sidelobeLevel_dB[1])
for peak in peakLoc_ele:
print(ind, peak)
# Calculate angles
print(self.wx_vec[ind], self.wz_vec[peak])
azi_est = torch.arcsin(self.wx_vec[ind] / (2 * torch.pi * self.d)) * -1 * 180 / torch.pi
ele_est = torch.arcsin(self.wz_vec[peak] / (2 * torch.pi * self.d)) * 180 / torch.pi
if (self.angles_DOA_azi[0] <= azi_est <= self.angles_DOA_azi[1] and self.angles_DOA_ele[0] <= ele_est <= self.angles_DOA_ele[1]):
print(azi_est, ele_est)
DOAObj_est.append([azi_est, ele_est, ind, peak])
obj_cnt += 1

Expand All @@ -170,15 +173,15 @@ def DOA_beamformingFFT(self, bin_val):
radar_params = get_radar_params(config_path, data['radar'], load=True)

# Get regular raw radar data
regular_data, timestamp = get_regular_data(data_path, radar_params['readObj'], '1', timestamp=True)
regular_data, timestamp = get_regular_data(data_path, radar_params['readObj'], 'all', timestamp=True)

# Perform Range & Doppler FFT
fft_processor = FFTProcessor(radar_params['rangeFFTObj'], radar_params['dopplerFFTObj'], device)
fft_output = fft_processor.run(regular_data)

# Perform CFAR-CASO detection
cfar_processor = CFARProcessor(radar_params['detectObj'], device)
detection_results = cfar_processor.run(fft_output[0,:256,:,:,:], 0)
detection_results = cfar_processor.run(fft_output[25,:256,:,:,:], 0)

# Test DOA Estimation
doa_processor = DOAProcessor(radar_params['DOAObj'], device)
Expand Down
60 changes: 33 additions & 27 deletions utility/tool_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,41 +41,47 @@ def peak_detect(input, gamma, sidelobeLevel_dB):
- peakLoc: Tensor containing the locations (indices) of the detected peaks
"""

input = input.flatten() # Ensure input is a 1D tensor
N, peaks = input.shape[0], []
maxVal, maxLoc, absMaxValue, minVal = torch.tensor(0.0), 0, torch.tensor(0.0), torch.tensor(float('inf'))


minVal, maxVal, maxLoc, maxLoc_r, numMax, = torch.tensor(float('inf'), dtype=torch.float64), torch.tensor(0.0), 0, 0, 0,
absMaxValue, locateMax, initStage, maxData = torch.tensor(0.0), False, True, []

N = input.shape[0]
for i in range(N):
currentVal = input[i]
i_loc = i % N
currentVal = input[i_loc]

# Track the absolute maximum value
# Update maximum and minimum values
absMaxValue = torch.max(absMaxValue, currentVal)

# Track the current max value and its location
if currentVal > maxVal:
maxVal = currentVal
maxLoc = i

# Track the minimum value
maxLoc = i_loc
maxLoc_r = i

minVal = torch.min(minVal, currentVal)

if currentVal > minVal * gamma:
if currentVal < maxVal / gamma:
peaks.append((maxLoc, maxVal, i - maxLoc)) # Store peak info
minVal = currentVal # Update minimum value for next peak detection
maxVal = currentVal # Reset maxVal for next peak search

# Filter out peaks below sidelobe level
valid_peaks = []
if locateMax:
if currentVal < (maxVal / gamma): # Peak found
maxData.append([maxLoc, maxVal, i - maxLoc_r, maxLoc_r])
numMax += 1
minVal = currentVal
locateMax = False
else:
if currentVal > minVal * gamma: # Valley found
locateMax = True
maxVal = currentVal
if initStage:
extendLoc = i
initStage = False

# Filter peaks based on sidelobe threshold
absMaxValue_db = absMaxValue * (10 ** (-sidelobeLevel_dB / 10))
maxData = [data for data in maxData if data[1] >= absMaxValue_db]

for peak in peaks:
if peak[1] >= absMaxValue_db: # Only consider peaks above sidelobe threshold
valid_peaks.append(peak)

# Extract peak values and locations
peakVal = torch.tensor([peak[1] for peak in valid_peaks], dtype=torch.float32)
peakLoc = torch.tensor([peak[0] for peak in valid_peaks], dtype=torch.long)

return peakVal, peakLoc
# Convert results to torch tensors
if len(maxData) > 0:
maxData = torch.tensor(maxData)
peakVal = maxData[:, 1]
peakLoc = (maxData[:, 0] % N) + 1

return peakVal.to(torch.float64), peakLoc.to(torch.long)
8 changes: 6 additions & 2 deletions utility/visualizer_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,15 @@ def PCD_display(point_cloud_data):
x = point_cloud_data[:, 2]
y = point_cloud_data[:, 3]
z = point_cloud_data[:, 4]
signal_power = 10 * np.log(10, point_cloud_data[:, 9])
velocity = point_cloud_data[:, 6]
signal_power = 10 * np.log10(point_cloud_data[:, 9])

# Create 3D scatter plot
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')

# Scatter plot with color based on frame indices
scatter = ax.scatter(x, y, z, c=signal_power, cmap='viridis', s=20, alpha=0.8)
scatter = ax.scatter(x, y, z, c=velocity, cmap='viridis', s=20, alpha=0.8)

# Add a color bar
cbar = plt.colorbar(scatter, ax=ax, pad=0.1, shrink=0.8)
Expand All @@ -41,6 +42,9 @@ def PCD_display(point_cloud_data):
ax.set_xlabel('X (meters)')
ax.set_ylabel('Y (meters)')
ax.set_zlabel('Z (meters)')
# ax.set_xlim((-2, 2))
# ax.set_ylim((0, 3))
# ax.set_zlim((0, 3))

# Set title and show
ax.set_title('3D Point Cloud')
Expand Down

0 comments on commit 6fa0ad9

Please sign in to comment.