├── LICENSE.txt ├── README.md ├── al_chinese.txt ├── assets ├── LEMMA-pipeline.png ├── Supplementary_material_for_LEMMA.pdf ├── qualitative-comparison.png └── quantitative-comparison.png ├── config └── super_resolution.yaml ├── dataset ├── __init__.py ├── charset_36.txt ├── confuse.pkl └── dataset.py ├── interfaces ├── base.py └── super_resolution.py ├── loss └── text_focus_loss.py ├── main.py ├── model ├── ABINet │ ├── __pycache__ │ │ ├── abinet.cpython-38.pyc │ │ ├── attention.cpython-38.pyc │ │ ├── backbone.cpython-38.pyc │ │ ├── resnet.cpython-38.pyc │ │ └── transformer.cpython-38.pyc │ ├── abinet.py │ ├── attention.py │ ├── backbone.py │ ├── resnet.py │ └── transformer.py ├── MATRN │ ├── __pycache__ │ │ ├── matrn.cpython-38.pyc │ │ └── sematic_visual_backbone.cpython-38.pyc │ ├── matrn.py │ └── sematic_visual_backbone.py ├── MPNCOV │ ├── __init__.py │ └── python │ │ ├── MPNCOV.py │ │ └── __init__.py ├── Position_aware_module.py ├── __init__.py ├── attention.py ├── backbone.py ├── crnn │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── crnn.cpython-36.pyc │ │ ├── crnn.cpython-37.pyc │ │ ├── crnn.cpython-38.pyc │ │ ├── model.cpython-37.pyc │ │ └── model.cpython-38.pyc │ ├── crnn.py │ ├── model.py │ └── modules │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── feature_extraction.cpython-37.pyc │ │ ├── feature_extraction.cpython-38.pyc │ │ ├── prediction.cpython-37.pyc │ │ ├── prediction.cpython-38.pyc │ │ ├── sequence_modeling.cpython-37.pyc │ │ ├── sequence_modeling.cpython-38.pyc │ │ ├── transformation.cpython-37.pyc │ │ └── transformation.cpython-38.pyc │ │ ├── feature_extraction.py │ │ ├── prediction.py │ │ ├── sequence_modeling.py │ │ └── transformation.py ├── lemma.py ├── moran │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── asrn_res.cpython-36.pyc │ │ ├── asrn_res.cpython-37.pyc │ │ ├── asrn_res.cpython-38.pyc │ │ ├── fracPickup.cpython-36.pyc │ │ ├── fracPickup.cpython-37.pyc │ │ ├── fracPickup.cpython-38.pyc │ │ ├── moran.cpython-36.pyc │ │ ├── moran.cpython-37.pyc │ │ ├── moran.cpython-38.pyc │ │ ├── morn.cpython-36.pyc │ │ ├── morn.cpython-37.pyc │ │ └── morn.cpython-38.pyc │ ├── asrn_res.py │ ├── fracPickup.py │ ├── moran.py │ └── morn.py ├── parseq │ ├── __pycache__ │ │ ├── modules.cpython-38.pyc │ │ ├── parseq.cpython-38.pyc │ │ └── parseq_tokenizer.cpython-38.pyc │ ├── modules.py │ ├── parseq.py │ └── parseq_tokenizer.py ├── recognizer │ ├── SwimTransformer.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── attention_recognition_head.cpython-36.pyc │ │ ├── attention_recognition_head.cpython-37.pyc │ │ ├── attention_recognition_head.cpython-38.pyc │ │ ├── recognizer_builder.cpython-36.pyc │ │ ├── recognizer_builder.cpython-37.pyc │ │ ├── recognizer_builder.cpython-38.pyc │ │ ├── resnet_aster.cpython-36.pyc │ │ ├── resnet_aster.cpython-37.pyc │ │ ├── resnet_aster.cpython-38.pyc │ │ ├── sequenceCrossEntropyLoss.cpython-36.pyc │ │ ├── sequenceCrossEntropyLoss.cpython-37.pyc │ │ ├── sequenceCrossEntropyLoss.cpython-38.pyc │ │ ├── stn_head.cpython-36.pyc │ │ ├── stn_head.cpython-37.pyc │ │ ├── stn_head.cpython-38.pyc │ │ ├── tps_spatial_transformer.cpython-36.pyc │ │ ├── tps_spatial_transformer.cpython-37.pyc │ │ └── tps_spatial_transformer.cpython-38.pyc │ ├── attention_recognition_head.py │ ├── recognizer_builder.py │ ├── resnet_aster.py │ ├── sequenceCrossEntropyLoss.py │ ├── stn_head.py │ └── tps_spatial_transformer.py ├── resnet.py ├── stn_head.py ├── tps_spatial_transformer.py └── transformer.py ├── setup.py └── utils ├── __init__.py ├── labelmaps.py ├── metrics.py ├── ssim_psnr.py ├── util.py ├── utils_crnn.py ├── utils_deblur.py ├── utils_image.py ├── utils_moran.py └── utils_sisr.py /README.md: -------------------------------------------------------------------------------- 1 | # LEMMA 2 | [![arXiv](https://img.shields.io/badge/arXiv-Paper-.svg)](https://arxiv.org/pdf/2307.09749.pdf) 3 | 4 | An official Pytorch implement of the paper "Towards Robust Scene Text Image Super-resolution via Explicit Location Enhancement" (IJCAI2023). 5 | 6 | Authors: *Hang Guo, Tao Dai, Guanghao Meng, and Shu-Tao Xia* 7 | 8 | This work proposes the Location Enhanced Multi-ModAl network (LEMMA) to address the challenges posed by complex backgrounds in scene text images with explicit positional enhancement. The architecture of LEMMA is as follows. 9 | 10 | ![LEMMA-pipeline](https://github.com/csguoh/LEMMA/blob/main/assets/LEMMA-pipeline.png) 11 | 12 | 13 | ## Pre-trained Model 14 | 15 | As the previous code is a bit of a mess, we re-organize the code and retrain our LEMMA. The performance of this re-trained model is as follows (better performance than that reported in the paper). 16 | 17 | | Text Recognizer | Easy | Medium | Hard | avgAcc | 18 | | :-------------: | :----: | :----: | :----: | :----: | 19 | | CRNN | 64.98% | 59.89% | 43.48% | 56.73% | 20 | | MORAN | 76.90% | 64.28% | 46.84% | 63.60% | 21 | | ASTER | 81.53% | 67.40% | 48.85% | 66.93% | 22 | 23 | One can download this model using this [link](https://drive.google.com/file/d/1iuVc0fh5rQAT2Ep5KgyV1GnsXPkdI4V0/view?usp=share_link) which contains the parameters of both the super-resolution brach and guidance generation branch. 24 | 25 | The log file of training is also available with this [link](https://drive.google.com/file/d/1xtNnJ3gUXO1FSaebn2_rCKlcOfNBZR6o/view?usp=share_link). 26 | 27 | 28 | 29 | ## Prepare Datasets 30 | 31 | In this work, we use STISR datasets TextZoom and four STR benchmarks, i.e., ICDAR2015, CUTE80, SVT and SVTP for model comparison. All the datasets are `lmdb` format. One can download these datasets from the this [link](https://drive.google.com/drive/folders/1uqr8WIEM2xRs-K6I9KxtOdjcSoDWqJNJ?usp=share_link) we have prepared for you. And please do not forget to accustom your own dataset path in `./comfig.yaml` , such as the parameter `train_data_dir` and `val_data_dir`. 32 | 33 | 34 | 35 | ## Text Recognizers 36 | 37 | Following previous STISR works, we also use [CRNN](https://github.com/meijieru/crnn.pytorch), [MORAN](https://github.com/Canjie-Luo/MORAN_v2 ) and [ASTER](https://github.com/ayumiymk/aster.pytorch) as the downstream text recognizer. 38 | 39 | Moreover, the code also supports some new text recognizers, such as [ABINet](https://github.com/FangShancheng/ABINet), [MATRN](https://github.com/byeonghu-na/MATRN) and [PARSeq](https://github.com/baudm/parseq). You can find the detailed comparison using these three new text recognizers in the supplementary material we provided and can also test LEMMA with these recognizers by modifying the command (e.g., `--test_model='ABINet'`). Please download these pre-trained text recognition models from the corresponding repositories we have provided above. 40 | 41 | You also need to modify the text recognizer model path in the `./config.yaml` file. Moreover, we employ the text focus loss proposed by STT during model training, since this text focus loss uses a pre-trained transformer based text recognizer, please download this recognition model [here](https://drive.google.com/file/d/1HRpzveBbnJPQn3-k_y2Y1YY4PcraWOFP/view?usp=drive_link) and also accustom the ckpt path. 42 | 43 | 44 | 45 | ## How to Run? 46 | 47 | We have set some default hype-parameters in the `config.yaml` and `main.py`, so you can directly implement training and testing after you modify the path of datasets and pre-trained model. 48 | 49 | ### Training 50 | 51 | ``` 52 | python main.py 53 | ``` 54 | 55 | ### Testing 56 | 57 | ``` 58 | python main.py --test 59 | ``` 60 | 61 | **NOTE:** You can also auccstom other hype-parameters in the `config.yaml` and `main.py` file, such as the `n_gpu`. 62 | 63 | 64 | 65 | ## Main Results 66 | 67 | ### Quantitative Comparison 68 | 69 | ![quantitative-comparison](https://github.com/csguoh/LEMMA/blob/main/assets/quantitative-comparison.png) 70 | 71 | 72 | 73 | ### Qualitative Comparison 74 | 75 | ![qualitative-comparison](https://github.com/csguoh/LEMMA/blob/main/assets/qualitative-comparison.png) 76 | 77 | 78 | 79 | ## Citation 80 | 81 | If you find our work helpful, please consider citing us. 82 | 83 | ``` 84 | @inproceedings{ijcai2023p87, 85 | title = {Towards Robust Scene Text Image Super-resolution via Explicit Location Enhancement}, 86 | author = {Guo, Hang and Dai, Tao and Meng, Guanghao and Xia, Shu-Tao}, 87 | booktitle = {Proceedings of the Thirty-Second International Joint Conference on 88 | Artificial Intelligence, {IJCAI-23}}, 89 | publisher = {International Joint Conferences on Artificial Intelligence Organization}, 90 | pages = {782--790}, 91 | year = {2023}, 92 | month = {8}, 93 | note = {Main Track}, 94 | doi = {10.24963/ijcai.2023/87}, 95 | url = {https://doi.org/10.24963/ijcai.2023/87}, 96 | } 97 | ``` 98 | 99 | 100 | 101 | ## Acknowledgement 102 | 103 | The code of this work is based on [TBSRN](https://github.com/FudanVI/FudanOCR/tree/main/scene-text-telescope), [TATT](https://github.com/mjq11302010044/TATT), and [C3-STISR](https://github.com/JingyeChen/C3-STISR). Thanks for your contributions. 104 | -------------------------------------------------------------------------------- /al_chinese.txt: -------------------------------------------------------------------------------- 1 | !"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\]^_`abcdefghijklmnopqrstuvwxyz{|}~一丁七万丈三上下不与丑专且世丘丙业丛东丝両丢两严丨个丫中丰串临丶丸丹为主丽举乃久么义之乌乎乏乐乒乓乔乖乘乙九也习乡书买乱乳乾亂了予争事二于亏云互五井亚些亜亞亡交亦产亨亩享京亭亮亲亵人亿什仁仅仇今介从仑仓仔仕他付仙仝仞仟代令以仪们仰仲件价任份仿企伊伍伏伐休众优伙会伞伟传伤伦伪伯伱伴伸伺似伽但位低住佐佑体何余佛作你佩佬佰佳使來侈例侍供依侠侣侦侧侨侬侯侵便係促俄俊俏俗保信俩俪俫修俯俱俺倆倍倒候倚借倡倩倪值倾假偉偌偏偕做停健偶偷偽偿傅傘備傢储催傲傳傷傻像僑僧僮僵價億儒儥優儱儿元兄充兆先光克免兑兒兔党兜入內全兩八公六兰共关兴兵其具典兹养兼兽冀内冈冉册再冏冒冕写军农冠冤冬冯冰冲决况冶冷冻冽净凄准凈凉凌凍减凖凛凝几凡凤処凭凯凰凱凳凶凸凹出击函凿刀刃分切刊刑划列刘则刚创初删判刨利别刮到制刷券刹刺刻剁剂剃削前剑剔剖剛剥剧剩剪副割創劑力办功加务劣动助努劫励劲劳劵势勁勃勇勋勐勒動務勝勢勤勺勾勿匀包匍匐化北匙匠匹区医區十千升午卉半卍华协卓单卖南博卜占卡卢卤卧卫印危即却卵卷卸卿厂厅历厉压厌厕厘厚原厢厦厨厲去县参又叉及友双反发叔取受变叛叟叠口古句另只叫召叭叮可台叱史右叶号司叻吃各合吉吊同名后吐向吓吕吖吗君吟否吧吨含听启吲吴吵吸吹吻吾呀呂呆呈告呐呕呗员呛呢呦周味呵呼命咀咆和咏咒咔咕咖咚咤咧咨咪咬咭咯咱咳咻咽咿品哄哆哇哈响哎哑哒哗哚哟員哥哦哨哩哪哭哮哲哺哼唇唐唑唛唢唤售唯唱唾啃啄商啊啖啞啡啤啥啦啪啫啰啶啸啾喂善喇喉喔喘喜喝喧喫喬單喰喱喵喷喽嗑嗓嗔嗖嗣嗨嗪嗮嗽嘉嘘嘟嘧嘴嘻嘿噜噢器噪噬嚣嚴嚼囊囍四回因囡团囤园困围固国图圆圈國園圖團土圣在圭地圳场圾址均坊坏坐块坚坛坠坡坤坦坨垂垃型垛垢垫埃埋城域埠培基堂堅堆堡堰報場堵塊塑塔塗塘塞填境墅墙增墨壁壊壕壞士壬壮声壳壶壹壽处备复夏夕外多夜够夢大天太夫央夯失夲头夷夹夺夾奂奇奈奉奋奏契奔奕奖套奢奥女奴奶她好如妃妆妇妈妊妍妖妙妝妞妥妩妮妳妹妻姆始姐姑姒姓委姗姚姜姨姬姿威娃娅娆娇娘娜娟娠娣娥娱娴娶婆婉婊婕婚婦婧婭婴婷媄媒媚媛媳媽嫁嫩嬉嬰嬷子孔孕字存孙孚孜孝孟孢季孤学孩孵學宁它宅宇守安宋完宏宗官宙定宛宜宝实宠审客宣室宪宫害宴家宸容宽宾宿寂寄密寇富寒寓寕寝寞察寡寢實寧寨寫寳寶寸对寺寻导対寿封専射将將專尊對導小尐少尔尖尘尚尝尤尬就尷尸尹尺尼尽尾尿局屁层居屆屉届屋屌屎屏屑展属層履山屹岁岂岑岗岚岛岩岭岳岸峡峦峰島峻崇崋崎崔崖崩崬崽嵌嵩巅巍川州巡巢巣工左巧巨巩巫差己已巴巾币市布帅帆师希帐帕帖帘帚帛帜帝帥带師席帮帯帶常帽幂幅幕幚干平年并幸幹幻幼幽广庄庆床序库应底店庙庚府庞废度座庭康庸廉廊廓廠廣延廷建廿开异弃弄式弓引弗弘弛弟张弥弦弧弯弱張強弹强彈归当录形彤彦彩彪彬彭彰影彻彼往征径待很律徐徒徕得從御復循微德徹徽心忄必忆忌忍忏忐忑志忘忙応忠忧快念忽怀态怎怒怕怖思怡急性怪总恊恋恍恐恒恢恤恨恩恭息恰恶恼悄悅悉悍悔悟悠患悦您悬悲悵情惆惊惑惜惟惠惧惫惬惯想惹愁愈愉意愛感愣愿慈慍慎慕慢慧慵憋憨憬懂懋懒懿戈戊戌戏成我戒或战截戰戳戴户房所扁扇手才扎扑扒打扔托扣执扩扫扬扭扮扯扰扳扶批找承技抄把抑抓投抖抗折抚抛抠抢护报披抱抵抹押抽担拆拉拌拍拎拐拒拓拔拖招拜拟拢拥拨择括拭拯拱拳拷拼拽拾拿持挂指按挎挑挖挚挞挠挡挣挤挥挪挫振挺捂捆捉捌捍捏捕捞损捡换捧据捶捷授掉掌排掘掟探接控推掸掺揉描提插揚換握揭援揽搁搅損搏搐搓搜搞搬搭携搽摄摆摇摈摊摔摘摩摸撃撇撑撒撕撞撬播撸撻撼擇操擎擒據擦攀攝支收攸改攻放政故效敌敏救教敛敢散敦敬数敲整敷數文斋斌斐斑斗料斛斜斤斧断斯新方施旁旅旋族旗无既日旦旧旨早旭时旺昆昇昊昌明易昔昕星映昡春昨昭是昱昵显時晃晋晒晓晔晕晖晗晚晞晟晨普景晰晳晴晶智晾暂暇暑暖暗暮暴曉曙曜曝曦曬曰曲更書曹曼曾替最會月有朋服朔朕朗望朝期朦木未末本术朱朴朵机杀杂权杆杉李杏材村杜杞束杠条来杨杭杯杰東松板极构析枕林枚果枝枞枣枪枫枭枯枰架枸枼柄柏某柑柒染柔柚柜柠查柯柱柳柴柿栅标栈栎栏树栓栖校株样核根格栽桁桂桃框案桉桌桐桑桔档桥桧桨桩桶梁梅梓梢梦梧梨梭梯械梳梵检棉棋棍棒棕棘棚棠森棵棺椅植椎椒椰椹楂楊楓楚楞楠楦業極楷楸楼楽概榄榉榔榕榛榜榨榴榻槍槐様槛槟槽樂樊樓標樟模樣横樱樹橄橋橘橙機橡橱橹檀檐檫檬櫥權次欢欣欧欲欺款歆歇歌歓歡止正此步武歪歲歳歷歸死殁殊残殖段殺殼殿毁毂毅母每毒比毕毛毡毫毯氏民氓气気氙氛氟氢氣氧氨氮氯水永汀汁求汇汉汐汕汗汛汝汞江池污汤汪汶決汽沁沂沃沈沉沌沏沐沖沙沟没沥沪沫沱河沸油治沽沾沿泄泉泊泌法泛泞泡波泣泥注泪泰泳泵泷泸泻泼泽泾洁洋洒洗洙洛洞津洪洱洲活洽派流浅浆浇浊测济浏浑浒浓浙浚浣浦浩浪浮浴海浸涂涇消涌涔涛涟涡涤润涧涨涩涮涯液涵淀淄淇淋淑淘淚淡淤淨淮深淳混添淼渄清渍渐渔渗減渝渠渡渣温測渭港渴游湃湓湖湘湛湯湾湿溃溅溉源溜溢溪溫溶溺滇滋滑滔滕滙滚满滢滤滨滩滴滿漂漆漏漓演漢漩漫漯漱漳漸漾潇潍潔潘潜潤潭潮澄澈澎澜澡澤澳激濃濠濯瀚瀛灌灣火灭灯灰灵灶灸灼灾灿炀炉炎炒炔炖炜炣炫炭炮炸点為炼炽烁烂烈烊烏烘烙烛烟烤烦烧烨烩烫热烯烷烹烽焉焊焕焖焗焙無焦焰焱然焼煊煌煎煖煙煜煞照煮煲煸熊熏熔熙熟熠熥熨熬熱熳燃燈燊燕燙營燥燳爆爪爬爱爲爵父爷爸爹爽爾片版牌牙牛牡牢牧物牵特犀犬犯状犸犹狀狂狄狐狒狗狠独狭狮狱狸狼猎猛猜猩猪猫猬献猴獅獨獭獸獻玄率玉王玖玛玥玩玫玮环现玲玷玺玻珀珂珊珍珏珑珞珠班珺現球琅理琐琢琥琦琪琬琮琯琳琴琼瑄瑕瑙瑚瑜瑞瑟瑪瑰瑶瑾璀璃璇璋璎璐璟璨環瓜瓣瓦瓮瓶瓷甄甘甜生產産用甩甫甬田由甲申电男甸町画畅界畔留畜略番畸疆疏疑疗疙疝疣疤疫疮疯疲疵疹疼疾病症痒痔痕痘痛痞痣痧痰痱痴瘊瘤瘦瘩療癣癫登發白百皂的皆皇皎皓皖皙皮皱皲皴皺盅盆盈益盎盏盐监盒盔盖盗盘盛盜盟盤目盲直相盼盾省眉看眞真眠眸眼着睁睛睡督睦睫睿瞅瞌瞎瞬瞰瞳瞿矣知矩矫短石矶矿码砂砍研砖砧砭破砸础硅硒硕硝硫硬确硼碁碌碍碎碑碗碘碟碧碩碰碱碳碼磁磅磊磐磕磨磷磺示礼社祁祈祖祛祝神祠祥票祸祺禁禅福禦禧禪禮禹离禾秀私秉秋种科秒秘租秤秦秧秩积称移稀程稍税稔稚稠種稳稻稿穆積穗穴究穷空穿突窃窄窈窍窕窖窗窘窝窥窦立竖站竞竟章童竭端竹竿笈笋笑笔笙笛符笨第笼筆等筋筏筐筑筒答策筛筝筠筷签简箍箐箔算管箭箱箸節範篇篮篷簡簧籁籍米类籽粉粒粕粗粘粤粥粧粪粮粵粹粽精糊糕糖糙糜糝糟糠糯系紀約紅紊紐純紗級素索紧紫累細紹終組給絮絲絶綁經綫維綱網緑緖緣編緯縣縫縮總繁織繹纂纇續纠红纤约级纪纫纬纯纱纲纳纵纶纷纸纹纺纽线练组绅细织终绍绎经绑绒结绕绗绘给绚络绝绞统绢绣继绨绩绪续绮绰绳维绵绷绸绻综绽绿缀缅缆缇缈缉缎缓缔缕编缘缚缝缠缤缥缦缨缩缴缸缺罂罄罐网罗罘罚罩罪置署羅羊美羔羙羚羞羡群義羽翁翅翌翎習翔翘翟翠翡翰翻翼耀老考者而耍耐耕耗耦耳耻聆聊职联聖聘聚聪聯聲聿肃肆肉肌肖肘肚肝肠股肢肤肥肩肪肯育肴肺肽肾肿胀胃胆背胎胖胚胜胞胡胤胧胰胶胸胺能脂脆脉脊脏脐脑脓脖脚脱脸脾腊腋腌腐腔腕腩腮腰腱腹腺腻腾腿膀膏膚膛膜膝膠膨膳臀臂臣臥臨自臭至致臺臻舆與興舊舌舍舒舔舖舜舞舟航般舰舱舵舶船艇艦良色艳艶艷艺艾节芃芊芋芍芒芙芜芝芥芦芩芪芬芭芮芯花芳芷芸芹芽苇苍苏苑苓苔苗苡若苦苪苯英苹茂范茄茉茗茜茧茨茬茯茵茶茸荀荆草荐荒荔荞荟荡荣荧药荷莆莉莊莎莓莞莫莱莲获莹菀菁菇菊菌菜菠菩華菱菲萃萄萌萍萎萝萤营萧萨萬萱落葆葉著葛葡董葫葱葵蒂蒄蒋蒙蒜蒟蒲蒸蒻蓄蓉蓓蓝蓬蔓蔗蔘蔡蔬蔻蔽蕉蕊蕒蕨蕲蕴蕾薄薇薈薏薦薩薪薬薯薰藍藏藓藕藝藤藥藻藿蘆蘇蘋蘑蘭虎虐虑虚虞號虫虱虹虽虾蚀蚁蚂蚊蚌蚕蚤蚪蛀蛇蛊蛋蛎蛔蛛蛤蛭蛮蛰蛲蛳蜀蜂蜕蜗蜘蜜蜡蜱蝇蝌蝎蝠蝴蝶螂融螨螯螺蟀蟆蟋蟑蟹蠔蠕蠶血衆行衍術衔街衛衡衣补表衫衬衰袁袆袋袍袖袜被袭裁裂装裆裔裕裘裙補裝裤裱裳裸裹製褂褐褚褥褪褲褶襟西要覆見親觀见观规觅视览觉角解触言訂計訊記訣設詢試詩該誉誌誓語誠說説調論諾謢謹證識警護讀计订认让训议讯记讲许论设访证评识诉诊词译试诗诚话诞诠询该详语误诱说诵请诸诺读课谁调谅谈谊谋谐谓谜谢谣谦谧谨谭谱谷豆豇豌豐豚象豪豫豹豺貂貅貌貔貝財貨販貳貴費貼貿賀資賞賠賣質賴購贅贈贝负贡财责贤败账货质贩购贮贯贰贱贴贵贸费贺贼贾赁资赋赌赏赐赔赖赚赛赞赠赢赣赤赫走赴赵赶起超越趋趣足跃跌跑距跟跨跪路跳践踏踩踪蹄蹈蹭蹲躁身躲躺車軍軒軽載輔輕輝輪輼车轧轨轩转轮软轴轻载轿较辅辆辈辉辐辑输辛辜辞辟辣辦辧辨辩辫辰边辽达迁迅过迈迎运近返还这进远违连迟迪迫迭述迷迹追退送适逃逅逆选逊逍透逐递途逗這通逛逝速造逢連進逸逹逺逻逼遂遇運遍過道達遗遠遥適遮遵避邀邂還邢那邦邪邮邯邵邹邻郁郊郎郑郝部郭郵郸都鄂鄉酉酊配酐酒酚酥酪酬酮酯酱酵酶酷酸酿醇醉醋醒醛醫醬醯釀采釉释里重野量金釜針鉄鉅鉴銀銑銷鋒鋪鋼錄錯錶鍋鍵鍾鎏鎖鎢鏜鏡鐘鐡鐵鑫针钉钊钒钓钙钛钜钞钟钠钢钣钥钦钧钨钩钮钰钱钳钻钼钽钾铁铂铃铅铆铎铛铜铝铠铡铣铬铭铰铲银铸铺链销锁锂锅锆锈锋锌锐锗错锡锣锤锥锦键锯锰锳锵锶锻镀镁镂镇镊镌镍镜镭镰镶長长門閃開間閟関閣閩關门闪闭问闯闰闲间闷闸闹闺闻闽阀阁阅阎阔阙阜队阱防阳阴阵阶阻阿陀附际陆陇陈陌降限陕院除险陪陰陳陵陶陷陽隅隆隋随隐隔隙際障难雀雁雄雅集雇雌雍雑雕雙雜離難雨雪雯雲零雷電雾需霆震霉霏霓霖霜霞露霸霾靈靑青靓靖静靜非靠靡面革靴鞋鞘鞭韓韦韧韩韬音韵韶順頓頭頰顆題額顏顔類顯页顶项顺须顽顾顿颁颂预领颈颊颐频颓颖颗题颜额颠風飄风飘飛飞食飮飴飼飾餅養餐餓餠館饅饥饪饭饮饰饱饲饵饶饼饿馅馆馈馏馒首香馥馨馬馴騷驗马驭驰驱驴驶驻驼驾驿骄骆骊验骏骐骑骗骚骤骨骷骼髅體高髙髪髮鬃鬓鬚鬼魁魂魄魅魏魔魚鮮鯉鯽鱸鱼鲁鲍鲜鲟鲤鲨鲩鲶鲷鲸鳄鳍鳕鳞鳥鳯鴨鴻鵬鷄鸟鸡鸢鸣鸥鸦鸭鸯鸳鸽鸾鸿鹃鹄鹅鹉鹏鹤鹦鹭鹰鹿麒麓麗麝麥麦麯麺麻黃黄黎黏黑黒黔默黛點黯鼎鼓鼙鼠鼻齊齐齿龄龈龋龍龙龚龟 -------------------------------------------------------------------------------- /assets/LEMMA-pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/assets/LEMMA-pipeline.png -------------------------------------------------------------------------------- /assets/Supplementary_material_for_LEMMA.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/assets/Supplementary_material_for_LEMMA.pdf -------------------------------------------------------------------------------- /assets/qualitative-comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/assets/qualitative-comparison.png -------------------------------------------------------------------------------- /assets/quantitative-comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/assets/quantitative-comparison.png -------------------------------------------------------------------------------- /config/super_resolution.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | 3 | 4 | TRAIN: 5 | pretrained_trans: '/data/guohang/pretrained/pretrain_transformer.pth' 6 | train_data_dir: [ 7 | '/data/guohang/dataset/TextZoom/train1', 8 | '/data/guohang/dataset/TextZoom/train2', 9 | ] 10 | batch_size: 64 11 | width: 128 12 | height: 32 13 | epochs: 500 14 | cuda: True 15 | ngpu: 2 16 | workers: 0 17 | resume: '/data/guohang/LEMMA/ckpt/demo/LEMMA-release.pth' 18 | ckpt_dir: './ckpt/with_test' 19 | voc_type: 'all' #'digits lower upper all' 20 | saveInterval: 200 21 | displayInterval: 50 #display loss 22 | adadelta: False 23 | lr: 0.001 24 | adam: True 25 | optimizer: "Adam" 26 | beta1: 0.5 27 | manualSeed: 1234 28 | max_len: 100 29 | keep_ratio: False 30 | down_sample_scale: 2 31 | 32 | VAL: 33 | val_data_dir: [ 34 | # '/data/guohang/dataset/IC15', 35 | # '/data/guohang/dataset/CUTE80', 36 | # '/data/guohang/dataset/SVTP', 37 | # '/data/guohang/dataset/SVT', 38 | '/data/guohang/dataset/TextZoom/test/easy', 39 | '/data/guohang/dataset/TextZoom/test/medium', 40 | '/data/guohang/dataset/TextZoom/test/hard', 41 | ] 42 | n_vis: 10 43 | vis_dir: 'demo' 44 | valInterval: 400 45 | rec_pretrained: '/data/guohang/pretrained/aster.pth.tar' 46 | moran_pretrained: '/data/guohang/pretrained/moran.pth' 47 | crnn_pretrained: '/data/guohang/pretrained/crnn.pth' 48 | 49 | TEST: 50 | checkpoint: '' 51 | test_data_dir: [ 52 | ] 53 | 54 | CONVERT: 55 | image_dir: 56 | lmdb_dir: 57 | n_convert: 10 58 | 59 | 60 | PositionAware: 61 | dataset_max_length: 25 62 | dataset_charset_path: './dataset/charset_36.txt' 63 | model_vision_attention_mode: 'nearest' 64 | vision: { 65 | checkpoint: '/data/guohang/pretrained/ABINet-pretrained/pretrain-vision-model/best-pretrain-vision-model.pth', 66 | loss_weight: 1., 67 | attention: 'position', 68 | backbone: 'transformer', 69 | backbone_ln: 3, 70 | d_model: 512 71 | } 72 | language: { 73 | checkpoint: '/data/guohang/pretrained/ABINet-pretrained/pretrain-language-model/pretrain-language-model.pth', 74 | num_layers: 4, 75 | loss_weight: 1., 76 | detach: True, 77 | use_self_attn: False 78 | } 79 | 80 | 81 | ABINet: 82 | dataset_max_length: 25 83 | dataset_charset_path: './dataset/charset_36.txt' 84 | model_vision_attention_mode: 'nearest' 85 | full_ckpt: '/data/guohang/pretrained/ABINet-pretrained/train-abinet/best-train-abinet.pth' 86 | vision: { 87 | checkpoint: '/data/guohang/pretrained/ABINet-pretrained/pretrain-vision-model/best-pretrain-vision-model.pth', 88 | loss_weight: 1., 89 | attention: 'position', 90 | backbone: 'transformer', 91 | backbone_ln: 3, 92 | d_model: 512 93 | } 94 | language: { 95 | checkpoint: '/data/guohang/pretrained/ABINet-pretrained/pretrain-language-model/pretrain-language-model.pth', 96 | num_layers: 4, 97 | loss_weight: 1., 98 | detach: True, 99 | use_self_attn: False 100 | } 101 | 102 | 103 | 104 | MATRN: 105 | dataset_charset_path: './dataset/charset_36.txt' 106 | dataset_max_length: 25 107 | model_vision_attention_mode: 'nearest' 108 | full_ckpt: '/data/guohang/pretrained/ABINet-pretrained/best-train-matrn.pth' 109 | vision: { 110 | checkpoint: , 111 | attention: 'position', 112 | backbone: 'transformer', 113 | backbone_ln: 3, 114 | d_model: 512 115 | } 116 | language: { 117 | checkpoint: , 118 | num_layers: 4, 119 | detach: True, 120 | use_self_attn: False 121 | } 122 | 123 | PARSeq: 124 | full_ckpt: '/data/guohang/pretrained/PARSeq.pth' 125 | img_size: [32,128] 126 | patch_size: [4,8] 127 | embed_dim: 384 128 | enc_depth: 12 129 | enc_num_heads: 6 130 | enc_mlp_ratio: 4 131 | 132 | self.max_label_length: 25 133 | self.decode_ar: True 134 | self.refine_iters: 1 135 | 136 | dec_num_heads: 12 137 | dec_mlp_ratio: 4 138 | dropout: 0.1 139 | dec_depth: 1 140 | perm_num: 6 141 | perm_mirrored: True 142 | max_label_length: 25 143 | 144 | 145 | 146 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import * 2 | -------------------------------------------------------------------------------- /dataset/charset_36.txt: -------------------------------------------------------------------------------- 1 | 0 a 2 | 1 b 3 | 2 c 4 | 3 d 5 | 4 e 6 | 5 f 7 | 6 g 8 | 7 h 9 | 8 i 10 | 9 j 11 | 10 k 12 | 11 l 13 | 12 m 14 | 13 n 15 | 14 o 16 | 15 p 17 | 16 q 18 | 17 r 19 | 18 s 20 | 19 t 21 | 20 u 22 | 21 v 23 | 22 w 24 | 23 x 25 | 24 y 26 | 25 z 27 | 26 1 28 | 27 2 29 | 28 3 30 | 29 4 31 | 30 5 32 | 31 6 33 | 32 7 34 | 33 8 35 | 34 9 36 | 35 0 -------------------------------------------------------------------------------- /dataset/confuse.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/dataset/confuse.pkl -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import argparse 3 | import os 4 | from easydict import EasyDict 5 | from interfaces.super_resolution import TextSR 6 | from setup import Logger 7 | 8 | 9 | def main(config, args): 10 | Mission = TextSR(config, args) 11 | if args.test: 12 | Mission.test() 13 | else: 14 | Mission.train() 15 | 16 | 17 | if __name__ == '__main__': 18 | parser = argparse.ArgumentParser(description='') 19 | parser.add_argument('--config', type=str, default='super_resolution.yaml') 20 | parser.add_argument('--test', action='store_true', default=False) 21 | parser.add_argument('--STN', action='store_true', default=True, help='') 22 | parser.add_argument('--srb', type=int, default=5, help='') 23 | parser.add_argument('--mask', action='store_true', default=True, help='') 24 | parser.add_argument('--demo_dir', type=str, default='./demo') 25 | parser.add_argument('--test_model', type=str, default='CRNN', choices=['ASTER', "CRNN", "MORAN",'ABINet','MATRN','PARSeq']) 26 | parser.add_argument('--learning_rate', type=float, default=0.001, help='') 27 | parser.add_argument('--lr_position', type=float, default=1e-4, help='fine tune for position aware module') 28 | args = parser.parse_args() 29 | config_path = os.path.join('config', args.config) 30 | config = yaml.load(open(config_path, 'rb'), Loader=yaml.Loader) 31 | config = EasyDict(config) 32 | config.TRAIN.lr = args.learning_rate 33 | parser_TPG = argparse.ArgumentParser() 34 | Logger.init('logs', 'LEMMA', 'train') 35 | Logger.enable_file() 36 | main(config, args) 37 | -------------------------------------------------------------------------------- /model/ABINet/__pycache__/abinet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/ABINet/__pycache__/abinet.cpython-38.pyc -------------------------------------------------------------------------------- /model/ABINet/__pycache__/attention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/ABINet/__pycache__/attention.cpython-38.pyc -------------------------------------------------------------------------------- /model/ABINet/__pycache__/backbone.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/ABINet/__pycache__/backbone.cpython-38.pyc -------------------------------------------------------------------------------- /model/ABINet/__pycache__/resnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/ABINet/__pycache__/resnet.cpython-38.pyc -------------------------------------------------------------------------------- /model/ABINet/__pycache__/transformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/ABINet/__pycache__/transformer.cpython-38.pyc -------------------------------------------------------------------------------- /model/ABINet/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .transformer import PositionalEncoding 4 | 5 | class Attention(nn.Module): 6 | def __init__(self, in_channels=512, max_length=25, n_feature=256): 7 | super().__init__() 8 | self.max_length = max_length 9 | 10 | self.f0_embedding = nn.Embedding(max_length, in_channels) 11 | self.w0 = nn.Linear(max_length, n_feature) 12 | self.wv = nn.Linear(in_channels, in_channels) 13 | self.we = nn.Linear(in_channels, max_length) 14 | 15 | self.active = nn.Tanh() 16 | self.softmax = nn.Softmax(dim=2) 17 | 18 | def forward(self, enc_output): 19 | enc_output = enc_output.permute(0, 2, 3, 1).flatten(1, 2) 20 | reading_order = torch.arange(self.max_length, dtype=torch.long, device=enc_output.device) 21 | reading_order = reading_order.unsqueeze(0).expand(enc_output.size(0), -1) # (S,) -> (B, S) 22 | reading_order_embed = self.f0_embedding(reading_order) # b,25,512 23 | 24 | t = self.w0(reading_order_embed.permute(0, 2, 1)) # b,512,256 25 | t = self.active(t.permute(0, 2, 1) + self.wv(enc_output)) # b,256,512 26 | 27 | attn = self.we(t) # b,256,25 28 | attn = self.softmax(attn.permute(0, 2, 1)) # b,25,256 29 | g_output = torch.bmm(attn, enc_output) # b,25,512 30 | return g_output, attn.view(*attn.shape[:2], 8, 32) 31 | 32 | 33 | def encoder_layer(in_c, out_c, k=3, s=2, p=1): 34 | return nn.Sequential(nn.Conv2d(in_c, out_c, k, s, p), 35 | nn.BatchNorm2d(out_c), 36 | nn.ReLU(True)) 37 | 38 | def decoder_layer(in_c, out_c, k=3, s=1, p=1, mode='nearest', scale_factor=None, size=None): 39 | align_corners = None if mode=='nearest' else True 40 | return nn.Sequential(nn.Upsample(size=size, scale_factor=scale_factor, 41 | mode=mode, align_corners=align_corners), 42 | nn.Conv2d(in_c, out_c, k, s, p), 43 | nn.BatchNorm2d(out_c), 44 | nn.ReLU(True)) 45 | 46 | 47 | class PositionAttention(nn.Module): 48 | def __init__(self, max_length, in_channels=512, num_channels=64, 49 | h=8, w=32, mode='nearest', **kwargs): 50 | super().__init__() 51 | self.max_length = max_length 52 | self.k_encoder = nn.Sequential( 53 | encoder_layer(in_channels, num_channels, s=(1, 2)), 54 | encoder_layer(num_channels, num_channels, s=(2, 2)), 55 | encoder_layer(num_channels, num_channels, s=(2, 2)), 56 | encoder_layer(num_channels, num_channels, s=(2, 2)) 57 | ) 58 | self.k_decoder = nn.Sequential( 59 | decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode), 60 | decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode), 61 | decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode), 62 | decoder_layer(num_channels, in_channels, size=(h, w), mode=mode) 63 | ) 64 | 65 | self.pos_encoder = PositionalEncoding(in_channels, dropout=0, max_len=max_length) 66 | self.project = nn.Linear(in_channels, in_channels) 67 | 68 | def forward(self, x): 69 | N, E, H, W = x.size() 70 | k, v = x, x # (N, E, H, W) 71 | 72 | # calculate key vector 73 | features = [] 74 | for i in range(0, len(self.k_encoder)): 75 | k = self.k_encoder[i](k) 76 | features.append(k) 77 | for i in range(0, len(self.k_decoder) - 1): 78 | k = self.k_decoder[i](k) 79 | k = k + features[len(self.k_decoder) - 2 - i] 80 | k = self.k_decoder[-1](k) 81 | 82 | # calculate query vector 83 | # TODO q=f(q,k) 84 | zeros = x.new_zeros((self.max_length, N, E)) # (T, N, E) 85 | q = self.pos_encoder(zeros) # (T, N, E) 86 | q = q.permute(1, 0, 2) # (N, T, E) 87 | q = self.project(q) # (N, T, E) 88 | 89 | # calculate attention 90 | attn_scores = torch.bmm(q, k.flatten(2, 3)) # (N, T, (H*W)) 91 | attn_scores = attn_scores / (E ** 0.5) 92 | attn_scores = torch.softmax(attn_scores, dim=-1) 93 | 94 | v = v.permute(0, 2, 3, 1).view(N, -1, E) # (N, (H*W), E) 95 | attn_vecs = torch.bmm(attn_scores, v) # (N, T, E) 96 | 97 | return attn_vecs, attn_scores.view(N, -1, H, W) 98 | -------------------------------------------------------------------------------- /model/ABINet/backbone.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from model.ABINet.resnet import resnet45 5 | from model.ABINet.transformer import (PositionalEncoding, 6 | TransformerEncoder, 7 | TransformerEncoderLayer) 8 | 9 | _default_tfmer_cfg = dict(d_model=512, nhead=8, d_inner=2048, # 1024 10 | dropout=0.1, activation='relu',num_layers=3) 11 | 12 | 13 | class ResTranformer(nn.Module): 14 | def __init__(self, config): 15 | super().__init__() 16 | self.resnet = resnet45() 17 | 18 | self.d_model = _default_tfmer_cfg['d_model'] 19 | nhead = _default_tfmer_cfg['nhead'] 20 | d_inner = _default_tfmer_cfg['d_inner'] 21 | dropout = _default_tfmer_cfg['dropout'] 22 | activation = _default_tfmer_cfg['activation'] 23 | num_layers = _default_tfmer_cfg['num_layers'] 24 | 25 | self.pos_encoder = PositionalEncoding(self.d_model, max_len=8*32) 26 | encoder_layer = TransformerEncoderLayer(d_model=self.d_model, nhead=nhead, 27 | dim_feedforward=d_inner, dropout=dropout, activation=activation) 28 | self.transformer = TransformerEncoder(encoder_layer, num_layers) 29 | 30 | def forward(self, images,label_strs): 31 | # Resnet45 + Transformer Encoder 32 | # 相当于用一个小的Resnet+Trans 来替代原来大的backbone Resnet 33 | feature = self.resnet(images,label_strs) 34 | n, c, h, w = feature.shape 35 | feature = feature.view(n, c, -1).permute(2, 0, 1) 36 | feature = self.pos_encoder(feature) # add PE 37 | feature = self.transformer(feature) # encoder 38 | feature = feature.permute(1, 2, 0).view(n, c, h, w) 39 | return feature 40 | 41 | 42 | -------------------------------------------------------------------------------- /model/ABINet/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.utils.model_zoo as model_zoo 6 | 7 | 8 | def conv1x1(in_planes, out_planes, stride=1): 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 10 | 11 | 12 | def conv3x3(in_planes, out_planes, stride=1): 13 | "3x3 convolution with padding" 14 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 15 | padding=1, bias=False) 16 | 17 | 18 | class BasicBlock(nn.Module): 19 | expansion = 1 20 | 21 | def __init__(self, inplanes, planes, stride=1, downsample=None): 22 | super(BasicBlock, self).__init__() 23 | self.conv1 = conv1x1(inplanes, planes) 24 | self.bn1 = nn.BatchNorm2d(planes) 25 | self.relu = nn.ReLU(inplace=True) 26 | self.conv2 = conv3x3(planes, planes, stride) 27 | self.bn2 = nn.BatchNorm2d(planes) 28 | self.downsample = downsample 29 | self.stride = stride 30 | 31 | def forward(self, x): 32 | residual = x 33 | 34 | out = self.conv1(x) 35 | out = self.bn1(out) 36 | out = self.relu(out) 37 | 38 | out = self.conv2(out) 39 | out = self.bn2(out) 40 | 41 | if self.downsample is not None: 42 | residual = self.downsample(x) 43 | 44 | out += residual 45 | out = self.relu(out) 46 | 47 | return out 48 | 49 | 50 | class ResNet(nn.Module): 51 | 52 | def __init__(self, block, layers): 53 | self.inplanes = 32 54 | super(ResNet, self).__init__() 55 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, 56 | bias=False) 57 | self.bn1 = nn.BatchNorm2d(32) 58 | self.relu = nn.ReLU(inplace=True) 59 | 60 | self.layer1 = self._make_layer(block, 32, layers[0], stride=2) 61 | self.layer2 = self._make_layer(block, 64, layers[1], stride=1) 62 | self.layer3 = self._make_layer(block, 128, layers[2], stride=2) 63 | self.layer4 = self._make_layer(block, 256, layers[3], stride=1) 64 | self.layer5 = self._make_layer(block, 512, layers[4], stride=1) 65 | 66 | for m in self.modules(): 67 | if isinstance(m, nn.Conv2d): 68 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 69 | m.weight.data.normal_(0, math.sqrt(2. / n)) 70 | elif isinstance(m, nn.BatchNorm2d): 71 | m.weight.data.fill_(1) 72 | m.bias.data.zero_() 73 | 74 | def _make_layer(self, block, planes, blocks, stride=1): 75 | downsample = None 76 | if stride != 1 or self.inplanes != planes * block.expansion: 77 | downsample = nn.Sequential( 78 | nn.Conv2d(self.inplanes, planes * block.expansion, 79 | kernel_size=1, stride=stride, bias=False), 80 | nn.BatchNorm2d(planes * block.expansion), 81 | ) 82 | 83 | layers = [] 84 | layers.append(block(self.inplanes, planes, stride, downsample)) 85 | self.inplanes = planes * block.expansion 86 | for i in range(1, blocks): 87 | layers.append(block(self.inplanes, planes)) 88 | 89 | return nn.Sequential(*layers) 90 | 91 | def forward(self, x,label_strs): 92 | x = self.conv1(x) 93 | x = self.bn1(x) 94 | x = self.relu(x) 95 | x = self.layer1(x) 96 | x = self.layer2(x) 97 | x = self.layer3(x) 98 | x = self.layer4(x) 99 | if label_strs is not None: 100 | for bid in range(x.shape[0]): 101 | visulize_all_channel_into_one(x[bid],label_strs[bid]) 102 | x = self.layer5(x) 103 | return x 104 | 105 | 106 | def resnet45(): 107 | return ResNet(BasicBlock, [3, 4, 6, 6, 3]) 108 | 109 | 110 | def visulize_all_channel_into_one(feature_map,label): 111 | import numpy as np 112 | import matplotlib.pyplot as plt 113 | output = feature_map 114 | 115 | output = output.data.squeeze() 116 | output = output.cpu().numpy() 117 | 118 | output = np.mean(output, axis=0) 119 | 120 | height, width = 32, 128 121 | times = height / float(width) 122 | plt.rcParams["figure.figsize"] = (1, times) 123 | plt.axis('off') 124 | plt.imshow(output, cmap='jet', interpolation='bilinear') 125 | label=label.split('/')[0]#special case 126 | if "The" in label or '11:00am' in label : 127 | return 128 | plt.savefig('E:/feature-viz/C3/{}.png'.format(label), dpi=3 * height) -------------------------------------------------------------------------------- /model/MATRN/__pycache__/matrn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/MATRN/__pycache__/matrn.cpython-38.pyc -------------------------------------------------------------------------------- /model/MATRN/__pycache__/sematic_visual_backbone.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/MATRN/__pycache__/sematic_visual_backbone.cpython-38.pyc -------------------------------------------------------------------------------- /model/MATRN/matrn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from model.ABINet.abinet import BaseVision,BaseAlignment,BCNLanguage,Model 4 | from model.MATRN.sematic_visual_backbone import BaseSemanticVisual_backbone_feature 5 | import torch.nn.functional as F 6 | import logging 7 | 8 | class MATRN(Model): 9 | def __init__(self, config): 10 | super().__init__(config) 11 | self.iter_size = 3 12 | self.test_bh = None 13 | self.vision = BaseVision(config) 14 | self.language = BCNLanguage(config) 15 | self.semantic_visual = BaseSemanticVisual_backbone_feature(config) 16 | self.max_length = config.dataset_max_length + 1 # additional stop token 17 | self.mean = torch.tensor([0.485, 0.456, 0.406]) 18 | self.std = torch.tensor([0.229, 0.224, 0.225]) 19 | # load full model--> Vision Language Align 20 | if config.full_ckpt is not None: 21 | logging.info(f'Read full ckpt model from {config.full_ckpt}.') 22 | self.load(config.full_ckpt) 23 | 24 | 25 | def forward(self, images,input_lr=False,normalize=True): 26 | device = images.device 27 | if images.shape[2] == 16: 28 | images = F.interpolate(images, scale_factor=2, mode='bicubic', align_corners=True) 29 | self.mean = self.mean.to(device) 30 | self.std = self.std.to(device) 31 | if normalize: 32 | images =(images-self.mean[..., None, None] ) / self.std[..., None, None] 33 | 34 | v_res = self.vision(images) 35 | a_res = v_res 36 | for _ in range(self.iter_size): 37 | tokens = torch.softmax(a_res['logits'], dim=-1) 38 | lengths = a_res['pt_lengths'] 39 | lengths.clamp_(2, self.max_length) 40 | l_res = self.language(tokens, lengths) 41 | 42 | lengths_l = l_res['pt_lengths'] 43 | lengths_l.clamp_(2, self.max_length) 44 | 45 | v_attn_input = v_res['attn_scores'].clone().detach() 46 | l_logits_input = None 47 | texts_input = None 48 | 49 | a_res = self.semantic_visual(l_res['feature'], v_res['backbone_feature'], lengths_l=lengths_l, v_attn=v_attn_input, l_logits=l_logits_input, texts=texts_input, training=self.training) 50 | 51 | # TODO 和ABINet一样,这里Matrn直接把logits进行了互换 52 | v_res['logits'] = a_res['logits'] 53 | 54 | return v_res 55 | 56 | 57 | 58 | -------------------------------------------------------------------------------- /model/MATRN/sematic_visual_backbone.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | import numpy as np 5 | 6 | from model.ABINet.attention import * 7 | from model.ABINet.abinet import Model, _default_tfmer_cfg 8 | from model.ABINet.transformer import (PositionalEncoding, 9 | TransformerEncoder, 10 | TransformerEncoderLayer) 11 | 12 | # 这个模型全部不能拿来直接用,这是改的关键! 13 | class BaseSemanticVisual_backbone_feature(Model): 14 | def __init__(self, config): 15 | super().__init__(config) 16 | d_model = _default_tfmer_cfg['d_model'] 17 | nhead = _default_tfmer_cfg['nhead'] 18 | d_inner = _default_tfmer_cfg['d_inner'] 19 | dropout = _default_tfmer_cfg['dropout'] 20 | activation = _default_tfmer_cfg['activation'] 21 | num_layers = 2 22 | self.mask_example_prob = 0.9 23 | self.mask_candidate_prob = 0.9 24 | self.num_vis_mask = 10 25 | self.nhead = nhead 26 | 27 | self.d_model = d_model 28 | self.use_self_attn = False 29 | self.loss_weight = 1.0 30 | self.max_length = config.dataset_max_length + 1 # additional stop token 31 | self.debug = False 32 | 33 | encoder_layer = TransformerEncoderLayer(d_model=d_model, nhead=nhead, 34 | dim_feedforward=d_inner, dropout=dropout, activation=activation) 35 | self.model1 = TransformerEncoder(encoder_layer, num_layers) 36 | self.pos_encoder_tfm = PositionalEncoding(d_model, dropout=0, max_len=8*32) 37 | 38 | mode = 'nearest' 39 | self.model2_vis = PositionAttention( 40 | max_length=config.dataset_max_length + 1, # additional stop token 41 | mode=mode 42 | ) # 实现了把最后混合后的图像序列变为和文本一样的模态。使用的模型就相当于文本识别模型把图像变成文本那样 43 | self.cls_vis = nn.Linear(d_model, self.charset.num_classes) 44 | self.cls_sem = nn.Linear(d_model, self.charset.num_classes) 45 | self.w_att = nn.Linear(2 * d_model, d_model) 46 | 47 | v_token = torch.empty((1, d_model)) 48 | self.v_token = nn.Parameter(v_token) 49 | torch.nn.init.uniform_(self.v_token, -0.001, 0.001) 50 | 51 | self.cls = nn.Linear(d_model, self.charset.num_classes) 52 | 53 | def forward(self, l_feature, v_feature, lengths_l=None, v_attn=None, l_logits=None, texts=None, training=True): 54 | """ 55 | Args: 56 | l_feature: (N, T, E) where T is length, N is batch size and d is dim of model 57 | v_feature: (N, E, H, W) 58 | lengths_l: (N,) 59 | v_attn: (N, T, H, W) 60 | l_logits: (N, T, C) 61 | texts: (N, T, C) 62 | """ 63 | padding_mask = self._get_padding_mask(lengths_l, self.max_length) 64 | 65 | l_feature = l_feature.permute(1, 0, 2) # (T, N, E) 66 | N, E, H, W = v_feature.size() 67 | v_feature = v_feature.view(N, E, H*W).contiguous().permute(2, 0, 1) # (H*W, N, E) 68 | # ==========对输入的视觉特征做掩模处理使其对遮挡更鲁棒========= 69 | if training: # 视觉掩模,我们也可以用啊,可以用它做遮挡鲁棒性加强! 70 | n, t, h, w = v_attn.shape 71 | v_attn = v_attn.view(n, t, -1) # (N, T, H*W) 72 | for idx, length in enumerate(lengths_l): 73 | if np.random.random() <= self.mask_example_prob: 74 | l_idx = np.random.randint(int(length)) 75 | v_random_idx = v_attn[idx, l_idx].argsort(descending=True).cpu().numpy()[:self.num_vis_mask,] # 通过argsort找到和第idx个字符相关的文本区域索引,取前numvismask个作为被掩盖区域 76 | v_random_idx = v_random_idx[np.random.random(v_random_idx.shape) <= self.mask_candidate_prob] # 找到需要掩盖的之后并不是都要掩盖,而是再调随机个数掩盖 77 | v_feature[v_random_idx, idx] = self.v_token # 把对应位置全部放上empty 78 | 79 | if len(v_attn.shape) == 4: 80 | n, t, h, w = v_attn.shape 81 | v_attn = v_attn.view(n, t, -1) # (N, T, H*W) 82 | 83 | # =============根据视觉-语义注意力图将视觉的位置编码转化到文本的位置编码================ 84 | zeros = v_feature.new_zeros((h*w, n, E)) # (H*W, N, E) 85 | base_pos = self.pos_encoder_tfm(zeros) # (H*W, N, E) 86 | base_pos = base_pos.permute(1, 0, 2) # (N, H*W, E) 87 | 88 | base_pos = torch.bmm(v_attn, base_pos) # (N, T, E) 89 | base_pos = base_pos.permute(1, 0, 2) # (T, N, E) 90 | 91 | # =======0=========对语义特征加上对齐后的位置编码:attn map的用处之一 ========================= 92 | l_feature = l_feature + base_pos 93 | 94 | # ===============多模态Transformer:很简单,拼起来过一层自注意力,再分开,就结束了========= 95 | sv_feature = torch.cat((v_feature, l_feature), dim=0) # (H*W+T, N, E) 96 | sv_feature = self.model1(sv_feature) # (H*W+T, N, E) 97 | sv_to_v_feature = sv_feature[:H*W] # (H*W, N, E) 98 | sv_to_s_feature = sv_feature[H*W:] # (T, N, E) 99 | 100 | # =============对视觉特征经过decoder得到和文本相同模态的特征===== 101 | sv_to_v_feature = sv_to_v_feature.permute(1, 2, 0).view(N, E, H, W) 102 | sv_to_v_feature, _ = self.model2_vis(sv_to_v_feature) # (N, T, E) 103 | sv_to_v_logits = self.cls_vis(sv_to_v_feature) # (N, T, C) 104 | pt_v_lengths = self._get_length(sv_to_v_logits) # (N,) 105 | 106 | sv_to_s_feature = sv_to_s_feature.permute(1, 0, 2) # (N, T, E) 107 | sv_to_s_logits = self.cls_sem(sv_to_s_feature) # (N, T, C) 108 | pt_s_lengths = self._get_length(sv_to_s_logits) # (N,) 109 | # ==================gate fusion ================= 110 | f = torch.cat((sv_to_v_feature, sv_to_s_feature), dim=2) 111 | f_att = torch.sigmoid(self.w_att(f)) 112 | output = f_att * sv_to_v_feature + (1 - f_att) * sv_to_s_feature 113 | 114 | logits = self.cls(output) # (N, T, C) 115 | pt_lengths = self._get_length(logits) 116 | 117 | return {'logits': logits, 'pt_lengths': pt_lengths, 'loss_weight':self.loss_weight*3, 118 | 'v_logits': sv_to_v_logits, 'pt_v_lengths': pt_v_lengths, 119 | 's_logits': sv_to_s_logits, 'pt_s_lengths': pt_s_lengths, 120 | 'name': 'alignment'} 121 | -------------------------------------------------------------------------------- /model/MPNCOV/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/MPNCOV/__init__.py -------------------------------------------------------------------------------- /model/MPNCOV/python/MPNCOV.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @file: MPNCOV.py 3 | @author: Jiangtao Xie 4 | @author: Peihua Li 5 | 6 | Copyright (C) 2018 Peihua Li and Jiangtao Xie 7 | 8 | All rights reserved. 9 | ''' 10 | import torch 11 | import numpy as np 12 | from torch.autograd import Function 13 | 14 | class Covpool(Function): 15 | @staticmethod 16 | def forward(ctx, input): 17 | x = input 18 | batchSize = x.data.shape[0] 19 | dim = x.data.shape[1] 20 | h = x.data.shape[2] 21 | w = x.data.shape[3] 22 | M = h*w 23 | x = x.reshape(batchSize,dim,M) 24 | I_hat = (-1./M/M)*torch.ones(M,M,device = x.device) + (1./M)*torch.eye(M,M,device = x.device) 25 | I_hat = I_hat.view(1,M,M).repeat(batchSize,1,1).type(x.dtype) 26 | y = x.bmm(I_hat).bmm(x.transpose(1,2)) 27 | ctx.save_for_backward(input,I_hat) 28 | return y 29 | @staticmethod 30 | def backward(ctx, grad_output): 31 | input,I_hat = ctx.saved_tensors 32 | x = input 33 | batchSize = x.data.shape[0] 34 | dim = x.data.shape[1] 35 | h = x.data.shape[2] 36 | w = x.data.shape[3] 37 | M = h*w 38 | x = x.reshape(batchSize,dim,M) 39 | grad_input = grad_output + grad_output.transpose(1,2) 40 | grad_input = grad_input.bmm(x).bmm(I_hat) 41 | grad_input = grad_input.reshape(batchSize,dim,h,w) 42 | return grad_input 43 | 44 | class Sqrtm(Function): 45 | @staticmethod 46 | def forward(ctx, input, iterN): 47 | x = input 48 | batchSize = x.data.shape[0] 49 | dim = x.data.shape[1] 50 | dtype = x.dtype 51 | I3 = 3.0*torch.eye(dim,dim,device = x.device).view(1, dim, dim).repeat(batchSize,1,1).type(dtype) 52 | normA = (1.0/3.0)*x.mul(I3).sum(dim=1).sum(dim=1) 53 | A = x.div(normA.view(batchSize,1,1).expand_as(x)) 54 | Y = torch.zeros(batchSize, iterN, dim, dim, requires_grad = False, device = x.device) 55 | Z = torch.eye(dim,dim,device = x.device).view(1,dim,dim).repeat(batchSize,iterN,1,1) 56 | if iterN < 2: 57 | ZY = 0.5*(I3 - A) 58 | Y[:,0,:,:] = A.bmm(ZY) 59 | else: 60 | ZY = 0.5*(I3 - A) 61 | Y[:,0,:,:] = A.bmm(ZY) 62 | Z[:,0,:,:] = ZY 63 | for i in range(1, iterN-1): 64 | ZY = 0.5*(I3 - Z[:,i-1,:,:].bmm(Y[:,i-1,:,:])) 65 | Y[:,i,:,:] = Y[:,i-1,:,:].bmm(ZY) 66 | Z[:,i,:,:] = ZY.bmm(Z[:,i-1,:,:]) 67 | ZY = 0.5*Y[:,iterN-2,:,:].bmm(I3 - Z[:,iterN-2,:,:].bmm(Y[:,iterN-2,:,:])) 68 | y = ZY*torch.sqrt(normA).view(batchSize, 1, 1).expand_as(x) 69 | ctx.save_for_backward(input, A, ZY, normA, Y, Z) 70 | ctx.iterN = iterN 71 | return y 72 | @staticmethod 73 | def backward(ctx, grad_output): 74 | input, A, ZY, normA, Y, Z = ctx.saved_tensors 75 | iterN = ctx.iterN 76 | x = input 77 | batchSize = x.data.shape[0] 78 | dim = x.data.shape[1] 79 | dtype = x.dtype 80 | der_postCom = grad_output*torch.sqrt(normA).view(batchSize, 1, 1).expand_as(x) 81 | der_postComAux = (grad_output*ZY).sum(dim=1).sum(dim=1).div(2*torch.sqrt(normA)) 82 | I3 = 3.0*torch.eye(dim,dim,device = x.device).view(1, dim, dim).repeat(batchSize,1,1).type(dtype) 83 | if iterN < 2: 84 | der_NSiter = 0.5*(der_postCom.bmm(I3 - A) - A.bmm(der_sacleTrace)) 85 | else: 86 | dldY = 0.5*(der_postCom.bmm(I3 - Y[:,iterN-2,:,:].bmm(Z[:,iterN-2,:,:])) - 87 | Z[:,iterN-2,:,:].bmm(Y[:,iterN-2,:,:]).bmm(der_postCom)) 88 | dldZ = -0.5*Y[:,iterN-2,:,:].bmm(der_postCom).bmm(Y[:,iterN-2,:,:]) 89 | for i in range(iterN-3, -1, -1): 90 | YZ = I3 - Y[:,i,:,:].bmm(Z[:,i,:,:]) 91 | ZY = Z[:,i,:,:].bmm(Y[:,i,:,:]) 92 | dldY_ = 0.5*(dldY.bmm(YZ) - 93 | Z[:,i,:,:].bmm(dldZ).bmm(Z[:,i,:,:]) - 94 | ZY.bmm(dldY)) 95 | dldZ_ = 0.5*(YZ.bmm(dldZ) - 96 | Y[:,i,:,:].bmm(dldY).bmm(Y[:,i,:,:]) - 97 | dldZ.bmm(ZY)) 98 | dldY = dldY_ 99 | dldZ = dldZ_ 100 | der_NSiter = 0.5*(dldY.bmm(I3 - A) - dldZ - A.bmm(dldY)) 101 | grad_input = der_NSiter.div(normA.view(batchSize,1,1).expand_as(x)) 102 | grad_aux = der_NSiter.mul(x).sum(dim=1).sum(dim=1) 103 | for i in range(batchSize): 104 | grad_input[i,:,:] += (der_postComAux[i] \ 105 | - grad_aux[i] / (normA[i] * normA[i])) \ 106 | *torch.ones(dim,device = x.device).diag() 107 | return grad_input, None 108 | 109 | class Triuvec(Function): 110 | @staticmethod 111 | def forward(ctx, input): 112 | x = input 113 | batchSize = x.data.shape[0] 114 | dim = x.data.shape[1] 115 | dtype = x.dtype 116 | x = x.reshape(batchSize, dim*dim) 117 | I = torch.ones(dim,dim).triu().t().reshape(dim*dim) 118 | index = I.nonzero() 119 | y = torch.zeros(batchSize,dim*(dim+1)/2,device = x.device) 120 | for i in range(batchSize): 121 | y[i, :] = x[i, index].t() 122 | ctx.save_for_backward(input,index) 123 | return y 124 | @staticmethod 125 | def backward(ctx, grad_output): 126 | input,index = ctx.saved_tensors 127 | x = input 128 | batchSize = x.data.shape[0] 129 | dim = x.data.shape[1] 130 | dtype = x.dtype 131 | grad_input = torch.zeros(batchSize,dim,dim,device = x.device,requires_grad=False) 132 | grad_input = grad_input.reshape(batchSize,dim*dim) 133 | for i in range(batchSize): 134 | grad_input[i,index] = grad_output[i,:].reshape(index.size(),1) 135 | grad_input = grad_input.reshape(batchSize,dim,dim) 136 | return grad_input 137 | 138 | def CovpoolLayer(var): 139 | return Covpool.apply(var) 140 | 141 | def SqrtmLayer(var, iterN): 142 | return Sqrtm.apply(var, iterN) 143 | 144 | def TriuvecLayer(var): 145 | return Triuvec.apply(var) 146 | -------------------------------------------------------------------------------- /model/MPNCOV/python/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/MPNCOV/python/__init__.py -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | # Author: Hang Guo 2 | # Tsinghua University -------------------------------------------------------------------------------- /model/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .transformer import PositionalEncoding 4 | 5 | 6 | 7 | class Attention(nn.Module): 8 | def __init__(self, in_channels=512, max_length=25, n_feature=256): 9 | super().__init__() 10 | self.max_length = max_length 11 | 12 | self.f0_embedding = nn.Embedding(max_length, in_channels) 13 | self.w0 = nn.Linear(max_length, n_feature) 14 | self.wv = nn.Linear(in_channels, in_channels) 15 | self.we = nn.Linear(in_channels, max_length) 16 | 17 | self.active = nn.Tanh() 18 | self.softmax = nn.Softmax(dim=2) 19 | 20 | def forward(self, enc_output): 21 | enc_output = enc_output.permute(0, 2, 3, 1).flatten(1, 2) 22 | reading_order = torch.arange(self.max_length, dtype=torch.long, device=enc_output.device) 23 | reading_order = reading_order.unsqueeze(0).expand(enc_output.size(0), -1) # (S,) -> (B, S) 24 | reading_order_embed = self.f0_embedding(reading_order) # b,25,512 25 | 26 | t = self.w0(reading_order_embed.permute(0, 2, 1)) # b,512,256 27 | t = self.active(t.permute(0, 2, 1) + self.wv(enc_output)) # b,256,512 28 | 29 | attn = self.we(t) # b,256,25 30 | attn = self.softmax(attn.permute(0, 2, 1)) # b,25,256 31 | g_output = torch.bmm(attn, enc_output) # b,25,512 32 | return g_output, attn.view(*attn.shape[:2], 8, 32) 33 | 34 | 35 | def encoder_layer(in_c, out_c, k=3, s=2, p=1): 36 | return nn.Sequential(nn.Conv2d(in_c, out_c, k, s, p), 37 | nn.BatchNorm2d(out_c), 38 | nn.ReLU(True)) 39 | 40 | def decoder_layer(in_c, out_c, k=3, s=1, p=1, mode='nearest', scale_factor=None, size=None): 41 | align_corners = None if mode=='nearest' else True 42 | return nn.Sequential(nn.Upsample(size=size, scale_factor=scale_factor, 43 | mode=mode, align_corners=align_corners), 44 | nn.Conv2d(in_c, out_c, k, s, p), 45 | nn.BatchNorm2d(out_c), 46 | nn.ReLU(True)) 47 | 48 | 49 | class PositionAttention(nn.Module): 50 | def __init__(self, max_length, in_channels=512, num_channels=64, 51 | h=8, w=32, mode='nearest', **kwargs): 52 | super().__init__() 53 | self.max_length = max_length # len of alphbet -- 26 54 | self.k_encoder = nn.Sequential( 55 | encoder_layer(in_channels, num_channels, s=(1, 2)), 56 | encoder_layer(num_channels, num_channels, s=(2, 2)), 57 | encoder_layer(num_channels, num_channels, s=(2, 2)), 58 | encoder_layer(num_channels, num_channels, s=(2, 2)) 59 | ) # conv - bn - relu 60 | self.k_decoder = nn.Sequential( 61 | decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode), 62 | decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode), 63 | decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode), 64 | decoder_layer(num_channels, in_channels, size=(h, w), mode=mode) 65 | )# upsample - conv - bn -relu 66 | self.pos_encoder = PositionalEncoding(in_channels, dropout=0, max_len=max_length) 67 | self.project = nn.Linear(in_channels, in_channels) # 对位置编码再过一层Linear 68 | 69 | def forward(self, x, q=None): 70 | # x is img feat 71 | # q is text PE 72 | N, E, H, W = x.size() 73 | k, v = x, x # (N, E, H, W) 74 | 75 | # calculate key vector from x 76 | # mini U-Net 77 | features = [] 78 | for i in range(0, len(self.k_encoder)): 79 | k = self.k_encoder[i](k) 80 | features.append(k) 81 | for i in range(0, len(self.k_decoder) - 1): 82 | k = self.k_decoder[i](k) 83 | k = k + features[len(self.k_decoder) - 2 - i] 84 | k = self.k_decoder[-1](k) 85 | 86 | # calculate query vector 87 | # TODO q=f(q,k) 88 | if q is None: 89 | zeros = x.new_zeros((self.max_length, N, E)) # (T, N, E) 90 | q = self.pos_encoder(zeros) # (T, N, E) 91 | q = q.permute(1, 0, 2) # (N, T, E) 92 | q = self.project(q) # (N, T, E) 93 | 94 | # calculate attention 95 | attn_scores = torch.bmm(q, k.flatten(2, 3)) # (N, T, (H*W)) 96 | attn_scores = attn_scores / (E ** 0.5) 97 | attn_scores = torch.softmax(attn_scores, dim=-1) 98 | 99 | v = v.permute(0, 2, 3, 1).view(N, -1, E) # (N, (H*W), E) 100 | attn_vecs = torch.bmm(attn_scores, v) # (N, T, E) 101 | 102 | return attn_vecs, attn_scores.view(N, -1, H, W) 103 | -------------------------------------------------------------------------------- /model/backbone.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from model.resnet import resnet45 5 | from model.transformer import (PositionalEncoding, 6 | TransformerEncoder, 7 | TransformerEncoderLayer) 8 | 9 | _default_tfmer_cfg = dict(d_model=512, nhead=8, d_inner=2048, # 1024 10 | dropout=0.1, activation='relu',num_layers=3) 11 | 12 | 13 | class ResTranformer(nn.Module): 14 | def __init__(self, config): 15 | super().__init__() 16 | self.resnet = resnet45() 17 | 18 | self.d_model = _default_tfmer_cfg['d_model'] 19 | nhead = _default_tfmer_cfg['nhead'] 20 | d_inner = _default_tfmer_cfg['d_inner'] 21 | dropout = _default_tfmer_cfg['dropout'] 22 | activation = _default_tfmer_cfg['activation'] 23 | num_layers = _default_tfmer_cfg['num_layers'] 24 | 25 | self.pos_encoder = PositionalEncoding(self.d_model, max_len=8*32) 26 | encoder_layer = TransformerEncoderLayer(d_model=self.d_model, nhead=nhead, 27 | dim_feedforward=d_inner, dropout=dropout, activation=activation) 28 | self.transformer = TransformerEncoder(encoder_layer, num_layers) 29 | 30 | def forward(self, images): 31 | # Resnet45 + Transformer Encoder 32 | # 相当于用一个小的Resnet+Trans 来替代原来大的backbone Resnet 33 | feature = self.resnet(images) 34 | n, c, h, w = feature.shape 35 | feature = feature.view(n, c, -1).permute(2, 0, 1) 36 | feature = self.pos_encoder(feature) # add PE 37 | feature = self.transformer(feature) # encoder 38 | feature = feature.permute(1, 2, 0).view(n, c, h, w) 39 | return feature 40 | -------------------------------------------------------------------------------- /model/crnn/__init__.py: -------------------------------------------------------------------------------- 1 | from .crnn import CRNN, CRNN_ResNet18 2 | from .model import Model -------------------------------------------------------------------------------- /model/crnn/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/crnn/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /model/crnn/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/crnn/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /model/crnn/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/crnn/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /model/crnn/__pycache__/crnn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/crnn/__pycache__/crnn.cpython-36.pyc -------------------------------------------------------------------------------- /model/crnn/__pycache__/crnn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/crnn/__pycache__/crnn.cpython-37.pyc -------------------------------------------------------------------------------- /model/crnn/__pycache__/crnn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/crnn/__pycache__/crnn.cpython-38.pyc -------------------------------------------------------------------------------- /model/crnn/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/crnn/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /model/crnn/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/crnn/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /model/crnn/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2019-present NAVER Corp. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import torch.nn as nn 18 | 19 | from .modules.transformation import TPS_SpatialTransformerNetwork 20 | from .modules.feature_extraction import VGG_FeatureExtractor, RCNN_FeatureExtractor, ResNet_FeatureExtractor 21 | from .modules.sequence_modeling import BidirectionalLSTM 22 | from .modules.prediction import Attention 23 | import numpy as np 24 | 25 | class Model(nn.Module): 26 | 27 | def __init__(self, opt): 28 | super(Model, self).__init__() 29 | self.opt = opt 30 | self.stages = {'Trans': opt.Transformation, 'Feat': opt.FeatureExtraction, 31 | 'Seq': opt.SequenceModeling, 'Pred': opt.Prediction} 32 | 33 | """ Transformation """ 34 | if opt.Transformation == 'TPS': 35 | self.Transformation = TPS_SpatialTransformerNetwork( 36 | F=opt.num_fiducial, I_size=(opt.imgH, opt.imgW), I_r_size=(opt.imgH, opt.imgW), I_channel_num=opt.input_channel) 37 | else: 38 | print('No Transformation module specified') 39 | 40 | """ FeatureExtraction """ 41 | if opt.FeatureExtraction == 'VGG': 42 | self.FeatureExtraction = VGG_FeatureExtractor(opt.input_channel, opt.output_channel) 43 | elif opt.FeatureExtraction == 'RCNN': 44 | self.FeatureExtraction = RCNN_FeatureExtractor(opt.input_channel, opt.output_channel) 45 | elif opt.FeatureExtraction == 'ResNet': 46 | self.FeatureExtraction = ResNet_FeatureExtractor(opt.input_channel, opt.output_channel) 47 | else: 48 | raise Exception('No FeatureExtraction module specified') 49 | self.FeatureExtraction_output = opt.output_channel # int(imgH/16-1) * 512 50 | self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1)) # Transform final (imgH/16-1) -> 1 51 | 52 | """ Sequence modeling""" 53 | if opt.SequenceModeling == 'BiLSTM': 54 | self.SequenceModeling = nn.Sequential( 55 | BidirectionalLSTM(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size), 56 | BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size)) 57 | self.SequenceModeling_output = opt.hidden_size 58 | else: 59 | print('No SequenceModeling module specified') 60 | self.SequenceModeling_output = self.FeatureExtraction_output 61 | 62 | """ Prediction """ 63 | if opt.Prediction == 'CTC': 64 | self.Prediction = nn.Linear(self.SequenceModeling_output, opt.num_class) 65 | elif opt.Prediction == 'Attn': 66 | self.Prediction = Attention(self.SequenceModeling_output, opt.hidden_size, opt.num_class) 67 | else: 68 | raise Exception('Prediction is neither CTC or Attn') 69 | 70 | def forward(self, input, text=None, is_train=True): 71 | """ Transformation stage """ 72 | if not self.stages['Trans'] == "None": 73 | input = self.Transformation(input) 74 | 75 | """ Feature extraction stage """ 76 | visual_feature = self.FeatureExtraction(input) 77 | visual_feature = self.AdaptiveAvgPool(visual_feature.permute(0, 3, 1, 2)) # [b, c, h, w] -> [b, w, c, h] 78 | visual_feature = visual_feature.squeeze(3) 79 | 80 | """ Sequence modeling stage """ 81 | if self.stages['Seq'] == 'BiLSTM': 82 | contextual_feature = self.SequenceModeling(visual_feature) 83 | else: 84 | contextual_feature = visual_feature # for convenience. this is NOT contextually modeled by BiLSTM 85 | 86 | """ Prediction stage """ 87 | if self.stages['Pred'] == 'CTC': 88 | prediction = self.Prediction(contextual_feature.contiguous()) 89 | prediction = prediction.permute(1, 0, 2) 90 | else: 91 | prediction = self.Prediction(contextual_feature.contiguous(), text, is_train, batch_max_length=self.opt.batch_max_length) 92 | 93 | # print("prediction:", prediction.shape) 94 | 95 | return prediction 96 | -------------------------------------------------------------------------------- /model/crnn/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/crnn/modules/__init__.py -------------------------------------------------------------------------------- /model/crnn/modules/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/crnn/modules/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /model/crnn/modules/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/crnn/modules/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /model/crnn/modules/__pycache__/feature_extraction.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/crnn/modules/__pycache__/feature_extraction.cpython-37.pyc -------------------------------------------------------------------------------- /model/crnn/modules/__pycache__/feature_extraction.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/crnn/modules/__pycache__/feature_extraction.cpython-38.pyc -------------------------------------------------------------------------------- /model/crnn/modules/__pycache__/prediction.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/crnn/modules/__pycache__/prediction.cpython-37.pyc -------------------------------------------------------------------------------- /model/crnn/modules/__pycache__/prediction.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/crnn/modules/__pycache__/prediction.cpython-38.pyc -------------------------------------------------------------------------------- /model/crnn/modules/__pycache__/sequence_modeling.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/crnn/modules/__pycache__/sequence_modeling.cpython-37.pyc -------------------------------------------------------------------------------- /model/crnn/modules/__pycache__/sequence_modeling.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/crnn/modules/__pycache__/sequence_modeling.cpython-38.pyc -------------------------------------------------------------------------------- /model/crnn/modules/__pycache__/transformation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/crnn/modules/__pycache__/transformation.cpython-37.pyc -------------------------------------------------------------------------------- /model/crnn/modules/__pycache__/transformation.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/crnn/modules/__pycache__/transformation.cpython-38.pyc -------------------------------------------------------------------------------- /model/crnn/modules/feature_extraction.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class VGG_FeatureExtractor(nn.Module): 6 | """ FeatureExtractor of CRNN (https://arxiv.org/pdf/1507.05717.pdf) """ 7 | 8 | def __init__(self, input_channel, output_channel=512): 9 | super(VGG_FeatureExtractor, self).__init__() 10 | self.output_channel = [int(output_channel / 8), int(output_channel / 4), 11 | int(output_channel / 2), output_channel] # [64, 128, 256, 512] 12 | self.ConvNet = nn.Sequential( 13 | nn.Conv2d(input_channel, self.output_channel[0], 3, 1, 1), nn.ReLU(True), 14 | nn.MaxPool2d(2, 2), # 64x16x50 15 | nn.Conv2d(self.output_channel[0], self.output_channel[1], 3, 1, 1), nn.ReLU(True), 16 | nn.MaxPool2d(2, 2), # 128x8x25 17 | nn.Conv2d(self.output_channel[1], self.output_channel[2], 3, 1, 1), nn.ReLU(True), # 256x8x25 18 | nn.Conv2d(self.output_channel[2], self.output_channel[2], 3, 1, 1), nn.ReLU(True), 19 | nn.MaxPool2d((2, 1), (2, 1)), # 256x4x25 20 | nn.Conv2d(self.output_channel[2], self.output_channel[3], 3, 1, 1, bias=False), 21 | nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True), # 512x4x25 22 | nn.Conv2d(self.output_channel[3], self.output_channel[3], 3, 1, 1, bias=False), 23 | nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True), 24 | nn.MaxPool2d((2, 1), (2, 1)), # 512x2x25 25 | nn.Conv2d(self.output_channel[3], self.output_channel[3], 2, 1, 0), nn.ReLU(True)) # 512x1x24 26 | 27 | def forward(self, input): 28 | return self.ConvNet(input) 29 | 30 | 31 | class RCNN_FeatureExtractor(nn.Module): 32 | """ FeatureExtractor of GRCNN (https://papers.nips.cc/paper/6637-gated-recurrent-convolution-neural-network-for-ocr.pdf) """ 33 | 34 | def __init__(self, input_channel, output_channel=512): 35 | super(RCNN_FeatureExtractor, self).__init__() 36 | self.output_channel = [int(output_channel / 8), int(output_channel / 4), 37 | int(output_channel / 2), output_channel] # [64, 128, 256, 512] 38 | self.ConvNet = nn.Sequential( 39 | nn.Conv2d(input_channel, self.output_channel[0], 3, 1, 1), nn.ReLU(True), 40 | nn.MaxPool2d(2, 2), # 64 x 16 x 50 41 | GRCL(self.output_channel[0], self.output_channel[0], num_iteration=5, kernel_size=3, pad=1), 42 | nn.MaxPool2d(2, 2), # 64 x 8 x 25 43 | GRCL(self.output_channel[0], self.output_channel[1], num_iteration=5, kernel_size=3, pad=1), 44 | nn.MaxPool2d(2, (2, 1), (0, 1)), # 128 x 4 x 26 45 | GRCL(self.output_channel[1], self.output_channel[2], num_iteration=5, kernel_size=3, pad=1), 46 | nn.MaxPool2d(2, (2, 1), (0, 1)), # 256 x 2 x 27 47 | nn.Conv2d(self.output_channel[2], self.output_channel[3], 2, 1, 0, bias=False), 48 | nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True)) # 512 x 1 x 26 49 | 50 | def forward(self, input): 51 | return self.ConvNet(input) 52 | 53 | 54 | class ResNet_FeatureExtractor(nn.Module): 55 | """ FeatureExtractor of FAN (http://openaccess.thecvf.com/content_ICCV_2017/papers/Cheng_Focusing_Attention_Towards_ICCV_2017_paper.pdf) """ 56 | 57 | def __init__(self, input_channel, output_channel=512): 58 | super(ResNet_FeatureExtractor, self).__init__() 59 | self.ConvNet = ResNet(input_channel, output_channel, BasicBlock, [1, 2, 5, 3]) 60 | 61 | def forward(self, input): 62 | return self.ConvNet(input) 63 | 64 | 65 | # For Gated RCNN 66 | class GRCL(nn.Module): 67 | 68 | def __init__(self, input_channel, output_channel, num_iteration, kernel_size, pad): 69 | super(GRCL, self).__init__() 70 | self.wgf_u = nn.Conv2d(input_channel, output_channel, 1, 1, 0, bias=False) 71 | self.wgr_x = nn.Conv2d(output_channel, output_channel, 1, 1, 0, bias=False) 72 | self.wf_u = nn.Conv2d(input_channel, output_channel, kernel_size, 1, pad, bias=False) 73 | self.wr_x = nn.Conv2d(output_channel, output_channel, kernel_size, 1, pad, bias=False) 74 | 75 | self.BN_x_init = nn.BatchNorm2d(output_channel) 76 | 77 | self.num_iteration = num_iteration 78 | self.GRCL = [GRCL_unit(output_channel) for _ in range(num_iteration)] 79 | self.GRCL = nn.Sequential(*self.GRCL) 80 | 81 | def forward(self, input): 82 | """ The input of GRCL is consistant over time t, which is denoted by u(0) 83 | thus wgf_u / wf_u is also consistant over time t. 84 | """ 85 | wgf_u = self.wgf_u(input) 86 | wf_u = self.wf_u(input) 87 | x = F.relu(self.BN_x_init(wf_u)) 88 | 89 | for i in range(self.num_iteration): 90 | x = self.GRCL[i](wgf_u, self.wgr_x(x), wf_u, self.wr_x(x)) 91 | 92 | return x 93 | 94 | 95 | class GRCL_unit(nn.Module): 96 | 97 | def __init__(self, output_channel): 98 | super(GRCL_unit, self).__init__() 99 | self.BN_gfu = nn.BatchNorm2d(output_channel) 100 | self.BN_grx = nn.BatchNorm2d(output_channel) 101 | self.BN_fu = nn.BatchNorm2d(output_channel) 102 | self.BN_rx = nn.BatchNorm2d(output_channel) 103 | self.BN_Gx = nn.BatchNorm2d(output_channel) 104 | 105 | def forward(self, wgf_u, wgr_x, wf_u, wr_x): 106 | G_first_term = self.BN_gfu(wgf_u) 107 | G_second_term = self.BN_grx(wgr_x) 108 | G = F.sigmoid(G_first_term + G_second_term) 109 | 110 | x_first_term = self.BN_fu(wf_u) 111 | x_second_term = self.BN_Gx(self.BN_rx(wr_x) * G) 112 | x = F.relu(x_first_term + x_second_term) 113 | 114 | return x 115 | 116 | 117 | class BasicBlock(nn.Module): 118 | expansion = 1 119 | 120 | def __init__(self, inplanes, planes, stride=1, downsample=None): 121 | super(BasicBlock, self).__init__() 122 | self.conv1 = self._conv3x3(inplanes, planes) 123 | self.bn1 = nn.BatchNorm2d(planes) 124 | self.conv2 = self._conv3x3(planes, planes) 125 | self.bn2 = nn.BatchNorm2d(planes) 126 | self.relu = nn.ReLU(inplace=True) 127 | self.downsample = downsample 128 | self.stride = stride 129 | 130 | def _conv3x3(self, in_planes, out_planes, stride=1): 131 | "3x3 convolution with padding" 132 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 133 | padding=1, bias=False) 134 | 135 | def forward(self, x): 136 | residual = x 137 | 138 | out = self.conv1(x) 139 | out = self.bn1(out) 140 | out = self.relu(out) 141 | 142 | out = self.conv2(out) 143 | out = self.bn2(out) 144 | 145 | if self.downsample is not None: 146 | residual = self.downsample(x) 147 | out += residual 148 | out = self.relu(out) 149 | 150 | return out 151 | 152 | 153 | class ResNet(nn.Module): 154 | 155 | def __init__(self, input_channel, output_channel, block, layers): 156 | super(ResNet, self).__init__() 157 | 158 | self.output_channel_block = [int(output_channel / 4), int(output_channel / 2), output_channel, output_channel] 159 | 160 | self.inplanes = int(output_channel / 8) 161 | self.conv0_1 = nn.Conv2d(input_channel, int(output_channel / 16), 162 | kernel_size=3, stride=1, padding=1, bias=False) 163 | self.bn0_1 = nn.BatchNorm2d(int(output_channel / 16)) 164 | self.conv0_2 = nn.Conv2d(int(output_channel / 16), self.inplanes, 165 | kernel_size=3, stride=1, padding=1, bias=False) 166 | self.bn0_2 = nn.BatchNorm2d(self.inplanes) 167 | self.relu = nn.ReLU(inplace=True) 168 | 169 | self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 170 | self.layer1 = self._make_layer(block, self.output_channel_block[0], layers[0]) 171 | self.conv1 = nn.Conv2d(self.output_channel_block[0], self.output_channel_block[ 172 | 0], kernel_size=3, stride=1, padding=1, bias=False) 173 | self.bn1 = nn.BatchNorm2d(self.output_channel_block[0]) 174 | 175 | self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 176 | self.layer2 = self._make_layer(block, self.output_channel_block[1], layers[1], stride=1) 177 | self.conv2 = nn.Conv2d(self.output_channel_block[1], self.output_channel_block[ 178 | 1], kernel_size=3, stride=1, padding=1, bias=False) 179 | self.bn2 = nn.BatchNorm2d(self.output_channel_block[1]) 180 | 181 | self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=(2, 1), padding=(0, 1)) 182 | self.layer3 = self._make_layer(block, self.output_channel_block[2], layers[2], stride=1) 183 | self.conv3 = nn.Conv2d(self.output_channel_block[2], self.output_channel_block[ 184 | 2], kernel_size=3, stride=1, padding=1, bias=False) 185 | self.bn3 = nn.BatchNorm2d(self.output_channel_block[2]) 186 | 187 | self.layer4 = self._make_layer(block, self.output_channel_block[3], layers[3], stride=1) 188 | self.conv4_1 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[ 189 | 3], kernel_size=2, stride=(2, 1), padding=(0, 1), bias=False) 190 | self.bn4_1 = nn.BatchNorm2d(self.output_channel_block[3]) 191 | self.conv4_2 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[ 192 | 3], kernel_size=2, stride=1, padding=0, bias=False) 193 | self.bn4_2 = nn.BatchNorm2d(self.output_channel_block[3]) 194 | 195 | def _make_layer(self, block, planes, blocks, stride=1): 196 | downsample = None 197 | if stride != 1 or self.inplanes != planes * block.expansion: 198 | downsample = nn.Sequential( 199 | nn.Conv2d(self.inplanes, planes * block.expansion, 200 | kernel_size=1, stride=stride, bias=False), 201 | nn.BatchNorm2d(planes * block.expansion), 202 | ) 203 | 204 | layers = [] 205 | layers.append(block(self.inplanes, planes, stride, downsample)) 206 | self.inplanes = planes * block.expansion 207 | for i in range(1, blocks): 208 | layers.append(block(self.inplanes, planes)) 209 | 210 | return nn.Sequential(*layers) 211 | 212 | def forward(self, x): 213 | x = self.conv0_1(x) 214 | x = self.bn0_1(x) 215 | x = self.relu(x) 216 | x = self.conv0_2(x) 217 | x = self.bn0_2(x) 218 | x = self.relu(x) 219 | 220 | x = self.maxpool1(x) 221 | x = self.layer1(x) 222 | x = self.conv1(x) 223 | x = self.bn1(x) 224 | x = self.relu(x) 225 | 226 | x = self.maxpool2(x) 227 | x = self.layer2(x) 228 | x = self.conv2(x) 229 | x = self.bn2(x) 230 | x = self.relu(x) 231 | 232 | x = self.maxpool3(x) 233 | x = self.layer3(x) 234 | x = self.conv3(x) 235 | x = self.bn3(x) 236 | x = self.relu(x) 237 | 238 | x = self.layer4(x) 239 | x = self.conv4_1(x) 240 | x = self.bn4_1(x) 241 | x = self.relu(x) 242 | x = self.conv4_2(x) 243 | x = self.bn4_2(x) 244 | x = self.relu(x) 245 | 246 | return x 247 | -------------------------------------------------------------------------------- /model/crnn/modules/prediction.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Attention(nn.Module): 7 | 8 | def __init__(self, input_size, hidden_size, num_classes): 9 | super(Attention, self).__init__() 10 | self.attention_cell = AttentionCell(input_size, hidden_size, num_classes) 11 | self.hidden_size = hidden_size 12 | self.num_classes = num_classes 13 | self.generator = nn.Linear(hidden_size, num_classes) 14 | 15 | def _char_to_onehot(self, input_char, onehot_dim=38): 16 | input_char = input_char.unsqueeze(1) 17 | batch_size = input_char.size(0) 18 | one_hot = torch.cuda.FloatTensor(batch_size, onehot_dim).zero_() 19 | one_hot = one_hot.scatter_(1, input_char, 1) 20 | return one_hot 21 | 22 | def forward(self, batch_H, text, is_train=True, batch_max_length=25): 23 | """ 24 | input: 25 | batch_H : contextual_feature H = hidden state of encoder. [batch_size x num_steps x num_classes] 26 | text : the text-index of each image. [batch_size x (max_length+1)]. +1 for [GO] token. text[:, 0] = [GO]. 27 | output: probability distribution at each step [batch_size x num_steps x num_classes] 28 | """ 29 | batch_size = batch_H.size(0) 30 | num_steps = batch_max_length + 1 # +1 for [s] at end of sentence. 31 | 32 | output_hiddens = torch.cuda.FloatTensor(batch_size, num_steps, self.hidden_size).fill_(0) 33 | hidden = (torch.cuda.FloatTensor(batch_size, self.hidden_size).fill_(0), 34 | torch.cuda.FloatTensor(batch_size, self.hidden_size).fill_(0)) 35 | 36 | if is_train: 37 | for i in range(num_steps): 38 | # one-hot vectors for a i-th char. in a batch 39 | char_onehots = self._char_to_onehot(text[:, i], onehot_dim=self.num_classes) 40 | # hidden : decoder's hidden s_{t-1}, batch_H : encoder's hidden H, char_onehots : one-hot(y_{t-1}) 41 | hidden, alpha = self.attention_cell(hidden, batch_H, char_onehots) 42 | output_hiddens[:, i, :] = hidden[0] # LSTM hidden index (0: hidden, 1: Cell) 43 | probs = self.generator(output_hiddens) 44 | 45 | else: 46 | targets = torch.cuda.LongTensor(batch_size).fill_(0) # [GO] token 47 | probs = torch.cuda.FloatTensor(batch_size, num_steps, self.num_classes).fill_(0) 48 | 49 | for i in range(num_steps): 50 | char_onehots = self._char_to_onehot(targets, onehot_dim=self.num_classes) 51 | hidden, alpha = self.attention_cell(hidden, batch_H, char_onehots) 52 | probs_step = self.generator(hidden[0]) 53 | probs[:, i, :] = probs_step 54 | _, next_input = probs_step.max(1) 55 | targets = next_input 56 | 57 | return probs # batch_size x num_steps x num_classes 58 | 59 | 60 | class AttentionCell(nn.Module): 61 | 62 | def __init__(self, input_size, hidden_size, num_embeddings): 63 | super(AttentionCell, self).__init__() 64 | self.i2h = nn.Linear(input_size, hidden_size, bias=False) 65 | self.h2h = nn.Linear(hidden_size, hidden_size) # either i2i or h2h should have bias 66 | self.score = nn.Linear(hidden_size, 1, bias=False) 67 | self.rnn = nn.LSTMCell(input_size + num_embeddings, hidden_size) 68 | self.hidden_size = hidden_size 69 | 70 | def forward(self, prev_hidden, batch_H, char_onehots): 71 | # [batch_size x num_encoder_step x num_channel] -> [batch_size x num_encoder_step x hidden_size] 72 | batch_H_proj = self.i2h(batch_H) 73 | prev_hidden_proj = self.h2h(prev_hidden[0]).unsqueeze(1) 74 | e = self.score(torch.tanh(batch_H_proj + prev_hidden_proj)) # batch_size x num_encoder_step * 1 75 | 76 | alpha = F.softmax(e, dim=1) 77 | context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze(1) # batch_size x num_channel 78 | concat_context = torch.cat([context, char_onehots], 1) # batch_size x (num_channel + num_embedding) 79 | cur_hidden = self.rnn(concat_context, prev_hidden) 80 | return cur_hidden, alpha 81 | -------------------------------------------------------------------------------- /model/crnn/modules/sequence_modeling.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class BidirectionalLSTM(nn.Module): 5 | 6 | def __init__(self, input_size, hidden_size, output_size): 7 | super(BidirectionalLSTM, self).__init__() 8 | self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True) 9 | self.linear = nn.Linear(hidden_size * 2, output_size) 10 | 11 | def forward(self, input): 12 | """ 13 | input : visual feature [batch_size x T x input_size] 14 | output : contextual feature [batch_size x T x output_size] 15 | """ 16 | self.rnn.flatten_parameters() 17 | recurrent, _ = self.rnn(input) # batch_size x T x input_size -> batch_size x T x (2*hidden_size) 18 | output = self.linear(recurrent) # batch_size x T x output_size 19 | return output 20 | -------------------------------------------------------------------------------- /model/crnn/modules/transformation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class TPS_SpatialTransformerNetwork(nn.Module): 8 | """ Rectification Network of RARE, namely TPS based STN """ 9 | 10 | def __init__(self, F, I_size, I_r_size, I_channel_num=1): 11 | """ Based on RARE TPS 12 | input: 13 | batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width] 14 | I_size : (height, width) of the input image I 15 | I_r_size : (height, width) of the rectified image I_r 16 | I_channel_num : the number of channels of the input image I 17 | output: 18 | batch_I_r: rectified image [batch_size x I_channel_num x I_r_height x I_r_width] 19 | """ 20 | super(TPS_SpatialTransformerNetwork, self).__init__() 21 | self.F = F 22 | self.I_size = I_size 23 | self.I_r_size = I_r_size # = (I_r_height, I_r_width) 24 | self.I_channel_num = I_channel_num 25 | self.LocalizationNetwork = LocalizationNetwork(self.F, self.I_channel_num) 26 | self.GridGenerator = GridGenerator(self.F, self.I_r_size) 27 | 28 | def forward(self, batch_I): 29 | batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2 30 | build_P_prime = self.GridGenerator.build_P_prime(batch_C_prime) # batch_size x n (= I_r_width x I_r_height) x 2 31 | build_P_prime_reshape = build_P_prime.reshape([build_P_prime.size(0), self.I_r_size[0], self.I_r_size[1], 2]) 32 | batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border') 33 | 34 | return batch_I_r 35 | 36 | 37 | class LocalizationNetwork(nn.Module): 38 | """ Localization Network of RARE, which predicts C' (K x 2) from I (I_width x I_height) """ 39 | 40 | def __init__(self, F, I_channel_num): 41 | super(LocalizationNetwork, self).__init__() 42 | self.F = F 43 | self.I_channel_num = I_channel_num 44 | self.conv = nn.Sequential( 45 | nn.Conv2d(in_channels=self.I_channel_num, out_channels=64, kernel_size=3, stride=1, padding=1, 46 | bias=False), nn.BatchNorm2d(64), nn.ReLU(True), 47 | nn.MaxPool2d(2, 2), # batch_size x 64 x I_height/2 x I_width/2 48 | nn.Conv2d(64, 128, 3, 1, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(True), 49 | nn.MaxPool2d(2, 2), # batch_size x 128 x I_height/4 x I_width/4 50 | nn.Conv2d(128, 256, 3, 1, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True), 51 | nn.MaxPool2d(2, 2), # batch_size x 256 x I_height/8 x I_width/8 52 | nn.Conv2d(256, 512, 3, 1, 1, bias=False), nn.BatchNorm2d(512), nn.ReLU(True), 53 | nn.AdaptiveAvgPool2d(1) # batch_size x 512 54 | ) 55 | 56 | self.localization_fc1 = nn.Sequential(nn.Linear(512, 256), nn.ReLU(True)) 57 | self.localization_fc2 = nn.Linear(256, self.F * 2) 58 | 59 | # Init fc2 in LocalizationNetwork 60 | self.localization_fc2.weight.data.fill_(0) 61 | """ see RARE paper Fig. 6 (a) """ 62 | ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2)) 63 | ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2)) 64 | ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2)) 65 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) 66 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) 67 | initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) 68 | self.localization_fc2.bias.data = torch.from_numpy(initial_bias).float().view(-1) 69 | 70 | def forward(self, batch_I): 71 | """ 72 | input: batch_I : Batch Input Image [batch_size x I_channel_num x I_height x I_width] 73 | output: batch_C_prime : Predicted coordinates of fiducial points for input batch [batch_size x F x 2] 74 | """ 75 | batch_size = batch_I.size(0) 76 | features = self.conv(batch_I).view(batch_size, -1) 77 | batch_C_prime = self.localization_fc2(self.localization_fc1(features)).view(batch_size, self.F, 2) 78 | return batch_C_prime 79 | 80 | 81 | class GridGenerator(nn.Module): 82 | """ Grid Generator of RARE, which produces P_prime by multipling T with P """ 83 | 84 | def __init__(self, F, I_r_size): 85 | """ Generate P_hat and inv_delta_C for later """ 86 | super(GridGenerator, self).__init__() 87 | self.eps = 1e-6 88 | self.I_r_height, self.I_r_width = I_r_size 89 | self.F = F 90 | self.C = self._build_C(self.F) # F x 2 91 | self.P = self._build_P(self.I_r_width, self.I_r_height) 92 | self.register_buffer("inv_delta_C", torch.tensor(self._build_inv_delta_C(self.F, self.C)).float()) # F+3 x F+3 93 | self.register_buffer("P_hat", torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float()) # n x F+3 94 | 95 | def _build_C(self, F): 96 | """ Return coordinates of fiducial points in I_r; C """ 97 | ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2)) 98 | ctrl_pts_y_top = -1 * np.ones(int(F / 2)) 99 | ctrl_pts_y_bottom = np.ones(int(F / 2)) 100 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) 101 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) 102 | C = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) 103 | return C # F x 2 104 | 105 | def _build_inv_delta_C(self, F, C): 106 | """ Return inv_delta_C which is needed to calculate T """ 107 | hat_C = np.zeros((F, F), dtype=float) # F x F 108 | for i in range(0, F): 109 | for j in range(i, F): 110 | r = np.linalg.norm(C[i] - C[j]) 111 | hat_C[i, j] = r 112 | hat_C[j, i] = r 113 | np.fill_diagonal(hat_C, 1) 114 | hat_C = (hat_C ** 2) * np.log(hat_C) 115 | # print(C.shape, hat_C.shape) 116 | delta_C = np.concatenate( # F+3 x F+3 117 | [ 118 | np.concatenate([np.ones((F, 1)), C, hat_C], axis=1), # F x F+3 119 | np.concatenate([np.zeros((2, 3)), np.transpose(C)], axis=1), # 2 x F+3 120 | np.concatenate([np.zeros((1, 3)), np.ones((1, F))], axis=1) # 1 x F+3 121 | ], 122 | axis=0 123 | ) 124 | inv_delta_C = np.linalg.inv(delta_C) 125 | return inv_delta_C # F+3 x F+3 126 | 127 | def _build_P(self, I_r_width, I_r_height): 128 | I_r_grid_x = (np.arange(-I_r_width, I_r_width, 2) + 1.0) / I_r_width # self.I_r_width 129 | I_r_grid_y = (np.arange(-I_r_height, I_r_height, 2) + 1.0) / I_r_height # self.I_r_height 130 | P = np.stack( # self.I_r_width x self.I_r_height x 2 131 | np.meshgrid(I_r_grid_x, I_r_grid_y), 132 | axis=2 133 | ) 134 | return P.reshape([-1, 2]) # n (= self.I_r_width x self.I_r_height) x 2 135 | 136 | def _build_P_hat(self, F, C, P): 137 | n = P.shape[0] # n (= self.I_r_width x self.I_r_height) 138 | P_tile = np.tile(np.expand_dims(P, axis=1), (1, F, 1)) # n x 2 -> n x 1 x 2 -> n x F x 2 139 | C_tile = np.expand_dims(C, axis=0) # 1 x F x 2 140 | P_diff = P_tile - C_tile # n x F x 2 141 | rbf_norm = np.linalg.norm(P_diff, ord=2, axis=2, keepdims=False) # n x F 142 | rbf = np.multiply(np.square(rbf_norm), np.log(rbf_norm + self.eps)) # n x F 143 | P_hat = np.concatenate([np.ones((n, 1)), P, rbf], axis=1) 144 | return P_hat # n x F+3 145 | 146 | def build_P_prime(self, batch_C_prime): 147 | """ Generate Grid from batch_C_prime [batch_size x F x 2] """ 148 | batch_size = batch_C_prime.size(0) 149 | batch_inv_delta_C = self.inv_delta_C.repeat(batch_size, 1, 1) 150 | batch_P_hat = self.P_hat.repeat(batch_size, 1, 1) 151 | batch_C_prime_with_zeros = torch.cat((batch_C_prime, torch.zeros( 152 | batch_size, 3, 2).float().cuda()), dim=1) # batch_size x F+3 x 2 153 | batch_T = torch.bmm(batch_inv_delta_C, batch_C_prime_with_zeros) # batch_size x F+3 x 2 154 | batch_P_prime = torch.bmm(batch_P_hat, batch_T) # batch_size x n x 2 155 | return batch_P_prime # batch_size x n x 2 156 | -------------------------------------------------------------------------------- /model/lemma.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | from model.Position_aware_module import PositionAwareModule,Location_enhancement_Multimodal_alignment 6 | from .tps_spatial_transformer import TPSSpatialTransformer 7 | from .stn_head import STNHead 8 | from torchvision import transforms 9 | SHUT_BN = False 10 | 11 | def showPIL(img,batch_id=0): 12 | img = img[batch_id,:3,...]*255 13 | img = torch.as_tensor(img.detach().cpu(),dtype=torch.uint8).numpy() 14 | img = transforms.functional.to_pil_image(img.transpose((1,2,0))) 15 | img.show() 16 | 17 | 18 | class AffineModulate(nn.Module): 19 | def __init__(self, channel=64): 20 | super(AffineModulate, self).__init__() 21 | self.BN = nn.BatchNorm2d(channel) 22 | self.conv1x1_1 = nn.Sequential( 23 | nn.Conv2d(channel*2, channel, 1), 24 | nn.BatchNorm2d(channel), 25 | nn.PReLU(), 26 | nn.Conv2d(channel, channel, 1) 27 | ) 28 | self.conv1x1_2 = nn.Sequential( 29 | nn.Conv2d(channel*2, channel, 1), 30 | nn.BatchNorm2d(channel), 31 | nn.PReLU(), 32 | nn.Conv2d(channel, channel, 1) 33 | ) 34 | self.conv1x1_3 = nn.Sequential( 35 | nn.Conv2d(channel*2, channel, 1), 36 | nn.BatchNorm2d(channel), 37 | nn.PReLU(), 38 | nn.Conv2d(channel, channel, 1) 39 | ) 40 | self.channel_attn = nn.Sequential( 41 | # Linear 42 | nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(16, 64), groups=64), # 全局深度可分离 43 | nn.Conv2d(channel, channel, 1), 44 | nn.ReLU(inplace=True), 45 | nn.Conv2d(channel, channel, 1), 46 | nn.ReLU(inplace=True), 47 | nn.Conv2d(channel, channel, 1), 48 | nn.Sigmoid() 49 | ) 50 | self.aggregation = nn.Sequential( 51 | nn.Conv2d(channel,channel,1), 52 | nn.BatchNorm2d(channel), 53 | nn.PReLU(), 54 | nn.Conv2d(channel,channel,1) 55 | ) 56 | self.BN2 = nn.BatchNorm2d(channel) 57 | 58 | def forward(self,image_feature,tp_map): 59 | tp_map = self.BN(tp_map) 60 | cat_feature = torch.cat([image_feature,tp_map],dim=1) 61 | x1,x2,x3 = self.conv1x1_1(cat_feature),self.conv1x1_2(cat_feature),self.conv1x1_3(cat_feature) 62 | x1 = self.channel_attn(x1) 63 | x2 = self.aggregation(x1*x2) 64 | return self.BN2(x3+x2) 65 | 66 | 67 | class LEMMA(nn.Module): 68 | def __init__(self, 69 | scale_factor=2, 70 | width=128, 71 | height=32, 72 | STN=False, 73 | srb_nums=5, 74 | mask=True, 75 | hidden_units=32, 76 | word_vec_d=300, 77 | text_emb=37, 78 | out_text_channels=64, 79 | feature_rotate=False, 80 | rotate_train=3., 81 | cfg=None): 82 | super(LEMMA, self).__init__() 83 | in_planes = 3 84 | if mask: 85 | in_planes = 4 86 | assert math.log(scale_factor, 2) % 1 == 0 87 | upsample_block_num = int(math.log(scale_factor, 2)) 88 | self.block1 = nn.Sequential( 89 | nn.Conv2d(in_planes, 2 * hidden_units, kernel_size=9, padding=4), # 256 out feature 90 | nn.PReLU() 91 | ) 92 | self.srb_nums = srb_nums 93 | for i in range(srb_nums): 94 | setattr(self, 'block%d' % (i + 2), RecurrentResidualBlockAffine(2 * hidden_units, out_text_channels)) 95 | self.feature_rotate = feature_rotate 96 | self.rotate_train = rotate_train 97 | if not SHUT_BN: 98 | setattr(self, 'block%d' % (srb_nums + 2), 99 | nn.Sequential( 100 | nn.Conv2d(2 * hidden_units, 2 * hidden_units, kernel_size=3, padding=1), 101 | nn.BatchNorm2d(2 * hidden_units) 102 | )) 103 | else: 104 | setattr(self, 'block%d' % (srb_nums + 2), 105 | nn.Sequential( 106 | nn.Conv2d(2 * hidden_units, 2 * hidden_units, kernel_size=3, padding=1), 107 | # nn.BatchNorm2d(2 * hidden_units) 108 | )) 109 | 110 | block_ = [UpsampleBLock(2 * hidden_units, 2) for _ in range(upsample_block_num)] 111 | block_.append(nn.Conv2d(2 * hidden_units, in_planes, kernel_size=9, padding=4)) 112 | setattr(self, 'block%d' % (srb_nums + 3), nn.Sequential(*block_)) 113 | self.tps_inputsize = [height // scale_factor, width // scale_factor] 114 | tps_outputsize = [height // scale_factor, width // scale_factor] 115 | num_control_points = 20 116 | tps_margins = [0.05, 0.05] 117 | self.stn = STN 118 | if self.stn: 119 | self.tps = TPSSpatialTransformer( 120 | output_image_size=tuple(tps_outputsize), 121 | num_control_points=num_control_points, 122 | margins=tuple(tps_margins)) 123 | 124 | self.stn_head = STNHead( 125 | in_planes=in_planes, 126 | num_ctrlpoints=num_control_points, 127 | activation='none', 128 | input_size=self.tps_inputsize) 129 | 130 | self.block_range = [k for k in range(2, self.srb_nums + 2)] 131 | self.position_prior = PositionAwareModule(cfg.PositionAware) 132 | # We implement both Location Enhancement Module and Multi-modal Alignment Module here 133 | self.guidanceGen = Location_enhancement_Multimodal_alignment(cfg.PositionAware) 134 | 135 | 136 | def forward(self, x): 137 | if self.stn and self.training: 138 | _, ctrl_points_x = self.stn_head(x) 139 | x, _ = self.tps(x, ctrl_points_x) 140 | 141 | x_ = x.clone()[:,:3,:,:].detach() 142 | x_ = F.interpolate(x_,scale_factor=2,mode='bicubic',align_corners=True) 143 | pos_prior = self.position_prior(x_) # 'attn_map, 'text_feature', 'text_logits', 'pt_lengths' 144 | 145 | block = {'1': self.block1(x)} 146 | padding_feature = block['1'] 147 | tp_map,pr_weights = self.guidanceGen(pos_prior,padding_feature) 148 | for i in range(self.srb_nums + 1): 149 | if i + 2 in self.block_range: 150 | block[str(i + 2)] = getattr(self, 'block%d' % (i + 2))(block[str(i + 1)], tp_map) 151 | else: 152 | block[str(i + 2)] = getattr(self, 'block%d' % (i + 2))(block[str(i + 1)]) 153 | 154 | block[str(self.srb_nums + 3)] = getattr(self, 'block%d' % (self.srb_nums + 3)) \ 155 | ((block['1'] + block[str(self.srb_nums + 2)])) 156 | output = torch.tanh(block[str(self.srb_nums + 3)]) 157 | self.block = block 158 | return output, pos_prior 159 | 160 | 161 | class RecurrentResidualBlock(nn.Module): 162 | def __init__(self, channels): 163 | super(RecurrentResidualBlock, self).__init__() 164 | self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) 165 | self.bn1 = nn.BatchNorm2d(channels) 166 | self.gru1 = GruBlock(channels, channels) 167 | # self.prelu = nn.ReLU() 168 | self.prelu = mish() 169 | self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) 170 | self.bn2 = nn.BatchNorm2d(channels) 171 | self.gru2 = GruBlock(channels, channels) 172 | 173 | def forward(self, x): 174 | residual = self.conv1(x) 175 | residual = self.bn1(residual) 176 | residual = self.prelu(residual) 177 | residual = self.conv2(residual) 178 | residual = self.bn2(residual) 179 | residual = self.gru1(residual.transpose(-1, -2)).transpose(-1, -2) 180 | 181 | return self.gru2(x + residual) 182 | 183 | 184 | 185 | 186 | class RecurrentResidualBlockAffine(nn.Module): 187 | def __init__(self, channels,text_channel=None): 188 | # channls = 64 189 | super(RecurrentResidualBlockAffine, self).__init__() 190 | self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) 191 | self.bn1 = nn.BatchNorm2d(channels) 192 | self.gru1 = GruBlock(channels, channels) 193 | self.prelu = mish() 194 | self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) 195 | self.bn2 = nn.BatchNorm2d(channels) 196 | self.gru2 = GruBlock(channels, channels) # + text_channels 197 | self.affine = AffineModulate(channel=channels) 198 | self.BN = nn.BatchNorm2d(channels) 199 | 200 | 201 | def forward(self, x, tp_map): 202 | residual = self.conv1(x) 203 | if not SHUT_BN: 204 | residual = self.bn1(residual) 205 | residual = self.prelu(residual) 206 | residual = self.conv2(residual) 207 | if not SHUT_BN: 208 | residual = self.bn2(residual) 209 | # AdaFM 210 | residual = self.affine(residual,tp_map) 211 | # SRB 212 | residual = self.gru1(residual.transpose(-1, -2)).transpose(-1, -2) 213 | 214 | return self.gru2(self.BN(x + residual)) 215 | 216 | 217 | 218 | 219 | class UpsampleBLock(nn.Module): 220 | def __init__(self, in_channels, up_scale): 221 | super(UpsampleBLock, self).__init__() 222 | self.conv = nn.Conv2d(in_channels, in_channels * up_scale ** 2, kernel_size=3, padding=1) 223 | 224 | self.pixel_shuffle = nn.PixelShuffle(up_scale) 225 | # self.prelu = nn.ReLU() 226 | self.prelu = mish() 227 | 228 | def forward(self, x): 229 | x = self.conv(x) 230 | x = self.pixel_shuffle(x) 231 | x = self.prelu(x) 232 | return x 233 | 234 | 235 | class mish(nn.Module): 236 | def __init__(self, ): 237 | super(mish, self).__init__() 238 | self.activated = True 239 | 240 | def forward(self, x): 241 | if self.activated: 242 | x = x * (torch.tanh(F.softplus(x))) 243 | return x 244 | 245 | 246 | class GruBlock(nn.Module): 247 | def __init__(self, in_channels, out_channels): 248 | super(GruBlock, self).__init__() 249 | assert out_channels % 2 == 0 250 | self.gru = nn.GRU(out_channels, out_channels // 2, bidirectional=True, batch_first=True) 251 | 252 | def forward(self, x): 253 | x = x.permute(0, 2, 3, 1).contiguous() 254 | b = x.size() 255 | x = x.view(b[0] * b[1], b[2], b[3]) 256 | self.gru.flatten_parameters() 257 | x, _ = self.gru(x) 258 | x = x.view(b[0], b[1], b[2], b[3]) 259 | x = x.permute(0, 3, 1, 2) 260 | return x 261 | -------------------------------------------------------------------------------- /model/moran/__init__.py: -------------------------------------------------------------------------------- 1 | from .moran import MORAN 2 | -------------------------------------------------------------------------------- /model/moran/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/moran/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /model/moran/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/moran/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /model/moran/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/moran/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /model/moran/__pycache__/asrn_res.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/moran/__pycache__/asrn_res.cpython-36.pyc -------------------------------------------------------------------------------- /model/moran/__pycache__/asrn_res.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/moran/__pycache__/asrn_res.cpython-37.pyc -------------------------------------------------------------------------------- /model/moran/__pycache__/asrn_res.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/moran/__pycache__/asrn_res.cpython-38.pyc -------------------------------------------------------------------------------- /model/moran/__pycache__/fracPickup.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/moran/__pycache__/fracPickup.cpython-36.pyc -------------------------------------------------------------------------------- /model/moran/__pycache__/fracPickup.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/moran/__pycache__/fracPickup.cpython-37.pyc -------------------------------------------------------------------------------- /model/moran/__pycache__/fracPickup.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/moran/__pycache__/fracPickup.cpython-38.pyc -------------------------------------------------------------------------------- /model/moran/__pycache__/moran.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/moran/__pycache__/moran.cpython-36.pyc -------------------------------------------------------------------------------- /model/moran/__pycache__/moran.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/moran/__pycache__/moran.cpython-37.pyc -------------------------------------------------------------------------------- /model/moran/__pycache__/moran.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/moran/__pycache__/moran.cpython-38.pyc -------------------------------------------------------------------------------- /model/moran/__pycache__/morn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/moran/__pycache__/morn.cpython-36.pyc -------------------------------------------------------------------------------- /model/moran/__pycache__/morn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/moran/__pycache__/morn.cpython-37.pyc -------------------------------------------------------------------------------- /model/moran/__pycache__/morn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/moran/__pycache__/morn.cpython-38.pyc -------------------------------------------------------------------------------- /model/moran/asrn_res.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | from torch.nn.parameter import Parameter 7 | from .fracPickup import fracPickup 8 | 9 | class BidirectionalLSTM(nn.Module): 10 | 11 | def __init__(self, nIn, nHidden, nOut): 12 | super(BidirectionalLSTM, self).__init__() 13 | 14 | self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True, dropout=0.3) 15 | self.embedding = nn.Linear(nHidden * 2, nOut) 16 | 17 | def forward(self, input): 18 | recurrent, _ = self.rnn(input) 19 | T, b, h = recurrent.size() 20 | t_rec = recurrent.view(T * b, h) 21 | 22 | output = self.embedding(t_rec) # [T * b, nOut] 23 | output = output.view(T, b, -1) 24 | 25 | return output 26 | 27 | class AttentionCell(nn.Module): 28 | def __init__(self, input_size, hidden_size, num_embeddings=128, CUDA=True): 29 | super(AttentionCell, self).__init__() 30 | self.i2h = nn.Linear(input_size, hidden_size,bias=False) 31 | self.h2h = nn.Linear(hidden_size, hidden_size) 32 | self.score = nn.Linear(hidden_size, 1, bias=False) 33 | self.rnn = nn.GRUCell(input_size+num_embeddings, hidden_size) 34 | self.hidden_size = hidden_size 35 | self.input_size = input_size 36 | self.num_embeddings = num_embeddings 37 | self.fracPickup = fracPickup(CUDA=CUDA) 38 | 39 | def forward(self, prev_hidden, feats, cur_embeddings, test=False): 40 | nT = feats.size(0) 41 | nB = feats.size(1) 42 | nC = feats.size(2) 43 | hidden_size = self.hidden_size 44 | 45 | feats_proj = self.i2h(feats.view(-1,nC)) 46 | prev_hidden_proj = self.h2h(prev_hidden).view(1,nB, hidden_size).expand(nT, nB, hidden_size).contiguous().view(-1, hidden_size) 47 | emition = self.score(F.tanh(feats_proj + prev_hidden_proj).view(-1, hidden_size)).view(nT,nB) 48 | 49 | alpha = F.softmax(emition, 0) # nT * nB 50 | 51 | if not test: 52 | alpha_fp = self.fracPickup(alpha.transpose(0,1).contiguous().unsqueeze(1).unsqueeze(2)).squeeze() 53 | context = (feats * alpha_fp.transpose(0,1).contiguous().view(nT,nB,1).expand(nT, nB, nC)).sum(0).squeeze(0) # nB * nC 54 | if len(context.size()) == 1: 55 | context = context.unsqueeze(0) 56 | context = torch.cat([context, cur_embeddings], 1) 57 | cur_hidden = self.rnn(context, prev_hidden) 58 | return cur_hidden, alpha_fp 59 | else: 60 | context = (feats * alpha.view(nT,nB,1).expand(nT, nB, nC)).sum(0).squeeze(0) # nB * nC 61 | if len(context.size()) == 1: 62 | context = context.unsqueeze(0) 63 | context = torch.cat([context, cur_embeddings], 1) 64 | cur_hidden = self.rnn(context, prev_hidden) 65 | return cur_hidden, alpha 66 | 67 | class Attention(nn.Module): 68 | def __init__(self, input_size, hidden_size, num_classes, num_embeddings=128, CUDA=True): 69 | super(Attention, self).__init__() 70 | self.attention_cell = AttentionCell(input_size, hidden_size, num_embeddings, CUDA=CUDA) 71 | self.input_size = input_size 72 | self.hidden_size = hidden_size 73 | self.generator = nn.Linear(hidden_size, num_classes) 74 | self.char_embeddings = Parameter(torch.randn(num_classes+1, num_embeddings)) 75 | self.num_embeddings = num_embeddings 76 | self.num_classes = num_classes 77 | self.cuda = CUDA 78 | 79 | # targets is nT * nB 80 | def forward(self, feats, text_length, text, test=False): 81 | 82 | nT = feats.size(0) 83 | nB = feats.size(1) 84 | nC = feats.size(2) 85 | hidden_size = self.hidden_size 86 | input_size = self.input_size 87 | # from IPython import embed 88 | # embed() 89 | assert(input_size == nC) 90 | assert(nB == text_length.numel()) 91 | 92 | num_steps = text_length.data.max() 93 | num_labels = text_length.data.sum() 94 | 95 | if not test: 96 | 97 | targets = torch.zeros(nB, num_steps+1).long() 98 | if self.cuda: 99 | targets = targets.cuda() 100 | start_id = 0 101 | 102 | for i in range(nB): 103 | targets[i][1:1+text_length.data[i]] = text.data[start_id:start_id+text_length.data[i]]+1 104 | start_id = start_id+text_length.data[i] 105 | targets = Variable(targets.transpose(0,1).contiguous()) 106 | 107 | output_hiddens = Variable(torch.zeros(num_steps, nB, hidden_size).type_as(feats.data)) 108 | hidden = Variable(torch.zeros(nB,hidden_size).type_as(feats.data)) 109 | 110 | for i in range(num_steps): 111 | cur_embeddings = self.char_embeddings.index_select(0, targets[i]) 112 | hidden, alpha = self.attention_cell(hidden, feats, cur_embeddings, test) 113 | output_hiddens[i] = hidden 114 | 115 | new_hiddens = Variable(torch.zeros(num_labels, hidden_size).type_as(feats.data)) 116 | b = 0 117 | start = 0 118 | 119 | for length in text_length.data: 120 | new_hiddens[start:start+length] = output_hiddens[0:length,b,:] 121 | start = start + length 122 | b = b + 1 123 | 124 | probs = self.generator(new_hiddens) 125 | return probs 126 | 127 | else: 128 | 129 | hidden = Variable(torch.zeros(nB,hidden_size).type_as(feats.data)) 130 | targets_temp = Variable(torch.zeros(nB).long().contiguous()) 131 | probs = Variable(torch.zeros(nB*num_steps, self.num_classes)) 132 | if self.cuda: 133 | targets_temp = targets_temp.cuda() 134 | probs = probs.cuda() 135 | 136 | for i in range(num_steps): 137 | cur_embeddings = self.char_embeddings.index_select(0, targets_temp) 138 | hidden, alpha = self.attention_cell(hidden, feats, cur_embeddings, test) 139 | hidden2class = self.generator(hidden) 140 | probs[i*nB:(i+1)*nB] = hidden2class 141 | _, targets_temp = hidden2class.max(1) 142 | targets_temp += 1 143 | 144 | probs = probs.view(num_steps, nB, self.num_classes).permute(1, 0, 2).contiguous() 145 | probs = probs.view(-1, self.num_classes).contiguous() 146 | probs_res = Variable(torch.zeros(num_labels, self.num_classes).type_as(feats.data)) 147 | b = 0 148 | start = 0 149 | 150 | for length in text_length.data: 151 | probs_res[start:start+length] = probs[b*num_steps:b*num_steps+length] 152 | start = start + length 153 | b = b + 1 154 | 155 | return probs_res 156 | 157 | class Residual_block(nn.Module): 158 | def __init__(self, c_in, c_out, stride): 159 | super(Residual_block, self).__init__() 160 | self.downsample = None 161 | flag = False 162 | if isinstance(stride, tuple): 163 | if stride[0] > 1: 164 | self.downsample = nn.Sequential(nn.Conv2d(c_in, c_out, 3, stride, 1),nn.BatchNorm2d(c_out, momentum=0.01)) 165 | flag = True 166 | else: 167 | if stride > 1: 168 | self.downsample = nn.Sequential(nn.Conv2d(c_in, c_out, 3, stride, 1),nn.BatchNorm2d(c_out, momentum=0.01)) 169 | flag = True 170 | if flag: 171 | self.conv1 = nn.Sequential(nn.Conv2d(c_in, c_out, 3, stride, 1), 172 | nn.BatchNorm2d(c_out, momentum=0.01)) 173 | else: 174 | self.conv1 = nn.Sequential(nn.Conv2d(c_in, c_out, 1, stride, 0), 175 | nn.BatchNorm2d(c_out, momentum=0.01)) 176 | self.conv2 = nn.Sequential(nn.Conv2d(c_out, c_out, 3, 1, 1), 177 | nn.BatchNorm2d(c_out, momentum=0.01)) 178 | self.relu = nn.ReLU() 179 | 180 | def forward(self,x): 181 | residual = x 182 | conv1 = self.conv1(x) 183 | conv2 = self.conv2(conv1) 184 | if self.downsample is not None: 185 | residual = self.downsample(residual) 186 | return self.relu(residual + conv2) 187 | 188 | class ResNet(nn.Module): 189 | def __init__(self,c_in): 190 | super(ResNet,self).__init__() 191 | self.block0 = nn.Sequential(nn.Conv2d(c_in, 32, 3, 1, 1),nn.BatchNorm2d(32, momentum=0.01)) 192 | self.block1 = self._make_layer(32, 32, 2, 3) 193 | self.block2 = self._make_layer(32, 64, 2, 4) 194 | self.block3 = self._make_layer(64, 128, (2,1), 6) 195 | self.block4 = self._make_layer(128, 256, (2,1), 6) 196 | self.block5 = self._make_layer(256, 512, (2,1), 3) 197 | 198 | def _make_layer(self,c_in,c_out,stride,repeat=3): 199 | layers = [] 200 | layers.append(Residual_block(c_in, c_out, stride)) 201 | for i in range(repeat - 1): 202 | layers.append(Residual_block(c_out, c_out, 1)) 203 | return nn.Sequential(*layers) 204 | 205 | def forward(self,x): 206 | block0 = self.block0(x) 207 | block1 = self.block1(block0) 208 | block2 = self.block2(block1) 209 | block3 = self.block3(block2) 210 | block4 = self.block4(block3) 211 | block5 = self.block5(block4) 212 | return block5 213 | 214 | class ASRN(nn.Module): 215 | 216 | def __init__(self, imgH, nc, nclass, nh, BidirDecoder=False, CUDA=True): 217 | super(ASRN, self).__init__() 218 | assert imgH % 16 == 0, 'imgH must be a multiple of 16' 219 | 220 | self.cnn = ResNet(nc) 221 | 222 | self.rnn = nn.Sequential( 223 | BidirectionalLSTM(512, nh, nh), 224 | BidirectionalLSTM(nh, nh, nh), 225 | ) 226 | 227 | self.BidirDecoder = BidirDecoder 228 | if self.BidirDecoder: 229 | self.attentionL2R = Attention(nh, nh, nclass, 256, CUDA=CUDA) 230 | self.attentionR2L = Attention(nh, nh, nclass, 256, CUDA=CUDA) 231 | else: 232 | self.attention = Attention(nh, nh, nclass, 256, CUDA=CUDA) 233 | 234 | for m in self.modules(): 235 | if isinstance(m, nn.Conv2d): 236 | nn.init.kaiming_normal(m.weight, mode='fan_out', a=0) 237 | elif isinstance(m, nn.BatchNorm2d): 238 | nn.init.constant(m.weight, 1) 239 | nn.init.constant(m.bias, 0) 240 | 241 | def forward(self, input, length, text, text_rev, test=False): 242 | # conv features 243 | conv = self.cnn(input) 244 | 245 | b, c, h, w = conv.size() 246 | assert h == 1, "the height of conv must be 1" 247 | conv = conv.squeeze(2) 248 | conv = conv.permute(2, 0, 1).contiguous() # [w, b, c] 249 | 250 | # rnn features 251 | rnn = self.rnn(conv) 252 | 253 | if self.BidirDecoder: 254 | outputL2R = self.attentionL2R(rnn, length, text, test) 255 | outputR2L = self.attentionR2L(rnn, length, text_rev, test) 256 | return outputL2R, outputR2L 257 | else: 258 | output = self.attention(rnn, length, text, test) 259 | return output 260 | -------------------------------------------------------------------------------- /model/moran/fracPickup.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import numpy as np 5 | import numpy.random as npr 6 | 7 | class fracPickup(nn.Module): 8 | 9 | def __init__(self, CUDA=True): 10 | super(fracPickup, self).__init__() 11 | self.cuda = CUDA 12 | 13 | def forward(self, x): 14 | x_shape = x.size() 15 | assert len(x_shape) == 4 16 | assert x_shape[2] == 1 17 | 18 | fracPickup_num = 1 19 | 20 | h_list = 1. 21 | w_list = np.arange(x_shape[3])*2./(x_shape[3]-1)-1 22 | for i in range(fracPickup_num): 23 | idx = int(npr.rand()*len(w_list)) 24 | if idx <= 0 or idx >= x_shape[3]-1: 25 | continue 26 | beta = npr.rand()/4. 27 | value0 = (beta*w_list[idx] + (1-beta)*w_list[idx-1]) 28 | value1 = (beta*w_list[idx-1] + (1-beta)*w_list[idx]) 29 | w_list[idx-1] = value0 30 | w_list[idx] = value1 31 | 32 | grid = np.meshgrid( 33 | w_list, 34 | h_list, 35 | indexing='ij' 36 | ) 37 | grid = np.stack(grid, axis=-1) 38 | grid = np.transpose(grid, (1, 0, 2)) 39 | grid = np.expand_dims(grid, 0) 40 | grid = np.tile(grid, [x_shape[0], 1, 1, 1]) 41 | grid = torch.from_numpy(grid).type(x.data.type()) 42 | if self.cuda: 43 | grid = grid.cuda() 44 | self.grid = Variable(grid, requires_grad=False) 45 | 46 | x_offset = nn.functional.grid_sample(x, self.grid) 47 | 48 | return x_offset 49 | -------------------------------------------------------------------------------- /model/moran/moran.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .morn import MORN 3 | from .asrn_res import ASRN 4 | 5 | 6 | class MORAN(nn.Module): 7 | 8 | def __init__(self, nc, nclass, nh, targetH, targetW, BidirDecoder=False, 9 | inputDataType='torch.cuda.FloatTensor', maxBatch=256, CUDA=True): 10 | super(MORAN, self).__init__() 11 | self.MORN = MORN(nc, targetH, targetW, inputDataType, maxBatch, CUDA) 12 | self.ASRN = ASRN(targetH, nc, nclass, nh, BidirDecoder, CUDA) 13 | 14 | def forward(self, x, length, text, text_rev, test=False, debug=False): 15 | if debug: 16 | x_rectified, demo = self.MORN(x, test, debug=debug) 17 | preds = self.ASRN(x_rectified, length, text, text_rev, test) 18 | return preds, demo 19 | else: 20 | x_rectified = self.MORN(x, test, debug=debug) 21 | preds = self.ASRN(x_rectified, length, text, text_rev, test) 22 | return preds 23 | -------------------------------------------------------------------------------- /model/moran/morn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import numpy as np 5 | 6 | class MORN(nn.Module): 7 | def __init__(self, nc, targetH, targetW, inputDataType='torch.cuda.FloatTensor', maxBatch=256, CUDA=True): 8 | super(MORN, self).__init__() 9 | self.targetH = targetH 10 | self.targetW = targetW 11 | self.inputDataType = inputDataType 12 | self.maxBatch = maxBatch 13 | self.cuda = CUDA 14 | 15 | self.cnn = nn.Sequential( 16 | nn.MaxPool2d(2, 2), 17 | nn.Conv2d(nc, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.ReLU(True), nn.MaxPool2d(2, 2), 18 | nn.Conv2d(64, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.ReLU(True), nn.MaxPool2d(2, 2), 19 | nn.Conv2d(128, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.ReLU(True), 20 | nn.Conv2d(64, 16, 3, 1, 1), nn.BatchNorm2d(16), nn.ReLU(True), 21 | nn.Conv2d(16, 1, 3, 1, 1), nn.BatchNorm2d(1) 22 | ) 23 | 24 | self.pool = nn.MaxPool2d(2, 1) 25 | 26 | h_list = np.arange(self.targetH)*2./(self.targetH-1)-1 27 | w_list = np.arange(self.targetW)*2./(self.targetW-1)-1 28 | 29 | grid = np.meshgrid( 30 | w_list, 31 | h_list, 32 | indexing='ij' 33 | ) 34 | grid = np.stack(grid, axis=-1) 35 | grid = np.transpose(grid, (1, 0, 2)) 36 | grid = np.expand_dims(grid, 0) 37 | grid = np.tile(grid, [maxBatch, 1, 1, 1]) 38 | grid = torch.from_numpy(grid).type(self.inputDataType) 39 | if self.cuda: 40 | grid = grid.cuda() 41 | 42 | self.grid = Variable(grid, requires_grad=False) 43 | self.grid_x = self.grid[:, :, :, 0].unsqueeze(3) 44 | self.grid_y = self.grid[:, :, :, 1].unsqueeze(3) 45 | 46 | def forward(self, x, test, enhance=1, debug=False): 47 | 48 | if not test and np.random.random() > 0.5: 49 | return nn.functional.upsample(x, size=(self.targetH, self.targetW), mode='bilinear') 50 | if not test: 51 | enhance = 0 52 | 53 | assert x.size(0) <= self.maxBatch 54 | assert x.data.type() == self.inputDataType 55 | 56 | grid = self.grid[:x.size(0)].to(x.device) 57 | grid_x = self.grid_x[:x.size(0)].to(x.device) 58 | grid_y = self.grid_y[:x.size(0)].to(x.device) 59 | x_small = nn.functional.upsample(x, size=(self.targetH, self.targetW), mode='bilinear') 60 | 61 | offsets = self.cnn(x_small) 62 | offsets_posi = nn.functional.relu(offsets, inplace=False) 63 | offsets_nega = nn.functional.relu(-offsets, inplace=False) 64 | offsets_pool = self.pool(offsets_posi) - self.pool(offsets_nega) 65 | 66 | offsets_grid = nn.functional.grid_sample(offsets_pool, grid) 67 | offsets_grid = offsets_grid.permute(0, 2, 3, 1).contiguous() 68 | offsets_x = torch.cat([grid_x, grid_y + offsets_grid], 3) 69 | x_rectified = nn.functional.grid_sample(x, offsets_x) 70 | 71 | # print("x device:", x.device, offsets_grid.device) 72 | 73 | for iteration in range(enhance): 74 | offsets = self.cnn(x_rectified) 75 | 76 | offsets_posi = nn.functional.relu(offsets, inplace=False) 77 | offsets_nega = nn.functional.relu(-offsets, inplace=False) 78 | offsets_pool = self.pool(offsets_posi) - self.pool(offsets_nega) 79 | 80 | offsets_grid += nn.functional.grid_sample(offsets_pool, grid).permute(0, 2, 3, 1).contiguous() 81 | offsets_x = torch.cat([grid_x, grid_y + offsets_grid], 3) 82 | x_rectified = nn.functional.grid_sample(x, offsets_x) 83 | 84 | if debug: 85 | 86 | offsets_mean = torch.mean(offsets_grid.view(x.size(0), -1), 1) 87 | offsets_max, _ = torch.max(offsets_grid.view(x.size(0), -1), 1) 88 | offsets_min, _ = torch.min(offsets_grid.view(x.size(0), -1), 1) 89 | 90 | import matplotlib.pyplot as plt 91 | from colour import Color 92 | from torchvision import transforms 93 | import cv2 94 | 95 | alpha = 0.7 96 | density_range = 256 97 | color_map = np.empty([self.targetH, self.targetW, 3], dtype=int) 98 | cmap = plt.get_cmap("rainbow") 99 | blue = Color("blue") 100 | hex_colors = list(blue.range_to(Color("red"), density_range)) 101 | rgb_colors = [[rgb * 255 for rgb in color.rgb] for color in hex_colors][::-1] 102 | to_pil_image = transforms.ToPILImage() 103 | 104 | for i in range(x.size(0)): 105 | 106 | img_small = x_small[i].data.cpu().mul_(0.5).add_(0.5) 107 | img = to_pil_image(img_small) 108 | img = np.array(img) 109 | if len(img.shape) == 2: 110 | img = cv2.merge([img.copy()]*3) 111 | img_copy = img.copy() 112 | 113 | v_max = offsets_max.data[i] 114 | v_min = offsets_min.data[i] 115 | img_offsets = (offsets_grid[i]).view(1, self.targetH, self.targetW).add_(-v_min).mul_(1./(v_max-v_min)).data.cpu() 116 | img_offsets = to_pil_image(img_offsets) 117 | img_offsets = np.array(img_offsets) 118 | color_map = np.empty([self.targetH, self.targetW, 3], dtype=int) 119 | for h_i in range(self.targetH): 120 | for w_i in range(self.targetW): 121 | color_map[h_i][w_i] = rgb_colors[int(img_offsets[h_i, w_i]/256.*density_range)] 122 | color_map = color_map.astype(np.uint8) 123 | cv2.addWeighted(color_map, alpha, img_copy, 1-alpha, 0, img_copy) 124 | 125 | img_processed = x_rectified[i].data.cpu().mul_(0.5).add_(0.5) 126 | img_processed = to_pil_image(img_processed) 127 | img_processed = np.array(img_processed) 128 | if len(img_processed.shape) == 2: 129 | img_processed = cv2.merge([img_processed.copy()]*3) 130 | 131 | total_img = np.ones([self.targetH, self.targetW*3+10, 3], dtype=int)*255 132 | total_img[0:self.targetH, 0:self.targetW] = img 133 | total_img[0:self.targetH, self.targetW+5:2*self.targetW+5] = img_copy 134 | total_img[0:self.targetH, self.targetW*2+10:3*self.targetW+10] = img_processed 135 | total_img = cv2.resize(total_img.astype(np.uint8), (300, 50)) 136 | # cv2.imshow("Input_Offsets_Output", total_img) 137 | # cv2.waitKey() 138 | 139 | return x_rectified, total_img 140 | 141 | return x_rectified 142 | -------------------------------------------------------------------------------- /model/parseq/__pycache__/modules.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/parseq/__pycache__/modules.cpython-38.pyc -------------------------------------------------------------------------------- /model/parseq/__pycache__/parseq.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/parseq/__pycache__/parseq.cpython-38.pyc -------------------------------------------------------------------------------- /model/parseq/__pycache__/parseq_tokenizer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/parseq/__pycache__/parseq_tokenizer.cpython-38.pyc -------------------------------------------------------------------------------- /model/parseq/modules.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional 3 | 4 | import torch 5 | from torch import nn as nn, Tensor 6 | from torch.nn import functional as F 7 | from torch.nn.modules import transformer 8 | 9 | from timm.models.vision_transformer import VisionTransformer, PatchEmbed 10 | 11 | 12 | class DecoderLayer(nn.Module): 13 | """A Transformer decoder layer supporting two-stream attention (XLNet) 14 | This implements a pre-LN decoder, as opposed to the post-LN default in PyTorch.""" 15 | 16 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation='gelu', 17 | layer_norm_eps=1e-5): 18 | super().__init__() 19 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True) 20 | self.cross_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True) 21 | # Implementation of Feedforward model 22 | self.linear1 = nn.Linear(d_model, dim_feedforward) 23 | self.dropout = nn.Dropout(dropout) 24 | self.linear2 = nn.Linear(dim_feedforward, d_model) 25 | 26 | self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) 27 | self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps) 28 | self.norm_q = nn.LayerNorm(d_model, eps=layer_norm_eps) 29 | self.norm_c = nn.LayerNorm(d_model, eps=layer_norm_eps) 30 | self.dropout1 = nn.Dropout(dropout) 31 | self.dropout2 = nn.Dropout(dropout) 32 | self.dropout3 = nn.Dropout(dropout) 33 | 34 | self.activation = transformer._get_activation_fn(activation) 35 | 36 | def __setstate__(self, state): 37 | if 'activation' not in state:#fasle 38 | state['activation'] = F.gelu 39 | super().__setstate__(state) 40 | 41 | def forward_stream(self, tgt: Tensor, tgt_norm: Tensor, tgt_kv: Tensor, memory: Tensor, tgt_mask: Optional[Tensor], 42 | tgt_key_padding_mask: Optional[Tensor]): 43 | """Forward pass for a single stream (i.e. content or query) 44 | tgt_norm is just a LayerNorm'd tgt. Added as a separate parameter for efficiency. 45 | Both tgt_kv and memory are expected to be LayerNorm'd too. 46 | memory is LayerNorm'd by ViT. 47 | """ 48 | tgt2, sa_weights = self.self_attn(tgt_norm, tgt_kv, tgt_kv, attn_mask=tgt_mask, 49 | key_padding_mask=tgt_key_padding_mask) 50 | tgt = tgt + self.dropout1(tgt2) 51 | 52 | tgt2, ca_weights = self.cross_attn(self.norm1(tgt), memory, memory) 53 | tgt = tgt + self.dropout2(tgt2) 54 | 55 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(self.norm2(tgt))))) 56 | tgt = tgt + self.dropout3(tgt2) 57 | return tgt, sa_weights, ca_weights 58 | 59 | def forward(self, query, content, memory, query_mask: Optional[Tensor] = None, content_mask: Optional[Tensor] = None, 60 | content_key_padding_mask: Optional[Tensor] = None, update_content: bool = True): 61 | query_norm = self.norm_q(query)#self-attn: q 62 | content_norm = self.norm_c(content)#self_attn: k v 63 | query = self.forward_stream(query, query_norm, content_norm, memory, query_mask, content_key_padding_mask)[0] 64 | if update_content:#false 65 | content = self.forward_stream(content, content_norm, content_norm, memory, content_mask, 66 | content_key_padding_mask)[0] 67 | return query, content 68 | 69 | 70 | class Decoder(nn.Module): 71 | __constants__ = ['norm'] 72 | 73 | def __init__(self, decoder_layer, num_layers, norm): 74 | super().__init__() 75 | self.layers = transformer._get_clones(decoder_layer, num_layers) 76 | self.num_layers = num_layers 77 | self.norm = norm 78 | 79 | def forward(self, query, content, memory, query_mask: Optional[Tensor] = None, content_mask: Optional[Tensor] = None, 80 | content_key_padding_mask: Optional[Tensor] = None): 81 | for i, mod in enumerate(self.layers): 82 | last = i == len(self.layers) - 1 83 | query, content = mod(query, content, memory, query_mask, content_mask, content_key_padding_mask, 84 | update_content=not last) 85 | query = self.norm(query) 86 | return query 87 | 88 | 89 | class Encoder(VisionTransformer): 90 | 91 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., 92 | qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed): 93 | super().__init__(img_size, patch_size, in_chans, embed_dim=embed_dim, depth=depth, num_heads=num_heads, 94 | mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, 95 | drop_path_rate=drop_path_rate, embed_layer=embed_layer, 96 | num_classes=0, global_pool='', class_token=False) # these disable the classifier head 97 | 98 | def forward(self, x): 99 | # Return all tokens 100 | return self.forward_features(x) # 直接使用timm的vision transformer 101 | 102 | 103 | class TokenEmbedding(nn.Module): 104 | 105 | def __init__(self, charset_size: int, embed_dim: int): 106 | super().__init__() 107 | self.embedding = nn.Embedding(charset_size, embed_dim) 108 | self.embed_dim = embed_dim 109 | 110 | def forward(self, tokens: torch.Tensor): 111 | return math.sqrt(self.embed_dim) * self.embedding(tokens) 112 | -------------------------------------------------------------------------------- /model/parseq/parseq_tokenizer.py: -------------------------------------------------------------------------------- 1 | import re 2 | from abc import ABC, abstractmethod 3 | from itertools import groupby 4 | from typing import List, Optional, Tuple 5 | import torch 6 | from torch import Tensor 7 | from torch.nn.utils.rnn import pad_sequence 8 | 9 | 10 | 11 | class BaseTokenizer(ABC): 12 | 13 | def __init__(self, charset: str, specials_first: tuple = (), specials_last: tuple = ()) -> None: 14 | self._itos = specials_first + tuple(charset) + specials_last 15 | self._stoi = {s: i for i, s in enumerate(self._itos)} 16 | 17 | def __len__(self): 18 | return len(self._itos) 19 | 20 | def _tok2ids(self, tokens: str) -> List[int]: 21 | return [self._stoi[s] for s in tokens] 22 | 23 | def _ids2tok(self, token_ids: List[int], join: bool = True) -> str: 24 | tokens = [self._itos[i] for i in token_ids] 25 | return ''.join(tokens) if join else tokens 26 | 27 | @abstractmethod 28 | def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor: 29 | """Encode a batch of labels to a representation suitable for the model. 30 | 31 | Args: 32 | labels: List of labels. Each can be of arbitrary length. 33 | device: Create tensor on this device. 34 | 35 | Returns: 36 | Batched tensor representation padded to the max label length. Shape: N, L 37 | """ 38 | raise NotImplementedError 39 | 40 | @abstractmethod 41 | def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]: 42 | """Internal method which performs the necessary filtering prior to decoding.""" 43 | raise NotImplementedError 44 | 45 | def decode(self, token_dists: Tensor, raw: bool = False) -> Tuple[List[str], List[Tensor]]: 46 | """Decode a batch of token distributions. 47 | 48 | Args: 49 | token_dists: softmax probabilities over the token distribution. Shape: N, L, C 50 | raw: return unprocessed labels (will return list of list of strings) 51 | 52 | Returns: 53 | list of string labels (arbitrary length) and 54 | their corresponding sequence probabilities as a list of Tensors 55 | """ 56 | batch_tokens = [] 57 | batch_probs = [] 58 | for dist in token_dists: 59 | probs, ids = dist.max(-1) # greedy selection 60 | if not raw: 61 | probs, ids = self._filter(probs, ids) 62 | tokens = self._ids2tok(ids, not raw) 63 | batch_tokens.append(tokens) 64 | batch_probs.append(probs) 65 | return batch_tokens, batch_probs 66 | 67 | 68 | class Tokenizer(BaseTokenizer): 69 | BOS = '[B]' 70 | EOS = '[E]' 71 | PAD = '[P]' 72 | 73 | def __init__(self, charset: str) -> None: 74 | specials_first = (self.EOS,) 75 | specials_last = (self.BOS, self.PAD) 76 | super().__init__(charset, specials_first, specials_last) 77 | self.eos_id, self.bos_id, self.pad_id = [self._stoi[s] for s in specials_first + specials_last] 78 | 79 | def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor: 80 | batch = [torch.as_tensor([self.bos_id] + self._tok2ids(y) + [self.eos_id], dtype=torch.long, device=device) 81 | for y in labels] 82 | return pad_sequence(batch, batch_first=True, padding_value=self.pad_id) 83 | 84 | def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]: 85 | ids = ids.tolist() 86 | try: 87 | eos_idx = ids.index(self.eos_id) 88 | except ValueError: 89 | eos_idx = len(ids) # Nothing to truncate. 90 | # Truncate after EOS 91 | ids = ids[:eos_idx] 92 | probs = probs[:eos_idx + 1] # but include prob. for EOS (if it exists) 93 | return probs, ids 94 | 95 | 96 | def get_parseq_tokenize(): 97 | charset_train = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~" 98 | tokenizer = Tokenizer(charset_train) 99 | return tokenizer 100 | 101 | -------------------------------------------------------------------------------- /model/recognizer/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from .recognizer_builder import RecognizerBuilder 3 | from .resnet_aster import * 4 | 5 | __factory = { 6 | 'ResNet_ASTER': ResNet_ASTER, 7 | } 8 | 9 | def names(): 10 | return sorted(__factory.keys()) 11 | 12 | 13 | def create(name, *args, **kwargs): 14 | """Create a model instance. 15 | 16 | Parameters 17 | ---------- 18 | name: str 19 | Model name. One of __factory 20 | pretrained: bool, optional 21 | If True, will use ImageNet pretrained model. Default: True 22 | num_classes: int, optional 23 | If positive, will change the original classifier the fit the new classifier with num_classes. Default: True 24 | with_words: bool, optional 25 | If True, the input of this model is the combination of image and word. Default: False 26 | """ 27 | if name not in __factory: 28 | raise KeyError('Unknown model:', name) 29 | return __factory[name](*args, **kwargs) -------------------------------------------------------------------------------- /model/recognizer/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/recognizer/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /model/recognizer/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/recognizer/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /model/recognizer/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/recognizer/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /model/recognizer/__pycache__/attention_recognition_head.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/recognizer/__pycache__/attention_recognition_head.cpython-36.pyc -------------------------------------------------------------------------------- /model/recognizer/__pycache__/attention_recognition_head.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/recognizer/__pycache__/attention_recognition_head.cpython-37.pyc -------------------------------------------------------------------------------- /model/recognizer/__pycache__/attention_recognition_head.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/recognizer/__pycache__/attention_recognition_head.cpython-38.pyc -------------------------------------------------------------------------------- /model/recognizer/__pycache__/recognizer_builder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/recognizer/__pycache__/recognizer_builder.cpython-36.pyc -------------------------------------------------------------------------------- /model/recognizer/__pycache__/recognizer_builder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/recognizer/__pycache__/recognizer_builder.cpython-37.pyc -------------------------------------------------------------------------------- /model/recognizer/__pycache__/recognizer_builder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/recognizer/__pycache__/recognizer_builder.cpython-38.pyc -------------------------------------------------------------------------------- /model/recognizer/__pycache__/resnet_aster.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/recognizer/__pycache__/resnet_aster.cpython-36.pyc -------------------------------------------------------------------------------- /model/recognizer/__pycache__/resnet_aster.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/recognizer/__pycache__/resnet_aster.cpython-37.pyc -------------------------------------------------------------------------------- /model/recognizer/__pycache__/resnet_aster.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/recognizer/__pycache__/resnet_aster.cpython-38.pyc -------------------------------------------------------------------------------- /model/recognizer/__pycache__/sequenceCrossEntropyLoss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/recognizer/__pycache__/sequenceCrossEntropyLoss.cpython-36.pyc -------------------------------------------------------------------------------- /model/recognizer/__pycache__/sequenceCrossEntropyLoss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/recognizer/__pycache__/sequenceCrossEntropyLoss.cpython-37.pyc -------------------------------------------------------------------------------- /model/recognizer/__pycache__/sequenceCrossEntropyLoss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/recognizer/__pycache__/sequenceCrossEntropyLoss.cpython-38.pyc -------------------------------------------------------------------------------- /model/recognizer/__pycache__/stn_head.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/recognizer/__pycache__/stn_head.cpython-36.pyc -------------------------------------------------------------------------------- /model/recognizer/__pycache__/stn_head.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/recognizer/__pycache__/stn_head.cpython-37.pyc -------------------------------------------------------------------------------- /model/recognizer/__pycache__/stn_head.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/recognizer/__pycache__/stn_head.cpython-38.pyc -------------------------------------------------------------------------------- /model/recognizer/__pycache__/tps_spatial_transformer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/recognizer/__pycache__/tps_spatial_transformer.cpython-36.pyc -------------------------------------------------------------------------------- /model/recognizer/__pycache__/tps_spatial_transformer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/recognizer/__pycache__/tps_spatial_transformer.cpython-37.pyc -------------------------------------------------------------------------------- /model/recognizer/__pycache__/tps_spatial_transformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/LEMMA/ad25df9c71229ad6b5ac8d05e0bbce7e50940701/model/recognizer/__pycache__/tps_spatial_transformer.cpython-38.pyc -------------------------------------------------------------------------------- /model/recognizer/attention_recognition_head.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import sys 4 | 5 | import torch 6 | from torch import nn 7 | from torch.nn import functional as F 8 | from torch.nn import init 9 | from IPython import embed 10 | 11 | class AttentionRecognitionHead(nn.Module): 12 | """ 13 | input: [b x 16 x 64 x in_planes] 14 | output: probability sequence: [b x T x num_classes] 15 | """ 16 | def __init__(self, num_classes, in_planes, sDim, attDim, max_len_labels): 17 | super(AttentionRecognitionHead, self).__init__() 18 | self.num_classes = num_classes # this is the output classes. So it includes the . 19 | self.in_planes = in_planes 20 | self.sDim = sDim 21 | self.attDim = attDim 22 | self.max_len_labels = max_len_labels 23 | 24 | self.decoder = DecoderUnit(sDim=sDim, xDim=in_planes, yDim=num_classes, attDim=attDim) 25 | 26 | def forward(self, x): 27 | x, targets, lengths = x 28 | batch_size = x.size(0) 29 | # Decoder 30 | state = torch.zeros(1, batch_size, self.sDim).cuda() 31 | outputs = [] 32 | 33 | for i in range(max(lengths)): 34 | if i == 0: 35 | y_prev = torch.zeros((batch_size)).fill_(self.num_classes).cuda() # the last one is used as the . 36 | else: 37 | y_prev = targets[:,i-1].cuda() 38 | 39 | output, state = self.decoder(x, state, y_prev) 40 | outputs.append(output) 41 | outputs = torch.cat([_.unsqueeze(1) for _ in outputs], 1) 42 | return outputs 43 | 44 | # inference stage. 45 | def sample(self, x): 46 | x, _, _ = x 47 | batch_size = x.size(0) 48 | # Decoder 49 | state = torch.zeros(1, batch_size, self.sDim) 50 | 51 | predicted_ids, predicted_scores = [], [] 52 | for i in range(self.max_len_labels): 53 | if i == 0: 54 | y_prev = torch.zeros((batch_size)).fill_(self.num_classes) 55 | else: 56 | y_prev = predicted 57 | 58 | output, state = self.decoder(x, state, y_prev) 59 | output = F.softmax(output, dim=1) 60 | score, predicted = output.max(1) 61 | predicted_ids.append(predicted.unsqueeze(1)) 62 | predicted_scores.append(score.unsqueeze(1)) 63 | predicted_ids = torch.cat(predicted_ids, 1) 64 | predicted_scores = torch.cat(predicted_scores, 1) 65 | # return predicted_ids.squeeze(), predicted_scores.squeeze() 66 | return predicted_ids, predicted_scores 67 | 68 | def beam_search(self, x, beam_width, eos): 69 | 70 | def _inflate(tensor, times, dim): 71 | repeat_dims = [1] * tensor.dim() 72 | repeat_dims[dim] = times 73 | return tensor.repeat(*repeat_dims) 74 | 75 | # https://github.com/IBM/pytorch-seq2seq/blob/fede87655ddce6c94b38886089e05321dc9802af/seq2seq/models/TopKDecoder.py 76 | batch_size, l, d = x.size() 77 | # inflated_encoder_feats = _inflate(encoder_feats, beam_width, 0) # ABC --> AABBCC -/-> ABCABC 78 | inflated_encoder_feats = x.unsqueeze(1).permute((1,0,2,3)).repeat((beam_width,1,1,1)).permute((1,0,2,3)).contiguous().view(-1, l, d) 79 | 80 | # Initialize the decoder 81 | state = torch.zeros(1, batch_size * beam_width, self.sDim).cuda() 82 | pos_index = (torch.Tensor(range(batch_size)) * beam_width).long().view(-1, 1).cuda() 83 | 84 | # Initialize the scores 85 | sequence_scores = torch.Tensor(batch_size * beam_width, 1).cuda() 86 | sequence_scores.fill_(-float('Inf')) 87 | sequence_scores.index_fill_(0, torch.Tensor([i * beam_width for i in range(0, batch_size)]).long().cuda(), 0.0) 88 | # sequence_scores.fill_(0.0) 89 | 90 | # Initialize the input vector 91 | y_prev = torch.zeros((batch_size * beam_width)).fill_(self.num_classes).cuda() 92 | 93 | # Store decisions for backtracking 94 | stored_scores = list() 95 | stored_predecessors = list() 96 | stored_emitted_symbols = list() 97 | 98 | for i in range(self.max_len_labels): 99 | output, state = self.decoder(inflated_encoder_feats, state, y_prev) 100 | log_softmax_output = F.log_softmax(output, dim=1) 101 | 102 | sequence_scores = _inflate(sequence_scores, self.num_classes, 1) 103 | sequence_scores += log_softmax_output 104 | scores, candidates = sequence_scores.view(batch_size, -1).topk(beam_width, dim=1) 105 | 106 | # Reshape input = (bk, 1) and sequence_scores = (bk, 1) 107 | y_prev = (candidates % self.num_classes).view(batch_size * beam_width) 108 | sequence_scores = scores.view(batch_size * beam_width, 1) 109 | 110 | # Update fields for next timestep 111 | predecessors = (candidates / self.num_classes + pos_index.expand_as(candidates)).view(batch_size * beam_width, 1).long() 112 | 113 | # print("state:", self.num_classes, ) 114 | 115 | state = state.index_select(1, predecessors.squeeze()) 116 | 117 | # Update sequence socres and erase scores for symbol so that they aren't expanded 118 | stored_scores.append(sequence_scores.clone()) 119 | eos_indices = y_prev.view(-1, 1).eq(eos) 120 | if eos_indices.nonzero().dim() > 0: 121 | sequence_scores.masked_fill_(eos_indices, -float('inf')) 122 | 123 | # Cache results for backtracking 124 | stored_predecessors.append(predecessors) 125 | stored_emitted_symbols.append(y_prev) 126 | 127 | # Do backtracking to return the optimal values 128 | #====== backtrak ======# 129 | # Initialize return variables given different types 130 | p = list() 131 | l = [[self.max_len_labels] * beam_width for _ in range(batch_size)] # Placeholder for lengths of top-k sequences 132 | 133 | # the last step output of the beams are not sorted 134 | # thus they are sorted here 135 | sorted_score, sorted_idx = stored_scores[-1].view(batch_size, beam_width).topk(beam_width) 136 | # initialize the sequence scores with the sorted last step beam scores 137 | s = sorted_score.clone() 138 | 139 | batch_eos_found = [0] * batch_size # the number of EOS found 140 | # in the backward loop below for each batch 141 | t = self.max_len_labels - 1 142 | # initialize the back pointer with the sorted order of the last step beams. 143 | # add pos_index for indexing variable with b*k as the first dimension. 144 | t_predecessors = (sorted_idx + pos_index.expand_as(sorted_idx)).view(batch_size * beam_width) 145 | while t >= 0: 146 | # Re-order the variables with the back pointer 147 | current_symbol = stored_emitted_symbols[t].index_select(0, t_predecessors) 148 | t_predecessors = stored_predecessors[t].index_select(0, t_predecessors).squeeze() 149 | eos_indices = stored_emitted_symbols[t].eq(eos).nonzero() 150 | if eos_indices.dim() > 0: 151 | for i in range(eos_indices.size(0)-1, -1, -1): 152 | # Indices of the EOS symbol for both variables 153 | # with b*k as the first dimension, and b, k for 154 | # the first two dimensions 155 | idx = eos_indices[i] 156 | b_idx = int(idx[0] / beam_width) 157 | # The indices of the replacing position 158 | # according to the replacement strategy noted above 159 | res_k_idx = beam_width - (batch_eos_found[b_idx] % beam_width) - 1 160 | batch_eos_found[b_idx] += 1 161 | res_idx = b_idx * beam_width + res_k_idx 162 | 163 | # Replace the old information in return variables 164 | # with the new ended sequence information 165 | t_predecessors[res_idx] = stored_predecessors[t][idx[0]] 166 | current_symbol[res_idx] = stored_emitted_symbols[t][idx[0]] 167 | s[b_idx, res_k_idx] = stored_scores[t][idx[0], [0]] 168 | l[b_idx][res_k_idx] = t + 1 169 | 170 | # record the back tracked results 171 | p.append(current_symbol) 172 | 173 | t -= 1 174 | 175 | # Sort and re-order again as the added ended sequences may change 176 | # the order (very unlikely) 177 | s, re_sorted_idx = s.topk(beam_width) 178 | for b_idx in range(batch_size): 179 | l[b_idx] = [l[b_idx][k_idx.item()] for k_idx in re_sorted_idx[b_idx,:]] 180 | 181 | re_sorted_idx = (re_sorted_idx + pos_index.expand_as(re_sorted_idx)).view(batch_size*beam_width) 182 | 183 | # Reverse the sequences and re-order at the same time 184 | # It is reversed because the backtracking happens in reverse time order 185 | p = [step.index_select(0, re_sorted_idx).view(batch_size, beam_width, -1) for step in reversed(p)] 186 | p = torch.cat(p, -1)[:,0,:] 187 | return p, torch.ones_like(p) 188 | 189 | 190 | class AttentionUnit(nn.Module): 191 | def __init__(self, sDim, xDim, attDim): 192 | super(AttentionUnit, self).__init__() 193 | 194 | self.sDim = sDim 195 | self.xDim = xDim 196 | self.attDim = attDim 197 | 198 | self.sEmbed = nn.Linear(sDim, attDim) 199 | self.xEmbed = nn.Linear(xDim, attDim) 200 | self.wEmbed = nn.Linear(attDim, 1) 201 | 202 | # self.init_weights() 203 | 204 | def init_weights(self): 205 | init.normal_(self.sEmbed.weight, std=0.01) 206 | init.constant_(self.sEmbed.bias, 0) 207 | init.normal_(self.xEmbed.weight, std=0.01) 208 | init.constant_(self.xEmbed.bias, 0) 209 | init.normal_(self.wEmbed.weight, std=0.01) 210 | init.constant_(self.wEmbed.bias, 0) 211 | 212 | def forward(self, x, sPrev): 213 | sPrev = sPrev.cuda() 214 | batch_size, T, _ = x.size() # [b x T x xDim] 215 | x = x.contiguous().view(-1, self.xDim) # [(b x T) x xDim] 216 | xProj = self.xEmbed(x) # [(b x T) x attDim] 217 | xProj = xProj.view(batch_size, T, -1) # [b x T x attDim] 218 | 219 | sPrev = sPrev.squeeze(0) 220 | from IPython import embed; 221 | # embed() 222 | sProj = self.sEmbed(sPrev) # [b x attDim] 223 | sProj = torch.unsqueeze(sProj, 1) # [b x 1 x attDim] 224 | sProj = sProj.expand(batch_size, T, self.attDim) # [b x T x attDim] 225 | 226 | sumTanh = torch.tanh(sProj + xProj) 227 | sumTanh = sumTanh.view(-1, self.attDim) 228 | 229 | vProj = self.wEmbed(sumTanh) # [(b x T) x 1] 230 | vProj = vProj.view(batch_size, T) 231 | 232 | alpha = F.softmax(vProj, dim=1) # attention weights for each sample in the minibatch 233 | 234 | return alpha 235 | 236 | 237 | class DecoderUnit(nn.Module): 238 | def __init__(self, sDim, xDim, yDim, attDim): 239 | super(DecoderUnit, self).__init__() 240 | self.sDim = sDim 241 | self.xDim = xDim 242 | self.yDim = yDim 243 | self.attDim = attDim 244 | self.emdDim = attDim 245 | 246 | self.attention_unit = AttentionUnit(sDim, xDim, attDim) 247 | self.tgt_embedding = nn.Embedding(yDim+1, self.emdDim) # the last is used for 248 | self.gru = nn.GRU(input_size=xDim+self.emdDim, hidden_size=sDim, batch_first=True) 249 | self.fc = nn.Linear(sDim, yDim) 250 | 251 | # self.init_weights() 252 | 253 | def init_weights(self): 254 | init.normal_(self.tgt_embedding.weight, std=0.01) 255 | init.normal_(self.fc.weight, std=0.01) 256 | init.constant_(self.fc.bias, 0) 257 | 258 | def forward(self, x, sPrev, yPrev): 259 | sPrev = sPrev.cuda() 260 | # x: feature sequence from the image decoder. 261 | batch_size, T, _ = x.size() 262 | alpha = self.attention_unit(x, sPrev) 263 | context = torch.bmm(alpha.unsqueeze(1), x).squeeze(1) 264 | yPrev = yPrev.cuda() 265 | yProj = self.tgt_embedding(yPrev.long()) 266 | self.gru.flatten_parameters() 267 | output, state = self.gru(torch.cat([yProj, context], 1).unsqueeze(1), sPrev) 268 | output = output.squeeze(1) 269 | 270 | output = self.fc(output) 271 | return output, state -------------------------------------------------------------------------------- /model/recognizer/recognizer_builder.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from PIL import Image 4 | import numpy as np 5 | from collections import OrderedDict 6 | import sys 7 | 8 | import torch 9 | from torch import nn 10 | from torch.nn import functional as F 11 | from torch.nn import init 12 | sys.path.append('./') 13 | from .resnet_aster import * 14 | from .attention_recognition_head import AttentionRecognitionHead 15 | sys.path.append('../') 16 | from .sequenceCrossEntropyLoss import SequenceCrossEntropyLoss 17 | from .tps_spatial_transformer import TPSSpatialTransformer 18 | from .stn_head import STNHead 19 | 20 | tps_inputsize = [32, 64] 21 | tps_outputsize = [32, 100] 22 | num_control_points = 20 23 | tps_margins = [0.05, 0.05] 24 | beam_width = 5 25 | 26 | 27 | class RecognizerBuilder(nn.Module): 28 | """ 29 | This is the integrated model. 30 | """ 31 | def __init__(self, arch, rec_num_classes, sDim = 512, attDim = 512, max_len_labels = 100, eos = 'EOS', STN_ON = True): 32 | super(RecognizerBuilder, self).__init__() 33 | 34 | self.arch = arch 35 | self.rec_num_classes = rec_num_classes 36 | self.sDim = sDim 37 | self.attDim = attDim 38 | self.max_len_labels = max_len_labels 39 | self.eos = eos 40 | self.STN_ON = STN_ON 41 | 42 | self.tps_inputsize = tps_inputsize 43 | 44 | self.encoder = ResNet_ASTER(self.arch) 45 | encoder_out_planes = self.encoder.out_planes 46 | 47 | self.decoder = AttentionRecognitionHead( 48 | num_classes=rec_num_classes, 49 | in_planes=encoder_out_planes, 50 | sDim=sDim, 51 | attDim=attDim, 52 | max_len_labels=max_len_labels) 53 | self.rec_crit = SequenceCrossEntropyLoss() 54 | 55 | if self.STN_ON: 56 | self.tps = TPSSpatialTransformer( 57 | output_image_size=tuple(tps_outputsize), 58 | num_control_points=num_control_points, 59 | margins=tuple(tps_margins)) 60 | self.stn_head = STNHead( 61 | in_planes=3, 62 | num_ctrlpoints=num_control_points, 63 | activation='none') 64 | 65 | def forward(self, input_dict): 66 | return_dict = {} 67 | return_dict['losses'] = {} 68 | return_dict['output'] = {} 69 | 70 | x, rec_targets, rec_lengths = input_dict['images'], \ 71 | input_dict['rec_targets'], \ 72 | input_dict['rec_lengths'] 73 | 74 | # rectification 75 | if self.STN_ON: 76 | # input images are downsampled before being fed into stn_head. 77 | stn_input = F.interpolate(x, self.tps_inputsize, mode='bilinear', align_corners=True) 78 | stn_img_feat, ctrl_points = self.stn_head(stn_input) 79 | x, _ = self.tps(x, ctrl_points) 80 | # if not self.training: 81 | # # save for visualization 82 | # return_dict['output']['ctrl_points'] = ctrl_points 83 | # return_dict['output']['rectified_images'] = x 84 | 85 | encoder_feats = self.encoder(x) 86 | encoder_feats = encoder_feats.contiguous() 87 | 88 | if self.training: 89 | rec_pred = self.decoder([encoder_feats, rec_targets, rec_lengths]) 90 | loss_rec = self.rec_crit(rec_pred, rec_targets, rec_lengths) 91 | return_dict['losses']['loss_rec'] = loss_rec 92 | else: 93 | rec_pred, rec_pred_scores = self.decoder.beam_search(encoder_feats, beam_width, self.eos) 94 | rec_pred_ = self.decoder([encoder_feats, rec_targets, rec_lengths]) 95 | loss_rec = self.rec_crit(rec_pred_, rec_targets, rec_lengths) 96 | return_dict['losses']['loss_rec'] = loss_rec 97 | return_dict['output']['pred_rec'] = rec_pred 98 | return_dict['output']['pred_rec_score'] = rec_pred_scores 99 | 100 | # pytorch0.4 bug on gathering scalar(0-dim) tensors 101 | for k, v in return_dict['losses'].items(): 102 | return_dict['losses'][k] = v.unsqueeze(0) 103 | 104 | return return_dict 105 | 106 | 107 | if __name__ == '__main__': 108 | from IPython import embed 109 | embed() 110 | -------------------------------------------------------------------------------- /model/recognizer/resnet_aster.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | 5 | import sys 6 | import math 7 | # 8 | # from config import get_args 9 | # global_args = get_args(sys.argv[1:]) 10 | 11 | 12 | def conv3x3(in_planes, out_planes, stride=1): 13 | """3x3 convolution with padding""" 14 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 15 | padding=1, bias=False) 16 | 17 | 18 | def conv1x1(in_planes, out_planes, stride=1): 19 | """1x1 convolution""" 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 21 | 22 | 23 | def get_sinusoid_encoding(n_position, feat_dim, wave_length=10000): 24 | # [n_position] 25 | positions = torch.arange(0, n_position)#.cuda() 26 | # [feat_dim] 27 | dim_range = torch.arange(0, feat_dim)#.cuda() 28 | dim_range = torch.pow(wave_length, 2 * (dim_range // 2) / feat_dim) 29 | # [n_position, feat_dim] 30 | angles = positions.unsqueeze(1) / dim_range.unsqueeze(0) 31 | angles = angles.float() 32 | angles[:, 0::2] = torch.sin(angles[:, 0::2]) 33 | angles[:, 1::2] = torch.cos(angles[:, 1::2]) 34 | return angles 35 | 36 | 37 | class AsterBlock(nn.Module): 38 | 39 | def __init__(self, inplanes, planes, stride=1, downsample=None): 40 | super(AsterBlock, self).__init__() 41 | self.conv1 = conv1x1(inplanes, planes, stride) 42 | self.bn1 = nn.BatchNorm2d(planes) 43 | self.relu = nn.ReLU(inplace=True) 44 | self.conv2 = conv3x3(planes, planes) 45 | self.bn2 = nn.BatchNorm2d(planes) 46 | self.downsample = downsample 47 | self.stride = stride 48 | 49 | def forward(self, x): 50 | residual = x 51 | out = self.conv1(x) 52 | out = self.bn1(out) 53 | out = self.relu(out) 54 | out = self.conv2(out) 55 | out = self.bn2(out) 56 | 57 | if self.downsample is not None: 58 | residual = self.downsample(x) 59 | out += residual 60 | out = self.relu(out) 61 | return out 62 | 63 | 64 | class ResNet_ASTER(nn.Module): 65 | """For aster or crnn""" 66 | 67 | def __init__(self, with_lstm=False, n_group=1): 68 | super(ResNet_ASTER, self).__init__() 69 | self.with_lstm = with_lstm 70 | self.n_group = n_group 71 | 72 | in_channels = 3 73 | self.layer0 = nn.Sequential( 74 | nn.Conv2d(in_channels, 32, kernel_size=(3, 3), stride=1, padding=1, bias=False), 75 | nn.BatchNorm2d(32), 76 | nn.ReLU(inplace=True)) 77 | 78 | self.inplanes = 32 79 | self.layer1 = self._make_layer(32, 3, [2, 2]) # [16, 50] 80 | self.layer2 = self._make_layer(64, 4, [2, 2]) # [8, 25] 81 | self.layer3 = self._make_layer(128, 6, [2, 1]) # [4, 25] 82 | self.layer4 = self._make_layer(256, 6, [2, 1]) # [2, 25] 83 | self.layer5 = self._make_layer(512, 3, [2, 1]) # [1, 25] 84 | 85 | if with_lstm: 86 | self.rnn = nn.LSTM(512, 256, bidirectional=True, num_layers=2, batch_first=True) 87 | 88 | self.out_planes = 2 * 256 89 | else: 90 | self.out_planes = 512 91 | 92 | for m in self.modules(): 93 | if isinstance(m, nn.Conv2d): 94 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 95 | elif isinstance(m, nn.BatchNorm2d): 96 | nn.init.constant_(m.weight, 1) 97 | nn.init.constant_(m.bias, 0) 98 | 99 | def _make_layer(self, planes, blocks, stride): 100 | downsample = None 101 | if stride != [1, 1] or self.inplanes != planes: 102 | downsample = nn.Sequential( 103 | conv1x1(self.inplanes, planes, stride), 104 | nn.BatchNorm2d(planes)) 105 | 106 | layers = [] 107 | layers.append(AsterBlock(self.inplanes, planes, stride, downsample)) 108 | self.inplanes = planes 109 | for _ in range(1, blocks): 110 | layers.append(AsterBlock(self.inplanes, planes)) 111 | return nn.Sequential(*layers) 112 | 113 | def forward(self, x): 114 | 115 | x0 = self.layer0(x) 116 | x1 = self.layer1(x0) 117 | x2 = self.layer2(x1) 118 | x3 = self.layer3(x2) 119 | x4 = self.layer4(x3) 120 | x5 = self.layer5(x4) 121 | 122 | cnn_feat = x5.squeeze(2) # [N, c, w] 123 | cnn_feat = cnn_feat.transpose(2, 1) 124 | if self.with_lstm: 125 | # print("shit") 126 | # self.rnn.flatten_parameters() 127 | 128 | if not hasattr(self, '_flattened'): 129 | self.rnn.flatten_parameters() 130 | setattr(self, '_flattened', True) 131 | 132 | rnn_feat, _ = self.rnn(cnn_feat) 133 | return rnn_feat 134 | else: 135 | return cnn_feat 136 | 137 | 138 | if __name__ == "__main__": 139 | x = torch.randn(7, 3, 32, 280) 140 | net = ResNet_ASTER(with_lstm=True, n_group=1) 141 | encoder_feat = net(x) 142 | print(encoder_feat.size()) -------------------------------------------------------------------------------- /model/recognizer/sequenceCrossEntropyLoss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Variable 6 | import torch.nn.functional as F 7 | 8 | def to_contiguous(tensor): 9 | if tensor.is_contiguous(): 10 | return tensor 11 | else: 12 | return tensor.contiguous() 13 | 14 | def _assert_no_grad(variable): 15 | assert not variable.requires_grad, \ 16 | "nn criterions don't compute the gradient w.r.t. targets - please " \ 17 | "mark these variables as not requiring gradients" 18 | 19 | class SequenceCrossEntropyLoss(nn.Module): 20 | def __init__(self, 21 | weight=None, 22 | size_average=True, 23 | ignore_index=-100, 24 | sequence_normalize=False, 25 | sample_normalize=True): 26 | super(SequenceCrossEntropyLoss, self).__init__() 27 | self.weight = weight 28 | self.size_average = size_average 29 | self.ignore_index = ignore_index 30 | self.sequence_normalize = sequence_normalize 31 | self.sample_normalize = sample_normalize 32 | 33 | assert (sequence_normalize and sample_normalize) == False 34 | 35 | def forward(self, input, target, length): 36 | _assert_no_grad(target) 37 | # length to mask 38 | batch_size, def_max_length = target.size(0), target.size(1) 39 | mask = torch.zeros(batch_size, def_max_length) 40 | for i in range(batch_size): 41 | mask[i,:length[i]].fill_(1) 42 | mask = mask.type_as(input) 43 | # truncate to the same size 44 | max_length = max(length) 45 | assert max_length == input.size(1) 46 | target = target[:, :max_length] 47 | mask = mask[:, :max_length] 48 | input = to_contiguous(input).view(-1, input.size(2)) 49 | input = F.log_softmax(input, dim=1) 50 | target = to_contiguous(target).view(-1, 1) 51 | mask = to_contiguous(mask).view(-1, 1) 52 | output = - input.gather(1, target.long()) * mask 53 | # if self.size_average: 54 | # output = torch.sum(output) / torch.sum(mask) 55 | # elif self.reduce: 56 | # output = torch.sum(output) 57 | ## 58 | output = torch.sum(output) 59 | if self.sequence_normalize: 60 | output = output / torch.sum(mask) 61 | if self.sample_normalize: 62 | output = output / batch_size 63 | 64 | return output -------------------------------------------------------------------------------- /model/recognizer/stn_head.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import math 4 | import numpy as np 5 | import sys 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | from torch.nn import init 11 | from IPython import embed 12 | 13 | 14 | def conv3x3_block(in_planes, out_planes, stride=1): 15 | """3x3 convolution with padding""" 16 | conv_layer = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, padding=1) 17 | 18 | block = nn.Sequential( 19 | conv_layer, 20 | nn.BatchNorm2d(out_planes), 21 | nn.ReLU(inplace=True), 22 | ) 23 | return block 24 | 25 | 26 | class STNHead(nn.Module): 27 | def __init__(self, in_planes, num_ctrlpoints, activation='none'): 28 | super(STNHead, self).__init__() 29 | 30 | self.in_planes = in_planes 31 | self.num_ctrlpoints = num_ctrlpoints 32 | self.activation = activation 33 | self.stn_convnet = nn.Sequential( 34 | conv3x3_block(in_planes, 32), # 32*64 35 | nn.MaxPool2d(kernel_size=2, stride=2), 36 | conv3x3_block(32, 64), # 16*32 37 | nn.MaxPool2d(kernel_size=2, stride=2), 38 | conv3x3_block(64, 128), # 8*16 39 | nn.MaxPool2d(kernel_size=2, stride=2), 40 | conv3x3_block(128, 256), # 4*8 41 | nn.MaxPool2d(kernel_size=2, stride=2), 42 | conv3x3_block(256, 256), # 2*4, 43 | nn.MaxPool2d(kernel_size=2, stride=2), 44 | conv3x3_block(256, 256)) # 1*2 45 | 46 | self.stn_fc1 = nn.Sequential( 47 | nn.Linear(2*256, 512), 48 | nn.BatchNorm1d(512), 49 | nn.ReLU(inplace=True)) 50 | self.stn_fc2 = nn.Linear(512, num_ctrlpoints*2) 51 | 52 | self.init_weights(self.stn_convnet) 53 | self.init_weights(self.stn_fc1) 54 | self.init_stn(self.stn_fc2) 55 | 56 | def init_weights(self, module): 57 | for m in module.modules(): 58 | if isinstance(m, nn.Conv2d): 59 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 60 | m.weight.data.normal_(0, math.sqrt(2. / n)) 61 | if m.bias is not None: 62 | m.bias.data.zero_() 63 | elif isinstance(m, nn.BatchNorm2d): 64 | m.weight.data.fill_(1) 65 | m.bias.data.zero_() 66 | elif isinstance(m, nn.Linear): 67 | m.weight.data.normal_(0, 0.001) 68 | m.bias.data.zero_() 69 | 70 | def init_stn(self, stn_fc2): 71 | margin = 0.01 72 | sampling_num_per_side = int(self.num_ctrlpoints / 2) 73 | ctrl_pts_x = np.linspace(margin, 1.-margin, sampling_num_per_side) 74 | ctrl_pts_y_top = np.ones(sampling_num_per_side) * margin 75 | ctrl_pts_y_bottom = np.ones(sampling_num_per_side) * (1-margin) 76 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) 77 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) 78 | ctrl_points = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0).astype(np.float32) 79 | if self.activation == 'none': 80 | pass 81 | elif self.activation == 'sigmoid': 82 | ctrl_points = -np.log(1. / ctrl_points - 1.) 83 | stn_fc2.weight.data.zero_() 84 | stn_fc2.bias.data = torch.Tensor(ctrl_points).view(-1) 85 | 86 | def forward(self, x): 87 | x = self.stn_convnet(x) 88 | batch_size, _, h, w = x.size() 89 | x = x.view(batch_size, -1) 90 | # embed() 91 | img_feat = self.stn_fc1(x) 92 | x = self.stn_fc2(0.1 * img_feat) 93 | if self.activation == 'sigmoid': 94 | x = F.sigmoid(x) 95 | x = x.view(-1, self.num_ctrlpoints, 2) 96 | return img_feat, x 97 | 98 | 99 | if __name__ == "__main__": 100 | in_planes = 3 101 | num_ctrlpoints = 20 102 | activation='none' # 'sigmoid' 103 | stn_head = STNHead(in_planes, num_ctrlpoints, activation) 104 | input = torch.randn(10, 3, 32, 128) 105 | control_points = stn_head(input) 106 | # print(control_points.size()) 107 | embed() -------------------------------------------------------------------------------- /model/recognizer/tps_spatial_transformer.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | import numpy as np 5 | import itertools 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from IPython import embed 11 | 12 | 13 | def grid_sample(input, grid, canvas = None): 14 | output = F.grid_sample(input, grid) 15 | if canvas is None: 16 | return output 17 | else: 18 | input_mask = input.data.new(input.size()).fill_(1) 19 | output_mask = F.grid_sample(input_mask, grid) 20 | padded_output = output * output_mask + canvas * (1 - output_mask) 21 | return padded_output 22 | 23 | 24 | # phi(x1, x2) = r^2 * log(r), where r = ||x1 - x2||_2 25 | def compute_partial_repr(input_points, control_points): 26 | N = input_points.size(0) 27 | M = control_points.size(0) 28 | pairwise_diff = input_points.view(N, 1, 2) - control_points.view(1, M, 2) 29 | # original implementation, very slow 30 | # pairwise_dist = torch.sum(pairwise_diff ** 2, dim = 2) # square of distance 31 | pairwise_diff_square = pairwise_diff * pairwise_diff 32 | pairwise_dist = pairwise_diff_square[:, :, 0] + pairwise_diff_square[:, :, 1] 33 | repr_matrix = 0.5 * pairwise_dist * torch.log(pairwise_dist) 34 | # fix numerical error for 0 * log(0), substitute all nan with 0 35 | mask = repr_matrix != repr_matrix 36 | repr_matrix.masked_fill_(mask, 0) 37 | return repr_matrix 38 | 39 | 40 | # output_ctrl_pts are specified, according to our task. 41 | def build_output_control_points(num_control_points, margins): 42 | margin_x, margin_y = margins 43 | num_ctrl_pts_per_side = num_control_points // 2 44 | ctrl_pts_x = np.linspace(margin_x, 1.0 - margin_x, num_ctrl_pts_per_side) 45 | ctrl_pts_y_top = np.ones(num_ctrl_pts_per_side) * margin_y 46 | ctrl_pts_y_bottom = np.ones(num_ctrl_pts_per_side) * (1.0 - margin_y) 47 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) 48 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) 49 | # ctrl_pts_top = ctrl_pts_top[1:-1,:] 50 | # ctrl_pts_bottom = ctrl_pts_bottom[1:-1,:] 51 | output_ctrl_pts_arr = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) 52 | output_ctrl_pts = torch.Tensor(output_ctrl_pts_arr) 53 | return output_ctrl_pts 54 | 55 | 56 | # demo: ~/test/models/test_tps_transformation.py 57 | class TPSSpatialTransformer(nn.Module): 58 | 59 | def __init__(self, output_image_size=None, num_control_points=None, margins=None): 60 | super(TPSSpatialTransformer, self).__init__() 61 | self.output_image_size = output_image_size 62 | self.num_control_points = num_control_points 63 | self.margins = margins 64 | 65 | self.target_height, self.target_width = output_image_size 66 | target_control_points = build_output_control_points(num_control_points, margins) 67 | N = num_control_points 68 | # N = N - 4 69 | 70 | # create padded kernel matrix 71 | forward_kernel = torch.zeros(N + 3, N + 3) 72 | target_control_partial_repr = compute_partial_repr(target_control_points, target_control_points) 73 | forward_kernel[:N, :N].copy_(target_control_partial_repr) 74 | forward_kernel[:N, -3].fill_(1) 75 | forward_kernel[-3, :N].fill_(1) 76 | forward_kernel[:N, -2:].copy_(target_control_points) 77 | forward_kernel[-2:, :N].copy_(target_control_points.transpose(0, 1)) 78 | # compute inverse matrix 79 | inverse_kernel = torch.inverse(forward_kernel) 80 | 81 | # create target cordinate matrix 82 | HW = self.target_height * self.target_width 83 | target_coordinate = list(itertools.product(range(self.target_height), range(self.target_width))) 84 | target_coordinate = torch.Tensor(target_coordinate) # HW x 2 85 | Y, X = target_coordinate.split(1, dim = 1) 86 | Y = Y / (self.target_height - 1) 87 | X = X / (self.target_width - 1) 88 | target_coordinate = torch.cat([X, Y], dim = 1) # convert from (y, x) to (x, y) 89 | target_coordinate_partial_repr = compute_partial_repr(target_coordinate, target_control_points) 90 | target_coordinate_repr = torch.cat([ 91 | target_coordinate_partial_repr, torch.ones(HW, 1), target_coordinate 92 | ], dim = 1) 93 | 94 | # register precomputed matrices 95 | self.register_buffer('inverse_kernel', inverse_kernel) 96 | self.register_buffer('padding_matrix', torch.zeros(3, 2)) 97 | self.register_buffer('target_coordinate_repr', target_coordinate_repr) 98 | self.register_buffer('target_control_points', target_control_points) 99 | 100 | def forward(self, input, source_control_points): 101 | assert source_control_points.ndimension() == 3 102 | assert source_control_points.size(1) == self.num_control_points 103 | assert source_control_points.size(2) == 2 104 | batch_size = source_control_points.size(0) 105 | 106 | Y = torch.cat([source_control_points, self.padding_matrix.expand(batch_size, 3, 2)], 1) 107 | mapping_matrix = torch.matmul(self.inverse_kernel, Y) 108 | source_coordinate = torch.matmul(self.target_coordinate_repr, mapping_matrix) 109 | 110 | grid = source_coordinate.view(-1, self.target_height, self.target_width, 2) 111 | grid = torch.clamp(grid, 0, 1) # the source_control_points may be out of [0, 1]. 112 | # the input to grid_sample is normalized [-1, 1], but what we get is [0, 1] 113 | grid = 2.0 * grid - 1.0 114 | output_maps = grid_sample(input, grid, canvas=None) 115 | return output_maps, source_coordinate 116 | 117 | 118 | if __name__=='__main__': 119 | embed() -------------------------------------------------------------------------------- /model/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.utils.model_zoo as model_zoo 6 | 7 | 8 | def conv1x1(in_planes, out_planes, stride=1): 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 10 | 11 | 12 | def conv3x3(in_planes, out_planes, stride=1): 13 | "3x3 convolution with padding" 14 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 15 | padding=1, bias=False) 16 | 17 | 18 | class BasicBlock(nn.Module): 19 | expansion = 1 20 | 21 | def __init__(self, inplanes, planes, stride=1, downsample=None): 22 | super(BasicBlock, self).__init__() 23 | self.conv1 = conv1x1(inplanes, planes) 24 | self.bn1 = nn.BatchNorm2d(planes) 25 | self.relu = nn.ReLU(inplace=True) 26 | self.conv2 = conv3x3(planes, planes, stride) 27 | self.bn2 = nn.BatchNorm2d(planes) 28 | self.downsample = downsample 29 | self.stride = stride 30 | 31 | def forward(self, x): 32 | residual = x 33 | 34 | out = self.conv1(x) 35 | out = self.bn1(out) 36 | out = self.relu(out) 37 | 38 | out = self.conv2(out) 39 | out = self.bn2(out) 40 | 41 | if self.downsample is not None: 42 | residual = self.downsample(x) 43 | 44 | out += residual 45 | out = self.relu(out) 46 | 47 | return out 48 | 49 | 50 | class ResNet(nn.Module): 51 | 52 | def __init__(self, block, layers): 53 | self.inplanes = 32 54 | super(ResNet, self).__init__() 55 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, 56 | bias=False) 57 | self.bn1 = nn.BatchNorm2d(32) 58 | self.relu = nn.ReLU(inplace=True) 59 | 60 | self.layer1 = self._make_layer(block, 32, layers[0], stride=2) 61 | self.layer2 = self._make_layer(block, 64, layers[1], stride=1) 62 | self.layer3 = self._make_layer(block, 128, layers[2], stride=2) 63 | self.layer4 = self._make_layer(block, 256, layers[3], stride=1) 64 | self.layer5 = self._make_layer(block, 512, layers[4], stride=1) 65 | 66 | for m in self.modules(): 67 | if isinstance(m, nn.Conv2d): 68 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 69 | m.weight.data.normal_(0, math.sqrt(2. / n)) 70 | elif isinstance(m, nn.BatchNorm2d): 71 | m.weight.data.fill_(1) 72 | m.bias.data.zero_() 73 | 74 | def _make_layer(self, block, planes, blocks, stride=1): 75 | downsample = None 76 | if stride != 1 or self.inplanes != planes * block.expansion: 77 | downsample = nn.Sequential( 78 | nn.Conv2d(self.inplanes, planes * block.expansion, 79 | kernel_size=1, stride=stride, bias=False), 80 | nn.BatchNorm2d(planes * block.expansion), 81 | ) 82 | 83 | layers = [] 84 | layers.append(block(self.inplanes, planes, stride, downsample)) 85 | self.inplanes = planes * block.expansion 86 | for i in range(1, blocks): 87 | layers.append(block(self.inplanes, planes)) 88 | 89 | return nn.Sequential(*layers) 90 | 91 | def forward(self, x): 92 | x = self.conv1(x) 93 | x = self.bn1(x) 94 | x = self.relu(x) 95 | x = self.layer1(x) 96 | x = self.layer2(x) 97 | x = self.layer3(x) 98 | x = self.layer4(x) 99 | x = self.layer5(x) 100 | return x 101 | 102 | 103 | def resnet45(): 104 | return ResNet(BasicBlock, [3, 4, 6, 6, 3]) 105 | -------------------------------------------------------------------------------- /model/stn_head.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import math 4 | import numpy as np 5 | import sys 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | from torch.nn import init 11 | 12 | 13 | def conv3x3_block(in_planes, out_planes, stride=1): 14 | """3x3 convolution with padding""" 15 | conv_layer = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, padding=1) 16 | 17 | block = nn.Sequential( 18 | conv_layer, 19 | nn.BatchNorm2d(out_planes), 20 | nn.ReLU(inplace=True), 21 | ) 22 | return block 23 | 24 | 25 | class STNHead(nn.Module): 26 | def __init__(self, in_planes, num_ctrlpoints, activation='none', input_size=(16, 64)): 27 | super(STNHead, self).__init__() 28 | 29 | self.in_planes = in_planes 30 | self.num_ctrlpoints = num_ctrlpoints 31 | self.activation = activation 32 | self.stn_convnet = nn.Sequential( 33 | # conv3x3_block(in_planes, 32), # 32*128 34 | # nn.MaxPool2d(kernel_size=2, stride=2), 35 | conv3x3_block(in_planes, 32), # 16*64 36 | nn.MaxPool2d(kernel_size=2, stride=2), 37 | conv3x3_block(32, 64), # 8*32 38 | nn.MaxPool2d(kernel_size=2, stride=2), 39 | conv3x3_block(64, 128), # 4*16 40 | nn.MaxPool2d(kernel_size=2, stride=2), 41 | conv3x3_block(128, 256), # 2*8 42 | nn.MaxPool2d(kernel_size=2, stride=2), 43 | conv3x3_block(256, 256), # 1*4, 44 | nn.MaxPool2d(kernel_size=(1,2), stride=(1,2)), 45 | conv3x3_block(256, 256)) # 1*2 46 | 47 | flatten_width = int(input_size[1] / 32) 48 | # print("flw:", input_size[1] / 32) 49 | self.stn_fc1 = nn.Sequential( 50 | nn.Linear(512, 512), #flatten_width*256 51 | nn.BatchNorm1d(512), 52 | nn.ReLU(inplace=True)) 53 | self.stn_fc2 = nn.Linear(512, num_ctrlpoints*2) 54 | 55 | self.init_weights(self.stn_convnet) 56 | self.init_weights(self.stn_fc1) 57 | self.init_stn(self.stn_fc2) 58 | 59 | def init_weights(self, module): 60 | for m in module.modules(): 61 | if isinstance(m, nn.Conv2d): 62 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 63 | m.weight.data.normal_(0, math.sqrt(2. / n)) 64 | if m.bias is not None: 65 | m.bias.data.zero_() 66 | elif isinstance(m, nn.BatchNorm2d): 67 | m.weight.data.fill_(1) 68 | m.bias.data.zero_() 69 | elif isinstance(m, nn.Linear): 70 | m.weight.data.normal_(0, 0.001) 71 | m.bias.data.zero_() 72 | 73 | def init_stn(self, stn_fc2): 74 | margin = 0.01 75 | sampling_num_per_side = int(self.num_ctrlpoints / 2) 76 | ctrl_pts_x = np.linspace(margin, 1.-margin, sampling_num_per_side) 77 | ctrl_pts_y_top = np.ones(sampling_num_per_side) * margin 78 | ctrl_pts_y_bottom = np.ones(sampling_num_per_side) * (1-margin) 79 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) 80 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) 81 | ctrl_points = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0).astype(np.float32) 82 | # print(ctrl_points.shape) 83 | if self.activation == 'none': 84 | pass 85 | elif self.activation == 'sigmoid': 86 | ctrl_points = -np.log(1. / ctrl_points - 1.) 87 | elif self.activation == 'relu': 88 | ctrl_points = F.relu(torch.Tensor(ctrl_points)) 89 | stn_fc2.weight.data.zero_() 90 | stn_fc2.bias.data = torch.Tensor(ctrl_points).view(-1) 91 | 92 | def forward(self, x): 93 | x = self.stn_convnet(x) 94 | batch_size, _, h, w = x.size() 95 | x = x.view(batch_size, -1) 96 | 97 | # print("x:", x.shape) 98 | 99 | img_feat = self.stn_fc1(x) 100 | x = self.stn_fc2(0.1 * img_feat) 101 | if self.activation == 'sigmoid': 102 | x = torch.sigmoid(x) 103 | if self.activation == 'relu': 104 | x = F.relu(x) 105 | x = x.view(-1, self.num_ctrlpoints, 2) 106 | return img_feat, x 107 | 108 | 109 | if __name__ == "__main__": 110 | in_planes = 3 111 | num_ctrlpoints = 20 112 | activation='none' # 'sigmoid' 113 | stn_head = STNHead(in_planes, num_ctrlpoints, activation) 114 | input = torch.randn(10, 3, 32, 64) 115 | control_points = stn_head(input) 116 | print(control_points.size()) 117 | -------------------------------------------------------------------------------- /model/tps_spatial_transformer.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import numpy as np 4 | import itertools 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | def grid_sample(input, grid, canvas = None): 11 | output = F.grid_sample(input, grid) 12 | if canvas is None: 13 | return output 14 | else: 15 | input_mask = input.data.new(input.size()).fill_(1) 16 | output_mask = F.grid_sample(input_mask, grid) 17 | padded_output = output * output_mask + canvas * (1 - output_mask) 18 | return padded_output 19 | 20 | 21 | # phi(x1, x2) = r^2 * log(r), where r = ||x1 - x2||_2 22 | def compute_partial_repr(input_points, control_points): 23 | N = input_points.size(0) 24 | M = control_points.size(0) 25 | pairwise_diff = input_points.view(N, 1, 2) - control_points.view(1, M, 2) 26 | # original implementation, very slow 27 | # pairwise_dist = torch.sum(pairwise_diff ** 2, dim = 2) # square of distance 28 | pairwise_diff_square = pairwise_diff * pairwise_diff 29 | pairwise_dist = pairwise_diff_square[:, :, 0] + pairwise_diff_square[:, :, 1] 30 | repr_matrix = 0.5 * pairwise_dist * torch.log(pairwise_dist) 31 | # fix numerical error for 0 * log(0), substitute all nan with 0 32 | mask = repr_matrix != repr_matrix 33 | repr_matrix.masked_fill_(mask, 0) 34 | return repr_matrix 35 | 36 | 37 | # output_ctrl_pts are specified, according to our task. 38 | def build_output_control_points(num_control_points, margins): 39 | margin_x, margin_y = margins 40 | num_ctrl_pts_per_side = num_control_points // 2 41 | ctrl_pts_x = np.linspace(margin_x, 1.0 - margin_x, num_ctrl_pts_per_side) 42 | ctrl_pts_y_top = np.ones(num_ctrl_pts_per_side) * margin_y 43 | ctrl_pts_y_bottom = np.ones(num_ctrl_pts_per_side) * (1.0 - margin_y) 44 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) 45 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) 46 | # ctrl_pts_top = ctrl_pts_top[1:-1,:] 47 | # ctrl_pts_bottom = ctrl_pts_bottom[1:-1,:] 48 | output_ctrl_pts_arr = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) 49 | output_ctrl_pts = torch.Tensor(output_ctrl_pts_arr) 50 | return output_ctrl_pts 51 | 52 | 53 | # demo: ~/test/models/test_tps_transformation.py 54 | class TPSSpatialTransformer(nn.Module): 55 | 56 | def __init__(self, output_image_size=None, num_control_points=None, margins=None): 57 | super(TPSSpatialTransformer, self).__init__() 58 | self.output_image_size = output_image_size 59 | self.num_control_points = num_control_points 60 | self.margins = margins 61 | 62 | self.target_height, self.target_width = output_image_size 63 | target_control_points = build_output_control_points(num_control_points, margins) 64 | N = num_control_points 65 | # N = N - 4 66 | 67 | # create padded kernel matrix 68 | forward_kernel = torch.zeros(N + 3, N + 3) 69 | target_control_partial_repr = compute_partial_repr(target_control_points, target_control_points) 70 | forward_kernel[:N, :N].copy_(target_control_partial_repr) 71 | forward_kernel[:N, -3].fill_(1) 72 | forward_kernel[-3, :N].fill_(1) 73 | forward_kernel[:N, -2:].copy_(target_control_points) 74 | forward_kernel[-2:, :N].copy_(target_control_points.transpose(0, 1)) 75 | # compute inverse matrix 76 | inverse_kernel = torch.inverse(forward_kernel) 77 | 78 | # create target cordinate matrix 79 | HW = self.target_height * self.target_width 80 | target_coordinate = list(itertools.product(range(self.target_height), range(self.target_width))) 81 | target_coordinate = torch.Tensor(target_coordinate) # HW x 2 82 | Y, X = target_coordinate.split(1, dim = 1) 83 | Y = Y / (self.target_height - 1) 84 | X = X / (self.target_width - 1) 85 | target_coordinate = torch.cat([X, Y], dim = 1) # convert from (y, x) to (x, y) 86 | target_coordinate_partial_repr = compute_partial_repr(target_coordinate, target_control_points) 87 | target_coordinate_repr = torch.cat([ 88 | target_coordinate_partial_repr, torch.ones(HW, 1), target_coordinate 89 | ], dim = 1) 90 | 91 | # register precomputed matrices 92 | self.register_buffer('inverse_kernel', inverse_kernel) 93 | self.register_buffer('padding_matrix', torch.zeros(3, 2)) 94 | self.register_buffer('target_coordinate_repr', target_coordinate_repr) 95 | self.register_buffer('target_control_points', target_control_points) 96 | 97 | def forward(self, input, source_control_points): 98 | assert source_control_points.ndimension() == 3 99 | assert source_control_points.size(1) == self.num_control_points 100 | assert source_control_points.size(2) == 2 101 | batch_size = source_control_points.size(0) 102 | 103 | Y = torch.cat([source_control_points, self.padding_matrix.expand(batch_size, 3, 2)], 1) 104 | mapping_matrix = torch.matmul(self.inverse_kernel, Y) 105 | source_coordinate = torch.matmul(self.target_coordinate_repr, mapping_matrix) 106 | 107 | grid = source_coordinate.view(-1, self.target_height, self.target_width, 2) 108 | grid = torch.clamp(grid, 0, 1) # the source_control_points may be out of [0, 1]. 109 | # the input to grid_sample is normalized [-1, 1], but what we get is [0, 1] 110 | grid = 2.0 * grid - 1.0 111 | output_maps = grid_sample(input, grid, canvas=None) 112 | return output_maps, source_coordinate -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import time 4 | import cv2 5 | import numpy as np 6 | import torch 7 | import yaml 8 | import datetime 9 | from matplotlib import colors 10 | from matplotlib import pyplot as plt 11 | from torch import Tensor, nn 12 | from torch.utils.data import ConcatDataset 13 | 14 | class CharsetMapper(object): 15 | """A simple class to map ids into strings. 16 | 17 | It works only when the character set is 1:1 mapping between individual 18 | characters and individual ids. 19 | """ 20 | 21 | def __init__(self, 22 | filename='./dataset/charset_36.txt', 23 | max_length=30, 24 | null_char=u'\u2591'): 25 | """Creates a lookup table. 26 | 27 | Args: 28 | filename: Path to charset file which maps characters to ids. 29 | max_sequence_length: The max length of ids and string. 30 | null_char: A unicode character used to replace '' character. 31 | the default value is a light shade block '░'. 32 | """ 33 | self.null_char = null_char 34 | self.max_length = max_length 35 | 36 | self.label_to_char = self._read_charset(filename) 37 | self.char_to_label = dict(map(reversed, self.label_to_char.items())) 38 | self.num_classes = len(self.label_to_char) 39 | 40 | def _read_charset(self, filename): 41 | """Reads a charset definition from a tab separated text file. 42 | 43 | Args: 44 | filename: a path to the charset file. 45 | 46 | Returns: 47 | a dictionary with keys equal to character codes and values - unicode 48 | characters. 49 | """ 50 | import re 51 | pattern = re.compile(r'(\d+)\t(.+)') 52 | charset = {} 53 | self.null_label = 0 54 | charset[self.null_label] = self.null_char 55 | with open(filename, 'r') as f: 56 | for i, line in enumerate(f): 57 | m = pattern.match(line) 58 | assert m, f'Incorrect charset file. line #{i}: {line}' 59 | label = int(m.group(1)) + 1 60 | char = m.group(2) 61 | charset[label] = char 62 | return charset 63 | 64 | def trim(self, text): 65 | assert isinstance(text, str) 66 | return text.replace(self.null_char, '') 67 | 68 | def get_text(self, labels, length=None, padding=True, trim=False): 69 | """ Returns a string corresponding to a sequence of character ids. 70 | """ 71 | length = length if length else self.max_length 72 | labels = [l.item() if isinstance(l, Tensor) else int(l) for l in labels] 73 | if padding: 74 | labels = labels + [self.null_label] * (length-len(labels)) 75 | text = ''.join([self.label_to_char[label] for label in labels]) 76 | if trim: text = self.trim(text) 77 | return text 78 | 79 | def get_labels(self, text, length=None, padding=True, case_sensitive=False): 80 | """ Returns the labels of the corresponding text. 81 | """ 82 | length = length if length else self.max_length 83 | if padding: 84 | text = text + self.null_char * (length - len(text)) 85 | if not case_sensitive: 86 | text = text.lower() 87 | labels = [self.char_to_label[char] if char in self.char_to_label.keys() else 0 for char in text] 88 | 89 | return labels[:self.max_length] 90 | 91 | def pad_labels(self, labels, length=None): 92 | length = length if length else self.max_length 93 | 94 | return labels + [self.null_label] * (length - len(labels)) 95 | 96 | @property 97 | def digits(self): 98 | return '0123456789' 99 | 100 | @property 101 | def digit_labels(self): 102 | return self.get_labels(self.digits, padding=False) 103 | 104 | @property 105 | def alphabets(self): 106 | all_chars = list(self.char_to_label.keys()) 107 | valid_chars = [] 108 | for c in all_chars: 109 | if c in 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ': 110 | valid_chars.append(c) 111 | return ''.join(valid_chars) 112 | 113 | @property 114 | def alphabet_labels(self): 115 | return self.get_labels(self.alphabets, padding=False) 116 | 117 | 118 | class Timer(object): 119 | """A simple timer.""" 120 | def __init__(self): 121 | self.data_time = 0. 122 | self.data_diff = 0. 123 | self.data_total_time = 0. 124 | self.data_call = 0 125 | self.running_time = 0. 126 | self.running_diff = 0. 127 | self.running_total_time = 0. 128 | self.running_call = 0 129 | 130 | def tic(self): 131 | self.start_time = time.time() 132 | self.running_time = self.start_time 133 | 134 | def toc_data(self): 135 | self.data_time = time.time() 136 | self.data_diff = self.data_time - self.running_time 137 | self.data_total_time += self.data_diff 138 | self.data_call += 1 139 | 140 | def toc_running(self): 141 | self.running_time = time.time() 142 | self.running_diff = self.running_time - self.data_time 143 | self.running_total_time += self.running_diff 144 | self.running_call += 1 145 | 146 | def total_time(self): 147 | return self.data_total_time + self.running_total_time 148 | 149 | def average_time(self): 150 | return self.average_data_time() + self.average_running_time() 151 | 152 | def average_data_time(self): 153 | return self.data_total_time / (self.data_call or 1) 154 | 155 | def average_running_time(self): 156 | return self.running_total_time / (self.running_call or 1) 157 | 158 | 159 | class Logger(object): 160 | _handle = None 161 | _root = None 162 | 163 | @staticmethod 164 | def init(output_dir, name, phase): 165 | time_str = datetime.datetime.now().strftime('%Y%m%d%H%M') 166 | format = '[%(asctime)s %(filename)s:%(lineno)d %(levelname)s {}] ' \ 167 | '%(message)s'.format(name) 168 | logging.basicConfig(level=logging.INFO, format=format) 169 | 170 | try: os.makedirs(output_dir) 171 | except: pass 172 | config_path = os.path.join(output_dir,"{}_{}.txt".format(time_str,phase)) 173 | Logger._handle = logging.FileHandler(config_path) 174 | Logger._root = logging.getLogger() 175 | 176 | @staticmethod 177 | def enable_file(): 178 | if Logger._handle is None or Logger._root is None: 179 | raise Exception('Invoke Logger.init() first!') 180 | Logger._root.addHandler(Logger._handle) 181 | 182 | @staticmethod 183 | def disable_file(): 184 | if Logger._handle is None or Logger._root is None: 185 | raise Exception('Invoke Logger.init() first!') 186 | Logger._root.removeHandler(Logger._handle) 187 | 188 | 189 | class Config(object): 190 | 191 | def __init__(self, config_path, host=True): 192 | def __dict2attr(d, prefix=''): 193 | for k, v in d.items(): 194 | if isinstance(v, dict): 195 | __dict2attr(v, f'{prefix}{k}_') 196 | else: 197 | if k == 'phase': 198 | assert v in ['train', 'test'] 199 | if k == 'stage': 200 | assert v in ['pretrain-vision', 'pretrain-language', 201 | 'train-semi-super', 'train-super'] 202 | self.__setattr__(f'{prefix}{k}', v) 203 | 204 | assert os.path.exists(config_path), '%s does not exists!' % config_path 205 | with open(config_path) as file: 206 | config_dict = yaml.load(file, Loader=yaml.FullLoader) 207 | with open('configs/template.yaml') as file: 208 | default_config_dict = yaml.load(file, Loader=yaml.FullLoader) 209 | __dict2attr(default_config_dict) 210 | __dict2attr(config_dict) 211 | 212 | def __getattr__(self, item): 213 | attr = self.__dict__.get(item) 214 | if attr is None: 215 | attr = dict() 216 | prefix = f'{item}_' 217 | for k, v in self.__dict__.items(): 218 | if k.startswith(prefix): 219 | n = k.replace(prefix, '') 220 | attr[n] = v 221 | return attr if len(attr) > 0 else None 222 | else: 223 | return attr 224 | 225 | def __repr__(self): 226 | str = 'ModelConfig(\n' 227 | for i, (k, v) in enumerate(sorted(vars(self).items())): 228 | str += f'\t({i}): {k} = {v}\n' 229 | str += ')' 230 | return str 231 | 232 | def blend_mask(image, mask, alpha=0.5, cmap='jet', color='b', color_alpha=1.0): 233 | # normalize mask 234 | mask = (mask-mask.min()) / (mask.max() - mask.min() + np.finfo(float).eps) 235 | if mask.shape != image.shape: 236 | mask = cv2.resize(mask,(image.shape[1], image.shape[0])) 237 | # get color map 238 | color_map = plt.get_cmap(cmap) 239 | mask = color_map(mask)[:,:,:3] 240 | # convert float to uint8 241 | mask = (mask * 255).astype(dtype=np.uint8) 242 | 243 | # set the basic color 244 | basic_color = np.array(colors.to_rgb(color)) * 255 245 | basic_color = np.tile(basic_color, [image.shape[0], image.shape[1], 1]) 246 | basic_color = basic_color.astype(dtype=np.uint8) 247 | # blend with basic color 248 | blended_img = cv2.addWeighted(image, color_alpha, basic_color, 1-color_alpha, 0) 249 | # blend with mask 250 | blended_img = cv2.addWeighted(blended_img, alpha, mask, 1-alpha, 0) 251 | 252 | return blended_img 253 | 254 | def onehot(label, depth, device=None): 255 | """ 256 | Args: 257 | label: shape (n1, n2, ..., ) 258 | depth: a scalar 259 | 260 | Returns: 261 | onehot: (n1, n2, ..., depth) 262 | """ 263 | if not isinstance(label, torch.Tensor): 264 | label = torch.tensor(label, device=device) 265 | onehot = torch.zeros(label.size() + torch.Size([depth]), device=device) 266 | onehot = onehot.scatter_(-1, label.unsqueeze(-1), 1) 267 | 268 | return onehot 269 | 270 | class MyDataParallel(nn.DataParallel): 271 | 272 | def gather(self, outputs, target_device): 273 | r""" 274 | Gathers tensors from different GPUs on a specified device 275 | (-1 means the CPU). 276 | """ 277 | def gather_map(outputs): 278 | out = outputs[0] 279 | if isinstance(out, (str, int, float)): 280 | return out 281 | if isinstance(out, list) and isinstance(out[0], str): 282 | return [o for out in outputs for o in out] 283 | if isinstance(out, torch.Tensor): 284 | return torch.nn.parallel._functions.Gather.apply(target_device, self.dim, *outputs) 285 | if out is None: 286 | return None 287 | if isinstance(out, dict): 288 | if not all((len(out) == len(d) for d in outputs)): 289 | raise ValueError('All dicts must have the same number of keys') 290 | return type(out)(((k, gather_map([d[k] for d in outputs])) 291 | for k in out)) 292 | return type(out)(map(gather_map, zip(*outputs))) 293 | 294 | # Recursive function calls like this create reference cycles. 295 | # Setting the function to None clears the refcycle. 296 | try: 297 | res = gather_map(outputs) 298 | finally: 299 | gather_map = None 300 | return res 301 | 302 | 303 | class MyConcatDataset(ConcatDataset): 304 | def __getattr__(self, k): 305 | return getattr(self.datasets[0], k) 306 | 307 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from .labelmaps import * 3 | from .util import str_filt 4 | 5 | -------------------------------------------------------------------------------- /utils/labelmaps.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import torch 3 | import string 4 | 5 | 6 | def get_vocabulary(voc_type, EOS='EOS', PADDING='PADDING', UNKNOWN='UNKNOWN'): 7 | ''' 8 | voc_type: str: one of 'LOWERCASE', 'ALLCASES', 'ALLCASES_SYMBOLS' 9 | ''' 10 | voc = None 11 | types = ['digit', 'lower', 'upper', 'all'] 12 | if voc_type == 'digit': 13 | voc = list(string.digits) 14 | elif voc_type == 'lower': 15 | voc = list(string.digits + string.ascii_lowercase) 16 | elif voc_type == 'upper': 17 | voc = list(string.digits + string.ascii_letters) 18 | elif voc_type == 'all': 19 | voc = list(string.digits + string.ascii_letters + string.punctuation) 20 | elif voc_type == 'chinese': 21 | voc = list(open("al_chinese.txt", "r").readlines()[0].replace("\n", "")) 22 | else: 23 | raise KeyError('voc_type Error') 24 | 25 | # update the voc with specifical chars 26 | voc.append(EOS) 27 | voc.append(PADDING) 28 | voc.append(UNKNOWN) 29 | 30 | return voc 31 | 32 | 33 | ## param voc: the list of vocabulary 34 | def char2id(voc): 35 | return dict(zip(voc, range(len(voc)))) 36 | 37 | 38 | def id2char(voc): 39 | return dict(zip(range(len(voc)), voc)) 40 | 41 | 42 | def labels2strs(labels, id2char, char2id): 43 | # labels: batch_size x len_seq 44 | if labels.ndimension() == 1: 45 | labels = labels.unsqueeze(0) 46 | assert labels.dim() == 2 47 | labels = to_numpy(labels) 48 | strings = [] 49 | batch_size = labels.shape[0] 50 | 51 | for i in range(batch_size): 52 | label = labels[i] 53 | string = [] 54 | for l in label: 55 | if l == char2id['EOS']: 56 | break 57 | else: 58 | string.append(id2char[l]) 59 | string = ''.join(string) 60 | strings.append(string) 61 | 62 | return strings 63 | 64 | 65 | def to_numpy(tensor): 66 | if torch.is_tensor(tensor): 67 | return tensor.cpu().numpy() 68 | elif type(tensor).__module__ != 'numpy': 69 | raise ValueError("Cannot convert {} to numpy array" 70 | .format(type(tensor))) 71 | return tensor 72 | 73 | 74 | def to_torch(ndarray): 75 | if type(ndarray).__module__ == 'numpy': 76 | return torch.from_numpy(ndarray) 77 | elif not torch.is_tensor(ndarray): 78 | raise ValueError("Cannot convert {} to torch tensor" 79 | .format(type(ndarray))) 80 | return ndarray 81 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import numpy as np 4 | import editdistance 5 | import string 6 | import math 7 | from IPython import embed 8 | import torch 9 | import torch.nn.functional as F 10 | import sys 11 | sys.path.append('../') 12 | from utils import to_torch, to_numpy 13 | 14 | def get_string_abinet(v_res,abi_charset): 15 | # 输入是文本识别结果的logits,输出是字符串列表 16 | # TODO 这里只使用v_res,只后还可以再加上a_res 17 | # {'feature': attn_vecs, 'logits': logits, 'pt_lengths': pt_lengths,'attn_scores': attn_scores, 'name': 'vision',' backbone_feature': features} 18 | logit = v_res['logits'] 19 | out = F.softmax(logit, dim=2) 20 | pt_text, pt_scores, pt_lengths = [], [], [] 21 | for o in out: 22 | text = abi_charset.get_text(o.argmax(dim=1), padding=False, trim=False) 23 | text = text.split(abi_charset.null_char)[0] # end at end-token 本质上也是实现了trim的功能 24 | pt_text.append(text) 25 | pt_scores.append(o.max(dim=1)[0]) 26 | pt_lengths.append(min(len(text) + 1, abi_charset.max_length)) # one for end-token 27 | return pt_text, pt_scores, pt_lengths 28 | 29 | 30 | def _normalize_text(text): 31 | text = ''.join(filter(lambda x: x in (string.digits + string.ascii_letters), text)) 32 | return text.lower() 33 | 34 | 35 | def get_string_parseq(output_dict,tokenizer): 36 | logits = output_dict['logits'] 37 | probs = logits.softmax(-1) 38 | preds, probs = tokenizer.decode(probs) 39 | return preds 40 | 41 | def get_string_aster(output, target, dataset=None): 42 | # label_seq 43 | assert output.dim() == 2 and target.dim() == 2 44 | 45 | end_label = dataset.char2id[dataset.EOS] 46 | unknown_label = dataset.char2id[dataset.UNKNOWN] 47 | num_samples, max_len_labels = output.size() 48 | num_classes = len(dataset.char2id.keys()) 49 | assert num_samples == target.size(0) and max_len_labels == target.size(1) 50 | output = to_numpy(output) 51 | target = to_numpy(target) 52 | 53 | # list of char list 54 | pred_list, targ_list = [], [] 55 | for i in range(num_samples): 56 | pred_list_i = [] 57 | for j in range(max_len_labels): 58 | if output[i, j] != end_label: 59 | if output[i, j] != unknown_label: 60 | try: 61 | pred_list_i.append(dataset.id2char[output[i, j]]) 62 | except: 63 | embed(header='problem') 64 | else: 65 | break 66 | pred_list.append(pred_list_i) 67 | 68 | for i in range(num_samples): 69 | targ_list_i = [] 70 | for j in range(max_len_labels): 71 | if target[i, j] != end_label: 72 | if target[i, j] != unknown_label: 73 | targ_list_i.append(dataset.id2char[target[i, j]]) 74 | else: 75 | break 76 | targ_list.append(targ_list_i) 77 | 78 | # char list to string 79 | # if dataset.lowercase: 80 | if True: 81 | # pred_list = [''.join(pred).lower() for pred in pred_list] 82 | # targ_list = [''.join(targ).lower() for targ in targ_list] 83 | pred_list = [_normalize_text(pred) for pred in pred_list] 84 | targ_list = [_normalize_text(targ) for targ in targ_list] 85 | else: 86 | pred_list = [''.join(pred) for pred in pred_list] 87 | targ_list = [''.join(targ) for targ in targ_list] 88 | 89 | return pred_list, targ_list 90 | 91 | 92 | def get_string_crnn(outputs_, use_chinese=False, alphabet='-0123456789abcdefghijklmnopqrstuvwxyz'): 93 | outputs = outputs_.permute(1, 0, 2).contiguous() 94 | predict_result = [] 95 | 96 | if use_chinese: 97 | alphabet = open("al_chinese.txt", 'r').readlines()[0].replace("\n", "") 98 | 99 | for output in outputs: 100 | max_index = torch.max(output, 1)[1] 101 | 102 | out_str = "" 103 | last = "" 104 | for i in max_index: 105 | if alphabet[i] != last: 106 | if i != 0: 107 | out_str += alphabet[i] 108 | last = alphabet[i] 109 | else: 110 | last = "" 111 | 112 | predict_result.append(out_str) 113 | return predict_result 114 | 115 | 116 | def _lexicon_search(lexicon, word): 117 | edit_distances = [] 118 | for lex_word in lexicon: 119 | edit_distances.append(editdistance.eval(_normalize_text(lex_word), _normalize_text(word))) 120 | edit_distances = np.asarray(edit_distances, dtype=np.int) 121 | argmin = np.argmin(edit_distances) 122 | return lexicon[argmin] 123 | 124 | 125 | def Accuracy(output, target, dataset=None): 126 | pred_list, targ_list = get_string_aster(output, target, dataset) 127 | 128 | acc_list = [(pred == targ) for pred, targ in zip(pred_list, targ_list)] 129 | accuracy = 1.0 * sum(acc_list) / len(acc_list) 130 | return accuracy 131 | 132 | 133 | def Accuracy_with_lexicon(output, target, dataset=None, file_names=None): 134 | pred_list, targ_list = get_string_aster(output, target, dataset) 135 | accuracys = [] 136 | 137 | # with no lexicon 138 | acc_list = [(pred == targ) for pred, targ in zip(pred_list, targ_list)] 139 | accuracy = 1.0 * sum(acc_list) / len(acc_list) 140 | accuracys.append(accuracy) 141 | 142 | # lexicon50 143 | if len(file_names) == 0 or len(dataset.lexicons50[file_names[0]]) == 0: 144 | accuracys.append(0) 145 | else: 146 | refined_pred_list = [_lexicon_search(dataset.lexicons50[file_name], pred) for file_name, pred in zip(file_names, pred_list)] 147 | acc_list = [(pred == targ) for pred, targ in zip(refined_pred_list, targ_list)] 148 | accuracy = 1.0 * sum(acc_list) / len(acc_list) 149 | accuracys.append(accuracy) 150 | 151 | # lexicon1k 152 | if len(file_names) == 0 or len(dataset.lexicons1k[file_names[0]]) == 0: 153 | accuracys.append(0) 154 | else: 155 | refined_pred_list = [_lexicon_search(dataset.lexicons1k[file_name], pred) for file_name, pred in zip(file_names, pred_list)] 156 | acc_list = [(pred == targ) for pred, targ in zip(refined_pred_list, targ_list)] 157 | accuracy = 1.0 * sum(acc_list) / len(acc_list) 158 | accuracys.append(accuracy) 159 | 160 | # lexiconfull 161 | if len(file_names) == 0 or len(dataset.lexiconsfull[file_names[0]]) == 0: 162 | accuracys.append(0) 163 | else: 164 | refined_pred_list = [_lexicon_search(dataset.lexiconsfull[file_name], pred) for file_name, pred in zip(file_names, pred_list)] 165 | acc_list = [(pred == targ) for pred, targ in zip(refined_pred_list, targ_list)] 166 | accuracy = 1.0 * sum(acc_list) / len(acc_list) 167 | accuracys.append(accuracy) 168 | 169 | return accuracys 170 | 171 | 172 | def EditDistance(output, target, dataset=None): 173 | pred_list, targ_list = get_string_aster(output, target, dataset) 174 | 175 | ed_list = [editdistance.eval(pred, targ) for pred, targ in zip(pred_list, targ_list)] 176 | eds = sum(ed_list) 177 | return eds 178 | 179 | 180 | def EditDistance_with_lexicon(output, target, dataset=None, file_names=None): 181 | pred_list, targ_list = get_string_aster(output, target, dataset) 182 | eds = [] 183 | 184 | # with no lexicon 185 | ed_list = [editdistance.eval(pred, targ) for pred, targ in zip(pred_list, targ_list)] 186 | ed = sum(ed_list) 187 | eds.append(ed) 188 | 189 | # lexicon50 190 | if len(file_names) == 0 or len(dataset.lexicons50[file_names[0]]) == 0: 191 | eds.append(0) 192 | else: 193 | refined_pred_list = [_lexicon_search(dataset.lexicons50[file_name], pred) for file_name, pred in zip(file_names, pred_list)] 194 | ed_list = [editdistance.eval(pred, targ) for pred, targ in zip(refined_pred_list, targ_list)] 195 | ed = sum(ed_list) 196 | eds.append(ed) 197 | 198 | # lexicon1k 199 | if len(file_names) == 0 or len(dataset.lexicons1k[file_names[0]]) == 0: 200 | eds.append(0) 201 | else: 202 | refined_pred_list = [_lexicon_search(dataset.lexicons1k[file_name], pred) for file_name, pred in zip(file_names, pred_list)] 203 | ed_list = [editdistance.eval(pred, targ) for pred, targ in zip(refined_pred_list, targ_list)] 204 | ed = sum(ed_list) 205 | eds.append(ed) 206 | 207 | # lexiconfull 208 | if len(file_names) == 0 or len(dataset.lexiconsfull[file_names[0]]) == 0: 209 | eds.append(0) 210 | else: 211 | refined_pred_list = [_lexicon_search(dataset.lexiconsfull[file_name], pred) for file_name, pred in zip(file_names, pred_list)] 212 | ed_list = [editdistance.eval(pred, targ) for pred, targ in zip(refined_pred_list, targ_list)] 213 | ed = sum(ed_list) 214 | eds.append(ed) 215 | 216 | return eds 217 | 218 | 219 | def RecPostProcess(output, target, score, dataset=None): 220 | pred_list, targ_list = get_string_aster(output, target, dataset) 221 | max_len_labels = output.size(1) 222 | score_list = [] 223 | 224 | score = to_numpy(score) 225 | for i, pred in enumerate(pred_list): 226 | len_pred = len(pred) + 1 # eos should be included 227 | len_pred = min(max_len_labels, len_pred) # maybe the predicted string don't include a eos. 228 | score_i = score[i,:len_pred] 229 | score_i = math.exp(sum(map(math.log, score_i))) 230 | score_list.append(score_i) 231 | return pred_list, targ_list, score_list -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # encoding: utf-8 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import Variable 7 | import collections 8 | import string 9 | from IPython import embed 10 | 11 | 12 | def str_filt(str_, voc_type): 13 | alpha_dict = { 14 | 'digit': string.digits, 15 | 'lower': string.digits + string.ascii_lowercase, 16 | 'upper': string.digits + string.ascii_letters, 17 | 'all': string.digits + string.ascii_letters + string.punctuation, 18 | 'chinese': open("al_chinese.txt", "r",encoding='UTF-8').readlines()[0].replace("\n", "") 19 | } 20 | if voc_type == 'lower': 21 | str_ = str_.lower() 22 | 23 | if voc_type == 'chinese': # Chinese character only 24 | new_str = "" 25 | for ch in str_: 26 | if '\u4e00' <= ch <= '\u9fa5' or ch in string.digits + string.ascii_letters: 27 | new_str += ch 28 | str_ = new_str 29 | for char in str_: 30 | if char not in alpha_dict[voc_type]: #voc_type 31 | str_ = str_.replace(char, '') 32 | return str_ 33 | 34 | 35 | class strLabelConverter(object): 36 | """Convert between str and label. 37 | 38 | NOTE: 39 | Insert `blank` to the alphabet for CTC. 40 | 41 | Args: 42 | alphabet (str): set of the possible characters. 43 | ignore_case (bool, default=True): whether or not to ignore all of the case. 44 | """ 45 | 46 | def __init__(self, alphabet): 47 | self.alphabet = alphabet + '-' # for `-1` index 48 | 49 | self.dict = {} 50 | for i, char in enumerate(alphabet): 51 | # NOTE: 0 is reserved for 'blank' required by wrap_ctc 52 | self.dict[char] = i + 1 53 | 54 | def encode(self, text): 55 | """Support batch or single str. 56 | 57 | Args: 58 | text (str or list of str): texts to convert. 59 | 60 | Returns: 61 | torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts. 62 | torch.IntTensor [n]: length of each text. 63 | """ 64 | if isinstance(text, str): 65 | from IPython import embed 66 | # embed() 67 | text = [ 68 | self.dict[char] 69 | for char in text 70 | ] 71 | length = [len(text)] 72 | elif isinstance(text, collections.Iterable): 73 | length = [len(s) for s in text] 74 | text = ''.join(text) 75 | text, _ = self.encode(text) 76 | return (torch.IntTensor(text), torch.IntTensor(length)) 77 | 78 | def decode(self, t, length, raw=False): 79 | """Decode encoded texts back into strs. 80 | 81 | Args: 82 | torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts. 83 | torch.IntTensor [n]: length of each text. 84 | 85 | Raises: 86 | AssertionError: when the texts and its length does not match. 87 | 88 | Returns: 89 | text (str or list of str): texts to convert. 90 | """ 91 | if length.numel() == 1: 92 | length = length[0] 93 | assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(), length) 94 | if raw: 95 | return ''.join([self.alphabet[i - 1] for i in t]) 96 | else: 97 | char_list = [] 98 | for i in range(length): 99 | if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): 100 | char_list.append(self.alphabet[t[i] - 1]) 101 | return ''.join(char_list) 102 | else: 103 | # batch mode 104 | assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(t.numel(), length.sum()) 105 | texts = [] 106 | index = 0 107 | for i in range(length.numel()): 108 | l = length[i] 109 | texts.append( 110 | self.decode( 111 | t[index:index + l], torch.IntTensor([l]), raw=raw)) 112 | index += l 113 | return texts 114 | 115 | 116 | class averager(object): 117 | """Compute average for `torch.Variable` and `torch.Tensor`. """ 118 | 119 | def __init__(self): 120 | self.reset() 121 | 122 | def add(self, v): 123 | if isinstance(v, Variable): 124 | count = v.data.numel() 125 | v = v.data.sum() 126 | elif isinstance(v, torch.Tensor): 127 | count = v.numel() 128 | v = v.sum() 129 | 130 | self.n_count += count 131 | self.sum += v 132 | 133 | def reset(self): 134 | self.n_count = 0 135 | self.sum = 0 136 | 137 | def val(self): 138 | res = 0 139 | if self.n_count != 0: 140 | res = self.sum / float(self.n_count) 141 | return res 142 | 143 | 144 | def oneHot(v, v_length, nc): 145 | batchSize = v_length.size(0) 146 | maxLength = v_length.max() 147 | v_onehot = torch.FloatTensor(batchSize, maxLength, nc).fill_(0) 148 | acc = 0 149 | for i in range(batchSize): 150 | length = v_length[i] 151 | label = v[acc:acc + length].view(-1, 1).long() 152 | v_onehot[i, :length].scatter_(1, label, 1.0) 153 | acc += length 154 | return v_onehot 155 | 156 | 157 | def loadData(v, data): 158 | # v.data.resize_(data.size()).copy_(data) 159 | v.resize_(data.size()).copy_(data) 160 | 161 | def prettyPrint(v): 162 | print('Size {0}, Type: {1}'.format(str(v.size()), v.data.type())) 163 | print('| Max: %f | Min: %f | Mean: %f' % (v.max().data[0], v.min().data[0], 164 | v.mean().data[0])) 165 | 166 | 167 | def assureRatio(img): 168 | """Ensure imgH <= imgW.""" 169 | b, c, h, w = img.size() 170 | if h > w: 171 | main = nn.UpsamplingBilinear2d(size=(h, h), scale_factor=None) 172 | img = main(img) 173 | return img 174 | 175 | 176 | if __name__=='__main__': 177 | converter = strLabelConverter(string.digits+string.ascii_lowercase) 178 | embed() -------------------------------------------------------------------------------- /utils/utils_crnn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # encoding: utf-8 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import Variable 7 | import collections 8 | 9 | 10 | class strLabelConverter(object): 11 | """Convert between str and label. 12 | 13 | NOTE: 14 | Insert `blank` to the alphabet for CTC. 15 | 16 | Args: 17 | alphabet (str): set of the possible characters. 18 | ignore_case (bool, default=True): whether or not to ignore all of the case. 19 | """ 20 | 21 | def __init__(self, alphabet, ignore_case=True): 22 | self._ignore_case = ignore_case 23 | if self._ignore_case: 24 | alphabet = alphabet.lower() 25 | self.alphabet = alphabet + '-' # for `-1` index 26 | 27 | self.dict = {} 28 | for i, char in enumerate(alphabet): 29 | # NOTE: 0 is reserved for 'blank' required by wrap_ctc 30 | self.dict[char] = i + 1 31 | 32 | def encode(self, text): 33 | """Support batch or single str. 34 | 35 | Args: 36 | text (str or list of str): texts to convert. 37 | 38 | Returns: 39 | torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts. 40 | torch.IntTensor [n]: length of each text. 41 | """ 42 | if isinstance(text, str): 43 | text = [ 44 | self.dict[char.lower() if self._ignore_case else char] 45 | for char in text 46 | ] 47 | length = [len(text)] 48 | elif isinstance(text, collections.Iterable): 49 | length = [len(s) for s in text] 50 | text = ''.join(text) 51 | text, _ = self.encode(text) 52 | return (torch.IntTensor(text), torch.IntTensor(length)) 53 | 54 | def decode(self, t, length, raw=False): 55 | """Decode encoded texts back into strs. 56 | 57 | Args: 58 | torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts. 59 | torch.IntTensor [n]: length of each text. 60 | 61 | Raises: 62 | AssertionError: when the texts and its length does not match. 63 | 64 | Returns: 65 | text (str or list of str): texts to convert. 66 | """ 67 | if length.numel() == 1: 68 | length = length[0] 69 | assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(), length) 70 | if raw: 71 | return ''.join([self.alphabet[i - 1] for i in t]) 72 | else: 73 | char_list = [] 74 | for i in range(length): 75 | if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): 76 | char_list.append(self.alphabet[t[i] - 1]) 77 | return ''.join(char_list) 78 | else: 79 | # batch mode 80 | assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(t.numel(), length.sum()) 81 | texts = [] 82 | index = 0 83 | for i in range(length.numel()): 84 | l = length[i] 85 | texts.append( 86 | self.decode( 87 | t[index:index + l], torch.IntTensor([l]), raw=raw)) 88 | index += l 89 | return texts 90 | 91 | 92 | class averager(object): 93 | """Compute average for `torch.Variable` and `torch.Tensor`. """ 94 | 95 | def __init__(self): 96 | self.reset() 97 | 98 | def add(self, v): 99 | if isinstance(v, Variable): 100 | count = v.data.numel() 101 | v = v.data.sum() 102 | elif isinstance(v, torch.Tensor): 103 | count = v.numel() 104 | v = v.sum() 105 | 106 | self.n_count += count 107 | self.sum += v 108 | 109 | def reset(self): 110 | self.n_count = 0 111 | self.sum = 0 112 | 113 | def val(self): 114 | res = 0 115 | if self.n_count != 0: 116 | res = self.sum / float(self.n_count) 117 | return res 118 | 119 | 120 | def oneHot(v, v_length, nc): 121 | batchSize = v_length.size(0) 122 | maxLength = v_length.max() 123 | v_onehot = torch.FloatTensor(batchSize, maxLength, nc).fill_(0) 124 | acc = 0 125 | for i in range(batchSize): 126 | length = v_length[i] 127 | label = v[acc:acc + length].view(-1, 1).long() 128 | v_onehot[i, :length].scatter_(1, label, 1.0) 129 | acc += length 130 | return v_onehot 131 | 132 | 133 | def loadData(v, data): 134 | v.data.resize_(data.size()).copy_(data) 135 | 136 | 137 | def prettyPrint(v): 138 | print('Size {0}, Type: {1}'.format(str(v.size()), v.data.type())) 139 | print('| Max: %f | Min: %f | Mean: %f' % (v.max().data[0], v.min().data[0], 140 | v.mean().data[0])) 141 | 142 | 143 | def assureRatio(img): 144 | """Ensure imgH <= imgW.""" 145 | b, c, h, w = img.size() 146 | if h > w: 147 | main = nn.UpsamplingBilinear2d(size=(h, h), scale_factor=None) 148 | img = main(img) 149 | return img 150 | -------------------------------------------------------------------------------- /utils/utils_moran.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import collections 5 | 6 | class strLabelConverterForAttention(object): 7 | """Convert between str and label. 8 | 9 | NOTE: 10 | Insert `EOS` to the alphabet for attention. 11 | 12 | Args: 13 | alphabet (str): set of the possible characters. 14 | ignore_case (bool, default=True): whether or not to ignore all of the case. 15 | """ 16 | 17 | def __init__(self, alphabet, sep): 18 | self._scanned_list = False 19 | self._out_of_list = '' 20 | self._ignore_case = True 21 | self.sep = sep 22 | self.alphabet = alphabet.split(sep) 23 | 24 | self.dict = {} 25 | for i, item in enumerate(self.alphabet): 26 | self.dict[item] = i 27 | 28 | def scan(self, text): 29 | # print(text) 30 | text_tmp = text 31 | text = [] 32 | for i in range(len(text_tmp)): 33 | text_result = '' 34 | for j in range(len(text_tmp[i])): 35 | chara = text_tmp[i][j].lower() if self._ignore_case else text_tmp[i][j] 36 | if chara not in self.alphabet: 37 | if chara in self._out_of_list: 38 | continue 39 | else: 40 | self._out_of_list += chara 41 | file_out_of_list = open("out_of_list.txt", "a+") 42 | file_out_of_list.write(chara + "\n") 43 | file_out_of_list.close() 44 | print('" %s " is not in alphabet...' % chara) 45 | continue 46 | else: 47 | text_result += chara 48 | text.append(text_result) 49 | text_result = tuple(text) 50 | self._scanned_list = True 51 | return text_result 52 | 53 | def encode(self, text, scanned=True): 54 | """Support batch or single str. 55 | 56 | Args: 57 | text (str or list of str): texts to convert. 58 | 59 | Returns: 60 | torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts. 61 | torch.IntTensor [n]: length of each text. 62 | """ 63 | self._scanned_list = scanned 64 | if not self._scanned_list: 65 | text = self.scan(text) 66 | 67 | if isinstance(text, str): 68 | text = [ 69 | self.dict[char.lower() if self._ignore_case else char] 70 | for char in text 71 | ] 72 | length = [len(text)] 73 | elif isinstance(text, collections.Iterable): 74 | length = [len(s) for s in text] 75 | text = ''.join(text) 76 | text, _ = self.encode(text) 77 | return (torch.LongTensor(text), torch.LongTensor(length)) 78 | 79 | def decode(self, t, length): 80 | """Decode encoded texts back into strs. 81 | 82 | Args: 83 | torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts. 84 | torch.IntTensor [n]: length of each text. 85 | 86 | Raises: 87 | AssertionError: when the texts and its length does not match. 88 | 89 | Returns: 90 | text (str or list of str): texts to convert. 91 | """ 92 | if length.numel() == 1: 93 | length = length[0] 94 | assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(), length) 95 | return ''.join([self.alphabet[i] for i in t]) 96 | else: 97 | # batch mode 98 | assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(t.numel(), length.sum()) 99 | texts = [] 100 | index = 0 101 | for i in range(length.numel()): 102 | l = length[i] 103 | texts.append( 104 | self.decode( 105 | t[index:index + l], torch.LongTensor([l]))) 106 | index += l 107 | return texts 108 | 109 | class averager(object): 110 | """Compute average for `torch.Variable` and `torch.Tensor`. """ 111 | 112 | def __init__(self): 113 | self.reset() 114 | 115 | def add(self, v): 116 | if isinstance(v, Variable): 117 | count = v.data.numel() 118 | v = v.data.sum() 119 | elif isinstance(v, torch.Tensor): 120 | count = v.numel() 121 | v = v.sum() 122 | 123 | self.n_count += count 124 | self.sum += v 125 | 126 | def reset(self): 127 | self.n_count = 0 128 | self.sum = 0 129 | 130 | def val(self): 131 | res = 0 132 | if self.n_count != 0: 133 | res = self.sum / float(self.n_count) 134 | return res 135 | 136 | def loadData(v, data): 137 | major, _ = get_torch_version() 138 | 139 | if major >= 1: 140 | v.resize_(data.size()).copy_(data) 141 | else: 142 | v.data.resize_(data.size()).copy_(data) 143 | 144 | def get_torch_version(): 145 | """ 146 | Find pytorch version and return it as integers 147 | for major and minor versions 148 | """ 149 | torch_version = str(torch.__version__).split(".") 150 | return int(torch_version[0]), int(torch_version[1]) 151 | --------------------------------------------------------------------------------