๋ณธ ๋ํ์ ์ฃผ์ ๋ ๋น ๋ฒ์ง์ผ๋ก ์ ํ๋ ์นด๋ฉ๋ผ ์ด๋ฏธ์ง ํ์ง์ ํฅ์์ํค๋ AI ๋ชจ๋ธ ๊ฐ๋ฐ์ด์์ต๋๋ค. ์ฃผ์ด์ง ์ด๋ฏธ์ง๋ ์๋ ๊ทธ๋ฆผ๊ณผ ๊ฐ์ด ๋น๋ฒ์ง, ๋ธ๋ฌ ํ์ ๋ฑ์ ํฌํจํ์ฌ ๋ณตํฉ์ ์ธ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ ํ์๊ฐ ์์์ต๋๋ค.
- ํ์ต ๋ฐ์ดํฐ: 272์ฅ์ 2448ร3264, 350์ฅ์ 1224ร1632 ์ด๋ฏธ์ง๋ก ๊ตฌ์ฑ๋ 622์ฅ์ ๋น๋ฒ์ง ์ด๋ฏธ์ง
- ํ ์คํธ ๋ฐ์ดํฐ: 20์ฅ์ 2448ร3264 ๋น๋ฒ์ง ์ด๋ฏธ์ง
- PSNR(Peak Signal-to-noise ratio): ์ ํธ๊ฐ ๊ฐ์ง ์ ์๋ ์ต๋ ์ ๋ ฅ์ ๋ํ ์ก์์ ์ ๋ ฅ์ ๋ํ๋ธ ๊ฒ์ผ๋ก, ์์ ๋๋ ๋์์ ์์ค ์์ถ์์ ํ์ง ์์ค ์ ๋ณด๋ฅผ ํ๊ฐํ ๋ ์ฌ์ฉ๋ฉ๋๋ค.
# clone repository
$ git clone https://github.com/TeamBCP5/image-reconstruction.git
# install necessary tools
$ pip install -r requirements.txt
# Download: https://dacon.io/competitions/official/235746/data
[camera_dataset]/
โโโ train_input_img/ # ํ์ต ๋ฐ์ดํฐ ์
๋ ฅ ์ด๋ฏธ์ง
โโโ train_label_img/
โโโ hinet_dataset/ # postprocessing ๋ฐ์ดํฐ์
๋๋ ํ ๋ฆฌ NOTE. ํ์ต ๊ณผ์ ์ค ๊ตฌ์ถ๋๋ ๋๋ ํ ๋ฆฌ์
๋๋ค.
โ โโโ train_input_img/
โ โโโ train_label_img/
โโโ test_input_img/
[code]
โโโ camera_dataset/ # ๋ฐ์ดํฐ์
๋๋ ํ ๋ฆฌ
โโโ configs/ # ๋ชจ๋ธ config ํ์ผ ๋๋ ํ ๋ฆฌ
โโโ data/ # data ์ฒ๋ฆฌ ๊ด๋ จ ๋ชจ๋ ๋๋ ํ ๋ฆฌ
โโโ networks/ # ๋ชจ๋ธ ์ํคํ
์ฒ ๊ด๋ จ ๋ชจ๋ ๋๋ ํ ๋ฆฌ
โโโ train_modules/ # ๋ชจ๋ธ ํ์ต ๊ด๋ จ ๋ชจ๋ ๋๋ ํ ๋ฆฌ
โโโ utils/ # ์ ํธ๋ฆฌํฐ ๊ด๋ จ ๋ชจ๋ ๋๋ ํ ๋ฆฌ
โโโ README.md
โโโ requirements.txt
โโโ demo_augmentations.py # Augmentation ํ
์คํธ๋ฅผ ์ํ ์คํฌ๋ฆฝํธ ํ์ผ
โโโ demo_preprocessing.py # Preprocessing ํ
์คํธ๋ฅผ ์ํ ์คํฌ๋ฆฝํธ ํ์ผ
โโโ train.py
โโโ inference.py
์นด๋ฉ๋ผ ์ด๋ฏธ์ง ํ์ง ๊ฐ์ ๊ณผ์ ์ ๋๋ต ๋ค์ ๊ทธ๋ฆผ๊ณผ ๊ฐ์ต๋๋ค. Sliding Window ๊ธฐ๋ฐ์ Pix2Pix ๋ชจ๋ธ์ ํตํด 1์ฐจ์ ์ผ๋ก ๋น๋ฒ์ง์ ์ ๊ฑฐํ ๋ค, HINet ๋ชจ๋ธ์ ํตํด ๊ฒฉ์ ๋ฌด๋ฌ ๋ฑ ์์๋ ํ์ง์ ๋ณด์ํฉ๋๋ค. ํนํ, ํ์ต ๋จ๊ณ์์ Pix2Pix Generator๋ Discriminator์ ํจ๊ป ํ์ต๋ฉ๋๋ค.
๋ชจ๋ธ ํ์ต๊ณผ ์ถ๋ก ์ ๊ธฐ๋ณธ์ ์ผ๋ก ๋ชจ๋ธ๋ณ Configuration ํ์ผ์ ๋ฐํ์ผ๋ก ์งํ๋ฉ๋๋ค. ๊ฐ Configuration ํ์ผ์๋ ๋ชจ๋ธ ๊ตฌ์กฐ์ ํ์ต ๋ฐ์ดํฐ์ ๊ฒฝ๋ก ๋ฑ ํ์ต๊ณผ ์ถ๋ก ์ ์ํ ์ค์ ๊ฐ์ด ๊ธฐ๋ก๋์ด ์์ต๋๋ค. ์ํํ ํ์ต/์ถ๋ก ์ ์ํด์๋ ๋ฐ์ดํฐ์ ๊ฒฝ๋ก ๋ฑ ์ค์ ๊ฐ์ ํ๊ฒฝ์ ๋ง๊ฒ ์ค์ ํด์ฃผ์ ์ผ ํฉ๋๋ค. Configuration ํ์ผ ๋ช ์ธ๋ ์ด๊ณณ์์ ํ์ธํ์ค ์ ์์ต๋๋ค.
์ต์ข ๊ฒฐ๊ณผ๋ฌผ ์ ์ถ์ ํ์ฉ๋ ๋ชจ๋ธ์ ๋ค์์ 3๋จ๊ณ์ ๊ฑธ์น ํ์ต์ ํตํด ์ ์๋์์ต๋๋ค.
- Sliding Window ๋ฐฉ๋ฒ์ ๋ฐํ์ผ๋ก ์ด๋ฏธ์ง ํ์ง์ ํฅ์์ํค๋ ๋ฉ์ธ ๋ชจ๋ธ(Pix2Pix)์ ํ์ตํฉ๋๋ค.
- Input. ๋ํ์์ ์ฃผ์ด์ง ํ์ต ๋ฐ์ดํฐ์ input ์ด๋ฏธ์ง
- Label. ๋ํ์์ ์ฃผ์ด์ง ํ์ต ๋ฐ์ดํฐ์ label ์ด๋ฏธ์ง
- ํ์ฒ๋ฆฌ ๋ชจ๋ธ(HINet)์ ์ฃผ์ด์ง ๋ฐ์ดํฐ๋ฅผ ํ์ฉํ์ฌ 1์ฐจ์ ์ผ๋ก ํ์ตํฉ๋๋ค.
- Input. ๋ํ์์ ์ฃผ์ด์ง ํ์ต ๋ฐ์ดํฐ์ input ์ด๋ฏธ์ง
- Label. ๋ํ์์ ์ฃผ์ด์ง ํ์ต ๋ฐ์ดํฐ์ label ์ด๋ฏธ์ง
- II์์ ํ์ตํ ํ์ฒ๋ฆฌ ๋ชจ๋ธ(HINet)์ ๋ถ๋ฌ์ ํ์ต์ ์งํํฉ๋๋ค.
- Input. ๋ํ์์ ์ฃผ์ด์ง ํ์ต ๋ฐ์ดํฐ์ input ์ด๋ฏธ์ง์ ๋ํ I์์ ํ์ตํ ๋ฉ์ธ ๋ชจ๋ธ(Pix2Pix)์ ์ถ๋ก ๊ฒฐ๊ณผ
- Label. ๋ํ์์ ์ฃผ์ด์ง ํ์ต ๋ฐ์ดํฐ์ label ์ด๋ฏธ์ง
์ ํ์ต ๋จ๊ณ๋ฅผ ๋ชจ๋ ํฌํจํ ํ์ต์ ์ํํฉ๋๋ค.
$ python train.py --train_type 'all'
๋จ๊ณ I์ ํด๋น๋๋ Pix2Pix ๋ชจ๋ธ ํ์ต์ ์ํํฉ๋๋ค.
$ python train.py --train_type 'pix2pix'
๋จ๊ณ II์ ํด๋น๋๋ HINet ๋ชจ๋ธ ํ์ต์ ์ํํฉ๋๋ค.
$ python train.py --train_type 'hinet'
train_type
: ํ์ต ๋ฐฉ์ ์ค์
'all'
: ์ ์ธ ๋จ๊ณ์ ๊ฑธ์น ํ์ต์ ์งํํฉ๋๋ค. ์ต์ข ๊ฒฐ๊ณผ๋ฌผ ์ฌํ์๋ ์ด ์ค์ ๊ฐ์ ์ฌ์ฉ๋ฉ๋๋ค.'pix2pix'
: Pix2Pix ๋ชจ๋ธ์ ๊ฐ๋ณ ํ์ต์ ์ํํฉ๋๋ค.'hinet'
: HINet ๋ชจ๋ธ์ ๊ฐ๋ณ ํ์ต์ ์ํํฉ๋๋ค. '๋จ๊ณ II. ํ์ฒ๋ฆฌ ๋ชจ๋ธ(HINet) 1์ฐจ ํ์ต'์ ๊ธฐ์ค์ผ๋ก ํ์ต์ด ์งํ๋ฉ๋๋ค.
config_pix2pix
: Pix2Pix ๋ชจ๋ธ configuration ํ์ผ ๊ฒฝ๋ก
config_hinet_phase1
: HINet ๋ชจ๋ธ(phase1) configuration ํ์ผ ๊ฒฝ๋ก
config_hinet_phase2
: HINet ๋ชจ๋ธ(phase2) configuration ํ์ผ ๊ฒฝ๋ก
๋ฉ์ธ ๋ชจ๋ธ(Pix2Pix)๊ณผ ํ์ฒ๋ฆฌ ๋ชจ๋ธ(HINet)์ ๋ถ๋ฌ์ ์ถ๋ก ์ ์ํํฉ๋๋ค. ์ถ๋ก ์ ๋ค์์ ๋ ๋จ๊ณ๋ฅผ ๊ฑฐ์ณ ์งํ๋ฉ๋๋ค.
$ python inference.py --checkpoint_main "./best_models/pix2pix.pth" --checkpoint_post "./best_models/hinet.pth" --image_dir "./camera_dataset/test_input_img"
- Input. ๋ํ์์ ์ฃผ์ด์ง ํ ์คํธ ๋ฐ์ดํฐ์ input ์ด๋ฏธ์ง
- Input. ๋จ๊ณ I์์ ๋ฉ์ธ ๋ชจ๋ธ์ ์ถ๋ก ๊ฒฐ๊ณผ
- ํด๋น ๋จ๊ณ์์์ ๊ฒฐ๊ณผ๋ฌผ์ด ์ต์ข ์ถ๋ก ๊ฒฐ๊ณผ๋ฌผ๋ก ์ ์ฅ๋ฉ๋๋ค.
config_main
: Main ๋ชจ๋ธ(Pix2Pix) config ํ์ผ ๊ฒฝ๋ก
config_post
: Postprocessing ๋ชจ๋ธ(HINet) config ํ์ผ ๊ฒฝ๋ก
checkpoint_main
: ํ์ตํ main ๋ชจ๋ธ(Pix2Pix)์ pth ํ์ผ ๊ฒฝ๋ก
checkpoint_post
: ํ์ตํ postprocessing ๋ชจ๋ธ(HINet)์ pth ํ์ผ ๊ฒฝ๋ก
image_dir
: ์ถ๋ก ์ ์ฌ์ฉ๋ ๋ฐ์ดํฐ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก
patch_size
: ์ถ๋ก ์ ์ฌ์ฉ๋ ์ด๋ฏธ์ง patch์ ํฌ๊ธฐ
stride
: ์ถ๋ก ์ ์ฌ์ฉ๋ stride์ ํฌ๊ธฐ
batch_size
: ์ถ๋ก ์ ์ฌ์ฉ๋ batch์ ํฌ๊ธฐ
output_dir
: ์ถ๋ก ๊ฒฐ๊ณผ๋ฅผ ์ ์ฅํ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก. ํด๋น ๋๋ ํ ๋ฆฌ ๋ด ์์ถํ์ผ ํํ๋ก ๊ฒฐ๊ณผ๋ฌผ์ด ์ ์ฅ๋ฉ๋๋ค.
๋ชจ๋ธ ํ์ต์ ํ์ฉํ data augmentation์ ์์ ๊ฒฐ๊ณผ๋ฌผ์ ์์ฑํฉ๋๋ค.
$ python demo_augmentation.py --data_dir "./camera_dataset/" --num_samples 10 --save_dir './sample_augmentation/'
[SAVE_DIR]
โโโ original/ # ์๋ณธ ์ด๋ฏธ์ง
โโโ hinet/ # HINet์ ์ํ data augmentation ๊ฒฐ๊ณผ๋ฌผ
โโโ pix2pix/ # pix2pix๋ฅผ ์ํ data augmentation ๊ฒฐ๊ณผ๋ฌผ
data_dir
: input ๋ฐ์ดํฐ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก
num_samples
: ์์ฑํ ์ํ ์
save_dir
: Augmentation ์ ์ฉ ๊ฒฐ๊ณผ๋ฅผ ์ ์ฅํ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก
๋ชจ๋ธ ํ์ต์ ํ์ฉํ data preprocessing์ ์์ ๊ฒฐ๊ณผ๋ฌผ์ ์์ฑํฉ๋๋ค.
$ python demo_preprocessing.py --data_dir "./camera_dataset/" --num_samples 10 --save_dir './sample_preprocessing/'
[SAVE_DIR]
โโโ original/ # ์๋ณธ ์ด๋ฏธ์ง
โโโ hinet/ # HINet์ ์ํ data preprocessing ๊ฒฐ๊ณผ๋ฌผ
โโโ pix2pix/ # pix2pix๋ฅผ ์ํ data preprocessing ๊ฒฐ๊ณผ๋ฌผ
data_dir
: input ๋ฐ์ดํฐ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก
num_samples
: ์์ฑํ ์ํ ์
save_dir
: Augmentation ์ ์ฉ ๊ฒฐ๊ณผ๋ฅผ ์ ์ฅํ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก
stride
: Sliding Window ์ ์ฌ์ฉํ stride
patch_size
: Sliding Window ์ ์ฌ์ฉํ patch ์ฌ์ด์ฆ