Source code for the paper "Speaking the Same Language: Matching Machine to Human Captions by Adversarial Training"
The starting point for the adversarial caption generator training is "train_adversarial_caption_gen_v2.py". To train regular captioning models (using ML losss) start with "driver_theano.py"
This is built Python+numpy+theano. It's a large codebase containing the code to implement captioning frameworks used in the following papers:
- "Speaking the Same Language: Matching Machine to Human Captions by Adversarial Training" (https://arxiv.org/abs/1703.10476)
- "Paying Attention to Descriptions Generated by Image Captioning Models" (https://arxiv.org/abs/1704.07434)
- "Exploiting scene context for image captioning" (https://dl.acm.org/citation.cfm?id=2983571)
- "Frame-and segment-level features and candidate pool evaluation for video caption generation" (https://arxiv.org/abs/1608.04959)
- "Video captioning with recurrent networks based on frame-and video-level features and visual content classification" (https://arxiv.org/abs/1512.02949)
-
Make sure you have theano installed and working. As a quick check "import theano" should work without any errors on a python shell
-
The code expects the data files to be in "data/<dataset_name>" directory. It needs a .json file containing all the training/validation/test samples and we need a .npy/.mat/.bin feature files containin the CNN features for each of the samples. Actual images are only needed for visualisation of results and are not needed during training.
-
The data and some pre-trained models can be downloaded from the below links. This doesn't include image features. You can use any extracted features for this purpose. The pre-trained models use ResNet features extracted as in (https://github.com/akirafukui/vqa-mcb/tree/master/preprocess). Since the feature files are large, I have not uploaded them here.
Data: dataset.json, labels
Features: Faster-RCNN detection binary features
Baseline beamsearch(5): val set , test set
Our GAN model (beam-5): val set , test set
- Training the adversarial model
THEANO_FLAGS='mode=FAST_RUN,floatX=float32,device=cuda3' python train_adversarial_caption_gen_v2.py --maxlen 20 -o cvCoco/advDummy --fappend r-dep3-frc80-resnet-1samp-pretrainBOTH-miniBatchDiscr-GUMBHard0p5-smooth3-noMle-featMatchPsentEmbMatch-randInp50d --batch_size 10 --eval_period 0.5 --max_epochs 50 --feature_file fasterRcnn_clasDetFEat80.npy --eval_feature aux_inp --aux_inp_file resnet150_2048-mean.npy -ld 1e-5 -cb 50 --word_encoding_size 512 --sent_encoding_size 400 --solver rmsprop --train_evaluator_only 0 --use_gumbel_mse 1 -lg 1e-6 --eval_model lstm_eval --eval_init trainedModels/advers/evaluators/advmodel_checkpoint_coco_wks-12-46_r-reg-res150mean-5samp-lstmevalonly_318_94.22_EVOnly.p --disk_feature 0 --metrics_to_track meteor cider len lcldiv_1 lcldiv_2 --gumbel_temp_init 0.5 --use_gumbel_hard 1 --hidden_depth 3 --en_residual_conn 1 --n_gen_samples 5 --merge_dim 50 --softmax_smooth_factor 3.0 --use_mle_train 0 --rev_eval 1 --gen_input_noise 1 --gen_feature_matching 1 --continue_training trainedModels/coco/mpi/model_checkpoint_coco_wks-12-46_r-dep3-frc80-resnet150mean_per9.32.p
- Generating captions
THEANO_FLAGS='mode=FAST_RUN,floatX=float32,device=cuda1' python predict_on_images.py cvCoco/advDummy/advmodel_checkpoint_coco_wks-12-46_r-dep3-frc80-resnet-5samp-pretrainBOTH-miniBatchDiscr-GUMBHard0p5-smooth3-noMle-featMatch-randInp50d_55999_15.20_genacc.p --aux_inp_file data/coco/resnet150_2048-mean.npy -f data/coco/fasterRcnn_clasDetFEat80.npy -i imgLists/imgListCOCO_MiniTestSet_ranzato.txt --fname_append ranzatotest_MLE-20Wrd-Smth3-randInpFeatMatch-ResnetMean-56k-beamsearch5 --softmax_smooth_factor 3.0 --labels data/coco/labels.txt --greedy 0 --computelogprob 1 --dobeamsearch 1 -b 5
Example image list file is here: https://drive.google.com/open?id=0B76QzqVJdOJ5NUtEMkx4ZzNKRWM
- Pre-Training the caption generator
THEANO_FLAGS='mode=FAST_RUN,floatX=float32,device=cuda0' python driver_theano.py -d coco -l 1e-4 --maxlen 20 --decay_rate 0.999 --grad_clip 10.0 --image_encoding_size 512 --word_encoding_size 512 --hidden_size 512 -o cvCoco/salLclExpts --fappend r-dep3-frc80-resnet150mean --worker_status_output_directory statusCoco/c1 --write_checkpoint_ppl_threshold 14 --regc 2.66e-07 --batch_size 256 --eval_period 0.5 --max_epochs 60 --eval_batch_size 256 --aux_inp_file resnet150_2048-14-14.npzl --feature_file fasterRcnn_clasDetFEat80.npy --data_file dataset.json --sample_by_len 1 --lr_decay_st_epoch 1 --lr_decay 0.99 --disk_feature 2 --hidden_depth 3 --en_residual_conn 1 --poolmethod "none mean"
Some of the code and structure is based on original neuraltalk code relased by Andrej Karpath at https://github.com/karpathy/neuraltalk