Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

strange detection result depend on bbox_pred #186

Open
daf11865 opened this issue May 13, 2016 · 19 comments
Open

strange detection result depend on bbox_pred #186

daf11865 opened this issue May 13, 2016 · 19 comments

Comments

@daf11865
Copy link

daf11865 commented May 13, 2016

hi everyone, I encounter a very strange problem either is my misunderstanding or there is a bug?
I've trained faster_rcnn_end2end with ZF in pascal, using original script's cmd.
also I've traind faster_rcnn_end_to_end with VGG16 in imagenet(only 5 cls).

but while I applied these two model on some image, let's say ZF, the predictng bbox is very bad
IF THE NAME OF 'pred_bbox' IN test.prototxt IS 'pred_bbox'(same name as train.prototxt). the result is like
test

but if I CHANGE THE NAME FROM 'pred_bbox' IN test.prototxt TO ANY DIFFERENT NAME, like 'pred_bbox_xxx', the result is acceptable, like
test

this means that when using 'pred_bbox', the model uses the trained weight in this layer, which makes the predecting bbox bad.
but when changing to 'pred_bbox_xxx', the model uses the random initialized weights in this layer, which strangely makes the result good.

it seems like the bbox regression isn't traind well, but how come? I use the same procedure offered by
rbgirshick.
here is my test.pt and train.pt and my cmd to train as same as original file without modified

cmd:
time ./tools/train_net.py --gpu 0 --solver models/pascal_voc/ZF/faster_rcnn_end2end/solver.prototxt --weights data/imagenet_models/ZF.v2.caffemodel --iters 70000 --imdb voc_2007_trainval --cfg experiments/cfgs/faster_rcnn_end2end.yml

train.prototxt(only list fasterrcnn part):
name: "ZF"
layer {
name: 'input-data'
type: 'Python'
top: 'data'
top: 'im_info'
top: 'gt_boxes'
python_param {
module: 'roi_data_layer.layer'
layer: 'RoIDataLayer'
param_str: "'num_classes': 21"
}
}

========= conv1-conv5 ============

.
.
.

========= RPN ============

layer {
name: "rpn_conv/3x3"
type: "Convolution"
bottom: "conv5"
top: "rpn/output"
param { lr_mult: 1.0 }
param { lr_mult: 2.0 }
convolution_param {
num_output: 256
kernel_size: 3 pad: 1 stride: 1
weight_filler { type: "gaussian" std: 0.01 }
bias_filler { type: "constant" value: 0 }
}
}
layer {
name: "rpn_relu/3x3"
type: "ReLU"
bottom: "rpn/output"
top: "rpn/output"
}
layer {
name: "rpn_cls_score"
type: "Convolution"
bottom: "rpn/output"
top: "rpn_cls_score"
param { lr_mult: 1.0 }
param { lr_mult: 2.0 }
convolution_param {
num_output: 18 # 2(bg/fg) * 9(anchors)
kernel_size: 1 pad: 0 stride: 1
weight_filler { type: "gaussian" std: 0.01 }
bias_filler { type: "constant" value: 0 }
}
}
layer {
name: "rpn_bbox_pred"
type: "Convolution"
bottom: "rpn/output"
top: "rpn_bbox_pred"
param { lr_mult: 1.0 }
param { lr_mult: 2.0 }
convolution_param {
num_output: 36 # 4 * 9(anchors)
kernel_size: 1 pad: 0 stride: 1
weight_filler { type: "gaussian" std: 0.01 }
bias_filler { type: "constant" value: 0 }
}
}
layer {
bottom: "rpn_cls_score"
top: "rpn_cls_score_reshape"
name: "rpn_cls_score_reshape"
type: "Reshape"
reshape_param { shape { dim: 0 dim: 2 dim: -1 dim: 0 } }
}
layer {
name: 'rpn-data'
type: 'Python'
bottom: 'rpn_cls_score'
bottom: 'gt_boxes'
bottom: 'im_info'
bottom: 'data'
top: 'rpn_labels'
top: 'rpn_bbox_targets'
top: 'rpn_bbox_inside_weights'
top: 'rpn_bbox_outside_weights'
python_param {
module: 'rpn.anchor_target_layer'
layer: 'AnchorTargetLayer'
param_str: "'feat_stride': 16"
}
}
layer {
name: "rpn_loss_cls"
type: "SoftmaxWithLoss"
bottom: "rpn_cls_score_reshape"
bottom: "rpn_labels"
propagate_down: 1
propagate_down: 0
top: "rpn_cls_loss"
loss_weight: 1
loss_param {
ignore_label: -1
normalize: true
}
}
layer {
name: "rpn_loss_bbox"
type: "SmoothL1Loss"
bottom: "rpn_bbox_pred"
bottom: "rpn_bbox_targets"
bottom: 'rpn_bbox_inside_weights'
bottom: 'rpn_bbox_outside_weights'
top: "rpn_loss_bbox"
loss_weight: 1
smooth_l1_loss_param { sigma: 3.0 }
}

========= RoI Proposal ============

layer {
name: "rpn_cls_prob"
type: "Softmax"
bottom: "rpn_cls_score_reshape"
top: "rpn_cls_prob"
}
layer {
name: 'rpn_cls_prob_reshape'
type: 'Reshape'
bottom: 'rpn_cls_prob'
top: 'rpn_cls_prob_reshape'
reshape_param { shape { dim: 0 dim: 18 dim: -1 dim: 0 } }
}
layer {
name: 'proposal'
type: 'Python'
bottom: 'rpn_cls_prob_reshape'
bottom: 'rpn_bbox_pred'
bottom: 'im_info'
top: 'rpn_rois'
python_param {
module: 'rpn.proposal_layer'
layer: 'ProposalLayer'
param_str: "'feat_stride': 16"
}
}
layer {
name: 'roi-data'
type: 'Python'
bottom: 'rpn_rois'
bottom: 'gt_boxes'
top: 'rois'
top: 'labels'
top: 'bbox_targets'
top: 'bbox_inside_weights'
top: 'bbox_outside_weights'
python_param {
module: 'rpn.proposal_target_layer'
layer: 'ProposalTargetLayer'
param_str: "'num_classes': 21"
}
}

========= RCNN ============

layer {
name: "roi_pool_conv5"
type: "ROIPooling"
bottom: "conv5"
bottom: "rois"
top: "roi_pool_conv5"
roi_pooling_param {
pooled_w: 6
pooled_h: 6
spatial_scale: 0.0625 # 1/16
}
}
layer {
name: "fc6"
type: "InnerProduct"
bottom: "roi_pool_conv5"
top: "fc6"
param { lr_mult: 1.0 }
param { lr_mult: 2.0 }
inner_product_param {
num_output: 4096
}
}
layer {
name: "relu6"
type: "ReLU"
bottom: "fc6"
top: "fc6"
}
layer {
name: "drop6"
type: "Dropout"
bottom: "fc6"
top: "fc6"
dropout_param {
dropout_ratio: 0.5
scale_train: false
}
}
layer {
name: "fc7"
type: "InnerProduct"
bottom: "fc6"
top: "fc7"
param { lr_mult: 1.0 }
param { lr_mult: 2.0 }
inner_product_param {
num_output: 4096
}
}
layer {
name: "relu7"
type: "ReLU"
bottom: "fc7"
top: "fc7"
}
layer {
name: "drop7"
type: "Dropout"
bottom: "fc7"
top: "fc7"
dropout_param {
dropout_ratio: 0.5
scale_train: false
}
}
layer {
name: "cls_score"
type: "InnerProduct"
bottom: "fc7"
top: "cls_score"
param { lr_mult: 1.0 }
param { lr_mult: 2.0 }
inner_product_param {
num_output: 21
weight_filler {
type: "gaussian"
std: 0.01
}
bias_filler {
type: "constant"
value: 0
}
}
}
layer {
name: "bbox_pred"
type: "InnerProduct"
bottom: "fc7"
top: "bbox_pred"
param { lr_mult: 1.0 }
param { lr_mult: 2.0 }
inner_product_param {
num_output: 84
weight_filler {
type: "gaussian"
std: 0.001
}
bias_filler {
type: "constant"
value: 0
}
}
}
layer {
name: "loss_cls"
type: "SoftmaxWithLoss"
bottom: "cls_score"
bottom: "labels"
propagate_down: 1
propagate_down: 0
top: "cls_loss"
loss_weight: 1
loss_param {
ignore_label: -1
normalize: true
}
}
layer {
name: "loss_bbox"
type: "SmoothL1Loss"
bottom: "bbox_pred"
bottom: "bbox_targets"
bottom: 'bbox_inside_weights'
bottom: 'bbox_outside_weights'
top: "bbox_loss"
loss_weight: 1
}

test.prototxt:(only list fasterrcnn part)
name: "ZF"

input: "data"
input_shape {
dim: 1
dim: 3
dim: 224
dim: 224
}

input: "im_info"
input_shape {
dim: 1
dim: 3
}

========= conv1-conv5 ============

.
.
.

========= RPN ============

layer {
name: "rpn_conv/3x3"
type: "Convolution"
bottom: "conv5"
top: "rpn/output"
convolution_param {
num_output: 256
kernel_size: 3 pad: 1 stride: 1
weight_filler { type: "gaussian" std: 0.01 }
bias_filler { type: "constant" value: 0 }
}
}
layer {
name: "rpn_relu/3x3"
type: "ReLU"
bottom: "rpn/output"
top: "rpn/output"
}
layer {
name: "rpn_cls_score"
type: "Convolution"
bottom: "rpn/output"
top: "rpn_cls_score"
convolution_param {
num_output: 18 # 2(bg/fg) * 9(anchors)
kernel_size: 1 pad: 0 stride: 1
weight_filler { type: "gaussian" std: 0.01 }
bias_filler { type: "constant" value: 0 }
}
}
layer {
name: "rpn_bbox_pred"
type: "Convolution"
bottom: "rpn/output"
top: "rpn_bbox_pred"
convolution_param {
num_output: 36 # 4 * 9(anchors)
kernel_size: 1 pad: 0 stride: 1
weight_filler { type: "gaussian" std: 0.01 }
bias_filler { type: "constant" value: 0 }
}
}
layer {
bottom: "rpn_cls_score"
top: "rpn_cls_score_reshape"
name: "rpn_cls_score_reshape"
type: "Reshape"
reshape_param { shape { dim: 0 dim: 2 dim: -1 dim: 0 } }
}

========= RoI Proposal ============

layer {
name: "rpn_cls_prob"
type: "Softmax"
bottom: "rpn_cls_score_reshape"
top: "rpn_cls_prob"
}
layer {
name: 'rpn_cls_prob_reshape'
type: 'Reshape'
bottom: 'rpn_cls_prob'
top: 'rpn_cls_prob_reshape'
reshape_param { shape { dim: 0 dim: 18 dim: -1 dim: 0 } }
}
layer {
name: 'proposal'
type: 'Python'
bottom: 'rpn_cls_prob_reshape'
bottom: 'rpn_bbox_pred'
bottom: 'im_info'
top: 'rois'
python_param {
module: 'rpn.proposal_layer'
layer: 'ProposalLayer'
param_str: "'feat_stride': 16"
}
}

========= RCNN ============

layer {
name: "roi_pool_conv5"
type: "ROIPooling"
bottom: "conv5"
bottom: "rois"
top: "roi_pool_conv5"
roi_pooling_param {
pooled_w: 6
pooled_h: 6
spatial_scale: 0.0625 # 1/16
}
}
layer {
name: "fc6"
type: "InnerProduct"
bottom: "roi_pool_conv5"
top: "fc6"
inner_product_param {
num_output: 4096
}
}
layer {
name: "relu6"
type: "ReLU"
bottom: "fc6"
top: "fc6"
}
layer {
name: "drop6"
type: "Dropout"
bottom: "fc6"
top: "fc6"
dropout_param {
dropout_ratio: 0.5
scale_train: false
}
}
layer {
name: "fc7"
type: "InnerProduct"
bottom: "fc6"
top: "fc7"
inner_product_param {
num_output: 4096
}
}
layer {
name: "relu7"
type: "ReLU"
bottom: "fc7"
top: "fc7"
}
layer {
name: "drop7"
type: "Dropout"
bottom: "fc7"
top: "fc7"
dropout_param {
dropout_ratio: 0.5
scale_train: false
}
}
layer {
name: "cls_score"
type: "InnerProduct"
bottom: "fc7"
top: "cls_score"
inner_product_param {
num_output: 21
}
}
layer {
name: "bbox_pred"--------------------------------->the strange part I mean before
type: "InnerProduct"
bottom: "fc7"
top: "bbox_pred"
inner_product_param {
num_output: 84
}
}
layer {
name: "cls_prob"
type: "Softmax"
bottom: "cls_score"
top: "cls_prob"
loss_param {
ignore_label: -1
normalize: true
}
}

@daf11865
Copy link
Author

daf11865 commented May 14, 2016

can anyone help?
I think it's because the bbox regressor doesn't trained well
how should I set the proper iters for it

@ericromanenghi
Copy link

Could you show how you do the testing?

It's a strange error, I have trained a lot of models with this net, and I have never got this error.

@daf11865
Copy link
Author

daf11865 commented May 21, 2016

@ericromanenghi here is my test.py. Could you help me to figure it out? I also train ZF alt-opt yesterday and it works with regression. However, the end2end version doesn't work. Is it possible that the iters = 70000 isn't enough for bbox_pred layer being trained well?

import numpy as np

import _init_paths_no_Anaconda
from fast_rcnn.config import cfg
from fast_rcnn.test import im_detect
from fast_rcnn.nms_wrapper import nms
from utils.timer import Timer
import matplotlib.pyplot as plt
import numpy as np
import scipy.io as sio
import caffe, os, sys, cv2
import argparse

from utils.blob import im_list_to_blob
from fast_rcnn.bbox_transform import clip_boxes, bbox_transform_inv

import matplotlib as mpl

CLASSES = ('background','aeroplane', 'bicycle', 'bird', 'boat','bottle', 'bus', 'car', 'cat', 'chair','cow', 'diningtable', 'dog', 'horse','motorbike', 'person', 'pottedplant','sheep', 'sofa', 'train', 'tvmonitor')

def _get_image_blob(im):

im_shape = im.shape
im_size_min = np.min(im_shape[0:2])
im_size_max = np.max(im_shape[0:2])

processed_ims = []
im_scale_factors = []

for target_size in cfg.TEST.SCALES:
    im_scale = float(target_size) / float(im_size_min)
    # Prevent the biggest axis from being more than MAX_SIZE
    if np.round(im_scale * im_size_max) > cfg.TEST.MAX_SIZE:
        im_scale = float(cfg.TEST.MAX_SIZE) / float(im_size_max)
    im = cv2.resize(im, None, None, fx=im_scale, fy=im_scale,
                    interpolation=cv2.INTER_LINEAR)
    im_scale_factors.append(im_scale)
    processed_ims.append(im)

# Create a blob to hold the input images
blob = im_list_to_blob(processed_ims)

return blob, np.array(im_scale_factors)

def _get_blobs(im, rois):
"""Convert an image and RoIs within that image into network inputs."""
blobs = {'data' : None, 'rois' : None}
blobs['data'], im_scale_factors = _get_image_blob(im)
if not cfg.TEST.HAS_RPN:
blobs['rois'] = _get_rois_blob(rois, im_scale_factors)
return blobs, im_scale_factors

if name == 'main':

CONF_THRESH = 0.75
NMS_THRESH = 0.3

caffe.set_mode_gpu()        
cfg.TEST.HAS_RPN = True  # Use RPN for proposals

#prototxt = '/home/daf/CNN/py-faster-rcnn-master/py-faster-rcnn/models/coco/VGG16/faster_rcnn_end2end/test.prototxt'
#caffemodel = '/home/daf/CNN/py-faster-rcnn-master/py-faster-rcnn/data/faster_rcnn_models/coco_vgg16_faster_rcnn_final.caffemodel'

#prototxt = '/home/daf/CNN/py-faster-rcnn-master/py-faster-rcnn/models/pascal_voc/VGG16/faster_rcnn_alt_opt/faster_rcnn_test.pt'
#caffemodel = '/home/daf/CNN/py-faster-rcnn-master/py-faster-rcnn/data/faster_rcnn_models/VGG16_faster_rcnn_final.caffemodel'

#prototxt = '/home/daf/CNN/py-faster-rcnn-master/py-faster-rcnn/models/pascal_voc/ZF/faster_rcnn_end2end/test.prototxt'
#caffemodel = '/home/daf/CNN/py-faster-rcnn-master/py-faster-rcnn/daf_caffemodel/zf_faster_rcnn_iter_lab_140000.caffemodel' #fail

prototxt = '/home/daf/CNN/py-faster-rcnn-master/py-faster-rcnn/models/pascal_voc/ZF/faster_rcnn_alt_opt/faster_rcnn_test.pt'
caffemodel = '/home/daf/CNN/py-faster-rcnn-master/py-faster-rcnn/output/faster_rcnn_alt_opt/voc_2007_trainval/ZF_faster_rcnn_final.caffemodel'

#prototxt = '/home/daf/CNN/py-faster-rcnn-master/py-faster-rcnn/models/imagenet/VGG16/faster_rcnn_end2end/5cls_stage/test_5cls.prototxt'
#caffemodel = '/home/daf/CNN/py-faster-rcnn-master/py-faster-rcnn/daf_caffemodel/vgg16_faster_rcnn_5cls_stage1_iter_30000.caffemodel'#fail

#prototxt = '/home/daf/CNN/py-faster-rcnn-master/py-faster-rcnn/models/labim/VGG16/faster_rcnn_end2end/test~BN.prototxt'
#caffemodel = '/home/daf/CNN/py-faster-rcnn-master/py-faster-rcnn/daf_caffemodel/labim_vgg16_faster_rcnn_~BN_iter_750.caffemodel'

#prototxt = '/home/daf/CNN/py-faster-rcnn-master/py-faster-rcnn/models/pascal_voc/ZF/faster_rcnn_end2end/test.prototxt'
#caffemodel = '/home/daf/CNN/py-faster-rcnn-master/py-faster-rcnn/daf_caffemodel/zf_faster_rcnn_iter_70000.caffemodel' #fail

net = caffe.Net(prototxt, caffemodel, caffe.TEST)

frame = cv2.imread('/home/daf/CNN/py-faster-rcnn-master/py-faster-rcnn/data/ILSVRC2015/Data/DET/train/ILSVRC2013_train/n00007846/n00007846_32419.JPEG')
#frame = cv2.imread('/home/daf/CNN/py-faster-rcnn-master/py-faster-rcnn/data/labim/Data/train/12874478_1115857121770145_1206356860_o.jpg')
#frame = cv2.imread('/home/daf/CNN/py-faster-rcnn-master/py-faster-rcnn/data/demo/004545.jpg')

##process here
im = frame - cfg.PIXEL_MEANS

boxes = None
blobs, im_scales = _get_blobs(im, boxes)
im_blob = blobs['data']
blobs['im_info'] = np.array([[im_blob.shape[2], im_blob.shape[3], im_scales[0]]],dtype=np.float32)
net.blobs['data'].reshape(*(blobs['data'].shape))
net.blobs['im_info'].reshape(*(blobs['im_info'].shape))

forward_kwargs = {'data': blobs['data'].astype(np.float32, copy=False)}
forward_kwargs['im_info'] = blobs['im_info'].astype(np.float32, copy=False)

blobs_out = net.forward(**forward_kwargs)

rois = net.blobs['rois'].data.copy()
boxes = rois[:, 1:5] / im_scales[0]

scores = blobs_out['cls_prob']

box_deltas = blobs_out['bbox_pred']
pred_boxes = bbox_transform_inv(boxes, box_deltas)
pred_boxes = clip_boxes(pred_boxes, im.shape)

for cls_ind, cls in enumerate(CLASSES[1:]):
    cls_ind += 1 # because we skipped background
        cls_boxes = pred_boxes[:, 4*cls_ind:4*(cls_ind + 1)]
        cls_scores = scores[:, cls_ind]
    cfd_keep = np.where(cls_scores>=CONF_THRESH)
    cls_scores = cls_scores[cfd_keep[0]]
    cls_boxes = cls_boxes[cfd_keep[0],:]
        dets = np.hstack((cls_boxes,cls_scores[:, np.newaxis])).astype(np.float32)
    nms_keep = nms(dets, NMS_THRESH)
    dets = dets[nms_keep, :]

    for i in range(0,len(dets)):
        bb = dets[i][0:4]
        cv2.rectangle(frame,(bb[0], bb[1]), (bb[2], bb[3]),(cls_ind*1,255-cls_ind*3,cls_ind*3),3)
        cv2.putText(frame,cls,(int(bb[0]), int(bb[1])-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5,(255,255,255),2)           
        cv2.putText(frame,'prob:'+str(dets[i][4]),(int(bb[0]), int(bb[1])+10), cv2.FONT_HERSHEY_SIMPLEX, 0.5,(0,0,0),1)     
        print  'id:',cls_ind, 'obj:', cls, 'prob:', dets[i][4], 'bbox:', bb[0], bb[1], bb[2], bb[3]

# Display the resulting frame
cv2.namedWindow('detection',cv2.WINDOW_NORMAL)
    cv2.imshow('detection',frame)
cv2.waitKey(0) 

@daf11865
Copy link
Author

@ericromanenghi
I think the problem is due to the fast rcnn part in end2end training, the bbox_pred seems to learn nothing. Hence if I set cfg.TEST.BBOX_REG = False, the net doesn't use the regressor, and the result based only on RPN bbox is okay(not as good as the one with a successful regressor).

But the alt-opt training's result is good on the other hand, the regressor learns well.

Any idea?

@majestix89
Copy link

@daf11865
Somehow I have the same issue here. Took the Imagenet pre-trained model and fine-tuned the VGG16 network with faster_rcnn_end2end on my own dataset. The predicted bounding-boxes look almost the same as yours. Most of them are close to the actual object, but don't contain the object at all (with a probability greater 0.9).
Changing the TEST.BBOX_REG to False gives, as you described, for me as well results that are way better.

Did you find a solution?

@daf11865
Copy link
Author

@majestix89 No
I've just rebuild the py-faster-rcnn and now train a new model again. I hope the issue is due to some code changed inappropriately. Hence I restart everything from scratch. hope this time it'll succeed
I'll tell you if it succeeds

@daf11865
Copy link
Author

@majestix89
I rebuild all py-faster-rcnn and trained VGG16 for only 5 imagenet classes, but still same result as you said...
so frustrated...

my cmd:
time ./tools/train_net.py --gpu 0 --solver models/imagenet/VGG16/faster_rcnn_end2end/5cls/solver.prototxt --weights data/imagnet_models/VGG16.v2.caffemodel --imdb imagenet_2015_train --iters 50000 --cfg experiments/cfgs/faster_rcnn_end2end.yml --set TRAIN.SCALES "[400, 500, 600, 700]"

@majestix89
Copy link

@daf11865
Ok thx for letting me know. I didn't achieve new results either.
You said the results were alright when using the alternating training method, right?
I have some code running using this method at the moment and will see the results on Monday.
(I'm training on two classes, background and another one)

Another question: What do you use the TRAIN.SCALES for? Is it necessary to set those scales for proper training?

@daf11865
Copy link
Author

@majestix89
Yes, when I use alternating training method, the result seems good.
I think yours will be alright, too.

As my opinion, using the TRAIN.SCALES is for better scale-invariant result.

@daf11865
Copy link
Author

@majestix89
I think I found the problem, which is beyond understand.
Is it right that you use snapshot from rbgirshick/fast-rcnn#35 ?
If I use the model from this snapshot model, which is under the .../py_faster_rcnn directory, it fails.
But using the model from .../ py_faster_rcnn/output/.../xxx.caffemodel makes it succeed. Very strange, right?
So I guess when using resume training, the caffemodel which is under .../ py_faster_rcnn/output should be used instead of the one under .../py_faster_rcnn

@majestix89
Copy link

@daf11865
Short update from my side.
Finally I managed to get some proper results with the end2end training method. I used to train on VGG16 and didn't achieve proper results after about 150000 iterations.
I changed my training now and used the ZF network instead. The results were way better on my training set. (BBOX regression did its job )
I'm glad to have gotten some good results, but if I evaluate the four losses of ZF and VGG16 I noticed that those are almost the same.
So what I don't get at the moment is the connection between the losses and the final results. Intuitively I would have guessed they have something to do with each other, but obviously they don't.
Both of them are varying around 0.15

@ksaluja15
Copy link

ksaluja15 commented Jun 23, 2016

I experienced the same error. Turning off the __C.TEST.BBOX_REG variable in lib/fast-rcnn/config.py corrected the issue. No clue why the Bounding box regressor is not learning anything.

@frankmanbb
Copy link

I have the same problem, similar to #395

@tony5614
Copy link

tony5614 commented Feb 24, 2017

I have similar problem like you,
and it comes along with this warning " RuntimeWarning: invalid value encountered in log " while training

It turns out that my training data bounding box starts from 0

      <bndbox>
           <xmin>0</xmin>
           <ymin>42</ymin>
           <xmax>244</xmax>
           <ymax>236</ymax>
      </bndbox>
 </object>

it's because when pascal_voc.py reading the xml annotation,it will minus boudningbox by 1 by default,so It causes this warning RuntimeWarning: invalid value encountered in log
image

@youye115
Copy link

youye115 commented Mar 2, 2017

if your fast_rcnn/config.py :

Test using bounding-box regressors

__C.TEST.BBOX_REG = False
closed ?
It should be closed in test period

@boyflytobeman
Copy link

replace the code in demo.py
#caffemodel = os.path.join(cfg.DATA_DIR, 'faster_rcnn_models',
#                         NETS[args.demo_net][1])

with
caffemodel = os.path.join('output', 'faster_rcnn_end2end','voc_2007_trainval',
NETS[args.demo_net][1])
then,the detection will be all right ,but i do not know why .

@mostafa-saad
Copy link

Hello guys,

I suffered too from the __C.TEST.BBOX_REG = False story. To get things correctly without losing this box regression, I followed the advice in this link: https://huangying-zhan.github.io/2016/09/22/detection-faster-rcnn.html

Specifically:
1- You shouldn't use the old bbox_pred as it is tuned already fo an old dataset
2- But, Renaming it too will cause this problem as seems faster-rcnn depends on this exact layer name
3- To solve the concern, A) You rename it, go a 0 iteration, rename again to original names and tune this new version. Better read the whole section from the mentioned link: "3.2. Prepare network and pre-trained model"

There is something that I did not try but have no time for it. What if we directly used the trained model without the renaming story. In my opinion, that I did not verify, if your # of labels != 21, then things may go ok. Otherwise, the tuned parameters for the 21 labels will be used as they are, which will give wrong results. Probably many people did not feel with this problem, as this did not try the renaming step.

@ds2268
Copy link

ds2268 commented Nov 7, 2017

2- But, Renaming it too will cause this problem as seems faster-rcnn depends on this exact layer name (@mostafa-saad)

It's because of snapshotting code (denormalization of learned bboxs offsets):

I think the problem is either:

  1. you renamed the bbox_pred layer to something else during training
  2. some incorrect resuming/fine tuning from pre-existed models learned on faster-rcnn (not ImageNet)

If you use default 'bbox_pred' layer with changed or even the same number of classes it should work.

@mostafa-saad : Otherwise, the tuned parameters for the 21 labels will be used as they are, which will give wrong results.

Answer: fine tuning by default is done from ImageNet models which don't contain 'bbox_pred' layer so fine tuning doesn't affect. Fine tunning from models learned on faster-rcnn won't work as blob dimensions won't agree (if different number of classes).

def snapshot(self):
"""Take a snapshot of the network after unnormalizing the learned
bounding-box regression weights. This enables easy use at test-time.
"""
net = self.solver.net
scale_bbox_params = (cfg.TRAIN.BBOX_REG and
cfg.TRAIN.BBOX_NORMALIZE_TARGETS and
net.params.has_key('bbox_pred'))
if scale_bbox_params:
# save original values
orig_0 = net.params['bbox_pred'][0].data.copy()
orig_1 = net.params['bbox_pred'][1].data.copy()
# scale and shift with bbox reg unnormalization; then save snapshot
net.params['bbox_pred'][0].data[...] = \
(net.params['bbox_pred'][0].data *
self.bbox_stds[:, np.newaxis])
net.params['bbox_pred'][1].data[...] = \
(net.params['bbox_pred'][1].data *
self.bbox_stds + self.bbox_means)
infix = ('_' + cfg.TRAIN.SNAPSHOT_INFIX
if cfg.TRAIN.SNAPSHOT_INFIX != '' else '')
filename = (self.solver_param.snapshot_prefix + infix +
'_iter_{:d}'.format(self.solver.iter) + '.caffemodel')
filename = os.path.join(self.output_dir, filename)
net.save(str(filename))
print 'Wrote snapshot to: {:s}'.format(filename)
if scale_bbox_params:
# restore net to original state
net.params['bbox_pred'][0].data[...] = orig_0
net.params['bbox_pred'][1].data[...] = orig_1
return filename

@zoufangyu1987
Copy link

@mostafa-saad
Thank you for helping me a lot.
It is successful in this way. followed the advice in this link: https://huangying-zhan.github.io/2016/09/22/detection-faster-rcnn.html

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests