├── IPACT_Data_sources.csv ├── IPFAQ_Data_sources.csv ├── LICENSE ├── README.md ├── SFT.py ├── crawlers └── wipo │ ├── scrapy.cfg │ └── wipo │ ├── __init__.py │ ├── items.py │ ├── middlewares.py │ ├── pipelines.py │ ├── settings.py │ └── spiders │ ├── __init__.py │ └── wipo_int.py ├── inference.py ├── patent_pretrain.py ├── run_SFT.sh ├── run_inference.sh ├── run_pretrain.sh └── src └── bnnt ├── __init__.py ├── data ├── __init__.py ├── data_utils.py ├── dev5K.json └── raw_datasets.py ├── ds_utils.py ├── model ├── __init__.py └── model_utils.py ├── module ├── __init__.py ├── lora.py └── test.ipynb └── utils.py /IPACT_Data_sources.csv: -------------------------------------------------------------------------------- 1 | source,processor,language,,,,,,,,,,,,,,,,,,, 2 | http://epaper.iprchn.com/zscqb/h5/html5/2023-04/21/content_27601_7600799.htm,minghuan,ZH,,,,,,,,,,,,,,,,,,, 3 | https://eduai.baidu.com/view/6f7c32b0d7d8d15abe23482fb4daa58da0111cfe,fuqiang,ZH,,,,,,,,,,,,,,,,,,, 4 | https://www.juxuewen.com/question/1159.html,minghuan,ZH,,,,,,,,,,,,,,,,,,, 5 | https://eduai.baidu.com/view/9c690b1d29f90242a8956bec0975f46527d3a7c9,fuqiang,ZH,,,,,,,,,,,,,,,,,,, 6 | https://wenku.baidu.com/view/daea1b562179168884868762caaedd3383c4b5d2.html?_wkts_=1683278279098,shiwen,ZH,,,,,,,,,,,,,,,,,,, 7 | http://www.dxh.gov.cn/hdjl/zxdc/zxdcxq/index.shtml?/personalCenter/answerSheet/answerSheet.html?metadataId=ff80808180b79f8a01830b603e8001f7&siteId=43,minghuan,ZH,,,,,,,,,,,,,,,,,,, 8 | http://guangxi.china.com.cn/2022-06/17/content_42002807.html,minghuan,ZH,,,,,,,,,,,,,,,,,,, 9 | http://www.educhenggong.com/Uploads/attached/file/20210324/1616572571695330.pdf,minghuan,ZH,,,,,,,,,,,,,,,,,,, 10 | https://www.gov.cn/guoqing/2021-10/29/content_5647633.htm,shiwen,ZH,,,,,,,,,,,,,,,,,,, 11 | https://www.gov.cn/zhengce/2020-12/26/content_5573623.htm,shiwen,ZH,,,,,,,,,,,,,,,,,,, 12 | https://www.gov.cn/gongbao/content/2000/content_60431.htm,shiwen,ZH,,,,,,,,,,,,,,,,,,, 13 | https://www.gov.cn/zhengce/2020-12/26/content_5574514.htm,shiwen,ZH,,,,,,,,,,,,,,,,,,, 14 | https://www.gov.cn/zhengce/2020-12/26/content_5573535.htm,shiwen,ZH,,,,,,,,,,,,,,,,,,, 15 | https://www.gov.cn/zhengce/2020-12/26/content_5573535.htm,shiwen,ZH,,,,,,,,,,,,,,,,,,, 16 | https://www.gov.cn/xinwen/2017-11/05/content_5237325.htm,shiwen,ZH,,,,,,,,,,,,,,,,,,, 17 | https://www.gov.cn/zhengce/2020-12/26/content_5574414.htm,shiwen,ZH,,,,,,,,,,,,,,,,,,, 18 | https://inside.nku.edu/content/dam/creative-thinking/docs/CT-Handouts-page/CT%20Handout%20Intellectual%20Property%20Quiz.pdf,minghuan,EN,,,,,,,,,,,,,,,,,,, 19 | https://www.wipo.int/ip-outreach/en/ipday/2017/ipday2017_quiz-copy.html,minghuan,EN,,,,,,,,,,,,,,,,,,, 20 | https://www.gastao.eu/ip-knowledge-test,minghuan,EN,,,,,,,,,,,,,,,,,,, 21 | https://www.proprofs.com/quiz-school/quizzes/intellectual-property-law-quiz,minghuan,EN,,,,,,,,,,,,,,,,,,, 22 | https://fr.surveymonkey.com/r/LW83BBV,minghuan,EN,,,,,,,,,,,,,,,,,,, 23 | https://www.riddle.com/view/57770?qzzr=1,minghuan,EN,,,,,,,,,,,,,,,,,,, 24 | https://about.lens.org/patent-knowledge-quiz/,minghuan,EN,,,,,,,,,,,,,,,,,,, 25 | https://www.examsegg.com/intellectual-property-rights-india-questions.html,fuqiang,EN,,,,,,,,,,,,,,,,,,, 26 | https://www.rkdewan.com/quizstart.php?qid=24,minghuan,EN,,,,,,,,,,,,,,,,,,, 27 | https://www.eduki.ch/en/quizz/intellectual-property-quiz,shiwen,EN,,,,,,,,,,,,,,,,,,, 28 | https://qpkendra.com/mcq/ipr-and-patenting-mcq-pg-1.html,fuqiang,EN,,,,,,,,,,,,,,,,,,, 29 | https://openstax.org/books/introduction-intellectual-property/pages/chapter-1,fuqiang,EN,,,,,,,,,,,,,,,,,,, 30 | https://www.lexifiche.com/quiz-propriete-intellectuelle-breve,fuqiang,fr,,,,,,,,,,,,,,,,,,, 31 | https://www.q-net.or.kr/cst003.do?id=cst00309&gSite=L&gId=51,fuqiang,Korean,,,,,,,,,,,,,,,,,,, 32 | https://www.agaroot.jp/benri/column/past-questions/,fuqiang,JP,,,,,,,,,,,,,,,,,,, 33 | https://www.geo.de/wissen/quiz/wissenstest-erfindungen-i-30201276.html,minghuan,de,,,,,,,,,,,,,,,,,,, 34 | https://www.geo.de/wissen/quiz/wissenstest-erfindungen-ii-30201270.html,minghuan,de,,,,,,,,,,,,,,,,,,, 35 | https://www.eduki.ch/de/quizz/quiz-geistiges-eigentum,minghuan,de,,,,,,,,,,,,,,,,,,, 36 | https://www.fsgu-akademie.de/quiz/geistiges-eigentum-teil-1/,fuqiang,de,,,,,,,,,,,,,,,,,,, 37 | https://www.fsgu-akademie.de/quiz/geistiges-eigentum-teil-2/,fuqiang,de,,,,,,,,,,,,,,,,,,, 38 | https://www.fsgu-akademie.de/quiz/geistiges-eigentum-teil-3/,fuqiang,de,,,,,,,,,,,,,,,,,,, 39 | https://www.fsgu-akademie.de/quiz/geistiges-eigentum-teil-4/,fuqiang,de,,,,,,,,,,,,,,,,,,, 40 | https://www.fsgu-akademie.de/quiz/geistiges-eigentum-teil-5/,fuqiang,de,,,,,,,,,,,,,,,,,,, 41 | https://www.fsgu-akademie.de/quiz/geistiges-eigentum-teil-6/,fuqiang,de,,,,,,,,,,,,,,,,,,, 42 | https://www.fsgu-akademie.de/quiz/geistiges-eigentum-teil-7/,fuqiang,de,,,,,,,,,,,,,,,,,,, 43 | https://www.fsgu-akademie.de/quiz/geistiges-eigentum-teil-8/,fuqiang,de,,,,,,,,,,,,,,,,,,, 44 | https://www.fsgu-akademie.de/quiz/geistiges-eigentum-teil-9/,fuqiang,de,,,,,,,,,,,,,,,,,,, 45 | https://www.fsgu-akademie.de/quiz/geistiges-eigentum-teil-10/,minghuan,de,,,,,,,,,,,,,,,,,,, 46 | https://www.fsgu-akademie.de/quiz/design-schutzrecht-teil-1/,fuqiang,de,,,,,,,,,,,,,,,,,,, 47 | https://www.fsgu-akademie.de/quiz/design-schutzrecht-teil-2/,fuqiang,de,,,,,,,,,,,,,,,,,,, 48 | https://www.fsgu-akademie.de/quiz/design-schutzrecht-teil-3/,fuqiang,de,,,,,,,,,,,,,,,,,,, 49 | https://www.fsgu-akademie.de/quiz/design-schutzrecht-teil-4/,fuqiang,de,,,,,,,,,,,,,,,,,,, 50 | https://www.fsgu-akademie.de/quiz/handelsmarke-teil-1/,fuqiang,de,,,,,,,,,,,,,,,,,,, 51 | https://www.fsgu-akademie.de/quiz/handelsmarke-teil-2/,fuqiang,de,,,,,,,,,,,,,,,,,,, 52 | https://www.fsgu-akademie.de/quiz/handelsmarke-teil-3/,fuqiang,de,,,,,,,,,,,,,,,,,,, 53 | https://www.fsgu-akademie.de/quiz/handelsmarke-teil-4/,fuqiang,de,,,,,,,,,,,,,,,,,,, 54 | https://www.fsgu-akademie.de/quiz/patentrecht-deutschland-teil-1/,fuqiang,de,,,,,,,,,,,,,,,,,,, 55 | https://www.fsgu-akademie.de/quiz/patentrecht-deutschland-teil-2/,fuqiang,de,,,,,,,,,,,,,,,,,,, 56 | https://www.fsgu-akademie.de/quiz/patentrecht-deutschland-teil-3/,fuqiang,de,,,,,,,,,,,,,,,,,,, 57 | https://www.fsgu-akademie.de/quiz/patentrecht-deutschland-teil-4/,fuqiang,de,,,,,,,,,,,,,,,,,,, 58 | https://www.fsgu-akademie.de/quiz/urheberrecht-teil-1/,fuqiang,de,,,,,,,,,,,,,,,,,,, 59 | https://www.fsgu-akademie.de/quiz/urheberrecht-teil-2/,fuqiang,de,,,,,,,,,,,,,,,,,,, 60 | https://www.fsgu-akademie.de/quiz/urheberrecht-teil-3/,fuqiang,de,,,,,,,,,,,,,,,,,,, 61 | https://www.fsgu-akademie.de/quiz/urheberrecht-teil-4/,fuqiang,de,,,,,,,,,,,,,,,,,,, 62 | ,,,,,,,,,,,,,,,,,,,,, 63 | ,,,,,,,,,,,,,,,,,,,,, 64 | ,,,,,,,,,,,,,,,,,,,,, 65 | ,,,,,,,,,,,,,,,,,,,,, 66 | ,,,,,,,,,,,,,,,,,,,,, 67 | ,,,,,,,,,,,,,,,,,,,,, 68 | ,,,,,,,,,,,,,,,,,,,,, 69 | ,,,,,,,,,,,,,,,,,,,,, 70 | ,,,,,,,,,,,,,,,,,,,,, 71 | ,,,,,,,,,,,,,,,,,,,,, 72 | ,,,,,,,,,,,,,,,,,,,,, 73 | ,,,,,,,,,,,,,,,,,,,,, 74 | ,,,,,,,,,,,,,,,,,,,,, 75 | ,,,,,,,,,,,,,,,,,,,,, 76 | ,,,,,,,,,,,,,,,,,,,,, 77 | ,,,,,,,,,,,,,,,,,,,,, 78 | ,,,,,,,,,,,,,,,,,,,,, 79 | ,,,,,,,,,,,,,,,,,,,,, 80 | ,,,,,,,,,,,,,,,,,,,,, 81 | ,,,,,,,,,,,,,,,,,,,,, 82 | ,,,,,,,,,,,,,,,,,,,,, 83 | ,,,,,,,,,,,,,,,,,,,,, 84 | ,,,,,,,,,,,,,,,,,,,,, 85 | ,,,,,,,,,,,,,,,,,,,,, 86 | ,,,,,,,,,,,,,,,,,,,,, 87 | ,,,,,,,,,,,,,,,,,,,,, 88 | ,,,,,,,,,,,,,,,,,,,,, 89 | ,,,,,,,,,,,,,,,,,,,,, 90 | ,,,,,,,,,,,,,,,,,,,,, 91 | ,,,,,,,,,,,,,,,,,,,,, 92 | ,,,,,,,,,,,,,,,,,,,,, 93 | ,,,,,,,,,,,,,,,,,,,,, 94 | ,,,,,,,,,,,,,,,,,,,,, 95 | ,,,,,,,,,,,,,,,,,,,,, 96 | ,,,,,,,,,,,,,,,,,,,,, 97 | ,,,,,,,,,,,,,,,,,,,,, 98 | ,,,,,,,,,,,,,,,,,,,,, 99 | ,,,,,,,,,,,,,,,,,,,,, 100 | ,,,,,,,,,,,,,,,,,,,,, 101 | ,,,,,,,,,,,,,,,,,,,,, 102 | ,,,,,,,,,,,,,,,,,,,,, 103 | ,,,,,,,,,,,,,,,,,,,,, 104 | ,,,,,,,,,,,,,,,,,,,,, 105 | ,,,,,,,,,,,,,,,,,,,,, 106 | ,,,,,,,,,,,,,,,,,,,,, 107 | ,,,,,,,,,,,,,,,,,,,,, 108 | ,,,,,,,,,,,,,,,,,,,,, 109 | ,,,,,,,,,,,,,,,,,,,,, 110 | ,,,,,,,,,,,,,,,,,,,,, 111 | ,,,,,,,,,,,,,,,,,,,,, 112 | ,,,,,,,,,,,,,,,,,,,,, 113 | ,,,,,,,,,,,,,,,,,,,,, 114 | ,,,,,,,,,,,,,,,,,,,,, 115 | ,,,,,,,,,,,,,,,,,,,,, 116 | ,,,,,,,,,,,,,,,,,,,,, 117 | ,,,,,,,,,,,,,,,,,,,,, 118 | ,,,,,,,,,,,,,,,,,,,,, 119 | ,,,,,,,,,,,,,,,,,,,,, 120 | ,,,,,,,,,,,,,,,,,,,,, 121 | ,,,,,,,,,,,,,,,,,,,,, 122 | ,,,,,,,,,,,,,,,,,,,,, 123 | ,,,,,,,,,,,,,,,,,,,,, 124 | ,,,,,,,,,,,,,,,,,,,,, 125 | ,,,,,,,,,,,,,,,,,,,,, 126 | ,,,,,,,,,,,,,,,,,,,,, 127 | ,,,,,,,,,,,,,,,,,,,,, 128 | ,,,,,,,,,,,,,,,,,,,,, 129 | ,,,,,,,,,,,,,,,,,,,,, 130 | ,,,,,,,,,,,,,,,,,,,,, 131 | ,,,,,,,,,,,,,,,,,,,,, 132 | -------------------------------------------------------------------------------- /IPFAQ_Data_sources.csv: -------------------------------------------------------------------------------- 1 | source,processor,language 2 | https://www.pinlicheng.com/question/zlwd,fuqiang,zh 3 | https://www.cnipa.gov.cn/jact/front/mailpublist.do?sysid=16,fuqiang,zh 4 | https://zscqj.cq.gov.cn/zwgk_232/zczxwdk/,fuqiang,zh 5 | http://scjg.hangzhou.gov.cn/col/col1229574855/index.html,fuqiang,zh 6 | https://www.marks-clerk.com/zh-hans/%E8%A7%82%E7%82%B9/%E8%A7%81%E8%A7%A3/%E9%97%AE%E7%AD%94-%E6%AC%A7%E6%B4%B2%E5%8D%95%E4%B8%80%E4%B8%93%E5%88%A9%E5%92%8C%E7%BB%9F%E4%B8%80%E4%B8%93%E5%88%A9%E6%B3%95%E9%99%A2/,fuqiang,zh 7 | https://www.texunip.com/article/wenda/,fuqiang,zh 8 | https://www.wipo.int/edocs/pubdocs/zh/wipo_pub_1056.pdf,minghuan,zh 9 | https://kyc.hsu.edu.cn/_upload/article/files/d1/2b/12be920a4b49b59f9913c44094c4/2d70474a-31e5-4fe8-ab09-f8bb8f08ee82.pdf,shiwen,zh 10 | https://kyy.bupt.edu.cn/__local/0/75/C4/7487C790ECD665806021DC36FD2_427D287F_832DB.pdf?e=.pdf,shiwen,zh 11 | http://www.gov.cn/xinwen/2021-04/25/5602104/files/9cfbfa3fed814e1f9d04e56959ed13fb.pdf,shiwen,zh 12 | http://www.ipo.dicp.ac.cn/info/1032/2680.htm,minghuan,zh 13 | https://hb.hainanu.edu.cn/zscqcg/info/1157/1551.htm,minghuan,zh 14 | http://www.evcrrc.com/news/233.html,minghuan,zh 15 | https://eduai.baidu.com/view/3d0ed9b566ce0508763231126edb6f1aff00712b,fuqiang,zh 16 | https://eduai.baidu.com/view/b2d3491e964bcf84b9d57ba0,fuqiang,zh 17 | https://eduai.baidu.com/view/e5f20c61f4ec4afe04a1b0717fd5360cba1a8d35,fuqiang,zh 18 | https://kycyc.yctei.cn/2017/0531/c1427a30002/page.htm,minghuan,zh 19 | https://eduai.baidu.com/view/dcb16222f48a6529647d27284b73f242336c3122,fuqiang,zh 20 | https://mp.weixin.qq.com/s?__biz=MzA3MTExMzIyMg==&mid=2652052792&idx=1&sn=e8e0c45d2cda379ad26f7521cf4c08cc&chksm=84d53251b3a2bb47f73472c250768d3a27586a7bc0b6a98f3fd402a8770d853e69195c7a059f&scene=27,minghuan,zh 21 | https://www.government.nl/topics/intellectual-property/question-and-answer,fuqiang,en 22 | https://www.wipo.int/edocs/pubdocs/en/wipo_pub_1056.pdf,minghuan,en 23 | https://www.wipo.int/patents/en/faq_patents.html,fuqiang,en 24 | https://www.researchgate.net/search/question?q=patent,fuqiang,en 25 | https://www.upcounsel.com/patent-questions-and-answers,fuqiang,en 26 | https://www.uspto.gov/help/patent-help#type-browse-faqs_3246,fuqiang,en 27 | https://www.wipo.int/classifications/ipc/en/faq/,minghuan,en 28 | https://www.wipo.int/edocs/pubdocs/en/wipo_pub_450_2020.pdf,shiwen,en 29 | https://www.wipo.int/ip-outreach/en/ipday/2023/faq-world-ip-day.html,minghuan,en 30 | https://www.uspto.gov/patents/basics/essentials#questions,shiwen,en 31 | https://about.lens.org/category/faq/,shiwen,en 32 | https://www.rkdewan.com/faqs.php,minghuan,en 33 | https://www.rkdewan.com/trademarkfaq.php,minghuan,en 34 | https://www.rkdewan.com/designfaq.php,minghuan,en 35 | https://www.rkdewan.com/copyrightfaq.php,minghuan,en 36 | https://www.wipo.int/edocs/pubdocs/es/wipo_pub_1056.pdf,minghuan,es 37 | https://www.sansokan.jp/akinai/faq/detail.san?H_FAQ_CL=0&H_FAQ_NO=1461,fuqiang,jp 38 | https://www.kjpaa.jp/beginner/qaindex,fuqiang,jp 39 | http://www.sakae-pat.jp/faq/index.html,minghuan,jp 40 | https://jp.aigipat.com/faq/andSoOn.html,minghuan,jp 41 | https://www.jfc.go.jp/n/finance/keiei/nouhau1.html,minghuan,jp 42 | https://narapatent.com/faq/,minghuan,jp 43 | https://airimaq.kyushu-u.ac.jp/patent/patent-forms-list/#,minghuan,jp 44 | https://www.ondatechno.com/jp/search_analysis/faq/,minghuan,jp 45 | http://s-c-ip.com/faq%E7%89%B9%E8%A8%B1/,minghuan,jp 46 | https://www.isokanet.com/tokususu/faq.html,minghuan,jp 47 | https://www.primeworks-ip.com/en/faq/,minghuan,jp 48 | https://shizuoka-ipc.gr.jp/faq/,minghuan,jp 49 | https://www.aichi-patent.or.jp/faq/,minghuan,jp 50 | https://ip-ups.com/faq.html,minghuan,jp 51 | https://avenirpat.com/sp1/faqsp.html,minghuan,jp 52 | https://www.widebandip.com/jp/mobile/knowledge2.php?type1=A&idno=71,minghuan,jp 53 | https://www.ideaintellectual.com/ja/faq/,minghuan,jp 54 | https://www.ryuka.com/jp/business/service/faq/,minghuan,jp 55 | https://www.nagata-patent.com/FAQ.html,minghuan,jp 56 | https://jpaa-hokkaido.jp/faq/,minghuan,jp 57 | https://www.omni-pat.com/faq,minghuan,jp 58 | https://www.j-platpat.inpit.go.jp/c0500,minghuan,jp 59 | https://www.dpma.de/service/kmu/geistiges_eigentum/index.html,minghuan,de 60 | https://www.dpma.de/dpma/veroeffentlichungen/jahresberichte/quiz70jahre/index.html,minghuan,de 61 | https://www.hu-berlin.de/en/research/transfer/patente_lizenzen/pl_pat_frag_html,minghuan,de 62 | https://www.dkfz.de/en/techtrans/DKFZ-DKTK/FAQs.html,minghuan,de 63 | https://www.bluepatent.com/de/faq,fuqiang,de 64 | https://www.patoffice.de/haufig-gestellte-fragen,minghuan,de 65 | https://www.maiwald.eu/en/upc/#faq,minghuan,de 66 | https://www.w-hs.de/forschung-und-kooperation/erfindungen-und-patente/faqs/,minghuan,de 67 | https://www.raffay-fleck.de/faq/patent/,minghuan,de 68 | https://www.ige.ch/de/etwas-schuetzen/patente/faq,minghuan,de 69 | https://www.patent-markenzentrum.tu-darmstadt.de/schutzrechte/faq/index.de.jsp,minghuan,de 70 | https://lbp-patent.de/information/faq/,minghuan,de 71 | https://aaa-patent.de/faq/,minghuan,de 72 | https://www.gulde.com/de/informationen/faq,minghuan,de 73 | https://www.cohausz-florack.de/karriere/faq/,minghuan,de 74 | https://patentepi.org/de/faq/,minghuan,de 75 | https://www.patente-erfolgreich-anmelden.de/faqs,minghuan,de 76 | https://de.maucherjenkins.com/uber-ip/einheitspatent-faq,minghuan,de 77 | https://www.fsgu-akademie.de/lexikon/geistiges-eigentum/,minghuan,de 78 | https://www.fsgu-akademie.de/lexikon/handelsmarke/,minghuan,de 79 | https://www.fsgu-akademie.de/lexikon/design-schutzrecht/,minghuan,de 80 | https://www.fsgu-akademie.de/lexikon/patentrecht-deutschland/,minghuan,de 81 | https://www.fsgu-akademie.de/lexikon/urheberrecht/,minghuan,de 82 | https://www.ieepi.org/institut/faq/,yuelin,fr 83 | https://www.wto.org/french/tratop_f/trips_f/tripfq_f.htm#WhatAre,yuelin,fr 84 | https://www.wipo.int/publications/en/details.jsp?id=4410,minghuan,fr 85 | https://www.wipo.int/patents/fr/faq_patents.html,yuelin,fr 86 | https://www.wipo.int/edocs/pubdocs/ru/wipo_pub_1056.pdf,minghuan,ru 87 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 墨子(MoZi): An IP-oriented Multilingual Large Language Model 2 | ![image](https://github.com/AI-for-Science/MoZi/assets/56249874/0d0f8faa-4074-4748-b395-481517aea34e) 3 | 4 | ## Paper 5 | [MoZIP: A Multilingual Benchmark to Evaluate Large Language Models in Intellectual Property](https://arxiv.org/abs/2402.16389) (Accepted by LREC-COLING 2024) 6 | ``` 7 | @article{ni2024mozip, 8 | title={MoZIP: A Multilingual Benchmark to Evaluate Large Language Models in Intellectual Property}, 9 | author={Ni, Shiwen and Tan, Minghuan and Bai, Yuelin and Niu, Fuqiang and Yang, Min and Zhang, Bowen and Xu, Ruifeng and Chen, Xiaojun and Li, Chengming and Hu, Xiping and others}, 10 | journal={Proceedings of the 2024 Joint International Conference on Computational Linguistics, Language Resources and Evaluation (LREC-COLING 2024)}, 11 | year={2024} 12 | } 13 | ``` 14 | 15 | ## Model 16 | 17 | (基于qwen2-7b微调的**mozi3-7b**)https://huggingface.co/Nozomi7/mozi3-7b 18 | 19 | (基于chatglm3-6b微调的**mozi2-6b**)https://huggingface.co/Nozomi7/mozi2-6b 20 | 21 | (基于baichuan2-7b微调的**mozi2-7b**)https://huggingface.co/Nozomi7/mozi2-7b 22 | 23 | (基于bloomz-mt-7b微调的**mozi-7b**)https://huggingface.co/BNNT/mozi-7b 24 | 25 | ## MoZIP benchmark 26 | https://huggingface.co/datasets/BNNT/IPQuiz 27 | 28 | https://huggingface.co/datasets/BNNT/IPQA 29 | 30 | https://huggingface.co/datasets/BNNT/PatentMatch 31 | 32 | ## STF data 33 | General STF: https://huggingface.co/datasets/BNNT/mozi_general_instructions_3m 34 | 35 | Ip-specific STF: https://huggingface.co/datasets/BNNT/mozi_IP_instructions 36 | 37 | ## Example 38 | #### Q:你知道怎么引用他人作品又不侵犯著作权吗? 39 | 40 | **chatglm3-6b**:引用他人作品时,我们需要遵循合理使用原则,遵循以下几点,可以在一定程度上避免侵犯著作权: 41 | 1. 获取授权:在使用他人作品之前,尽量尝试与原作者取得联系,获得授权。email或者社交媒体等方式联系作者,征得他们的同意。 42 | 2. 注明来源:在使用他人作品时,应在使用过程中注明作者姓名、作品名称以及来源。这不仅是尊重原作者的劳动成果,也是遵守著作权法的表现。 43 | 3. transform 或者改编:在引用他人作品时,可以进行一定的修改或者改编,让作品更适合自己的需求。这种方式可以降低侵权风险。 44 | 4. 使用引用:如果引用他人作品是为了学术研究、评论、新闻报道等目的,那么可以不用获得原作者的授权。 45 | 总结起来,要想合理使用他人作品,就需要遵循上述原则,尽可能地避免侵犯著作权。 46 | 47 | **mozi2-6b**:引用他人作品是一种常见的学术或商业行为,但在引用过程中需要遵守一些法律和道德准则,以避免侵犯著作权。以下是一些建议: 48 | 1. 尊重版权:在引用他人作品时,应始终尊重版权持有人的权利和知识产权。未经许可,不要擅自使用他人的作品,以免侵犯他人的著作权。 49 | 2. 注明来源:在引用他人作品时,应确保注明来源,并使用引号括起所引用的部分。这有助于表明你对版权的尊重,并减少侵权的风险。 50 | 3. 合理使用:在某些情况下,对他人作品的使用可能是合理的,例如为了评论、批评、新闻报道或学术研究等目的。但是,这种使用必须在合理范围内,并不得影响版权持有人的利益。 51 | 4. 付费使用:如果你需要使用他人的作品,而该作品已经发表,但未获得版权持有人的许可或授权,那么你可能需要考虑付费使用。在这种情况下,你应该与版权持有人联系并协商许可条款。 52 | 5. 避免侵权:在引用他人作品时,要避免使用他人的作品作为自己的创作成果,以避免侵权纠纷。如果你需要在自己的创作中使用他人的作品,应该确保你的使用方式符合版权法律法规的规定。 53 | 总之,在引用他人作品时,我们应该尊重版权,注明来源,合理使用,付费使用,并避免侵权。这样可以保护自己的合法权益,同时也有助于促进知识产权的保护和市场的发展。 54 | 55 | #### Q:商号与商标的区别是什么? 56 | 57 | **chatglm3-6b**:商号和商标都是用于区分和识别某个企业或产品的标志,但它们之间有一些区别。 58 | 商号是指企业为了区分自己的商品或服务而使用的名称,通常是文字、字母、数字或组合等形式。商号不一定是独特的,可能会有多个企业使用相同的商号。商号主要起到标识作用,让消费者能够识别商品或服务的来源。 59 | 商标是指通过使用特定的标志,如文字、图像、符号、设计等来区分和识别某个企业或产品的标志。商标通常是独特的,只能由一个企业使用,有助于消费者在市场上区分和选择产品或服务。商标还可以包括企业的名称、标志、口号等。 60 | 商号和商标之间的主要区别在于是否具有独特性。商号可能会有多个企业使用相同或相似的名称,而商标必须是独特的,只能由一个企业使用。此外,商号主要起到标识作用,而商标除了标识作用外,还具有区分和选择产品的功能。 61 | 62 | **mozi2-6b**: 63 | 商号与商标在以下六个方面存在区别: 64 | 1. 功能和作用:商号主要用于区分不同的企业,而商标则主要用于区分不同的商品。一般而言,一个公司和企业只能有一个商号,而可以根据不同的商品来注册商标。 65 | 2. 表现形式:商号一般只能用两个以上的汉字组成,而商标则可以采用数字、图形、拼音、颜色和英文字母等来组合运用。这也反映了两者在表现形式上的差异。 66 | 3. 使用的范围和效力:在中国,商标权具有全国范围内的效力,且其有效期为十年,过期后可以进行续展。相对地,商号权具有一定的区域限制,首先要通过国家工商行政管理总局核准后在一定的区域内使用,但对于商号的使用期限这没有限制,即只要一经通过就可终身使用。 67 | 4. 转让和许可:商号权一般不能单独转让,而商标权可以单独转让或许可他人使用。此外,商标权的使用期限一般为十年,到期后可以续展,而商号权则没有这一限制。 68 | 5. 权利内容:商号权的内容一般包括禁止他人使用相同或近似商号,而商标权的内容则包括禁止他人使用相同或近似商标。此外,商标权还包括禁止他人使用与注册商标相同或近似的标志。 69 | 6. 法律保护:商号和商标都受法律保护,但受保护的范围和所依据的法律可能有所不同。在中国,商标受《中华人民共和国商标法》保护,而商号受《中华人民共和国公司法》和《中华人民共和国企业名称登记管理规定》保护。 70 | 总的来说,商号和商标在功能、表现形式、使用范围、转让和许可、权利内容和法律保护等方面存在明显的区别。这些区别反映了两者在商业实践和法律规定上的差异。 71 | 72 | 73 | 74 | -------------------------------------------------------------------------------- /SFT.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Microsoft Corporation. 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | # DeepSpeed Team 6 | import argparse 7 | import os 8 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 9 | from datasets import disable_caching 10 | disable_caching() 11 | import math 12 | import sys 13 | from tqdm import tqdm 14 | import torch 15 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler 16 | from torch.utils.data.distributed import DistributedSampler 17 | import transformers 18 | print("transformers.__version__ : ", transformers.__version__)#4.29.0.dev0 19 | from transformers import ( 20 | AutoModelForCausalLM, 21 | AutoTokenizer, 22 | SchedulerType, 23 | default_data_collator, 24 | get_scheduler, 25 | LlamaTokenizer, 26 | ) 27 | 28 | import deepspeed 29 | from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam 30 | 31 | sys.path.append( 32 | os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) 33 | from utils.data.data_utils import create_prompt_dataset 34 | from utils.utils import print_rank_0, to_device, save_hf_format, set_random_seed, get_all_reduce_mean, get_optimizer_grouped_parameters, save_zero_three_model 35 | from utils.ds_utils import get_train_ds_config 36 | from utils.module.lora import convert_linear_layer_to_lora, convert_lora_to_linear_layer, only_optimize_lora_parameters 37 | from utils.model.model_utils import create_hf_model 38 | 39 | 40 | def parse_args(): 41 | parser = argparse.ArgumentParser( 42 | description= 43 | "Finetune a transformers model on a causal language modeling task") 44 | parser.add_argument('--data_path', 45 | nargs='*', 46 | default=[], 47 | help='Path to the training dataset. Accepted format:' 48 | '1) a single data path, 2) multiple datasets in the' 49 | 'form: dataset1-path dataset2-path ...') 50 | parser.add_argument('--data_split', 51 | type=str, 52 | default='10,0,0', 53 | help='Comma-separated list of proportions for training' 54 | 'phase 1, 2, and 3 data. For example the split `2,4,4`' 55 | 'will use 60% of data for phase 1, 20% for phase 2' 56 | 'and 20% for phase 3.') 57 | parser.add_argument('--sft_only_data_path', nargs='*', default=[], help='Path to the dataset for only using in SFT phase.') 58 | parser.add_argument('--eval_data_file', type=str, default=None) 59 | 60 | parser.add_argument( 61 | '--data_output_path', 62 | type=str, 63 | default='output/data_files/', 64 | help= 65 | 'Where to store the data-related files such as shuffle index. This needs to be on a local storage of a node (not on a shared storage)' 66 | ) 67 | parser.add_argument( 68 | "--model_name_or_path", 69 | type=str, 70 | help= 71 | "Path to pretrained model or model identifier from huggingface.co/models.", 72 | required=True, 73 | ) 74 | parser.add_argument( 75 | "--per_device_train_batch_size", 76 | type=int, 77 | default=16, 78 | help="Batch size (per device) for the training dataloader.", 79 | ) 80 | parser.add_argument( 81 | "--per_device_eval_batch_size", 82 | type=int, 83 | default=16, 84 | help="Batch size (per device) for the evaluation dataloader.", 85 | ) 86 | parser.add_argument( 87 | "--max_seq_len", 88 | type=int, 89 | default=512, 90 | help="The maximum sequence length.", 91 | ) 92 | parser.add_argument( 93 | "--learning_rate", 94 | type=float, 95 | default=1e-3, 96 | help= 97 | "Initial learning rate (after the potential warmup period) to use.", 98 | ) 99 | parser.add_argument("--weight_decay", 100 | type=float, 101 | default=0.1, 102 | help="Weight decay to use.") 103 | parser.add_argument("--num_train_epochs", 104 | type=int, 105 | default=1, 106 | help="Total number of training epochs to perform.") 107 | parser.add_argument( 108 | "--gradient_accumulation_steps", 109 | type=int, 110 | default=1, 111 | help= 112 | "Number of updates steps to accumulate before performing a backward/update pass.", 113 | ) 114 | parser.add_argument( 115 | "--lr_scheduler_type", 116 | type=SchedulerType, 117 | default="cosine", 118 | help="The scheduler type to use.", 119 | choices=[ 120 | "linear", "cosine", "cosine_with_restarts", "polynomial", 121 | "constant", "constant_with_warmup" 122 | ], 123 | ) 124 | parser.add_argument( 125 | "--num_warmup_steps", 126 | type=int, 127 | default=0, 128 | help="Number of steps for the warmup in the lr scheduler.") 129 | parser.add_argument("--output_dir", 130 | type=str, 131 | default=None, 132 | help="Where to store the model.") 133 | parser.add_argument("--seed", 134 | type=int, 135 | default=1234, 136 | help="A seed for reproducible training.") 137 | parser.add_argument("--local_rank", 138 | type=int, 139 | default=-1, 140 | help="local_rank for distributed training on gpus") 141 | parser.add_argument('--gradient_checkpointing', 142 | action='store_true', 143 | help='Enable HF gradient checkpointing for model.') 144 | parser.add_argument('--save_steps', 145 | type=int, 146 | default=10, 147 | help='Num of steps to save checkpoint.') 148 | parser.add_argument('--evaluation_steps', 149 | type=int, 150 | default=10, 151 | help='Num of steps to evaluate.') 152 | # deepspeed features 153 | parser.add_argument('--offload', 154 | action='store_true', 155 | help='Enable ZeRO Offload techniques.') 156 | parser.add_argument( 157 | '--zero_stage', 158 | type=int, 159 | default=0, 160 | help='ZeRO optimization stage for Actor model (and clones).') 161 | ## LoRA for efficient training setting 162 | parser.add_argument("--lora_dim", 163 | type=int, 164 | default=0, 165 | help="If > 0, use LoRA for efficient training.") 166 | parser.add_argument("--lora_alpha", 167 | type=int, 168 | default=0, 169 | help="lora alpha") 170 | parser.add_argument("--lora_droppout", 171 | type=float, 172 | default=0., 173 | help="lora_droppout") 174 | parser.add_argument("--lora_module_name", 175 | type=str, 176 | default="decoder.layers.", 177 | help="The scope of LoRA.") 178 | parser.add_argument('--only_optimize_lora', 179 | action='store_true', 180 | help='Only optimize the LoRA parameters.') 181 | parser.add_argument("--show_loss_step", default=100, type=int, help = "Show the loss step") 182 | parser.add_argument("--max_new_tokens", default=1024, type=int, help = "Max number of output tokens") 183 | 184 | parser = deepspeed.add_config_arguments(parser) 185 | args = parser.parse_args() 186 | 187 | # Validate settings 188 | if args.gradient_checkpointing and args.lora_dim > 0: 189 | assert ( 190 | not args.only_optimize_lora 191 | ), "--gradient_checkpointing and --only_optimizer_lora cannot be enabled at the same time." 192 | 193 | return args 194 | 195 | 196 | def main(): 197 | args = parse_args() 198 | 199 | if args.local_rank == -1: 200 | device = torch.device("cuda") 201 | else: 202 | torch.cuda.set_device(args.local_rank) 203 | device = torch.device("cuda", args.local_rank) 204 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 205 | # torch.distributed.init_process_group(backend='nccl') 206 | 207 | deepspeed.init_distributed() 208 | 209 | args.global_rank = torch.distributed.get_rank() 210 | 211 | ds_config = get_train_ds_config(offload=args.offload, 212 | stage=args.zero_stage) 213 | ds_config[ 214 | 'train_micro_batch_size_per_gpu'] = args.per_device_train_batch_size 215 | ds_config[ 216 | 'train_batch_size'] = args.per_device_train_batch_size * torch.distributed.get_world_size( 217 | ) * args.gradient_accumulation_steps 218 | 219 | # If passed along, set the training seed now. 220 | set_random_seed(args.seed) 221 | 222 | assert not args.offload, "zero-offload is not currently supported but coming soon!" 223 | 224 | torch.distributed.barrier() 225 | 226 | print("model_name_or_path : ", args.model_name_or_path) 227 | if "llama" in args.model_name_or_path.lower(): 228 | tokenizer = LlamaTokenizer.from_pretrained(args.model_name_or_path)#May occur RecursionError: maximum recursion depth exceeded if used AutoTokenizer 229 | tokenizer.pad_token_id = 0 # that is , initial llama has no 230 | # assert tokenizer.bos_token_id == 1 and tokenizer.eos_token_id == 2, (tokenizer.bos_token_id, tokenizer.eos_token_id) 231 | tokenizer.bos_token_id = 1 232 | tokenizer.eos_token_id = 2 233 | #transformers version has a different influence for LlamaTokenizer 234 | else: 235 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) 236 | 237 | tokenizer.pad_token_id = 0# For Bloom, we also set zero to tokenizer.pad_token_id 238 | tokenizer.padding_side = "left" 239 | print("Making tokenizer padding side to left") 240 | print("tokenizer.bos_token_id: ", tokenizer.bos_token_id) 241 | print("tokenizer.eos_token_id: ", tokenizer.eos_token_id) 242 | 243 | model = create_hf_model(AutoModelForCausalLM, args.model_name_or_path, 244 | tokenizer, ds_config) 245 | 246 | if args.lora_dim > 0: 247 | lora_module_name = args.lora_module_name.split(",") 248 | print("lora_module_name: ", lora_module_name) 249 | print("lora_dim: {}, lora_alpha: {}, lora_scaling: {}, lora_dropout: {}".format(args.lora_dim, args.lora_alpha, args.lora_alpha/args.lora_dim, args.lora_droppout)) 250 | 251 | model = convert_linear_layer_to_lora(model, lora_module_name = lora_module_name, lora_dim = args.lora_dim, lora_alpha = args.lora_alpha, lora_droppout=args.lora_droppout) 252 | 253 | if args.only_optimize_lora: 254 | model = only_optimize_lora_parameters(model) 255 | 256 | # Prepare the data 257 | train_phase = 1 258 | print("sft_only_data_path : ", args.sft_only_data_path) 259 | train_dataset, eval_dataset = create_prompt_dataset( 260 | local_rank = args.local_rank, 261 | sft_only_data_path = args.sft_only_data_path, 262 | eval_data_file = args.eval_data_file, 263 | data_split = args.data_split, 264 | output_path = args.data_output_path, 265 | train_phase = train_phase, 266 | seed = args.seed, 267 | tokenizer = tokenizer, 268 | max_seq_len = args.max_seq_len 269 | ) 270 | 271 | # DataLoaders creation: 272 | if args.local_rank == -1: 273 | train_sampler = RandomSampler(train_dataset) 274 | eval_sampler = SequentialSampler(eval_dataset) 275 | else: 276 | train_sampler = DistributedSampler(train_dataset) 277 | eval_sampler = DistributedSampler(eval_dataset) 278 | train_dataloader = DataLoader(train_dataset, 279 | collate_fn=default_data_collator, 280 | sampler=train_sampler, 281 | batch_size=args.per_device_train_batch_size) 282 | print("len(train_dataloader) = ", len(train_dataloader)) 283 | print("len(train_dataset) = ", len(train_dataset)) 284 | print("args.per_device_train_batch_size = ", args.per_device_train_batch_size) 285 | 286 | eval_dataloader = DataLoader(eval_dataset, 287 | collate_fn=default_data_collator, 288 | sampler=eval_sampler, 289 | batch_size=args.per_device_eval_batch_size) 290 | print("len(eval_dataloader) = ", len(eval_dataloader)) 291 | print("len(eval_dataset) = ", len(eval_dataset)) 292 | print("args.per_device_eval_batch_size = ", args.per_device_eval_batch_size) 293 | 294 | 295 | def evaluation(model, eval_dataloader): 296 | model.eval() 297 | losses = 0 298 | # output_texts = [] 299 | for step, batch in tqdm(enumerate(eval_dataloader), total=len(eval_dataloader), unit="batch"): 300 | batch = to_device(batch, device) 301 | with torch.no_grad(): 302 | outputs = model(**batch) 303 | 304 | loss = outputs.loss 305 | losses += loss.float() 306 | 307 | losses = losses / (step + 1) 308 | model.train() 309 | try: 310 | perplexity = torch.exp(losses) 311 | except OverflowError: 312 | perplexity = float("inf") 313 | try: 314 | perplexity = get_all_reduce_mean(perplexity).item() 315 | except: 316 | pass 317 | # with open("./predictions.txt", "w") as f: 318 | # for pred_text in output_texts: 319 | # f.write(pred_text+"\n") 320 | 321 | return perplexity 322 | 323 | # Split weights in two groups, one with weight decay and the other not. 324 | optimizer_grouped_parameters = get_optimizer_grouped_parameters( 325 | model, args.weight_decay) 326 | 327 | AdamOptimizer = DeepSpeedCPUAdam if args.offload else FusedAdam 328 | optimizer = AdamOptimizer(optimizer_grouped_parameters, 329 | lr=args.learning_rate, 330 | betas=(0.9, 0.95)) 331 | 332 | num_update_steps_per_epoch = math.ceil( 333 | len(train_dataloader) / args.gradient_accumulation_steps) 334 | lr_scheduler = get_scheduler( 335 | name=args.lr_scheduler_type, 336 | optimizer=optimizer, 337 | num_warmup_steps=args.num_warmup_steps, 338 | num_training_steps=args.num_train_epochs * num_update_steps_per_epoch, 339 | ) 340 | 341 | model, optimizer, _, lr_scheduler = deepspeed.initialize( 342 | model=model, 343 | optimizer=optimizer, 344 | args=args, 345 | config=ds_config, 346 | lr_scheduler=lr_scheduler, 347 | dist_init_required=True) 348 | 349 | if args.gradient_checkpointing: 350 | model.gradient_checkpointing_enable() 351 | 352 | # Train! 353 | print_rank_0("***** Running training *****", args.global_rank) 354 | print_rank_0( 355 | f"***** Evaluating perplexity, Epoch {0}/{args.num_train_epochs} *****", 356 | args.global_rank) 357 | 358 | training_step_losses = [] 359 | 360 | for epoch in range(args.num_train_epochs): 361 | print_rank_0( 362 | f"Beginning of Epoch {epoch + 1}/{args.num_train_epochs}, Total Micro Batches {len(train_dataloader)}", 363 | args.global_rank) 364 | model.train() 365 | for step, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader), unit="batch"): 366 | batch = to_device(batch, device) 367 | outputs = model(**batch, use_cache=False) 368 | loss = outputs.loss 369 | model.backward(loss) 370 | model.step() 371 | print("Epoch: {}, step: {}, loss: {}".format(epoch, step, loss.item())) 372 | if (step + 1) % args.save_steps == 0: 373 | save_path = os.path.join(args.output_dir, f"Epoch-{epoch + 1}-step-{step + 1}") 374 | 375 | save_zero_three_model(model, 376 | args.global_rank, 377 | save_path, 378 | zero_stage=args.zero_stage) 379 | 380 | print_rank_0( 381 | f"Saving checkpoint... Steps {step + 1} Epoch {epoch + 1}/{args.num_train_epochs}", 382 | args.global_rank) 383 | 384 | if (step + 1) % args.evaluation_steps == 0: 385 | perplexity = evaluation(model, eval_dataloader) 386 | 387 | # wandb.log({"perplexity": perplexity}, step=step) 388 | 389 | print_rank_0(f"ppl: {perplexity}", args.global_rank) 390 | print_rank_0( 391 | f"***** Evaluating perplexity, Steps {step + 1}, Epoch {epoch + 1}/{args.num_train_epochs} *****", 392 | args.global_rank) 393 | 394 | # Evaluate perplexity on the validation set. 395 | #perplexity = evaluation(model, eval_dataloader) 396 | #print_rank_0(f"ppl: {perplexity}", args.global_rank) 397 | #print_rank_0( 398 | # f"***** Evaluating perplexity, Steps {step + 1}, Epoch {epoch + 1}/{args.num_train_epochs} *****", 399 | # args.global_rank) 400 | 401 | # Save after each epoch. 402 | 403 | model.tput_timer.update_epoch_count() 404 | 405 | if args.output_dir is not None: 406 | print_rank_0('saving the final model ...', args.global_rank)#It will overwrite the last epoch model 407 | model = convert_lora_to_linear_layer(model) 408 | 409 | if args.global_rank == 0: 410 | save_hf_format(model, tokenizer, args) 411 | 412 | if args.zero_stage == 3: 413 | # For zero stage 3, each gpu only has a part of the model, so we need a special save function 414 | save_zero_three_model(model, 415 | args.global_rank, 416 | args.output_dir, 417 | zero_stage=args.zero_stage) 418 | 419 | 420 | if __name__ == "__main__": 421 | main() 422 | -------------------------------------------------------------------------------- /crawlers/wipo/scrapy.cfg: -------------------------------------------------------------------------------- 1 | # Automatically created by: scrapy startproject 2 | # 3 | # For more information about the [deploy] section see: 4 | # https://scrapyd.readthedocs.io/en/latest/deploy.html 5 | 6 | [settings] 7 | default = wipo.settings 8 | 9 | [deploy] 10 | #url = http://localhost:6800/ 11 | project = wipo 12 | -------------------------------------------------------------------------------- /crawlers/wipo/wipo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-for-Science/MoZi/517cb2ed7b7cab0d23803c964ae671f6df405e8b/crawlers/wipo/wipo/__init__.py -------------------------------------------------------------------------------- /crawlers/wipo/wipo/items.py: -------------------------------------------------------------------------------- 1 | # Define here the models for your scraped items 2 | # 3 | # See documentation in: 4 | # https://docs.scrapy.org/en/latest/topics/items.html 5 | 6 | import scrapy 7 | 8 | 9 | class WipoItem(scrapy.Item): 10 | # define the fields for your item here like: 11 | # name = scrapy.Field() 12 | rule = scrapy.Field() 13 | link_text = scrapy.Field() 14 | depth = scrapy.Field() 15 | retry_times = scrapy.Field() 16 | download_timeout = scrapy.Field() 17 | download_slot = scrapy.Field() 18 | download_latency = scrapy.Field() 19 | 20 | -------------------------------------------------------------------------------- /crawlers/wipo/wipo/middlewares.py: -------------------------------------------------------------------------------- 1 | # Define here the models for your spider middleware 2 | # 3 | # See documentation in: 4 | # https://docs.scrapy.org/en/latest/topics/spider-middleware.html 5 | 6 | from scrapy import signals 7 | 8 | # useful for handling different item types with a single interface 9 | from itemadapter import is_item, ItemAdapter 10 | 11 | 12 | class WipoSpiderMiddleware: 13 | # Not all methods need to be defined. If a method is not defined, 14 | # scrapy acts as if the spider middleware does not modify the 15 | # passed objects. 16 | 17 | @classmethod 18 | def from_crawler(cls, crawler): 19 | # This method is used by Scrapy to create your spiders. 20 | s = cls() 21 | crawler.signals.connect(s.spider_opened, signal=signals.spider_opened) 22 | return s 23 | 24 | def process_spider_input(self, response, spider): 25 | # Called for each response that goes through the spider 26 | # middleware and into the spider. 27 | 28 | # Should return None or raise an exception. 29 | return None 30 | 31 | def process_spider_output(self, response, result, spider): 32 | # Called with the results returned from the Spider, after 33 | # it has processed the response. 34 | 35 | # Must return an iterable of Request, or item objects. 36 | for i in result: 37 | yield i 38 | 39 | def process_spider_exception(self, response, exception, spider): 40 | # Called when a spider or process_spider_input() method 41 | # (from other spider middleware) raises an exception. 42 | 43 | # Should return either None or an iterable of Request or item objects. 44 | pass 45 | 46 | def process_start_requests(self, start_requests, spider): 47 | # Called with the start requests of the spider, and works 48 | # similarly to the process_spider_output() method, except 49 | # that it doesn’t have a response associated. 50 | 51 | # Must return only requests (not items). 52 | for r in start_requests: 53 | yield r 54 | 55 | def spider_opened(self, spider): 56 | spider.logger.info("Spider opened: %s" % spider.name) 57 | 58 | 59 | class WipoDownloaderMiddleware: 60 | # Not all methods need to be defined. If a method is not defined, 61 | # scrapy acts as if the downloader middleware does not modify the 62 | # passed objects. 63 | 64 | @classmethod 65 | def from_crawler(cls, crawler): 66 | # This method is used by Scrapy to create your spiders. 67 | s = cls() 68 | crawler.signals.connect(s.spider_opened, signal=signals.spider_opened) 69 | return s 70 | 71 | def process_request(self, request, spider): 72 | # Called for each request that goes through the downloader 73 | # middleware. 74 | 75 | # Must either: 76 | # - return None: continue processing this request 77 | # - or return a Response object 78 | # - or return a Request object 79 | # - or raise IgnoreRequest: process_exception() methods of 80 | # installed downloader middleware will be called 81 | return None 82 | 83 | def process_response(self, request, response, spider): 84 | # Called with the response returned from the downloader. 85 | 86 | # Must either; 87 | # - return a Response object 88 | # - return a Request object 89 | # - or raise IgnoreRequest 90 | return response 91 | 92 | def process_exception(self, request, exception, spider): 93 | # Called when a download handler or a process_request() 94 | # (from other downloader middleware) raises an exception. 95 | 96 | # Must either: 97 | # - return None: continue processing this exception 98 | # - return a Response object: stops process_exception() chain 99 | # - return a Request object: stops process_exception() chain 100 | pass 101 | 102 | def spider_opened(self, spider): 103 | spider.logger.info("Spider opened: %s" % spider.name) 104 | -------------------------------------------------------------------------------- /crawlers/wipo/wipo/pipelines.py: -------------------------------------------------------------------------------- 1 | # Define your item pipelines here 2 | # 3 | # Don't forget to add your pipeline to the ITEM_PIPELINES setting 4 | # See: https://docs.scrapy.org/en/latest/topics/item-pipeline.html 5 | 6 | 7 | # useful for handling different item types with a single interface 8 | from itemadapter import ItemAdapter 9 | 10 | 11 | class WipoPipeline: 12 | def process_item(self, item, spider): 13 | return item 14 | -------------------------------------------------------------------------------- /crawlers/wipo/wipo/settings.py: -------------------------------------------------------------------------------- 1 | # Scrapy settings for wipo project 2 | # 3 | # For simplicity, this file contains only settings considered important or 4 | # commonly used. You can find more settings consulting the documentation: 5 | # 6 | # https://docs.scrapy.org/en/latest/topics/settings.html 7 | # https://docs.scrapy.org/en/latest/topics/downloader-middleware.html 8 | # https://docs.scrapy.org/en/latest/topics/spider-middleware.html 9 | 10 | BOT_NAME = "wipo" 11 | 12 | SPIDER_MODULES = ["wipo.spiders"] 13 | NEWSPIDER_MODULE = "wipo.spiders" 14 | 15 | 16 | # Crawl responsibly by identifying yourself (and your website) on the user-agent 17 | #USER_AGENT = "wipo (+http://www.yourdomain.com)" 18 | 19 | # Obey robots.txt rules 20 | ROBOTSTXT_OBEY = True 21 | 22 | # Configure maximum concurrent requests performed by Scrapy (default: 16) 23 | #CONCURRENT_REQUESTS = 32 24 | 25 | # Configure a delay for requests for the same website (default: 0) 26 | # See https://docs.scrapy.org/en/latest/topics/settings.html#download-delay 27 | # See also autothrottle settings and docs 28 | DOWNLOAD_DELAY = 3 29 | # The download delay setting will honor only one of: 30 | #CONCURRENT_REQUESTS_PER_DOMAIN = 16 31 | CONCURRENT_REQUESTS_PER_IP = 16 32 | 33 | # Disable cookies (enabled by default) 34 | #COOKIES_ENABLED = False 35 | 36 | # Disable Telnet Console (enabled by default) 37 | #TELNETCONSOLE_ENABLED = False 38 | 39 | # Override the default request headers: 40 | #DEFAULT_REQUEST_HEADERS = { 41 | # "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8", 42 | # "Accept-Language": "en", 43 | #} 44 | 45 | # Enable or disable spider middlewares 46 | # See https://docs.scrapy.org/en/latest/topics/spider-middleware.html 47 | #SPIDER_MIDDLEWARES = { 48 | # "wipo.middlewares.WipoSpiderMiddleware": 543, 49 | #} 50 | 51 | # Enable or disable downloader middlewares 52 | # See https://docs.scrapy.org/en/latest/topics/downloader-middleware.html 53 | #DOWNLOADER_MIDDLEWARES = { 54 | # "wipo.middlewares.WipoDownloaderMiddleware": 543, 55 | #} 56 | 57 | # Enable or disable extensions 58 | # See https://docs.scrapy.org/en/latest/topics/extensions.html 59 | #EXTENSIONS = { 60 | # "scrapy.extensions.telnet.TelnetConsole": None, 61 | #} 62 | 63 | # Configure item pipelines 64 | # See https://docs.scrapy.org/en/latest/topics/item-pipeline.html 65 | #ITEM_PIPELINES = { 66 | # "wipo.pipelines.WipoPipeline": 300, 67 | #} 68 | 69 | # Enable and configure the AutoThrottle extension (disabled by default) 70 | # See https://docs.scrapy.org/en/latest/topics/autothrottle.html 71 | #AUTOTHROTTLE_ENABLED = True 72 | # The initial download delay 73 | #AUTOTHROTTLE_START_DELAY = 5 74 | # The maximum download delay to be set in case of high latencies 75 | #AUTOTHROTTLE_MAX_DELAY = 60 76 | # The average number of requests Scrapy should be sending in parallel to 77 | # each remote server 78 | #AUTOTHROTTLE_TARGET_CONCURRENCY = 1.0 79 | # Enable showing throttling stats for every response received: 80 | #AUTOTHROTTLE_DEBUG = False 81 | 82 | # Enable and configure HTTP caching (disabled by default) 83 | # See https://docs.scrapy.org/en/latest/topics/downloader-middleware.html#httpcache-middleware-settings 84 | #HTTPCACHE_ENABLED = True 85 | #HTTPCACHE_EXPIRATION_SECS = 0 86 | #HTTPCACHE_DIR = "httpcache" 87 | #HTTPCACHE_IGNORE_HTTP_CODES = [] 88 | #HTTPCACHE_STORAGE = "scrapy.extensions.httpcache.FilesystemCacheStorage" 89 | 90 | # Set settings whose default value is deprecated to a future-proof value 91 | REQUEST_FINGERPRINTER_IMPLEMENTATION = "2.7" 92 | TWISTED_REACTOR = "twisted.internet.asyncioreactor.AsyncioSelectorReactor" 93 | FEED_EXPORT_ENCODING = "utf-8" 94 | -------------------------------------------------------------------------------- /crawlers/wipo/wipo/spiders/__init__.py: -------------------------------------------------------------------------------- 1 | # This package will contain the spiders of your Scrapy project 2 | # 3 | # Please refer to the documentation for information on how to create and manage 4 | # your spiders. 5 | -------------------------------------------------------------------------------- /crawlers/wipo/wipo/spiders/wipo_int.py: -------------------------------------------------------------------------------- 1 | import os 2 | import scrapy 3 | from wipo.items import WipoItem 4 | from scrapy.spiders import CrawlSpider, Rule 5 | from scrapy.linkextractors import LinkExtractor 6 | 7 | DATA_DIR = 'data' 8 | 9 | class WipoIntSpider(CrawlSpider): 10 | name = "wipo_int" 11 | allowed_domains = ["www.wipo.int"] 12 | 13 | main_site = "https://www.wipo.int/treaties" 14 | start_urls = [ 15 | main_site, 16 | ] 17 | rules = ( 18 | # Extract links matching 'category.php' (but not matching 'subsection.php') 19 | # and follow links from them (since no callback means follow=True by default). 20 | # Rule(LinkExtractor(allow=('category\.php',), deny=('subsection\.php',))), 21 | 22 | # # Extract links matching 'item.php' and parse them with the spider's method parse_item 23 | # Rule(LinkExtractor(allow=('chinese\/*\/*\.htm',)), callback='parse_item', follow=True), 24 | Rule(LinkExtractor(allow=('/treaties/.*',)), callback='parse_item', follow=True), 25 | Rule(LinkExtractor(allow=('\.pdf',)), callback='save_pdf', follow=True), 26 | # Rule(LinkExtractor(allow=('\d+\.index\.htm',)), callback='parse_item'), 27 | ) 28 | 29 | def save_pdf(self, response): 30 | # path = response.url.split('/')[-1] 31 | self.logger.info('Hi, this is an pdf page! %s', response.url) 32 | pdf_file = os.path.join(DATA_DIR, response.url.replace(self.main_site, '')) 33 | self.logger.info('Saving PDF %s', pdf_file) 34 | os.makedirs(os.path.dirname(pdf_file), exist_ok=True) 35 | with open(pdf_file, 'wb') as f: 36 | f.write(response.body) 37 | 38 | def parse_item(self, response): 39 | self.logger.info('Hi, this is an item page! %s', response.url) 40 | html_file = DATA_DIR + response.url.replace(self.main_site, '') 41 | if not html_file.endswith('.html'): 42 | html_file = html_file + ".html" 43 | 44 | self.logger.info('Saving to dir %s', DATA_DIR) 45 | self.logger.info('Saving to html %s', html_file) 46 | content = response.body.decode() 47 | # try: 48 | # content = response.body.decode("gb2312").replace('gb2312', 'utf8') 49 | # except UnicodeDecodeError: 50 | # self.logger.info('Unicode error for {}'.format(response.url)) 51 | # content = response.body.decode('gb18030').replace('gb2312', 'utf8') 52 | 53 | os.makedirs(os.path.dirname(html_file), exist_ok=True) 54 | with open(html_file, 'w', encoding='utf8') as f: 55 | f.write(content) 56 | item = WipoItem() 57 | for k in response.meta: 58 | item[k] = response.meta[k] 59 | # item['id'] = response.xpath('//td[@id="item_id"]/text()').re(r'ID: (\d+)') 60 | # item['name'] = response.xpath('//td[@id="item_name"]/text()').get() 61 | # item['description'] = response.xpath('//td[@id="item_description"]/text()').get() 62 | # item['link_text'] = response.meta['link_text'] 63 | 64 | return item 65 | 66 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import json 3 | from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, AutoConfig, AutoModel 4 | import argparse 5 | from tqdm import tqdm 6 | import json, os 7 | import sys 8 | 9 | 10 | parser = argparse.ArgumentParser() 11 | 12 | parser.add_argument('--model_name_or_path',required=True, type=str) 13 | parser.add_argument('--foundation_model', default='bloom', required=True,type=str) 14 | parser.add_argument('--test_file',required=True, default='./test.json', type=str) 15 | parser.add_argument('--predictions_file', default='./predictions.json', type=str) 16 | args = parser.parse_args() 17 | 18 | print("test_file: " + args.test_file) 19 | print("model_name_or_path: " + args.model_name_or_path) 20 | print("foundation_model: " + args.foundation_model) 21 | 22 | max_new_tokens = 2048 23 | generation_config = dict( 24 | temperature=0.001, 25 | top_k=30, 26 | top_p=0.85, 27 | do_sample=True, 28 | num_beams=1, 29 | repetition_penalty=1.2, 30 | max_new_tokens=max_new_tokens 31 | ) 32 | 33 | dev_batch_size = 8 34 | 35 | def read_data(filename): 36 | res = [] 37 | with open(filename, 'r', encoding='utf-8') as f: 38 | lines = f.readlines() 39 | for line in lines: 40 | res.append(json.loads(line.strip())) 41 | return res 42 | 43 | 44 | input_items = read_data(args.test_file) # 输入 45 | output_items = [] 46 | 47 | def write_data(filename, examples): 48 | with open(filename, 'w', encoding='utf-8') as f: 49 | for example in examples: 50 | f.write(json.dumps(example, ensure_ascii=False) + "\n") 51 | 52 | print("predictions will be written at {}".format(args.predictions_file)) 53 | 54 | 55 | if __name__ == '__main__': 56 | load_type = torch.float16 57 | if torch.cuda.is_available(): 58 | device = torch.device(0) 59 | else: 60 | device = torch.device('cpu') 61 | 62 | 63 | 64 | if "llama" in args.foundation_model: 65 | tokenizer = LlamaTokenizer.from_pretrained(args.model_name_or_path) 66 | tokenizer.pad_token_id = 0 67 | tokenizer.padding_side = "left" 68 | model_config = AutoConfig.from_pretrained(args.model_name_or_path) 69 | model = AutoModelForCausalLM.from_pretrained( 70 | args.model_name_or_path, 71 | torch_dtype=load_type, 72 | config=model_config, 73 | ignore_mismatched_sizes=True 74 | ) 75 | model.to(device) 76 | elif "bloom" in args.foundation_model: 77 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) 78 | tokenizer.pad_token_id = 0 79 | tokenizer.padding_side = "left" 80 | model_config = AutoConfig.from_pretrained(args.model_name_or_path) 81 | model = AutoModelForCausalLM.from_pretrained( 82 | args.model_name_or_path, 83 | torch_dtype=load_type, 84 | config=model_config, 85 | ignore_mismatched_sizes=True 86 | ) 87 | model.to(device) 88 | elif "chatglm" in args.foundation_model: 89 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True) 90 | model = AutoModel.from_pretrained(args.model_name_or_path, trust_remote_code=True).half().cuda() 91 | 92 | 93 | model.eval() 94 | print("Load model successfully") 95 | 96 | index = 0 97 | 98 | for i in tqdm(range(0, len(input_items), dev_batch_size), total=len(input_items)//dev_batch_size, unit="item"): 99 | batch_input_items = input_items[i:i+dev_batch_size] 100 | batch_input_text = ["Human: "+input_item['instruction']+"\nAssistant: " for input_item in batch_input_items] 101 | # batch_input_text = ["Please give the correct option for the following question "+input_item['instruction'] for input_item in batch_input_items] 102 | batch_inputs = tokenizer(batch_input_text, max_length=max_new_tokens, padding=True, truncation=True,return_tensors="pt") #add_special_tokens=False ? 103 | batch_generation_output = model.generate( 104 | input_ids = batch_inputs["input_ids"].to(device), 105 | attention_mask = batch_inputs['attention_mask'].to(device), 106 | eos_token_id=tokenizer.eos_token_id, 107 | pad_token_id=tokenizer.pad_token_id, 108 | **generation_config 109 | ) 110 | 111 | batch_generate_text = tokenizer.batch_decode(batch_generation_output,skip_special_tokens=True) 112 | 113 | for generate_text, input_item in zip(batch_generate_text, batch_input_items): 114 | output_items.append({"generate_text":generate_text}) 115 | if index%1 == 0: 116 | print(generate_text) 117 | index += 1 118 | 119 | write_data(args.predictions_file, output_items) -------------------------------------------------------------------------------- /patent_pretrain.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Microsoft Corporation. 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | # DeepSpeed Team 6 | import argparse 7 | import json 8 | import os 9 | 10 | import deepspeed 11 | import jsonlines 12 | import math 13 | import torch 14 | import wandb 15 | from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam 16 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler 17 | from torch.utils.data.distributed import DistributedSampler 18 | from tqdm import tqdm 19 | from transformers import ( 20 | AutoModelForCausalLM, 21 | AutoTokenizer, 22 | SchedulerType, 23 | default_data_collator, 24 | get_scheduler, 25 | LlamaTokenizer 26 | ) 27 | 28 | from bnnt.data.data_utils import get_shuffle_idx 29 | from bnnt.data.raw_datasets import BNNTDataset 30 | from bnnt.ds_utils import get_train_ds_config 31 | from bnnt.model.model_utils import create_hf_model 32 | from bnnt.module.lora import convert_linear_layer_to_lora, only_optimize_lora_parameters 33 | from bnnt.utils import print_rank_0, to_device, save_hf_format, set_random_seed, get_all_reduce_mean, \ 34 | get_optimizer_grouped_parameters, save_zero_three_model 35 | 36 | 37 | def parse_args(): 38 | parser = argparse.ArgumentParser( 39 | description= 40 | "Finetune a transformers model on a causal language modeling task") 41 | parser.add_argument('--data_path', 42 | nargs='*', 43 | default=['Dahoas/rm-static'], 44 | help='Path to the training dataset. Accepted format:' 45 | '1) a single data path, 2) multiple datasets in the' 46 | 'form: dataset1-path dataset2-path ...') 47 | parser.add_argument('--data_split', 48 | type=str, 49 | default='6,2,2', 50 | help='Comma-separated list of proportions for training' 51 | 'phase 1, 2, and 3 data. For example the split `2,4,4`' 52 | 'will use 60% of data for phase 1, 20% for phase 2' 53 | 'and 20% for phase 3.') 54 | parser.add_argument( 55 | '--sft_only_data_path', 56 | nargs='*', 57 | default=[], 58 | help='Path to the dataset for only using in SFT phase.') 59 | parser.add_argument( 60 | '--data_output_path', 61 | type=str, 62 | default='/tmp/data_files/', 63 | help= 64 | 'Where to store the data-related files such as shuffle index. This needs to be on a local storage of a node (not on a shared storage)' 65 | ) 66 | parser.add_argument( 67 | "--model_name_or_path", 68 | type=str, 69 | help= 70 | "Path to pretrained model or model identifier from huggingface.co/models.", 71 | required=True, 72 | ) 73 | parser.add_argument( 74 | "--per_device_train_batch_size", 75 | type=int, 76 | default=16, 77 | help="Batch size (per device) for the training dataloader.", 78 | ) 79 | parser.add_argument( 80 | "--per_device_eval_batch_size", 81 | type=int, 82 | default=16, 83 | help="Batch size (per device) for the evaluation dataloader.", 84 | ) 85 | parser.add_argument( 86 | "--max_seq_len", 87 | type=int, 88 | default=512, 89 | help="The maximum sequence length.", 90 | ) 91 | parser.add_argument( 92 | "--learning_rate", 93 | type=float, 94 | default=1e-3, 95 | help= 96 | "Initial learning rate (after the potential warmup period) to use.", 97 | ) 98 | parser.add_argument("--weight_decay", 99 | type=float, 100 | default=0.1, 101 | help="Weight decay to use.") 102 | parser.add_argument("--num_train_epochs", 103 | type=int, 104 | default=1, 105 | help="Total number of training epochs to perform.") 106 | parser.add_argument( 107 | "--gradient_accumulation_steps", 108 | type=int, 109 | default=1, 110 | help= 111 | "Number of updates steps to accumulate before performing a backward/update pass.", 112 | ) 113 | parser.add_argument( 114 | "--lr_scheduler_type", 115 | type=SchedulerType, 116 | default="cosine", 117 | help="The scheduler type to use.", 118 | choices=[ 119 | "linear", "cosine", "cosine_with_restarts", "polynomial", 120 | "constant", "constant_with_warmup" 121 | ], 122 | ) 123 | parser.add_argument( 124 | "--num_warmup_steps", 125 | type=int, 126 | default=0, 127 | help="Number of steps for the warmup in the lr scheduler.") 128 | parser.add_argument("--output_dir", 129 | type=str, 130 | default=None, 131 | help="Where to store the model.") 132 | parser.add_argument("--seed", 133 | type=int, 134 | default=1234, 135 | help="A seed for reproducible training.") 136 | parser.add_argument("--local_rank", 137 | type=int, 138 | default=-1, 139 | help="local_rank for distributed training on gpus") 140 | parser.add_argument("--distributed_port", 141 | type=int, 142 | default=-1, 143 | help="distributed_port for distributed training on gpus") 144 | parser.add_argument('--gradient_checkpointing', 145 | action='store_true', 146 | help='Enable HF gradient checkpointing for model.') 147 | # deepspeed features 148 | parser.add_argument('--offload', 149 | action='store_true', 150 | help='Enable ZeRO Offload techniques.') 151 | parser.add_argument('--save_steps', 152 | type=int, 153 | default=10, 154 | help='Num of steps to save checkpoint.') 155 | parser.add_argument('--evaluation_steps', 156 | type=int, 157 | default=10, 158 | help='Num of steps to evaluate.') 159 | parser.add_argument( 160 | '--zero_stage', 161 | type=int, 162 | default=0, 163 | help='ZeRO optimization stage for Actor model (and clones).') 164 | ## LoRA for efficient training setting 165 | parser.add_argument("--lora_dim", 166 | type=int, 167 | default=0, 168 | help="If > 0, use LoRA for efficient training.") 169 | parser.add_argument("--lora_module_name", 170 | type=str, 171 | default="decoder.layers.", 172 | help="The scope of LoRA.") 173 | parser.add_argument('--only_optimize_lora', 174 | action='store_true', 175 | help='Only optimize the LoRA parameters.') 176 | parser = deepspeed.add_config_arguments(parser) 177 | args = parser.parse_args() 178 | 179 | # Validate settings 180 | if args.gradient_checkpointing and args.lora_dim > 0: 181 | assert ( 182 | not args.only_optimize_lora 183 | ), "--gradient_checkpointing and --only_optimizer_lora cannot be enabled at the same time." 184 | 185 | return args 186 | 187 | 188 | def main(): 189 | args = parse_args() 190 | 191 | # world_size = torch.distributed.get_world_size() 192 | PROJECT_NAME = f"bloomz-7b1-gpu-8-bs-{args.per_device_train_batch_size}-grad-accum-{args.gradient_accumulation_steps}-lr-{args.learning_rate}-maxlen-{args.max_seq_len}-ZeRO-{args.zero_stage}" 193 | 194 | wandb.init( 195 | # set the wandb project where this run will be logged 196 | entity="bnnt", 197 | project=PROJECT_NAME, 198 | 199 | # track hyperparameters and run metadata 200 | config={ 201 | "learning_rate": args.learning_rate, 202 | "dataset": "patent-full", 203 | "Epochs": args.num_train_epochs, 204 | } 205 | ) 206 | 207 | if args.local_rank == -1: 208 | device = torch.device("cuda") 209 | else: 210 | torch.cuda.set_device(args.local_rank) 211 | device = torch.device("cuda", args.local_rank) 212 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 213 | # torch.distributed.init_process_group(backend='nccl') 214 | deepspeed.init_distributed(distributed_port=args.distributed_port) 215 | 216 | args.global_rank = torch.distributed.get_rank() 217 | 218 | ds_config = get_train_ds_config(offload=args.offload, 219 | stage=args.zero_stage) 220 | ds_config[ 221 | 'train_micro_batch_size_per_gpu'] = args.per_device_train_batch_size 222 | ds_config[ 223 | 'train_batch_size'] = args.per_device_train_batch_size * torch.distributed.get_world_size( 224 | ) * args.gradient_accumulation_steps 225 | 226 | # If passed along, set the training seed now. 227 | set_random_seed(args.seed) 228 | 229 | assert not args.offload, "zero-offload is not currently supported but coming soon!" 230 | 231 | torch.distributed.barrier() 232 | 233 | print("model_name_or_path : ", args.model_name_or_path) 234 | if "llama" in args.model_name_or_path: 235 | tokenizer = LlamaTokenizer.from_pretrained(args.model_name_or_path, cache_dir='/data6/.cache/huggingface/hub/') 236 | assert tokenizer.eos_token_id == 2 237 | assert tokenizer.bos_token_id == 1 238 | args.lora_module_name = [ 239 | "q_proj", 240 | "k_proj", 241 | "v_proj", 242 | "down_proj", 243 | "gate_proj", 244 | "up_proj" 245 | ] 246 | else: 247 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) 248 | 249 | # tokenizer.pad_token = tokenizer.eos_token 250 | tokenizer.pad_token_id = 0 251 | tokenizer.padding_side = "left" 252 | print("Making tokenizer padding side to left") 253 | 254 | model = create_hf_model(AutoModelForCausalLM, args.model_name_or_path, 255 | tokenizer, ds_config) 256 | 257 | if args.lora_dim > 0: 258 | model = convert_linear_layer_to_lora(model, args.lora_module_name, 259 | args.lora_dim) 260 | # model = convert_LLaMA_to_lora(model, args.lora_module_name) 261 | if args.only_optimize_lora: 262 | model = only_optimize_lora_parameters(model) 263 | 264 | # Prepare the data 265 | train_phase = 1 266 | # print("sft_only_data_path : ", args.sft_only_data_path) 267 | # train_dataset, eval_dataset = create_prompt_dataset( 268 | # args.local_rank, 269 | # args.data_path, 270 | # args.data_split, 271 | # args.data_output_path, 272 | # train_phase, 273 | # args.seed, 274 | # tokenizer, 275 | # args.max_seq_len, 276 | # sft_only_data_path=args.sft_only_data_path) 277 | 278 | data_path = f"/data6/xjy2023/xl/new_shard_{args.local_rank}.jsonl" 279 | print("data_path : ", data_path) 280 | if data_path.endswith("json"): 281 | with open(data_path) as f: 282 | raw_dataset = json.load(f) 283 | shuffle_idx = get_shuffle_idx(args.seed, len(raw_dataset[:1000])) 284 | 285 | eval_idx = list(shuffle_idx)[:500] 286 | print(eval_idx) 287 | 288 | print("Loading dataset") 289 | pbar = tqdm(total=len(raw_dataset), position=args.local_rank) 290 | train_dataset, eval_dataset = [], [] 291 | # for i, tmp_data in tqdm(enumerate(raw_dataset), total=len(raw_dataset), position=args.local_rank): 292 | tmp_idx = 0 293 | while raw_dataset: 294 | tmp_data = raw_dataset.pop() 295 | # tokenize the text 296 | chosen_sentence = f"标题:{tmp_data['title']}。摘要:{tmp_data['summary']}专利公开号:{tmp_data['publicNo']}。权利要求:{tmp_data['powerRequirements']}说明书:{tmp_data['instructions']}" 297 | chosen_token = tokenizer(chosen_sentence, 298 | max_length=args.max_seq_len, 299 | padding="max_length", 300 | truncation=True) 301 | 302 | chosen_token["input_ids"][-1] = tokenizer.eos_token_id 303 | 304 | chosen_token["labels"] = torch.LongTensor( 305 | [-100] + [-100 if tokenizer.pad_token_id == idx else idx for idx in chosen_token["input_ids"]][1:]) 306 | 307 | chosen_token["input_ids"] = torch.LongTensor(chosen_token["input_ids"]).squeeze(0) 308 | chosen_token["attention_mask"] = torch.LongTensor(chosen_token["attention_mask"]).squeeze(0) 309 | chosen_token["labels"] = torch.LongTensor(chosen_token["labels"]).squeeze(0) 310 | 311 | if tmp_idx in eval_idx: 312 | eval_dataset.append(chosen_token) 313 | else: 314 | train_dataset.append(chosen_token) 315 | tmp_idx += 1 316 | pbar.update(1) 317 | elif data_path.endswith("jsonl"): 318 | # raw_dataset = [] 319 | # with open(data_path,'r',encoding='utf-8') as f: 320 | # for _ in tqdm(range(3500000)): 321 | # line = f.readline() 322 | # line_dict = ast.literal_eval(line) 323 | # raw_dataset.append(line_dict) 324 | total_size = 3500000 325 | # shuffle_idx = get_shuffle_idx(args.seed, total_size) 326 | # print("raw dataset: ", total_size) 327 | # eval_idx = list(shuffle_idx)[:500] 328 | # print(eval_idx) 329 | raw_data = [] 330 | with jsonlines.open(data_path) as f: 331 | for i, tmp_data in tqdm(enumerate(f), 332 | desc=data_path, 333 | position=args.local_rank, 334 | total=total_size): 335 | if i < total_size: 336 | raw_data.append(tmp_data) 337 | else: 338 | break 339 | 340 | train_dataset = BNNTDataset(raw_data[500:], tokenizer, args) 341 | eval_dataset = BNNTDataset(raw_data[:500], tokenizer, args) 342 | 343 | # DataLoaders creation: 344 | if args.local_rank == -1: 345 | train_sampler = RandomSampler(train_dataset) 346 | eval_sampler = SequentialSampler(eval_dataset) 347 | else: 348 | train_sampler = DistributedSampler(train_dataset) 349 | eval_sampler = DistributedSampler(eval_dataset) 350 | train_dataloader = DataLoader(train_dataset, 351 | collate_fn=default_data_collator, 352 | sampler=train_sampler, 353 | batch_size=args.per_device_train_batch_size) 354 | eval_dataloader = DataLoader(eval_dataset, 355 | collate_fn=default_data_collator, 356 | sampler=eval_sampler, 357 | batch_size=args.per_device_eval_batch_size) 358 | 359 | def evaluation(model, eval_dataloader): 360 | model.eval() 361 | losses = 0 362 | for step, batch in tqdm(enumerate(eval_dataloader), total=len(eval_dataloader), unit="batch"): 363 | batch = to_device(batch, device) 364 | with torch.no_grad(): 365 | outputs = model(**batch) 366 | 367 | loss = outputs.loss 368 | losses += loss.float() 369 | losses = losses / (step + 1) 370 | model.train() 371 | try: 372 | perplexity = torch.exp(losses) 373 | except OverflowError: 374 | perplexity = float("inf") 375 | try: 376 | perplexity = get_all_reduce_mean(perplexity).item() 377 | except: 378 | pass 379 | return perplexity 380 | 381 | # Split weights in two groups, one with weight decay and the other not. 382 | optimizer_grouped_parameters = get_optimizer_grouped_parameters( 383 | model, args.weight_decay) 384 | 385 | AdamOptimizer = DeepSpeedCPUAdam if args.offload else FusedAdam 386 | optimizer = AdamOptimizer(optimizer_grouped_parameters, 387 | lr=args.learning_rate, 388 | betas=(0.9, 0.95)) 389 | 390 | num_update_steps_per_epoch = math.ceil( 391 | len(train_dataloader) / args.gradient_accumulation_steps) 392 | lr_scheduler = get_scheduler( 393 | name=args.lr_scheduler_type, 394 | optimizer=optimizer, 395 | num_warmup_steps=args.num_warmup_steps, 396 | num_training_steps=args.num_train_epochs * num_update_steps_per_epoch, 397 | ) 398 | 399 | model, optimizer, _, lr_scheduler = deepspeed.initialize( 400 | model=model, 401 | optimizer=optimizer, 402 | args=args, 403 | config=ds_config, 404 | lr_scheduler=lr_scheduler, 405 | dist_init_required=True) 406 | 407 | if args.gradient_checkpointing: 408 | model.gradient_checkpointing_enable() 409 | 410 | # Train! 411 | print_rank_0("***** Running training *****", args.global_rank) 412 | print_rank_0( 413 | f"***** Evaluating perplexity, Epoch {0}/{args.num_train_epochs} *****", 414 | args.global_rank) 415 | perplexity = evaluation(model, eval_dataloader) 416 | print_rank_0(f"ppl: {perplexity}", args.global_rank) 417 | 418 | for epoch in range(args.num_train_epochs): 419 | print_rank_0( 420 | f"Beginning of Epoch {epoch + 1}/{args.num_train_epochs}, Total Micro Batches {len(train_dataloader)}", 421 | args.global_rank) 422 | model.train() 423 | for step, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader), unit="batch"): 424 | batch = to_device(batch, device) 425 | outputs = model(**batch, use_cache=False) 426 | loss = outputs.loss 427 | model.backward(loss) 428 | model.step() 429 | print("Epoch: {}, step: {}, loss: {}".format(epoch, step, loss.item())) 430 | wandb.log({"loss": loss.item()}, step=step) 431 | 432 | if (step + 1) % args.save_steps == 0: 433 | save_path = os.path.join(args.output_dir, f"Epoch-{epoch + 1}-step-{step + 1}") 434 | 435 | save_zero_three_model(model, 436 | args.global_rank, 437 | save_path, 438 | zero_stage=args.zero_stage) 439 | 440 | print_rank_0( 441 | f"Saving checkpoint... Steps {step + 1} Epoch {epoch + 1}/{args.num_train_epochs}", 442 | args.global_rank) 443 | 444 | if (step + 1) % args.evaluation_steps == 0: 445 | perplexity = evaluation(model, eval_dataloader) 446 | 447 | wandb.log({"perplexity": perplexity}, step=step) 448 | 449 | print_rank_0(f"ppl: {perplexity}", args.global_rank) 450 | print_rank_0( 451 | f"***** Evaluating perplexity, Steps {step + 1}, Epoch {epoch + 1}/{args.num_train_epochs} *****", 452 | args.global_rank) 453 | 454 | # Evaluate perplexity on the validation set. 455 | perplexity = evaluation(model, eval_dataloader) 456 | print_rank_0(f"ppl: {perplexity}", args.global_rank) 457 | print_rank_0( 458 | f"***** Evaluating perplexity, Steps {step + 1}, Epoch {epoch + 1}/{args.num_train_epochs} *****", 459 | args.global_rank) 460 | 461 | # Save after each epoch. 462 | 463 | model.tput_timer.update_epoch_count() 464 | 465 | if args.output_dir is not None: 466 | print_rank_0('saving the final model ...', args.global_rank) 467 | # model = convert_lora_to_linear_layer(model) 468 | 469 | if args.global_rank == 0: 470 | save_hf_format(model, tokenizer, args) 471 | 472 | if args.zero_stage == 3: 473 | # For zero stage 3, each gpu only has a part of the model, so we need a special save function 474 | save_zero_three_model(model, 475 | args.global_rank, 476 | args.output_dir, 477 | zero_stage=args.zero_stage) 478 | 479 | 480 | if __name__ == "__main__": 481 | main() 482 | -------------------------------------------------------------------------------- /run_SFT.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Microsoft Corporation. 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | export NCCL_P2P_DISABLE="1" 6 | export NCCL_IB_DISABLE="1" 7 | # DeepSpeed Team 8 | UUID=$(uuidgen) 9 | echo "${UUID}" 10 | RUN_FILE=$(readlink -f "$0") 11 | WORK_DIR=$(dirname "$RUN_FILE") 12 | echo "${WORK_DIR}" 13 | 14 | # DeepSpeed Team 15 | OUTPUT=$1 16 | if [ "$OUTPUT" == "" ]; then 17 | OUTPUT=/home//output/${UUID} 18 | fi 19 | 20 | mkdir -p $OUTPUT 21 | #bigscience/bloomz-1b7 22 | 23 | 24 | mkdir -p "${OUTPUT}"/logs 25 | log_file="${OUTPUT}"/logs/train.txt 26 | exec &> >(tee -a "$log_file") 27 | 28 | # CUDA_VISIBLE_DEVICES=4 29 | # /nfs/data6/patent_sft/instruction_general_3026150_conversations.json 30 | # /data6/instruction_patent_20k_conversations 31 | deepspeed SFT.py \ 32 | --sft_only_data_path /BELLE/instruction_ip.json \ 33 | --model_name_or_path /mozi-7b-3m-40k \ 34 | --per_device_train_batch_size 1 \ 35 | --per_device_eval_batch_size 1 \ 36 | --max_seq_len 2048 \ 37 | --learning_rate 5e-6 \ 38 | --weight_decay 0.0001 \ 39 | --num_train_epochs 3 \ 40 | --gradient_accumulation_steps 4 \ 41 | --lr_scheduler_type cosine \ 42 | --num_warmup_steps 100 \ 43 | --seed 1234 \ 44 | --zero_stage 3 \ 45 | --gradient_checkpointing \ 46 | --save_steps 2000 \ 47 | --evaluation_steps 500 \ 48 | --output_dir $OUTPUT 49 | # &> $OUTPUT/training.log 50 | -------------------------------------------------------------------------------- /run_inference.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=4,5,6,7 python inference.py \ 2 | --model_name_or_path ../mozi-7b-3m-40k \ 3 | --foundation_model bloom \ 4 | --test_file IPQA-test-5.json \ 5 | --predictions_file ./mozi-7b--predictions-ipqa-5.json 6 | 7 | 8 | # /data6/.cache/huggingface/hub/models--decapoda-research--llama-7b-hf/snapshots/5f98eefcc80e437ef68d457ad7bf167c2c6a1348 9 | # /data6/model/BELLE-7B-2M 10 | # CUDA_VISIBLE_DEVICES=7 11 | # /data6/.cache/huggingface/hub/models--bigscience--bloomz-7b1-mt/snapshots/13e9b1a39fe86c8024fe15667d063aa8a3e32460 12 | # /data6/.cache/huggingface/hub/models--THUDM--chatglm-6b/snapshots/35ca52301fbedee885b0838da5d15b7b47faa37c 13 | # /data6/mozi-7b-3m-40k 14 | -------------------------------------------------------------------------------- /run_pretrain.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Microsoft Corporation. 3 | # SPDX-License-Identifier: Apache-2.0 4 | UUID=$(uuidgen) 5 | echo "${UUID}" 6 | RUN_FILE=$(readlink -f "$0") 7 | WORK_DIR=$(dirname "$RUN_FILE") 8 | echo "${WORK_DIR}" 9 | 10 | # DeepSpeed Team 11 | OUTPUT=$1 12 | if [ "$OUTPUT" == "" ]; then 13 | OUTPUT=/data6/output/${UUID} 14 | fi 15 | 16 | mkdir -p $OUTPUT 17 | #bigscience/bloomz-1b7 18 | 19 | DISTRIBUTED_PORT=25002 20 | 21 | mkdir -p "${OUTPUT}"/logs 22 | log_file="${OUTPUT}"/logs/train.txt 23 | exec &> >(tee -a "$log_file") 24 | 25 | PYTHONPATH="${WORK_DIR}"/src deepspeed --master_port 25003 patent_pretrain.py.py \ 26 | --sft_only_data_path belleMath.json \ 27 | --model_name_or_path /data6/.cache/huggingface/hub/models--bigscience--bloomz-7b1-mt/snapshots/13e9b1a39fe86c8024fe15667d063aa8a3e32460/ \ 28 | --per_device_train_batch_size 1 \ 29 | --per_device_eval_batch_size 2 \ 30 | --max_seq_len 2048 \ 31 | --learning_rate 5e-6 \ 32 | --weight_decay 0.0001 \ 33 | --num_train_epochs 1 \ 34 | --gradient_accumulation_steps 8 \ 35 | --lr_scheduler_type cosine \ 36 | --num_warmup_steps 1000 \ 37 | --save_steps 2500 \ 38 | --evaluation_steps 1000 \ 39 | --zero_stage 3 \ 40 | --seed 1234 \ 41 | --gradient_checkpointing \ 42 | --distributed_port $DISTRIBUTED_PORT \ 43 | --output_dir $OUTPUT 44 | # &> $OUTPUT/training.log 45 | 46 | # --deepspeed_config ds_config.json \ 47 | -------------------------------------------------------------------------------- /src/bnnt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-for-Science/MoZi/517cb2ed7b7cab0d23803c964ae671f6df405e8b/src/bnnt/__init__.py -------------------------------------------------------------------------------- /src/bnnt/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-for-Science/MoZi/517cb2ed7b7cab0d23803c964ae671f6df405e8b/src/bnnt/data/__init__.py -------------------------------------------------------------------------------- /src/bnnt/data/data_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | # DeepSpeed Team 5 | """ 6 | Part of the code was adopted from https://github.com/microsoft/Megatron-DeepSpeed/blob/main/megatron/data/dataset_utils.py 7 | """ 8 | import torch 9 | from torch.utils.data import Dataset, Subset 10 | from torch.nn.utils.rnn import pad_sequence 11 | import torch.nn.functional as F 12 | from datasets import load_dataset 13 | import numpy as np 14 | import os 15 | from itertools import chain 16 | from . import raw_datasets 17 | 18 | 19 | def get_raw_dataset(dataset_name, output_path, seed, local_rank): 20 | if type(dataset_name)==list: 21 | dataset_name = dataset_name[0] 22 | print("dataset_name : ", dataset_name) 23 | if "belle" in dataset_name.lower(): 24 | return raw_datasets.BelleOpenSoucreDataset(output_path, seed, local_rank, data_file=dataset_name) 25 | 26 | elif dataset_name == "Dahoas/rm-static": 27 | return raw_datasets.DahoasRmstaticDataset(output_path, seed, 28 | local_rank) 29 | elif dataset_name == "Dahoas/full-hh-rlhf": 30 | return raw_datasets.DahoasFullhhrlhfDataset(output_path, seed, 31 | local_rank) 32 | elif dataset_name == "Dahoas/synthetic-instruct-gptj-pairwise": 33 | return raw_datasets.DahoasSyntheticinstructgptjpairwiseDataset( 34 | output_path, seed, local_rank) 35 | elif dataset_name == "yitingxie/rlhf-reward-datasets": 36 | return raw_datasets.YitingxieRlhfrewarddatasetsDataset( 37 | output_path, seed, local_rank) 38 | elif dataset_name == "openai/webgpt_comparisons": 39 | return raw_datasets.OpenaiWebgptcomparisonsDataset( 40 | output_path, seed, local_rank) 41 | elif dataset_name == "stanfordnlp/SHP": 42 | return raw_datasets.StanfordnlpSHPDataset(output_path, seed, 43 | local_rank) 44 | elif dataset_name == "wangrui6/Zhihu-KOL": 45 | return raw_datasets.Wangrui6ZhihuKOLDataset(output_path, seed, 46 | local_rank) 47 | elif dataset_name == "Cohere/miracl-zh-queries-22-12": 48 | return raw_datasets.CohereMiraclzhqueries2212Dataset( 49 | output_path, seed, local_rank) 50 | elif dataset_name == "Hello-SimpleAI/HC3-Chinese": 51 | return raw_datasets.HelloSimpleAIHC3ChineseDataset( 52 | output_path, seed, local_rank) 53 | elif dataset_name == "mkqa-Chinese": 54 | return raw_datasets.MkqaChineseDataset(output_path, seed, local_rank) 55 | elif dataset_name == "mkqa-Japanese": 56 | return raw_datasets.MkqaJapaneseDataset(output_path, seed, local_rank) 57 | elif dataset_name == "Cohere/miracl-ja-queries-22-12": 58 | return raw_datasets.CohereMiracljaqueries2212Dataset( 59 | output_path, seed, local_rank) 60 | elif dataset_name == "lmqg/qg_jaquad": 61 | return raw_datasets.LmqgQgjaquadDataset(output_path, seed, local_rank) 62 | elif dataset_name == "lmqg/qag_jaquad": 63 | return raw_datasets.LmqgQagjaquadDataset(output_path, seed, local_rank) 64 | else: 65 | raise RuntimeError( 66 | f"We do not have configs for dataset {dataset_name}, but you can add it by yourself in raw_datasets.py." 67 | ) 68 | 69 | 70 | def get_shuffle_idx(seed, size): 71 | np_rng = np.random.RandomState(seed=seed) 72 | dtype_ = np.uint32 73 | if size >= (np.iinfo(np.uint32).max - 1): 74 | dtype_ = np.int64 75 | shuffle_idx = np.arange(start=0, stop=size, step=1, dtype=dtype_) 76 | np_rng.shuffle(shuffle_idx) 77 | return shuffle_idx 78 | 79 | 80 | def get_raw_dataset_split_index(local_rank, output_path, dataset_name, seed, 81 | split_name, data_split, split_index, 82 | data_size): 83 | index_file_name = f"{output_path}/{dataset_name}_seed{seed}_{split_name}_{data_split}_{split_index}.npy" 84 | if not os.path.isfile(index_file_name) and local_rank <= 0: 85 | splits = [float(s) for s in data_split.split(',')] 86 | splits_sum = sum(splits) 87 | splits = [split / splits_sum for split in splits] 88 | splits_index = [0] 89 | for index, split in enumerate(splits): 90 | splits_index.append(splits_index[index] + 91 | int(round(split * float(data_size)))) 92 | diff = splits_index[-1] - data_size 93 | for index in range(1, len(splits_index)): 94 | splits_index[index] -= diff 95 | assert splits_index[-1] == data_size 96 | 97 | shuffle_idx = get_shuffle_idx(seed, data_size) 98 | for split_i in range(len(splits)): 99 | shuffle_idx_split_file_name = f"{output_path}/{dataset_name}_seed{seed}_{split_name}_{data_split}_{split_i}.npy" 100 | shuffle_idx_split = shuffle_idx[ 101 | splits_index[split_i]:splits_index[split_i + 1]] 102 | np.save(shuffle_idx_split_file_name, 103 | shuffle_idx_split, 104 | allow_pickle=True) 105 | torch.distributed.barrier() 106 | index = np.load(index_file_name, allow_pickle=True) 107 | return index.tolist() 108 | 109 | 110 | class PromptDataset(Dataset): 111 | 112 | def __init__(self, prompt_dataset, chosen_dataset, reject_dataset, 113 | pad_token_id, train_phase) -> None: 114 | super().__init__() 115 | self.prompt_dataset = prompt_dataset 116 | self.chosen_dataset = chosen_dataset 117 | self.reject_dataset = reject_dataset 118 | self.pad_token_id = pad_token_id 119 | self.train_phase = train_phase 120 | 121 | def __len__(self): 122 | length = len(self.chosen_dataset) 123 | if self.train_phase == 3: 124 | length = len(self.prompt_dataset) 125 | return length 126 | 127 | def __getitem__(self, idx): 128 | if self.train_phase == 1: 129 | return { 130 | "input_ids": self.chosen_dataset[idx]["input_ids"], 131 | "attention_mask": self.chosen_dataset[idx]["attention_mask"], 132 | "labels": self.chosen_dataset[idx]["labels"] 133 | } 134 | elif self.train_phase == 2: 135 | return self.chosen_dataset[idx]["input_ids"], self.chosen_dataset[idx]["attention_mask"], \ 136 | self.reject_dataset[idx]["input_ids"], self.reject_dataset[idx]["attention_mask"] 137 | elif self.train_phase == 3: 138 | return self.prompt_dataset[idx]["input_ids"],self.prompt_dataset[idx]["attention_mask"], \ 139 | self.pad_token_id 140 | 141 | 142 | 143 | def create_dataset_split(current_dataset, raw_dataset, train_phase, tokenizer, 144 | end_of_conversation_token, max_seq_len): 145 | prompt_dataset = [] 146 | chosen_dataset = [] 147 | reject_dataset = [] 148 | assert tokenizer.padding_side == "left" 149 | if train_phase == 1: 150 | for i, tmp_data in enumerate(current_dataset): 151 | # tokenize the text 152 | prompt_text = raw_dataset.get_prompt(tmp_data) 153 | tokenized_prompt_text = tokenizer(prompt_text, truncation=True,max_length=max_seq_len,padding=False,return_tensors=None) 154 | user_prompt_len = len(tokenized_prompt_text["input_ids"]) 155 | 156 | chosen_sentence = raw_dataset.get_prompt_and_chosen(tmp_data) # the accept response 157 | 158 | chosen_sentence += end_of_conversation_token 159 | chosen_token = tokenizer(chosen_sentence, 160 | max_length=max_seq_len, 161 | padding="max_length", 162 | truncation=True) 163 | 164 | if chosen_token["input_ids"][-1] != tokenizer.eos_token_id:#Make sure tokenizer.padding_side is left 165 | chosen_token["input_ids"].append(tokenizer.eos_token_id) 166 | chosen_token["attention_mask"].append(1) 167 | 168 | chosen_token["labels"] = torch.LongTensor([-100] * user_prompt_len + chosen_token["input_ids"][user_prompt_len:]) 169 | 170 | chosen_token["input_ids"] = torch.LongTensor(chosen_token["input_ids"]).squeeze(0) 171 | chosen_token["attention_mask"] = torch.LongTensor(chosen_token["attention_mask"]).squeeze(0) 172 | chosen_token["labels"] = torch.LongTensor(chosen_token["labels"]).squeeze(0) 173 | 174 | chosen_dataset.append(chosen_token) 175 | 176 | elif train_phase == 2: 177 | for i, tmp_data in enumerate(current_dataset): 178 | # tokenize the text 179 | chosen_sentence = raw_dataset.get_prompt_and_chosen( 180 | tmp_data) # the accept response 181 | reject_sentence = raw_dataset.get_prompt_and_rejected( 182 | tmp_data) # the accept response 183 | if chosen_sentence is not None and reject_sentence is not None: 184 | chosen_sentence += end_of_conversation_token # the accept response 185 | reject_sentence += end_of_conversation_token 186 | chosen_token = tokenizer(chosen_sentence, 187 | max_length=max_seq_len, 188 | padding="max_length", 189 | truncation=True, 190 | return_tensors="pt") 191 | reject_token = tokenizer(reject_sentence, 192 | max_length=max_seq_len, 193 | padding="max_length", 194 | truncation=True, 195 | return_tensors="pt") 196 | chosen_token["input_ids"] = chosen_token["input_ids"] 197 | chosen_token["attention_mask"] = chosen_token["attention_mask"] 198 | chosen_dataset.append(chosen_token) 199 | 200 | reject_token["input_ids"] = reject_token["input_ids"] 201 | reject_token["attention_mask"] = reject_token["attention_mask"] 202 | reject_dataset.append(reject_token) 203 | 204 | elif train_phase == 3: 205 | for i, tmp_data in enumerate(current_dataset): 206 | # tokenize the text 207 | prompt = raw_dataset.get_prompt(tmp_data) 208 | if prompt is not None: 209 | prompt_token = tokenizer(prompt, return_tensors="pt") 210 | prompt_token["input_ids"] = prompt_token["input_ids"] 211 | prompt_token["attention_mask"] = prompt_token["attention_mask"] 212 | for key_word in ["input_ids", "attention_mask"]: 213 | length = prompt_token[key_word].size()[-1] 214 | if length > max_seq_len: 215 | y = prompt_token[key_word].squeeze(0)[length - 216 | (max_seq_len - 217 | 1):].flip(0) 218 | else: 219 | y = prompt_token[key_word].squeeze(0).flip(0) 220 | prompt_token[key_word] = y 221 | prompt_dataset.append(prompt_token) 222 | return PromptDataset(prompt_dataset, chosen_dataset, reject_dataset, 223 | tokenizer.pad_token_id, train_phase) 224 | 225 | 226 | def create_dataset(local_rank, dataset_name, data_split, output_path, 227 | train_phase, seed, tokenizer, end_of_conversation_token, 228 | max_seq_len): 229 | #dataset_name can be the file path 230 | print("dataset_name: ", dataset_name) 231 | raw_dataset = get_raw_dataset(dataset_name, output_path, seed, local_rank) 232 | train_dataset = raw_dataset.get_train_data() 233 | train_index = get_raw_dataset_split_index(local_rank, output_path, 234 | raw_dataset.dataset_name_clean, 235 | seed, "train", data_split, 236 | train_phase - 1, 237 | len(train_dataset)) 238 | train_dataset = Subset(train_dataset, train_index) 239 | train_dataset = create_dataset_split(train_dataset, raw_dataset, 240 | train_phase, tokenizer, 241 | end_of_conversation_token, 242 | max_seq_len) 243 | 244 | eval_dataset = raw_dataset.get_eval_data() 245 | eval_index = get_raw_dataset_split_index(local_rank, output_path, 246 | raw_dataset.dataset_name_clean, 247 | seed, "eval", 248 | data_split, train_phase - 1, 249 | len(eval_dataset)) 250 | eval_dataset = Subset(eval_dataset, eval_index) 251 | eval_dataset = create_dataset_split(eval_dataset, raw_dataset, train_phase, 252 | tokenizer, end_of_conversation_token, 253 | max_seq_len) 254 | 255 | # for item in train_dataset: 256 | # print(item) 257 | return train_dataset, eval_dataset 258 | 259 | 260 | 261 | def create_prompt_dataset(local_rank, 262 | data_path, 263 | data_split, 264 | output_path, 265 | train_phase, 266 | seed, 267 | tokenizer, 268 | max_seq_len, 269 | end_of_conversation_token="<|endoftext|>", 270 | sft_only_data_path=[]): 271 | """ 272 | Creates the prompt dataset 273 | """ 274 | os.makedirs(output_path, exist_ok=True) 275 | print("data_path: ", data_path)#['/nfs/a100-80G-18/xunxianghui/gitrepositories/training_datasets/belle/belle_extra_5k.dev.json'] 276 | 277 | fname = "_".join(data_path) 278 | sft_cache_key = "_".join(sft_only_data_path) 279 | tokenizer_name = tokenizer.init_kwargs["name_or_path"].replace("/", "_") 280 | fname = f"{fname}_split{data_split}_phase{train_phase}_seed{seed}_tokenizer{tokenizer_name}_seqlen{max_seq_len}_sft{sft_cache_key}" 281 | fname = "_".join(fname.split("/")) 282 | fname = str(hash(fname)) # hash the file name to avoid too long file name 283 | train_fname = f"{output_path}/traindata_{fname}.pt" 284 | eval_fname = f"{output_path}/evaldata_{fname}.pt" 285 | print("fname = " + fname) 286 | print("sft_cache_key = " +sft_cache_key) 287 | print("tokenizer_name "+tokenizer_name) 288 | 289 | cache_found = os.path.isfile(train_fname) and os.path.isfile(eval_fname) 290 | buf_create_cache = torch.ByteTensor([not cache_found]).cuda() 291 | torch.distributed.all_reduce(buf_create_cache) 292 | train_dataset, eval_dataset = create_dataset( 293 | local_rank, sft_only_data_path, data_split, output_path, train_phase, 294 | seed, tokenizer, end_of_conversation_token, max_seq_len) 295 | if local_rank <= 0: 296 | torch.save(train_dataset, train_fname) 297 | torch.save(eval_dataset, eval_fname) 298 | return train_dataset, eval_dataset 299 | 300 | 301 | # # Skip creating cache if we found it on all the nodes. 302 | # if buf_create_cache.item() == 0: 303 | # return torch.load(train_fname), torch.load(eval_fname) 304 | # else: 305 | # if len(data_path) == 1: # Single dataset. 306 | # train_dataset, eval_dataset = create_dataset( 307 | # local_rank, data_path[0], data_split, output_path, train_phase, 308 | # seed, tokenizer, end_of_conversation_token, max_seq_len) 309 | # else: # Blending datasets. 310 | # train_datasets = [] 311 | # eval_datasets = [] 312 | # train_size = 0 313 | # eval_size = 0 314 | # for d_path in data_path: 315 | # train_dataset, eval_dataset = create_dataset( 316 | # local_rank, d_path, data_split, output_path, train_phase, 317 | # seed, tokenizer, end_of_conversation_token, max_seq_len) 318 | # train_datasets.append(train_dataset) 319 | # eval_datasets.append(eval_dataset) 320 | # train_size += len(train_dataset) 321 | # eval_size += len(eval_dataset) 322 | # train_dataset = ConcatDataset(train_datasets) 323 | # shuffle_idx = get_shuffle_idx(seed, train_size) 324 | # train_dataset = Subset(train_dataset, shuffle_idx.tolist()) 325 | # eval_dataset = ConcatDataset(eval_datasets) 326 | # shuffle_idx = get_shuffle_idx(seed, eval_size) 327 | # eval_dataset = Subset(eval_dataset, shuffle_idx.tolist()) 328 | 329 | # # Append the SFT-only dataset if it exists, and current phase is 1(SFT). 330 | # if train_phase == 1 and sft_only_data_path: 331 | # sft_train_datasets = [] 332 | # sft_eval_datasets = [] 333 | # sft_train_size = 0 334 | # sft_eval_size = 0 335 | # for sft_path in sft_only_data_path: 336 | # sft_train_dataset, sft_eval_dataset = create_dataset( 337 | # local_rank, 338 | # sft_path, 339 | # "10,0,0", 340 | # output_path, 341 | # train_phase, 342 | # seed, 343 | # tokenizer, 344 | # end_of_conversation_token, 345 | # max_seq_len, 346 | # ) 347 | # sft_train_datasets.append(sft_train_dataset) 348 | # sft_eval_datasets.append(sft_eval_dataset) 349 | # sft_train_size += len(sft_train_dataset) 350 | # sft_eval_size += len(sft_eval_dataset) 351 | # if sft_train_datasets: # Check if sft_train_datasets is not empty 352 | # sft_train_dataset = ConcatDataset(sft_train_datasets) 353 | # train_dataset = ConcatDataset( 354 | # [train_dataset, sft_train_dataset]) 355 | # shuffle_idx = get_shuffle_idx(seed, len(train_dataset)) 356 | # train_dataset = Subset(train_dataset, shuffle_idx.tolist()) 357 | # if sft_eval_datasets: # Check if sft_eval_datasets is not empty 358 | # sft_eval_dataset = ConcatDataset(sft_eval_datasets) 359 | # eval_dataset = ConcatDataset([eval_dataset, sft_eval_dataset]) 360 | # shuffle_idx = get_shuffle_idx(seed, len(eval_dataset)) 361 | # eval_dataset = Subset(eval_dataset, shuffle_idx.tolist()) 362 | 363 | # if local_rank <= 0: 364 | # torch.save(train_dataset, train_fname) 365 | # torch.save(eval_dataset, eval_fname) 366 | # return train_dataset, eval_dataset 367 | 368 | 369 | class DataCollatorReward: 370 | 371 | def __call__(self, data): 372 | batch = {} 373 | batch["input_ids"] = torch.cat([f[0] 374 | for f in data] + [f[2] for f in data], 375 | dim=0) 376 | batch["attention_mask"] = torch.cat([f[1] for f in data] + 377 | [f[3] for f in data], 378 | dim=0) 379 | return batch 380 | 381 | 382 | class DataCollatorRLHF: 383 | 384 | def __init__(self, max_token_len, inference_tp_size): 385 | self.max_token_len = max_token_len 386 | self.inference_tp_size = inference_tp_size 387 | 388 | def __call__(self, data): 389 | batch = {} 390 | pad_token_id = data[-1][-1] 391 | 392 | prompt = pad_sequence([f[0] for f in data], 393 | padding_value=pad_token_id, 394 | batch_first=True) 395 | prompt_mask = pad_sequence([f[1] for f in data], 396 | padding_value=0, 397 | batch_first=True) 398 | 399 | ### make sure the final ouput is a seqence of 2**? 400 | length = prompt.size()[-1] 401 | pad_length = self.max_token_len - length 402 | if pad_length > 0: 403 | batch["prompt"] = F.pad(prompt, 404 | pad=(pad_length, 0), 405 | mode='constant', 406 | value=pad_token_id) 407 | batch["prompt_att_mask"] = F.pad(prompt_mask, 408 | pad=(pad_length, 0), 409 | mode='constant', 410 | value=0) 411 | else: 412 | batch["prompt"] = prompt 413 | batch["prompt_att_mask"] = prompt_mask 414 | batch["prompt"] = batch["prompt"].flip(1) 415 | batch["prompt_att_mask"] = batch["prompt_att_mask"].flip(1) 416 | return batch 417 | 418 | 419 | def get_unsupervised_data(args, tokenizer): 420 | unsupervised_raw_datasets = load_dataset( 421 | args.unsupervised_dataset_name, args.unsupervised_dataset_config_name) 422 | column_names = unsupervised_raw_datasets["train"].column_names 423 | text_column_name = "text" if "text" in column_names else column_names[0] 424 | 425 | def tokenize_function(examples): 426 | return tokenizer(examples[text_column_name]) 427 | 428 | tokenized_datasets = unsupervised_raw_datasets.map( 429 | tokenize_function, 430 | batched=True, 431 | num_proc=args.preprocessing_num_workers, 432 | remove_columns=column_names, 433 | load_from_cache_file=True, 434 | desc="Running tokenizer on dataset", 435 | ) 436 | 437 | block_size = args.max_prompt_seq_len + args.max_answer_seq_len 438 | 439 | def group_texts(examples): 440 | # Concatenate all texts. 441 | concatenated_examples = { 442 | k: list(chain(*examples[k])) 443 | for k in examples.keys() 444 | } 445 | total_length = len(concatenated_examples[list(examples.keys())[0]]) 446 | # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can 447 | # customize this part to your needs. 448 | if total_length >= block_size: 449 | total_length = (total_length // block_size) * block_size 450 | # Split by chunks of max_len. 451 | result = { 452 | k: 453 | [t[i:i + block_size] for i in range(0, total_length, block_size)] 454 | for k, t in concatenated_examples.items() 455 | } 456 | result["labels"] = result["input_ids"].copy() 457 | return result 458 | 459 | lm_datasets = tokenized_datasets.map( 460 | group_texts, 461 | batched=True, 462 | num_proc=args.preprocessing_num_workers, 463 | load_from_cache_file=True, 464 | desc=f"Grouping texts in chunks of {block_size}", 465 | ) 466 | 467 | train_dataset = lm_datasets["train"] 468 | 469 | return train_dataset 470 | 471 | 472 | class MiniDataset: 473 | 474 | def __init__(self, max_size, small_batch_size): 475 | self.dataset = [] 476 | self.max_size = max_size 477 | self.small_batch_size = small_batch_size 478 | 479 | def seperate(self): 480 | small_dataset = [] 481 | for large_batch in self.dataset: 482 | if type(large_batch) == list or type(large_batch) == tuple: 483 | large_size = len(large_batch[0]) 484 | elif type(large_batch) == dict: 485 | large_size = len(large_batch[list(large_batch.keys())[0]]) 486 | else: 487 | large_size = len(large_batch) 488 | for i in range(0, large_size, self.small_batch_size): 489 | if type(large_batch) == list or type(large_batch) == tuple: 490 | small_dataset.append( 491 | [x[i:i + self.small_batch_size] for x in large_batch]) 492 | elif type(large_batch) == dict: 493 | small_dataset.append({ 494 | k: v[i:i + self.small_batch_size] 495 | for k, v in large_batch.items() 496 | }) 497 | else: 498 | small_dataset.append(large_batch[i:i + 499 | self.small_batch_size]) 500 | self.free() 501 | 502 | return small_dataset 503 | 504 | def add(self, data): 505 | if len(self.dataset) < self.max_size: 506 | self.dataset.append(data) 507 | if len(self.dataset) == self.max_size: 508 | return self.seperate() 509 | else: 510 | return None 511 | else: 512 | raise ValueError( 513 | "The dataset is full but we did not stop it. There is a bug in the code." 514 | ) 515 | 516 | def free(self): 517 | self.dataset = [] 518 | -------------------------------------------------------------------------------- /src/bnnt/data/raw_datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import re 5 | 6 | import torch 7 | # DeepSpeed Team 8 | from datasets import load_dataset 9 | from torch.utils.data import Subset, Dataset 10 | 11 | 12 | # The template prompt dataset class that all new dataset porting needs to 13 | # follow in order to have a unified API and unified data format. 14 | class PromptRawDataset(object): 15 | 16 | def __init__(self, output_path, seed, local_rank): 17 | self.output_path = output_path 18 | self.seed = seed 19 | self.local_rank = local_rank 20 | 21 | def get_train_data(self): 22 | return 23 | 24 | def get_eval_data(self): 25 | return 26 | 27 | # The prompt should be in the format of: " Human: " + actual_prompt_sentence + " Assistant:" 28 | def get_prompt(self, sample): 29 | return 30 | 31 | # The chosen response should be in the format of: " " + actual_response_sentence 32 | def get_chosen(self, sample): 33 | return 34 | 35 | # The rejected response should be in the format of: " " + actual_response_sentence 36 | # If the dataset does not have rejected response, return None 37 | def get_rejected(self, sample): 38 | return 39 | 40 | def get_prompt_and_chosen(self, sample): 41 | return 42 | 43 | def get_prompt_and_rejected(self, sample): 44 | return 45 | 46 | 47 | # English dataset. https://huggingface.co/datasets/Dahoas/rm-static 48 | class DahoasRmstaticDataset(PromptRawDataset): 49 | 50 | def __init__(self, output_path, seed, local_rank): 51 | super().__init__(output_path, seed, local_rank) 52 | self.dataset_name = "Dahoas/rm-static" 53 | self.dataset_name_clean = "Dahoas_rm_static" 54 | self.raw_datasets = load_dataset("Dahoas/rm-static") 55 | 56 | def get_train_data(self): 57 | return self.raw_datasets["train"] 58 | 59 | def get_eval_data(self): 60 | return self.raw_datasets["test"] 61 | 62 | def get_prompt(self, sample): 63 | return sample['prompt'] 64 | 65 | def get_chosen(self, sample): 66 | return sample['chosen'] 67 | 68 | def get_rejected(self, sample): 69 | return sample['rejected'] 70 | 71 | def get_prompt_and_chosen(self, sample): 72 | return sample['prompt'] + sample['chosen'] 73 | 74 | def get_prompt_and_rejected(self, sample): 75 | return sample['prompt'] + sample['rejected'] 76 | 77 | 78 | # Belleschool_math0.25K 79 | class BelleOpenSoucreDataset(PromptRawDataset): 80 | 81 | def __init__(self, output_path, seed, local_rank, data_file): 82 | eval_data_file = "utils/data/dev5K.json" 83 | super().__init__(output_path, seed, local_rank) 84 | self.dataset_name = "BelleOpenSoucre" 85 | self.dataset_name_clean = "BelleOpenSoucre" 86 | print("data_file = ", data_file) 87 | self.raw_datasets = load_dataset("json", data_files=data_file) 88 | self.dev_raw_datasets = load_dataset("json", data_files=eval_data_file) 89 | print(self.raw_datasets["train"]) 90 | 91 | def get_train_data(self): 92 | return self.raw_datasets["train"] 93 | 94 | def get_eval_data(self): 95 | return self.dev_raw_datasets["train"] 96 | 97 | def get_prompt(self, sample): 98 | return "Human: " + sample['instruction'] + sample['input'] + "\n Assistant: " 99 | 100 | def get_chosen(self, sample): 101 | return "Human: " + sample['instruction'] + sample['input'] + "\n Assistant: " 102 | 103 | def get_prompt_and_chosen(self, sample): 104 | return "Human: " + sample['instruction'] + sample['input'] + "\n Assistant: " + sample['output'] 105 | 106 | 107 | # English dataset 108 | class DahoasFullhhrlhfDataset(PromptRawDataset): 109 | 110 | def __init__(self, output_path, seed, local_rank): 111 | super().__init__(output_path, seed, local_rank) 112 | self.dataset_name = "Dahoas/full-hh-rlhf" 113 | self.dataset_name_clean = "Dahoas_full_hh_rlhf" 114 | self.raw_datasets = load_dataset("Dahoas/full-hh-rlhf") 115 | 116 | def get_train_data(self): 117 | return self.raw_datasets["train"] 118 | 119 | def get_eval_data(self): 120 | return self.raw_datasets["test"] 121 | 122 | def get_prompt(self, sample): 123 | return sample['prompt'] 124 | 125 | def get_chosen(self, sample): 126 | return sample['chosen'] 127 | 128 | def get_rejected(self, sample): 129 | return sample['rejected'] 130 | 131 | def get_prompt_and_chosen(self, sample): 132 | return sample['prompt'] + sample['chosen'] 133 | 134 | def get_prompt_and_rejected(self, sample): 135 | return sample['prompt'] + sample['rejected'] 136 | 137 | 138 | # English dataset 139 | class DahoasSyntheticinstructgptjpairwiseDataset(PromptRawDataset): 140 | 141 | def __init__(self, output_path, seed, local_rank): 142 | super().__init__(output_path, seed, local_rank) 143 | self.dataset_name = "Dahoas/synthetic-instruct-gptj-pairwise" 144 | self.dataset_name_clean = "Dahoas_synthetic_instruct_gptj_pairwise" 145 | self.raw_datasets = load_dataset( 146 | "Dahoas/synthetic-instruct-gptj-pairwise") 147 | 148 | def get_train_data(self): 149 | from .data_utils import get_raw_dataset_split_index 150 | dataset = self.raw_datasets["train"] 151 | index = get_raw_dataset_split_index(self.local_rank, self.output_path, 152 | self.dataset_name_clean, 153 | self.seed, "train_eval", "9,1", 0, 154 | len(dataset)) 155 | dataset = Subset(dataset, index) 156 | return dataset 157 | 158 | def get_eval_data(self): 159 | from .data_utils import get_raw_dataset_split_index 160 | dataset = self.raw_datasets["train"] 161 | index = get_raw_dataset_split_index(self.local_rank, self.output_path, 162 | self.dataset_name_clean, 163 | self.seed, "train_eval", "9,1", 1, 164 | len(dataset)) 165 | dataset = Subset(dataset, index) 166 | return dataset 167 | 168 | def get_prompt(self, sample): 169 | return " Human: " + sample['prompt'] + " Assistant:" 170 | 171 | def get_chosen(self, sample): 172 | return " " + sample['chosen'] 173 | 174 | def get_rejected(self, sample): 175 | return " " + sample['rejected'] 176 | 177 | def get_prompt_and_chosen(self, sample): 178 | return " Human: " + sample['prompt'] + " Assistant: " + sample['chosen'] 179 | 180 | def get_prompt_and_rejected(self, sample): 181 | return " Human: " + sample['prompt'] + " Assistant: " + sample[ 182 | 'rejected'] 183 | 184 | 185 | # English dataset 186 | class YitingxieRlhfrewarddatasetsDataset(PromptRawDataset): 187 | 188 | def __init__(self, output_path, seed, local_rank): 189 | super().__init__(output_path, seed, local_rank) 190 | self.dataset_name = "yitingxie/rlhf-reward-datasets" 191 | self.dataset_name_clean = "yitingxie_rlhf_reward_datasets" 192 | self.raw_datasets = load_dataset("yitingxie/rlhf-reward-datasets") 193 | 194 | def get_train_data(self): 195 | return self.raw_datasets["train"] 196 | 197 | def get_eval_data(self): 198 | return self.raw_datasets["test"] 199 | 200 | def get_prompt(self, sample): 201 | return sample['prompt'] + "Assistant:" 202 | 203 | def get_chosen(self, sample): 204 | return sample['chosen'].split("Assistant:")[-1] 205 | 206 | def get_rejected(self, sample): 207 | return sample['rejected'].split("Assistant:")[-1] 208 | 209 | def get_prompt_and_chosen(self, sample): 210 | return sample['prompt'] + sample['chosen'] 211 | 212 | def get_prompt_and_rejected(self, sample): 213 | return sample['prompt'] + sample['rejected'] 214 | 215 | 216 | # English dataset 217 | class OpenaiWebgptcomparisonsDataset(PromptRawDataset): 218 | 219 | def __init__(self, output_path, seed, local_rank): 220 | super().__init__(output_path, seed, local_rank) 221 | self.dataset_name = "openai/webgpt_comparisons" 222 | self.dataset_name_clean = "openai_webgpt_comparisons" 223 | self.raw_datasets = load_dataset("openai/webgpt_comparisons") 224 | 225 | def get_train_data(self): 226 | from .data_utils import get_raw_dataset_split_index 227 | dataset = self.raw_datasets["train"] 228 | index = get_raw_dataset_split_index(self.local_rank, self.output_path, 229 | self.dataset_name_clean, 230 | self.seed, "train_eval", "9,1", 0, 231 | len(dataset)) 232 | dataset = Subset(dataset, index) 233 | return dataset 234 | 235 | def get_eval_data(self): 236 | from .data_utils import get_raw_dataset_split_index 237 | dataset = self.raw_datasets["train"] 238 | index = get_raw_dataset_split_index(self.local_rank, self.output_path, 239 | self.dataset_name_clean, 240 | self.seed, "train_eval", "9,1", 1, 241 | len(dataset)) 242 | dataset = Subset(dataset, index) 243 | return dataset 244 | 245 | def get_prompt(self, sample): 246 | return " Human: " + sample['question']['full_text'] + " Assistant:" 247 | 248 | def get_chosen(self, sample): 249 | if float(sample['score_0']) >= float(sample['score_1']): 250 | response = sample['answer_0'] 251 | else: 252 | response = sample['answer_1'] 253 | # This data has citation square brackets and numbers (e.g., "[1]"). 254 | # Right now we are not doing browser-assisted finetuning, thus we 255 | # remove these citations to avoid confusing the model. 256 | response = re.sub(r" [\(\[].*?[\)\]]", "", response) 257 | response = re.sub(r"[\(\[].*?[\)\]]", "", response) 258 | return " " + response 259 | 260 | def get_rejected(self, sample): 261 | if float(sample['score_0']) < float(sample['score_1']): 262 | response = sample['answer_0'] 263 | else: 264 | response = sample['answer_1'] 265 | response = re.sub(r" [\(\[].*?[\)\]]", "", response) 266 | response = re.sub(r"[\(\[].*?[\)\]]", "", response) 267 | return " " + response 268 | 269 | def get_prompt_and_chosen(self, sample): 270 | if float(sample['score_0']) >= float(sample['score_1']): 271 | response = sample['answer_0'] 272 | else: 273 | response = sample['answer_1'] 274 | response = re.sub(r" [\(\[].*?[\)\]]", "", response) 275 | response = re.sub(r"[\(\[].*?[\)\]]", "", response) 276 | return " Human: " + sample['question'][ 277 | 'full_text'] + " Assistant: " + response 278 | 279 | def get_prompt_and_rejected(self, sample): 280 | if float(sample['score_0']) < float(sample['score_1']): 281 | response = sample['answer_0'] 282 | else: 283 | response = sample['answer_1'] 284 | response = re.sub(r" [\(\[].*?[\)\]]", "", response) 285 | response = re.sub(r"[\(\[].*?[\)\]]", "", response) 286 | return " Human: " + sample['question'][ 287 | 'full_text'] + " Assistant: " + response 288 | 289 | 290 | # English dataset 291 | class StanfordnlpSHPDataset(PromptRawDataset): 292 | 293 | def __init__(self, output_path, seed, local_rank): 294 | super().__init__(output_path, seed, local_rank) 295 | self.dataset_name = "stanfordnlp/SHP" 296 | self.dataset_name_clean = "stanfordnlp_SHP" 297 | self.raw_datasets = load_dataset("stanfordnlp/SHP") 298 | 299 | def get_train_data(self): 300 | return self.raw_datasets["train"] 301 | 302 | def get_eval_data(self): 303 | return self.raw_datasets["validation"] 304 | 305 | def get_prompt(self, sample): 306 | return " Human: " + sample['history'] + " Assistant:" 307 | 308 | def get_chosen(self, sample): 309 | if int(sample["labels"]) == 1: 310 | response = sample["human_ref_A"] 311 | else: 312 | response = sample["human_ref_B"] 313 | return " " + response 314 | 315 | def get_rejected(self, sample): 316 | if int(sample["labels"]) == 1: 317 | response = sample["human_ref_B"] 318 | else: 319 | response = sample["human_ref_A"] 320 | return " " + response 321 | 322 | def get_prompt_and_chosen(self, sample): 323 | if int(sample["labels"]) == 1: 324 | response = sample["human_ref_A"] 325 | else: 326 | response = sample["human_ref_B"] 327 | return " Human: " + sample['history'] + " Assistant: " + response 328 | 329 | def get_prompt_and_rejected(self, sample): 330 | if int(sample["labels"]) == 1: 331 | response = sample["human_ref_B"] 332 | else: 333 | response = sample["human_ref_A"] 334 | return " Human: " + sample['history'] + " Assistant: " + response 335 | 336 | 337 | # Chinese dataset 338 | class Wangrui6ZhihuKOLDataset(PromptRawDataset): 339 | 340 | def __init__(self, output_path, seed, local_rank): 341 | super().__init__(output_path, seed, local_rank) 342 | self.dataset_name = "wangrui6/Zhihu-KOL" 343 | self.dataset_name_clean = "wangrui6_Zhihu_KOL" 344 | self.raw_datasets = load_dataset("wangrui6/Zhihu-KOL") 345 | 346 | def get_train_data(self): 347 | from .data_utils import get_raw_dataset_split_index 348 | dataset = self.raw_datasets["train"] 349 | index = get_raw_dataset_split_index(self.local_rank, self.output_path, 350 | self.dataset_name_clean, 351 | self.seed, "train_eval", "9,1", 0, 352 | len(dataset)) 353 | dataset = Subset(dataset, index) 354 | return dataset 355 | 356 | def get_eval_data(self): 357 | from .data_utils import get_raw_dataset_split_index 358 | dataset = self.raw_datasets["train"] 359 | index = get_raw_dataset_split_index(self.local_rank, self.output_path, 360 | self.dataset_name_clean, 361 | self.seed, "train_eval", "9,1", 1, 362 | len(dataset)) 363 | dataset = Subset(dataset, index) 364 | return dataset 365 | 366 | def get_prompt(self, sample): 367 | if sample['INSTRUCTION'] is not None: 368 | return " Human: " + sample['INSTRUCTION'] + " Assistant:" 369 | return None 370 | 371 | def get_chosen(self, sample): 372 | if sample['RESPONSE'] is not None: 373 | return " " + sample['RESPONSE'] 374 | return None 375 | 376 | def get_rejected(self, sample): 377 | print( 378 | f"Warning: dataset {self.dataset_name} does not include rejected response." 379 | ) 380 | return None 381 | 382 | def get_prompt_and_chosen(self, sample): 383 | if sample['INSTRUCTION'] is not None and sample['RESPONSE'] is not None: 384 | return " Human: " + sample[ 385 | 'INSTRUCTION'] + " Assistant: " + sample['RESPONSE'] 386 | return None 387 | 388 | def get_prompt_and_rejected(self, sample): 389 | print( 390 | f"Warning: dataset {self.dataset_name} does not include rejected response." 391 | ) 392 | return None 393 | 394 | 395 | # Chinese dataset 396 | class CohereMiraclzhqueries2212Dataset(PromptRawDataset): 397 | 398 | def __init__(self, output_path, seed, local_rank): 399 | super().__init__(output_path, seed, local_rank) 400 | self.dataset_name = "Cohere/miracl-zh-queries-22-12" 401 | self.dataset_name_clean = "Cohere_miracl_zh_queries_22_12" 402 | self.raw_datasets = load_dataset("Cohere/miracl-zh-queries-22-12") 403 | 404 | def get_train_data(self): 405 | return self.raw_datasets["train"] 406 | 407 | def get_eval_data(self): 408 | return self.raw_datasets["dev"] 409 | 410 | def get_prompt(self, sample): 411 | return " Human: " + sample['query'] + " Assistant:" 412 | 413 | def get_chosen(self, sample): 414 | return " " + sample['positive_passages'][0]['text'] 415 | 416 | def get_rejected(self, sample): 417 | return " " + sample['negative_passages'][0]['text'] 418 | 419 | def get_prompt_and_chosen(self, sample): 420 | return " Human: " + sample['query'] + " Assistant: " + sample[ 421 | 'positive_passages'][0]['text'] 422 | 423 | def get_prompt_and_rejected(self, sample): 424 | return " Human: " + sample['query'] + " Assistant: " + sample[ 425 | 'negative_passages'][0]['text'] 426 | 427 | 428 | # Chinese dataset 429 | class HelloSimpleAIHC3ChineseDataset(PromptRawDataset): 430 | 431 | def __init__(self, output_path, seed, local_rank): 432 | super().__init__(output_path, seed, local_rank) 433 | self.dataset_name = "Hello-SimpleAI/HC3-Chinese" 434 | self.dataset_name_clean = "Hello_SimpleAI_HC3_Chinese" 435 | self.raw_datasets = load_dataset("Hello-SimpleAI/HC3-Chinese", "all") 436 | 437 | def get_train_data(self): 438 | from .data_utils import get_raw_dataset_split_index 439 | dataset = self.raw_datasets["train"] 440 | index = get_raw_dataset_split_index(self.local_rank, self.output_path, 441 | self.dataset_name_clean, 442 | self.seed, "train_eval", "9,1", 0, 443 | len(dataset)) 444 | dataset = Subset(dataset, index) 445 | return dataset 446 | 447 | def get_eval_data(self): 448 | from .data_utils import get_raw_dataset_split_index 449 | dataset = self.raw_datasets["train"] 450 | index = get_raw_dataset_split_index(self.local_rank, self.output_path, 451 | self.dataset_name_clean, 452 | self.seed, "train_eval", "9,1", 1, 453 | len(dataset)) 454 | dataset = Subset(dataset, index) 455 | return dataset 456 | 457 | def get_prompt(self, sample): 458 | if sample['question'] is not None: 459 | return " Human: " + sample['question'] + " Assistant:" 460 | return None 461 | 462 | def get_chosen(self, sample): 463 | if sample['human_answers'][0] is not None: 464 | return " " + sample['human_answers'][0] 465 | return None 466 | 467 | def get_rejected(self, sample): 468 | print( 469 | f"Warning: dataset {self.dataset_name} does not include rejected response." 470 | ) 471 | return None 472 | 473 | def get_prompt_and_chosen(self, sample): 474 | if sample['question'] is not None and sample['human_answers'][ 475 | 0] is not None: 476 | return " Human: " + sample['question'] + " Assistant: " + sample[ 477 | 'human_answers'][0] 478 | return None 479 | 480 | def get_prompt_and_rejected(self, sample): 481 | print( 482 | f"Warning: dataset {self.dataset_name} does not include rejected response." 483 | ) 484 | return None 485 | 486 | 487 | # Chinese dataset 488 | class MkqaChineseDataset(PromptRawDataset): 489 | 490 | def __init__(self, output_path, seed, local_rank): 491 | super().__init__(output_path, seed, local_rank) 492 | self.dataset_name = "mkqa-Chinese" 493 | self.dataset_name_clean = "mkqa" 494 | self.raw_datasets = load_dataset("mkqa") 495 | 496 | def get_train_data(self): 497 | from .data_utils import get_raw_dataset_split_index 498 | dataset = self.raw_datasets["train"] 499 | index = get_raw_dataset_split_index(self.local_rank, self.output_path, 500 | self.dataset_name_clean, 501 | self.seed, "train_eval", "9,1", 0, 502 | len(dataset)) 503 | dataset = Subset(dataset, index) 504 | return dataset 505 | 506 | def get_eval_data(self): 507 | from .data_utils import get_raw_dataset_split_index 508 | dataset = self.raw_datasets["train"] 509 | index = get_raw_dataset_split_index(self.local_rank, self.output_path, 510 | self.dataset_name_clean, 511 | self.seed, "train_eval", "9,1", 1, 512 | len(dataset)) 513 | dataset = Subset(dataset, index) 514 | return dataset 515 | 516 | def get_prompt(self, sample): 517 | if sample['queries']['zh_cn'] is not None: 518 | return " Human: " + sample['queries']['zh_cn'] + " Assistant:" 519 | return None 520 | 521 | def get_chosen(self, sample): 522 | if sample['answers']['zh_cn'][0]['text'] is not None: 523 | return " " + sample['answers']['zh_cn'][0]['text'] 524 | return None 525 | 526 | def get_rejected(self, sample): 527 | print( 528 | f"Warning: dataset {self.dataset_name} does not include rejected response." 529 | ) 530 | return None 531 | 532 | def get_prompt_and_chosen(self, sample): 533 | if sample['queries']['zh_cn'] is not None and sample['answers'][ 534 | 'zh_cn'][0]['text'] is not None: 535 | return " Human: " + sample['queries'][ 536 | 'zh_cn'] + " Assistant: " + sample['answers']['zh_cn'][0][ 537 | 'text'] 538 | return None 539 | 540 | def get_prompt_and_rejected(self, sample): 541 | print( 542 | f"Warning: dataset {self.dataset_name} does not include rejected response." 543 | ) 544 | return None 545 | 546 | 547 | # Japanese dataset 548 | class MkqaJapaneseDataset(PromptRawDataset): 549 | 550 | def __init__(self, output_path, seed, local_rank): 551 | super().__init__(output_path, seed, local_rank) 552 | self.dataset_name = "mkqa-Japanese" 553 | self.dataset_name_clean = "mkqa" 554 | self.raw_datasets = load_dataset("mkqa") 555 | 556 | def get_train_data(self): 557 | from .data_utils import get_raw_dataset_split_index 558 | dataset = self.raw_datasets["train"] 559 | index = get_raw_dataset_split_index(self.local_rank, self.output_path, 560 | self.dataset_name_clean, 561 | self.seed, "train_eval", "9,1", 0, 562 | len(dataset)) 563 | dataset = Subset(dataset, index) 564 | return dataset 565 | 566 | def get_eval_data(self): 567 | from .data_utils import get_raw_dataset_split_index 568 | dataset = self.raw_datasets["train"] 569 | index = get_raw_dataset_split_index(self.local_rank, self.output_path, 570 | self.dataset_name_clean, 571 | self.seed, "train_eval", "9,1", 1, 572 | len(dataset)) 573 | dataset = Subset(dataset, index) 574 | return dataset 575 | 576 | def get_prompt(self, sample): 577 | if sample['queries']['ja'] is not None: 578 | return " Human: " + sample['queries']['ja'] + " Assistant:" 579 | return None 580 | 581 | def get_chosen(self, sample): 582 | if sample['answers']['ja'][0]['text'] is not None: 583 | return " " + sample['answers']['ja'][0]['text'] 584 | return None 585 | 586 | def get_rejected(self, sample): 587 | print( 588 | f"Warning: dataset {self.dataset_name} does not include rejected response." 589 | ) 590 | return None 591 | 592 | def get_prompt_and_chosen(self, sample): 593 | if sample['queries']['ja'] is not None and sample['answers']['ja'][0][ 594 | 'text'] is not None: 595 | return " Human: " + sample['queries'][ 596 | 'ja'] + " Assistant: " + sample['answers']['ja'][0]['text'] 597 | return None 598 | 599 | def get_prompt_and_rejected(self, sample): 600 | print( 601 | f"Warning: dataset {self.dataset_name} does not include rejected response." 602 | ) 603 | return None 604 | 605 | 606 | # Japanese dataset 607 | class CohereMiracljaqueries2212Dataset(PromptRawDataset): 608 | 609 | def __init__(self, output_path, seed, local_rank): 610 | super().__init__(output_path, seed, local_rank) 611 | self.dataset_name = "Cohere/miracl-ja-queries-22-12" 612 | self.dataset_name_clean = "Cohere_miracl_ja_queries_22_12" 613 | self.raw_datasets = load_dataset("Cohere/miracl-ja-queries-22-12") 614 | 615 | def get_train_data(self): 616 | return self.raw_datasets["train"] 617 | 618 | def get_eval_data(self): 619 | return self.raw_datasets["dev"] 620 | 621 | def get_prompt(self, sample): 622 | return " Human: " + sample['query'] + " Assistant:" 623 | 624 | def get_chosen(self, sample): 625 | return " " + sample['positive_passages'][0]['text'] 626 | 627 | def get_rejected(self, sample): 628 | return " " + sample['negative_passages'][0]['text'] 629 | 630 | def get_prompt_and_chosen(self, sample): 631 | return " Human: " + sample['query'] + " Assistant: " + sample[ 632 | 'positive_passages'][0]['text'] 633 | 634 | def get_prompt_and_rejected(self, sample): 635 | return " Human: " + sample['query'] + " Assistant: " + sample[ 636 | 'negative_passages'][0]['text'] 637 | 638 | 639 | # Japanese dataset 640 | class LmqgQgjaquadDataset(PromptRawDataset): 641 | 642 | def __init__(self, output_path, seed, local_rank): 643 | super().__init__(output_path, seed, local_rank) 644 | self.dataset_name = "lmqg/qg_jaquad" 645 | self.dataset_name_clean = "lmqg_qg_jaquad" 646 | self.raw_datasets = load_dataset("lmqg/qg_jaquad") 647 | 648 | def get_train_data(self): 649 | return self.raw_datasets["train"] 650 | 651 | def get_eval_data(self): 652 | return self.raw_datasets["validation"] 653 | 654 | def get_prompt(self, sample): 655 | return " Human: " + sample['question'] + " Assistant:" 656 | 657 | def get_chosen(self, sample): 658 | return " " + sample['sentence'] 659 | 660 | def get_rejected(self, sample): 661 | print( 662 | f"Warning: dataset {self.dataset_name} does not include rejected response." 663 | ) 664 | return None 665 | 666 | def get_prompt_and_chosen(self, sample): 667 | return " Human: " + sample['question'] + " Assistant: " + sample[ 668 | 'sentence'] 669 | 670 | def get_prompt_and_rejected(self, sample): 671 | print( 672 | f"Warning: dataset {self.dataset_name} does not include rejected response." 673 | ) 674 | return None 675 | 676 | 677 | # Japanese dataset 678 | class LmqgQagjaquadDataset(PromptRawDataset): 679 | 680 | def __init__(self, output_path, seed, local_rank): 681 | super().__init__(output_path, seed, local_rank) 682 | self.dataset_name = "lmqg/qag_jaquad" 683 | self.dataset_name_clean = "lmqg_qag_jaquad" 684 | self.raw_datasets = load_dataset("lmqg/qag_jaquad") 685 | 686 | def get_train_data(self): 687 | return self.raw_datasets["train"] 688 | 689 | def get_eval_data(self): 690 | return self.raw_datasets["validation"] 691 | 692 | def get_prompt(self, sample): 693 | return " Human: " + sample['questions'][0] + " Assistant:" 694 | 695 | def get_chosen(self, sample): 696 | return " " + sample['paragraph'] 697 | 698 | def get_rejected(self, sample): 699 | print( 700 | f"Warning: dataset {self.dataset_name} does not include rejected response." 701 | ) 702 | return None 703 | 704 | def get_prompt_and_chosen(self, sample): 705 | return " Human: " + sample['questions'][0] + " Assistant: " + sample[ 706 | 'paragraph'] 707 | 708 | def get_prompt_and_rejected(self, sample): 709 | print( 710 | f"Warning: dataset {self.dataset_name} does not include rejected response." 711 | ) 712 | return None 713 | 714 | 715 | class BNNTDataset(Dataset): 716 | 717 | def __init__(self, data, tokenizer, args): 718 | self.data = data 719 | self.tokenizer = tokenizer 720 | self.args = args 721 | 722 | def __getitem__(self, idx): 723 | tokenizer = self.tokenizer 724 | args = self.args 725 | tmp_data = self.data[idx] 726 | chosen_sentence = f"标题:{tmp_data['title']}。摘要:{tmp_data['summary']}专利公开号:{tmp_data['publicNo']}。权利要求:{tmp_data['powerRequirements']}说明书:{tmp_data['instructions']}" 727 | chosen_token = tokenizer(chosen_sentence, 728 | max_length=args.max_seq_len, 729 | padding="max_length", 730 | truncation=True) 731 | 732 | chosen_token["input_ids"][-1] = tokenizer.eos_token_id 733 | 734 | chosen_token["labels"] = torch.LongTensor( 735 | [-100] + [-100 if tokenizer.pad_token_id == j else j for j in chosen_token["input_ids"]][1:]) 736 | 737 | chosen_token["input_ids"] = torch.LongTensor(chosen_token["input_ids"]).squeeze(0) 738 | chosen_token["attention_mask"] = torch.LongTensor(chosen_token["attention_mask"]).squeeze(0) 739 | chosen_token["labels"] = torch.LongTensor(chosen_token["labels"]).squeeze(0) 740 | return chosen_token 741 | 742 | def __len__(self): 743 | return len(self.data) 744 | -------------------------------------------------------------------------------- /src/bnnt/ds_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | # DeepSpeed Team 5 | GLOBAL_BATCH_SIZE = 32 6 | MICRO_BATCH_SIZE = 4 7 | 8 | 9 | def get_train_ds_config(offload, 10 | stage=2, 11 | enable_hybrid_engine=False, 12 | inference_tp_size=1, 13 | release_inference_cache=False, 14 | pin_parameters=True, 15 | tp_gather_partition_size=8): 16 | 17 | device = "cpu" if offload else "none" 18 | zero_opt_dict = { 19 | "stage": stage, 20 | "offload_param": { 21 | "device": device 22 | }, 23 | "offload_optimizer": { 24 | "device": device 25 | }, 26 | "stage3_param_persistence_threshold": 1e4, 27 | "stage3_max_live_parameters": 3e7, 28 | "stage3_prefetch_bucket_size": 3e7, 29 | "memory_efficient_linear": False 30 | } 31 | return { 32 | "train_batch_size": GLOBAL_BATCH_SIZE, 33 | "train_micro_batch_size_per_gpu": MICRO_BATCH_SIZE, 34 | "steps_per_print": 10, 35 | "zero_optimization": zero_opt_dict, 36 | "fp16": { 37 | "enabled": True, 38 | "loss_scale_window": 100 39 | }, 40 | "gradient_clipping": 1.0, 41 | "prescale_gradients": False, 42 | "wall_clock_breakdown": False, 43 | "hybrid_engine": { 44 | "enabled": enable_hybrid_engine, 45 | "inference_tp_size": inference_tp_size, 46 | "release_inference_cache": release_inference_cache, 47 | "pin_parameters": pin_parameters, 48 | "tp_gather_partition_size": tp_gather_partition_size, 49 | } 50 | } 51 | 52 | 53 | def get_eval_ds_config(offload, stage=0): 54 | device = "cpu" if offload else "none" 55 | zero_opt_dict = { 56 | "stage": stage, 57 | "stage3_param_persistence_threshold": 1e4, 58 | "offload_param": { 59 | "device": device 60 | }, 61 | "memory_efficient_linear": False 62 | } 63 | return { 64 | "train_batch_size": GLOBAL_BATCH_SIZE, 65 | "train_micro_batch_size_per_gpu": MICRO_BATCH_SIZE, 66 | "steps_per_print": 10, 67 | "zero_optimization": zero_opt_dict, 68 | "fp16": { 69 | "enabled": True 70 | }, 71 | "gradient_clipping": 1.0, 72 | "prescale_gradients": False, 73 | "wall_clock_breakdown": False 74 | } 75 | -------------------------------------------------------------------------------- /src/bnnt/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-for-Science/MoZi/517cb2ed7b7cab0d23803c964ae671f6df405e8b/src/bnnt/model/__init__.py -------------------------------------------------------------------------------- /src/bnnt/model/model_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | # DeepSpeed Team 5 | import os 6 | import math 7 | import torch 8 | from transformers import ( 9 | AutoConfig, 10 | AutoModel, 11 | ) 12 | 13 | from transformers.deepspeed import HfDeepSpeedConfig 14 | 15 | # from .reward_model import RewardModel 16 | 17 | 18 | def create_hf_model(model_class, 19 | model_name_or_path, 20 | tokenizer, 21 | ds_config=None, 22 | rlhf_training=False): 23 | model_config = AutoConfig.from_pretrained(model_name_or_path,trust_remote_code=True) 24 | model_config.dropout = 0.0 25 | # Note: dschf is defined in function scope to avoid global effects 26 | # https://huggingface.co/docs/transformers/main_classes/deepspeed#nontrainer-deepspeed-integration 27 | if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3: 28 | dschf = HfDeepSpeedConfig(ds_config) 29 | else: 30 | dschf = None 31 | if rlhf_training: 32 | # the weight loading is handled by create critic model 33 | model = model_class.from_config(model_config,trust_remote_code=True) 34 | else: 35 | model = AutoModel.from_pretrained( 36 | model_name_or_path, 37 | from_tf=bool(".ckpt" in model_name_or_path), 38 | config=model_config,trust_remote_code=True) 39 | 40 | model.config.end_token_id = tokenizer.eos_token_id 41 | model.config.pad_token_id = model.config.eos_token_id 42 | model.resize_token_embeddings(int( 43 | 8 * 44 | math.ceil(len(tokenizer) / 8.0))) # make the vocab size multiple of 8 45 | 46 | return model 47 | 48 | 49 | # def create_critic_model(model_name_or_path, 50 | # tokenizer, 51 | # ds_config, 52 | # num_padding_at_beginning=0, 53 | # rlhf_training=False): 54 | # # OPT model family always put a padding token at the beginning of the sequence, 55 | # # we did not see this in other models but not sure if it is a general rule 56 | # critic_model = create_hf_model(AutoModel, model_name_or_path, tokenizer, 57 | # ds_config, rlhf_training) 58 | # critic_model = RewardModel( 59 | # critic_model, 60 | # tokenizer, 61 | # num_padding_at_beginning=num_padding_at_beginning) 62 | 63 | # if rlhf_training: 64 | # # critic model needs to load the weight here 65 | # model_ckpt_path = os.path.join(model_name_or_path, 'pytorch_model.bin') 66 | # assert os.path.exists( 67 | # model_ckpt_path 68 | # ), f"Cannot find model checkpoint at {model_ckpt_path}" 69 | # critic_model.load_state_dict( 70 | # torch.load(model_ckpt_path, map_location='cpu')) 71 | 72 | # return critic_model 73 | -------------------------------------------------------------------------------- /src/bnnt/module/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-for-Science/MoZi/517cb2ed7b7cab0d23803c964ae671f6df405e8b/src/bnnt/module/__init__.py -------------------------------------------------------------------------------- /src/bnnt/module/lora.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | # DeepSpeed Team 5 | import math 6 | import torch 7 | from torch import nn 8 | import torch.nn.functional as F 9 | from deepspeed.compression.helper import recursive_getattr, recursive_setattr 10 | import deepspeed 11 | 12 | 13 | class LinearLayer_LoRA(nn.Module): 14 | # an simple implementation of LoRA 15 | # for now only support Linear Layer 16 | def __init__(self, 17 | weight, 18 | lora_dim=0, 19 | lora_scaling=1, 20 | lora_droppout=0, 21 | bias=None): 22 | super(LinearLayer_LoRA, self).__init__() 23 | self.weight = weight 24 | self.bias = bias 25 | 26 | if lora_dim <= 0: 27 | raise ValueError( 28 | "You are training to use LoRA, whose reduced dim should be larger than 1" 29 | ) 30 | 31 | try: 32 | # for zero stage 3 33 | rows, columns = weight.ds_shape 34 | except: 35 | rows, columns = weight.shape 36 | self.lora_right_weight = nn.Parameter(torch.zeros( 37 | columns, 38 | lora_dim)) # apply transpose so in forward we do not need to 39 | self.lora_left_weight = nn.Parameter(torch.zeros(lora_dim, rows)) 40 | self.lora_scaling = lora_scaling / lora_dim 41 | 42 | if lora_droppout > 0: 43 | self.lora_dropout = nn.Dropout(lora_droppout) 44 | else: 45 | self.lora_dropout = nn.Identity() 46 | 47 | self.reset_parameters() 48 | # disable the original weight gradient 49 | self.weight.requires_grad = False 50 | # fuse LoRA to the original weight 51 | self.fuse_lora = False 52 | 53 | def eval(self): 54 | self.lora_dropout.eval() 55 | 56 | # self.fuse_lora_weight() 57 | 58 | def train(self, mode=True): 59 | self.lora_dropout.train(mode) 60 | # self.unfuse_lora_weight() 61 | 62 | def reset_parameters(self): 63 | nn.init.kaiming_uniform_(self.lora_right_weight, a=math.sqrt(5)) 64 | nn.init.zeros_(self.lora_left_weight) 65 | 66 | def fuse_lora_weight(self): 67 | if not self.fuse_lora: 68 | self.weight.data += self.lora_scaling * torch.matmul( 69 | self.lora_left_weight.t(), self.lora_right_weight.t()) 70 | self.fuse_lora = True 71 | 72 | def unfuse_lora_weight(self): 73 | if self.fuse_lora: 74 | self.weight.data -= self.lora_scaling * torch.matmul( 75 | self.lora_left_weight.t(), self.lora_right_weight.t()) 76 | self.fuse_lora = False 77 | 78 | def forward(self, input): 79 | if self.fuse_lora: 80 | return F.linear(input, self.weight, self.bias) 81 | else: 82 | return F.linear( 83 | input, self.weight, 84 | self.bias) + (self.lora_dropout(input) @ self.lora_right_weight 85 | @ self.lora_left_weight) * self.lora_scaling 86 | 87 | 88 | # convert the linear layer to LoRA 89 | def convert_linear_layer_to_lora(model, 90 | part_module_name, 91 | lora_dim=0, 92 | lora_scaling=1, 93 | lora_droppout=0): 94 | def set_params(param_modules, name): 95 | for each_ in param_modules: 96 | if each_ in name: 97 | return True 98 | return False 99 | 100 | if type(part_module_name)==str: 101 | part_module_name = [part_module_name] 102 | 103 | repalce_name = [] 104 | for name, module in model.named_modules(): 105 | if isinstance(module, nn.Linear) and set_params(part_module_name, name): 106 | repalce_name.append(name) 107 | 108 | print("repalce_name : ", repalce_name) 109 | for name in repalce_name: 110 | module = recursive_getattr(model, name) 111 | tmp = LinearLayer_LoRA( 112 | module.weight, lora_dim, lora_scaling, lora_droppout, 113 | module.bias).to(module.weight.device).to(module.weight.dtype) 114 | recursive_setattr(model, name, tmp) 115 | return model 116 | 117 | 118 | def convert_LLaMA_to_lora(model, 119 | part_module_name, 120 | lora_dim=8, 121 | lora_scaling=16, 122 | lora_droppout=0.05): 123 | ''' 124 | model.layers.29.mlp 125 | model.layers.29.mlp.gate_proj 126 | model.layers.29.mlp.down_proj 127 | model.layers.29.mlp.up_proj 128 | model.layers.29.mlp.act_fn 129 | model.layers.29.input_layernorm 130 | model.layers.29.post_attention_layernorm 131 | model.layers.30 132 | model.layers.30.self_attn 133 | model.layers.30.self_attn.q_proj 134 | model.layers.30.self_attn.k_proj 135 | model.layers.30.self_attn.v_proj 136 | model.layers.30.self_attn.o_proj 137 | model.layers.30.self_attn.rotary_emb 138 | ''' 139 | def set_params(param_modules, name): 140 | for each_ in param_modules: 141 | if each_ in name: 142 | return True 143 | return False 144 | 145 | part_module_name = [ 146 | "q_proj", 147 | "k_proj", 148 | "v_proj", 149 | "down_proj", 150 | "gate_proj", 151 | "up_proj" 152 | ] 153 | repalce_name = [] 154 | for name, module in model.named_modules(): 155 | if isinstance(module, nn.Linear) and set_params(part_module_name, name): 156 | repalce_name.append(name) 157 | 158 | print("repalce_name : ", repalce_name) 159 | for name in repalce_name: 160 | module = recursive_getattr(model, name) 161 | tmp = LinearLayer_LoRA( 162 | module.weight, lora_dim, lora_scaling, lora_droppout, 163 | module.bias).to(module.weight.device).to(module.weight.dtype) 164 | recursive_setattr(model, name, tmp) 165 | return model 166 | 167 | 168 | def _z3_params_to_fetch(param_list): 169 | return [ 170 | p for p in param_list 171 | if hasattr(p, 'ds_id') and p.ds_status == deepspeed.runtime.zero. 172 | partition_parameters.ZeroParamStatus.NOT_AVAILABLE 173 | ] 174 | 175 | 176 | # convert the LoRA layer to linear layer 177 | def convert_lora_to_linear_layer(model): 178 | repalce_name = [] 179 | for name, module in model.named_modules(): 180 | if isinstance(module, LinearLayer_LoRA): 181 | repalce_name.append(name) 182 | for name in repalce_name: 183 | module = recursive_getattr(model, name) 184 | zero_stage_3 = hasattr(module.weight, 'ds_id') 185 | with deepspeed.zero.GatheredParameters(_z3_params_to_fetch([ 186 | module.weight, module.bias, module.lora_left_weight, 187 | module.lora_right_weight 188 | ]), 189 | modifier_rank=0, 190 | enabled=zero_stage_3): 191 | module.fuse_lora_weight() 192 | return model 193 | 194 | 195 | def only_optimize_lora_parameters(model): 196 | # turn off the gradient of all the parameters except the LoRA parameters 197 | for name, param in model.named_parameters(): 198 | if "lora_right_weight" in name or "lora_left_weight" in name: 199 | param.requires_grad = True 200 | else: 201 | param.requires_grad = False 202 | return model 203 | -------------------------------------------------------------------------------- /src/bnnt/module/test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "ename": "ImportError", 10 | "evalue": "cannot import name 'LlamaForCausalLM' from 'transformers' (/home/jiyunjie001/anaconda3/lib/python3.9/site-packages/transformers/__init__.py)", 11 | "output_type": "error", 12 | "traceback": [ 13 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 14 | "\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)", 15 | "\u001b[0;32m/tmp/ipykernel_70206/1328397760.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mtransformers\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mLlamaForCausalLM\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mLlamaTokenizer\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 16 | "\u001b[0;31mImportError\u001b[0m: cannot import name 'LlamaForCausalLM' from 'transformers' (/home/jiyunjie001/anaconda3/lib/python3.9/site-packages/transformers/__init__.py)" 17 | ] 18 | } 19 | ], 20 | "source": [ 21 | "from transformers import LlamaForCausalLM, LlamaTokenizer" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [] 30 | } 31 | ], 32 | "metadata": { 33 | "kernelspec": { 34 | "display_name": "base", 35 | "language": "python", 36 | "name": "python3" 37 | }, 38 | "language_info": { 39 | "codemirror_mode": { 40 | "name": "ipython", 41 | "version": 3 42 | }, 43 | "file_extension": ".py", 44 | "mimetype": "text/x-python", 45 | "name": "python", 46 | "nbconvert_exporter": "python", 47 | "pygments_lexer": "ipython3", 48 | "version": "3.9.13" 49 | }, 50 | "orig_nbformat": 4, 51 | "vscode": { 52 | "interpreter": { 53 | "hash": "20634b102e5ce792587e2ff69e667181c6bf9c98b3c846a57d51ac3559caad44" 54 | } 55 | } 56 | }, 57 | "nbformat": 4, 58 | "nbformat_minor": 2 59 | } 60 | -------------------------------------------------------------------------------- /src/bnnt/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | # DeepSpeed Team 5 | import os 6 | import torch 7 | import random 8 | import numpy as np 9 | from transformers import set_seed 10 | import deepspeed 11 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus 12 | 13 | 14 | def print_rank_0(msg, rank=0): 15 | if rank <= 0: 16 | print(msg) 17 | 18 | 19 | def to_device(batch, device): 20 | output = {} 21 | for k, v in batch.items(): 22 | try: 23 | output[k] = v.to(device) 24 | except: 25 | output[k] = v 26 | return output 27 | 28 | 29 | class MovingAverage: 30 | 31 | def __init__(self): 32 | self.count = 0 33 | self.total = 0 34 | self.mean = 0 35 | 36 | def update(self, num): 37 | self.total += num 38 | self.count += 1 39 | self.mean = self.total / self.count 40 | 41 | return self.mean 42 | 43 | 44 | def save_hf_format(model, tokenizer, args, sub_folder=""): 45 | # used to save huggingface format, so we can use it for hf.from_pretrained 46 | model_to_save = model.module if hasattr(model, 'module') else model 47 | CONFIG_NAME = "config.json" 48 | WEIGHTS_NAME = "pytorch_model.bin" 49 | output_dir = os.path.join(args.output_dir, sub_folder) 50 | if not os.path.exists(output_dir): 51 | os.makedirs(output_dir) 52 | output_model_file = os.path.join(output_dir, WEIGHTS_NAME) 53 | output_config_file = os.path.join(output_dir, CONFIG_NAME) 54 | save_dict = model_to_save.state_dict() 55 | for key in list(save_dict.keys()): 56 | if "lora" in key: 57 | del save_dict[key] 58 | torch.save(save_dict, output_model_file) 59 | model_to_save.config.to_json_file(output_config_file) 60 | tokenizer.save_vocabulary(output_dir) 61 | 62 | 63 | def set_random_seed(seed): 64 | if seed is not None: 65 | set_seed(seed) 66 | random.seed(seed) 67 | np.random.seed(seed) 68 | torch.manual_seed(seed) 69 | torch.cuda.manual_seed_all(seed) 70 | 71 | 72 | def get_all_reduce_mean(tensor): 73 | torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM) 74 | tensor = tensor / torch.distributed.get_world_size() 75 | return tensor 76 | 77 | 78 | def get_optimizer_grouped_parameters(model, 79 | weight_decay, 80 | no_decay_name_list=[ 81 | "bias", "LayerNorm.weight" 82 | ]): 83 | optimizer_grouped_parameters = [ 84 | { 85 | "params": [ 86 | p for n, p in model.named_parameters() 87 | if (not any(nd in n 88 | for nd in no_decay_name_list) and p.requires_grad) 89 | ], 90 | "weight_decay": 91 | weight_decay, 92 | }, 93 | { 94 | "params": [ 95 | p for n, p in model.named_parameters() 96 | if (any(nd in n 97 | for nd in no_decay_name_list) and p.requires_grad) 98 | ], 99 | "weight_decay": 100 | 0.0, 101 | }, 102 | ] 103 | return optimizer_grouped_parameters 104 | 105 | 106 | def _z3_params_to_fetch(param_list): 107 | return [ 108 | p for p in param_list 109 | if hasattr(p, 'ds_id') and p.ds_status == ZeroParamStatus.NOT_AVAILABLE 110 | ] 111 | 112 | 113 | def moving_average(model, model_ema, beta=0.992, device=None, zero_stage=0): 114 | zero_stage_3 = (zero_stage == 3) 115 | with torch.no_grad(): 116 | for param, param_ema in zip(model.parameters(), 117 | model_ema.parameters()): 118 | # TODO: use prefiltering for efficiency 119 | params_to_fetch = _z3_params_to_fetch([param, param_ema 120 | ]) if zero_stage_3 else [] 121 | should_gather_param = len(params_to_fetch) > 0 122 | with deepspeed.zero.GatheredParameters( 123 | params_to_fetch, enabled=should_gather_param): 124 | data = param.data 125 | if device is not None: 126 | data = data.to(device) 127 | param_ema.data.copy_(torch.lerp(data, param_ema.data, beta)) 128 | 129 | 130 | def save_zero_three_model(model_ema, global_rank, save_dir, zero_stage=0): 131 | zero_stage_3 = (zero_stage == 3) 132 | os.makedirs(save_dir, exist_ok=True) 133 | WEIGHTS_NAME = "pytorch_model.bin" 134 | output_model_file = os.path.join(save_dir, WEIGHTS_NAME) 135 | 136 | model_to_save = model_ema.module if hasattr(model_ema, 137 | 'module') else model_ema 138 | if not zero_stage_3: 139 | if global_rank == 0: 140 | torch.save(model_to_save.state_dict(), output_model_file) 141 | else: 142 | output_state_dict = {} 143 | for k, v in model_to_save.named_parameters(): 144 | 145 | if hasattr(v, 'ds_id'): 146 | with deepspeed.zero.GatheredParameters(_z3_params_to_fetch([v 147 | ]), 148 | enabled=zero_stage_3): 149 | v_p = v.data.cpu() 150 | else: 151 | v_p = v.cpu() 152 | if global_rank == 0 and "lora" not in k: 153 | output_state_dict[k] = v_p 154 | if global_rank == 0: 155 | torch.save(output_state_dict, output_model_file) 156 | del output_state_dict 157 | --------------------------------------------------------------------------------