├── .gitattributes ├── .gitignore ├── LICENSE.md ├── README.md ├── dev-requirements.txt ├── mapillm ├── .DS_Store ├── __init__.py ├── agent.py ├── app.py ├── mapi_tools.py ├── reaction_prediction.py ├── utils.py └── version.py ├── requirements.txt ├── setup.py └── tests ├── baselines ├── ICL.ipynb ├── KNN.ipynb ├── RF.ipynb ├── RNN.ipynb └── get_dataset.py └── test.ipynb /.gitattributes: -------------------------------------------------------------------------------- 1 | *.7z filter=lfs diff=lfs merge=lfs -text 2 | *.arrow filter=lfs diff=lfs merge=lfs -text 3 | *.bin filter=lfs diff=lfs merge=lfs -text 4 | *.bz2 filter=lfs diff=lfs merge=lfs -text 5 | *.ckpt filter=lfs diff=lfs merge=lfs -text 6 | *.ftz filter=lfs diff=lfs merge=lfs -text 7 | *.gz filter=lfs diff=lfs merge=lfs -text 8 | *.h5 filter=lfs diff=lfs merge=lfs -text 9 | *.joblib filter=lfs diff=lfs merge=lfs -text 10 | *.lfs.* filter=lfs diff=lfs merge=lfs -text 11 | *.mlmodel filter=lfs diff=lfs merge=lfs -text 12 | *.model filter=lfs diff=lfs merge=lfs -text 13 | *.msgpack filter=lfs diff=lfs merge=lfs -text 14 | *.npy filter=lfs diff=lfs merge=lfs -text 15 | *.npz filter=lfs diff=lfs merge=lfs -text 16 | *.onnx filter=lfs diff=lfs merge=lfs -text 17 | *.ot filter=lfs diff=lfs merge=lfs -text 18 | *.parquet filter=lfs diff=lfs merge=lfs -text 19 | *.pb filter=lfs diff=lfs merge=lfs -text 20 | *.pickle filter=lfs diff=lfs merge=lfs -text 21 | *.pkl filter=lfs diff=lfs merge=lfs -text 22 | *.pt filter=lfs diff=lfs merge=lfs -text 23 | *.pth filter=lfs diff=lfs merge=lfs -text 24 | *.rar filter=lfs diff=lfs merge=lfs -text 25 | *.safetensors filter=lfs diff=lfs merge=lfs -text 26 | saved_model/**/* filter=lfs diff=lfs merge=lfs -text 27 | *.tar.* filter=lfs diff=lfs merge=lfs -text 28 | *.tflite filter=lfs diff=lfs merge=lfs -text 29 | *.tgz filter=lfs diff=lfs merge=lfs -text 30 | *.wasm filter=lfs diff=lfs merge=lfs -text 31 | *.xz filter=lfs diff=lfs merge=lfs -text 32 | *.zip filter=lfs diff=lfs merge=lfs -text 33 | *.zst filter=lfs diff=lfs merge=lfs -text 34 | *tfevents* filter=lfs diff=lfs merge=lfs -text 35 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Mayk Caldas 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MAPI_LLM 🧱🦜️🔗 2 | 3 | ## A LLM application developed during the LLM March *MADNESS* Hackathon 4 | - Developed by: 5 | - Mayk Caldas ([@maykcaldas](https://github.com/maykcaldas)), and 6 | - Sam Cox ([@SamCox822](https://github.com/SamCox822)) 7 | 8 | ### What is this? 9 | - This is a demo of an app that can answer questions about material science using the [LangChain🦜️🔗](https://github.com/hwchase17/langchain/) and the [Materials Project API](https://materialsproject.org/). 10 | - Its behavior is based on Large Language Models (LLM), and it aims to be a tool to help scientists with quick predictions of numerous properties of materials. 11 | 12 | A brief video presentation on how to use this app can be found [here](https://twitter.com/SamCox822/status/1641484192566460416) and [here](https://twitter.com/Kyam888/status/1641485895189639192). 13 | 14 | >It is a work in progress, so please be patient with it. 15 | >We are working on a systematic validation. 16 | 17 | 18 | Please, notice that MAPI_LLM is under the MIT license, but The Materials Project has its own [terms of use](https://materialsproject.org/about/terms) and any usage of their data is subject to the appropriate [terms of use](https://materialsproject.org/about/terms). 19 | -------------------------------------------------------------------------------- /dev-requirements.txt: -------------------------------------------------------------------------------- 1 | cloudpickle 2 | datasets 3 | ase 4 | pymatgen 5 | scikit-learn 6 | xgboost 7 | transformers 8 | torch -------------------------------------------------------------------------------- /mapillm/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maykcaldas/MAPI_LLM/10b40c85f2c8fa132418014c4d5fe7711566cd48/mapillm/.DS_Store -------------------------------------------------------------------------------- /mapillm/__init__.py: -------------------------------------------------------------------------------- 1 | from .version import __version__ 2 | from .agent import Agent 3 | -------------------------------------------------------------------------------- /mapillm/agent.py: -------------------------------------------------------------------------------- 1 | from .mapi_tools import mapi_tools 2 | from .utils import common_tools 3 | # from reaction_prediction import SynthesisReactions 4 | from langchain import hub, agents 5 | from langchain.agents import AgentExecutor, create_react_agent 6 | from langchain_openai import ChatOpenAI 7 | import os 8 | 9 | # reaction = SynthesisReactions() 10 | 11 | print("apsidjaiposdhu") 12 | 13 | class Agent: 14 | def __init__(self, openai_api_key, mapi_api_key): 15 | self.llm = ChatOpenAI( 16 | temperature=0.1, 17 | model="gpt-3.5-turbo", 18 | streaming=True, 19 | ) 20 | self.tools = ( 21 | mapi_tools + 22 | # reaction.get_tools() + 23 | # agents.load_tools(["llm-math", "python_repl"], llm=self.llm) + 24 | common_tools 25 | ) 26 | 27 | self.prompt = hub.pull("hwchase17/react") 28 | self.agent = create_react_agent(self.llm, self.tools, self.prompt) 29 | self.agent_executor = AgentExecutor(agent=self.agent, tools=self.tools, verbose=True) 30 | 31 | def run(self, query: str): 32 | return self.agent_executor.invoke({ 33 | 'input': query 34 | }) 35 | 36 | if __name__ == "__main__": 37 | import os 38 | from dotenv import load_dotenv 39 | load_dotenv(override=True) 40 | 41 | a = Agent(openai_api_key=os.getenv("OPENAI_API_KEY"), mapi_api_key=os.getenv("MAPI_API_KEY")) 42 | a.run("What's the band gap of Fe2O4?") 43 | -------------------------------------------------------------------------------- /mapillm/app.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | import numpy as np 3 | import agent 4 | import os 5 | 6 | css_style = """ 7 | .gradio-container { 8 | font-family: "IBM Plex Mono"; 9 | } 10 | """ 11 | 12 | def agent_run(q, openai_api_key, mapi_api_key, serp_api_key): 13 | os.environ["OPENAI_API_KEY"]=openai_api_key 14 | os.environ["MAPI_API_KEY"]=mapi_api_key 15 | os.environ["SERPAPI_API_KEY"]=serp_api_key 16 | agent_chain = agent.Agent(openai_api_key, mapi_api_key) 17 | try: 18 | out = agent_chain.run(q) 19 | except Exception as err: 20 | out = f"Something went wrong. Please try again.\nError: {err}" 21 | return out 22 | 23 | with gr.Blocks(css=css_style) as demo: 24 | gr.Markdown(f''' 25 | # A LLM application developed during the LLM March *MADNESS* Hackathon 26 | - Developed by: Mayk Caldas ([@maykcaldas](https://github.com/maykcaldas)) and Sam Cox ([@SamCox822](https://github.com/SamCox822)) 27 | 28 | ## What is this? 29 | - This is a demo of an app that can answer questions about material science using the [LangChain🦜️🔗](https://github.com/hwchase17/langchain/) and the [Materials Project API](https://materialsproject.org/). 30 | - Its behavior is based on Large Language Models (LLM), and it aims to be a tool to help scientists with quick predictions of numerous properties of materials. 31 | It is a work in progress, so please be patient with it. We are working on a systematic validation. 32 | 33 | 34 | ### Some keys are needed to use it: 35 | 1. An openAI API key ( [Check it here](https://platform.openai.com/account/api-keys) ) 36 | 2. A Material Project's API key ( [Check it here](https://materialsproject.org/api#api-key) ) 37 | 3. A SERP API key ( [Check it here](https://serpapi.com/account-api) ) 38 | - Only used if the chain runs a web search to answer the question. 39 | ''') 40 | with gr.Accordion("List of properties we developed tools for", open=False): 41 | gr.Markdown(f""" 42 | - Classification tasks: "Is the material AnByCz stable?" 43 | - Stable, 44 | - Magnetic, 45 | - Gap direct, and 46 | - Metal. 47 | - Regression tasks: "What is the band gap of the material AnByCz?" 48 | - Band gap, 49 | - Volume, 50 | - Density, 51 | - Atomic density, 52 | - Formation energy per atom, 53 | - Energy per atom, 54 | - Electronic energy, 55 | - Ionic energy, and 56 | - Total energy. 57 | - Reaction procedure for synthesis proposal: "Give me a reaction procedure to synthesize the material AnByCz"(under development) 58 | """) 59 | openai_api_key = gr.Textbox( 60 | label="OpenAI API Key", placeholder="sk-...", type="password") 61 | mapi_api_key = gr.Textbox( 62 | label="Material Project API Key", placeholder="...", type="password") 63 | serp_api_key = gr.Textbox( 64 | label="Serp API Key", placeholder="...", type="password") 65 | with gr.Tab("MAPI Query"): 66 | text_input = gr.Textbox(label="", placeholder="Enter question here...") 67 | text_output = gr.Textbox(placeholder="Your answer will appear here...") 68 | text_button = gr.Button("Ask!") 69 | 70 | text_button.click(agent_run, inputs=[text_input, openai_api_key, mapi_api_key, serp_api_key], outputs=text_output) 71 | 72 | demo.launch() 73 | -------------------------------------------------------------------------------- /mapillm/mapi_tools.py: -------------------------------------------------------------------------------- 1 | from mp_api.client import MPRester 2 | from emmet.core.summary import HasProps 3 | import openai 4 | import langchain 5 | from langchain_openai import ChatOpenAI 6 | from langchain.agents import initialize_agent 7 | from langchain.agents import Tool, tool 8 | from langchain.prompts.few_shot import FewShotPromptTemplate 9 | from langchain.prompts.prompt import PromptTemplate 10 | from langchain_community.vectorstores import FAISS 11 | from langchain_openai import OpenAIEmbeddings 12 | from langchain.prompts.example_selector import (MaxMarginalRelevanceExampleSelector, 13 | SemanticSimilarityExampleSelector) 14 | import requests 15 | import warnings 16 | from rdkit import Chem 17 | import pandas as pd 18 | import os 19 | 20 | class MAPITools: 21 | def __init__(self): 22 | self.model = 'gpt-4-turbo' #maybe change to gpt-4 when ready 23 | self.k=10 24 | 25 | def get_material_atoms(self, formula): 26 | f'''Receives a material formula and returns the atoms symbols present in it separated by comma.''' 27 | import re 28 | pattern = re.compile(r"([A-Z][a-z]*)(\d*)") 29 | matches = pattern.findall(formula) 30 | atoms = [] 31 | for m in matches: 32 | atom, count = m 33 | count = int(count) if count else 1 34 | atoms.append((atom, count)) 35 | return ",".join([a[0] for a in atoms]) 36 | 37 | def check_prop_by_formula(self, formula): 38 | raise NotImplementedError('Should be implemented in children classes') 39 | 40 | def search_similars_by_atom(self, atoms): 41 | f'''This function receives a string with the atoms separated by comma as input and returns a list of similar materials.''' 42 | atoms = atoms.replace(" ", "") 43 | with MPRester(os.getenv("MAPI_API_KEY")) as mpr: 44 | docs = mpr.materials.summary.search(elements=atoms.split(','), fields=["formula_pretty", self.prop]) 45 | return docs 46 | 47 | def create_context_prompt(self, formula): 48 | raise NotImplementedError('Should be implemented in children classes') 49 | 50 | def LLM_predict(self, prompt): 51 | f''' This function receives a prompt generate with context by the create_context_prompt tool and request a completion to a language model. Then returns the completion.''' 52 | llm = ChatOpenAI( 53 | model_name=self.model, 54 | temperature=0.1, 55 | n=5, 56 | # best_of=5, 57 | # stop=["\n\n", "###", "#", "##"], 58 | ) 59 | return llm.invoke([prompt]).generations[0][0].text 60 | 61 | def get_tools(self): 62 | return [ 63 | Tool( 64 | name = "Get atoms in material", 65 | func = self.get_material_atoms, 66 | description = ( 67 | "Receives a material formula and returns the atoms symbols present in it separated by comma." 68 | ) 69 | ), 70 | Tool( 71 | name = f"Checks if material is {self.prop_name} by formula", 72 | func = self.check_prop_by_formula, 73 | description = ( 74 | f"This functions searches in the material project's API for the formula and returns if it is {self.prop_name} or not." 75 | ) 76 | ), 77 | # Tool( 78 | # name = "Search similar materials by atom", 79 | # func = self.search_similars_by_atom, 80 | # description = ( 81 | # "This function receives a string with the atoms separated by comma as input and returns a list of similar materials." 82 | # ) 83 | # ), 84 | Tool( 85 | name = f"Create {self.prop_name} context to LLM search", 86 | func = self.create_context_prompt, 87 | description = ( 88 | f"This function received a material formula as input and create a prompt to be inputed in the LLM_predict tool to predict if the material is {self.prop_name}." 89 | if isinstance(self, MAPI_class_tools) else 90 | f"This function received a material formula as input and create a prompt to be inputed in the LLM_predict tool to predict the {self.prop_name} of a material." 91 | ) 92 | ), 93 | Tool(name = "LLM prediction", 94 | func = self.LLM_predict, 95 | description = ( 96 | "This function receives a prompt generate with context by the create_context_prompt tool and request a completion to a language model. Then returns the completion" 97 | ) 98 | ) 99 | ] 100 | 101 | class MAPI_class_tools(MAPITools): 102 | def __init__(self, prop, prop_name, p_label, n_label): 103 | super().__init__() 104 | self.prop = prop 105 | self.prop_name = prop_name 106 | self.p_label = p_label 107 | self.n_label = n_label 108 | 109 | def check_prop_by_formula(self, formula): 110 | f''' This functions searches in the material project's API for the formula and returns if it is {self.prop_name} or not.''' 111 | with MPRester(os.getenv("MAPI_API_KEY")) as mpr: 112 | docs = mpr.materials.summary.search(formula=formula, fields=["formula_pretty", self.prop]) 113 | if len(docs) > 1: 114 | warnings.warn(f"More than one material found for {formula}. Will use the first one. Please, check the results.") 115 | if docs: 116 | if docs[0].formula_pretty == formula: 117 | return self.p_label if docs[0].model_dump()[self.prop] else self.n_label 118 | return f"Could not find any material while searching {formula}" 119 | 120 | def create_context_prompt(self, formula): 121 | f'''This function received a material formula as input and create a prompt to be inputed in the LLM_predict tool to predict if the formula is a {self.prop_name} material.''' 122 | elements = self.get_material_atoms(formula) 123 | similars = self.search_similars_by_atom(elements) 124 | similars = [ 125 | {'formula': ex.formula_pretty, 126 | 'prop': self.p_label if ex.model_dump()[self.prop] else self.n_label 127 | } for ex in similars 128 | ] 129 | examples = pd.DataFrame(similars).drop_duplicates().to_dict(orient="records") 130 | example_selector = MaxMarginalRelevanceExampleSelector.from_examples( 131 | examples, 132 | OpenAIEmbeddings(), 133 | FAISS, 134 | k=self.k, 135 | ) 136 | 137 | prefix=( 138 | f'You are a bot who can predict if a material is {self.prop_name}.\n' 139 | f'Given this list of known materials and the information if they are {self.p_label} or {self.n_label}, \n' 140 | f'you need to answer the question if the last material is {self.prop_name}:' 141 | ) 142 | prompt_template=PromptTemplate( 143 | input_variables=["formula", "prop"], 144 | template=f"Is {{formula}} a {self.prop_name} material?@@@\n{{prop}}###", 145 | ) 146 | suffix = f"Is {{formula}} a {self.prop_name} material?@@@\n" 147 | prompt = FewShotPromptTemplate( 148 | # examples=examples, 149 | example_prompt=prompt_template, 150 | example_selector=example_selector, 151 | prefix=prefix, 152 | suffix=suffix, 153 | input_variables=["formula"]) 154 | 155 | return prompt.format(formula=formula) 156 | 157 | class MAPI_reg_tools(MAPITools): 158 | # TODO: deal with units 159 | def __init__(self, prop, prop_name): 160 | super().__init__() 161 | self.prop = prop 162 | self.prop_name = prop_name 163 | 164 | def check_prop_by_formula(self, formula): 165 | f''' This functions searches in the material project's API for the formula and returns the {self.prop_name}.''' 166 | with MPRester(os.getenv("MAPI_API_KEY")) as mpr: 167 | docs = mpr.materials.summary.search(formula=formula, fields=["formula_pretty", self.prop]) 168 | if len(docs) > 1: 169 | warnings.warn(f"More than one material found for {formula}. Will use the first one. Please, check the results.") 170 | if docs: 171 | if docs[0].formula_pretty == formula: 172 | return docs[0].model_dump()[self.prop] 173 | elif docs[0].model_dump()[self.prop] is None: 174 | return f"There is no record of {self.prop_name} for {formula}" 175 | return f"Could not find any material while searching {formula}" 176 | 177 | def create_context_prompt(self, formula): 178 | f'''This function received a material formula as input and create a prompt to be inputed in the LLM_predict tool to predict the {self.prop_name} of the material.''' 179 | elements = self.get_material_atoms(formula) 180 | similars = self.search_similars_by_atom(elements) 181 | similars = [ 182 | {'formula': ex.formula_pretty, 183 | 'prop': f"{ex.model_dump()[self.prop]:2f}" if ex.model_dump()[self.prop] is not None else None 184 | } for ex in similars 185 | ] 186 | examples = pd.DataFrame(similars).drop_duplicates().dropna().to_dict(orient="records") 187 | 188 | example_selector = MaxMarginalRelevanceExampleSelector.from_examples( 189 | examples, 190 | OpenAIEmbeddings(), 191 | FAISS, 192 | k=self.k, 193 | ) 194 | 195 | prefix=( 196 | f'You are a bot who can predict the {self.prop_name} of a material .\n' 197 | f'Given this list of known materials and the measurement of their {self.prop_name}, \n' 198 | f'you need to predict what is the {self.prop_name} of the material:' 199 | 'The answer should be numeric and finish with ###' 200 | ) 201 | prompt_template=PromptTemplate( 202 | input_variables=["formula", "prop"], 203 | template=f"What is the {self.prop_name} for {{formula}}?@@@\n{{prop}}###", 204 | ) 205 | suffix = f"What is the {self.prop_name} for {{formula}}?@@@\n" 206 | prompt = FewShotPromptTemplate( 207 | # examples=examples, 208 | example_prompt=prompt_template, 209 | example_selector=example_selector, 210 | prefix=prefix, 211 | suffix=suffix, 212 | input_variables=["formula"]) 213 | 214 | return prompt.format(formula=formula) 215 | 216 | 217 | # Now we create the tools 218 | stability = MAPI_class_tools( 219 | "is_stable","stable","Stable","Unstable" 220 | ) 221 | magnetism = MAPI_class_tools( 222 | "is_magnetic","magnetic","Magnetic","Not magnetic" 223 | ) 224 | metal = MAPI_class_tools( 225 | "is_metal","metallic","Metal","Not metal" 226 | ) 227 | gap_direct = MAPI_class_tools( 228 | "is_gap_direct","gap direct","Gap direct","Gap indirect" 229 | ) 230 | band_gap = MAPI_reg_tools( 231 | "band_gap","band gap" 232 | ) 233 | energy_per_atom = MAPI_reg_tools( 234 | "energy_per_atom","energy per atom gap" 235 | ) 236 | formation_energy_per_atom = MAPI_reg_tools( 237 | "formation_energy_per_atom","formation energy per atom gap" 238 | ) 239 | volume = MAPI_reg_tools( 240 | "volume","volume" 241 | ) 242 | density = MAPI_reg_tools( 243 | "density","density" 244 | ) 245 | atomic_density = MAPI_reg_tools( 246 | "density_atomic","atomic density" 247 | ) 248 | electronic_energy = MAPI_reg_tools( 249 | "e_electronic","electronic energy" 250 | ) 251 | ionic_energy = MAPI_reg_tools( 252 | "e_ion","cationic energy" 253 | ) 254 | total_energy = MAPI_reg_tools( 255 | "e_total","total energy" 256 | ) 257 | 258 | mapi_tools = [] 259 | for prop in [stability, magnetism, metal, gap_direct, band_gap, 260 | energy_per_atom, formation_energy_per_atom, volume, density, atomic_density, electronic_energy, ionic_energy, total_energy]: 261 | # for prop in [band_gap]: 262 | mapi_tools += prop.get_tools() -------------------------------------------------------------------------------- /mapillm/reaction_prediction.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | from langchain.agents import Tool, tool 4 | # from mp_api.client import MPRester 5 | from pymatgen.ext.matproj import MPRester 6 | from rxn_network.entries.entry_set import GibbsEntrySet 7 | from rxn_network.enumerators.basic import BasicEnumerator 8 | 9 | class SynthesisReactions: 10 | def __init__(self, temp=900, stabl=0.025, exclusive_precursors=False, exclusive_targets=False): 11 | self.temp = temp 12 | self.stabl = stabl 13 | self.exclusive_precursors = exclusive_precursors 14 | self.exclusive_targets = exclusive_targets 15 | 16 | def _split_string(self, s): 17 | if isinstance(s, list): 18 | s = "".join(s) 19 | parts = re.findall('[a-z]+|[A-Z][a-z]*', s) 20 | letters_only = [re.sub(r'\d+', '', part) for part in parts] 21 | unique_letters = list(set(letters_only)) 22 | result = "-".join(unique_letters) 23 | return result 24 | 25 | def _get_rxn_from_precursor(self, precursors_formulas): 26 | prec = precursors_formulas.split(',') if "," in precursors_formulas else precursors_formulas 27 | 28 | with MPRester(os.getenv("MAPI_API_KEY")) as mpr: 29 | entries = mpr.get_entries_in_chemsys(self._split_string(prec)) 30 | 31 | gibbs_entries = GibbsEntrySet.from_computed_entries(entries, self.temp) 32 | filtered_entries = gibbs_entries.filter_by_stability(self.stabl) 33 | 34 | prec = [prec] if isinstance(prec, str) else prec 35 | be = BasicEnumerator(precursors=prec, exclusive_precursors=self.exclusive_precursors) 36 | rxns = be.enumerate(filtered_entries) 37 | try: 38 | rxn_choice = next(iter(rxns)) 39 | return str(rxn_choice) 40 | except: 41 | return "Error: No reactions found." 42 | 43 | def _get_rxn_from_target(self, targets_formulas): 44 | targets = targets_formulas.split(',') if "," in targets_formulas else targets_formulas 45 | 46 | with MPRester(os.getenv("MAPI_API_KEY")) as mpr: 47 | entries = mpr.get_entries_in_chemsys(self._split_string(targets)) 48 | 49 | gibbs_entries = GibbsEntrySet.from_computed_entries(entries, self.temp) 50 | filtered_entries = gibbs_entries.filter_by_stability(self.stabl) 51 | 52 | targets = [targets] if isinstance(targets, str) else targets 53 | 54 | be = BasicEnumerator(targets=targets, exclusive_targets=self.exclusive_targets) 55 | rxns = be.enumerate(filtered_entries) 56 | try: 57 | rxn_choice = next(iter(rxns)) 58 | return str(rxn_choice) 59 | except: 60 | return "Error: No reactions found." 61 | 62 | def _break_equation(self, equation): 63 | pattern = r'(\d*\.?\d*\s*[A-Za-z]+\d*|\+|\->)' 64 | pieces = re.findall(pattern, equation) 65 | equation_pieces = [] 66 | current_piece = '' 67 | for piece in pieces: 68 | if piece == '+' or piece == '->': 69 | equation_pieces.append(current_piece.strip()) 70 | equation_pieces.append(piece) 71 | current_piece = '' 72 | else: 73 | current_piece += piece + ' ' 74 | equation_pieces.append(current_piece.strip()) 75 | return equation_pieces 76 | 77 | def _convert_equation_pieces(self, equation_pieces): 78 | if '+' in equation_pieces: 79 | equation_pieces = [piece if piece != '+' else 'with' for piece in equation_pieces] 80 | equation_pieces = [piece if piece != '->' else 'to yield' for piece in equation_pieces] 81 | else: 82 | equation_pieces = [piece if piece != '->' else 'yields' for piece in equation_pieces] 83 | return equation_pieces 84 | 85 | def _split_equation_pieces(self, equation_pieces): 86 | new_pieces = [] 87 | for piece in equation_pieces: 88 | if piece in ["with", "to yield", "yields"]: 89 | new_pieces.append(piece) 90 | else: 91 | if re.match(r'^\d*\.\d+|\d+', piece): 92 | number_match = re.match(r'^\d*\.\d+|\d+', piece) 93 | number = number_match.group(0) 94 | rest = piece[len(number):] 95 | new_pieces.append(number) 96 | new_pieces.append(rest) 97 | else: 98 | new_pieces.append("1") 99 | new_pieces.append(piece) 100 | return new_pieces 101 | 102 | def _modify_mols(self, equation_pieces): 103 | for i, piece in enumerate(equation_pieces): 104 | if piece.replace('.', '', 1).isdigit(): 105 | equation_pieces[i] = f"{piece} mols" 106 | return equation_pieces 107 | 108 | def _combine_equation_pieces(self, equation_pieces): 109 | if 'with' in equation_pieces: 110 | equation_pieces.insert(0, 'mix') 111 | combined_string = ' '.join(equation_pieces) 112 | return combined_string 113 | 114 | def _process_equation(self, equation): 115 | equation_pieces = self._break_equation(equation) 116 | converted_pieces = self._convert_equation_pieces(equation_pieces) 117 | split_pieces = self._split_equation_pieces(converted_pieces) 118 | modified_pieces = self._modify_mols(split_pieces) 119 | combined_string = self._combine_equation_pieces(modified_pieces) 120 | return combined_string 121 | 122 | def get_reaction(self, input_string): 123 | input_parts = input_string.split(',', 1) 124 | if len(input_parts) != 2: 125 | raise ValueError("Invalid input format. Expected 'precursor' or 'target', followed by a comma, and then the list of formulas separated by a comma.") 126 | 127 | mode, formulas = input_parts 128 | mode = mode.lower().strip() 129 | 130 | if mode == "precursor": 131 | reaction = self._get_rxn_from_precursor(formulas) 132 | elif mode == "target": 133 | reaction = self._get_rxn_from_target(formulas) 134 | else: 135 | raise ValueError("Invalid mode. Expected 'precursor' or 'target'.") 136 | processed_reaction = self._process_equation(reaction) 137 | return processed_reaction 138 | 139 | def get_tools(self): 140 | return [ 141 | Tool( 142 | name = "Get a synthesis reaction for a material", 143 | func = self.get_reaction, 144 | description = ( 145 | "This function is useful for suggesting a synthesis reaction for a material. " 146 | "Give this tool a string containing either precursor or target, then a comma, followed by the formulas separated by comma as input and returns a synthesis reaction." 147 | "The mode is used to determine if the input is a precursor or a target material. " 148 | ) 149 | )] 150 | 151 | -------------------------------------------------------------------------------- /mapillm/utils.py: -------------------------------------------------------------------------------- 1 | from langchain.agents import Tool, tool 2 | import requests 3 | from langchain_community.llms import OpenAI 4 | from langchain.chains import LLMMathChain 5 | from langchain_community.utilities import SerpAPIWrapper 6 | import os 7 | from rdkit import Chem 8 | 9 | @tool 10 | def query2smiles(text): 11 | '''This function queries the one given molecule name and returns a SMILES string from the record''' 12 | try:#query the PubChem database 13 | r = requests.get('https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/' + text + '/property/IsomericSMILES/JSON') 14 | #convert the response to a json object 15 | data = r.json() 16 | #return the SMILES string 17 | smi = data['PropertyTable']['Properties'][0]['IsomericSMILES'] 18 | # remove salts 19 | return smi 20 | except: 21 | f"Could not find the IUPAC name for {text}" 22 | 23 | @tool 24 | def smiles2IUPAC(text): 25 | '''This function queries the one given smiles name and returns a IUPAC name from the record''' 26 | #query the PubChem database 27 | try: 28 | r = requests.get('https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/smiles/' + text + '/property/IUPACName/JSON') 29 | data = r.json() 30 | smi = data["PropertyTable"]["Properties"][0]["IUPACName"] 31 | return smi 32 | except: 33 | return f"Could not find the IUPAC name for {text}" 34 | 35 | @tool 36 | def formula2IUPAC(text): 37 | '''This function queries the one given chemical formula and returns a material name from the record.''' 38 | try: 39 | r = requests.get('https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/formula/' + text + '/property/IUPACName/JSON') 40 | data = r.json() 41 | print(data) 42 | smi = data["PropertyTable"]["Properties"][0]["IUPACName"] 43 | return smi 44 | except: 45 | return f"Could not find the IUPAC name for {text}" 46 | 47 | @tool 48 | def name2formula(text): 49 | '''This function queries the one given material name and returns a chemical formula from the record.''' 50 | try: 51 | r = requests.get('https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/' + text + '/property/MolecularFormula/JSON') 52 | data = r.json() 53 | print(data) 54 | smi = data["PropertyTable"]["Properties"][0]["MolecularFormula"] 55 | return smi 56 | except: 57 | return f"Could not find the molecular formula for {text}" 58 | 59 | @tool 60 | def canonicalizeSMILES(smiles): 61 | '''Given a smiles representation, this function returns a canonicalized version of the same smiles. 62 | It's better to search for molecules in its canonicalized form''' 63 | return Chem.MolToSmiles(Chem.MolFromSmiles(smiles)) 64 | 65 | @tool 66 | def web_search(keywords, search_engine="google"): 67 | '''Useful to do a simple google search. 68 | Use this tool to find general information from websites. 69 | Use keywords for your search. 70 | ''' 71 | return SerpAPIWrapper( 72 | serpapi_api_key=os.getenv("SERP_API_KEY"), 73 | search_engine=search_engine 74 | ).run(keywords) 75 | 76 | @tool 77 | def LLM_predict(prompt): 78 | ''' This function receives a prompt generate with context by the create_context_prompt tool and request a completion to a language model. Then returns the completion''' 79 | llm = OpenAI( 80 | model_name='text-ada-001', #TODO: Maybe change to gpt-4 when ready 81 | temperature=0.7, 82 | n=1, 83 | best_of=5, 84 | top_p=1.0, 85 | stop=["\n\n", "###", "#", "##"], 86 | # model_kwargs=kwargs, 87 | ) 88 | return llm.generate([prompt]).generations[0][0].text 89 | 90 | common_tools = [ 91 | query2smiles, 92 | smiles2IUPAC, 93 | # formula2IUPAC, 94 | # name2formula, 95 | canonicalizeSMILES, 96 | # web_search, 97 | LLM_predict 98 | ] -------------------------------------------------------------------------------- /mapillm/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.2" -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | pandas 3 | openai 4 | langchain 5 | langchain_openai 6 | langchainhub 7 | mp_api 8 | requests 9 | rdkit 10 | transformers 11 | faiss-cpu 12 | pymatgen 13 | reaction-network 14 | python-dotenv 15 | numexpr -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from glob import glob 4 | from setuptools import setup, find_packages 5 | 6 | exec(open("mapillm/version.py").read()) 7 | 8 | with open("README.md", "r", encoding="utf-8") as fh: 9 | long_description = fh.read() 10 | 11 | setup( 12 | name="MAPI_LLM", 13 | version=__version__, 14 | description="A Python package for the MAPI_LLM project", 15 | author="Mayk Caldas", 16 | author_email="maykcaldas@gmail.edu", 17 | url="https://github.com/maykcaldas/MAPI_LLM", 18 | license="MIT", 19 | packages=['mapillm'], 20 | install_requires=[ 21 | "numpy", 22 | "pandas", 23 | "openai", 24 | "langchain", 25 | "langchain_openai", 26 | "langchainhub", 27 | "mp_api", 28 | "request", 29 | "rdkit", 30 | "transformers", 31 | "pymatgen", 32 | "faiss-cpu", 33 | "reaction-network", 34 | "python-dotenv", 35 | "numexpr", 36 | ], 37 | test_suite="tests", 38 | long_description=long_description, 39 | long_description_content_type="text/markdown", 40 | classifiers=[ 41 | "Programming Language :: Python :: 3", 42 | "License :: OSI Approved :: MIT License", 43 | "Operating System :: OS Independent", 44 | ], 45 | ) -------------------------------------------------------------------------------- /tests/baselines/ICL.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 15, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "data": { 10 | "text/plain": [ 11 | "True" 12 | ] 13 | }, 14 | "execution_count": 15, 15 | "metadata": {}, 16 | "output_type": "execute_result" 17 | } 18 | ], 19 | "source": [ 20 | "from transformers import BertTokenizerFast\n", 21 | "import re \n", 22 | "import sys\n", 23 | "import os\n", 24 | "sys.path.append(os.path.join(os.path.dirname(\".\"), '../..'))\n", 25 | "\n", 26 | "import mapillm\n", 27 | "from mapillm.mapi_tools import MAPI_reg_tools\n", 28 | "from datasets import load_dataset\n", 29 | "\n", 30 | "from dotenv import load_dotenv\n", 31 | "load_dotenv(\"../.env\")\n" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 16, 37 | "metadata": {}, 38 | "outputs": [ 39 | { 40 | "name": "stdout", 41 | "output_type": "stream", 42 | "text": [ 43 | "['nsites', 'nelements', 'formula_pretty', 'chemsys', 'volume', 'density', 'density_atomic', 'property_name', 'material_id', 'deprecation_reasons', 'last_updated', 'origins', 'warnings', 'structure', 'task_ids', 'uncorrected_energy_per_atom', 'energy_per_atom', 'formation_energy_per_atom', 'energy_above_hull', 'is_stable', 'equilibrium_reaction_energy_per_atom', 'xas', 'grain_boundaries', 'band_gap', 'cbm', 'vbm', 'efermi', 'is_gap_direct', 'is_metal', 'es_source_calc_id', 'dos_energy_up', 'dos_energy_down', 'is_magnetic', 'ordering', 'total_magnetization', 'total_magnetization_normalized_vol', 'total_magnetization_normalized_formula_units', 'num_magnetic_sites', 'num_unique_magnetic_sites', 'bulk_modulus', 'shear_modulus', 'universal_anisotropy', 'homogeneous_poisson', 'e_total', 'e_ionic', 'e_electronic', 'n', 'e_ij_max', 'weighted_surface_energy_EV_PER_ANG2', 'weighted_surface_energy', 'weighted_work_function', 'surface_anisotropy', 'shape_factor', 'has_reconstructed', 'has_props', 'theoretical', 'database_IDs', 'crystal_system', 'symbol', 'number', 'point_group']\n" 44 | ] 45 | } 46 | ], 47 | "source": [ 48 | "dataset=load_dataset('ur-whitelab/mapi', token=os.environ['HF_TOKEN'])\n", 49 | "\n", 50 | "print(dataset['train'].column_names)\n", 51 | "\n", 52 | "target = [\"band_gap\"]\n", 53 | "# features=['formula_pretty', 'crystal_system', 'symbol', 'point_group']\n", 54 | "features=['formula_pretty', 'crystal_system', 'point_group']\n" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 17, 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "train_dataset = dataset['train'].select_columns(features+target)\n", 64 | "test_dataset = dataset['test'].select_columns(features+target)\n", 65 | "\n", 66 | "def filter_none(example):\n", 67 | " return all(value is not None for value in example.values())\n", 68 | "\n", 69 | "train_dataset = train_dataset.filter(filter_none).to_pandas()\n", 70 | "test_dataset = test_dataset.filter(filter_none).to_pandas()" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 18, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "from langchain.prompts.prompt import PromptTemplate\n", 80 | "from langchain.prompts.few_shot import FewShotPromptTemplate\n", 81 | "from langchain.prompts.example_selector import (MaxMarginalRelevanceExampleSelector, \n", 82 | " SemanticSimilarityExampleSelector)\n", 83 | "from langchain_community.vectorstores import FAISS\n", 84 | "from langchain_openai import OpenAIEmbeddings\n", 85 | "\n", 86 | "\n", 87 | "examples = train_dataset.astype(str).drop_duplicates().dropna().to_dict(orient=\"records\")\n", 88 | "\n", 89 | "example_selector = MaxMarginalRelevanceExampleSelector.from_examples(\n", 90 | " examples,\n", 91 | " OpenAIEmbeddings(),\n", 92 | " FAISS,\n", 93 | " k=10,\n", 94 | " )\n", 95 | "\n", 96 | "prompt_template=PromptTemplate(\n", 97 | " input_variables=[\"crystal_system\", \"formula_pretty\", \"point_group\", \"band_gap\"],\n", 98 | " template=f\"What is the band_gap for {{crystal_system}} {{formula_pretty}} with space group {{point_group}}?@@@\\n{{band_gap}}###\"\n", 99 | " )\n", 100 | " \n", 101 | "prefix=(\n", 102 | " f'You are a bot who can predict the band_gap of a material .\\n'\n", 103 | " f'Given this list of known materials and the measurement of their band gap, \\n'\n", 104 | " f'you need to predict what is the band gap of the material:'\n", 105 | " f'The answer should be numeric and finish with ###'\n", 106 | " )\n", 107 | "suffix = f\"What is the band_gap for {{crystal_system}} {{formula_pretty}} with space group {{point_group}}?@@@\\n\"\n", 108 | "prompt = FewShotPromptTemplate(\n", 109 | " # examples=examples,\n", 110 | " example_prompt=prompt_template,\n", 111 | " example_selector=example_selector,\n", 112 | " prefix=prefix,\n", 113 | " suffix=suffix,\n", 114 | " input_variables=[\"crystal_system\", \"formula_pretty\", \"point_group\"])\n" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": 19, 120 | "metadata": {}, 121 | "outputs": [ 122 | { 123 | "name": "stdout", 124 | "output_type": "stream", 125 | "text": [ 126 | "Get atoms in material\n", 127 | "\tReceives a material formula and returns the atoms symbols present in it separated by comma.\n", 128 | "Checks if material is band gap by formula\n", 129 | "\tThis functions searches in the material project's API for the formula and returns if it is band gap or not.\n", 130 | "Create band gap context to LLM search\n", 131 | "\tThis function received a material formula as input and create a prompt to be inputed in the LLM_predict tool to predict the band gap of a material.\n", 132 | "LLM prediction\n", 133 | "\tThis function receives a prompt generate with context by the create_context_prompt tool and request a completion to a language model. Then returns the completion\n" 134 | ] 135 | } 136 | ], 137 | "source": [ 138 | "band_gap_tool = MAPI_reg_tools(\n", 139 | " \"band_gap\",\"band gap\"\n", 140 | ")\n", 141 | "tools = band_gap_tool.get_tools()\n", 142 | "for k in tools:\n", 143 | " print(f\"{k.name}\\n\\t{k.description}\")\n", 144 | "get_material_atoms = tools[0].func\n", 145 | "check_prop_by_formula = tools[1].func\n", 146 | "create_context_prompt = tools[2].func\n", 147 | "LLM_predict = tools[3].func\n", 148 | "\n", 149 | "# I want to evaluate the tools doing something like that. But as the tools aren't ready, I'll just evaluate the ICL\n", 150 | "# check_prop_by_formula(formula = \"LiCoO2\")\n", 151 | "# get_material_atoms(formula = \"LiCoO2\")\n", 152 | "# prompt = create_context_prompt(formula = \"LiCoO2\")\n", 153 | "# print(prompt)\n", 154 | "# LLM_predict(prompt=prompt)" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": 20, 160 | "metadata": {}, 161 | "outputs": [], 162 | "source": [ 163 | "import numpy as np\n", 164 | "import pandas as pd\n", 165 | "import matplotlib.pyplot as plt\n", 166 | "import urllib.request\n", 167 | "import matplotlib as mpl\n", 168 | "import matplotlib.font_manager as font_manager\n", 169 | "urllib.request.urlretrieve('https://github.com/google/fonts/raw/main/ofl/ibmplexmono/IBMPlexMono-Regular.ttf', 'IBMPlexMono-Regular.ttf')\n", 170 | "fe = font_manager.FontEntry(\n", 171 | " fname='IBMPlexMono-Regular.ttf',\n", 172 | " name='plexmono')\n", 173 | "font_manager.fontManager.ttflist.append(fe)\n", 174 | "plt.rcParams.update({'axes.facecolor':'#f5f4e9',\n", 175 | " 'grid.color' : '#AAAAAA',\n", 176 | " 'axes.edgecolor':'#333333',\n", 177 | " 'figure.facecolor':'#FFFFFF',\n", 178 | " 'axes.grid': False,\n", 179 | " 'axes.prop_cycle': plt.cycler('color', plt.cm.Dark2.colors),\n", 180 | " 'font.family': fe.name,\n", 181 | " 'figure.figsize': (3.5,3.5 / 1.2),\n", 182 | " 'ytick.left': True,\n", 183 | " 'xtick.bottom': True\n", 184 | " })" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": 21, 190 | "metadata": {}, 191 | "outputs": [], 192 | "source": [ 193 | "from langchain_openai import ChatOpenAI\n", 194 | "\n", 195 | "llm = ChatOpenAI(\n", 196 | " model_name=\"gpt-3.5-turbo-0125\",\n", 197 | " temperature=0.1,\n", 198 | " n=5,\n", 199 | " # best_of=5,\n", 200 | " # stop=[\"\\n\\n\", \"###\", \"#\", \"##\"],\n", 201 | " )" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": 22, 207 | "metadata": {}, 208 | "outputs": [ 209 | { 210 | "name": "stdout", 211 | "output_type": "stream", 212 | "text": [ 213 | "(100,) (100,)\n" 214 | ] 215 | } 216 | ], 217 | "source": [ 218 | "k=8\n", 219 | "\n", 220 | "import numpy as np\n", 221 | "import matplotlib.pyplot as plt\n", 222 | "from sklearn.metrics import mean_absolute_error, mean_squared_error\n", 223 | "\n", 224 | "yhat=[]\n", 225 | "y=[]\n", 226 | "for k in range(100):\n", 227 | " try:\n", 228 | " y.append(test_dataset.iloc[k][target])\n", 229 | " \n", 230 | " formula_pretty = test_dataset.iloc[k][\"formula_pretty\"]\n", 231 | " crystal_system = test_dataset.iloc[k][\"crystal_system\"]\n", 232 | " point_group = test_dataset.iloc[k][\"point_group\"]\n", 233 | "\n", 234 | " p = prompt.format(crystal_system=crystal_system,\n", 235 | " formula_pretty=formula_pretty,\n", 236 | " point_group=point_group\n", 237 | " )\n", 238 | " completion = llm.invoke([p]).content\n", 239 | " pred = float(re.findall(r\"[-+]?\\d*\\.\\d+|\\d+\", completion)[0])\n", 240 | " yhat.append(pred)\n", 241 | " except Exception as e:\n", 242 | " print(k, e)\n", 243 | " yhat.append(-1)\n", 244 | "\n", 245 | "y = np.array(y).astype(float).flatten()\n", 246 | "yhat = np.array(yhat).astype(float).flatten()\n", 247 | "print(y.shape, yhat.shape)" 248 | ] 249 | }, 250 | { 251 | "cell_type": "code", 252 | "execution_count": 23, 253 | "metadata": {}, 254 | "outputs": [ 255 | { 256 | "data": { 257 | "image/png": "", 258 | "text/plain": [ 259 | "
" 260 | ] 261 | }, 262 | "metadata": {}, 263 | "output_type": "display_data" 264 | } 265 | ], 266 | "source": [ 267 | "lim = (min(y),max(y))\n", 268 | "plt.xlabel('True')\n", 269 | "plt.ylabel('Predicted')\n", 270 | "plt.plot(y, yhat, 'o', alpha=0.4)\n", 271 | "plt.plot(lim, lim, '--')\n", 272 | "plt.text(lim[0] + 0.1*(max(y)-min(y)), lim[1] - 1*0.1*(max(y)-min(y)), f\"correlation = {np.corrcoef(y, yhat)[0,1]:.3f}\")\n", 273 | "plt.text(lim[0] + 0.1*(max(y)-min(y)), lim[1] - 2*0.1*(max(y)-min(y)), f\"MAE = {mean_squared_error(y, yhat):.3f}\")\n", 274 | "plt.show()\n" 275 | ] 276 | } 277 | ], 278 | "metadata": { 279 | "kernelspec": { 280 | "display_name": "mapi", 281 | "language": "python", 282 | "name": "python3" 283 | }, 284 | "language_info": { 285 | "codemirror_mode": { 286 | "name": "ipython", 287 | "version": 3 288 | }, 289 | "file_extension": ".py", 290 | "mimetype": "text/x-python", 291 | "name": "python", 292 | "nbconvert_exporter": "python", 293 | "pygments_lexer": "ipython3", 294 | "version": "3.10.14" 295 | } 296 | }, 297 | "nbformat": 4, 298 | "nbformat_minor": 2 299 | } 300 | -------------------------------------------------------------------------------- /tests/baselines/KNN.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "/Users/maykcaldas/miniconda3/envs/mapi/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 13 | " from .autonotebook import tqdm as notebook_tqdm\n" 14 | ] 15 | }, 16 | { 17 | "data": { 18 | "text/plain": [ 19 | "True" 20 | ] 21 | }, 22 | "execution_count": 1, 23 | "metadata": {}, 24 | "output_type": "execute_result" 25 | } 26 | ], 27 | "source": [ 28 | "import os\n", 29 | "from datasets import load_dataset\n", 30 | "from sklearn.neighbors import KNeighborsRegressor\n", 31 | "import re\n", 32 | "\n", 33 | "from dotenv import load_dotenv\n", 34 | "load_dotenv(\"../.env\")\n" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 2, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "dataset=load_dataset('ur-whitelab/mapi', token=os.environ['HF_TOKEN'])\n", 44 | "\n", 45 | "target = [\"band_gap\"]\n", 46 | "features=['nsites', 'nelements', 'formula_pretty', 'chemsys', 'volume', 'density', 'density_atomic', 'crystal_system', 'symbol', 'number', 'point_group']\n", 47 | "\n", 48 | "train_dataset = dataset['train'].select_columns(features+target).to_pandas()\n", 49 | "test_dataset = dataset['test'].select_columns(features+target).to_pandas()" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 3, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "def expand_formula(formula):\n", 59 | " # Pattern to identify groups in parentheses with their multipliers\n", 60 | " parenthetical_group_pattern = re.compile(r\"\\(([\\w]+)\\)(\\d+)\")\n", 61 | " match = parenthetical_group_pattern.search(formula)\n", 62 | " while match:\n", 63 | " # Extract the matched group and multiplier\n", 64 | " group, multiplier = match.group(1), int(match.group(2))\n", 65 | " # Expand the group by repeating the sequence inside the parentheses\n", 66 | " expanded_group = ''\n", 67 | " inner_matches = re.findall(r\"([A-Z][a-z]*)(\\d*)\", group)\n", 68 | " for elem, qty in inner_matches:\n", 69 | " if qty == '':\n", 70 | " qty = 1\n", 71 | " else:\n", 72 | " qty = int(qty)\n", 73 | " expanded_group += f\"{elem}{qty * multiplier}\"\n", 74 | " # Replace the original matched pattern with its expansion\n", 75 | " formula = formula[:match.start()] + expanded_group + formula[match.end():]\n", 76 | " # Search for the next match\n", 77 | " match = parenthetical_group_pattern.search(formula)\n", 78 | " return formula\n", 79 | "\n", 80 | "def parse_formula(formula):\n", 81 | " # Expand the formula first\n", 82 | " expanded_formula = expand_formula(formula)\n", 83 | " # Now parse the expanded formula\n", 84 | " element_pattern = re.compile(r\"([A-Z][a-z]*)(\\d*)\")\n", 85 | " matches = element_pattern.findall(expanded_formula)\n", 86 | " # Convert counts to integers, defaulting to 1 if absent\n", 87 | " parsed_matches = [(elem, int(qty) if qty else 1) for elem, qty in matches]\n", 88 | " # Simplify the representation by aggregating counts for each element\n", 89 | " composition = {}\n", 90 | " for element, count in parsed_matches:\n", 91 | " if element in composition:\n", 92 | " composition[element] += count\n", 93 | " else:\n", 94 | " composition[element] = count\n", 95 | " # Prepare the final aggregated result\n", 96 | " aggregated_matches = [(element, str(count)) for element, count in composition.items()]\n", 97 | " return expanded_formula, aggregated_matches\n", 98 | "\n", 99 | "def make_set(serie):\n", 100 | " all = \"-\".join(serie.to_list())\n", 101 | " all = all.split(\"-\")\n", 102 | " print(set(all))" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": 4, 108 | "metadata": {}, 109 | "outputs": [ 110 | { 111 | "name": "stdout", 112 | "output_type": "stream", 113 | "text": [ 114 | "Index(['nsites', 'nelements', 'formula_pretty', 'chemsys', 'volume', 'density',\n", 115 | " 'density_atomic', 'crystal_system', 'symbol', 'number', 'point_group',\n", 116 | " 'band_gap'],\n", 117 | " dtype='object')\n", 118 | "nsites 54\n", 119 | "nelements 5\n", 120 | "formula_pretty K2LaTa6(Br5O)3\n", 121 | "chemsys Br-K-La-O-Ta\n", 122 | "volume 1442.071056\n", 123 | "density 5.871127\n", 124 | "density_atomic 26.70502\n", 125 | "crystal_system Trigonal\n", 126 | "symbol P-31c\n", 127 | "number 163\n", 128 | "point_group -3m\n", 129 | "band_gap 1.1473\n", 130 | "Name: 0, dtype: object\n" 131 | ] 132 | }, 133 | { 134 | "data": { 135 | "text/html": [ 136 | "
\n", 137 | "\n", 150 | "\n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | " \n", 382 | " \n", 383 | " \n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | "
nsitesnelementsvolumedensitydensity_atomiccrystal_systemsymbolnumberpoint_groupband_gap...ThTcPaNpKrAcXeHeArNe
05451442.0710565.87112726.705020640163301.1473...0.00.00.00.00.00.00.00.00.00.0
1363553.52861714.66661115.37579543618990.0000...0.00.00.00.00.00.00.00.00.00.0
243702.9920010.686733175.74800072017170.0000...0.00.00.00.00.00.00.00.00.00.0
36355.0395551.1414049.173259399216186.5874...0.00.00.00.00.00.00.00.00.00.0
442111.0640859.28375527.7660211127139110.0000...0.00.00.00.00.00.00.00.00.00.0
..................................................................
124283103167.4654766.84361516.7465481188140110.0000...0.00.00.00.00.00.00.00.00.00.0
124284275613.3686266.19827422.7173572150121.9338...0.00.00.00.00.00.00.00.00.00.0
124285143134.1195944.6752019.57997168216010.5920...0.00.00.00.00.00.00.00.00.00.0
124286303304.3176824.02228010.143923511512240.5093...0.00.00.00.00.00.00.00.00.00.0
124287303452.4837746.56502715.0827921113127110.0000...0.00.00.00.00.00.00.00.00.00.0
\n", 444 | "

124283 rows × 99 columns

\n", 445 | "
" 446 | ], 447 | "text/plain": [ 448 | " nsites nelements volume density density_atomic \\\n", 449 | "0 54 5 1442.071056 5.871127 26.705020 \n", 450 | "1 36 3 553.528617 14.666611 15.375795 \n", 451 | "2 4 3 702.992001 0.686733 175.748000 \n", 452 | "3 6 3 55.039555 1.141404 9.173259 \n", 453 | "4 4 2 111.064085 9.283755 27.766021 \n", 454 | "... ... ... ... ... ... \n", 455 | "124283 10 3 167.465476 6.843615 16.746548 \n", 456 | "124284 27 5 613.368626 6.198274 22.717357 \n", 457 | "124285 14 3 134.119594 4.675201 9.579971 \n", 458 | "124286 30 3 304.317682 4.022280 10.143923 \n", 459 | "124287 30 3 452.483774 6.565027 15.082792 \n", 460 | "\n", 461 | " crystal_system symbol number point_group band_gap ... Th Tc \\\n", 462 | "0 6 40 163 30 1.1473 ... 0.0 0.0 \n", 463 | "1 4 36 189 9 0.0000 ... 0.0 0.0 \n", 464 | "2 7 201 71 7 0.0000 ... 0.0 0.0 \n", 465 | "3 3 99 216 18 6.5874 ... 0.0 0.0 \n", 466 | "4 1 127 139 11 0.0000 ... 0.0 0.0 \n", 467 | "... ... ... ... ... ... ... ... ... \n", 468 | "124283 1 188 140 11 0.0000 ... 0.0 0.0 \n", 469 | "124284 2 150 1 2 1.9338 ... 0.0 0.0 \n", 470 | "124285 6 82 160 1 0.5920 ... 0.0 0.0 \n", 471 | "124286 5 115 12 24 0.5093 ... 0.0 0.0 \n", 472 | "124287 1 113 127 11 0.0000 ... 0.0 0.0 \n", 473 | "\n", 474 | " Pa Np Kr Ac Xe He Ar Ne \n", 475 | "0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", 476 | "1 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", 477 | "2 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", 478 | "3 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", 479 | "4 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", 480 | "... ... ... ... ... ... ... ... ... \n", 481 | "124283 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", 482 | "124284 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", 483 | "124285 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", 484 | "124286 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", 485 | "124287 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", 486 | "\n", 487 | "[124283 rows x 99 columns]" 488 | ] 489 | }, 490 | "execution_count": 4, 491 | "metadata": {}, 492 | "output_type": "execute_result" 493 | } 494 | ], 495 | "source": [ 496 | "print(train_dataset.columns)\n", 497 | "print(train_dataset.iloc[0])\n", 498 | "\n", 499 | "formula = train_dataset.iloc[0]['formula_pretty']\n", 500 | "all_elements = \"-\".join(train_dataset['chemsys'].to_list())\n", 501 | "all_elements = set(all_elements.split(\"-\"))\n", 502 | "elements_dics = {k: v for (v, k) in enumerate(all_elements, 1)}\n", 503 | "\n", 504 | "all_crystal = set(train_dataset['crystal_system'].to_list())\n", 505 | "crystal_dict = {k: v for (v, k) in enumerate(all_crystal, 1)}\n", 506 | "\n", 507 | "all_symbols = set(train_dataset['symbol'].to_list())\n", 508 | "symbols_dict = {k: v for (v, k) in enumerate(all_symbols, 1)}\n", 509 | "\n", 510 | "all_pgs = set(train_dataset['point_group'].to_list())\n", 511 | "pgs_dict = {k: v for (v, k) in enumerate(all_pgs, 1)}\n", 512 | "\n", 513 | "# map all entries in colum poing_group to a number using the dictionary\n", 514 | "train_dataset['point_group'] = train_dataset['point_group'].map(pgs_dict)\n", 515 | "train_dataset['symbol'] = train_dataset['symbol'].map(symbols_dict)\n", 516 | "train_dataset['crystal_system'] = train_dataset['crystal_system'].map(crystal_dict)\n", 517 | "for i, r in train_dataset.iterrows():\n", 518 | " formula, elements = parse_formula(r['formula_pretty'])\n", 519 | " for e in elements:\n", 520 | " train_dataset.at[i, e[0]] = int(e[1])\n", 521 | "\n", 522 | "train_dataset.drop(columns=['formula_pretty', 'chemsys'], inplace=True)\n", 523 | "train_dataset.dropna(axis=0, how='any', inplace=True, subset=target)\n", 524 | "train_dataset = train_dataset.fillna(0)\n", 525 | "train_dataset" 526 | ] 527 | }, 528 | { 529 | "cell_type": "code", 530 | "execution_count": 5, 531 | "metadata": {}, 532 | "outputs": [ 533 | { 534 | "data": { 535 | "text/html": [ 536 | "
\n", 537 | "\n", 550 | "\n", 551 | " \n", 552 | " \n", 553 | " \n", 554 | " \n", 555 | " \n", 556 | " \n", 557 | " \n", 558 | " \n", 559 | " \n", 560 | " \n", 561 | " \n", 562 | " \n", 563 | " \n", 564 | " \n", 565 | " \n", 566 | " \n", 567 | " \n", 568 | " \n", 569 | " \n", 570 | " \n", 571 | " \n", 572 | " \n", 573 | " \n", 574 | " \n", 575 | " \n", 576 | " \n", 577 | " \n", 578 | " \n", 579 | " \n", 580 | " \n", 581 | " \n", 582 | " \n", 583 | " \n", 584 | " \n", 585 | " \n", 586 | " \n", 587 | " \n", 588 | " \n", 589 | " \n", 590 | " \n", 591 | " \n", 592 | " \n", 593 | " \n", 594 | " \n", 595 | " \n", 596 | " \n", 597 | " \n", 598 | " \n", 599 | " \n", 600 | " \n", 601 | " \n", 602 | " \n", 603 | " \n", 604 | " \n", 605 | " \n", 606 | " \n", 607 | " \n", 608 | " \n", 609 | " \n", 610 | " \n", 611 | " \n", 612 | " \n", 613 | " \n", 614 | " \n", 615 | " \n", 616 | " \n", 617 | " \n", 618 | " \n", 619 | " \n", 620 | " \n", 621 | " \n", 622 | " \n", 623 | " \n", 624 | " \n", 625 | " \n", 626 | " \n", 627 | " \n", 628 | " \n", 629 | " \n", 630 | " \n", 631 | " \n", 632 | " \n", 633 | " \n", 634 | " \n", 635 | " \n", 636 | " \n", 637 | " \n", 638 | " \n", 639 | " \n", 640 | " \n", 641 | " \n", 642 | " \n", 643 | " \n", 644 | " \n", 645 | " \n", 646 | " \n", 647 | " \n", 648 | " \n", 649 | " \n", 650 | " \n", 651 | " \n", 652 | " \n", 653 | " \n", 654 | " \n", 655 | " \n", 656 | " \n", 657 | " \n", 658 | " \n", 659 | " \n", 660 | " \n", 661 | " \n", 662 | " \n", 663 | " \n", 664 | " \n", 665 | " \n", 666 | " \n", 667 | " \n", 668 | " \n", 669 | " \n", 670 | " \n", 671 | " \n", 672 | " \n", 673 | " \n", 674 | " \n", 675 | " \n", 676 | " \n", 677 | " \n", 678 | " \n", 679 | " \n", 680 | " \n", 681 | " \n", 682 | " \n", 683 | " \n", 684 | " \n", 685 | " \n", 686 | " \n", 687 | " \n", 688 | " \n", 689 | " \n", 690 | " \n", 691 | " \n", 692 | " \n", 693 | " \n", 694 | " \n", 695 | " \n", 696 | " \n", 697 | " \n", 698 | " \n", 699 | " \n", 700 | " \n", 701 | " \n", 702 | " \n", 703 | " \n", 704 | " \n", 705 | " \n", 706 | " \n", 707 | " \n", 708 | " \n", 709 | " \n", 710 | " \n", 711 | " \n", 712 | " \n", 713 | " \n", 714 | " \n", 715 | " \n", 716 | " \n", 717 | " \n", 718 | " \n", 719 | " \n", 720 | " \n", 721 | " \n", 722 | " \n", 723 | " \n", 724 | " \n", 725 | " \n", 726 | " \n", 727 | " \n", 728 | " \n", 729 | " \n", 730 | " \n", 731 | " \n", 732 | " \n", 733 | " \n", 734 | " \n", 735 | " \n", 736 | " \n", 737 | " \n", 738 | " \n", 739 | " \n", 740 | " \n", 741 | " \n", 742 | " \n", 743 | " \n", 744 | " \n", 745 | " \n", 746 | " \n", 747 | " \n", 748 | " \n", 749 | " \n", 750 | " \n", 751 | " \n", 752 | " \n", 753 | " \n", 754 | " \n", 755 | " \n", 756 | " \n", 757 | " \n", 758 | " \n", 759 | " \n", 760 | " \n", 761 | " \n", 762 | " \n", 763 | " \n", 764 | " \n", 765 | " \n", 766 | " \n", 767 | " \n", 768 | " \n", 769 | " \n", 770 | " \n", 771 | " \n", 772 | " \n", 773 | " \n", 774 | " \n", 775 | " \n", 776 | " \n", 777 | " \n", 778 | " \n", 779 | " \n", 780 | " \n", 781 | " \n", 782 | " \n", 783 | " \n", 784 | " \n", 785 | " \n", 786 | " \n", 787 | " \n", 788 | " \n", 789 | " \n", 790 | " \n", 791 | " \n", 792 | " \n", 793 | " \n", 794 | " \n", 795 | " \n", 796 | " \n", 797 | " \n", 798 | " \n", 799 | " \n", 800 | " \n", 801 | " \n", 802 | " \n", 803 | " \n", 804 | " \n", 805 | " \n", 806 | " \n", 807 | " \n", 808 | " \n", 809 | " \n", 810 | " \n", 811 | " \n", 812 | " \n", 813 | " \n", 814 | " \n", 815 | " \n", 816 | " \n", 817 | " \n", 818 | " \n", 819 | " \n", 820 | " \n", 821 | " \n", 822 | " \n", 823 | " \n", 824 | " \n", 825 | " \n", 826 | " \n", 827 | " \n", 828 | " \n", 829 | " \n", 830 | " \n", 831 | " \n", 832 | " \n", 833 | " \n", 834 | " \n", 835 | " \n", 836 | " \n", 837 | " \n", 838 | " \n", 839 | " \n", 840 | " \n", 841 | " \n", 842 | " \n", 843 | "
nsitesnelementsvolumedensitydensity_atomiccrystal_systemsymbolnumberpoint_groupband_gap...ThTcPaNpKrAcXeHeArNe
0144287.7238164.58633320.55170161166300.0000...0000000000
111451605.4244022.97814014.0826702150123.8216...0000000000
2406509.1562503.26835912.72890621142242.0588...0000000000
3123135.0803672.46453411.2566974120191200.0000...0000000000
4143190.3572425.16340613.59694651151240.0000...0000000000
..................................................................
31068363587.4439334.25447316.3178875841443.6927...0000000000
31069144224.5871255.47486216.04193768216010.0485...0000000000
31070143146.8792125.22691810.49137271906370.0000...0000000000
31071293587.2617572.06035020.25040568216010.0000...0000000000
31072224281.4617658.25635912.79371761166300.0000...0000000000
\n", 844 | "

31070 rows × 99 columns

\n", 845 | "
" 846 | ], 847 | "text/plain": [ 848 | " nsites nelements volume density density_atomic \\\n", 849 | "0 14 4 287.723816 4.586333 20.551701 \n", 850 | "1 114 5 1605.424402 2.978140 14.082670 \n", 851 | "2 40 6 509.156250 3.268359 12.728906 \n", 852 | "3 12 3 135.080367 2.464534 11.256697 \n", 853 | "4 14 3 190.357242 5.163406 13.596946 \n", 854 | "... ... ... ... ... ... \n", 855 | "31068 36 3 587.443933 4.254473 16.317887 \n", 856 | "31069 14 4 224.587125 5.474862 16.041937 \n", 857 | "31070 14 3 146.879212 5.226918 10.491372 \n", 858 | "31071 29 3 587.261757 2.060350 20.250405 \n", 859 | "31072 22 4 281.461765 8.256359 12.793717 \n", 860 | "\n", 861 | " crystal_system symbol number point_group band_gap ... Th Tc Pa \\\n", 862 | "0 6 1 166 30 0.0000 ... 0 0 0 \n", 863 | "1 2 150 1 2 3.8216 ... 0 0 0 \n", 864 | "2 2 114 2 24 2.0588 ... 0 0 0 \n", 865 | "3 4 120 191 20 0.0000 ... 0 0 0 \n", 866 | "4 5 115 12 4 0.0000 ... 0 0 0 \n", 867 | "... ... ... ... ... ... ... .. .. .. \n", 868 | "31068 5 84 14 4 3.6927 ... 0 0 0 \n", 869 | "31069 6 82 160 1 0.0485 ... 0 0 0 \n", 870 | "31070 7 190 63 7 0.0000 ... 0 0 0 \n", 871 | "31071 6 82 160 1 0.0000 ... 0 0 0 \n", 872 | "31072 6 1 166 30 0.0000 ... 0 0 0 \n", 873 | "\n", 874 | " Np Kr Ac Xe He Ar Ne \n", 875 | "0 0 0 0 0 0 0 0 \n", 876 | "1 0 0 0 0 0 0 0 \n", 877 | "2 0 0 0 0 0 0 0 \n", 878 | "3 0 0 0 0 0 0 0 \n", 879 | "4 0 0 0 0 0 0 0 \n", 880 | "... .. .. .. .. .. .. .. \n", 881 | "31068 0 0 0 0 0 0 0 \n", 882 | "31069 0 0 0 0 0 0 0 \n", 883 | "31070 0 0 0 0 0 0 0 \n", 884 | "31071 0 0 0 0 0 0 0 \n", 885 | "31072 0 0 0 0 0 0 0 \n", 886 | "\n", 887 | "[31070 rows x 99 columns]" 888 | ] 889 | }, 890 | "execution_count": 5, 891 | "metadata": {}, 892 | "output_type": "execute_result" 893 | } 894 | ], 895 | "source": [ 896 | "test_dataset['point_group'] = test_dataset['point_group'].map(pgs_dict)\n", 897 | "test_dataset['symbol'] = test_dataset['symbol'].map(symbols_dict)\n", 898 | "test_dataset['crystal_system'] = test_dataset['crystal_system'].map(crystal_dict)\n", 899 | "for new_c in train_dataset.columns[10:]:\n", 900 | " test_dataset[new_c] = 0\n", 901 | "\n", 902 | "for i, r in test_dataset.iterrows():\n", 903 | " formula, elements = parse_formula(r['formula_pretty'])\n", 904 | " for e in elements:\n", 905 | " test_dataset.at[i, e[0]] = int(e[1])\n", 906 | "\n", 907 | "test_dataset.drop(columns=['formula_pretty', 'chemsys'], inplace=True)\n", 908 | "test_dataset.dropna(axis=0, how='any', inplace=True, subset=target)\n", 909 | "test_dataset = test_dataset.fillna(0)\n", 910 | "test_dataset" 911 | ] 912 | }, 913 | { 914 | "cell_type": "code", 915 | "execution_count": 6, 916 | "metadata": {}, 917 | "outputs": [ 918 | { 919 | "name": "stdout", 920 | "output_type": "stream", 921 | "text": [ 922 | "0.2919239302554062\n" 923 | ] 924 | } 925 | ], 926 | "source": [ 927 | "features = [f for f in features if f not in ['formula_pretty', 'chemsys']]\n", 928 | "model = KNeighborsRegressor(n_neighbors=10)\n", 929 | "model.fit(train_dataset.drop(target, axis=1), train_dataset[target])\n", 930 | "print(model.score(test_dataset.drop(target, axis=1), test_dataset[target]))\n" 931 | ] 932 | }, 933 | { 934 | "cell_type": "code", 935 | "execution_count": 7, 936 | "metadata": {}, 937 | "outputs": [], 938 | "source": [ 939 | "import numpy as np\n", 940 | "import pandas as pd\n", 941 | "import matplotlib.pyplot as plt\n", 942 | "import urllib.request\n", 943 | "import matplotlib as mpl\n", 944 | "import matplotlib.font_manager as font_manager\n", 945 | "urllib.request.urlretrieve('https://github.com/google/fonts/raw/main/ofl/ibmplexmono/IBMPlexMono-Regular.ttf', 'IBMPlexMono-Regular.ttf')\n", 946 | "fe = font_manager.FontEntry(\n", 947 | " fname='IBMPlexMono-Regular.ttf',\n", 948 | " name='plexmono')\n", 949 | "font_manager.fontManager.ttflist.append(fe)\n", 950 | "plt.rcParams.update({'axes.facecolor':'#f5f4e9',\n", 951 | " 'grid.color' : '#AAAAAA',\n", 952 | " 'axes.edgecolor':'#333333',\n", 953 | " 'figure.facecolor':'#FFFFFF',\n", 954 | " 'axes.grid': False,\n", 955 | " 'axes.prop_cycle': plt.cycler('color', plt.cm.Dark2.colors),\n", 956 | " 'font.family': fe.name,\n", 957 | " 'figure.figsize': (3.5,3.5 / 1.2),\n", 958 | " 'ytick.left': True,\n", 959 | " 'xtick.bottom': True\n", 960 | " })" 961 | ] 962 | }, 963 | { 964 | "cell_type": "code", 965 | "execution_count": 11, 966 | "metadata": {}, 967 | "outputs": [ 968 | { 969 | "name": "stdout", 970 | "output_type": "stream", 971 | "text": [ 972 | "(100,) (100,)\n" 973 | ] 974 | } 975 | ], 976 | "source": [ 977 | "k=8\n", 978 | "\n", 979 | "import numpy as np\n", 980 | "import matplotlib.pyplot as plt\n", 981 | "from sklearn.metrics import mean_absolute_error, mean_squared_error\n", 982 | "\n", 983 | "yhat=[]\n", 984 | "y=[]\n", 985 | "for k in range(100):\n", 986 | " try:\n", 987 | " y.append(test_dataset.iloc[k][target])\n", 988 | " yhat.append(model.predict(test_dataset.drop(target, axis=1).iloc[[k]]))\n", 989 | " except Exception as e:\n", 990 | " print(f\"Error at index {k}: {e}\")\n", 991 | "\n", 992 | "\n", 993 | "y = np.array(y).flatten()\n", 994 | "yhat = np.array(yhat).flatten()\n", 995 | "print(y.shape, yhat.shape)" 996 | ] 997 | }, 998 | { 999 | "cell_type": "code", 1000 | "execution_count": 12, 1001 | "metadata": {}, 1002 | "outputs": [ 1003 | { 1004 | "data": { 1005 | "image/png": "", 1006 | "text/plain": [ 1007 | "
" 1008 | ] 1009 | }, 1010 | "metadata": {}, 1011 | "output_type": "display_data" 1012 | } 1013 | ], 1014 | "source": [ 1015 | "lim = (min(y),max(y))\n", 1016 | "plt.xlabel('True')\n", 1017 | "plt.ylabel('Predicted')\n", 1018 | "plt.plot(y, yhat, 'o', alpha=0.2)\n", 1019 | "plt.plot(lim, lim, '--')\n", 1020 | "plt.text(lim[0] + 0.1*(max(y)-min(y)), lim[1] - 1*0.1*(max(y)-min(y)), f\"correlation = {np.corrcoef(y, yhat)[0,1]:.3f}\")\n", 1021 | "plt.text(lim[0] + 0.1*(max(y)-min(y)), lim[1] - 2*0.1*(max(y)-min(y)), f\"MAE = {mean_squared_error(y, yhat):.3f}\")\n", 1022 | "plt.show()\n" 1023 | ] 1024 | }, 1025 | { 1026 | "cell_type": "code", 1027 | "execution_count": null, 1028 | "metadata": {}, 1029 | "outputs": [], 1030 | "source": [] 1031 | } 1032 | ], 1033 | "metadata": { 1034 | "kernelspec": { 1035 | "display_name": "mapi", 1036 | "language": "python", 1037 | "name": "python3" 1038 | }, 1039 | "language_info": { 1040 | "codemirror_mode": { 1041 | "name": "ipython", 1042 | "version": 3 1043 | }, 1044 | "file_extension": ".py", 1045 | "mimetype": "text/x-python", 1046 | "name": "python", 1047 | "nbconvert_exporter": "python", 1048 | "pygments_lexer": "ipython3", 1049 | "version": "3.10.14" 1050 | } 1051 | }, 1052 | "nbformat": 4, 1053 | "nbformat_minor": 2 1054 | } 1055 | -------------------------------------------------------------------------------- /tests/baselines/RNN.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "/Users/maykcaldas/miniconda3/envs/mapi/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 13 | " from .autonotebook import tqdm as notebook_tqdm\n" 14 | ] 15 | }, 16 | { 17 | "data": { 18 | "text/plain": [ 19 | "True" 20 | ] 21 | }, 22 | "execution_count": 1, 23 | "metadata": {}, 24 | "output_type": "execute_result" 25 | } 26 | ], 27 | "source": [ 28 | "import os\n", 29 | "from datasets import load_dataset\n", 30 | "from transformers import BertTokenizerFast\n", 31 | "import re\n", 32 | "import torch\n", 33 | "import torch.nn as nn\n", 34 | "from torch.utils.data import DataLoader\n", 35 | "from dataclasses import dataclass\n", 36 | "\n", 37 | "from dotenv import load_dotenv\n", 38 | "load_dotenv(\"../.env\")\n" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 2, 44 | "metadata": {}, 45 | "outputs": [ 46 | { 47 | "name": "stdout", 48 | "output_type": "stream", 49 | "text": [ 50 | "['nsites', 'nelements', 'formula_pretty', 'chemsys', 'volume', 'density', 'density_atomic', 'property_name', 'material_id', 'deprecation_reasons', 'last_updated', 'origins', 'warnings', 'structure', 'task_ids', 'uncorrected_energy_per_atom', 'energy_per_atom', 'formation_energy_per_atom', 'energy_above_hull', 'is_stable', 'equilibrium_reaction_energy_per_atom', 'xas', 'grain_boundaries', 'band_gap', 'cbm', 'vbm', 'efermi', 'is_gap_direct', 'is_metal', 'es_source_calc_id', 'dos_energy_up', 'dos_energy_down', 'is_magnetic', 'ordering', 'total_magnetization', 'total_magnetization_normalized_vol', 'total_magnetization_normalized_formula_units', 'num_magnetic_sites', 'num_unique_magnetic_sites', 'bulk_modulus', 'shear_modulus', 'universal_anisotropy', 'homogeneous_poisson', 'e_total', 'e_ionic', 'e_electronic', 'n', 'e_ij_max', 'weighted_surface_energy_EV_PER_ANG2', 'weighted_surface_energy', 'weighted_work_function', 'surface_anisotropy', 'shape_factor', 'has_reconstructed', 'has_props', 'theoretical', 'database_IDs', 'crystal_system', 'symbol', 'number', 'point_group']\n" 51 | ] 52 | } 53 | ], 54 | "source": [ 55 | "dataset=load_dataset('ur-whitelab/mapi', token=os.environ['HF_TOKEN'])\n", 56 | "\n", 57 | "print(dataset['train'].column_names)\n", 58 | "\n", 59 | "target = [\"band_gap\"]\n", 60 | "features=['nsites', 'nelements', 'formula_pretty', 'chemsys', 'volume', 'density', 'density_atomic', 'crystal_system', 'symbol', 'number', 'point_group', 'structure']\n" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 11, 66 | "metadata": {}, 67 | "outputs": [ 68 | { 69 | "name": "stdout", 70 | "output_type": "stream", 71 | "text": [ 72 | "Full Formula (Nb1 V2 Mo1)\n", 73 | "Reduced Formula: NbV2Mo\n", 74 | "abc : 10.220753 10.220753 10.220753\n", 75 | "angles: 128.933454 117.899846 84.471274\n", 76 | "pbc : True True True\n", 77 | "Sites (4)\n", 78 | " # SP a b c magmom\n", 79 | "--- ---- --- -------- -------- --------\n", 80 | " 0 Nb 0 0 0 0.516\n", 81 | " 1 V 0 0.251541 0.251541 1.349\n", 82 | " 2 V 0 0.748459 0.748459 1.349\n", 83 | " 3 Mo 0 0.5 0.5 1.103\n" 84 | ] 85 | } 86 | ], 87 | "source": [ 88 | "print(dataset['train'][2]['structure'])" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "train_dataset = dataset['train'].select_columns(features+target).to_pandas()\n", 98 | "test_dataset = dataset['test'].select_columns(features+target).to_pandas()\n", 99 | "\n", 100 | "train_dataset = train_dataset.dropna(subset=target, axis=0)\n", 101 | "test_dataset = test_dataset.dropna(subset=target, axis=0)\n", 102 | "\n", 103 | "formula = train_dataset.iloc[0]['formula_pretty']\n", 104 | "all_elements = \"-\".join(train_dataset['chemsys'].to_list())\n", 105 | "all_elements = set(all_elements.split(\"-\"))\n", 106 | "elements_dics = {k: v for (v, k) in enumerate(all_elements, 1)}\n", 107 | "\n", 108 | "voc = \tlist(all_elements) + [str(i) for i in range(10)] + [\"Full Formula\", \"Reduced Formula\", \"abc\", \"angles\", \"pbc\", \"Sites\", \"True\", \"False\", \"magmom\", \".\", \"(\", \")\", \"-\", \"\\n\", \" \", \"a\", \"b\", \"c\"]\n", 109 | "\n", 110 | "\n", 111 | "tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased') # Check if there's a cased. It might be important for elements\n", 112 | "tokenizer.add_tokens(voc)\n", 113 | "\n", 114 | "# def tokenize_function(examples):\n", 115 | " # return tokenizer(examples['structure'], padding=\"max_length\", truncation=True, return_tensors='pt')\n", 116 | "\n", 117 | "# tokenized_datasets = dataset.map(tokenize_function, batched=True)\n", 118 | "# tokenizer" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [ 127 | "train_dataset = dataset['train'].select_columns(['structure']+target)\n", 128 | "test_dataset = dataset['test'].select_columns(['structure']+target)\n", 129 | "\n", 130 | "def filter_none(example):\n", 131 | " return all(value is not None for value in example.values())\n", 132 | "\n", 133 | "train_dataset = train_dataset.filter(filter_none)\n", 134 | "test_dataset = test_dataset.filter(filter_none)\n", 135 | "\n", 136 | "train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)\n", 137 | "test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=True)" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": null, 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [ 146 | "@dataclass\n", 147 | "class KDESolConfig:\n", 148 | " vocab_size: int = 30618\n", 149 | " batch_size: int = 256\n", 150 | " buffer_size: int = 10000\n", 151 | " rnn_units: int = 1028\n", 152 | " hidden_dim: int = 512\n", 153 | " embedding_dim: int = tokenizer.model_max_length\n", 154 | " reg_strength: float = 0.01\n", 155 | " lr: float = 1e-4\n", 156 | " drop_rate: float = 0.35\n", 157 | " nmodels: int = 10\n", 158 | " adv_epsilon: float = 1e-3\n", 159 | " epochs: int = 150\n", 160 | " pad_to_len: int = 512\n", 161 | "\n", 162 | "class RNN(nn.Module):\n", 163 | " def __init__(self, config=KDESolConfig()):\n", 164 | " super(RNN, self).__init__()\n", 165 | " self.config = config\n", 166 | "\n", 167 | " self.embedding = nn.Embedding(config.vocab_size, config.embedding_dim, padding_idx=0)\n", 168 | " self.dropout = nn.Dropout(config.drop_rate)\n", 169 | " self.rnn1 = nn.LSTM(config.embedding_dim, config.rnn_units, batch_first=True, bidirectional=True)\n", 170 | " self.rnn2 = nn.LSTM(2 * config.rnn_units, config.rnn_units, batch_first=True, bidirectional=True)\n", 171 | " self.layer_norm = nn.LayerNorm(2 * config.rnn_units)\n", 172 | " self.dense1 = nn.Linear(2 * config.rnn_units, config.hidden_dim)\n", 173 | " self.dense2 = nn.Linear(config.hidden_dim, config.hidden_dim // 2)\n", 174 | " self.out_mu = nn.Linear(config.hidden_dim // 2, 1)\n", 175 | " self.out_std = nn.Linear(config.hidden_dim // 2, 1)\n", 176 | "\n", 177 | " self.softplus = nn.Softplus()\n", 178 | "\n", 179 | " def forward(self, x):\n", 180 | " x = self.embedding(x)\n", 181 | " x = self.dropout(x)\n", 182 | " x, _ = self.rnn1(x)\n", 183 | " x, _ = self.rnn2(x)\n", 184 | " x = self.layer_norm(x[:, -1, :]) \n", 185 | " x = nn.SiLU()(self.dense1(x))\n", 186 | " x = self.dropout(x)\n", 187 | " x = nn.SiLU()(self.dense2(x))\n", 188 | " x = self.dropout(x)\n", 189 | " mu = self.out_mu(x)\n", 190 | " std = self.softplus(self.out_std(x))\n", 191 | " return mu\n", 192 | " return torch.cat((mu, std), dim=-1)\n" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": null, 198 | "metadata": {}, 199 | "outputs": [], 200 | "source": [ 201 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 202 | "\n", 203 | "model = RNN()\n", 204 | "model.to(device)\n", 205 | "\n", 206 | "loss_fn = nn.MSELoss()\n", 207 | "optimizer = torch.optim.SGD(model.parameters(), lr=0.01)\n", 208 | "\n", 209 | "losses = []\n", 210 | "\n", 211 | "for _ in range(10):\n", 212 | " print(f\"Starting epoch {_}.\")\n", 213 | " model.train()\n", 214 | " for batch, d in enumerate(train_dataloader):\n", 215 | " optimizer.zero_grad()\n", 216 | " size = len(train_dataloader.dataset)\n", 217 | " \n", 218 | " X = tokenizer(d['structure'], padding=\"max_length\", truncation=True, return_tensors='pt')\n", 219 | "\n", 220 | " X = X['input_ids'].to(device)\n", 221 | " y = d['band_gap'].to(device)\n", 222 | " \n", 223 | " pred = model(X)\n", 224 | " loss = loss_fn(torch.flatten(pred), y.to(torch.float32))\n", 225 | "\n", 226 | " loss.backward()\n", 227 | " optimizer.step()\n", 228 | " optimizer.zero_grad()\n", 229 | "\n", 230 | " if batch % 500 == 0:\n", 231 | " losses.append(loss.item())\n", 232 | " loss_item = loss.item()\n", 233 | " current = batch * len(X) + len(X)\n", 234 | " print(f\"\\tloss: {loss_item:>7f} [{current:>5d}/{size:>5d}]\")\n", 235 | " print(f\"Epoch {_} done.\")\n" 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": null, 241 | "metadata": {}, 242 | "outputs": [], 243 | "source": [ 244 | "import numpy as np\n", 245 | "import pandas as pd\n", 246 | "import matplotlib.pyplot as plt\n", 247 | "import urllib.request\n", 248 | "import matplotlib as mpl\n", 249 | "import matplotlib.font_manager as font_manager\n", 250 | "urllib.request.urlretrieve('https://github.com/google/fonts/raw/main/ofl/ibmplexmono/IBMPlexMono-Regular.ttf', 'IBMPlexMono-Regular.ttf')\n", 251 | "fe = font_manager.FontEntry(\n", 252 | " fname='IBMPlexMono-Regular.ttf',\n", 253 | " name='plexmono')\n", 254 | "font_manager.fontManager.ttflist.append(fe)\n", 255 | "plt.rcParams.update({'axes.facecolor':'#f5f4e9',\n", 256 | " 'grid.color' : '#AAAAAA',\n", 257 | " 'axes.edgecolor':'#333333',\n", 258 | " 'figure.facecolor':'#FFFFFF',\n", 259 | " 'axes.grid': False,\n", 260 | " 'axes.prop_cycle': plt.cycler('color', plt.cm.Dark2.colors),\n", 261 | " 'font.family': fe.name,\n", 262 | " 'figure.figsize': (3.5,3.5 / 1.2),\n", 263 | " 'ytick.left': True,\n", 264 | " 'xtick.bottom': True\n", 265 | " })" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": 2, 271 | "metadata": {}, 272 | "outputs": [ 273 | { 274 | "name": "stdout", 275 | "output_type": "stream", 276 | "text": [ 277 | "0 name 'test_dataset' is not defined\n", 278 | "1 name 'test_dataset' is not defined\n", 279 | "2 name 'test_dataset' is not defined\n", 280 | "3 name 'test_dataset' is not defined\n", 281 | "4 name 'test_dataset' is not defined\n", 282 | "5 name 'test_dataset' is not defined\n", 283 | "6 name 'test_dataset' is not defined\n", 284 | "7 name 'test_dataset' is not defined\n", 285 | "8 name 'test_dataset' is not defined\n", 286 | "9 name 'test_dataset' is not defined\n", 287 | "10 name 'test_dataset' is not defined\n", 288 | "11 name 'test_dataset' is not defined\n", 289 | "12 name 'test_dataset' is not defined\n", 290 | "13 name 'test_dataset' is not defined\n", 291 | "14 name 'test_dataset' is not defined\n", 292 | "15 name 'test_dataset' is not defined\n", 293 | "16 name 'test_dataset' is not defined\n", 294 | "17 name 'test_dataset' is not defined\n", 295 | "18 name 'test_dataset' is not defined\n", 296 | "19 name 'test_dataset' is not defined\n", 297 | "20 name 'test_dataset' is not defined\n", 298 | "21 name 'test_dataset' is not defined\n", 299 | "22 name 'test_dataset' is not defined\n", 300 | "23 name 'test_dataset' is not defined\n", 301 | "24 name 'test_dataset' is not defined\n", 302 | "25 name 'test_dataset' is not defined\n", 303 | "26 name 'test_dataset' is not defined\n", 304 | "27 name 'test_dataset' is not defined\n", 305 | "28 name 'test_dataset' is not defined\n", 306 | "29 name 'test_dataset' is not defined\n", 307 | "30 name 'test_dataset' is not defined\n", 308 | "31 name 'test_dataset' is not defined\n", 309 | "32 name 'test_dataset' is not defined\n", 310 | "33 name 'test_dataset' is not defined\n", 311 | "34 name 'test_dataset' is not defined\n", 312 | "35 name 'test_dataset' is not defined\n", 313 | "36 name 'test_dataset' is not defined\n", 314 | "37 name 'test_dataset' is not defined\n", 315 | "38 name 'test_dataset' is not defined\n", 316 | "39 name 'test_dataset' is not defined\n", 317 | "40 name 'test_dataset' is not defined\n", 318 | "41 name 'test_dataset' is not defined\n", 319 | "42 name 'test_dataset' is not defined\n", 320 | "43 name 'test_dataset' is not defined\n", 321 | "44 name 'test_dataset' is not defined\n", 322 | "45 name 'test_dataset' is not defined\n", 323 | "46 name 'test_dataset' is not defined\n", 324 | "47 name 'test_dataset' is not defined\n", 325 | "48 name 'test_dataset' is not defined\n", 326 | "49 name 'test_dataset' is not defined\n", 327 | "50 name 'test_dataset' is not defined\n", 328 | "51 name 'test_dataset' is not defined\n", 329 | "52 name 'test_dataset' is not defined\n", 330 | "53 name 'test_dataset' is not defined\n", 331 | "54 name 'test_dataset' is not defined\n", 332 | "55 name 'test_dataset' is not defined\n", 333 | "56 name 'test_dataset' is not defined\n", 334 | "57 name 'test_dataset' is not defined\n", 335 | "58 name 'test_dataset' is not defined\n", 336 | "59 name 'test_dataset' is not defined\n", 337 | "60 name 'test_dataset' is not defined\n", 338 | "61 name 'test_dataset' is not defined\n", 339 | "62 name 'test_dataset' is not defined\n", 340 | "63 name 'test_dataset' is not defined\n", 341 | "64 name 'test_dataset' is not defined\n", 342 | "65 name 'test_dataset' is not defined\n", 343 | "66 name 'test_dataset' is not defined\n", 344 | "67 name 'test_dataset' is not defined\n", 345 | "68 name 'test_dataset' is not defined\n", 346 | "69 name 'test_dataset' is not defined\n", 347 | "70 name 'test_dataset' is not defined\n", 348 | "71 name 'test_dataset' is not defined\n", 349 | "72 name 'test_dataset' is not defined\n", 350 | "73 name 'test_dataset' is not defined\n", 351 | "74 name 'test_dataset' is not defined\n", 352 | "75 name 'test_dataset' is not defined\n", 353 | "76 name 'test_dataset' is not defined\n", 354 | "77 name 'test_dataset' is not defined\n", 355 | "78 name 'test_dataset' is not defined\n", 356 | "79 name 'test_dataset' is not defined\n", 357 | "80 name 'test_dataset' is not defined\n", 358 | "81 name 'test_dataset' is not defined\n", 359 | "82 name 'test_dataset' is not defined\n", 360 | "83 name 'test_dataset' is not defined\n", 361 | "84 name 'test_dataset' is not defined\n", 362 | "85 name 'test_dataset' is not defined\n", 363 | "86 name 'test_dataset' is not defined\n", 364 | "87 name 'test_dataset' is not defined\n", 365 | "88 name 'test_dataset' is not defined\n", 366 | "89 name 'test_dataset' is not defined\n", 367 | "90 name 'test_dataset' is not defined\n", 368 | "91 name 'test_dataset' is not defined\n", 369 | "92 name 'test_dataset' is not defined\n", 370 | "93 name 'test_dataset' is not defined\n", 371 | "94 name 'test_dataset' is not defined\n", 372 | "95 name 'test_dataset' is not defined\n", 373 | "96 name 'test_dataset' is not defined\n", 374 | "97 name 'test_dataset' is not defined\n", 375 | "98 name 'test_dataset' is not defined\n", 376 | "99 name 'test_dataset' is not defined\n", 377 | "(0,) (0,)\n" 378 | ] 379 | } 380 | ], 381 | "source": [ 382 | "k=8\n", 383 | "\n", 384 | "import numpy as np\n", 385 | "import matplotlib.pyplot as plt\n", 386 | "from sklearn.metrics import mean_absolute_error, mean_squared_error\n", 387 | "\n", 388 | "yhat=[]\n", 389 | "y=[]\n", 390 | "for k in range(100):\n", 391 | " try:\n", 392 | " y.append(test_dataset.iloc[k][target])\n", 393 | " yhat.append(model.predict([test_dataset.iloc[k][features]]))\n", 394 | " except Exception as e:\n", 395 | " print(k, e)\n", 396 | "\n", 397 | "y = np.array(y).flatten()\n", 398 | "yhat = np.array(yhat).flatten()\n", 399 | "print(y.shape, yhat.shape)" 400 | ] 401 | }, 402 | { 403 | "cell_type": "code", 404 | "execution_count": null, 405 | "metadata": {}, 406 | "outputs": [], 407 | "source": [ 408 | "lim = (min(y),max(y))\n", 409 | "plt.xlabel('True')\n", 410 | "plt.ylabel('Predicted')\n", 411 | "plt.plot(y, yhat, 'o', alpha=0.2)\n", 412 | "plt.plot(lim, lim, '--')\n", 413 | "plt.text(lim[0] + 0.1*(max(y)-min(y)), lim[1] - 1*0.1*(max(y)-min(y)), f\"correlation = {np.corrcoef(y, yhat)[0,1]:.3f}\")\n", 414 | "plt.text(lim[0] + 0.1*(max(y)-min(y)), lim[1] - 2*0.1*(max(y)-min(y)), f\"MAE = {mean_squared_error(y, yhat):.3f}\")\n", 415 | "plt.show()\n" 416 | ] 417 | } 418 | ], 419 | "metadata": { 420 | "kernelspec": { 421 | "display_name": "mapi", 422 | "language": "python", 423 | "name": "python3" 424 | }, 425 | "language_info": { 426 | "codemirror_mode": { 427 | "name": "ipython", 428 | "version": 3 429 | }, 430 | "file_extension": ".py", 431 | "mimetype": "text/x-python", 432 | "name": "python", 433 | "nbconvert_exporter": "python", 434 | "pygments_lexer": "ipython3", 435 | "version": "3.10.14" 436 | } 437 | }, 438 | "nbformat": 4, 439 | "nbformat_minor": 2 440 | } 441 | -------------------------------------------------------------------------------- /tests/baselines/get_dataset.py: -------------------------------------------------------------------------------- 1 | # from pymatgen import MPRester 2 | from mp_api.client import MPRester 3 | from emmet.core.summary import HasProps 4 | import requests 5 | import pandas as pd 6 | import os 7 | 8 | from datasets import Dataset, DatasetDict 9 | import sys, cloudpickle 10 | 11 | from dotenv import load_dotenv 12 | load_dotenv("../.env") 13 | 14 | with MPRester(os.environ['MAPI_API_KEY']) as mpr: 15 | docs = mpr.materials.summary.search( 16 | # Needed to remove some fields with dates because some entries had invalid dates. This was breaking the workflow 17 | fields=['formula_pretty', 'symmetry', 'density', 'nsites', 'elements', 'nelements', 'composition', 'composition_reduced', 'chemsys', 'volume', 'density_atomic', 'property_name', 'material_id', 'structure', 'uncorrected_energy_per_atom', 'energy_per_atom', 'formation_energy_per_atom', 'energy_above_hull', 'is_stable', 'equilibrium_reaction_energy_per_atom', 'decomposes_to', 'grain_boundaries', 'band_gap', 'efermi', 'is_gap_direct', 'is_metal', 'bandstructure', 'dos', 'dos_energy_up', 'dos_energy_down', 'is_magnetic', 'total_magnetization', 'total_magnetization_normalized_vol', 'total_magnetization_normalized_formula_units', 'num_magnetic_sites', 'num_unique_magnetic_sites', 'types_of_magnetic_species', 'bulk_modulus', 'shear_modulus', 'universal_anisotropy', 'homogeneous_poisson', 'e_total', 'e_ionic', 'e_electronic', 'n', 'e_ij_max', 'weighted_surface_energy_EV_PER_ANG2', 'weighted_surface_energy', 'weighted_work_function', 'surface_anisotropy', 'shape_factor', 'has_reconstructed', 'possible_species', 'theoretical'] 18 | ) 19 | 20 | cloudpickle.dump(docs, open('./docs.pkl', 'wb')) 21 | 22 | df = pd.DataFrame(d.model_dump() for d in docs) 23 | df['crystal_system'] = [d.model_dump()['symmetry']['crystal_system'].value for d in docs] 24 | df['symbol'] = [d.model_dump()['symmetry']['symbol'] for d in docs] 25 | df['number'] = [d.model_dump()['symmetry']['number'] for d in docs] 26 | df['point_group'] = [d.model_dump()['symmetry']['point_group'] for d in docs] 27 | df['structure'] = [d.model_dump()['structure'].__str__() for d in docs] 28 | # TODO: Process elements list to save it as well 29 | 30 | to_remove = [ 31 | #removed because they use fancier types 32 | 'builder_meta', 33 | 'composition', 34 | 'composition_reduced', 35 | 'symmetry', 36 | 'formula_anonymous', 37 | 'fields_not_requested', 38 | 'deprecated', 39 | 'decomposes_to', 40 | 'bandstructure', 41 | 'dos', 42 | 'types_of_magnetic_species', 43 | 'possible_species', 44 | 'elements', #Elements is a list of Elements. Need to convert it to a comma-separated string. But using regex on `formula_pretty` work as well 45 | # removed because they're useless 46 | 'warning', 47 | 'has_props', 48 | 'task_ids', 49 | 'database_IDs', 50 | 'property_name' 51 | ] 52 | 53 | dataset = Dataset.from_pandas(df.drop(to_remove, axis=1)) 54 | train_test_split = dataset.train_test_split( 55 | test_size=0.2, shuffle=True, seed=8 56 | ) 57 | 58 | # Created a frozen dataset. Now the train/test split will be constant as we work on different models 59 | dataset_dict = DatasetDict({ 60 | "train": train_test_split["train"], 61 | "test": train_test_split["test"], 62 | }) 63 | dataset_dict.push_to_hub(repo_id='ur-whitelab/mapi', private=True, token=os.environ['HF_TOKEN']) -------------------------------------------------------------------------------- /tests/test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "/Users/maykcaldas/miniconda3/envs/mapi/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 13 | " from .autonotebook import tqdm as notebook_tqdm\n" 14 | ] 15 | }, 16 | { 17 | "name": "stdout", 18 | "output_type": "stream", 19 | "text": [ 20 | "No module named 'phonopy'\n", 21 | "No module named 'phonopy'\n" 22 | ] 23 | } 24 | ], 25 | "source": [ 26 | "# import sys\n", 27 | "# import os\n", 28 | "# sys.path.append(os.path.join(os.path.dirname(\".\"), '..'))\n", 29 | "\n", 30 | "import mapillm" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 11, 36 | "metadata": {}, 37 | "outputs": [ 38 | { 39 | "name": "stdout", 40 | "output_type": "stream", 41 | "text": [ 42 | "\n", 43 | "\n", 44 | "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", 45 | "\u001b[32;1m\u001b[1;3mI need to first check if Fe3O4 is a band gap material and then predict its band gap.\n", 46 | "Action: Checks if material is band gap by formula\n", 47 | "Action Input: Fe3O4\u001b[0m" 48 | ] 49 | }, 50 | { 51 | "name": "stderr", 52 | "output_type": "stream", 53 | "text": [ 54 | "Retrieving SummaryDoc documents: 100%|██████████| 29/29 [00:00<00:00, 460737.94it/s]" 55 | ] 56 | }, 57 | { 58 | "name": "stdout", 59 | "output_type": "stream", 60 | "text": [ 61 | "\u001b[33;1m\u001b[1;3m0.0\u001b[0m" 62 | ] 63 | }, 64 | { 65 | "name": "stderr", 66 | "output_type": "stream", 67 | "text": [ 68 | "\n", 69 | "/Users/maykcaldas/miniconda3/envs/mapi/lib/python3.11/site-packages/mapillm/mapi_tools.py:169: UserWarning: More than one material found for Fe3O4. Will use the first one. Please, check the results.\n", 70 | " warnings.warn(f\"More than one material found for {formula}. Will use the first one. Please, check the results.\")\n" 71 | ] 72 | }, 73 | { 74 | "name": "stdout", 75 | "output_type": "stream", 76 | "text": [ 77 | "\u001b[32;1m\u001b[1;3mFe3O4 is not a band gap material.\n", 78 | "Action: Create band gap context to LLM search\n", 79 | "Action Input: Fe3O4\u001b[0m" 80 | ] 81 | }, 82 | { 83 | "name": "stderr", 84 | "output_type": "stream", 85 | "text": [ 86 | "Retrieving SummaryDoc documents: 100%|██████████| 8901/8901 [00:00<00:00, 10893.11it/s]\n" 87 | ] 88 | }, 89 | { 90 | "name": "stdout", 91 | "output_type": "stream", 92 | "text": [ 93 | "\u001b[38;5;200m\u001b[1;3mYou are a bot who can predict the band gap of a material .\n", 94 | "Given this list of known materials and the measurement of their band gap, \n", 95 | "you need to predict what is the band gap of the material:The answer should be numeric and finish with ###\n", 96 | "\n", 97 | "What is the band gap for Fe3O4?@@@\n", 98 | "0.000000###\n", 99 | "\n", 100 | "What is the band gap for Fe3O4?@@@\n", 101 | "1.068900###\n", 102 | "\n", 103 | "What is the band gap for Fe3O4?@@@\n", 104 | "0.325100###\n", 105 | "\n", 106 | "What is the band gap for Fe3O4?@@@\n", 107 | "0.089500###\n", 108 | "\n", 109 | "What is the band gap for Fe3O4?@@@\n", 110 | "0.020100###\n", 111 | "\n", 112 | "What is the band gap for Fe3O4?@@@\n", 113 | "0.301200###\n", 114 | "\n", 115 | "What is the band gap for Fe3O4?@@@\n", 116 | "0.243500###\n", 117 | "\n", 118 | "What is the band gap for Fe3O4?@@@\n", 119 | "0.001900###\n", 120 | "\n", 121 | "What is the band gap for Fe3O4?@@@\n", 122 | "0.777000###\n", 123 | "\n", 124 | "What is the band gap for Fe3O4?@@@\n", 125 | "0.637000###\n", 126 | "\n", 127 | "What is the band gap for Fe3O4?@@@\n", 128 | "\u001b[0m\u001b[32;1m\u001b[1;3mI now know the final answer\n", 129 | "Final Answer: The band gap of Fe3O4 is 0.637 eV.\u001b[0m\n", 130 | "\n", 131 | "\u001b[1m> Finished chain.\u001b[0m\n" 132 | ] 133 | }, 134 | { 135 | "data": { 136 | "text/plain": [ 137 | "{'input': \"What's the band gap of Fe3O4?\",\n", 138 | " 'output': 'The band gap of Fe3O4 is 0.637 eV.'}" 139 | ] 140 | }, 141 | "execution_count": 11, 142 | "metadata": {}, 143 | "output_type": "execute_result" 144 | } 145 | ], 146 | "source": [ 147 | "import os\n", 148 | "from dotenv import load_dotenv\n", 149 | "load_dotenv(\".env\", override=True)\n", 150 | "\n", 151 | "a = mapillm.Agent(openai_api_key=os.getenv(\"OPENAI_API_KEY\"), mapi_api_key=os.getenv(\"MAPI_API_KEY\"))\n", 152 | "a.run(\"What's the band gap of Fe3O4?\")" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": 3, 158 | "metadata": {}, 159 | "outputs": [ 160 | { 161 | <<<<<<< Updated upstream 162 | ======= 163 | "name": "stdout", 164 | "output_type": "stream", 165 | "text": [ 166 | "\n", 167 | "\n", 168 | "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", 169 | "\u001b[32;1m\u001b[1;3mI need to determine the band gap of the orthorrombic Fe2O4.\n", 170 | "Action: Checks if material is band gap by formula\n", 171 | "Action Input: Fe2O4\u001b[0m" 172 | ] 173 | }, 174 | { 175 | "name": "stderr", 176 | "output_type": "stream", 177 | "text": [ 178 | "Retrieving SummaryDoc documents: 100%|██████████| 24/24 [00:00<00:00, 122610.59it/s]\n", 179 | "/Users/maykcaldas/Documents/WhiteLab/MAPI_LLM/tests/../mapillm/mapi_tools.py:160: UserWarning: More than one material found for Fe2O4. Will use the first one. Please, check the results.\n", 180 | " warnings.warn(f\"More than one material found for {formula}. Will use the first one. Please, check the results.\")\n" 181 | ] 182 | }, 183 | { 184 | "name": "stdout", 185 | "output_type": "stream", 186 | "text": [ 187 | "\u001b[36;1m\u001b[1;3mCould not find any material while searching Fe2O4\u001b[0m\u001b[32;1m\u001b[1;3mI need to input the correct formula for Fe2O4.\n", 188 | "Action: Create band gap context to LLM completion\n", 189 | "Action Input: Fe2O4\u001b[0m" 190 | ] 191 | }, 192 | { 193 | "name": "stderr", 194 | "output_type": "stream", 195 | "text": [ 196 | "Retrieving SummaryDoc documents: 100%|██████████| 8901/8901 [00:01<00:00, 7605.50it/s]\n" 197 | ] 198 | }, 199 | { 200 | "name": "stdout", 201 | "output_type": "stream", 202 | "text": [ 203 | "\u001b[33;1m\u001b[1;3mYou are a bot who can predict the band gap of a material .\n", 204 | "Given this list of known materials and the measurement of their band gap, \n", 205 | "you need to predict what is the band gap of the material:The answer should be numeric and finish with ###\n", 206 | "\n", 207 | "What is the band gap for Monoclinic Fe2(PO4)3 with space group 2/m?@@@\n", 208 | "0.000000###\n", 209 | "\n", 210 | "What is the band gap for Tetragonal Fe2O3 with space group 422?@@@\n", 211 | "1.567300###\n", 212 | "\n", 213 | "What is the band gap for Orthorhombic Fe2(MoO4)3 with space group mmm?@@@\n", 214 | "2.664900###\n", 215 | "\n", 216 | "What is the band gap for Triclinic Fe2O5 with space group -1?@@@\n", 217 | "0.000000###\n", 218 | "\n", 219 | "What is the band gap for Tetragonal Fe2Ir2O5 with space group 4/mmm?@@@\n", 220 | "0.000000###\n", 221 | "\n", 222 | "What is the band gap for Orthorhombic Fe2O3 with space group mmm?@@@\n", 223 | "0.000000###\n", 224 | "\n", 225 | "What is the band gap for Monoclinic Fe2O3 with space group 2?@@@\n", 226 | "1.346500###\n", 227 | "\n", 228 | "What is the band gap for Trigonal Fe2O3 with space group 3?@@@\n", 229 | "0.000000###\n", 230 | "\n", 231 | "What is the band gap for Triclinic Fe2O3 with space group 1?@@@\n", 232 | "0.220200###\n", 233 | "\n", 234 | "What is the band gap for Monoclinic Fe2CoO6 with space group 2?@@@\n", 235 | "0.000000###\n", 236 | "\n", 237 | "What is the band gap for Fe2O4?@@@\n", 238 | "\u001b[0m\u001b[32;1m\u001b[1;3mI now know the final answer\n", 239 | "Final Answer: The band gap for Fe2O4 is 0.000000\u001b[0m\n", 240 | "\n", 241 | "\u001b[1m> Finished chain.\u001b[0m\n" 242 | ] 243 | }, 244 | { 245 | "data": { 246 | "text/plain": [ 247 | "{'input': \"What's the band gap of the orthorrombic Fe2O4?\",\n", 248 | " 'output': 'The band gap for Fe2O4 is 0.000000'}" 249 | ] 250 | }, 251 | "execution_count": 3, 252 | "metadata": {}, 253 | "output_type": "execute_result" 254 | } 255 | ], 256 | "source": [ 257 | "import os\n", 258 | "from dotenv import load_dotenv\n", 259 | "load_dotenv(\".env\", override=True)\n", 260 | "\n", 261 | "a = mapillm.Agent(openai_api_key=os.getenv(\"OPENAI_API_KEY\"), mapi_api_key=os.getenv(\"MAPI_API_KEY\"))\n", 262 | "a.run(\"What's the band gap of the orthorrombic Fe2O4?\")\n", 263 | "# a.run(\"Can you compute bandgap?\")\n" 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": 3, 269 | "metadata": {}, 270 | "outputs": [ 271 | { 272 | >>>>>>> Stashed changes 273 | "name": "stdout", 274 | "output_type": "stream", 275 | "text": [ 276 | "\n", 277 | "\n", 278 | "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", 279 | "\u001b[32;1m\u001b[1;3mI need to first check if Fe2O4 is a band gap material and then predict its band gap if it is.\n", 280 | "Action: Checks if material is band gap by formula\n", 281 | "Action Input: Fe2O4\u001b[0m" 282 | ] 283 | }, 284 | { 285 | "name": "stderr", 286 | "output_type": "stream", 287 | "text": [ 288 | "Retrieving SummaryDoc documents: 100%|██████████| 24/24 [00:00<00:00, 103456.62it/s]\n", 289 | "/Users/maykcaldas/miniconda3/envs/mapi/lib/python3.11/site-packages/mapillm/mapi_tools.py:169: UserWarning: More than one material found for Fe2O4. Will use the first one. Please, check the results.\n", 290 | " warnings.warn(f\"More than one material found for {formula}. Will use the first one. Please, check the results.\")\n" 291 | ] 292 | }, 293 | { 294 | "name": "stdout", 295 | "output_type": "stream", 296 | "text": [ 297 | "\u001b[33;1m\u001b[1;3mCould not find any material while searching Fe2O4\u001b[0m\u001b[32;1m\u001b[1;3mFe2O4 might not be the correct formula, I should try a different approach.\n", 298 | "Action: Get atoms in material\n", 299 | "Action Input: Fe2O4\u001b[0m\u001b[36;1m\u001b[1;3mFe,O\u001b[0m\u001b[32;1m\u001b[1;3mFe2O4 is actually FeO2, I should create a band gap context to LLM search for FeO2.\n", 300 | "Action: Create band gap context to LLM search\n", 301 | "Action Input: FeO2\u001b[0m" 302 | ] 303 | }, 304 | { 305 | "name": "stderr", 306 | "output_type": "stream", 307 | "text": [ 308 | "Retrieving SummaryDoc documents: 100%|██████████| 8901/8901 [00:01<00:00, 8620.61it/s] \n" 309 | ] 310 | }, 311 | { 312 | "name": "stdout", 313 | "output_type": "stream", 314 | "text": [ 315 | "\u001b[38;5;200m\u001b[1;3mYou are a bot who can predict the band gap of a material .\n", 316 | "Given this list of known materials and the measurement of their band gap, \n", 317 | "you need to predict what is the band gap of the material:The answer should be numeric and finish with ###\n", 318 | "\n", 319 | "What is the band gap for FeO2F?@@@\n", 320 | "0.000000###\n", 321 | "\n", 322 | "What is the band gap for ScFeO3?@@@\n", 323 | "1.365100###\n", 324 | "\n", 325 | "What is the band gap for CaFeO2?@@@\n", 326 | "2.506900###\n", 327 | "\n", 328 | "What is the band gap for Lu(FeO2)2?@@@\n", 329 | "0.000000###\n", 330 | "\n", 331 | "What is the band gap for FeO?@@@\n", 332 | "2.087800###\n", 333 | "\n", 334 | "What is the band gap for Dy(FeO2)2?@@@\n", 335 | "0.000000###\n", 336 | "\n", 337 | "What is the band gap for CaFeO2?@@@\n", 338 | "1.134200###\n", 339 | "\n", 340 | "What is the band gap for CaFeO2?@@@\n", 341 | "0.040200###\n", 342 | "\n", 343 | "What is the band gap for FeO3?@@@\n", 344 | "0.305100###\n", 345 | "\n", 346 | "What is the band gap for Ti(FeO2)2?@@@\n", 347 | "0.000000###\n", 348 | "\n", 349 | "What is the band gap for FeO2?@@@\n", 350 | "\u001b[0m\u001b[32;1m\u001b[1;3mAfter predicting the band gap for FeO2, I can now provide the final answer.\n", 351 | "Final Answer: The band gap for FeO2 is 0.000000\u001b[0m\n", 352 | "\n", 353 | "\u001b[1m> Finished chain.\u001b[0m\n" 354 | ] 355 | }, 356 | { 357 | "data": { 358 | "text/plain": [ 359 | "{'input': \"What's the band gap of Fe2O4?\",\n", 360 | " 'output': 'The band gap for FeO2 is 0.000000'}" 361 | ] 362 | }, 363 | "execution_count": 2, 364 | "metadata": {}, 365 | "output_type": "execute_result" 366 | } 367 | ], 368 | "source": [ 369 | "import os\n", 370 | "from dotenv import load_dotenv\n", 371 | "load_dotenv(\".env\", override=True)\n", 372 | "\n", 373 | "a = mapillm.Agent(openai_api_key=os.getenv(\"OPENAI_API_KEY\"), mapi_api_key=os.getenv(\"MAPI_API_KEY\"))\n", 374 | "a.run(\"What's the band gap of Fe2O4?\")\n" 375 | ] 376 | } 377 | ], 378 | "metadata": { 379 | "kernelspec": { 380 | "display_name": "mapi", 381 | "language": "python", 382 | "name": "python3" 383 | }, 384 | "language_info": { 385 | "codemirror_mode": { 386 | "name": "ipython", 387 | "version": 3 388 | }, 389 | "file_extension": ".py", 390 | "mimetype": "text/x-python", 391 | "name": "python", 392 | "nbconvert_exporter": "python", 393 | "pygments_lexer": "ipython3", 394 | "version": "3.11.8" 395 | } 396 | }, 397 | "nbformat": 4, 398 | "nbformat_minor": 2 399 | } 400 | --------------------------------------------------------------------------------