Skip to content

Commit

Permalink
feat: add FFT Processor
Browse files Browse the repository at this point in the history
  • Loading branch information
shelta-zhao committed Jan 18, 2025
1 parent b9cb8c6 commit cceda6c
Show file tree
Hide file tree
Showing 8 changed files with 143 additions and 15 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# radar data
# File Ignore Addition
/data
*.jpg
*.png
*.bin
*.csv
*.json
*.mat
*.DS_Store

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
File renamed without changes.
Empty file removed module/adc_to_pcd.py
Empty file.
115 changes: 115 additions & 0 deletions module/fft_process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""
Author : Shelta Zhao(赵小棠)
Email : xiaotang_zhao@outlook.com
Copyright (C) : NJU DisLab, 2025.
Description : Define FFT Processor, including Range FFT & Doppler FFT.
"""

import os
import yaml
import torch
from parser.param_process import get_radar_params
from parser.adc_load import get_regular_data


class FFTProcessor:
def __init__(self, rangeFFTObj, dopplerFFTObj, device='cpu'):
"""
Initialize the FFTProcessor with range and Doppler FFT configurations.
Parameters:
rangeFFTObj (dict): Configuration for range FFT.
dopplerFFTObj (dict): Configuration for Doppler FFT.
device (str): The device to perform the computation on ('cpu' or 'cuda').
"""

self.rangeFFTObj = rangeFFTObj
self.dopplerFFTObj = dopplerFFTObj
self.device = device

def range_fft(self, input):
"""
Perform Range FFT on the input data.
Parameters:
input (np.ndarray): The input data to be transformed.
Returns:
torch.Tensor: The range FFT result. Shape: (num_frames, range_fft_size, num_chirps, num_rx, num_tx)
"""

# Get the basic fft parmas
radar_type, fft_size = self.rangeFFTObj['radarPlatform'], self.rangeFFTObj['rangeFFTSize']
dc_on, win_on = self.rangeFFTObj['dcOffsetCompEnable'], self.rangeFFTObj['rangeWindowEnable']
scale_on, scale_factor = self.rangeFFTObj['FFTOutScaleOn'], self.rangeFFTObj['scaleFactorRange']

# Convert input to tensor
input = torch.tensor(input, dtype=torch.complex64).to(self.device)
# Generate window coefficient
win_coeff = torch.hann_window(input.shape[1] + 2, periodic=True).to(self.device)[1:-1]
# Apply DC offset compensation
input = input - torch.mean(input, dim=1, keepdim=True) if dc_on else input
# Apply range-domain windowing
input = input * win_coeff.view(-1, 1, 1, 1) if win_on else input
# Perform FFT for each TX/RX chain
fft_output = torch.fft.fft(input, n=fft_size, dim=1)
# Apply scale factor
fft_output = fft_output * scale_factor if scale_on else fft_output
# Phase compensation for IWR6843ISK-ODS
if radar_type == 'IWR6843ISK-ODS':
fft_output[:, :, :, 1:3, :] *= torch.exp(-1j * torch.pi)

# Return range fft result
return fft_output

def doppler_fft(self, input):
"""
Perform Doppler FFT on the input data (result of range FFT).
Parameters:
input (torch.Tensor): The input data to be transformed (result of range fft).
Returns:
torch.Tensor: The Doppler FFT result. Shape: (num_frames, range_fft_size, doppler_fft_size, num_rx, num_tx)
"""

# Get the basic fft parmas
fft_size, win_on = self.dopplerFFTObj['dopplerFFTSize'], self.dopplerFFTObj['dopplerWindowEnable']
scale_on, scale_factor = self.dopplerFFTObj['FFTOutScaleOn'], self.dopplerFFTObj['scaleFactorDoppler']

# Generate window coefficient
win_coeff = torch.hann_window(input.shape[1] + 2, periodic=True).to(self.device)[1:-1]
# Apply Doppler-domain windowing
input = input * win_coeff.view(-1, 1, 1, 1) if win_on else input
# Perform FFT for each TX/RX chain
fft_output = torch.fft.fftshift(torch.fft.fft(input, n=fft_size, dim=2), dim=2)
# Apply scale factor
fft_output = fft_output * scale_factor if scale_on else fft_output

# Return doppler fft result
return fft_output.cpu() if self.device == 'cuda' else fft_output


if __name__ == "__main__":

# Parse data config & Get radar params
with open("adc_list.yaml", "r") as file:
data = yaml.safe_load(file)
data_path = os.path.join("data/adc_data", f"{data['prefix']}/{data['index']}")
config_path = os.path.join("data/radar_config", data["config"])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Get radar params
radar_params = get_radar_params(config_path, data['radar'], load=True)

# Get regular raw radar data
regular_data, timestamp = get_regular_data(data_path, radar_params['readObj'], '1', timestamp=True)

# Test Range FFT & Doppler FFT
fft_processor = FFTProcessor(radar_params['rangeFFTObj'], radar_params['dopplerFFTObj'], device)

range_fft_output = fft_processor.range_fft(regular_data)
print(range_fft_output.shape)

doppler_fft_out = fft_processor.doppler_fft(range_fft_output)
print(doppler_fft_out.shape)
14 changes: 7 additions & 7 deletions parser/adc_load.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""
Author : Shelta Zhao(赵小棠)
Affiliation : Nanjing University
Email : xiaotang_zhao@outlook.com
Description : Load & Get regular radar data.
Author : Shelta Zhao(赵小棠)
Email : xiaotang_zhao@outlook.com
Copyright (C) : NJU DisLab, 2025.
Description : Load & Get regular radar data.
"""

import os
Expand Down Expand Up @@ -281,10 +281,10 @@ def process_frame(frame_idx):
if __name__ == "__main__":

# Parse data config & Get readObj
with open("data2parse.yaml", "r") as file:
with open("adc_list.yaml", "r") as file:
data = yaml.safe_load(file)
data_path = os.path.join("datas/adcDatas", f"{data['prefix']}/{data['index']}")
config_path = os.path.join("datas/configs", data["config"])
data_path = os.path.join("data/adc_data", f"{data['prefix']}/{data['index']}")
config_path = os.path.join("data/radar_config", data["config"])
readObj = generate_params(config_path, data['radar'])['readObj']

# Test timestamp extraction
Expand Down
15 changes: 8 additions & 7 deletions parser/param_process.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""
Author : Shelta Zhao(赵小棠)
Affiliation : Nanjing University
Email : xiaotang_zhao@outlook.com
Description : Parses mmWave Studio config JSON files.
Author : Shelta Zhao(赵小棠)
Email : xiaotang_zhao@outlook.com
Copyright (C) : NJU DisLab, 2025.
Description : Parses mmWave Studio config JSON files.
"""

import os
Expand Down Expand Up @@ -616,13 +616,14 @@ def convert_numpy(obj):
if __name__ == "__main__":

# Parse radar config
with open("data2parse.yaml", "r") as file:
with open("adc_list.yaml", "r") as file:
data = yaml.safe_load(file)
config_path = os.path.join("datas/configs", data["config"])
config_path = os.path.join("data/radar_config", data["config"])

# Test generate params
radar_params = generate_params(config_path, data['radar'])
if not radar_params:
print("Invalid JSON files")
else:
print(radar_params)
print(yaml.dump(radar_params))

11 changes: 11 additions & 0 deletions pipline/adc_to_pcd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""
Author : Shelta Zhao(赵小棠)
Email : xiaotang_zhao@outlook.com
Copyright (C) : NJU DisLab, 2025.
Description : Traditon pipline to generate Point Cloud Data (PCD) from raw radar data.
"""

import os
import yaml
import torch
from module.fft_process import FFTProcessor
File renamed without changes.

0 comments on commit cceda6c

Please sign in to comment.