From de72a3c93466204bfb7eabe98843e793d9097e37 Mon Sep 17 00:00:00 2001 From: nicolai256 <68881396+nicolai256@users.noreply.github.com> Date: Mon, 17 Oct 2022 04:18:24 +0200 Subject: [PATCH] Add files via upload --- 1 install guide.docx | Bin 0 -> 13736 bytes deflicker.py | 116 ++++++++++ generate.py | 26 ++- libdeflicker.py | 155 +++++++++++++ logger1.py | 36 +++ train.py | 27 ++- train1.py | 132 +++++++++++ train2.py | 132 +++++++++++ trainers.py | 527 ++++++++++++++++++++++--------------------- trainers1.py | 267 ++++++++++++++++++++++ 10 files changed, 1139 insertions(+), 279 deletions(-) create mode 100644 1 install guide.docx create mode 100644 deflicker.py create mode 100644 libdeflicker.py create mode 100644 logger1.py create mode 100644 train1.py create mode 100644 train2.py create mode 100644 trainers1.py diff --git a/1 install guide.docx b/1 install guide.docx new file mode 100644 index 0000000000000000000000000000000000000000..fbc805213577473370337a5e3a24aeb0cc82917d GIT binary patch literal 13736 zcmeHu1zR1-w)VzDfMCJh-QC^Y-Q696y9RfM;1VRb28ZAdfsF=t_W<9XIWu==a?W@E zz&+K^(|dPUzpJaOYt^dtuGNY%5RmTy(114p0DuUvUNmE?4F&)ZLID7%fH&YeA`bSh zX7;WIs-BK!E_(DI>}&}OAi*i~0pOtP|6BeK_drAPsC_RZlGt6!BVt0cn(=-?1vU6b z{!}{U0~kziaP{ZJp^lfYsNhO!U*)>zaBgQu4>ZDElcJ;;tQ_)-It^sP8( zmKL|^{7tdQHV#v`CQ+F>H7rdFXR6mzV65 zrWo%G40gdgkHy9mNLSQ|g*kR5aAj4Z`8G&g;$c+&vBpPL*R!>@h*lnmyK${nOxEzY zOee|!mN2oBSx}|pn(!!L#7!;OS1KjhVL!n}GCcX1 z>!N7C>_zQj39z3LX*C#ZI^bmM30;8<)4%J-f@!_dHK;R~K-z-~cmw9);B3n9FC{T` zFmba3S+C#B)}IOk2C`isIsdoI=fqLafC0_jp)bMdzH~gkio)Br(nP?8XMh;KKp8FU z`1NL2jil7klwsSsb9K7QWB@yGvM0(Tdd|Bv|C=+?esqoZBK3!I{F+lCBpVHzBRp!% z1VbHtecRkAYk8Tun?TV5PUR`4@Wy0QZd_I~{(KpRn7H6w)JAJp0&MA8y%<@=@ zxr_tV>PQ|pY!TM*_XJH2nQI%t91alN{3=8Oj<3;U2b|KHcsiy~NsKb0@q1$GGEq|m zA5EA`6=^HF3kS-(jdCly^y+eF>0g=%2G45KU0SqeK3&ApbT*^LxA$}%3oD)5}$;PC@+CTgGGH(4!ol-JYu{wdM zM~y~7svX@br^&}4A%a|U9#y0&}vs@Zj}U6Rk% zWL(Z{1dWP6R#=-Lv-l%x=dKSk#cLT8NM4Kmy|NUYHCPjr?fs77G7RnHce)j7am~vu z_DeMD@jqtZ<0Qsx!D3BUnsP)g3 z%O{zYXAl&^CQQeLV3J*If8|BUtmNt!GfJ@ux~etEWR8TJ@Rwz>95X_eFl`U zc1$!Jtl`m@X5&c4!KD-5*mxIcB8EZYq?&j>&)NDeh8OR%`OclP=>Z#C6Luq_4`;F2 zg*BLVar4TxVIww68nV0`B}LJVQFmy6=XXa_+qme5_1-EqXkbb0ai+ zhqsNk&OL9@W`7t*KNvdaIG+KFJCD<@pN~#Y&HPNFx9K?e8&#tsmHep73!Fu5Eb>&~ z;wS~vO^CJBhtfiJ+cVJA8btAnudM@&dLa~k&nHyz&Jb!(CVGz=_EMk^I zJ|Zbmn)rd4VThYag7-;4AWtA~I_O5B^do{8&}TW5r3n9~oreRSR380el&#h#<>ZGo z(^-4k&4aHCU$nqR02%WKpJ%~Z@=%;9ZI@bsa^cE~LH=1nqkVH3Eck8?l zEb6#8`oS$F{~Zld)Ak;f48g2t=t{a%@fRtG%EN)ETB6EbLYLheX7q4x-ef_@6Cp_C zwCiq0fP6=v{8&9Ho`&=L7K2YG?uv-|3&C2U!iGYIYP7Y-6Y%A`Uf(h_CKiF?UiE8b z95_g_NV4qd+_bS`Yxa3_h=&{RH{Pd+f1kKtuMzP5Oy1iBl(^~h0uVtGL4FjJA*Z74 zfmpWZI1I#!B}^ntq{f#*%aLfOxL$AluU1a5H#;n_!s zezY;6BaNiS3gfw@jZ@i+w~4Qx~MqWfZRlH<}yR}4f7k;C5H#ZD>7e%$06o*3r^OsPK z+g$x|5^}e&EhvNSbymoeokKTDXM8x`r-iMN{EH%)?@;vl1>{jxszusg^Y%PD?|ak= z50B;;t zI>C>T`aLmx!!fOUH;jgmPI2#M(LYcoKqi@fx|1}qSo|WY4vbjn&$6d%vAmrwg@Nkk zlH^xk@O2hac>-nkrIy#k1=lwl#oZ-l}mFQ81@nqWxJ&hg^oVb;j;m{GE&6Y~}J-Fljqt4b<& zNA+W{j=(0iyj#xhGUkLw$NL0zj~QdH4Zq6I0FjQD#G()p`_`Gm!8*nd3L1N@tY~xm%C3J=jqRr zu!kwsC7)X7FcK9Nl*^*N@H#xHW4mn;UUkn+5P^f-=3d{n>eF_`*~MIt;Q?*z&86oq z^L}CJA?ewJMSGxGH9UN7+`qEIQ_L-O`4HMV4L)8Xd5{Gj!X99+x@PC$MH?=?8h{Xh z02~P8qvL(Y|Bhe(;R5ln41yRyOgoBc8(^2(>`MoS*bV5eh3f{Woq{0&K-*hAk!imG zsDRL_o*)`lgn_fgDo?LsGWICM12wZ%uB7^>5HyX$BHnLBW06(J#;bj3u)~GlDYTFP zvJA)2fZ6y?x~d~zpK7P!jXavqpI*m_nK6rl#1C2l`o#@fg_w1=80ggc2Dr0Q_Pa*Hur_0;Cq|?GE@f$0jhd9TGYcj8} za2yvwtW3b>u8Ei%==nqxmj~Mp`3j?BkFr=?D~3QE{v;gW7AEaQL6s}Mhzqe2Eq;Fq zW&FtDRvI)t5{-G_J=i*kxFD{*49{0l1^+p2keDSgqn43bbFV2eQC@82 z?W=4RJ+_zn?MZKpGZ5C}cc}nY6OyFttkG8tQ)_I2& z=84b8dsX(5R@l-}s0u2@?Pr#!dg8*~c%P-VJA0zS)-)nBymwyOM_1d3RNN&jFYPq) z^{XT+K8OQtmWM>z`36Vf)?J9UYCWgQ!ks#72_Uk|f4HE@d(#e@uMA!sIkLLc-4yt> z^E9qQ?iOjW4La-mG?wczhtoNJXCD~32peSh^j!CRyaBne%5e0;X!V`USYkT60jmO&)=gSO$wMmFQY zdIpurE7VJMFS)I;Op2#pSmurcI6^<)?nZt0CiH4;fnIR_Ft+yVHNtgciSC>CTpq?? z6<%%S&q7EYd>!dJE;}bO6;{201)?4$*jY(2GPB?GeLbC=U3Ims9jyrVzX2NwKP+)J zd5G`c-(R0MV2;l8P#lMGC1f6rZ}cOZ1WOv^-G(d>xAMnXy$;qO4-m1fWV6bjSQ)(W z-ExgzIzj4Fo5uT+75rd&Tz}Ux`ibz2tlobYDNCrPFPaJ1*k`i%EzhzK?^cx49i~Le z{0?^V1^Pd&CyS=+LT5lVeS`o2+Mnx5S4%THGlt*yOuzXc$6DhbaX8+0VLS@JyLkL$ z+loVSs2y@=AGSxS%Sy&=Y~n3!G&d#ng~{F zVA&EIPl-i5g0kXlq1mF*%>R1)wvCnq*92x8)d{O{MRH_#$5;v85HnK7HV8ExhY6nq zRu3ON92YNmEghYvSh)(5B(X(4$XLM=W0K=%LJ* zs9>v0h~;P~TG;Ui0)gaek%V~0X=jT#P)ZxYRrTFU3LEx+JZ51BF7HCYvk0Llb)Feb97p8xSvkC{M`5o{{b-Z<7hhOoZ-CP4gYb zk;Kq-VY1viCW-XE2P~F^LvMCQD3E;tIEa4%TQQ4v$80`5DAA^I2(?04#@L(XJ*5|x zXle{bW4J3pjqo;B!4mzR?-?F6suY1ri^n!Jh^f<$k7WGaEe)O%)s{2|rvHG@m~EJ! zi?M(Tb(x&$-prij%?rD2z4!V7qqWNVfCM?p1atv=N;okcVoTuI+FNP`z0qbl9P-QbSWEM|M`P`JQ@%-kz| z6UwJ=v2woNp3M}`^%H>gbB#Iv`}z|5Qv^1gDOi16{zz@H4ieK!US!!;(LjlMYl5dG z3&(qnsN8hYQguO^#<wG^ zT-8YZjuSC z)@)i-)twxgg^tv1$slz$!tB}k!tFEh)!v+*ZJ$cz^K2Ohe^`wgiqjNVnT)!+9eW!cGi?xEu)X9S>y8k(#s= zuyE0D#cbgrdqF()Kcw3-3Bsp?xfG`~j2_q6&_B^iQo z^mU^m2eMpNt#Q(RKsahYMM}1x`aP9i#ln3K8gO>&HL2cyt#c=;_`yo1@6(=LPxya+ zFRg#T@#gz_yTGrXyqpG2OV6wagV9*)XJb&Uxwg#nWWDD0MsruGYQ{l5D@Xg2mfank ze)B9PYp{0dR29^XeHYGsjQQ)ABV4Q6#cF8%szu7V5h2BMYcjE}Y1+9?>Ee@lU`+0! zA8#|;vCl;(W6f^k4Fkx;{X3~5kr(3u4haBsgNT*?Ayr(=TwSf~EnI#xfDP)a4nTG! ze}dZAo~bKbnJ6;~;niOEV>YplcWSqF9MVO$#6Q8MmYhC5pYbRgZ?ozrRPQI1KdoZK zb+un*dBkNS+=}8*570`}|F9GkC1XMZYF#frdwoDdh!z6}`Q8~~62=7{4*t(O<3t?& z+KMa?jzqE(m0At;>N|xA02sNFmtVyRe)pv<# zl{Te?jsLa&+;<@_s>z~9nDE`5??O0EWaMba^z{wm7QD|jDXAxWd{yK4w!%{2#R|fefEy=Q z#k&tSUqn0R-XpNvuf%1GmlL3?LUxccx__fmW;FwZ_n|Z5GBt(izY9>v+)KxH6S4L; z$0=&=T$dNgOX<~pINu%}rz-GYtW%=%nv!gM4^KVZLdQK>Uh9zz(`q$BtO<*mOrc0} zKGb>Zb~x(tjZSg6`4aCk`MNC#0fR<@I4C*?CrVzFY@l1TD14xUgl#sIgmj6pSX)00 zXo}9%wz|}xp?TkX-qquu3~cCAjw`0y6;RwceYd>&DXmz@y}$#@G_BJX*K|ui6-sMi zlw642i8&o8+IR0CfDR`?QA1{ODR};b?@wg?(a>EoA~i_0J>1ZG z^DMm}y-=b#Y9^B9tO1LAWRQQ>EnP)@lF!go$@0`<@GI=skpa7^C~%{XBn(exF4N=_ zWr2Z3h74Ne5Vu)FiDL}3-iW;Z0pZW0hfvIr9+;`~bnc@~jSqEgx%y^K%y@&NR?N7i zl?UhC`s3kC&R%@OyOlRTtsfK?OE+uwfd2oHyYF~K6&eH8h=NwmD1WY+%|Rh&RU>0t zv)^Wnm85lRAmjVNGZ=4}Pfvq&TSYnLDhpMCJ>LiR@dcGkb!7dcV^KI9x=5CW>4o!K zzvnmhNUkUNaqrbnJokN^Hf7+Ww6MvX#3Y+}+pN}9bx^Z+Gmb}w)UxV)^7-?D5V)_V z?@vbo)tGv)H>b^tSW2_>Azw$Bk#bmZ&8ySUgJZ73V6q?CSyo=xZa8`~a$y>kv_nPI z8Jmhw2fM<4WQxN+%tz)VuopdNq+v%+2$y0GK9V4?l%@Tb!5B@MIothsB+77U`?zg4bn=0 zEU7%ONc>c{Vt5a?XuYNO_Iiaoo@?+{cQF~wb{zda#$kB(LZ-e&(RC#_I^Ky0XD1A& z7*PTZ8j2Cx3r*Bx?rxJ^Ph}GlKX-b zCg|M@g5sL*|9tlzX2vRivlq$JJ<?1D37 zzZKaZ&GJ;MAv2uNC7JT?T9|+2gT8^z;2RK+3rkT)!mILKVSU{3Wrqrweu2ZD-S-5HW}I>9xUhXNA)V^`hfi`vwi-gDkwvD| z^L<|QG{vgCj~<%JUx(+IOs_ocMDkBBHs#-5S)5%V{ztR>JW`3d4(iEuPy`GK6afQ; zF%+F0992O2vQAU#cN6xQI4Gq^{Lx&KLnV(?;O?#>Pobt2pjJJ7TT4Z7G|5p} z7eQ%95>sDGHHkE;bMb<=4z)*=AEp@}5|ecqvA-BC<`yD92wO~Nrj5ZJ`92K7#jX-< zh3jS}P>j8@f-G1A5-8}hCDG%osGE20FNRZ^*mbX($9Dj$FR~XSinKnHXFkJu%!#9= z&y!rTfEN+9&$Pkg89=$4jSwKvJ|+4^mzE2Ma2?+){%v`%@mF2nZCa9hinADXl=9Xe zmFhiWi^dmIJoJeG3s2ZmonxU+$e2A-Qp5)gyCk#%7}2#LUdV$uB~bL3>4!HNf^9ii zK?*szMHj3eA~S*qg&XsS-W%BWZHl7LNrOBzAJ3SDt$`)H!}-Mn&XhPo_&5Yi5!!`B zCrqbC`l%+JNV8u9y4{QjuJPX(v>z!%AkwWOyc)S@W}WcCtF!R5){kQOm)Yy>6&FL- zBP#{__e&VGw>RKA@O6;aw+h(so_4CTCJp%960HZ&4&3z%i;&=jGo$>>twxgN4tQyV zv7~47rK93mYu9@p-~XT2qUC2>&(0vFH~>vA$bVM~69?ztEz5tk z0;Ce45}l~0(98Hf>`d+_qRd&dJ47@+P;zb~hLqfxbuHxpS_D-qvD6R5M zxwN+pXokF{q;g}bQS1v9m?nMfii)V5!fm@1b5M=+^ERI;fir9y5%D~_0ma$cQ8q#b zH&Oef!&(gG?%WQLoPF^kU^Cdb&_w+8=&nZ9?^tu->X2^BHmH}`ydMok5!KbuR12Zq zp2F{KULMzU=P~MTqrgT`CX8up-)RL+e)>3O{{S)8w8QVjG%ILVUXP)9nDafzKB8Lt zsWa3>&DwtDn9v0irmsP9q0_jMmy(mabZ}y|z8dm;XlsYaKZ~u#h~ZEmK@4HKba;5f zooOSNh1}lmm^-hqbZ-FO3^|fGPu2&8t~dMxb1dvbwFk~v9@NKK@K=>+NnUc10r~t! zi+iqm15j`OyVG>>G<(bh2LPxz0sz0w@4rW)U0gkF&3<>8I&$32r^EY%x|0;eh%*e}{diaLeQ+VQCK4m0QqDj2L|Fu{3%uSibaiH1 zJV;;CP6d$i#PxmbA)ZXz>oY>9|1egEXg!ST>sT_AzVm3tO)yC7OP>n!DJqQ?Lxo7} zD>gnJek=?c`H8Snea4#y3^a5m1!o)q+15GH6jQh_V+##w^bHs(;-jjeyrb7rLv`B& zvZOd;`;Y;~$R*S&^*@98Gb+T(5k^f>4@ZO%eOFJO z1B^~Dui!HJ#E(*qz!Ki<0~sZY<G^6@Rndhsv$r4B2%J0ZjYj(il zzamf~T)fPQ5v7M=Kq00!?v!iDA2OI~x+m4-t|h2-_?kOc`4 z&J=yMBh1I@b5R`mnC0a-q~z6suq|Bk=BrkpcnXT#1>D5g2bAbFV%r;~2oltU`H7|s ztqv}aCn{B;p)<9qVQ)0oM}4T6^~(JEJL>%{Z!}Y;*U2DJl#9igh+0g+)gh-AE@cJ-uZCv_Rho(0bM+dmrxV#qShl4!AW zGp;{KJ10V>JDXOL|5$5?Us`4@d3<2==q-~SXKH|mHTt=lUQr@2Se)NqTh-knw#GXsBg9x)+bg&aT$ zz8;h;vmoGWjqFnC(1X|0tVUPmu{34Vpj8D?L~+@P6lgC3>7A=4wMoTZp;tdCW!UtQ z)GdZdnARfHbPeENGcVCjomyUINHLA$$k|Bc$nqG~QDQVauMS?m@Y&h~=p>k%vRe+K zTyHp-bJ=VcVLmi99vgh2tL&;;i2MR&{04% z2&QOcwJNc~#NJ35aY0LBs?6!`dfuMxnpHhlrsx@t+P7thYzn76)<%+LCX*5qCG8x0 zA4!-YLU)$u!=^?XV`&*j#c*ac%!TaK`@^htpYS{tAH(4KUZfVguEYO}YPv5GwF+NC zw>2PSn`UJef+Nab@+6n zO#W88juv%7rfiM3Vpi{}Y_i$^QsW|7_26^7D?$r&V&ew3KT3?w1|p+}a(b zJO-n*FT6o?T%%7C)Wpn8w9(SAaD|CIy=APmjLKk->z`DDZ@T14p^yA)b;6!Dcgo?b1T*QMHDU1=>iU2ucCRrW_hk$U}CgVVe+(6<@oN1H((v}ypoNL#bUAAd1kd!bhkKQ zOy`JC2*9bG0?%|E4__1)JM}rqun@7ihE^GLPokMm-H&-!T99*w@Hh&t0NdiO*pbJ~Jy>8I+C3juETOi1#c3oju&TU1gRewPjDm51+o?tf;xrCP z_rwTdp@n#3p(#&^hrVf0A_CR40Q#`d2DhT$Zuv&PrQ#P4egCiXAOsrmzb@ljyq$1l zC=`CXT!s|LPA!@rW4-v14}@?M554#eLGbswx5$=-!jQjR%1r&{j+weBKjzOi4God86Oj6Y=j&^!`#Cu4KeJQO3A7R2KikL{(Dgf7+N@-d_x=3Y5sPDLXV^;cSK4jtwth< z^sK0=bdjW|;;P8*x=JF*>b%Hqs`0DtdU8Q23DOGt*j$iA?0M648~VGf7Q5(bs9tn? zoD7@5+CbpomiC5~A&3E)F2w^{b+mT9M_O((iE-Q|Z*Y?-oA=@j8{DVP%$ic;n= zZ4WN3KmpD087EZzd(34L_A)$8W0%q@|!HD!_Hyg5h1erW_)&GKg$ z3yvNc>wt?i?1GNsFiL^)P1UdzY?NAM5peaxT+}+$Ee1AGPY}*exo-eZ-CDBycO20_ zK|-m~segd1CC?MF4S#}UIg}Gs$XU04tRrEQ-qhx_M^=aHm;2TzxU!)Ocl|BZdOCMW zIc6h+PN-@jbE~en&C*_2<7OqKnz;>@k}n6dqt^btwh_mOg0*hb9O)mA(PBGH#YV4V zor0Z4*C-t`n!2J4f|8dzEh#6>y#&{1fgbb!SKVS;F7@Xl{c%le@t~QE+}}j}ukwEr z0BX#3@}D=?E_OGS=}XHYFl4rUdNP0Pd$1lJ_i$XqPmdfqBdnqaUq)P$$Rdzt?kOP` zx=B)<2R8N2L-o~(;=K9z7NhNR?c*(}gwlL-zCHfdm-WEMCc`0(;^ipYEW_u63aREa z;nB$x-CNr#5BtoH;tBP{1$JsR7qxShd4#)p#YGG1QG%11*Aw}lF?fMR0Zte!e@h&cYx~|u8ti0vGURcai^%5VgXUFuy&QXVunrcIy z_97<{5f>nk>app%$^PiIZW5F@|HtYY44e+cDE;S8T7L6*{$BpW2QG>-{}k}g1pL2) zlR&rpTRQ%)z<*{Y{~cHl;yeGhJmp^{{hF=%m#hQOcRT-Mtq5tF-eue+)?hr`p> +[options]``. ```` should specify a path to a folder than contains +the image files that are to the deflickered. The image names must contain +numbers somewhere, and the images will included in the timeseries in ascending +numerical order. specified the width (in images) of the square filter +used the smooth the image values. Other options include + ``--plot ``: + do not output images with adjusted means; instead, print a plot + of the RGB timeseries before and after smoothing to a PNG image in + ````. If ```` already exists, it may be overwritten. + ``--outdir ``: + output images with adjusted means in the directory specified by + ````. If the directory is the same as ````, the + smoothing is done in-place and the input files are overwritten. + +.. moduleauthor Tristan Abbott +""" + +from libdeflicker import meanRGB, squareFilter, relaxToMean, toIntColor +import os +import re +import sys +from PIL import Image +from matplotlib import pyplot as plt +import numpy as np + +if __name__ == "__main__": + + # Process input arguments + if len(sys.argv) < 3: + print ('Usage: python deflicker.py [..]') + exit(0) + loc = sys.argv[1] + w = int(sys.argv[2]) + __plot = False + __outdir = False + + for ii in range(3, len(sys.argv)): + a = sys.argv[ii] + if a == '--plot': + __plot = True + __file = sys.argv[ii+1] + elif a == '--outdir': + __outdir = True + __output = sys.argv[ii+1] + + # Just stop if not told to do anything + if not (__plot or __outdir): + print ('Exiting without doing anything') + exit(0) + + # Get list of image names in order + loc = sys.argv[1] + f = os.listdir(loc) + n = [] + ii = 0 + while ii < len(f): + match = re.search('\d+', f[ii]) + if match is not None: + n.append(int(match.group(0))) + ii += 1 + else: + f.pop(ii) + n = np.array(n) + i = np.argsort(n) + f = [f[ii] for ii in i] + + # Load images and calculate smoothed RGB curves + print ('Calculating smoothed sequence') + n = len(f) + rgb = np.zeros((n, 3)) + ii = 0 + for ff in f: + img = np.asarray(Image.open('%s/%s' % (loc, ff))) / 255. + rgb[ii,:] = meanRGB(img) + ii += 1 + + # Filter series + rgbi = np.zeros(rgb.shape) + for ii in range(0,3): + rgbi[:,ii] = squareFilter(rgb[:,ii], w) + + # Print initial and filtered series + if __plot: + print ('Plotting smoothed and unsmoothed sequences in %s') % __file + plt.subplot(1, 2, 1) + plt.plot(rgb[:,0], 'r', rgb[:,1], 'g', rgb[:,2], 'b') + plt.title('Unfiltered RGB sequence') + plt.subplot(1, 2, 2) + plt.plot(rgbi[:,0], 'r', rgbi[:,1], 'g', rgbi[:,2], 'b') + plt.title('Filtered RGB sequence (w = %d)' % w) + plt.savefig(__file) + + # Process images sequentially + if __outdir: + print ('Processing images') + ii = 0 + for ff in f: + img = np.asarray(Image.open('%s/%s' % (loc, ff))) / 255. + relaxToMean(img, rgbi[ii,:]) + jpg = Image.fromarray(toIntColor(img)) + jpg.save('%s/%s' % (__output, ff)) + ii += 1 + + print ('Finished') diff --git a/generate.py b/generate.py index 6762856..ce8e1e4 100644 --- a/generate.py +++ b/generate.py @@ -13,17 +13,25 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--checkpoint", help="checkpoint location", required=True) - parser.add_argument("--data_root", help="data root", required=True) - parser.add_argument("--dir_input", help="dir input", required=True) + #parser.add_argument("--data_root", help="data root", required=False) + #parser.add_argument("--dir_input", help="dir input", required=False) parser.add_argument("--dir_x1", help="dir extra 1", required=False) parser.add_argument("--dir_x2", help="dir extra 2", required=False) parser.add_argument("--dir_x3", help="dir extra 3", required=False) parser.add_argument("--outdir", help="output directory", required=True) parser.add_argument("--device", help="device", required=True) + parser.add_argument("--channels", help="if you didn't use tools_all.py u can just use --channels 1, if you did use it use --channels 2", required=True) + parser.add_argument('--projectname', type=str, help='name of the project_', required=True) args = parser.parse_args() - - generator = (torch.load(args.checkpoint, map_location=lambda storage, loc: storage)) + + data_path = os.path.expanduser('~\Documents\\visionsofchaos\\fewshot\\data') + data_root = data_path + "\\" + args.projectname+"_gen" + dir_input = "input_filtered" + checkpoint = data_path + "\\" + "\\"+ args.projectname+"_train"+"\\"+"logs_reference_P"+"\\"+args.checkpoint + + generator = (torch.load(checkpoint, map_location=lambda storage, loc: storage)) generator.eval() + if not os.path.exists(args.outdir): os.mkdir(args.outdir) @@ -35,10 +43,10 @@ if device.lower() != "cpu": generator = generator.type(torch.half) transform = build_transform() - dataset = DatasetFullImages(args.data_root + "/" + args.dir_input, "ignore", "ignore", device, - dir_x1=args.data_root + "/" + args.dir_x1 if args.dir_x1 is not None else None, - dir_x2=args.data_root + "/" + args.dir_x2 if args.dir_x2 is not None else None, - dir_x3=args.data_root + "/" + args.dir_x3 if args.dir_x3 is not None else None, + dataset = DatasetFullImages(data_root + "/" + dir_input, "ignore", "ignore", device, + dir_x1=data_root + "/" + args.dir_x1 if args.dir_x1 is not None else None, + dir_x2=data_root + "/" + args.dir_x2 if args.dir_x2 is not None else None, + dir_x3=data_root + "/" + args.dir_x3 if args.dir_x3 is not None else None, dir_x4=None, dir_x5=None, dir_x6=None, dir_x7=None, dir_x8=None, dir_x9=None) imloader = torch.utils.data.DataLoader(dataset, 1, shuffle=False, num_workers=1, drop_last=False) # num_workers=4 @@ -56,7 +64,7 @@ #image_space_in = to_image_space(batch['image'].cpu().data.numpy()) #image_space = to_image_space(net_out.cpu().data.numpy()) - image_space = ((net_out.clamp(-1, 1) + 1) * 127.5).permute((0, 2, 3, 1)) + image_space = ((net_out.clamp(-1, 1) + 1) * 127.5).permute((0, int(args.channels), 3, 1)) image_space = image_space.cpu().data.numpy().astype(np.uint8) for k in range(0, len(image_space)): diff --git a/libdeflicker.py b/libdeflicker.py new file mode 100644 index 0000000..77563d5 --- /dev/null +++ b/libdeflicker.py @@ -0,0 +1,155 @@ +""" +libdeflicker.py +--------------- +Library routines for image deflickering. + +.. moduleauthor Tristan Abbott +""" + +import numpy as np +from scipy import signal + +def squareFilter(sig, w): + """ + squareFilter(sig, w) + -------------------- + Smooth a signal with a square filter. + + This function is just a wrapper for scipy.signal.convolve with a kernel + given by ``np.ones(w)/w``. + + Parameters: + sig: np.array + Unsmoothed signal + w: int + Width of the filter + + Returns: + np.array + Smoothed signal + """ + # Create filter + win = np.ones(w) + # Pad input + sigp = np.concatenate(([np.tile(sig[0], w//2), sig, + np.tile(sig[-1], w//2)])) + # Filter + return signal.convolve(sigp, win, mode = + 'same')[w//2:-w//2+1] / np.sum(win) + +# Compute image-mean RGB values +def meanRGB(img, ii = -1): + """ + meanRGB(img, ii = -1) + --------------------- + Compute image-mean RGB values. + + This function takes an np.array representation of an image (x and y in the + first two dimensions and RGB values along the third dimension) and computes + the image-average R,G, and B values. + + Parameters: + img: np.array + Array image representation. The first two dimensions should + represent pixel positions, and each position in the third dimension + can represent a particular pixel attribute, e.g. an R, G, or B + value; an H, S, or V value, etc. + ii: int, optional + Specify a slice of the third dimension to average over. If a + particular slice is specified, the function returns a scalar; + otherwise, it returns an average over each slice in the third + dimension of the input image. ``ii`` must be between ``0`` and + ``img.shape[2]``, inclusive. + + Returns: + np.array + Average over the specified slice, if ``ii`` is given, or a 1D array + of average over the first two dimensions for each slice in the + third dimension. + """ + if ii < 0: + return np.array([np.mean(img[:,:,i]) for i in range(0,img.shape[2])]) + else: + return np.mean(img[:,:,ii]) + +# Adjust pixel-by-pixel RGB values to converge to correct mean +# by multiplying them by a uniform value. +def relaxToMean(img, rgb): + """ + relaxToMean(img, rgb) + --------------------- + Uniformly adjust pixel-by-pixel attributes so their mean becomes a + specified value. + + The adjustment is done by multiplying pixel attributes by a scaling factor + that is unique to the attribute but uniform over all the pixels in the + image. This function assumes that each + attribute is described by a floating point number between 0 and 1, + inclusive, and it will stop individual pixels from moving outside this range + while others are being scaled. + + Parameters: + img: np.array + Array image representation. The first two dimensions should + represent pixel positions, and each position in the third dimension + can represent a particular pixel attribute, e.g. an R, G, or B + value; an H, S, or V value, etc. + rgb: np.array + Desired image-mean values for each attribute included in ``img``. + The linear indices of the values in this array should map in order + to the attributes in the third dimension of ``img``. + + Returns: + np.array + ``img`` with each attribute multiplied by a factor (unique to the + attribute but the same for that attribute in every pixel in the + image) such that the image mean of that attribute is as specified + in ``rgb``. + + """ + rgbi = meanRGB(img) + fac = np.array([2. if i else 0.5 for i in rgbi < rgb]) + + # Relax toward mean + for ii in range(0,3): + + # Repeat until converged to mean + while not np.isclose(rgbi[ii], rgb[ii]): + + # Compute ratio + r = rgb[ii] / rgbi[ii] + # Relax image + img[:,:,ii] = np.minimum(1., img[:,:,ii] * r) + # Update average + rgbi[ii] = meanRGB(img, ii) + +# Convert floating point colors to integer colors +def toIntColor(img, t = np.uint8): + """ + toIntColor(img, t = np.uint8) + ----------------------------- + Convert floating-point attributes to other types. + + This function takes an image with floating-point [0,1] representations of + attributes and returns an near-equivalent image with attributes represented + by a different type. It does so by scaling the floating point attributes by + the maximum value representable by the new type and then converting the + scaled floating point value to the new type (with rounding, if required). + + Parameters: + img: np.array + Array image representation. The first two dimensions should + represent pixel positions, and each position in the third dimension + can represent a particular pixel attribute, e.g. an R, G, or B + value; an H, S, or V value, etc. The attributes must be represented + as [0,1] floating point values. + t: type, optional + Type used to represent attributes in the new image. By default, the + type is an unsigned 8 bit integer (``np.uint8``). + Returns: + np.array(dtype = t) + Representation of the attributes of ``img`` using the type + specified by ``t``. + """ + scale = np.iinfo(t).max + return np.round(img * scale).astype(t) diff --git a/logger1.py b/logger1.py new file mode 100644 index 0000000..29833b4 --- /dev/null +++ b/logger1.py @@ -0,0 +1,36 @@ +import tensorflow as tf +import os +import shutil + + +class Logger(object): + def __init__(self, log_dir, suffix=None): + """Create a summary writer logging to log_dir.""" + writer = tf.summary.create_file_writer(log_dir, filename_suffix=suffix) + with writer.as_default(): + for step in range(100): + # other model code would go here + tf.summary.scalar("my_metric", 0.5, step=step) + writer.flush() + + def scalar_summary(self, tag, value, step): + """Log a scalar variable.""" + summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) + self.writer.add_summary(summary, step) + + +class ModelLogger(object): + def __init__(self, log_dir, save_func): + self.log_dir = log_dir + self.save_func = save_func + + def save(self, model, epoch, isGenerator): + if isGenerator: + new_path = os.path.join(self.log_dir, "model_%05d.pth" % epoch) + else: + new_path = os.path.join(self.log_dir, "disc_%05d.pth" % epoch) + self.save_func(model, new_path) + + def copy_file(self, source): + shutil.copy(source, self.log_dir) + diff --git a/train.py b/train.py index dc62b7e..ac7176d 100644 --- a/train.py +++ b/train.py @@ -37,12 +37,18 @@ def worker_init_fn(worker_id): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--config', '-c', help='Yaml config with training parameters', required=True) - parser.add_argument('--log_folder', '-l', help='Log folder', required=True) - parser.add_argument('--data_root', '-r', help='Data root folder', required=True) + parser.add_argument('--log_folder', '-l', help='Log folder', default="logs_reference_P") + parser.add_argument('--data_root', '-r', help='Data root folder', required=False) parser.add_argument('--log_interval', '-i', type=int, help='Log interval', required=True) + parser.add_argument('--resume', '-rs', type=str, help='resume', required=False) + parser.add_argument('--projectname', type=str, help='name of the project_', required=True) args = parser.parse_args() - args_log_folder = args.data_root + "/" + args.log_folder + + + data_path = os.path.expanduser('~\Documents\\visionsofchaos\\fewshot\\data') + data_root = data_path + "\\" + args.projectname+"_train" + args_log_folder = data_root + "/" + "logs_reference_P" with open(args.config, 'r') as f: job_description = yaml.load(f, Loader=yaml.FullLoader) @@ -60,24 +66,25 @@ def worker_init_fn(worker_id): raise RuntimeError("Got unexpected parameter in training_dataset: " + str(training_dataset_parameters)) d = dict(config['training_dataset']) - d['dir_pre'] = args.data_root + "/" + d['dir_pre'] - d['dir_post'] = args.data_root + "/" + d['dir_post'] + d['dir_pre'] = data_root + "/" + d['dir_pre'] + d['dir_post'] = data_root + "/" + d['dir_post'] d['device'] = config['device'] if 'dir_mask' in d: - d['dir_mask'] = args.data_root + "/" + d['dir_mask'] + d['dir_mask'] = data_root + "/" + d['dir_mask'] # complete dir_x paths and set a correct number of channels channels = 3 for dir_x_index in range(1, 10): dir_x_name = f"dir_x{dir_x_index}" - d[dir_x_name] = args.data_root + "/" + d[dir_x_name] if dir_x_name in d else None + d[dir_x_name] = data_root + "/" + d[dir_x_name] if dir_x_name in d else None channels = channels + 3 if d[dir_x_name] is not None else channels config['generator']['args']['input_channels'] = channels print(d) - + resumedata = str(args.resume) generator = build_model(config['generator']['type'], config['generator']['args'], device) - #generator = (torch.load(args.data_root + "/model_00300_style2.pth", map_location=lambda storage, loc: storage)).to(device) + if args.resume: + generator = (torch.load(data_root + "/"+"/logs_reference_P"+"/"+resumedata+".pth", map_location=lambda storage, loc: storage)).to(device) opt_generator = build_optimizer(config['opt_generator']['type'], generator, config['opt_generator']['args']) discriminator, opt_discriminator = None, None @@ -127,6 +134,6 @@ def worker_init_fn(worker_id): args_config = args.config.replace('\\', '/') args_config = args_config[args_config.rfind('/') + 1:] - trainer.train(generator, discriminator, int(config['trainer']['epochs']), args.data_root, args_config, 0) + trainer.train(generator, discriminator, int(config['trainer']['epochs']), data_root, args_config, 0) print("Training finished", flush=True) sys.exit(0) diff --git a/train1.py b/train1.py new file mode 100644 index 0000000..dc62b7e --- /dev/null +++ b/train1.py @@ -0,0 +1,132 @@ +import argparse +import os +import data +import models as m +import torch +import torch.optim as optim +import yaml +from logger import Logger, ModelLogger +from trainers import Trainer +import sys +import numpy as np + + +def build_model(model_type, args, device): + model = getattr(m, model_type)(**args) + return model.to(device) + + +def build_optimizer(opt_type, model, args): + args['params'] = model.parameters() + opt_class = getattr(optim, opt_type) + return opt_class(**args) + + +def build_loggers(log_folder): + if not os.path.exists(log_folder): + os.makedirs(log_folder) + model_logger = ModelLogger(log_folder, torch.save) + scalar_logger = Logger(log_folder) + return scalar_logger, model_logger + + +def worker_init_fn(worker_id): + np.random.seed(np.random.get_state()[1][0] + worker_id) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--config', '-c', help='Yaml config with training parameters', required=True) + parser.add_argument('--log_folder', '-l', help='Log folder', required=True) + parser.add_argument('--data_root', '-r', help='Data root folder', required=True) + parser.add_argument('--log_interval', '-i', type=int, help='Log interval', required=True) + args = parser.parse_args() + + args_log_folder = args.data_root + "/" + args.log_folder + + with open(args.config, 'r') as f: + job_description = yaml.load(f, Loader=yaml.FullLoader) + + config = job_description['job'] + scalar_logger, model_logger = build_loggers(args_log_folder) + + model_logger.copy_file(args.config) + device = config.get('device') or 'cpu' + + # Check 'training_dataset' parameters + training_dataset_parameters = set(config['training_dataset'].keys()) - \ + {"type", "dir_pre", "dir_post", "dir_mask", "patch_size", "dir_x1", "dir_x2", "dir_x3", "dir_x4", "dir_x5", "dir_x6", "dir_x7", "dir_x8", "dir_x9", } + if len(training_dataset_parameters) > 0: + raise RuntimeError("Got unexpected parameter in training_dataset: " + str(training_dataset_parameters)) + + d = dict(config['training_dataset']) + d['dir_pre'] = args.data_root + "/" + d['dir_pre'] + d['dir_post'] = args.data_root + "/" + d['dir_post'] + d['device'] = config['device'] + if 'dir_mask' in d: + d['dir_mask'] = args.data_root + "/" + d['dir_mask'] + + # complete dir_x paths and set a correct number of channels + channels = 3 + for dir_x_index in range(1, 10): + dir_x_name = f"dir_x{dir_x_index}" + d[dir_x_name] = args.data_root + "/" + d[dir_x_name] if dir_x_name in d else None + channels = channels + 3 if d[dir_x_name] is not None else channels + config['generator']['args']['input_channels'] = channels + + print(d) + + generator = build_model(config['generator']['type'], config['generator']['args'], device) + #generator = (torch.load(args.data_root + "/model_00300_style2.pth", map_location=lambda storage, loc: storage)).to(device) + opt_generator = build_optimizer(config['opt_generator']['type'], generator, config['opt_generator']['args']) + + discriminator, opt_discriminator = None, None + if 'discriminator' in config: + discriminator = build_model(config['discriminator']['type'], config['discriminator']['args'], device) + #discriminator = (torch.load(args.data_root + "/disc_00300_style2.pth", map_location=lambda storage, loc: storage)).to(device) + opt_discriminator = build_optimizer(config['opt_discriminator']['type'], discriminator, config['opt_discriminator']['args']) + + if 'type' not in d: + raise RuntimeError("Type of training_dataset must be specified!") + + dataset_type = getattr(data, d.pop('type')) + training_dataset = dataset_type(**d) + + train_loader = torch.utils.data.DataLoader(training_dataset, config['trainer']['batch_size'], shuffle=False, + num_workers=config['num_workers'], drop_last=True)#, worker_init_fn=worker_init_fn) + + reconstruction_criterion = getattr(torch.nn, config['trainer']['reconstruction_criterion'])() + adversarial_criterion = getattr(torch.nn, config['trainer']['adversarial_criterion'])() + + perception_loss_model = None + perception_loss_weight = 1 + if 'perception_loss' in config: + if 'perception_model' in config['perception_loss']: + perception_loss_model = build_model(config['perception_loss']['perception_model']['type'], + config['perception_loss']['perception_model']['args'], + device) + else: + perception_loss_model = discriminator + + perception_loss_weight = config['perception_loss']['weight'] + + trainer = Trainer( + train_loader=train_loader, + data_for_dataloader=d, # data for later dataloader creation, if needed + opt_generator=opt_generator, opt_discriminator=opt_discriminator, + adversarial_criterion=adversarial_criterion, reconstruction_criterion=reconstruction_criterion, + reconstruction_weight=config['trainer']['reconstruction_weight'], + adversarial_weight=config['trainer']['adversarial_weight'], + log_interval=args.log_interval, + model_logger=model_logger, scalar_logger=scalar_logger, + perception_loss_model=perception_loss_model, + perception_loss_weight=perception_loss_weight, + use_image_loss=config['trainer']['use_image_loss'], + device=device + ) + + args_config = args.config.replace('\\', '/') + args_config = args_config[args_config.rfind('/') + 1:] + trainer.train(generator, discriminator, int(config['trainer']['epochs']), args.data_root, args_config, 0) + print("Training finished", flush=True) + sys.exit(0) diff --git a/train2.py b/train2.py new file mode 100644 index 0000000..dc62b7e --- /dev/null +++ b/train2.py @@ -0,0 +1,132 @@ +import argparse +import os +import data +import models as m +import torch +import torch.optim as optim +import yaml +from logger import Logger, ModelLogger +from trainers import Trainer +import sys +import numpy as np + + +def build_model(model_type, args, device): + model = getattr(m, model_type)(**args) + return model.to(device) + + +def build_optimizer(opt_type, model, args): + args['params'] = model.parameters() + opt_class = getattr(optim, opt_type) + return opt_class(**args) + + +def build_loggers(log_folder): + if not os.path.exists(log_folder): + os.makedirs(log_folder) + model_logger = ModelLogger(log_folder, torch.save) + scalar_logger = Logger(log_folder) + return scalar_logger, model_logger + + +def worker_init_fn(worker_id): + np.random.seed(np.random.get_state()[1][0] + worker_id) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--config', '-c', help='Yaml config with training parameters', required=True) + parser.add_argument('--log_folder', '-l', help='Log folder', required=True) + parser.add_argument('--data_root', '-r', help='Data root folder', required=True) + parser.add_argument('--log_interval', '-i', type=int, help='Log interval', required=True) + args = parser.parse_args() + + args_log_folder = args.data_root + "/" + args.log_folder + + with open(args.config, 'r') as f: + job_description = yaml.load(f, Loader=yaml.FullLoader) + + config = job_description['job'] + scalar_logger, model_logger = build_loggers(args_log_folder) + + model_logger.copy_file(args.config) + device = config.get('device') or 'cpu' + + # Check 'training_dataset' parameters + training_dataset_parameters = set(config['training_dataset'].keys()) - \ + {"type", "dir_pre", "dir_post", "dir_mask", "patch_size", "dir_x1", "dir_x2", "dir_x3", "dir_x4", "dir_x5", "dir_x6", "dir_x7", "dir_x8", "dir_x9", } + if len(training_dataset_parameters) > 0: + raise RuntimeError("Got unexpected parameter in training_dataset: " + str(training_dataset_parameters)) + + d = dict(config['training_dataset']) + d['dir_pre'] = args.data_root + "/" + d['dir_pre'] + d['dir_post'] = args.data_root + "/" + d['dir_post'] + d['device'] = config['device'] + if 'dir_mask' in d: + d['dir_mask'] = args.data_root + "/" + d['dir_mask'] + + # complete dir_x paths and set a correct number of channels + channels = 3 + for dir_x_index in range(1, 10): + dir_x_name = f"dir_x{dir_x_index}" + d[dir_x_name] = args.data_root + "/" + d[dir_x_name] if dir_x_name in d else None + channels = channels + 3 if d[dir_x_name] is not None else channels + config['generator']['args']['input_channels'] = channels + + print(d) + + generator = build_model(config['generator']['type'], config['generator']['args'], device) + #generator = (torch.load(args.data_root + "/model_00300_style2.pth", map_location=lambda storage, loc: storage)).to(device) + opt_generator = build_optimizer(config['opt_generator']['type'], generator, config['opt_generator']['args']) + + discriminator, opt_discriminator = None, None + if 'discriminator' in config: + discriminator = build_model(config['discriminator']['type'], config['discriminator']['args'], device) + #discriminator = (torch.load(args.data_root + "/disc_00300_style2.pth", map_location=lambda storage, loc: storage)).to(device) + opt_discriminator = build_optimizer(config['opt_discriminator']['type'], discriminator, config['opt_discriminator']['args']) + + if 'type' not in d: + raise RuntimeError("Type of training_dataset must be specified!") + + dataset_type = getattr(data, d.pop('type')) + training_dataset = dataset_type(**d) + + train_loader = torch.utils.data.DataLoader(training_dataset, config['trainer']['batch_size'], shuffle=False, + num_workers=config['num_workers'], drop_last=True)#, worker_init_fn=worker_init_fn) + + reconstruction_criterion = getattr(torch.nn, config['trainer']['reconstruction_criterion'])() + adversarial_criterion = getattr(torch.nn, config['trainer']['adversarial_criterion'])() + + perception_loss_model = None + perception_loss_weight = 1 + if 'perception_loss' in config: + if 'perception_model' in config['perception_loss']: + perception_loss_model = build_model(config['perception_loss']['perception_model']['type'], + config['perception_loss']['perception_model']['args'], + device) + else: + perception_loss_model = discriminator + + perception_loss_weight = config['perception_loss']['weight'] + + trainer = Trainer( + train_loader=train_loader, + data_for_dataloader=d, # data for later dataloader creation, if needed + opt_generator=opt_generator, opt_discriminator=opt_discriminator, + adversarial_criterion=adversarial_criterion, reconstruction_criterion=reconstruction_criterion, + reconstruction_weight=config['trainer']['reconstruction_weight'], + adversarial_weight=config['trainer']['adversarial_weight'], + log_interval=args.log_interval, + model_logger=model_logger, scalar_logger=scalar_logger, + perception_loss_model=perception_loss_model, + perception_loss_weight=perception_loss_weight, + use_image_loss=config['trainer']['use_image_loss'], + device=device + ) + + args_config = args.config.replace('\\', '/') + args_config = args_config[args_config.rfind('/') + 1:] + trainer.train(generator, discriminator, int(config['trainer']['epochs']), args.data_root, args_config, 0) + print("Training finished", flush=True) + sys.exit(0) diff --git a/trainers.py b/trainers.py index 9be770b..319231d 100644 --- a/trainers.py +++ b/trainers.py @@ -1,260 +1,267 @@ -import time -import models -import numpy as np -import six -import torch -import torch.nn as nn -from torch.autograd import Variable -from PIL import Image -from custom_transforms import * -from data import DatasetFullImages -import os - - -class Trainer(object): - def __init__(self, - train_loader, data_for_dataloader, opt_discriminator, opt_generator, - reconstruction_criterion, adversarial_criterion, reconstruction_weight, - adversarial_weight, log_interval, scalar_logger, model_logger, - perception_loss_model, perception_loss_weight, use_image_loss, device - ): - - self.train_loader = train_loader - self.data_for_dataloader = data_for_dataloader - - self.opt_discriminator = opt_discriminator - self.opt_generator = opt_generator - - self.reconstruction_criterion = reconstruction_criterion - self.adversarial_criterion = adversarial_criterion - - self.reconstruction_weight = reconstruction_weight - self.adversarial_weight = adversarial_weight - - self.scalar_logger = scalar_logger - self.model_logger = model_logger - - self.training_log = {} - self.log_interval = log_interval - - self.perception_loss_weight = perception_loss_weight - self.perception_loss_model = perception_loss_model - - self.use_adversarial_loss = False - self.use_image_loss = use_image_loss - self.device = device - - self.dataset = None - self.imloader = None - - - def run_discriminator(self, discriminator, images): - return discriminator(images) - - def compute_discriminator_loss(self, generator, discriminator, batch): - generated = generator(batch['pre']) - fake = self.apply_mask(generated, batch, 'pre_mask') - fake_labels, _ = self.run_discriminator(discriminator, fake.detach()) - - true = self.apply_mask(batch['already'], batch, 'already_mask') - true_labels, _ = self.run_discriminator(discriminator, true) - - discriminator_loss = self.adversarial_criterion(fake_labels, self.zeros_like(fake_labels)) + \ - self.adversarial_criterion(true_labels, self.ones_like(true_labels)) - - return discriminator_loss - - def compute_generator_loss(self, generator, discriminator, batch, use_gan, use_mask): - image_loss = 0 - perception_loss = 0 - adversarial_loss = 0 - - generated = generator(batch['pre']) - - if use_mask: - generated = generated * batch['mask'] - batch['post'] = batch['post'] * batch['mask'] - - if self.use_image_loss: - if generated[0][0].shape != batch['post'][0][0].shape: - if ((batch['post'][0][0].shape[0] - generated[0][0].shape[0]) % 2) != 0: - raise RuntimeError("batch['post'][0][0].shape[0] - generated[0][0].shape[0] must be even number") - if generated[0][0].shape[0] != generated[0][0].shape[1] or batch['post'][0][0].shape[0] != batch['post'][0][0].shape[1]: - raise RuntimeError("And also it is expected to be exact square ... fix it if you want") - boundary_size = int((batch['post'][0][0].shape[0] - generated[0][0].shape[0]) / 2) - cropped_batch_post = batch['post'][:, :, boundary_size: -1*boundary_size, boundary_size: -1*boundary_size] - image_loss = self.reconstruction_criterion(generated, cropped_batch_post) - else: - image_loss = self.reconstruction_criterion(generated, batch['post']) - - if self.perception_loss_model is not None: - _, fake_features = self.perception_loss_model(generated) - _, target_features = self.perception_loss_model(Variable(batch['post'], requires_grad=False)) - perception_loss = ((fake_features - target_features) ** 2).mean() - - - if self.use_adversarial_loss and use_gan: - fake = self.apply_mask(generated, batch, 'pre_mask') - fake_smiling_labels, _ = self.run_discriminator(discriminator, fake) - adversarial_loss = self.adversarial_criterion(fake_smiling_labels, self.ones_like(fake_smiling_labels)) - - return image_loss, perception_loss, adversarial_loss, generated - - - def train(self, generator, discriminator, epochs, data_root, config_yaml_name, starting_batch_num): - self.use_adversarial_loss = discriminator is not None - batch_num = starting_batch_num - save_num = 0 - - start = time.time() - for epoch in range(epochs): - np.random.seed() - for i, batch in enumerate(self.train_loader): - # just sets the models into training mode (enable BN and DO) - [m.train() for m in [generator, discriminator] if m is not None] - batch = {k: batch[k].to(self.device) if isinstance(batch[k], torch.Tensor) else batch[k] - for k in batch.keys()} - - # train discriminator - if self.use_adversarial_loss: - self.opt_discriminator.zero_grad() - discriminator_loss = self.compute_discriminator_loss(generator, discriminator, batch) - discriminator_loss.backward() - self.opt_discriminator.step() - - # train generator - self.opt_generator.zero_grad() - - g_image_loss, g_perc_loss, g_adv_loss, _ = self.compute_generator_loss(generator, discriminator, batch, use_gan=True, use_mask=False) - - generator_loss = self.reconstruction_weight * g_image_loss + \ - self.perception_loss_weight * g_perc_loss + \ - self.adversarial_weight * g_adv_loss - - generator_loss.backward() - - self.opt_generator.step() - - # log losses - current_log = {key: value.item() for key, value in six.iteritems(locals()) if - 'loss' in key and isinstance(value, Variable)} - - self.add_log(current_log) - - batch_num += 1 - - if batch_num % 100 == 0: - print(f"Batch num: {batch_num}, totally elapsed {(time.time() - start)}", flush=True) - - #if batch_num % self.log_interval == 0 or batch_num == 1: - if batch_num % self.log_interval == 0 or batch_num == 1: # (time.time() - start) > 16: - eval_start = time.time() - generator.eval() - self.test_on_full_image(generator, batch_num, data_root, config_yaml_name) - self.flush_scalar_log(batch_num, time.time() - start) - self.model_logger.save(generator, save_num, True) - #self.model_logger.save(discriminator, save_num, False) - save_num += 1 - print(f"Eval of batch: {batch_num} took {(time.time() - eval_start)}", flush=True) - - #if batch_num > 5000: - # sys.exit(0) - - self.model_logger.save(generator, 99999) - - # Accumulates the losses - def add_log(self, log): - for k, v in log.items(): - if k in self.training_log: - self.training_log[k] += v - else: - self.training_log[k] = v - - # Divide the losses by log_interval and print'em - def flush_scalar_log(self, batch_num, took): - for key in self.training_log.keys(): - self.scalar_logger.scalar_summary(key, self.training_log[key] / self.log_interval, batch_num) - - log = "[%d]" % batch_num - for key in sorted(self.training_log.keys()): - log += " [%s] % 7.4f" % (key, self.training_log[key] / self.log_interval) - - log += ". Took {}".format(took) - print(log, flush=True) - self.training_log = {} - - # Test the intermediate model on data from _gen folder - def test_on_full_image(self, generator, batch_num, data_root, config_yaml_name): - config_yaml_name = config_yaml_name.replace("reference", "").replace(".yaml", "") - - data_root = data_root.replace("_train", "_gen") - if self.dataset is None: - self.dataset = DatasetFullImages(data_root + "/" + self.data_for_dataloader['dir_pre'].split("/")[-1], - "ignore", # data_root + "/" + "ebsynth", - "ignore", # data_root + "/" + "mask", - self.device, - dir_x1=data_root + "/" + self.data_for_dataloader['dir_x1'].split("/")[-1] if self.data_for_dataloader['dir_x1'] is not None else None, - dir_x2=data_root + "/" + self.data_for_dataloader['dir_x2'].split("/")[-1] if self.data_for_dataloader['dir_x2'] is not None else None, - dir_x3=data_root + "/" + self.data_for_dataloader['dir_x3'].split("/")[-1] if self.data_for_dataloader['dir_x3'] is not None else None, - dir_x4=data_root + "/" + self.data_for_dataloader['dir_x4'].split("/")[-1] if self.data_for_dataloader['dir_x4'] is not None else None, - dir_x5=data_root + "/" + self.data_for_dataloader['dir_x5'].split("/")[-1] if self.data_for_dataloader['dir_x5'] is not None else None, - dir_x6=data_root + "/" + self.data_for_dataloader['dir_x6'].split("/")[-1] if self.data_for_dataloader['dir_x6'] is not None else None, - dir_x7=data_root + "/" + self.data_for_dataloader['dir_x7'].split("/")[-1] if self.data_for_dataloader['dir_x7'] is not None else None, - dir_x8=data_root + "/" + self.data_for_dataloader['dir_x8'].split("/")[-1] if self.data_for_dataloader['dir_x8'] is not None else None, - dir_x9=data_root + "/" + self.data_for_dataloader['dir_x9'].split("/")[-1] if self.data_for_dataloader['dir_x9'] is not None else None) - self.imloader = torch.utils.data.DataLoader(self.dataset, 1, shuffle=False, num_workers=1, drop_last=False) # num_workers=4 - - with torch.no_grad(): - log = "### \n" - log = log + "[%d]" % batch_num + " " - generator_loss_on_ebsynth = 0 - for i, batch in enumerate(self.imloader): - batch = {k: batch[k].to(self.device) if isinstance(batch[k], torch.Tensor) else batch[k] - for k in batch.keys()} - g_image_loss, g_perc_loss, g_adv_loss, e_cls_loss, e_smiling_loss, gan_output =\ - 0, 0, 0, 0, 0, generator(batch['pre']) - - generator_loss = self.reconstruction_weight * g_image_loss + \ - self.perception_loss_weight * g_perc_loss + \ - self.adversarial_weight * g_adv_loss - - if True or batch['file_name'][0] != "111.png": # do not accumulate loss in train frame - generator_loss_on_ebsynth = generator_loss_on_ebsynth + generator_loss - - if True or batch['file_name'][0] in ["111.png", "101.png", "106.png", "116.png", "121.png"]: - #log = log + batch['file_name'][0] - #log = log + ": %7.4f" % generator_loss + ", " - - image_space = to_image_space(gan_output.cpu().data.numpy()) - - gt_test_ganoutput_path = data_root + "/" + "res_" + config_yaml_name - if not os.path.exists(gt_test_ganoutput_path): - os.mkdir(gt_test_ganoutput_path) - gt_test_ganoutput_path_batch_num = gt_test_ganoutput_path + "/" + str("%07d" % batch_num) - if not os.path.exists(gt_test_ganoutput_path_batch_num): - os.mkdir(gt_test_ganoutput_path_batch_num) - for k in range(0, len(image_space)): - im = image_space[k].transpose(1, 2, 0) - Image.fromarray(im).save(os.path.join(gt_test_ganoutput_path_batch_num, batch['file_name'][k])) - if i == 0: - Image.fromarray(im).save(os.path.join(gt_test_ganoutput_path, str("%07d" % batch_num) + ".png")) - - log = log + " totalLossOnEbsynth: %7.4f" % (generator_loss_on_ebsynth/(len(self.imloader))) - print(log, flush=True) - - - def apply_mask(self, x, batch, mask_key): - if mask_key in batch: - mask = Variable(batch[mask_key].expand(x.size()), requires_grad=False) - return x * (mask / 2 + 0.5) - return x - - def ones_like(self, x): - return torch.ones_like(x).to(self.device) - - def zeros_like(self, x): - return torch.zeros_like(x).to(self.device) - - @staticmethod - def to_image_space(x): - return ((np.clip(x, -1, 1) + 1) / 2 * 255).astype(np.uint8) +import time +import models +import numpy as np +import six +import torch +import torch.nn as nn +from torch.autograd import Variable +from PIL import Image +from custom_transforms import * +from data import DatasetFullImages +import os +import gc + +import tensorflow as tf +config = tf.compat.v1.ConfigProto() +config.gpu_options.allow_growth = True +sess = tf.compat.v1.Session(config=config) + +torch.backends.cudnn.benchmark = False + +class Trainer(object): + def __init__(self, + train_loader, data_for_dataloader, opt_discriminator, opt_generator, + reconstruction_criterion, adversarial_criterion, reconstruction_weight, + adversarial_weight, log_interval, scalar_logger, model_logger, + perception_loss_model, perception_loss_weight, use_image_loss, device + ): + + self.train_loader = train_loader + self.data_for_dataloader = data_for_dataloader + + self.opt_discriminator = opt_discriminator + self.opt_generator = opt_generator + + self.reconstruction_criterion = reconstruction_criterion + self.adversarial_criterion = adversarial_criterion + + self.reconstruction_weight = reconstruction_weight + self.adversarial_weight = adversarial_weight + + self.scalar_logger = scalar_logger + self.model_logger = model_logger + + self.training_log = {} + self.log_interval = log_interval + + self.perception_loss_weight = perception_loss_weight + self.perception_loss_model = perception_loss_model + + self.use_adversarial_loss = False + self.use_image_loss = use_image_loss + self.device = device + + self.dataset = None + self.imloader = None + + + def run_discriminator(self, discriminator, images): + return discriminator(images) + + def compute_discriminator_loss(self, generator, discriminator, batch): + generated = generator(batch['pre']) + fake = self.apply_mask(generated, batch, 'pre_mask') + fake_labels, _ = self.run_discriminator(discriminator, fake.detach()) + + true = self.apply_mask(batch['already'], batch, 'already_mask') + true_labels, _ = self.run_discriminator(discriminator, true) + + discriminator_loss = self.adversarial_criterion(fake_labels, self.zeros_like(fake_labels)) + \ + self.adversarial_criterion(true_labels, self.ones_like(true_labels)) + + return discriminator_loss + + def compute_generator_loss(self, generator, discriminator, batch, use_gan, use_mask): + image_loss = 0 + perception_loss = 0 + adversarial_loss = 0 + + generated = generator(batch['pre']) + + if use_mask: + generated = generated * batch['mask'] + batch['post'] = batch['post'] * batch['mask'] + + if self.use_image_loss: + if generated[0][0].shape != batch['post'][0][0].shape: + if ((batch['post'][0][0].shape[0] - generated[0][0].shape[0]) % 2) != 0: + raise RuntimeError("batch['post'][0][0].shape[0] - generated[0][0].shape[0] must be even number") + if generated[0][0].shape[0] != generated[0][0].shape[1] or batch['post'][0][0].shape[0] != batch['post'][0][0].shape[1]: + raise RuntimeError("And also it is expected to be exact square ... fix it if you want") + boundary_size = int((batch['post'][0][0].shape[0] - generated[0][0].shape[0]) / 2) + cropped_batch_post = batch['post'][:, :, boundary_size: -1*boundary_size, boundary_size: -1*boundary_size] + image_loss = self.reconstruction_criterion(generated, cropped_batch_post) + else: + image_loss = self.reconstruction_criterion(generated, batch['post']) + + if self.perception_loss_model is not None: + _, fake_features = self.perception_loss_model(generated) + _, target_features = self.perception_loss_model(Variable(batch['post'], requires_grad=False)) + perception_loss = ((fake_features - target_features) ** 2).mean() + + + if self.use_adversarial_loss and use_gan: + fake = self.apply_mask(generated, batch, 'pre_mask') + fake_smiling_labels, _ = self.run_discriminator(discriminator, fake) + adversarial_loss = self.adversarial_criterion(fake_smiling_labels, self.ones_like(fake_smiling_labels)) + + return image_loss, perception_loss, adversarial_loss, generated + + + def train(self, generator, discriminator, epochs, data_root, config_yaml_name, starting_batch_num, step): + self.use_adversarial_loss = discriminator is not None + batch_num = starting_batch_num + save_num = 0 + + start = time.time() + for epoch in range(epochs): + np.random.seed() + for i, batch in enumerate(self.train_loader): + # just sets the models into training mode (enable BN and DO) + [m.train() for m in [generator, discriminator] if m is not None] + batch = {k: batch[k].to(self.device) if isinstance(batch[k], torch.Tensor) else batch[k] + for k in batch.keys()} + + # train discriminator + if self.use_adversarial_loss: + self.opt_discriminator.zero_grad() + discriminator_loss = self.compute_discriminator_loss(generator, discriminator, batch) + discriminator_loss.backward() + self.opt_discriminator.step() + + # train generator + self.opt_generator.zero_grad() + + g_image_loss, g_perc_loss, g_adv_loss, _ = self.compute_generator_loss(generator, discriminator, batch, use_gan=True, use_mask=False) + + generator_loss = self.reconstruction_weight * g_image_loss + \ + self.perception_loss_weight * g_perc_loss + \ + self.adversarial_weight * g_adv_loss + + generator_loss.backward() + + self.opt_generator.step() + + # log losses + current_log = {key: value.item() for key, value in six.iteritems(locals()) if + 'loss' in key and isinstance(value, Variable)} + + self.add_log(current_log) + + batch_num += 1 + + if batch_num % 100 == 0: + print(f"Batch num: {batch_num}, totally elapsed {(time.time() - start)}", flush=True) + + #if batch_num % self.log_interval == 0 or batch_num == 1: + if batch_num % self.log_interval == 0 or batch_num == 1: # (time.time() - start) > 16: + eval_start = time.time() + generator.eval() + self.test_on_full_image(generator, batch_num, data_root, config_yaml_name) + self.flush_scalar_log(batch_num, time.time() - start, step=step) + self.model_logger.save(generator, save_num, True) + #self.model_logger.save(discriminator, save_num, False) + save_num += 1 + print(f"Eval of batch: {batch_num} took {(time.time() - eval_start)}", flush=True) + + #if batch_num > 5000: + # sys.exit(0) + + self.model_logger.save(generator, 99999) + + # Accumulates the losses + def add_log(self, log): + for k, v in log.items(): + if k in self.training_log: + self.training_log[k] += v + else: + self.training_log[k] = v + + # Divide the losses by log_interval and print'em + def flush_scalar_log(self, batch_num, took, step): + for key in self.training_log.keys(): + self.scalar_logger.scalar_summary(key, self.training_log[key] / self.log_interval, batch_num, step=step) + + log = "[%d]" % batch_num + for key in sorted(self.training_log.keys()): + log += " [%s] % 7.4f" % (key, self.training_log[key] / self.log_interval) + + log += ". Took {}".format(took) + print(log, flush=True) + self.training_log = {} + + # Test the intermediate model on data from _gen folder + def test_on_full_image(self, generator, batch_num, data_root, config_yaml_name): + config_yaml_name = config_yaml_name.replace("reference", "").replace(".yaml", "") + + data_root = data_root.replace("_train", "_gen") + if self.dataset is None: + self.dataset = DatasetFullImages(data_root + "/" + self.data_for_dataloader['dir_pre'].split("/")[-1], + "ignore", # data_root + "/" + "ebsynth", + "ignore", # data_root + "/" + "mask", + self.device, + dir_x1=data_root + "/" + self.data_for_dataloader['dir_x1'].split("/")[-1] if self.data_for_dataloader['dir_x1'] is not None else None, + dir_x2=data_root + "/" + self.data_for_dataloader['dir_x2'].split("/")[-1] if self.data_for_dataloader['dir_x2'] is not None else None, + dir_x3=data_root + "/" + self.data_for_dataloader['dir_x3'].split("/")[-1] if self.data_for_dataloader['dir_x3'] is not None else None, + dir_x4=data_root + "/" + self.data_for_dataloader['dir_x4'].split("/")[-1] if self.data_for_dataloader['dir_x4'] is not None else None, + dir_x5=data_root + "/" + self.data_for_dataloader['dir_x5'].split("/")[-1] if self.data_for_dataloader['dir_x5'] is not None else None, + dir_x6=data_root + "/" + self.data_for_dataloader['dir_x6'].split("/")[-1] if self.data_for_dataloader['dir_x6'] is not None else None, + dir_x7=data_root + "/" + self.data_for_dataloader['dir_x7'].split("/")[-1] if self.data_for_dataloader['dir_x7'] is not None else None, + dir_x8=data_root + "/" + self.data_for_dataloader['dir_x8'].split("/")[-1] if self.data_for_dataloader['dir_x8'] is not None else None, + dir_x9=data_root + "/" + self.data_for_dataloader['dir_x9'].split("/")[-1] if self.data_for_dataloader['dir_x9'] is not None else None) + self.imloader = torch.utils.data.DataLoader(self.dataset, 1, shuffle=False, num_workers=1, drop_last=False) # num_workers=4 + + with torch.no_grad(): + log = "### \n" + log = log + "[%d]" % batch_num + " " + generator_loss_on_ebsynth = 0 + for i, batch in enumerate(self.imloader): + batch = {k: batch[k].to(self.device) if isinstance(batch[k], torch.Tensor) else batch[k] + for k in batch.keys()} + g_image_loss, g_perc_loss, g_adv_loss, e_cls_loss, e_smiling_loss, gan_output =\ + 0, 0, 0, 0, 0, generator(batch['pre']) + + generator_loss = self.reconstruction_weight * g_image_loss + \ + self.perception_loss_weight * g_perc_loss + \ + self.adversarial_weight * g_adv_loss + + if True or batch['file_name'][0] != "111.png": # do not accumulate loss in train frame + generator_loss_on_ebsynth = generator_loss_on_ebsynth + generator_loss + + if True or batch['file_name'][0] in ["111.png", "101.png", "106.png", "116.png", "121.png"]: + #log = log + batch['file_name'][0] + #log = log + ": %7.4f" % generator_loss + ", " + + image_space = to_image_space(gan_output.cpu().data.numpy()) + + gt_test_ganoutput_path = data_root + "/" + "res_" + config_yaml_name + if not os.path.exists(gt_test_ganoutput_path): + os.mkdir(gt_test_ganoutput_path) + gt_test_ganoutput_path_batch_num = gt_test_ganoutput_path + "/" + str("%07d" % batch_num) + if not os.path.exists(gt_test_ganoutput_path_batch_num): + os.mkdir(gt_test_ganoutput_path_batch_num) + for k in range(0, len(image_space)): + im = image_space[k].transpose(1, 2, 0) + Image.fromarray(im).save(os.path.join(gt_test_ganoutput_path_batch_num, batch['file_name'][k])) + if i == 0: + Image.fromarray(im).save(os.path.join(gt_test_ganoutput_path, str("%07d" % batch_num) + ".png")) + + log = log + " totalLossOnEbsynth: %7.4f" % (generator_loss_on_ebsynth/(len(self.imloader))) + print(log, flush=True) + + + def apply_mask(self, x, batch, mask_key): + if mask_key in batch: + mask = Variable(batch[mask_key].expand(x.size()), requires_grad=False) + return x * (mask / 2 + 0.5) + return x + + def ones_like(self, x): + return torch.ones_like(x).to(self.device) + + def zeros_like(self, x): + return torch.zeros_like(x).to(self.device) + + @staticmethod + def to_image_space(x): + return ((np.clip(x, -1, 1) + 1) / 2 * 255).astype(np.uint8) diff --git a/trainers1.py b/trainers1.py new file mode 100644 index 0000000..b99fd1f --- /dev/null +++ b/trainers1.py @@ -0,0 +1,267 @@ +import time +import models +import numpy as np +import six +import torch +import torch.nn as nn +from torch.autograd import Variable +from PIL import Image +from custom_transforms import * +from data import DatasetFullImages +import os +import gc + +import tensorflow as tf +config = tf.compat.v1.ConfigProto() +config.gpu_options.allow_growth = True +sess = tf.Session(config=config) + +torch.backends.cudnn.benchmark = False + +class Trainer(object): + def __init__(self, + train_loader, data_for_dataloader, opt_discriminator, opt_generator, + reconstruction_criterion, adversarial_criterion, reconstruction_weight, + adversarial_weight, log_interval, scalar_logger, model_logger, + perception_loss_model, perception_loss_weight, use_image_loss, device + ): + + self.train_loader = train_loader + self.data_for_dataloader = data_for_dataloader + + self.opt_discriminator = opt_discriminator + self.opt_generator = opt_generator + + self.reconstruction_criterion = reconstruction_criterion + self.adversarial_criterion = adversarial_criterion + + self.reconstruction_weight = reconstruction_weight + self.adversarial_weight = adversarial_weight + + self.scalar_logger = scalar_logger + self.model_logger = model_logger + + self.training_log = {} + self.log_interval = log_interval + + self.perception_loss_weight = perception_loss_weight + self.perception_loss_model = perception_loss_model + + self.use_adversarial_loss = False + self.use_image_loss = use_image_loss + self.device = device + + self.dataset = None + self.imloader = None + + + def run_discriminator(self, discriminator, images): + return discriminator(images) + + def compute_discriminator_loss(self, generator, discriminator, batch): + generated = generator(batch['pre']) + fake = self.apply_mask(generated, batch, 'pre_mask') + fake_labels, _ = self.run_discriminator(discriminator, fake.detach()) + + true = self.apply_mask(batch['already'], batch, 'already_mask') + true_labels, _ = self.run_discriminator(discriminator, true) + + discriminator_loss = self.adversarial_criterion(fake_labels, self.zeros_like(fake_labels)) + \ + self.adversarial_criterion(true_labels, self.ones_like(true_labels)) + + return discriminator_loss + + def compute_generator_loss(self, generator, discriminator, batch, use_gan, use_mask): + image_loss = 0 + perception_loss = 0 + adversarial_loss = 0 + + generated = generator(batch['pre']) + + if use_mask: + generated = generated * batch['mask'] + batch['post'] = batch['post'] * batch['mask'] + + if self.use_image_loss: + if generated[0][0].shape != batch['post'][0][0].shape: + if ((batch['post'][0][0].shape[0] - generated[0][0].shape[0]) % 2) != 0: + raise RuntimeError("batch['post'][0][0].shape[0] - generated[0][0].shape[0] must be even number") + if generated[0][0].shape[0] != generated[0][0].shape[1] or batch['post'][0][0].shape[0] != batch['post'][0][0].shape[1]: + raise RuntimeError("And also it is expected to be exact square ... fix it if you want") + boundary_size = int((batch['post'][0][0].shape[0] - generated[0][0].shape[0]) / 2) + cropped_batch_post = batch['post'][:, :, boundary_size: -1*boundary_size, boundary_size: -1*boundary_size] + image_loss = self.reconstruction_criterion(generated, cropped_batch_post) + else: + image_loss = self.reconstruction_criterion(generated, batch['post']) + + if self.perception_loss_model is not None: + _, fake_features = self.perception_loss_model(generated) + _, target_features = self.perception_loss_model(Variable(batch['post'], requires_grad=False)) + perception_loss = ((fake_features - target_features) ** 2).mean() + + + if self.use_adversarial_loss and use_gan: + fake = self.apply_mask(generated, batch, 'pre_mask') + fake_smiling_labels, _ = self.run_discriminator(discriminator, fake) + adversarial_loss = self.adversarial_criterion(fake_smiling_labels, self.ones_like(fake_smiling_labels)) + + return image_loss, perception_loss, adversarial_loss, generated + + + def train(self, generator, discriminator, epochs, data_root, config_yaml_name, starting_batch_num): + self.use_adversarial_loss = discriminator is not None + batch_num = starting_batch_num + save_num = 0 + + start = time.time() + for epoch in range(epochs): + np.random.seed() + for i, batch in enumerate(self.train_loader): + # just sets the models into training mode (enable BN and DO) + [m.train() for m in [generator, discriminator] if m is not None] + batch = {k: batch[k].to(self.device) if isinstance(batch[k], torch.Tensor) else batch[k] + for k in batch.keys()} + + # train discriminator + if self.use_adversarial_loss: + self.opt_discriminator.zero_grad() + discriminator_loss = self.compute_discriminator_loss(generator, discriminator, batch) + discriminator_loss.backward() + self.opt_discriminator.step() + + # train generator + self.opt_generator.zero_grad() + + g_image_loss, g_perc_loss, g_adv_loss, _ = self.compute_generator_loss(generator, discriminator, batch, use_gan=True, use_mask=False) + + generator_loss = self.reconstruction_weight * g_image_loss + \ + self.perception_loss_weight * g_perc_loss + \ + self.adversarial_weight * g_adv_loss + + generator_loss.backward() + + self.opt_generator.step() + + # log losses + current_log = {key: value.item() for key, value in six.iteritems(locals()) if + 'loss' in key and isinstance(value, Variable)} + + self.add_log(current_log) + + batch_num += 1 + + if batch_num % 100 == 0: + print(f"Batch num: {batch_num}, totally elapsed {(time.time() - start)}", flush=True) + + #if batch_num % self.log_interval == 0 or batch_num == 1: + if batch_num % self.log_interval == 0 or batch_num == 1: # (time.time() - start) > 16: + eval_start = time.time() + generator.eval() + self.test_on_full_image(generator, batch_num, data_root, config_yaml_name) + self.flush_scalar_log(batch_num, time.time() - start) + self.model_logger.save(generator, save_num, True) + #self.model_logger.save(discriminator, save_num, False) + save_num += 1 + print(f"Eval of batch: {batch_num} took {(time.time() - eval_start)}", flush=True) + + #if batch_num > 5000: + # sys.exit(0) + + self.model_logger.save(generator, 99999) + + # Accumulates the losses + def add_log(self, log): + for k, v in log.items(): + if k in self.training_log: + self.training_log[k] += v + else: + self.training_log[k] = v + + # Divide the losses by log_interval and print'em + def flush_scalar_log(self, batch_num, took): + for key in self.training_log.keys(): + self.scalar_logger.scalar_summary(key, self.training_log[key] / self.log_interval, batch_num) + + log = "[%d]" % batch_num + for key in sorted(self.training_log.keys()): + log += " [%s] % 7.4f" % (key, self.training_log[key] / self.log_interval) + + log += ". Took {}".format(took) + print(log, flush=True) + self.training_log = {} + + # Test the intermediate model on data from _gen folder + def test_on_full_image(self, generator, batch_num, data_root, config_yaml_name): + config_yaml_name = config_yaml_name.replace("reference", "").replace(".yaml", "") + + data_root = data_root.replace("_train", "_gen") + if self.dataset is None: + self.dataset = DatasetFullImages(data_root + "/" + self.data_for_dataloader['dir_pre'].split("/")[-1], + "ignore", # data_root + "/" + "ebsynth", + "ignore", # data_root + "/" + "mask", + self.device, + dir_x1=data_root + "/" + self.data_for_dataloader['dir_x1'].split("/")[-1] if self.data_for_dataloader['dir_x1'] is not None else None, + dir_x2=data_root + "/" + self.data_for_dataloader['dir_x2'].split("/")[-1] if self.data_for_dataloader['dir_x2'] is not None else None, + dir_x3=data_root + "/" + self.data_for_dataloader['dir_x3'].split("/")[-1] if self.data_for_dataloader['dir_x3'] is not None else None, + dir_x4=data_root + "/" + self.data_for_dataloader['dir_x4'].split("/")[-1] if self.data_for_dataloader['dir_x4'] is not None else None, + dir_x5=data_root + "/" + self.data_for_dataloader['dir_x5'].split("/")[-1] if self.data_for_dataloader['dir_x5'] is not None else None, + dir_x6=data_root + "/" + self.data_for_dataloader['dir_x6'].split("/")[-1] if self.data_for_dataloader['dir_x6'] is not None else None, + dir_x7=data_root + "/" + self.data_for_dataloader['dir_x7'].split("/")[-1] if self.data_for_dataloader['dir_x7'] is not None else None, + dir_x8=data_root + "/" + self.data_for_dataloader['dir_x8'].split("/")[-1] if self.data_for_dataloader['dir_x8'] is not None else None, + dir_x9=data_root + "/" + self.data_for_dataloader['dir_x9'].split("/")[-1] if self.data_for_dataloader['dir_x9'] is not None else None) + self.imloader = torch.utils.data.DataLoader(self.dataset, 1, shuffle=False, num_workers=1, drop_last=False) # num_workers=4 + + with torch.no_grad(): + log = "### \n" + log = log + "[%d]" % batch_num + " " + generator_loss_on_ebsynth = 0 + for i, batch in enumerate(self.imloader): + batch = {k: batch[k].to(self.device) if isinstance(batch[k], torch.Tensor) else batch[k] + for k in batch.keys()} + g_image_loss, g_perc_loss, g_adv_loss, e_cls_loss, e_smiling_loss, gan_output =\ + 0, 0, 0, 0, 0, generator(batch['pre']) + + generator_loss = self.reconstruction_weight * g_image_loss + \ + self.perception_loss_weight * g_perc_loss + \ + self.adversarial_weight * g_adv_loss + + if True or batch['file_name'][0] != "111.png": # do not accumulate loss in train frame + generator_loss_on_ebsynth = generator_loss_on_ebsynth + generator_loss + + if True or batch['file_name'][0] in ["111.png", "101.png", "106.png", "116.png", "121.png"]: + #log = log + batch['file_name'][0] + #log = log + ": %7.4f" % generator_loss + ", " + + image_space = to_image_space(gan_output.cpu().data.numpy()) + + gt_test_ganoutput_path = data_root + "/" + "res_" + config_yaml_name + if not os.path.exists(gt_test_ganoutput_path): + os.mkdir(gt_test_ganoutput_path) + gt_test_ganoutput_path_batch_num = gt_test_ganoutput_path + "/" + str("%07d" % batch_num) + if not os.path.exists(gt_test_ganoutput_path_batch_num): + os.mkdir(gt_test_ganoutput_path_batch_num) + for k in range(0, len(image_space)): + im = image_space[k].transpose(1, 2, 0) + Image.fromarray(im).save(os.path.join(gt_test_ganoutput_path_batch_num, batch['file_name'][k])) + if i == 0: + Image.fromarray(im).save(os.path.join(gt_test_ganoutput_path, str("%07d" % batch_num) + ".png")) + + log = log + " totalLossOnEbsynth: %7.4f" % (generator_loss_on_ebsynth/(len(self.imloader))) + print(log, flush=True) + + + def apply_mask(self, x, batch, mask_key): + if mask_key in batch: + mask = Variable(batch[mask_key].expand(x.size()), requires_grad=False) + return x * (mask / 2 + 0.5) + return x + + def ones_like(self, x): + return torch.ones_like(x).to(self.device) + + def zeros_like(self, x): + return torch.zeros_like(x).to(self.device) + + @staticmethod + def to_image_space(x): + return ((np.clip(x, -1, 1) + 1) / 2 * 255).astype(np.uint8)