├── .gitignore
├── LICENSE
├── README.md
├── data
├── csi300.pkl
├── daily_stock
│ ├── 2022-11-01.csv
│ ├── 2022-11-02.csv
│ ├── 2022-11-03.csv
│ ├── 2022-11-04.csv
│ ├── 2022-11-07.csv
│ ├── 2022-11-08.csv
│ ├── 2022-11-09.csv
│ ├── 2022-11-10.csv
│ ├── 2022-11-11.csv
│ ├── 2022-11-14.csv
│ ├── 2022-11-15.csv
│ ├── 2022-11-16.csv
│ ├── 2022-11-17.csv
│ ├── 2022-11-18.csv
│ ├── 2022-11-21.csv
│ ├── 2022-11-22.csv
│ ├── 2022-11-23.csv
│ ├── 2022-11-24.csv
│ ├── 2022-11-25.csv
│ ├── 2022-11-28.csv
│ ├── 2022-11-29.csv
│ ├── 2022-11-30.csv
│ ├── 2022-12-01.csv
│ ├── 2022-12-02.csv
│ ├── 2022-12-05.csv
│ ├── 2022-12-06.csv
│ ├── 2022-12-07.csv
│ ├── 2022-12-08.csv
│ ├── 2022-12-09.csv
│ ├── 2022-12-12.csv
│ ├── 2022-12-13.csv
│ ├── 2022-12-14.csv
│ ├── 2022-12-15.csv
│ ├── 2022-12-16.csv
│ ├── 2022-12-19.csv
│ ├── 2022-12-20.csv
│ ├── 2022-12-21.csv
│ ├── 2022-12-22.csv
│ ├── 2022-12-23.csv
│ ├── 2022-12-26.csv
│ ├── 2022-12-27.csv
│ ├── 2022-12-28.csv
│ ├── 2022-12-29.csv
│ └── 2022-12-30.csv
├── data_train_predict
│ ├── 2022-11-01.pkl
│ ├── 2022-11-02.pkl
│ ├── 2022-11-03.pkl
│ ├── 2022-11-04.pkl
│ ├── 2022-11-07.pkl
│ ├── 2022-11-08.pkl
│ ├── 2022-11-09.pkl
│ ├── 2022-11-10.pkl
│ ├── 2022-11-11.pkl
│ ├── 2022-11-14.pkl
│ ├── 2022-11-15.pkl
│ ├── 2022-11-16.pkl
│ ├── 2022-11-17.pkl
│ ├── 2022-11-18.pkl
│ ├── 2022-11-21.pkl
│ ├── 2022-11-22.pkl
│ ├── 2022-11-23.pkl
│ ├── 2022-11-24.pkl
│ ├── 2022-11-25.pkl
│ ├── 2022-11-28.pkl
│ ├── 2022-11-29.pkl
│ ├── 2022-11-30.pkl
│ ├── 2022-12-01.pkl
│ ├── 2022-12-02.pkl
│ ├── 2022-12-05.pkl
│ ├── 2022-12-06.pkl
│ ├── 2022-12-07.pkl
│ ├── 2022-12-08.pkl
│ ├── 2022-12-09.pkl
│ ├── 2022-12-12.pkl
│ ├── 2022-12-13.pkl
│ ├── 2022-12-14.pkl
│ ├── 2022-12-15.pkl
│ ├── 2022-12-16.pkl
│ ├── 2022-12-19.pkl
│ ├── 2022-12-20.pkl
│ ├── 2022-12-21.pkl
│ ├── 2022-12-22.pkl
│ ├── 2022-12-23.pkl
│ ├── 2022-12-26.pkl
│ ├── 2022-12-27.pkl
│ ├── 2022-12-28.pkl
│ ├── 2022-12-29.pkl
│ └── 2022-12-30.pkl
├── model_saved
│ └── 2022-12-29_epoch_60.dat
├── prediction
│ └── pred.csv
└── relation
│ ├── 2022-11-30.csv
│ └── 2022-12-30.csv
├── data_loader.py
├── main.py
├── main.sh
├── model
└── Thgnn.py
├── requirements.txt
├── trainer
└── trainer.py
└── utils
├── generate_data.py
└── generate_relation.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | GNU GENERAL PUBLIC LICENSE
2 | Version 3, 29 June 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 | Preamble
9 |
10 | The GNU General Public License is a free, copyleft license for
11 | software and other kinds of works.
12 |
13 | The licenses for most software and other practical works are designed
14 | to take away your freedom to share and change the works. By contrast,
15 | the GNU General Public License is intended to guarantee your freedom to
16 | share and change all versions of a program--to make sure it remains free
17 | software for all its users. We, the Free Software Foundation, use the
18 | GNU General Public License for most of our software; it applies also to
19 | any other work released this way by its authors. You can apply it to
20 | your programs, too.
21 |
22 | When we speak of free software, we are referring to freedom, not
23 | price. Our General Public Licenses are designed to make sure that you
24 | have the freedom to distribute copies of free software (and charge for
25 | them if you wish), that you receive source code or can get it if you
26 | want it, that you can change the software or use pieces of it in new
27 | free programs, and that you know you can do these things.
28 |
29 | To protect your rights, we need to prevent others from denying you
30 | these rights or asking you to surrender the rights. Therefore, you have
31 | certain responsibilities if you distribute copies of the software, or if
32 | you modify it: responsibilities to respect the freedom of others.
33 |
34 | For example, if you distribute copies of such a program, whether
35 | gratis or for a fee, you must pass on to the recipients the same
36 | freedoms that you received. You must make sure that they, too, receive
37 | or can get the source code. And you must show them these terms so they
38 | know their rights.
39 |
40 | Developers that use the GNU GPL protect your rights with two steps:
41 | (1) assert copyright on the software, and (2) offer you this License
42 | giving you legal permission to copy, distribute and/or modify it.
43 |
44 | For the developers' and authors' protection, the GPL clearly explains
45 | that there is no warranty for this free software. For both users' and
46 | authors' sake, the GPL requires that modified versions be marked as
47 | changed, so that their problems will not be attributed erroneously to
48 | authors of previous versions.
49 |
50 | Some devices are designed to deny users access to install or run
51 | modified versions of the software inside them, although the manufacturer
52 | can do so. This is fundamentally incompatible with the aim of
53 | protecting users' freedom to change the software. The systematic
54 | pattern of such abuse occurs in the area of products for individuals to
55 | use, which is precisely where it is most unacceptable. Therefore, we
56 | have designed this version of the GPL to prohibit the practice for those
57 | products. If such problems arise substantially in other domains, we
58 | stand ready to extend this provision to those domains in future versions
59 | of the GPL, as needed to protect the freedom of users.
60 |
61 | Finally, every program is threatened constantly by software patents.
62 | States should not allow patents to restrict development and use of
63 | software on general-purpose computers, but in those that do, we wish to
64 | avoid the special danger that patents applied to a free program could
65 | make it effectively proprietary. To prevent this, the GPL assures that
66 | patents cannot be used to render the program non-free.
67 |
68 | The precise terms and conditions for copying, distribution and
69 | modification follow.
70 |
71 | TERMS AND CONDITIONS
72 |
73 | 0. Definitions.
74 |
75 | "This License" refers to version 3 of the GNU General Public License.
76 |
77 | "Copyright" also means copyright-like laws that apply to other kinds of
78 | works, such as semiconductor masks.
79 |
80 | "The Program" refers to any copyrightable work licensed under this
81 | License. Each licensee is addressed as "you". "Licensees" and
82 | "recipients" may be individuals or organizations.
83 |
84 | To "modify" a work means to copy from or adapt all or part of the work
85 | in a fashion requiring copyright permission, other than the making of an
86 | exact copy. The resulting work is called a "modified version" of the
87 | earlier work or a work "based on" the earlier work.
88 |
89 | A "covered work" means either the unmodified Program or a work based
90 | on the Program.
91 |
92 | To "propagate" a work means to do anything with it that, without
93 | permission, would make you directly or secondarily liable for
94 | infringement under applicable copyright law, except executing it on a
95 | computer or modifying a private copy. Propagation includes copying,
96 | distribution (with or without modification), making available to the
97 | public, and in some countries other activities as well.
98 |
99 | To "convey" a work means any kind of propagation that enables other
100 | parties to make or receive copies. Mere interaction with a user through
101 | a computer network, with no transfer of a copy, is not conveying.
102 |
103 | An interactive user interface displays "Appropriate Legal Notices"
104 | to the extent that it includes a convenient and prominently visible
105 | feature that (1) displays an appropriate copyright notice, and (2)
106 | tells the user that there is no warranty for the work (except to the
107 | extent that warranties are provided), that licensees may convey the
108 | work under this License, and how to view a copy of this License. If
109 | the interface presents a list of user commands or options, such as a
110 | menu, a prominent item in the list meets this criterion.
111 |
112 | 1. Source Code.
113 |
114 | The "source code" for a work means the preferred form of the work
115 | for making modifications to it. "Object code" means any non-source
116 | form of a work.
117 |
118 | A "Standard Interface" means an interface that either is an official
119 | standard defined by a recognized standards body, or, in the case of
120 | interfaces specified for a particular programming language, one that
121 | is widely used among developers working in that language.
122 |
123 | The "System Libraries" of an executable work include anything, other
124 | than the work as a whole, that (a) is included in the normal form of
125 | packaging a Major Component, but which is not part of that Major
126 | Component, and (b) serves only to enable use of the work with that
127 | Major Component, or to implement a Standard Interface for which an
128 | implementation is available to the public in source code form. A
129 | "Major Component", in this context, means a major essential component
130 | (kernel, window system, and so on) of the specific operating system
131 | (if any) on which the executable work runs, or a compiler used to
132 | produce the work, or an object code interpreter used to run it.
133 |
134 | The "Corresponding Source" for a work in object code form means all
135 | the source code needed to generate, install, and (for an executable
136 | work) run the object code and to modify the work, including scripts to
137 | control those activities. However, it does not include the work's
138 | System Libraries, or general-purpose tools or generally available free
139 | programs which are used unmodified in performing those activities but
140 | which are not part of the work. For example, Corresponding Source
141 | includes interface definition files associated with source files for
142 | the work, and the source code for shared libraries and dynamically
143 | linked subprograms that the work is specifically designed to require,
144 | such as by intimate data communication or control flow between those
145 | subprograms and other parts of the work.
146 |
147 | The Corresponding Source need not include anything that users
148 | can regenerate automatically from other parts of the Corresponding
149 | Source.
150 |
151 | The Corresponding Source for a work in source code form is that
152 | same work.
153 |
154 | 2. Basic Permissions.
155 |
156 | All rights granted under this License are granted for the term of
157 | copyright on the Program, and are irrevocable provided the stated
158 | conditions are met. This License explicitly affirms your unlimited
159 | permission to run the unmodified Program. The output from running a
160 | covered work is covered by this License only if the output, given its
161 | content, constitutes a covered work. This License acknowledges your
162 | rights of fair use or other equivalent, as provided by copyright law.
163 |
164 | You may make, run and propagate covered works that you do not
165 | convey, without conditions so long as your license otherwise remains
166 | in force. You may convey covered works to others for the sole purpose
167 | of having them make modifications exclusively for you, or provide you
168 | with facilities for running those works, provided that you comply with
169 | the terms of this License in conveying all material for which you do
170 | not control copyright. Those thus making or running the covered works
171 | for you must do so exclusively on your behalf, under your direction
172 | and control, on terms that prohibit them from making any copies of
173 | your copyrighted material outside their relationship with you.
174 |
175 | Conveying under any other circumstances is permitted solely under
176 | the conditions stated below. Sublicensing is not allowed; section 10
177 | makes it unnecessary.
178 |
179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180 |
181 | No covered work shall be deemed part of an effective technological
182 | measure under any applicable law fulfilling obligations under article
183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184 | similar laws prohibiting or restricting circumvention of such
185 | measures.
186 |
187 | When you convey a covered work, you waive any legal power to forbid
188 | circumvention of technological measures to the extent such circumvention
189 | is effected by exercising rights under this License with respect to
190 | the covered work, and you disclaim any intention to limit operation or
191 | modification of the work as a means of enforcing, against the work's
192 | users, your or third parties' legal rights to forbid circumvention of
193 | technological measures.
194 |
195 | 4. Conveying Verbatim Copies.
196 |
197 | You may convey verbatim copies of the Program's source code as you
198 | receive it, in any medium, provided that you conspicuously and
199 | appropriately publish on each copy an appropriate copyright notice;
200 | keep intact all notices stating that this License and any
201 | non-permissive terms added in accord with section 7 apply to the code;
202 | keep intact all notices of the absence of any warranty; and give all
203 | recipients a copy of this License along with the Program.
204 |
205 | You may charge any price or no price for each copy that you convey,
206 | and you may offer support or warranty protection for a fee.
207 |
208 | 5. Conveying Modified Source Versions.
209 |
210 | You may convey a work based on the Program, or the modifications to
211 | produce it from the Program, in the form of source code under the
212 | terms of section 4, provided that you also meet all of these conditions:
213 |
214 | a) The work must carry prominent notices stating that you modified
215 | it, and giving a relevant date.
216 |
217 | b) The work must carry prominent notices stating that it is
218 | released under this License and any conditions added under section
219 | 7. This requirement modifies the requirement in section 4 to
220 | "keep intact all notices".
221 |
222 | c) You must license the entire work, as a whole, under this
223 | License to anyone who comes into possession of a copy. This
224 | License will therefore apply, along with any applicable section 7
225 | additional terms, to the whole of the work, and all its parts,
226 | regardless of how they are packaged. This License gives no
227 | permission to license the work in any other way, but it does not
228 | invalidate such permission if you have separately received it.
229 |
230 | d) If the work has interactive user interfaces, each must display
231 | Appropriate Legal Notices; however, if the Program has interactive
232 | interfaces that do not display Appropriate Legal Notices, your
233 | work need not make them do so.
234 |
235 | A compilation of a covered work with other separate and independent
236 | works, which are not by their nature extensions of the covered work,
237 | and which are not combined with it such as to form a larger program,
238 | in or on a volume of a storage or distribution medium, is called an
239 | "aggregate" if the compilation and its resulting copyright are not
240 | used to limit the access or legal rights of the compilation's users
241 | beyond what the individual works permit. Inclusion of a covered work
242 | in an aggregate does not cause this License to apply to the other
243 | parts of the aggregate.
244 |
245 | 6. Conveying Non-Source Forms.
246 |
247 | You may convey a covered work in object code form under the terms
248 | of sections 4 and 5, provided that you also convey the
249 | machine-readable Corresponding Source under the terms of this License,
250 | in one of these ways:
251 |
252 | a) Convey the object code in, or embodied in, a physical product
253 | (including a physical distribution medium), accompanied by the
254 | Corresponding Source fixed on a durable physical medium
255 | customarily used for software interchange.
256 |
257 | b) Convey the object code in, or embodied in, a physical product
258 | (including a physical distribution medium), accompanied by a
259 | written offer, valid for at least three years and valid for as
260 | long as you offer spare parts or customer support for that product
261 | model, to give anyone who possesses the object code either (1) a
262 | copy of the Corresponding Source for all the software in the
263 | product that is covered by this License, on a durable physical
264 | medium customarily used for software interchange, for a price no
265 | more than your reasonable cost of physically performing this
266 | conveying of source, or (2) access to copy the
267 | Corresponding Source from a network server at no charge.
268 |
269 | c) Convey individual copies of the object code with a copy of the
270 | written offer to provide the Corresponding Source. This
271 | alternative is allowed only occasionally and noncommercially, and
272 | only if you received the object code with such an offer, in accord
273 | with subsection 6b.
274 |
275 | d) Convey the object code by offering access from a designated
276 | place (gratis or for a charge), and offer equivalent access to the
277 | Corresponding Source in the same way through the same place at no
278 | further charge. You need not require recipients to copy the
279 | Corresponding Source along with the object code. If the place to
280 | copy the object code is a network server, the Corresponding Source
281 | may be on a different server (operated by you or a third party)
282 | that supports equivalent copying facilities, provided you maintain
283 | clear directions next to the object code saying where to find the
284 | Corresponding Source. Regardless of what server hosts the
285 | Corresponding Source, you remain obligated to ensure that it is
286 | available for as long as needed to satisfy these requirements.
287 |
288 | e) Convey the object code using peer-to-peer transmission, provided
289 | you inform other peers where the object code and Corresponding
290 | Source of the work are being offered to the general public at no
291 | charge under subsection 6d.
292 |
293 | A separable portion of the object code, whose source code is excluded
294 | from the Corresponding Source as a System Library, need not be
295 | included in conveying the object code work.
296 |
297 | A "User Product" is either (1) a "consumer product", which means any
298 | tangible personal property which is normally used for personal, family,
299 | or household purposes, or (2) anything designed or sold for incorporation
300 | into a dwelling. In determining whether a product is a consumer product,
301 | doubtful cases shall be resolved in favor of coverage. For a particular
302 | product received by a particular user, "normally used" refers to a
303 | typical or common use of that class of product, regardless of the status
304 | of the particular user or of the way in which the particular user
305 | actually uses, or expects or is expected to use, the product. A product
306 | is a consumer product regardless of whether the product has substantial
307 | commercial, industrial or non-consumer uses, unless such uses represent
308 | the only significant mode of use of the product.
309 |
310 | "Installation Information" for a User Product means any methods,
311 | procedures, authorization keys, or other information required to install
312 | and execute modified versions of a covered work in that User Product from
313 | a modified version of its Corresponding Source. The information must
314 | suffice to ensure that the continued functioning of the modified object
315 | code is in no case prevented or interfered with solely because
316 | modification has been made.
317 |
318 | If you convey an object code work under this section in, or with, or
319 | specifically for use in, a User Product, and the conveying occurs as
320 | part of a transaction in which the right of possession and use of the
321 | User Product is transferred to the recipient in perpetuity or for a
322 | fixed term (regardless of how the transaction is characterized), the
323 | Corresponding Source conveyed under this section must be accompanied
324 | by the Installation Information. But this requirement does not apply
325 | if neither you nor any third party retains the ability to install
326 | modified object code on the User Product (for example, the work has
327 | been installed in ROM).
328 |
329 | The requirement to provide Installation Information does not include a
330 | requirement to continue to provide support service, warranty, or updates
331 | for a work that has been modified or installed by the recipient, or for
332 | the User Product in which it has been modified or installed. Access to a
333 | network may be denied when the modification itself materially and
334 | adversely affects the operation of the network or violates the rules and
335 | protocols for communication across the network.
336 |
337 | Corresponding Source conveyed, and Installation Information provided,
338 | in accord with this section must be in a format that is publicly
339 | documented (and with an implementation available to the public in
340 | source code form), and must require no special password or key for
341 | unpacking, reading or copying.
342 |
343 | 7. Additional Terms.
344 |
345 | "Additional permissions" are terms that supplement the terms of this
346 | License by making exceptions from one or more of its conditions.
347 | Additional permissions that are applicable to the entire Program shall
348 | be treated as though they were included in this License, to the extent
349 | that they are valid under applicable law. If additional permissions
350 | apply only to part of the Program, that part may be used separately
351 | under those permissions, but the entire Program remains governed by
352 | this License without regard to the additional permissions.
353 |
354 | When you convey a copy of a covered work, you may at your option
355 | remove any additional permissions from that copy, or from any part of
356 | it. (Additional permissions may be written to require their own
357 | removal in certain cases when you modify the work.) You may place
358 | additional permissions on material, added by you to a covered work,
359 | for which you have or can give appropriate copyright permission.
360 |
361 | Notwithstanding any other provision of this License, for material you
362 | add to a covered work, you may (if authorized by the copyright holders of
363 | that material) supplement the terms of this License with terms:
364 |
365 | a) Disclaiming warranty or limiting liability differently from the
366 | terms of sections 15 and 16 of this License; or
367 |
368 | b) Requiring preservation of specified reasonable legal notices or
369 | author attributions in that material or in the Appropriate Legal
370 | Notices displayed by works containing it; or
371 |
372 | c) Prohibiting misrepresentation of the origin of that material, or
373 | requiring that modified versions of such material be marked in
374 | reasonable ways as different from the original version; or
375 |
376 | d) Limiting the use for publicity purposes of names of licensors or
377 | authors of the material; or
378 |
379 | e) Declining to grant rights under trademark law for use of some
380 | trade names, trademarks, or service marks; or
381 |
382 | f) Requiring indemnification of licensors and authors of that
383 | material by anyone who conveys the material (or modified versions of
384 | it) with contractual assumptions of liability to the recipient, for
385 | any liability that these contractual assumptions directly impose on
386 | those licensors and authors.
387 |
388 | All other non-permissive additional terms are considered "further
389 | restrictions" within the meaning of section 10. If the Program as you
390 | received it, or any part of it, contains a notice stating that it is
391 | governed by this License along with a term that is a further
392 | restriction, you may remove that term. If a license document contains
393 | a further restriction but permits relicensing or conveying under this
394 | License, you may add to a covered work material governed by the terms
395 | of that license document, provided that the further restriction does
396 | not survive such relicensing or conveying.
397 |
398 | If you add terms to a covered work in accord with this section, you
399 | must place, in the relevant source files, a statement of the
400 | additional terms that apply to those files, or a notice indicating
401 | where to find the applicable terms.
402 |
403 | Additional terms, permissive or non-permissive, may be stated in the
404 | form of a separately written license, or stated as exceptions;
405 | the above requirements apply either way.
406 |
407 | 8. Termination.
408 |
409 | You may not propagate or modify a covered work except as expressly
410 | provided under this License. Any attempt otherwise to propagate or
411 | modify it is void, and will automatically terminate your rights under
412 | this License (including any patent licenses granted under the third
413 | paragraph of section 11).
414 |
415 | However, if you cease all violation of this License, then your
416 | license from a particular copyright holder is reinstated (a)
417 | provisionally, unless and until the copyright holder explicitly and
418 | finally terminates your license, and (b) permanently, if the copyright
419 | holder fails to notify you of the violation by some reasonable means
420 | prior to 60 days after the cessation.
421 |
422 | Moreover, your license from a particular copyright holder is
423 | reinstated permanently if the copyright holder notifies you of the
424 | violation by some reasonable means, this is the first time you have
425 | received notice of violation of this License (for any work) from that
426 | copyright holder, and you cure the violation prior to 30 days after
427 | your receipt of the notice.
428 |
429 | Termination of your rights under this section does not terminate the
430 | licenses of parties who have received copies or rights from you under
431 | this License. If your rights have been terminated and not permanently
432 | reinstated, you do not qualify to receive new licenses for the same
433 | material under section 10.
434 |
435 | 9. Acceptance Not Required for Having Copies.
436 |
437 | You are not required to accept this License in order to receive or
438 | run a copy of the Program. Ancillary propagation of a covered work
439 | occurring solely as a consequence of using peer-to-peer transmission
440 | to receive a copy likewise does not require acceptance. However,
441 | nothing other than this License grants you permission to propagate or
442 | modify any covered work. These actions infringe copyright if you do
443 | not accept this License. Therefore, by modifying or propagating a
444 | covered work, you indicate your acceptance of this License to do so.
445 |
446 | 10. Automatic Licensing of Downstream Recipients.
447 |
448 | Each time you convey a covered work, the recipient automatically
449 | receives a license from the original licensors, to run, modify and
450 | propagate that work, subject to this License. You are not responsible
451 | for enforcing compliance by third parties with this License.
452 |
453 | An "entity transaction" is a transaction transferring control of an
454 | organization, or substantially all assets of one, or subdividing an
455 | organization, or merging organizations. If propagation of a covered
456 | work results from an entity transaction, each party to that
457 | transaction who receives a copy of the work also receives whatever
458 | licenses to the work the party's predecessor in interest had or could
459 | give under the previous paragraph, plus a right to possession of the
460 | Corresponding Source of the work from the predecessor in interest, if
461 | the predecessor has it or can get it with reasonable efforts.
462 |
463 | You may not impose any further restrictions on the exercise of the
464 | rights granted or affirmed under this License. For example, you may
465 | not impose a license fee, royalty, or other charge for exercise of
466 | rights granted under this License, and you may not initiate litigation
467 | (including a cross-claim or counterclaim in a lawsuit) alleging that
468 | any patent claim is infringed by making, using, selling, offering for
469 | sale, or importing the Program or any portion of it.
470 |
471 | 11. Patents.
472 |
473 | A "contributor" is a copyright holder who authorizes use under this
474 | License of the Program or a work on which the Program is based. The
475 | work thus licensed is called the contributor's "contributor version".
476 |
477 | A contributor's "essential patent claims" are all patent claims
478 | owned or controlled by the contributor, whether already acquired or
479 | hereafter acquired, that would be infringed by some manner, permitted
480 | by this License, of making, using, or selling its contributor version,
481 | but do not include claims that would be infringed only as a
482 | consequence of further modification of the contributor version. For
483 | purposes of this definition, "control" includes the right to grant
484 | patent sublicenses in a manner consistent with the requirements of
485 | this License.
486 |
487 | Each contributor grants you a non-exclusive, worldwide, royalty-free
488 | patent license under the contributor's essential patent claims, to
489 | make, use, sell, offer for sale, import and otherwise run, modify and
490 | propagate the contents of its contributor version.
491 |
492 | In the following three paragraphs, a "patent license" is any express
493 | agreement or commitment, however denominated, not to enforce a patent
494 | (such as an express permission to practice a patent or covenant not to
495 | sue for patent infringement). To "grant" such a patent license to a
496 | party means to make such an agreement or commitment not to enforce a
497 | patent against the party.
498 |
499 | If you convey a covered work, knowingly relying on a patent license,
500 | and the Corresponding Source of the work is not available for anyone
501 | to copy, free of charge and under the terms of this License, through a
502 | publicly available network server or other readily accessible means,
503 | then you must either (1) cause the Corresponding Source to be so
504 | available, or (2) arrange to deprive yourself of the benefit of the
505 | patent license for this particular work, or (3) arrange, in a manner
506 | consistent with the requirements of this License, to extend the patent
507 | license to downstream recipients. "Knowingly relying" means you have
508 | actual knowledge that, but for the patent license, your conveying the
509 | covered work in a country, or your recipient's use of the covered work
510 | in a country, would infringe one or more identifiable patents in that
511 | country that you have reason to believe are valid.
512 |
513 | If, pursuant to or in connection with a single transaction or
514 | arrangement, you convey, or propagate by procuring conveyance of, a
515 | covered work, and grant a patent license to some of the parties
516 | receiving the covered work authorizing them to use, propagate, modify
517 | or convey a specific copy of the covered work, then the patent license
518 | you grant is automatically extended to all recipients of the covered
519 | work and works based on it.
520 |
521 | A patent license is "discriminatory" if it does not include within
522 | the scope of its coverage, prohibits the exercise of, or is
523 | conditioned on the non-exercise of one or more of the rights that are
524 | specifically granted under this License. You may not convey a covered
525 | work if you are a party to an arrangement with a third party that is
526 | in the business of distributing software, under which you make payment
527 | to the third party based on the extent of your activity of conveying
528 | the work, and under which the third party grants, to any of the
529 | parties who would receive the covered work from you, a discriminatory
530 | patent license (a) in connection with copies of the covered work
531 | conveyed by you (or copies made from those copies), or (b) primarily
532 | for and in connection with specific products or compilations that
533 | contain the covered work, unless you entered into that arrangement,
534 | or that patent license was granted, prior to 28 March 2007.
535 |
536 | Nothing in this License shall be construed as excluding or limiting
537 | any implied license or other defenses to infringement that may
538 | otherwise be available to you under applicable patent law.
539 |
540 | 12. No Surrender of Others' Freedom.
541 |
542 | If conditions are imposed on you (whether by court order, agreement or
543 | otherwise) that contradict the conditions of this License, they do not
544 | excuse you from the conditions of this License. If you cannot convey a
545 | covered work so as to satisfy simultaneously your obligations under this
546 | License and any other pertinent obligations, then as a consequence you may
547 | not convey it at all. For example, if you agree to terms that obligate you
548 | to collect a royalty for further conveying from those to whom you convey
549 | the Program, the only way you could satisfy both those terms and this
550 | License would be to refrain entirely from conveying the Program.
551 |
552 | 13. Use with the GNU Affero General Public License.
553 |
554 | Notwithstanding any other provision of this License, you have
555 | permission to link or combine any covered work with a work licensed
556 | under version 3 of the GNU Affero General Public License into a single
557 | combined work, and to convey the resulting work. The terms of this
558 | License will continue to apply to the part which is the covered work,
559 | but the special requirements of the GNU Affero General Public License,
560 | section 13, concerning interaction through a network will apply to the
561 | combination as such.
562 |
563 | 14. Revised Versions of this License.
564 |
565 | The Free Software Foundation may publish revised and/or new versions of
566 | the GNU General Public License from time to time. Such new versions will
567 | be similar in spirit to the present version, but may differ in detail to
568 | address new problems or concerns.
569 |
570 | Each version is given a distinguishing version number. If the
571 | Program specifies that a certain numbered version of the GNU General
572 | Public License "or any later version" applies to it, you have the
573 | option of following the terms and conditions either of that numbered
574 | version or of any later version published by the Free Software
575 | Foundation. If the Program does not specify a version number of the
576 | GNU General Public License, you may choose any version ever published
577 | by the Free Software Foundation.
578 |
579 | If the Program specifies that a proxy can decide which future
580 | versions of the GNU General Public License can be used, that proxy's
581 | public statement of acceptance of a version permanently authorizes you
582 | to choose that version for the Program.
583 |
584 | Later license versions may give you additional or different
585 | permissions. However, no additional obligations are imposed on any
586 | author or copyright holder as a result of your choosing to follow a
587 | later version.
588 |
589 | 15. Disclaimer of Warranty.
590 |
591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599 |
600 | 16. Limitation of Liability.
601 |
602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610 | SUCH DAMAGES.
611 |
612 | 17. Interpretation of Sections 15 and 16.
613 |
614 | If the disclaimer of warranty and limitation of liability provided
615 | above cannot be given local legal effect according to their terms,
616 | reviewing courts shall apply local law that most closely approximates
617 | an absolute waiver of all civil liability in connection with the
618 | Program, unless a warranty or assumption of liability accompanies a
619 | copy of the Program in return for a fee.
620 |
621 | END OF TERMS AND CONDITIONS
622 |
623 | How to Apply These Terms to Your New Programs
624 |
625 | If you develop a new program, and you want it to be of the greatest
626 | possible use to the public, the best way to achieve this is to make it
627 | free software which everyone can redistribute and change under these terms.
628 |
629 | To do so, attach the following notices to the program. It is safest
630 | to attach them to the start of each source file to most effectively
631 | state the exclusion of warranty; and each file should have at least
632 | the "copyright" line and a pointer to where the full notice is found.
633 |
634 |
635 | Copyright (C)
636 |
637 | This program is free software: you can redistribute it and/or modify
638 | it under the terms of the GNU General Public License as published by
639 | the Free Software Foundation, either version 3 of the License, or
640 | (at your option) any later version.
641 |
642 | This program is distributed in the hope that it will be useful,
643 | but WITHOUT ANY WARRANTY; without even the implied warranty of
644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645 | GNU General Public License for more details.
646 |
647 | You should have received a copy of the GNU General Public License
648 | along with this program. If not, see .
649 |
650 | Also add information on how to contact you by electronic and paper mail.
651 |
652 | If the program does terminal interaction, make it output a short
653 | notice like this when it starts in an interactive mode:
654 |
655 | Copyright (C)
656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657 | This is free software, and you are welcome to redistribute it
658 | under certain conditions; type `show c' for details.
659 |
660 | The hypothetical commands `show w' and `show c' should show the appropriate
661 | parts of the General Public License. Of course, your program's commands
662 | might be different; for a GUI interface, you would use an "about box".
663 |
664 | You should also get your employer (if you work as a programmer) or school,
665 | if any, to sign a "copyright disclaimer" for the program, if necessary.
666 | For more information on this, and how to apply and follow the GNU GPL, see
667 | .
668 |
669 | The GNU General Public License does not permit incorporating your program
670 | into proprietary programs. If your program is a subroutine library, you
671 | may consider it more useful to permit linking proprietary applications with
672 | the library. If this is what you want to do, use the GNU Lesser General
673 | Public License instead of this License. But first, please read
674 | .
675 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Temporal and Heterogeneous Graph Neural Network for Financial Time Series Prediction (THGNN)
2 | ## 1. Prepare you training data
3 | The input to your model is a pkl file that includes the stock symbol `code`, the time `dt`, and the volume and price features. Then, you need to use `generate_relation.py` to generate daily stock relationships and `generate_data.py` to generate the final input data for the model. You can adjust the features used in building the stock relationship and generating the final input by changing `feature_cols`. The `relation` directory stores the relations between stocks. The `daily_stock` directory contains stocks that are trained each day. The `data_train_predict` directory stores the final inputs fed to the model each day. The `prediction` directory stores the prediction result of the validation set. The `model_saved` directory stores the trained model.
4 |
5 | ## 2. Train you model
6 | * Before training, make sure to change the parameters in class `Args` and function `main`.
7 |
8 | ``` config
9 | adj_threshold = 0.1 # the threshold of the relations between stocks
10 | max_epochs = 60 # the number of training epochs
11 | epochs_eval = 10 # the number of training epochs per evaluation or test interval
12 | epochs_save_by = 60 # the number of training epochs before a model is saved
13 | lr = 0.0002 # learning rate of the model
14 | gamma = 0.3 # gamma
15 | hidden_dim = 128 # hidden_dim
16 | num_heads = 8 # num_heads
17 | out_features = 32 # out_features
18 | model_name = "StockHeteGAT" # The main model name in model.thgnn.py
19 | dropout = 0.1 # dropout
20 | batch_size = 1 # batch_size
21 | loss_fcn = mse_loss # loss function
22 | epochs_save_by = 60 # the number of training epochs of the saved model
23 | data_start = 20 # index of training start date
24 | data_middle = 39 # index of evaluation or test start date/ index of training end date
25 | data_end = data_middle+4 # index of evaluation or test end date
26 | pre_data = '2021-12-29' # save the last date of the training
27 | ```
28 |
29 | * Install required packages
30 |
31 | ``` shell
32 | pip install -r requirements.txt for specific versions
33 | ```
34 |
35 | * Training
36 |
37 | ``` shell
38 | sh train.sh
39 | ```
40 | ## 3. Citing
41 |
42 | * If you find **THGNN** is useful for your research, please consider citing the following papers:
43 |
44 | ``` latex
45 | @inproceedings{Xiang2022Temporal,
46 | author = {Xiang, Sheng and Cheng, Dawei and Shang, Chencheng and Zhang, Ying and Liang, Yuqi},
47 | title = {Temporal and Heterogeneous Graph Neural Network for Financial Time Series Prediction},
48 | year = {2022},
49 | isbn = {9781450392365},
50 | publisher = {Association for Computing Machinery},
51 | address = {New York, NY, USA},
52 | url = {https://doi.org/10.1145/3511808.3557089},
53 | doi = {10.1145/3511808.3557089},
54 | booktitle = {Proceedings of the 31st ACM International Conference on Information & Knowledge Management},
55 | pages = {3584–3593},
56 | numpages = {10},
57 | location = {Atlanta, GA, USA},
58 | series = {CIKM '22}
59 | }
60 | ```
--------------------------------------------------------------------------------
/data/csi300.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TongjiFinLab/THGNN/eefbfd3f38fdc110deeaed8cdf39402ea7ba9c9d/data/csi300.pkl
--------------------------------------------------------------------------------
/data/daily_stock/2022-11-01.csv:
--------------------------------------------------------------------------------
1 | code,dt
2 |
--------------------------------------------------------------------------------
/data/daily_stock/2022-11-02.csv:
--------------------------------------------------------------------------------
1 | code,dt
2 |
--------------------------------------------------------------------------------
/data/daily_stock/2022-11-03.csv:
--------------------------------------------------------------------------------
1 | code,dt
2 |
--------------------------------------------------------------------------------
/data/daily_stock/2022-11-04.csv:
--------------------------------------------------------------------------------
1 | code,dt
2 |
--------------------------------------------------------------------------------
/data/daily_stock/2022-11-07.csv:
--------------------------------------------------------------------------------
1 | code,dt
2 |
--------------------------------------------------------------------------------
/data/daily_stock/2022-11-08.csv:
--------------------------------------------------------------------------------
1 | code,dt
2 |
--------------------------------------------------------------------------------
/data/daily_stock/2022-11-09.csv:
--------------------------------------------------------------------------------
1 | code,dt
2 |
--------------------------------------------------------------------------------
/data/daily_stock/2022-11-10.csv:
--------------------------------------------------------------------------------
1 | code,dt
2 |
--------------------------------------------------------------------------------
/data/daily_stock/2022-11-11.csv:
--------------------------------------------------------------------------------
1 | code,dt
2 |
--------------------------------------------------------------------------------
/data/daily_stock/2022-11-14.csv:
--------------------------------------------------------------------------------
1 | code,dt
2 |
--------------------------------------------------------------------------------
/data/daily_stock/2022-11-15.csv:
--------------------------------------------------------------------------------
1 | code,dt
2 |
--------------------------------------------------------------------------------
/data/daily_stock/2022-11-16.csv:
--------------------------------------------------------------------------------
1 | code,dt
2 |
--------------------------------------------------------------------------------
/data/daily_stock/2022-11-17.csv:
--------------------------------------------------------------------------------
1 | code,dt
2 |
--------------------------------------------------------------------------------
/data/daily_stock/2022-11-18.csv:
--------------------------------------------------------------------------------
1 | code,dt
2 |
--------------------------------------------------------------------------------
/data/daily_stock/2022-11-21.csv:
--------------------------------------------------------------------------------
1 | code,dt
2 |
--------------------------------------------------------------------------------
/data/daily_stock/2022-11-22.csv:
--------------------------------------------------------------------------------
1 | code,dt
2 |
--------------------------------------------------------------------------------
/data/daily_stock/2022-11-23.csv:
--------------------------------------------------------------------------------
1 | code,dt
2 |
--------------------------------------------------------------------------------
/data/daily_stock/2022-11-24.csv:
--------------------------------------------------------------------------------
1 | code,dt
2 |
--------------------------------------------------------------------------------
/data/daily_stock/2022-11-25.csv:
--------------------------------------------------------------------------------
1 | code,dt
2 |
--------------------------------------------------------------------------------
/data/daily_stock/2022-11-28.csv:
--------------------------------------------------------------------------------
1 | code,dt
2 | 000001.SZ,2022-11-28
3 | 000002.SZ,2022-11-28
4 | 000063.SZ,2022-11-28
5 | 000069.SZ,2022-11-28
6 |
--------------------------------------------------------------------------------
/data/daily_stock/2022-11-29.csv:
--------------------------------------------------------------------------------
1 | code,dt
2 | 000001.SZ,2022-11-29
3 | 000002.SZ,2022-11-29
4 | 000063.SZ,2022-11-29
5 | 000069.SZ,2022-11-29
6 |
--------------------------------------------------------------------------------
/data/daily_stock/2022-11-30.csv:
--------------------------------------------------------------------------------
1 | code,dt
2 | 000001.SZ,2022-11-30
3 | 000002.SZ,2022-11-30
4 | 000063.SZ,2022-11-30
5 | 000069.SZ,2022-11-30
6 |
--------------------------------------------------------------------------------
/data/daily_stock/2022-12-01.csv:
--------------------------------------------------------------------------------
1 | code,dt
2 | 000001.SZ,2022-12-01
3 | 000002.SZ,2022-12-01
4 | 000063.SZ,2022-12-01
5 | 000069.SZ,2022-12-01
6 |
--------------------------------------------------------------------------------
/data/daily_stock/2022-12-02.csv:
--------------------------------------------------------------------------------
1 | code,dt
2 | 000001.SZ,2022-12-02
3 | 000002.SZ,2022-12-02
4 | 000063.SZ,2022-12-02
5 | 000069.SZ,2022-12-02
6 |
--------------------------------------------------------------------------------
/data/daily_stock/2022-12-05.csv:
--------------------------------------------------------------------------------
1 | code,dt
2 | 000001.SZ,2022-12-05
3 | 000002.SZ,2022-12-05
4 | 000063.SZ,2022-12-05
5 | 000069.SZ,2022-12-05
6 |
--------------------------------------------------------------------------------
/data/daily_stock/2022-12-06.csv:
--------------------------------------------------------------------------------
1 | code,dt
2 | 000001.SZ,2022-12-06
3 | 000002.SZ,2022-12-06
4 | 000063.SZ,2022-12-06
5 | 000069.SZ,2022-12-06
6 |
--------------------------------------------------------------------------------
/data/daily_stock/2022-12-07.csv:
--------------------------------------------------------------------------------
1 | code,dt
2 | 000001.SZ,2022-12-07
3 | 000002.SZ,2022-12-07
4 | 000063.SZ,2022-12-07
5 | 000069.SZ,2022-12-07
6 |
--------------------------------------------------------------------------------
/data/daily_stock/2022-12-08.csv:
--------------------------------------------------------------------------------
1 | code,dt
2 | 000001.SZ,2022-12-08
3 | 000002.SZ,2022-12-08
4 | 000063.SZ,2022-12-08
5 | 000069.SZ,2022-12-08
6 |
--------------------------------------------------------------------------------
/data/daily_stock/2022-12-09.csv:
--------------------------------------------------------------------------------
1 | code,dt
2 | 000001.SZ,2022-12-09
3 | 000002.SZ,2022-12-09
4 | 000063.SZ,2022-12-09
5 | 000069.SZ,2022-12-09
6 |
--------------------------------------------------------------------------------
/data/daily_stock/2022-12-12.csv:
--------------------------------------------------------------------------------
1 | code,dt
2 | 000001.SZ,2022-12-12
3 | 000002.SZ,2022-12-12
4 | 000063.SZ,2022-12-12
5 | 000069.SZ,2022-12-12
6 |
--------------------------------------------------------------------------------
/data/daily_stock/2022-12-13.csv:
--------------------------------------------------------------------------------
1 | code,dt
2 | 000001.SZ,2022-12-13
3 | 000002.SZ,2022-12-13
4 | 000063.SZ,2022-12-13
5 | 000069.SZ,2022-12-13
6 |
--------------------------------------------------------------------------------
/data/daily_stock/2022-12-14.csv:
--------------------------------------------------------------------------------
1 | code,dt
2 | 000001.SZ,2022-12-14
3 | 000002.SZ,2022-12-14
4 | 000063.SZ,2022-12-14
5 | 000069.SZ,2022-12-14
6 |
--------------------------------------------------------------------------------
/data/daily_stock/2022-12-15.csv:
--------------------------------------------------------------------------------
1 | code,dt
2 | 000001.SZ,2022-12-15
3 | 000002.SZ,2022-12-15
4 | 000063.SZ,2022-12-15
5 | 000069.SZ,2022-12-15
6 |
--------------------------------------------------------------------------------
/data/daily_stock/2022-12-16.csv:
--------------------------------------------------------------------------------
1 | code,dt
2 | 000001.SZ,2022-12-16
3 | 000002.SZ,2022-12-16
4 | 000063.SZ,2022-12-16
5 | 000069.SZ,2022-12-16
6 |
--------------------------------------------------------------------------------
/data/daily_stock/2022-12-19.csv:
--------------------------------------------------------------------------------
1 | code,dt
2 | 000001.SZ,2022-12-19
3 | 000002.SZ,2022-12-19
4 | 000063.SZ,2022-12-19
5 | 000069.SZ,2022-12-19
6 |
--------------------------------------------------------------------------------
/data/daily_stock/2022-12-20.csv:
--------------------------------------------------------------------------------
1 | code,dt
2 | 000001.SZ,2022-12-20
3 | 000002.SZ,2022-12-20
4 | 000063.SZ,2022-12-20
5 | 000069.SZ,2022-12-20
6 |
--------------------------------------------------------------------------------
/data/daily_stock/2022-12-21.csv:
--------------------------------------------------------------------------------
1 | code,dt
2 | 000001.SZ,2022-12-21
3 | 000002.SZ,2022-12-21
4 | 000063.SZ,2022-12-21
5 | 000069.SZ,2022-12-21
6 |
--------------------------------------------------------------------------------
/data/daily_stock/2022-12-22.csv:
--------------------------------------------------------------------------------
1 | code,dt
2 | 000001.SZ,2022-12-22
3 | 000002.SZ,2022-12-22
4 | 000063.SZ,2022-12-22
5 | 000069.SZ,2022-12-22
6 |
--------------------------------------------------------------------------------
/data/daily_stock/2022-12-23.csv:
--------------------------------------------------------------------------------
1 | code,dt
2 | 000001.SZ,2022-12-23
3 | 000002.SZ,2022-12-23
4 | 000063.SZ,2022-12-23
5 | 000069.SZ,2022-12-23
6 |
--------------------------------------------------------------------------------
/data/daily_stock/2022-12-26.csv:
--------------------------------------------------------------------------------
1 | code,dt
2 | 000001.SZ,2022-12-26
3 | 000002.SZ,2022-12-26
4 | 000063.SZ,2022-12-26
5 | 000069.SZ,2022-12-26
6 |
--------------------------------------------------------------------------------
/data/daily_stock/2022-12-27.csv:
--------------------------------------------------------------------------------
1 | code,dt
2 | 000001.SZ,2022-12-27
3 | 000002.SZ,2022-12-27
4 | 000063.SZ,2022-12-27
5 | 000069.SZ,2022-12-27
6 |
--------------------------------------------------------------------------------
/data/daily_stock/2022-12-28.csv:
--------------------------------------------------------------------------------
1 | code,dt
2 | 000001.SZ,2022-12-28
3 | 000002.SZ,2022-12-28
4 | 000063.SZ,2022-12-28
5 | 000069.SZ,2022-12-28
6 |
--------------------------------------------------------------------------------
/data/daily_stock/2022-12-29.csv:
--------------------------------------------------------------------------------
1 | code,dt
2 | 000001.SZ,2022-12-29
3 | 000002.SZ,2022-12-29
4 | 000063.SZ,2022-12-29
5 | 000069.SZ,2022-12-29
6 |
--------------------------------------------------------------------------------
/data/daily_stock/2022-12-30.csv:
--------------------------------------------------------------------------------
1 | code,dt
2 | 000001.SZ,2022-12-30
3 | 000002.SZ,2022-12-30
4 | 000063.SZ,2022-12-30
5 | 000069.SZ,2022-12-30
6 |
--------------------------------------------------------------------------------
/data/data_train_predict/2022-11-01.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TongjiFinLab/THGNN/eefbfd3f38fdc110deeaed8cdf39402ea7ba9c9d/data/data_train_predict/2022-11-01.pkl
--------------------------------------------------------------------------------
/data/data_train_predict/2022-11-02.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TongjiFinLab/THGNN/eefbfd3f38fdc110deeaed8cdf39402ea7ba9c9d/data/data_train_predict/2022-11-02.pkl
--------------------------------------------------------------------------------
/data/data_train_predict/2022-11-03.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TongjiFinLab/THGNN/eefbfd3f38fdc110deeaed8cdf39402ea7ba9c9d/data/data_train_predict/2022-11-03.pkl
--------------------------------------------------------------------------------
/data/data_train_predict/2022-11-04.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TongjiFinLab/THGNN/eefbfd3f38fdc110deeaed8cdf39402ea7ba9c9d/data/data_train_predict/2022-11-04.pkl
--------------------------------------------------------------------------------
/data/data_train_predict/2022-11-07.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TongjiFinLab/THGNN/eefbfd3f38fdc110deeaed8cdf39402ea7ba9c9d/data/data_train_predict/2022-11-07.pkl
--------------------------------------------------------------------------------
/data/data_train_predict/2022-11-08.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TongjiFinLab/THGNN/eefbfd3f38fdc110deeaed8cdf39402ea7ba9c9d/data/data_train_predict/2022-11-08.pkl
--------------------------------------------------------------------------------
/data/data_train_predict/2022-11-09.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TongjiFinLab/THGNN/eefbfd3f38fdc110deeaed8cdf39402ea7ba9c9d/data/data_train_predict/2022-11-09.pkl
--------------------------------------------------------------------------------
/data/data_train_predict/2022-11-10.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TongjiFinLab/THGNN/eefbfd3f38fdc110deeaed8cdf39402ea7ba9c9d/data/data_train_predict/2022-11-10.pkl
--------------------------------------------------------------------------------
/data/data_train_predict/2022-11-11.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TongjiFinLab/THGNN/eefbfd3f38fdc110deeaed8cdf39402ea7ba9c9d/data/data_train_predict/2022-11-11.pkl
--------------------------------------------------------------------------------
/data/data_train_predict/2022-11-14.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TongjiFinLab/THGNN/eefbfd3f38fdc110deeaed8cdf39402ea7ba9c9d/data/data_train_predict/2022-11-14.pkl
--------------------------------------------------------------------------------
/data/data_train_predict/2022-11-15.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TongjiFinLab/THGNN/eefbfd3f38fdc110deeaed8cdf39402ea7ba9c9d/data/data_train_predict/2022-11-15.pkl
--------------------------------------------------------------------------------
/data/data_train_predict/2022-11-16.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TongjiFinLab/THGNN/eefbfd3f38fdc110deeaed8cdf39402ea7ba9c9d/data/data_train_predict/2022-11-16.pkl
--------------------------------------------------------------------------------
/data/data_train_predict/2022-11-17.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TongjiFinLab/THGNN/eefbfd3f38fdc110deeaed8cdf39402ea7ba9c9d/data/data_train_predict/2022-11-17.pkl
--------------------------------------------------------------------------------
/data/data_train_predict/2022-11-18.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TongjiFinLab/THGNN/eefbfd3f38fdc110deeaed8cdf39402ea7ba9c9d/data/data_train_predict/2022-11-18.pkl
--------------------------------------------------------------------------------
/data/data_train_predict/2022-11-21.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TongjiFinLab/THGNN/eefbfd3f38fdc110deeaed8cdf39402ea7ba9c9d/data/data_train_predict/2022-11-21.pkl
--------------------------------------------------------------------------------
/data/data_train_predict/2022-11-22.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TongjiFinLab/THGNN/eefbfd3f38fdc110deeaed8cdf39402ea7ba9c9d/data/data_train_predict/2022-11-22.pkl
--------------------------------------------------------------------------------
/data/data_train_predict/2022-11-23.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TongjiFinLab/THGNN/eefbfd3f38fdc110deeaed8cdf39402ea7ba9c9d/data/data_train_predict/2022-11-23.pkl
--------------------------------------------------------------------------------
/data/data_train_predict/2022-11-24.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TongjiFinLab/THGNN/eefbfd3f38fdc110deeaed8cdf39402ea7ba9c9d/data/data_train_predict/2022-11-24.pkl
--------------------------------------------------------------------------------
/data/data_train_predict/2022-11-25.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TongjiFinLab/THGNN/eefbfd3f38fdc110deeaed8cdf39402ea7ba9c9d/data/data_train_predict/2022-11-25.pkl
--------------------------------------------------------------------------------
/data/data_train_predict/2022-11-28.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TongjiFinLab/THGNN/eefbfd3f38fdc110deeaed8cdf39402ea7ba9c9d/data/data_train_predict/2022-11-28.pkl
--------------------------------------------------------------------------------
/data/data_train_predict/2022-11-29.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TongjiFinLab/THGNN/eefbfd3f38fdc110deeaed8cdf39402ea7ba9c9d/data/data_train_predict/2022-11-29.pkl
--------------------------------------------------------------------------------
/data/data_train_predict/2022-11-30.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TongjiFinLab/THGNN/eefbfd3f38fdc110deeaed8cdf39402ea7ba9c9d/data/data_train_predict/2022-11-30.pkl
--------------------------------------------------------------------------------
/data/data_train_predict/2022-12-01.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TongjiFinLab/THGNN/eefbfd3f38fdc110deeaed8cdf39402ea7ba9c9d/data/data_train_predict/2022-12-01.pkl
--------------------------------------------------------------------------------
/data/data_train_predict/2022-12-02.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TongjiFinLab/THGNN/eefbfd3f38fdc110deeaed8cdf39402ea7ba9c9d/data/data_train_predict/2022-12-02.pkl
--------------------------------------------------------------------------------
/data/data_train_predict/2022-12-05.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TongjiFinLab/THGNN/eefbfd3f38fdc110deeaed8cdf39402ea7ba9c9d/data/data_train_predict/2022-12-05.pkl
--------------------------------------------------------------------------------
/data/data_train_predict/2022-12-06.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TongjiFinLab/THGNN/eefbfd3f38fdc110deeaed8cdf39402ea7ba9c9d/data/data_train_predict/2022-12-06.pkl
--------------------------------------------------------------------------------
/data/data_train_predict/2022-12-07.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TongjiFinLab/THGNN/eefbfd3f38fdc110deeaed8cdf39402ea7ba9c9d/data/data_train_predict/2022-12-07.pkl
--------------------------------------------------------------------------------
/data/data_train_predict/2022-12-08.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TongjiFinLab/THGNN/eefbfd3f38fdc110deeaed8cdf39402ea7ba9c9d/data/data_train_predict/2022-12-08.pkl
--------------------------------------------------------------------------------
/data/data_train_predict/2022-12-09.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TongjiFinLab/THGNN/eefbfd3f38fdc110deeaed8cdf39402ea7ba9c9d/data/data_train_predict/2022-12-09.pkl
--------------------------------------------------------------------------------
/data/data_train_predict/2022-12-12.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TongjiFinLab/THGNN/eefbfd3f38fdc110deeaed8cdf39402ea7ba9c9d/data/data_train_predict/2022-12-12.pkl
--------------------------------------------------------------------------------
/data/data_train_predict/2022-12-13.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TongjiFinLab/THGNN/eefbfd3f38fdc110deeaed8cdf39402ea7ba9c9d/data/data_train_predict/2022-12-13.pkl
--------------------------------------------------------------------------------
/data/data_train_predict/2022-12-14.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TongjiFinLab/THGNN/eefbfd3f38fdc110deeaed8cdf39402ea7ba9c9d/data/data_train_predict/2022-12-14.pkl
--------------------------------------------------------------------------------
/data/data_train_predict/2022-12-15.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TongjiFinLab/THGNN/eefbfd3f38fdc110deeaed8cdf39402ea7ba9c9d/data/data_train_predict/2022-12-15.pkl
--------------------------------------------------------------------------------
/data/data_train_predict/2022-12-16.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TongjiFinLab/THGNN/eefbfd3f38fdc110deeaed8cdf39402ea7ba9c9d/data/data_train_predict/2022-12-16.pkl
--------------------------------------------------------------------------------
/data/data_train_predict/2022-12-19.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TongjiFinLab/THGNN/eefbfd3f38fdc110deeaed8cdf39402ea7ba9c9d/data/data_train_predict/2022-12-19.pkl
--------------------------------------------------------------------------------
/data/data_train_predict/2022-12-20.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TongjiFinLab/THGNN/eefbfd3f38fdc110deeaed8cdf39402ea7ba9c9d/data/data_train_predict/2022-12-20.pkl
--------------------------------------------------------------------------------
/data/data_train_predict/2022-12-21.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TongjiFinLab/THGNN/eefbfd3f38fdc110deeaed8cdf39402ea7ba9c9d/data/data_train_predict/2022-12-21.pkl
--------------------------------------------------------------------------------
/data/data_train_predict/2022-12-22.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TongjiFinLab/THGNN/eefbfd3f38fdc110deeaed8cdf39402ea7ba9c9d/data/data_train_predict/2022-12-22.pkl
--------------------------------------------------------------------------------
/data/data_train_predict/2022-12-23.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TongjiFinLab/THGNN/eefbfd3f38fdc110deeaed8cdf39402ea7ba9c9d/data/data_train_predict/2022-12-23.pkl
--------------------------------------------------------------------------------
/data/data_train_predict/2022-12-26.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TongjiFinLab/THGNN/eefbfd3f38fdc110deeaed8cdf39402ea7ba9c9d/data/data_train_predict/2022-12-26.pkl
--------------------------------------------------------------------------------
/data/data_train_predict/2022-12-27.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TongjiFinLab/THGNN/eefbfd3f38fdc110deeaed8cdf39402ea7ba9c9d/data/data_train_predict/2022-12-27.pkl
--------------------------------------------------------------------------------
/data/data_train_predict/2022-12-28.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TongjiFinLab/THGNN/eefbfd3f38fdc110deeaed8cdf39402ea7ba9c9d/data/data_train_predict/2022-12-28.pkl
--------------------------------------------------------------------------------
/data/data_train_predict/2022-12-29.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TongjiFinLab/THGNN/eefbfd3f38fdc110deeaed8cdf39402ea7ba9c9d/data/data_train_predict/2022-12-29.pkl
--------------------------------------------------------------------------------
/data/data_train_predict/2022-12-30.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TongjiFinLab/THGNN/eefbfd3f38fdc110deeaed8cdf39402ea7ba9c9d/data/data_train_predict/2022-12-30.pkl
--------------------------------------------------------------------------------
/data/model_saved/2022-12-29_epoch_60.dat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TongjiFinLab/THGNN/eefbfd3f38fdc110deeaed8cdf39402ea7ba9c9d/data/model_saved/2022-12-29_epoch_60.dat
--------------------------------------------------------------------------------
/data/prediction/pred.csv:
--------------------------------------------------------------------------------
1 | ,code,dt,score
2 | 0,000001.SZ,2022-12-26,0.26168692111968994
3 | 1,000002.SZ,2022-12-26,0.25601014494895935
4 | 2,000063.SZ,2022-12-26,0.6623930335044861
5 | 3,000069.SZ,2022-12-26,0.26233309507369995
6 | 0,000001.SZ,2022-12-27,0.2617263197898865
7 | 1,000002.SZ,2022-12-27,0.2553652226924896
8 | 2,000063.SZ,2022-12-27,0.6622966527938843
9 | 3,000069.SZ,2022-12-27,0.2624991834163666
10 | 0,000001.SZ,2022-12-28,0.2617460787296295
11 | 1,000002.SZ,2022-12-28,0.6396672129631042
12 | 2,000063.SZ,2022-12-28,0.6622033715248108
13 | 3,000069.SZ,2022-12-28,0.26258984208106995
14 | 0,000001.SZ,2022-12-29,0.26160314679145813
15 | 1,000002.SZ,2022-12-29,0.648451030254364
16 | 2,000063.SZ,2022-12-29,0.6622130870819092
17 | 3,000069.SZ,2022-12-29,0.2627296447753906
18 |
--------------------------------------------------------------------------------
/data/relation/2022-11-30.csv:
--------------------------------------------------------------------------------
1 | ,000001.SZ,000002.SZ,000063.SZ,000069.SZ
2 | 000001.SZ,1.0,0.9183336353345618,0.489495780528676,0.8544688680937519
3 | 000002.SZ,0.9183336353345618,1.0,0.3911211266892482,0.8806050758165241
4 | 000063.SZ,0.489495780528676,0.3911211266892482,1.0,0.27706243583541806
5 | 000069.SZ,0.8544688680937519,0.8806050758165241,0.27706243583541806,1.0
6 |
--------------------------------------------------------------------------------
/data/relation/2022-12-30.csv:
--------------------------------------------------------------------------------
1 | ,000001.SZ,000002.SZ,000063.SZ,000069.SZ
2 | 000001.SZ,1.0,0.70799906311983,0.7967784444477405,0.7205848661984243
3 | 000002.SZ,0.70799906311983,1.0,0.5748544637846457,0.8091412898515844
4 | 000063.SZ,0.7967784444477405,0.5748544637846457,1.0,0.6498760383308898
5 | 000069.SZ,0.7205848661984243,0.8091412898515844,0.6498760383308898,1.0
6 |
--------------------------------------------------------------------------------
/data_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from torch.utils import data
4 | import pickle
5 |
6 | class AllGraphDataSampler(data.Dataset):
7 | def __init__(self, base_dir, gname_list=None, data_start=None, data_middle=None, data_end=None, mode="train"):
8 | self.data_dir = os.path.join(base_dir)
9 | self.mode = mode
10 | self.data_start = data_start
11 | self.data_middle = data_middle
12 | self.data_end = data_end
13 | if gname_list is None:
14 | self.gnames_all = os.listdir(self.data_dir)
15 | self.gnames_all.sort()
16 | if mode == "train":
17 | self.gnames_all = self.gnames_all[self.data_start:self.data_middle]
18 | elif mode == "val":
19 | self.gnames_all = self.gnames_all[self.data_middle:self.data_end]
20 | self.data_all = self.load_state()
21 |
22 | def __len__(self):
23 | return len(self.data_all)
24 |
25 | def load_state(self):
26 | data_all = []
27 | length = len(self.gnames_all)
28 | for i in range(length):
29 | sys.stdout.flush()
30 | sys.stdout.write('{} data loading: {:.2f}%{}'.format(self.mode, i*100/length, '\r'))
31 | data_all.append(pickle.load(open(os.path.join(self.data_dir, self.gnames_all[i]), "rb")))
32 | print('{} data loaded!'.format(self.mode))
33 | return data_all
34 |
35 | def __getitem__(self, idx):
36 | return self.data_all[idx]
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from trainer.trainer import *
2 | from data_loader import *
3 | from model.Thgnn import *
4 | import warnings
5 | import torch
6 | import os
7 | from torch.utils.data import DataLoader
8 | import torch.optim as optim
9 | from torch.optim.lr_scheduler import StepLR
10 | import pandas as pd
11 | from pandas.core.frame import DataFrame
12 | from tqdm import tqdm
13 |
14 | warnings.filterwarnings("ignore")
15 | t_float = torch.float64
16 | torch.multiprocessing.set_sharing_strategy('file_system')
17 |
18 | class Args:
19 | def __init__(self, gpu=0, subtask="regression"):
20 | # device
21 | self.gpu = str(gpu)
22 | self.device = 'cpu'
23 | # data settings
24 | adj_threshold = 0.1
25 | self.adj_str = str(int(100*adj_threshold))
26 | self.pos_adj_dir = "pos_adj_" + self.adj_str
27 | self.neg_adj_dir = "neg_adj_" + self.adj_str
28 | self.feat_dir = "features"
29 | self.label_dir = "label"
30 | self.mask_dir = "mask"
31 | self.data_start = data_start
32 | self.data_middle = data_middle
33 | self.data_end = data_end
34 | self.pre_data = pre_data
35 | # epoch settings
36 | self.max_epochs = 60
37 | self.epochs_eval = 10
38 | # learning rate settings
39 | self.lr = 0.0002
40 | self.gamma = 0.3
41 | # model settings
42 | self.hidden_dim = 128
43 | self.num_heads = 8
44 | self.out_features = 32
45 | self.model_name = "StockHeteGAT"
46 | self.batch_size = 1
47 | self.loss_fcn = mse_loss
48 | # save model settings
49 | self.save_path = os.path.join(os.path.abspath('.'), "/home/THGNN-main/data/model_saved/")
50 | self.load_path = self.save_path
51 | self.save_name = self.model_name + "_hidden_" + str(self.hidden_dim) + "_head_" + str(self.num_heads) + \
52 | "_outfeat_" + str(self.out_features) + "_batchsize_" + str(self.batch_size) + "_adjth_" + \
53 | str(self.adj_str)
54 | self.epochs_save_by = 60
55 | self.sub_task = subtask
56 | eval("self.{}".format(self.sub_task))()
57 |
58 | def regression(self):
59 | self.save_name = self.save_name + "_reg_rank_"
60 | self.loss_fcn = mse_loss
61 | self.label_dir = self.label_dir + "_regression"
62 | self.mask_dir = self.mask_dir + "_regression"
63 |
64 | def regression_binary(self):
65 | self.save_name = self.save_name + "_reg_binary_"
66 | self.loss_fcn = mse_loss
67 | self.label_dir = self.label_dir + "_twoclass"
68 | self.mask_dir = self.mask_dir + "_twoclass"
69 |
70 | def classification_binary(self):
71 | self.save_name = self.save_name + "_clas_binary_"
72 | self.loss_fcn = bce_loss
73 | self.label_dir = self.label_dir + "_twoclass"
74 | self.mask_dir = self.mask_dir + "_twoclass"
75 |
76 | def classification_tertiary(self):
77 | self.save_name = self.save_name + "_clas_tertiary_"
78 | self.loss_fcn = bce_loss
79 | self.label_dir = self.label_dir + "_threeclass"
80 | self.mask_dir = self.mask_dir + "_threeclass"
81 |
82 |
83 | def fun_train_predict(data_start, data_middle, data_end, pre_data):
84 | args = Args()
85 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
86 | dataset = AllGraphDataSampler(base_dir="/home/THGNN-main/data/data_train_predict/", data_start=data_start,
87 | data_middle=data_middle, data_end=data_end)
88 | val_dataset = AllGraphDataSampler(base_dir="/home/THGNN-main/data/data_train_predict/", mode="val", data_start=data_start,
89 | data_middle=data_middle, data_end=data_end)
90 | dataset_loader = DataLoader(dataset, batch_size=args.batch_size, pin_memory=True, collate_fn=lambda x: x)
91 | val_dataset_loader = DataLoader(val_dataset, batch_size=1, pin_memory=True)
92 | model = eval(args.model_name)(hidden_dim=args.hidden_dim, num_heads=args.num_heads,
93 | out_features=args.out_features).to(args.device)
94 |
95 | # train
96 | optimizer = optim.Adam(model.parameters(), lr=args.lr)
97 | cold_scheduler = StepLR(optimizer=optimizer, step_size=5000, gamma=0.9, last_epoch=-1)
98 | default_scheduler = cold_scheduler
99 | print('start training')
100 | for epoch in range(args.max_epochs):
101 | train_loss = train_epoch(epoch=epoch, args=args, model=model, dataset_train=dataset_loader,
102 | optimizer=optimizer, scheduler=default_scheduler, loss_fcn=mse_loss)
103 | if epoch % args.epochs_eval == 0:
104 | eval_loss, _ = eval_epoch(args=args, model=model, dataset_eval=val_dataset_loader, loss_fcn=mse_loss)
105 | print('Epoch: {}/{}, train loss: {:.6f}, val loss: {:.6f}'.format(epoch + 1, args.max_epochs, train_loss,
106 | eval_loss))
107 | else:
108 | print('Epoch: {}/{}, train loss: {:.6f}'.format(epoch + 1, args.max_epochs, train_loss))
109 | if (epoch + 1) % args.epochs_save_by == 0:
110 | print("save model!")
111 | state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch + 1}
112 | torch.save(state, os.path.join(args.save_path, pre_data + "_epoch_" + str(epoch + 1) + ".dat"))
113 |
114 | # predict
115 | checkpoint = torch.load(os.path.join(args.load_path, pre_data + "_epoch_" + str(epoch + 1) + ".dat"))
116 | model.load_state_dict(checkpoint['model'])
117 | data_code = os.listdir('/home/THGNN-main/data/daily_stock')
118 | data_code = sorted(data_code)
119 | data_code_last = data_code[data_middle:data_end]
120 | df_score=pd.DataFrame()
121 | for i in tqdm(range(len(val_dataset))):
122 | df = pd.read_csv('/home/THGNN-main/data/daily_stock/' + data_code_last[i], dtype=object)
123 | tmp_data = val_dataset[i]
124 | pos_adj, neg_adj, features, labels, mask = extract_data(tmp_data, args.device)
125 | model.train()
126 | logits = model(features, pos_adj, neg_adj)
127 | result = logits.data.cpu().numpy().tolist()
128 | result_new = []
129 | for j in range(len(result)):
130 | result_new.append(result[j][0])
131 | res = {"score": result_new}
132 | res = DataFrame(res)
133 | df['score'] = res
134 | df_score=pd.concat([df_score,df])
135 |
136 | #df.to_csv('prediction/' + data_code_last[i], encoding='utf-8-sig', index=False)
137 | df_score.to_csv('/home/THGNN-main/data/prediction/pred.csv')
138 | print(df_score)
139 |
140 | if __name__ == "__main__":
141 | data_start = 20
142 | data_middle = 39
143 | data_end = data_middle+4
144 | pre_data = '2022-12-29'
145 | fun_train_predict(data_start, data_middle, data_end, pre_data)
--------------------------------------------------------------------------------
/main.sh:
--------------------------------------------------------------------------------
1 | python main.py > train.log 2>&1
--------------------------------------------------------------------------------
/model/Thgnn.py:
--------------------------------------------------------------------------------
1 | from torch.nn.parameter import Parameter
2 | from torch.nn.modules.module import Module
3 | import torch.nn as nn
4 | import torch
5 | import math
6 |
7 | class GraphAttnMultiHead(Module):
8 | def __init__(self, in_features, out_features, negative_slope=0.2, num_heads=4, bias=True, residual=True):
9 | super(GraphAttnMultiHead, self).__init__()
10 | self.num_heads = num_heads
11 | self.out_features = out_features
12 | self.weight = Parameter(torch.FloatTensor(in_features, num_heads * out_features))
13 | self.weight_u = Parameter(torch.FloatTensor(num_heads, out_features, 1))
14 | self.weight_v = Parameter(torch.FloatTensor(num_heads, out_features, 1))
15 | self.leaky_relu = nn.LeakyReLU(negative_slope=negative_slope)
16 | self.residual = residual
17 | if self.residual:
18 | self.project = nn.Linear(in_features, num_heads*out_features)
19 | else:
20 | self.project = None
21 | if bias:
22 | self.bias = Parameter(torch.FloatTensor(1, num_heads * out_features))
23 | else:
24 | self.register_parameter('bias', None)
25 | self.reset_parameters()
26 |
27 | def reset_parameters(self):
28 | stdv = 1. / math.sqrt(self.weight.size(-1))
29 | if self.bias is not None:
30 | self.bias.data.uniform_(-stdv, stdv)
31 | self.weight.data.uniform_(-stdv, stdv)
32 | stdv = 1. / math.sqrt(self.weight_u.size(-1))
33 | self.weight_u.data.uniform_(-stdv, stdv)
34 | self.weight_v.data.uniform_(-stdv, stdv)
35 |
36 | def forward(self, inputs, adj_mat, requires_weight=False):
37 | support = torch.mm(inputs, self.weight)
38 | support = support.reshape(-1, self.num_heads, self.out_features).permute(dims=(1, 0, 2))
39 | f_1 = torch.matmul(support, self.weight_u).reshape(self.num_heads, 1, -1)
40 | f_2 = torch.matmul(support, self.weight_v).reshape(self.num_heads, -1, 1)
41 | logits = f_1 + f_2
42 | weight = self.leaky_relu(logits)
43 | masked_weight = torch.mul(weight, adj_mat).to_sparse()
44 | attn_weights = torch.sparse.softmax(masked_weight, dim=2).to_dense()
45 | support = torch.matmul(attn_weights, support)
46 | support = support.permute(dims=(1, 0, 2)).reshape(-1, self.num_heads * self.out_features)
47 | if self.bias is not None:
48 | support = support + self.bias
49 | if self.residual:
50 | support = support + self.project(inputs)
51 | if requires_weight:
52 | return support, attn_weights
53 | else:
54 | return support, None
55 |
56 |
57 | class PairNorm(nn.Module):
58 | def __init__(self, mode='PN', scale=1):
59 | assert mode in ['None', 'PN', 'PN-SI', 'PN-SCS']
60 | super(PairNorm, self).__init__()
61 | self.mode = mode
62 | self.scale = scale
63 |
64 | def forward(self, x):
65 | if self.mode == 'None':
66 | return x
67 | col_mean = x.mean(dim=0)
68 | if self.mode == 'PN':
69 | x = x - col_mean
70 | rownorm_mean = (1e-6 + x.pow(2).sum(dim=1).mean()).sqrt()
71 | x = self.scale * x / rownorm_mean
72 | if self.mode == 'PN-SI':
73 | x = x - col_mean
74 | rownorm_individual = (1e-6 + x.pow(2).sum(dim=1, keepdim=True)).sqrt()
75 | x = self.scale * x / rownorm_individual
76 | if self.mode == 'PN-SCS':
77 | rownorm_individual = (1e-6 + x.pow(2).sum(dim=1, keepdim=True)).sqrt()
78 | x = self.scale * x / rownorm_individual - col_mean
79 | return x
80 |
81 |
82 | class GraphAttnSemIndividual(Module):
83 | def __init__(self, in_features, hidden_size=128, act=nn.Tanh()):
84 | super(GraphAttnSemIndividual, self).__init__()
85 | self.project = nn.Sequential(nn.Linear(in_features, hidden_size),
86 | act,
87 | nn.Linear(hidden_size, 1, bias=False))
88 |
89 | def forward(self, inputs, requires_weight=False):
90 | w = self.project(inputs)
91 | beta = torch.softmax(w, dim=1)
92 | if requires_weight:
93 | return (beta * inputs).sum(1), beta
94 | else:
95 | return (beta * inputs).sum(1), None
96 |
97 |
98 | class StockHeteGAT(nn.Module):
99 | def __init__(self, in_features=6, out_features=8, num_heads=8, hidden_dim=64, num_layers=1):
100 | super(StockHeteGAT, self).__init__()
101 | self.encoding = nn.GRU(
102 | input_size=in_features,
103 | hidden_size=hidden_dim,
104 | num_layers=num_layers,
105 | batch_first=True,
106 | bidirectional=False,
107 | dropout=0.1
108 | )
109 | self.pos_gat = GraphAttnMultiHead(
110 | in_features=hidden_dim,
111 | out_features=out_features,
112 | num_heads=num_heads
113 | )
114 | self.neg_gat = GraphAttnMultiHead(
115 | in_features=hidden_dim,
116 | out_features=out_features,
117 | num_heads=num_heads
118 | )
119 | self.mlp_self = nn.Linear(hidden_dim, hidden_dim)
120 | self.mlp_pos = nn.Linear(out_features*num_heads, hidden_dim)
121 | self.mlp_neg = nn.Linear(out_features*num_heads, hidden_dim)
122 | self.pn = PairNorm(mode='PN-SI')
123 | self.sem_gat = GraphAttnSemIndividual(in_features=hidden_dim,
124 | hidden_size=hidden_dim,
125 | act=nn.Tanh())
126 | self.predictor = nn.Sequential(
127 | nn.Linear(hidden_dim, 1),
128 | nn.Sigmoid()
129 | )
130 |
131 | for m in self.modules():
132 | if isinstance(m, nn.Linear):
133 | nn.init.xavier_uniform_(m.weight, gain=0.02)
134 |
135 | def forward(self, inputs, pos_adj, neg_adj, requires_weight=False):
136 | _, support = self.encoding(inputs)
137 | support = support.squeeze()
138 | pos_support, pos_attn_weights = self.pos_gat(support, pos_adj, requires_weight)
139 | neg_support, neg_attn_weights = self.neg_gat(support, neg_adj, requires_weight)
140 | support = self.mlp_self(support)
141 | pos_support = self.mlp_pos(pos_support)
142 | neg_support = self.mlp_neg(neg_support)
143 | all_embedding = torch.stack((support, pos_support, neg_support), dim=1)
144 | all_embedding, sem_attn_weights = self.sem_gat(all_embedding, requires_weight)
145 | all_embedding = self.pn(all_embedding)
146 | if requires_weight:
147 | return self.predictor(all_embedding), (pos_attn_weights, neg_attn_weights, sem_attn_weights)
148 | else:
149 | return self.predictor(all_embedding)
150 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.8.0
2 | pandas==1.5.3
3 | tqdm==4.65.0
4 | numpy==1.23.5
5 | networkx==3.1
--------------------------------------------------------------------------------
/trainer/trainer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | def mse_loss(logits, targets):
5 | mse = nn.MSELoss()
6 | loss = mse(logits.squeeze(), targets)
7 | return loss
8 |
9 |
10 | def bce_loss(logits, targets):
11 | bce = nn.BCELoss()
12 | loss = bce(logits.squeeze(), targets)
13 | return loss
14 |
15 |
16 | def evaluate(model, features, adj_pos, adj_neg, labels, mask, loss_func=nn.L1Loss()):
17 | model.eval()
18 | with torch.no_grad():
19 | logits = model(features, adj_pos, adj_neg)
20 |
21 | loss = loss_func(logits,labels)
22 | return loss, logits
23 |
24 |
25 | def extract_data(data_dict, device):
26 | pos_adj = data_dict['pos_adj'].to(device).squeeze()
27 | neg_adj = data_dict['neg_adj'].to(device).squeeze()
28 | features = data_dict['features'].to(device).squeeze()
29 | labels = data_dict['labels'].to(device).squeeze()
30 | mask = data_dict['mask']
31 | return pos_adj, neg_adj, features, labels, mask
32 |
33 |
34 | def train_epoch(epoch, args, model, dataset_train, optimizer, scheduler, loss_fcn):
35 | model.train()
36 | loss_return = 0
37 | for batch_data in dataset_train:
38 | for batch_idx, data in enumerate(batch_data):
39 | model.zero_grad()
40 | pos_adj, neg_adj, features, labels, mask = extract_data(data, args.device)
41 | logits = model(features, pos_adj, neg_adj)
42 | loss = loss_fcn(logits[mask], labels[mask])
43 | loss.backward()
44 | optimizer.step()
45 | scheduler.step()
46 | if batch_idx == 0:
47 | loss_return += loss.data
48 | return loss_return/len(dataset_train)
49 |
50 |
51 | def eval_epoch(args, model, dataset_eval, loss_fcn):
52 | loss = 0.
53 | logits = None
54 | for batch_idx, data in enumerate(dataset_eval):
55 | pos_adj, neg_adj, features, labels, mask = extract_data(data, args.device)
56 | loss, logits = evaluate(model, features, pos_adj, neg_adj, labels, mask, loss_func=loss_fcn)
57 | break
58 | return loss, logits
59 |
60 | def train_epoch(epoch, args, model, dataset_train, optimizer, scheduler, loss_fcn):
61 | model.train()
62 | loss_return = 0
63 | for batch_data in dataset_train:
64 | for batch_idx, data in enumerate(batch_data):
65 | model.zero_grad()
66 | pos_adj, neg_adj, features, labels, mask = extract_data(data, args.device)
67 | logits = model(features, pos_adj, neg_adj)
68 | loss = loss_fcn(logits[mask], labels[mask])
69 | loss.backward()
70 | optimizer.step()
71 | scheduler.step()
72 | if batch_idx == 0:
73 | loss_return += loss.data
74 | return loss_return/len(dataset_train)
75 |
76 |
77 | def eval_epoch(args, model, dataset_eval, loss_fcn):
78 | loss = 0.
79 | logits = None
80 | for batch_idx, data in enumerate(dataset_eval):
81 | pos_adj, neg_adj, features, labels, mask = extract_data(data, args.device)
82 | loss, logits = evaluate(model, features, pos_adj, neg_adj, labels, mask, loss_func=loss_fcn)
83 | break
84 | return loss, logits
--------------------------------------------------------------------------------
/utils/generate_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import pickle
4 | import numpy as np
5 | from tqdm import tqdm
6 | import networkx as nx
7 | import pandas as pd
8 | from torch.autograd import Variable
9 |
10 | feature_cols = ['open','high','low','close','to','vol']
11 |
12 | path1 = "/home/THGNN-main/data/csi300.pkl"
13 | df1 = pickle.load(open(path1, 'rb'), encoding='utf-8')
14 | relation = os.listdir('/home/THGNN-main/data/relation/')
15 | relation = sorted(relation)
16 | date_unique=df1['dt'].unique()
17 | stock_trade_data=date_unique.tolist()
18 | stock_trade_data.sort()
19 |
20 | df1['dt']=df1['dt'].astype('datetime64')
21 |
22 | def fun(relation_dt, start_dt_month, end_dt_month,df1):
23 | prev_date_num = 20
24 | adj_all = pd.read_csv('/home/THGNN-main/data/relation/'+relation_dt+'.csv', index_col=0)
25 | adj_stock_set = list(adj_all.index)
26 | pos_g = nx.Graph(adj_all > 0.1)
27 | pos_adj = nx.adjacency_matrix(pos_g).toarray()
28 | pos_adj = pos_adj - np.diag(np.diag(pos_adj))
29 | pos_adj = torch.from_numpy(pos_adj).type(torch.float32)
30 | neg_g = nx.Graph(adj_all < -0.1)
31 | neg_adj = nx.adjacency_matrix(neg_g)
32 | neg_adj.data = np.ones(neg_adj.data.shape)
33 | neg_adj = neg_adj.toarray()
34 | neg_adj = neg_adj - np.diag(np.diag(neg_adj))
35 | neg_adj = torch.from_numpy(neg_adj).type(torch.float32)
36 | print('neg_adj over')
37 | print(neg_adj.shape)
38 | dts = stock_trade_data[stock_trade_data.index(start_dt_month):stock_trade_data.index(end_dt_month)+1]
39 | print(dts)
40 | for i in tqdm(range(len(dts))):
41 | end_data=dts[i]
42 | start_data = stock_trade_data[stock_trade_data.index(end_data)-(prev_date_num - 1)]
43 | df2 = df1.loc[df1['dt'] <= end_data]
44 | df2 = df2.loc[df2['dt'] >= start_data]
45 | code = adj_stock_set
46 | feature_all = []
47 | mask = []
48 | labels = []
49 | day_last_code = []
50 | for j in range(len(code)):
51 | df3 = df2.loc[df2['code'] == code[j]]
52 | y = df3[feature_cols].values
53 | if y.T.shape[1] == prev_date_num:
54 | one = []
55 | feature_all.append(y)
56 | mask.append(True)
57 | label = df3.loc[df3['dt'] == end_data]['label'].values
58 | labels.append(label[0])
59 | one.append(code[j])
60 | one.append(end_data)
61 | day_last_code.append(one)
62 | feature_all = np.array(feature_all)
63 | features = torch.from_numpy(feature_all).type(torch.float32)
64 | mask = [True]*len(labels)
65 | labels = torch.tensor(labels, dtype=torch.float32)
66 | result = {'pos_adj': Variable(pos_adj), 'neg_adj': Variable(neg_adj), 'features': Variable(features),
67 | 'labels': Variable(labels), 'mask': mask}
68 | with open('/home/THGNN-main/data/data_train_predict/'+end_data+'.pkl', 'wb') as f:
69 | pickle.dump(result, f)
70 | df = pd.DataFrame(columns=['code', 'dt'], data=day_last_code)
71 | df.to_csv('/home/THGNN-main/data/daily_stock/'+end_data+'.csv', header=True, index=False, encoding='utf_8_sig')
72 |
73 | #The first parameter and third parameters indicate the last trading day of each month, and the second parameter indicates the first trading day of each month.
74 | # for i in ['2020','2021','2022']:
75 | # for j in ['01','02','03','04','05','06','07','08','09','10','11','12']:
76 | # stock_m=[k for k in stock_trade_data if k>i+'-'+j and k 1:
32 | pool = mp.Pool(processes=processes)
33 | args_all = [(ref_dict[code], ref_dict, n) for code in codes]
34 | results = [pool.apply_async(calculate_pccs, args=args) for args in args_all]
35 | output = [o.get() for o in results]
36 | data = np.stack(output)
37 | return pd.DataFrame(data=data, index=codes, columns=codes)
38 | data = np.zeros([len(codes), len(codes)])
39 | for i in tqdm(range(len(codes))):
40 | data[i, :] = calculate_pccs(ref_dict[codes[i]], ref_dict, n)
41 | return pd.DataFrame(data=data, index=codes, columns=codes)
42 |
43 | path1 = "/home/THGNN-main/data/csi300.pkl"
44 | df1 = pickle.load(open(path1, 'rb'), encoding='utf-8')
45 | #prev_date_num Indicates the number of days in which stock correlation is calculated
46 | prev_date_num = 20
47 | date_unique=df1['dt'].unique()
48 | stock_trade_data=date_unique.tolist()
49 | stock_trade_data.sort()
50 | stock_num=df1.code.unique().shape[0]
51 | #dt is the last trading day of each month
52 | dt=['2022-11-30','2022-12-30']
53 | # for i in ['2020','2021','2022']:
54 | # for j in ['01','02','03','04','05','06','07','08','09','10','11','12']:
55 | # stock_m=[k for k in stock_trade_data if k>i+'-'+j and k= start_data]
65 | code = sorted(list(set(df2['code'].values.tolist())))
66 | test_tmp = {}
67 | for j in tqdm(range(len(code))):
68 | df3 = df2.loc[df2['code'] == code[j]]
69 | y = df3[feature_cols].values
70 | if y.T.shape[1] == prev_date_num:
71 | test_tmp[code[j]] = y.T
72 | t1 = time.time()
73 | result = stock_cor_matrix(test_tmp, list(test_tmp.keys()), prev_date_num, processes=1)
74 | result=result.fillna(0)
75 | for i in range(0,stock_num):
76 | result.iloc[i,i]=1
77 | t2 = time.time()
78 | print('time cost', t2 - t1, 's')
79 | result.to_csv("/home/THGNN-main/data/relation/"+str(end_data)+".csv")
80 |
--------------------------------------------------------------------------------