Skip to content

An experiment in training a fully connected residual net to learn the argmax function.

License

Notifications You must be signed in to change notification settings

sparseinference/argmaxnet

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

9 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

argmaxnet

An experiment in training a fully connected residual net to learn the argmax function.
This seems surprisingly difficult for such a simple problem.


Experiment:

  • Python 3.7
  • PyTorch 0.5.0a0+f3ac619
  • Stacked Residual Net with identical residual modules followed by a linear reduction to a single output
  • The residual modules are wide in the hidden layer with the input and output dimensions equal

Training:
➜  argmaxnet git:(master) ✗ python3 train.py
ArgMaxNet : StackDepth=3, Block(10,60)
            opt: SGD(lr=0.0001 reduced by 0.98 each episode, momentum=0.9, weight_decay=0)
Parameter count = 3821
[    1] Loss: 0.841379   Err: 0.210000  elapsed:  0.009590 hours, lr=0.00010000
[    2] Loss: 1.009749   Err: 0.158800  elapsed:  0.019381 hours, lr=0.00009800
[    3] Loss: 0.673255   Err: 0.143200  elapsed:  0.029359 hours, lr=0.00009604
[    4] Loss: 0.400579   Err: 0.085800  elapsed:  0.039462 hours, lr=0.00009412
[    5] Loss: 0.522911   Err: 0.093000  elapsed:  0.049320 hours, lr=0.00009224
[    6] Loss: 0.390506   Err: 0.077000  elapsed:  0.059192 hours, lr=0.00009039
[    7] Loss: 0.352900   Err: 0.068200  elapsed:  0.069052 hours, lr=0.00008858
[    8] Loss: 0.293024   Err: 0.069600  elapsed:  0.078911 hours, lr=0.00008681
[    9] Loss: 0.308593   Err: 0.073800  elapsed:  0.088776 hours, lr=0.00008508
[   10] Loss: 0.485971   Err: 0.069800  elapsed:  0.098606 hours, lr=0.00008337
[   11] Loss: 0.503779   Err: 0.065200  elapsed:  0.108481 hours, lr=0.00008171
[   12] Loss: 0.277030   Err: 0.060000  elapsed:  0.118357 hours, lr=0.00008007
[   13] Loss: 0.223894   Err: 0.055800  elapsed:  0.128042 hours, lr=0.00007847
[   14] Loss: 0.266146   Err: 0.059600  elapsed:  0.137845 hours, lr=0.00007690
[   15] Loss: 0.316010   Err: 0.071800  elapsed:  0.147688 hours, lr=0.00007536
[   16] Loss: 0.285165   Err: 0.050400  elapsed:  0.157558 hours, lr=0.00007386
[   17] Loss: 0.392230   Err: 0.070600  elapsed:  0.167250 hours, lr=0.00007238
[   18] Loss: 0.246672   Err: 0.051400  elapsed:  0.176823 hours, lr=0.00007093
[   19] Loss: 0.209529   Err: 0.047200  elapsed:  0.186649 hours, lr=0.00006951
[   20] Loss: 0.183141   Err: 0.038200  elapsed:  0.196483 hours, lr=0.00006812
[   21] Loss: 0.212652   Err: 0.046800  elapsed:  0.206334 hours, lr=0.00006676
[   22] Loss: 0.276476   Err: 0.050800  elapsed:  0.216203 hours, lr=0.00006543
[   23] Loss: 0.289849   Err: 0.050800  elapsed:  0.225948 hours, lr=0.00006412
[   24] Loss: 0.275776   Err: 0.051600  elapsed:  0.235879 hours, lr=0.00006283
[   25] Loss: 0.200512   Err: 0.044600  elapsed:  0.245600 hours, lr=0.00006158
[   26] Loss: 0.194553   Err: 0.041800  elapsed:  0.255548 hours, lr=0.00006035
[   27] Loss: 0.169269   Err: 0.045200  elapsed:  0.265598 hours, lr=0.00005914
[   28] Loss: 0.273164   Err: 0.045200  elapsed:  0.275583 hours, lr=0.00005796
[   29] Loss: 0.272778   Err: 0.127400  elapsed:  0.285362 hours, lr=0.00005680
[   30] Loss: 0.174689   Err: 0.050200  elapsed:  0.295156 hours, lr=0.00005566
[   31] Loss: 0.285349   Err: 0.055200  elapsed:  0.304912 hours, lr=0.00005455
[   32] Loss: 0.300190   Err: 0.041400  elapsed:  0.314661 hours, lr=0.00005346
[   33] Loss: 0.202204   Err: 0.045000  elapsed:  0.324503 hours, lr=0.00005239
[   34] Loss: 0.198700   Err: 0.040200  elapsed:  0.334367 hours, lr=0.00005134
[   35] Loss: 0.202726   Err: 0.038800  elapsed:  0.344198 hours, lr=0.00005031
[   36] Loss: 0.227932   Err: 0.045600  elapsed:  0.354049 hours, lr=0.00004931
[   37] Loss: 0.168154   Err: 0.041200  elapsed:  0.363848 hours, lr=0.00004832
[   38] Loss: 0.244279   Err: 0.043800  elapsed:  0.373667 hours, lr=0.00004735
[   39] Loss: 0.192497   Err: 0.039800  elapsed:  0.383468 hours, lr=0.00004641
[   40] Loss: 0.175515   Err: 0.040000  elapsed:  0.393374 hours, lr=0.00004548
[   41] Loss: 0.132442   Err: 0.036400  elapsed:  0.403231 hours, lr=0.00004457
[   42] Loss: 0.131879   Err: 0.032400  elapsed:  0.413111 hours, lr=0.00004368
[   43] Loss: 0.239974   Err: 0.036400  elapsed:  0.422981 hours, lr=0.00004281
[   44] Loss: 0.146880   Err: 0.034800  elapsed:  0.432853 hours, lr=0.00004195
[   45] Loss: 0.116312   Err: 0.033400  elapsed:  0.442684 hours, lr=0.00004111
[   46] Loss: 0.214501   Err: 0.035400  elapsed:  0.452548 hours, lr=0.00004029
[   47] Loss: 0.277160   Err: 0.033400  elapsed:  0.462417 hours, lr=0.00003948
[   48] Loss: 0.178411   Err: 0.034000  elapsed:  0.472273 hours, lr=0.00003869
[   49] Loss: 0.246785   Err: 0.039400  elapsed:  0.482137 hours, lr=0.00003792
[   50] Loss: 0.231782   Err: 0.038800  elapsed:  0.492130 hours, lr=0.00003716
[   51] Loss: 0.157200   Err: 0.036600  elapsed:  0.502017 hours, lr=0.00003642
[   52] Loss: 0.259224   Err: 0.035000  elapsed:  0.511886 hours, lr=0.00003569
[   53] Loss: 0.262706   Err: 0.042200  elapsed:  0.521773 hours, lr=0.00003497
[   54] Loss: 0.169639   Err: 0.029600  elapsed:  0.531676 hours, lr=0.00003428
[   55] Loss: 0.280070   Err: 0.039000  elapsed:  0.541571 hours, lr=0.00003359
[   56] Loss: 0.151866   Err: 0.029200  elapsed:  0.551473 hours, lr=0.00003292
[   57] Loss: 0.201496   Err: 0.038000  elapsed:  0.561332 hours, lr=0.00003226
[   58] Loss: 0.125637   Err: 0.031600  elapsed:  0.571196 hours, lr=0.00003161
[   59] Loss: 0.218838   Err: 0.032000  elapsed:  0.581075 hours, lr=0.00003098
[   60] Loss: 0.119034   Err: 0.028800  elapsed:  0.590925 hours, lr=0.00003036
[   61] Loss: 0.087370   Err: 0.027000  elapsed:  0.600750 hours, lr=0.00002976
[   62] Loss: 0.105404   Err: 0.029600  elapsed:  0.610550 hours, lr=0.00002916
[   63] Loss: 0.202022   Err: 0.026000  elapsed:  0.620337 hours, lr=0.00002858
[   64] Loss: 0.114809   Err: 0.028200  elapsed:  0.630182 hours, lr=0.00002801
[   65] Loss: 0.156818   Err: 0.026400  elapsed:  0.639974 hours, lr=0.00002745
[   66] Loss: 0.117886   Err: 0.027400  elapsed:  0.649795 hours, lr=0.00002690
[   67] Loss: 0.148011   Err: 0.027200  elapsed:  0.659691 hours, lr=0.00002636
[   68] Loss: 0.185504   Err: 0.032400  elapsed:  0.669347 hours, lr=0.00002583
[   69] Loss: 0.158207   Err: 0.026200  elapsed:  0.678770 hours, lr=0.00002531
[   70] Loss: 0.211314   Err: 0.036000  elapsed:  0.688219 hours, lr=0.00002481
[   71] Loss: 0.123811   Err: 0.029000  elapsed:  0.697775 hours, lr=0.00002431
[   72] Loss: 0.121323   Err: 0.024200  elapsed:  0.707256 hours, lr=0.00002383
[   73] Loss: 0.092815   Err: 0.025200  elapsed:  0.716782 hours, lr=0.00002335
[   74] Loss: 0.077281   Err: 0.025600  elapsed:  0.726131 hours, lr=0.00002288
[   75] Loss: 0.127059   Err: 0.026400  elapsed:  0.735463 hours, lr=0.00002242
[   76] Loss: 0.135494   Err: 0.023200  elapsed:  0.744721 hours, lr=0.00002198
[   77] Loss: 0.095448   Err: 0.023400  elapsed:  0.753993 hours, lr=0.00002154
[   78] Loss: 0.137425   Err: 0.027600  elapsed:  0.763382 hours, lr=0.00002111
[   79] Loss: 0.105115   Err: 0.023200  elapsed:  0.772932 hours, lr=0.00002068
[   80] Loss: 0.135414   Err: 0.024000  elapsed:  0.782717 hours, lr=0.00002027
[   81] Loss: 0.146738   Err: 0.027800  elapsed:  0.792537 hours, lr=0.00001986
[   82] Loss: 0.109915   Err: 0.023200  elapsed:  0.802295 hours, lr=0.00001947
[   83] Loss: 0.100822   Err: 0.023200  elapsed:  0.812065 hours, lr=0.00001908
[   84] Loss: 0.129659   Err: 0.024000  elapsed:  0.821850 hours, lr=0.00001870
[   85] Loss: 0.117406   Err: 0.026800  elapsed:  0.831664 hours, lr=0.00001832
[   86] Loss: 0.097301   Err: 0.024600  elapsed:  0.841496 hours, lr=0.00001796
[   87] Loss: 0.088146   Err: 0.020000  elapsed:  0.851277 hours, lr=0.00001760
[   88] Loss: 0.108722   Err: 0.020400  elapsed:  0.861119 hours, lr=0.00001725
[   89] Loss: 0.105350   Err: 0.024800  elapsed:  0.870881 hours, lr=0.00001690
[   90] Loss: 0.140298   Err: 0.023400  elapsed:  0.880625 hours, lr=0.00001656
[   91] Loss: 0.081200   Err: 0.021400  elapsed:  0.890413 hours, lr=0.00001623
[   92] Loss: 0.114371   Err: 0.021400  elapsed:  0.900247 hours, lr=0.00001591
[   93] Loss: 0.206781   Err: 0.024200  elapsed:  0.910042 hours, lr=0.00001559
[   94] Loss: 0.171771   Err: 0.022200  elapsed:  0.919865 hours, lr=0.00001528
[   95] Loss: 0.120925   Err: 0.022200  elapsed:  0.929604 hours, lr=0.00001497
[   96] Loss: 0.087923   Err: 0.022400  elapsed:  0.939338 hours, lr=0.00001467
[   97] Loss: 0.081036   Err: 0.023000  elapsed:  0.949043 hours, lr=0.00001438
[   98] Loss: 0.105851   Err: 0.024000  elapsed:  0.958831 hours, lr=0.00001409
[   99] Loss: 0.116960   Err: 0.026200  elapsed:  0.968549 hours, lr=0.00001381
[  100] Loss: 0.065967   Err: 0.017200  elapsed:  0.978295 hours, lr=0.00001353

Test:

➜  argmaxnet git:(master) ✗ python3 test.py
Testing 100 batches of 500 random scores ...
Loss: 0.103573   Err: 0.018920  elapsed: 0.001009 hours  perInstance: 0.072660 ms

About

An experiment in training a fully connected residual net to learn the argmax function.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages