├── .github └── workflows │ └── codeql.yml ├── .gitignore ├── BRIDGE ├── example_initial_templates.txt ├── llm4text2ts.py ├── llm_agents │ ├── __init__.py │ ├── agent.py │ ├── llm.py │ └── tools │ │ ├── base.py │ │ ├── bing.py │ │ ├── hackernews.py │ │ ├── python_repl.py │ │ ├── search.py │ │ └── wikipedia.py ├── requirements.txt ├── self_refine │ ├── data.py │ ├── eval.py │ ├── feedback.py │ ├── make_challenging.py │ ├── multi_agent.py │ ├── predictor.py │ ├── prompt_building.py │ ├── refiner.py │ ├── run.py │ ├── task_init.py │ └── task_iterate.py ├── self_refine_main.py ├── template_extractor.py ├── ts_metrics.py └── ts_to_text.py ├── CODE_OF_CONDUCT.md ├── DiGA ├── README.md ├── agent │ ├── interactive_replay_agent.py │ ├── meta_agent.py │ ├── rl_agent.py │ └── utils │ │ ├── chartist_state.py │ │ ├── ctrl_loader.py │ │ ├── meta_oracle.py │ │ ├── rl_state.py │ │ └── trade_info_state.py ├── diffusion │ ├── cont_ctrl_net.py │ ├── ddpm.py │ └── disc_ctrl_net.py ├── environment.yml ├── generate.py ├── rltask │ ├── envs │ │ ├── base_market_env.py │ │ ├── ctrl_market_env.py │ │ └── replay_market_env.py │ └── train_test_rl.py ├── train.py └── utils │ ├── metrics_utils.py │ ├── pkl_utils.py │ └── test_utils.py ├── LICENSE ├── README.md ├── SECURITY.md ├── TarDiff ├── README.md ├── classifier │ ├── __init__.py │ ├── classifier_train.py │ ├── model.py │ └── train.sh ├── configs │ └── base │ │ └── mimic_icustay_base.yaml ├── data_preprocess │ └── README.md ├── environment.yaml ├── generation.sh ├── guidance_generation.py ├── images │ └── overview.png ├── ldm │ ├── data │ │ └── tsg_dataset.py │ ├── lr_scheduler.py │ ├── models │ │ ├── autoencoder.py │ │ └── diffusion │ │ │ ├── __init__.py │ │ │ ├── classifier.py │ │ │ ├── ddim.py │ │ │ ├── ddpm.py │ │ │ ├── guided_ddim.py │ │ │ ├── plms.py │ │ │ └── uni_csg.py │ ├── modules │ │ ├── attention.py │ │ ├── diffusionmodules │ │ │ ├── __init__.py │ │ │ ├── model.py │ │ │ ├── unet1d.py │ │ │ └── util.py │ │ ├── distributions │ │ │ ├── __init__.py │ │ │ └── distributions.py │ │ ├── ema.py │ │ ├── encoders │ │ │ ├── __init__.py │ │ │ └── modules.py │ │ ├── guidance_scorer.py │ │ └── losses │ │ │ ├── __init__.py │ │ │ ├── contperceptual.py │ │ │ └── vqperceptual.py │ └── util.py ├── train.sh ├── train_main.py └── utils │ ├── __init__.py │ └── callback_utils.py ├── TimeDP ├── .env ├── .gitignore ├── README.md ├── configs │ └── multi_domain_timedp.yaml ├── environment.yml ├── figure │ └── TimeDP_Overview.jpg ├── ldm │ ├── data │ │ └── tsg_dataset.py │ ├── lr_scheduler.py │ ├── models │ │ ├── autoencoder.py │ │ └── diffusion │ │ │ ├── ddim_time.py │ │ │ └── ddpm_time.py │ ├── modules │ │ ├── attention.py │ │ ├── diffusionmodules │ │ │ ├── ts_unet.py │ │ │ └── util.py │ │ ├── distributions │ │ │ └── distributions.py │ │ ├── ema.py │ │ └── encoders │ │ │ └── modules.py │ └── util.py ├── main_train.py ├── metrics │ ├── feature_distance_eval.py │ └── metrics_sets.py ├── train.sh ├── utils │ ├── callback_utils.py │ ├── cli_utils.py │ ├── data_utils.py │ ├── init_utils.py │ ├── pkl_utils.py │ ├── prepare_datasets.py │ └── test_utils.py └── visualize.py ├── diffusion ├── classifier │ ├── __init__.py │ ├── classifier_train.py │ ├── model.py │ └── train.sh ├── configs │ ├── mimic_icustay_base.yaml │ └── text_control.yaml ├── environment.yml ├── inference.py ├── ldm │ ├── data │ │ └── tsg_dataset.py │ ├── lr_scheduler.py │ ├── models │ │ ├── autoencoder.py │ │ └── diffusion │ │ │ ├── conditioning_mlp.py │ │ │ ├── ddim_time.py │ │ │ ├── ddpm_time.py │ │ │ └── guided_ddim_time.py │ ├── modules │ │ ├── attention.py │ │ ├── diffusionmodules │ │ │ ├── ts_unet.py │ │ │ └── util.py │ │ ├── distributions │ │ │ └── distributions.py │ │ ├── ema.py │ │ ├── encoders │ │ │ └── modules.py │ │ └── guidance_scorer.py │ └── util.py ├── main.py ├── metrics │ ├── feature_distance_eval.py │ └── metrics_sets.py ├── train.sh ├── utils │ ├── callback_utils.py │ ├── cli_utils.py │ ├── data_utils.py │ ├── init_utils.py │ ├── pkl_utils.py │ ├── test_utils.py │ └── text_encoder.py └── visualize.py ├── env.txt ├── environment.yml ├── figures ├── BRIDGE.jpeg ├── TarDiff_result.png ├── TextPreparation.jpeg ├── TimeCraft.png ├── TimeCraft2.png ├── overview_2.png ├── prototype_like_words.jpeg ├── pt_like_word_small.png └── timedp_indomain.png ├── process ├── dataset_split.py ├── prompt_bank.js ├── text_templates_example.json └── ts_to_text.py ├── supplementary ├── dataset_split.md ├── examples.md ├── inference_guidance.md ├── inference_prototype.md ├── inference_prototype_text.md ├── mimiciii_prepare.md └── training_details.md └── train_inference.py /.github/workflows/codeql.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | # 7 | # ******** NOTE ******** 8 | # We have attempted to detect the languages in your repository. Please check 9 | # the `language` matrix defined below to confirm you have the correct set of 10 | # supported CodeQL languages. 11 | # 12 | name: "CodeQL Advanced" 13 | 14 | on: 15 | push: 16 | branches: [ "main" ] 17 | pull_request: 18 | branches: [ "main" ] 19 | schedule: 20 | - cron: '26 16 * * 0' 21 | 22 | jobs: 23 | analyze: 24 | name: Analyze (${{ matrix.language }}) 25 | # Runner size impacts CodeQL analysis time. To learn more, please see: 26 | # - https://gh.io/recommended-hardware-resources-for-running-codeql 27 | # - https://gh.io/supported-runners-and-hardware-resources 28 | # - https://gh.io/using-larger-runners (GitHub.com only) 29 | # Consider using larger runners or machines with greater resources for possible analysis time improvements. 30 | runs-on: ${{ (matrix.language == 'swift' && 'macos-latest') || 'ubuntu-22.04' }} 31 | permissions: 32 | # required for all workflows 33 | security-events: write 34 | 35 | # required to fetch internal or private CodeQL packs 36 | packages: read 37 | 38 | # only required for workflows in private repositories 39 | actions: read 40 | contents: read 41 | 42 | strategy: 43 | fail-fast: false 44 | matrix: 45 | include: 46 | - language: python 47 | build-mode: none 48 | # CodeQL supports the following values keywords for 'language': 'c-cpp', 'csharp', 'go', 'java-kotlin', 'javascript-typescript', 'python', 'ruby', 'swift' 49 | # Use `c-cpp` to analyze code written in C, C++ or both 50 | # Use 'java-kotlin' to analyze code written in Java, Kotlin or both 51 | # Use 'javascript-typescript' to analyze code written in JavaScript, TypeScript or both 52 | # To learn more about changing the languages that are analyzed or customizing the build mode for your analysis, 53 | # see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/customizing-your-advanced-setup-for-code-scanning. 54 | # If you are analyzing a compiled language, you can modify the 'build-mode' for that language to customize how 55 | # your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages 56 | steps: 57 | - name: Checkout repository 58 | uses: actions/checkout@v4 59 | 60 | # Initializes the CodeQL tools for scanning. 61 | - name: Initialize CodeQL 62 | uses: github/codeql-action/init@v3 63 | with: 64 | languages: ${{ matrix.language }} 65 | build-mode: ${{ matrix.build-mode }} 66 | # If you wish to specify custom queries, you can do so here or in a config file. 67 | # By default, queries listed here will override any specified in a config file. 68 | # Prefix the list here with "+" to use these queries and those in the config file. 69 | 70 | # For more details on CodeQL's query packs, refer to: https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs 71 | # queries: security-extended,security-and-quality 72 | 73 | # If the analyze step fails for one of the languages you are analyzing with 74 | # "We were unable to automatically build your code", modify the matrix above 75 | # to set the build mode to "manual" for that language. Then modify this step 76 | # to build your code. 77 | # ℹ️ Command-line programs to run using the OS shell. 78 | # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun 79 | - if: matrix.build-mode == 'manual' 80 | shell: bash 81 | run: | 82 | echo 'If you are using a "manual" build mode for one or more of the' \ 83 | 'languages you are analyzing, replace this with the commands to build' \ 84 | 'your code, for example:' 85 | echo ' make bootstrap' 86 | echo ' make release' 87 | exit 1 88 | 89 | - name: Perform CodeQL Analysis 90 | uses: github/codeql-action/analyze@v3 91 | with: 92 | category: "/language:${{matrix.language}}" 93 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | **/__pycache__/ 3 | *.pyc 4 | *.pyo 5 | *.pyd 6 | *.pyw 7 | *.pyz 8 | 9 | logs/ 10 | -------------------------------------------------------------------------------- /BRIDGE/example_initial_templates.txt: -------------------------------------------------------------------------------- 1 | 1. The graph illustrates {metric} trends over time from {start_year} to {end_year}. Overall, {metric} {describe_general_trend}. In the initial years, {detail_initial_years}. As time progressed, {change_description}, culminating in {end_description} by {end_year}. 2 | 2. This chart presents {metric} variations per {unit} in {location} between {start_year} and {end_year}. {metric} remained highest in {location} throughout the period. Initially, {description_initial_period}. Subsequently, {description_mid_period}, while {location} experienced a {change_type} from {value_start} to {value_end}. 3 | 3. The diagram outlines {metric} per {unit} in {location} from {start_year} to {end_year}. Overall, {metric} {describe_overall_trend}. Initially, {location1}, {location2}, and {location3} showed {initial_behavior}. Over time, {location1} and {location2} {mid_behavior}, whereas {location3} {end_behavior}. 4 | 4. This line graph compares {metric} across different {location} from {start_year} to {end_year}. Throughout the entire period, {location} consistently showed {metric_behavior}. Initially, {location1}, {location2}, and {location3} had {initial_metric}. By {end_year}, {location1} and {location2} {mid_metric}, whereas {location3} {end_metric}. 5 | 5. The chart illustrates the progression of {metric} in {location} from {start_year} to {end_year}. {metric} remained highest in {location} throughout the observed period. Initially, {description_initial_period}. As years passed, {description_mid_period}, leading to {final_value} by {end_year}. 6 | 6. This graph depicts {metric} over time in {location} between {start_year} and {end_year}. {metric} showed distinct trends across the regions. At the start, {initial_description}. Over the years, {mid_description}, whereas {location3} {end_description}. 7 | 7. The diagram presents {metric} per {unit} in {location} from {start_year} to {end_year}. Overall, {metric} {overall_trend}. At the beginning, {location1}, {location2}, and {location3} had {initial_metric}. Over time, {location1} and {location2} {mid_metric}, while {location3} {end_metric}. 8 | 8. This line graph shows {metric} in {location} from {start_year} to {end_year}. {metric} remained highest in {location} throughout the entire period. Initially, {initial_description}. By {end_year}, {mid_description}, with {location3} {end_description}. 9 | 9. The chart illustrates the {metric} variations per {unit} in {location} over the period from {start_year} to {end_year}. {metric} showed diverse patterns across the regions. At the start, {initial_behavior}. Subsequently, {mid_behavior}, while {location3} {end_behavior}. 10 | 10. This graph outlines the {metric} trends in {location} between {start_year} and {end_year}. {metric} exhibited varying patterns across different regions. Initially, {initial_description}. As time progressed, {mid_description}, culminating in {final_value} by {end_year}. -------------------------------------------------------------------------------- /BRIDGE/llm_agents/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from llm_agents.agent import Agent 5 | from llm_agents.llm import ChatLLM 6 | from llm_agents.tools.python_repl import PythonREPLTool 7 | from llm_agents.tools.hackernews import HackerNewsSearchTool 8 | from llm_agents.tools.search import SerpAPITool 9 | from llm_agents.tools.searx import SearxSearchTool 10 | from llm_agents.tools.google_search import GoogleSearchTool 11 | 12 | __all__ = ['Agent', 'ChatLLM', 'PythonREPLTool', 13 | 'HackerNewsSearchTool', 'SerpAPITool', 'SearxSearchTool', 'GoogleSearchTool'] 14 | -------------------------------------------------------------------------------- /BRIDGE/llm_agents/agent.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import datetime 5 | import re 6 | 7 | from pydantic import BaseModel 8 | from typing import List, Dict, Tuple 9 | from llm_agents.llm import ChatLLM 10 | from llm_agents.tools.base import ToolInterface 11 | from llm_agents.tools.python_repl import PythonREPLTool 12 | 13 | 14 | FINAL_ANSWER_TOKEN = "Final Answer:" 15 | OBSERVATION_TOKEN = "Observation:" 16 | THOUGHT_TOKEN = "Thought:" 17 | PROMPT_TEMPLATE = """Today is {today} and you can use tools to get new information. Answer the question as best as you can using the following tools: 18 | 19 | {tool_description} 20 | 21 | # Use the following format: 22 | 23 | Question: the input question you must answer 24 | Thought: comment on what you want to do next 25 | Action: the action to take, exactly one element of [{tool_names}] 26 | Action Input: the input to the action 27 | Observation: the result of the action 28 | ... (this Thought/Action/Action Input/Observation repeats N times, use it until you are sure of the answer) 29 | Thought: I now know the final answer 30 | Final Answer: your final answer to the original input question 31 | 32 | # Attention: 33 | You will come up with solutions for any task or problem by following these steps: 34 | 1. You should first understand, analyze, and break down the human's problem/task. 35 | 2. You should then select the appropriate toolset ({tool_names}) to solve the problem/task. 36 | 3. You should act as an expert-level ChatGPT prompt engineer and planner with expertise in multiple fields, so that you can better develop a problem-solving plan and provide the best answer. 37 | 4. The execution plan should consist of multiple steps that solve the problem progressively. Make the plan as detailed as possible to ensure the accuracy and completeness. 38 | 5. Final answer should have the source link if the answer is from the internet. 39 | 40 | # Begin! 41 | 42 | Question: {question} 43 | Thought: {previous_responses} 44 | """ 45 | 46 | 47 | class Agent(BaseModel): 48 | llm: ChatLLM 49 | tools: List[ToolInterface] 50 | prompt_template: str = PROMPT_TEMPLATE 51 | max_loops: int = 10 52 | # The stop pattern is used, so the LLM does not hallucinate until the end 53 | stop_pattern: List[str] = [f'\n{OBSERVATION_TOKEN}', f'\n\t{OBSERVATION_TOKEN}'] 54 | 55 | @property 56 | def tool_description(self) -> str: 57 | return "\n".join([f"{tool.name}: {tool.description}" for tool in self.tools]) 58 | 59 | @property 60 | def tool_names(self) -> str: 61 | return ",".join([tool.name for tool in self.tools]) 62 | 63 | @property 64 | def tool_by_names(self) -> Dict[str, ToolInterface]: 65 | return {tool.name: tool for tool in self.tools} 66 | 67 | def run(self, question: str): 68 | previous_responses = [] 69 | num_loops = 0 70 | prompt = self.prompt_template.format( 71 | today = datetime.date.today(), 72 | tool_description=self.tool_description, 73 | tool_names=self.tool_names, 74 | question=question, 75 | previous_responses='{previous_responses}' 76 | ) 77 | print(prompt.format(previous_responses='')) 78 | while num_loops < self.max_loops: 79 | num_loops += 1 80 | curr_prompt = prompt.format(previous_responses='\n'.join(previous_responses)) 81 | generated, tool, tool_input = self.decide_next_action(curr_prompt) 82 | if tool == 'Final Answer': 83 | return tool_input 84 | if tool not in self.tool_by_names: 85 | raise ValueError(f"Unknown tool: {tool}") 86 | tool_result = self.tool_by_names[tool].use(tool_input) 87 | generated += f"\n{OBSERVATION_TOKEN} {tool_result}\n{THOUGHT_TOKEN}" 88 | print(generated) 89 | previous_responses.append(generated) 90 | 91 | def decide_next_action(self, prompt: str) -> str: 92 | generated = self.llm.generate(prompt, stop=self.stop_pattern) 93 | tool, tool_input = self._parse(generated) 94 | return generated, tool, tool_input 95 | 96 | def _parse(self, generated: str) -> Tuple[str, str]: 97 | if FINAL_ANSWER_TOKEN in generated: 98 | return "Final Answer", generated.split(FINAL_ANSWER_TOKEN)[-1].strip() 99 | regex = r"Action: [\[]?(.*?)[\]]?[\n]*Action Input:[\s]*(.*)" 100 | match = re.search(regex, generated, re.DOTALL) 101 | if not match: 102 | raise ValueError(f"Output of LLM is not parsable for next tool use: `{generated}`") 103 | tool = match.group(1).strip() 104 | tool_input = match.group(2) 105 | return tool, tool_input.strip(" ").strip('"') 106 | -------------------------------------------------------------------------------- /BRIDGE/llm_agents/llm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import openai 5 | import os 6 | from pydantic import BaseModel 7 | from typing import List, Optional 8 | 9 | # os.environ['OPENAI_API_KEY'] = "xxx" 10 | 11 | class ChatLLM(BaseModel): 12 | model: str = 'gpt-4o-2024-05-13' 13 | temperature: float = 0.0 14 | api_key: Optional[str] = None 15 | 16 | def __init__(self, **data): 17 | super().__init__(**data) 18 | if self.api_key is None: 19 | self.api_key = os.getenv("OPENAI_API_KEY") 20 | if not self.api_key: 21 | raise ValueError("No API key provided. Please set the OPENAI_API_KEY environment variable or provide it explicitly.") 22 | openai.api_key = self.api_key 23 | 24 | def generate(self, prompt: str, stop: List[str] = None): 25 | response = openai.ChatCompletion.create( 26 | model=self.model, 27 | messages=[{"role": "user", "content": prompt}], 28 | temperature=self.temperature, 29 | stop=stop 30 | ) 31 | return response.choices[0].message.content 32 | 33 | -------------------------------------------------------------------------------- /BRIDGE/llm_agents/tools/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from pydantic import BaseModel 5 | 6 | class ToolInterface(BaseModel): 7 | name: str 8 | description: str 9 | 10 | def use(self, input_text: str) -> str: 11 | raise NotImplementedError("use() method not implemented") # Implement in subclass 12 | -------------------------------------------------------------------------------- /BRIDGE/llm_agents/tools/bing.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import os 5 | from typing import Any, Dict, List 6 | from llm_agents.tools.base import ToolInterface 7 | import requests 8 | 9 | 10 | def _bing_search_results(query: str, max_results: int = 10) -> List[Dict[str, Any]]: 11 | """Query Bing Web Search API (Azure Cognitive Services) and return results.""" 12 | 13 | subscription_key = os.getenv("AZURE_BING_SEARCH_KEY") # Azure API Key 14 | endpoint = os.getenv("AZURE_BING_SEARCH_ENDPOINT") # e.g. https://.api.cognitive.microsoft.com 15 | 16 | if not subscription_key or not endpoint: 17 | raise EnvironmentError("Azure Bing credentials not found in environment variables.") 18 | 19 | search_url = f"{endpoint}/bing/v7.0/search" 20 | headers = {"Ocp-Apim-Subscription-Key": subscription_key} 21 | params = {"q": query, "count": max_results} 22 | 23 | response = requests.get(search_url, headers=headers, params=params) 24 | response.raise_for_status() 25 | 26 | return response.json().get("webPages", {}).get("value", []) 27 | 28 | 29 | def search(query: str) -> str: 30 | """Perform a web search using Azure Bing and return formatted text results.""" 31 | results = _bing_search_results(query) 32 | 33 | if not results: 34 | return "No relevant Bing Search Result was found." 35 | 36 | toret = [] 37 | for i, result in enumerate(results, 1): 38 | title = result.get("name") 39 | url = result.get("url") 40 | snippet = result.get("snippet") 41 | toret.append(f"Result {i}:\nTitle: {title}\nURL: {url}\nSnippet: {snippet}\n") 42 | 43 | return "\n".join(toret) 44 | 45 | 46 | class AzureBingSearchTool(ToolInterface): 47 | """Tool for web search using Azure Bing Search API.""" 48 | 49 | name: str = "Azure Bing Search" 50 | description: str = ( 51 | "Searches the web using Azure Bing Search API. " 52 | "Provide a natural language question and receive summarized web information." 53 | ) 54 | 55 | def use(self, input_text: str) -> str: 56 | return search(input_text) 57 | 58 | 59 | if __name__ == '__main__': 60 | tool = AzureBingSearchTool() 61 | res = tool.use("Who was the pope in 2023?") 62 | print(res) 63 | -------------------------------------------------------------------------------- /BRIDGE/llm_agents/tools/hackernews.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import requests 5 | from bs4 import BeautifulSoup 6 | from llm_agents.tools.base import ToolInterface 7 | 8 | 9 | ENDPOINT = "***" 10 | 11 | 12 | def extract_text_from(url, max_len: int = 2000) -> str: 13 | html = requests.get(url).text 14 | soup = BeautifulSoup(html, features="html.parser") 15 | text = soup.get_text() 16 | 17 | lines = (line.strip() for line in text.splitlines()) 18 | return '\n'.join(line for line in lines if line)[:max_len] 19 | 20 | 21 | def search_hn(query: str, crawl_urls: bool) -> str: 22 | params = { 23 | "query": query, 24 | "tags": "story", 25 | "numericFilters": "points>100" 26 | } 27 | 28 | response = requests.get(ENDPOINT, params=params) 29 | 30 | hits = response.json()["hits"] 31 | 32 | result = "" 33 | for hit in hits[:5]: 34 | title = hit["title"] 35 | url = hit["url"] 36 | result += f"Title: {title}\n" 37 | 38 | if url is not None and crawl_urls: 39 | result += f"\tExcerpt: {extract_text_from(url)}\n" 40 | else: 41 | objectID = hit["objectID"] 42 | comments_url = f"{ENDPOINT}?tags=comment,story_{objectID}&hitsPerPage=1" 43 | comments_response = requests.get(comments_url) 44 | comment = comments_response.json()["hits"][0]['comment_text'] 45 | 46 | result += f"\tComment: {comment}\n" 47 | return result 48 | 49 | 50 | class HackerNewsSearchTool(ToolInterface): 51 | """Tool to get some insight from Hacker News users""" 52 | 53 | name: str = "hacker news search" # Add type annotation 54 | description: str = ("Get insight from hacker news users to specific search terms. " 55 | "Input should be a search term (e.g. How to get rich?). " 56 | "The output will be the most recent stories related to it with a user comment.") # Add type annotation 57 | crawl_urls: bool = False # Add type annotation 58 | 59 | def use(self, input_text: str) -> str: 60 | return search_hn(input_text, self.crawl_urls) 61 | 62 | 63 | if __name__ == '__main__': 64 | hn = HackerNewsSearchTool() 65 | res = hn.use('GPT-4') 66 | print(res) 67 | -------------------------------------------------------------------------------- /BRIDGE/llm_agents/tools/python_repl.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import sys 5 | from io import StringIO 6 | from typing import Dict, Optional 7 | 8 | from pydantic import BaseModel, Field 9 | from llm_agents.tools.base import ToolInterface 10 | 11 | 12 | # Taken from https://github.com/hwchase17/langchain/blob/master/langchain/python.py 13 | class PythonREPL(BaseModel): 14 | """Simulates a standalone Python REPL.""" 15 | 16 | globals: Optional[Dict] = Field(default_factory=dict, alias="_globals") 17 | locals: Optional[Dict] = Field(default_factory=dict, alias="_locals") 18 | 19 | def run(self, command: str) -> str: 20 | """Run command with own globals/locals and returns anything printed.""" 21 | old_stdout = sys.stdout 22 | sys.stdout = mystdout = StringIO() 23 | try: 24 | exec(command, self.globals, self.locals) 25 | sys.stdout = old_stdout 26 | output = mystdout.getvalue() 27 | except Exception as e: 28 | sys.stdout = old_stdout 29 | output = str(e) 30 | return output 31 | 32 | 33 | def _get_default_python_repl() -> PythonREPL: 34 | return PythonREPL(_globals=globals(), _locals=None) 35 | 36 | 37 | class PythonREPLTool(ToolInterface): 38 | """A tool for running python code in a REPL.""" 39 | 40 | name: str = "Python REPL" 41 | description: str = ( 42 | "A Python shell. Use this to execute python commands. " 43 | "Input should be a valid python command. " 44 | "If you want to see the output of a value, you should print it out " 45 | "with `print(...)`." 46 | ) 47 | python_repl: PythonREPL = Field(default_factory=_get_default_python_repl) 48 | 49 | def use(self, input_text: str) -> str: 50 | input_text = input_text.strip().strip("```") 51 | return self.python_repl.run(input_text) 52 | 53 | 54 | if __name__ == '__main__': 55 | repl_tool = PythonREPLTool() 56 | result = repl_tool.use('print(5 * 7)') 57 | assert result == "35\n" 58 | print(result) 59 | -------------------------------------------------------------------------------- /BRIDGE/llm_agents/tools/search.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | # Based on https://github.com/hwchase17/langchain/blob/master/langchain/utilities/serpapi.py 5 | 6 | import os 7 | import sys 8 | from typing import Any 9 | 10 | from llm_agents.tools.base import ToolInterface 11 | 12 | def search(query: str) -> str: 13 | # Placeholder implementation. 14 | # Replace this function with actual Bing Web Search API call using Azure SDK or REST API. 15 | # Example: https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/quickstarts/rest/python 16 | 17 | return f"Search results for '{query}' would appear here using Bing Web Search API." 18 | 19 | def _process_response(res: dict) -> str: 20 | """Process response from SerpAPI.""" 21 | focus = ['title', 'snippet', 'link'] 22 | get_focused = lambda x: {i: j for i, j in x.items() if i in focus} 23 | 24 | if "error" in res.keys(): 25 | raise ValueError(f"Got error from SerpAPI: {res['error']}") 26 | if "answer_box" in res.keys() and "answer" in res["answer_box"].keys(): 27 | toret = res["answer_box"]["answer"] 28 | elif "answer_box" in res.keys() and "snippet" in res["answer_box"].keys(): 29 | toret = res["answer_box"]["snippet"] 30 | elif ( 31 | "answer_box" in res.keys() 32 | and "snippet_highlighted_words" in res["answer_box"].keys() 33 | ): 34 | toret = res["answer_box"]["snippet_highlighted_words"][0] 35 | elif ( 36 | "sports_results" in res.keys() 37 | and "game_spotlight" in res["sports_results"].keys() 38 | ): 39 | toret = res["sports_results"]["game_spotlight"] 40 | elif ( 41 | "knowledge_graph" in res.keys() 42 | and "description" in res["knowledge_graph"].keys() 43 | ): 44 | toret = res["knowledge_graph"]["description"] 45 | elif "snippet" in res["organic_results"][0].keys(): 46 | toret = res["organic_results"][0]["snippet"] 47 | 48 | else: 49 | toret = "No good search result found" 50 | 51 | toret_l = [] 52 | if res.get("organic_results"): 53 | for i, result in enumerate(res["organic_results"], 1): 54 | focused_info = get_focused(result) 55 | toret_l.append(f"Result {i}: {focused_info.get('title')}\n" 56 | f"Link: {focused_info.get('link')}\n" 57 | f"Snippet: {focused_info.get('snippet')}\n") 58 | 59 | return toret + '\n'.join(toret_l) 60 | 61 | # return str(toret) + '\n' + str(toret_l) 62 | 63 | class HiddenPrints: 64 | """Context manager to hide prints.""" 65 | 66 | def __enter__(self) -> None: 67 | """Open file to pipe stdout to.""" 68 | self._original_stdout = sys.stdout 69 | sys.stdout = open(os.devnull, "w") 70 | 71 | def __exit__(self, *_: Any) -> None: 72 | """Close file that stdout was piped to.""" 73 | sys.stdout.close() 74 | sys.stdout = self._original_stdout 75 | 76 | 77 | class AzureSearchTool(ToolInterface): 78 | """Tool for performing web search using Microsoft Azure Bing Search.""" 79 | 80 | name: str = "Azure Bing Search" 81 | description: str = ( 82 | "Use this tool to retrieve information from the web based on a natural language query. " 83 | "Input should be a question like 'How to add numbers in Clojure?'. " 84 | "The output will be a concise and relevant answer based on Bing Web Search results." 85 | ) 86 | 87 | def use(self, input_text: str) -> str: 88 | return search(input_text) 89 | -------------------------------------------------------------------------------- /BRIDGE/llm_agents/tools/wikipedia.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import os 5 | from typing import Any 6 | 7 | from llm_agents.tools.base import ToolInterface 8 | import wikipedia 9 | 10 | 11 | def search(query: str) -> str: 12 | try: 13 | summary = wikipedia.summary(query, sentences=1) 14 | page = wikipedia.page(query) 15 | return f"Summary: {summary}\nLink: {page.url}" 16 | except wikipedia.exceptions.DisambiguationError as e: 17 | options = e.options[:3] # Limiting to top 3 options 18 | return f"Disambiguation: {', '.join(options)}" 19 | except wikipedia.exceptions.PageError: 20 | return "No Wikipedia page found" 21 | except wikipedia.exceptions.WikipediaException as e: 22 | return f"Wikipedia error: {str(e)}" 23 | 24 | 25 | class WikipediaAPITool(ToolInterface): 26 | """Tool for Wikipedia search results.""" 27 | 28 | name: str = "Wikipedia Search" 29 | description: str = ("Get summary and link from Wikipedia for a given query. " 30 | "Input should be a topic name or question like 'Python programming'.") 31 | 32 | def use(self, input_text: str) -> str: 33 | return search(input_text) 34 | 35 | 36 | if __name__ == '__main__': 37 | s = WikipediaAPITool() 38 | res = s.use("Python programming") 39 | print(res) 40 | -------------------------------------------------------------------------------- /BRIDGE/requirements.txt: -------------------------------------------------------------------------------- 1 | serpapi 2 | aiohttp 3 | argparse 4 | openai==0.28 5 | pydantic>=1.10.5 6 | requests>=2.28.2 7 | google-api-python-client>=2.83.0 8 | google-search-results>=2.4.2 9 | bs4 10 | pandas 11 | typing 12 | scikit-learn 13 | nltk 14 | statsmodels 15 | multiprocess 16 | tiktoken 17 | darts 18 | transformers 19 | datasets 20 | gdown 21 | jsonlines -------------------------------------------------------------------------------- /BRIDGE/self_refine/eval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import pandas as pd 5 | from typing import List, Dict 6 | 7 | def run(path: str): 8 | 9 | df = pd.read_json(path, lines=True, orient="records") 10 | df = df[df['status'] != "error"] 11 | print(f"Loaded {len(df)} rows") 12 | for i, row in df.iterrows(): 13 | direct_output = row["sent_to_fb"][0] 14 | iter_output = row["sent_to_fb"][-1] 15 | df.loc[i, 'direct_concept_success'] = direct_output["concept_feedback"][0].lower() == "none" 16 | df.loc[i, 'direct_commonsense_success'] = direct_output["commonsense_feedback"].lower() == "none" 17 | df.loc[i, 'direct_success'] = direct_output["concept_feedback"][0].lower() == "none" and direct_output["commonsense_feedback"].lower() == "none" 18 | df.loc[i, 'iter_concept_success'] = iter_output["concept_feedback"][0].lower() == "none" 19 | df.loc[i, 'iter_commonsense_success'] = iter_output["commonsense_feedback"].lower() == "none" 20 | df.loc[i, 'iter_success'] = iter_output["concept_feedback"][0].lower() == "none" and iter_output["commonsense_feedback"].lower() == "none" 21 | 22 | # direct wins 23 | num_direct_cocept_wins = len(df[(df['direct_concept_success'] == True) & (df['iter_concept_success'] == False)]) 24 | num_direct_commonsense_wins = len(df[(df['direct_commonsense_success'] == True) & (df['iter_commonsense_success'] == False)]) 25 | num_iter_cocept_wins = len(df[(df['direct_concept_success'] == False) & (df['iter_concept_success'] == True)]) 26 | num_iter_commonsense_wins = len(df[(df['direct_commonsense_success'] == False) & (df['iter_commonsense_success'] == True)]) 27 | num_direct_wins = len(df[(df['direct_success'] == True) & (df['iter_success'] == False)]) 28 | num_iter_wins = len(df[(df['direct_success'] == False) & (df['iter_success'] == True)]) 29 | 30 | 31 | num_commonsense_ties = len(df) - num_direct_commonsense_wins - num_iter_commonsense_wins 32 | num_concept_ties = len(df) - num_direct_cocept_wins - num_iter_cocept_wins 33 | 34 | # normalize everything and print a nice report 35 | 36 | print(f"Direct concept wins: {num_direct_cocept_wins / len(df):.2f}") 37 | print(f"Direct commonsense wins: {num_direct_commonsense_wins / len(df):.2f}") 38 | print(f"Direct overall wins: {num_direct_wins / len(df):.2f}") 39 | print(f"Iter concept wins: {num_iter_cocept_wins / len(df):.2f}") 40 | print(f"Iter commonsense wins: {num_iter_commonsense_wins / len(df):.2f}") 41 | print(f"Iter overall wins: {num_iter_wins / len(df):.2f}") 42 | 43 | 44 | if __name__ == '__main__': 45 | import argparse 46 | args = argparse.ArgumentParser() 47 | args.add_argument("path", type=str) 48 | args = args.parse_args() 49 | 50 | run(path=args.path) 51 | 52 | 53 | -------------------------------------------------------------------------------- /BRIDGE/self_refine/make_challenging.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import pandas as pd 5 | 6 | df = pd.read_json("xxx.jsonl", lines=True, orient="records") 7 | 8 | from itertools import chain 9 | all_concepts = set(chain(*df['concepts'].tolist())) 10 | 11 | # challenging data with 10-15 concepts 12 | 13 | import random 14 | random.seed(42) 15 | 16 | n_samples = 200 17 | res = [] 18 | for i in range(n_samples): 19 | k = random.randint(20, 30) 20 | concepts = random.sample(all_concepts, k=k) 21 | res.append({"concepts": concepts}) 22 | 23 | pd.DataFrame(res).to_json("data/commongen_very_challenging.jsonl", lines=True, orient="records") 24 | 25 | -------------------------------------------------------------------------------- /BRIDGE/self_refine/prompt_building.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import re 5 | import pandas as pd 6 | 7 | class TimeSeriesFeedbackPrompt: 8 | prompt: str = """We want to evaluate each text description based on how well it describes the given time series on five qualities: i) accuracy of trend description, ii) mention of seasonality, iii) reference to external factors, iv) clarity of description, v) completeness of information. 9 | 10 | Here are some examples of this scoring rubric: 11 | 12 | Time Series: [100, 150, 200, 250, 300, 350, 400, 450, 500, 550, 600, 650, 700, 750, 800, 850, 900, 950, 1000, 1050, 1100, 1150, 1200, 1250] 13 | 14 | Text Description: "The data shows a steady increase in values over time, indicating a strong upward trend. There are no visible seasonal patterns or fluctuations." 15 | 16 | Scores: 17 | 18 | * Accuracy of trend description: The description accurately identifies the steady increase in the time series. 5/5 19 | * Mention of seasonality: The description correctly notes the absence of seasonality in the data. 5/5 20 | * Reference to external factors: The description does not mention any external factors, which may or may not be relevant. 3/5 21 | * Clarity of description: The description is clear and easy to understand. 5/5 22 | * Completeness of information: The description covers the main aspects of the time series but could mention the exact rate of increase. 4/5 23 | 24 | * Total score: 22/25 25 | 26 | Time Series: [30, 32, 35, 40, 45, 50, 55, 60, 65, 70, 75, 80, 85, 90, 95, 100, 105, 110, 115, 120, 125, 130, 135, 140, 145, 150, 155, 160, 165, 170, 175, 180, 185, 190, 195, 200, 205, 210, 215, 220, 225, 230, 235, 240, 245, 250, 255, 260, 265, 270, 275, 280, 285, 290, 295, 300, 305, 310, 315, 320, 325, 330, 335, 340, 345, 350, 355, 360, 365] 27 | 28 | Text Description: "The temperatures rise and fall cyclically throughout the year, demonstrating clear seasonal patterns. There is a gradual increase in temperature during the summer months." 29 | 30 | Scores: 31 | 32 | * Accuracy of trend description: The description accurately identifies the cyclical pattern in the time series. 5/5 33 | * Mention of seasonality: The description correctly notes the presence of seasonality in the data. 5/5 34 | * Reference to external factors: The description mentions the seasons, which are relevant external factors. 5/5 35 | * Clarity of description: The description is clear and easy to understand. 5/5 36 | * Completeness of information: The description is comprehensive and covers all necessary aspects of the time series. 5/5 37 | 38 | * Total score: 25/25 39 | """ 40 | 41 | def timeseries_iterate_prompt_to_json(output_file="./feedback.jsonl", allow_empty_feedback=True): 42 | prompt = TimeSeriesFeedbackPrompt.prompt 43 | res = [] 44 | 45 | examples = prompt.split("###") 46 | for example in examples: 47 | try: 48 | if not example: 49 | continue 50 | example = example.strip() 51 | 52 | time_series_match = re.search(r"Time Series: \[(.*)\]", example) 53 | if not time_series_match: 54 | continue 55 | time_series = time_series_match.group(1) 56 | time_series_list = [float(x) for x in time_series.split(",")] 57 | 58 | text_description = re.search(r'Text Description: "(.*)"', example).group(1) 59 | 60 | try: 61 | accuracy_of_trend = re.search(r"Accuracy of trend description: (.*)/5", example).group(1) 62 | mention_of_seasonality = re.search(r"Mention of seasonality: (.*)/5", example).group(1) 63 | reference_to_external_factors = re.search(r"Reference to external factors: (.*)/5", example).group(1) 64 | clarity_of_description = re.search(r"Clarity of description: (.*)/5", example).group(1) 65 | completeness_of_information = re.search(r"Completeness of information: (.*)/5", example).group(1) 66 | total_score = re.search(r"Total score: (.*)/25", example).group(1) 67 | feedback_text = "" 68 | 69 | except Exception as feedback_extraction_error: 70 | if allow_empty_feedback: 71 | feedback_text = "" 72 | print(f"[Warning] Feedback missing in example, filled with empty string.") 73 | else: 74 | raise ValueError(f"Feedback extraction failed and allow_empty_feedback=False: {feedback_extraction_error}") 75 | 76 | 77 | res.append({ 78 | "time_series": time_series_list, 79 | "text_description": text_description, 80 | "accuracy_of_trend": accuracy_of_trend, 81 | "mention_of_seasonality": mention_of_seasonality, 82 | "reference_to_external_factors": reference_to_external_factors, 83 | "clarity_of_description": clarity_of_description, 84 | "completeness_of_information": completeness_of_information, 85 | "total_score": total_score, 86 | "feedback": feedback_text 87 | }) 88 | 89 | except Exception as e: 90 | print(f"Error parsing example: {e}") 91 | 92 | df = pd.DataFrame(res) 93 | df.to_json(output_file, orient="records", lines=True) 94 | print(f"Saved feedback examples to {output_file}") 95 | 96 | 97 | if __name__ == "__main__": 98 | timeseries_iterate_prompt_to_json(output_file="./my_feedback_examples.jsonl") 99 | -------------------------------------------------------------------------------- /BRIDGE/self_refine/refiner.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import os 5 | import json 6 | import copy 7 | import requests 8 | 9 | from feedback import TimeSeriesFeedback 10 | 11 | # === Azure OpenAI Call Wrapper === 12 | def call_azure_openai(prompt: str, system_msg: str = None, max_tokens: int = 512) -> str: 13 | endpoint = os.getenv("AZURE_OPENAI_ENDPOINT") 14 | key = os.getenv("AZURE_OPENAI_API_KEY") 15 | deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME") 16 | api_version = os.getenv("AZURE_OPENAI_API_VERSION", "2024-05-13") 17 | 18 | url = f"{endpoint}/openai/deployments/{deployment}/chat/completions?api-version={api_version}" 19 | headers = { 20 | "Content-Type": "application/json", 21 | "api-key": key 22 | } 23 | 24 | messages = [] 25 | if system_msg: 26 | messages.append({"role": "system", "content": system_msg}) 27 | messages.append({"role": "user", "content": prompt}) 28 | 29 | payload = { 30 | "messages": messages, 31 | "temperature": 0.0, 32 | "max_tokens": max_tokens 33 | } 34 | 35 | response = requests.post(url, headers=headers, json=payload) 36 | response.raise_for_status() 37 | return response.json()["choices"][0]["message"]["content"].strip() 38 | 39 | 40 | def auto_refine_text( 41 | generated_series, 42 | actual_series, 43 | initial_text, 44 | model="gpt-4o", 45 | iterations=5, 46 | score_threshold=25 47 | ): 48 | feedback_tool = TimeSeriesFeedback(model=model) 49 | 50 | current_text = initial_text 51 | best_text = current_text 52 | best_score = 0 53 | history = [] 54 | 55 | for iteration in range(iterations): 56 | print(f"\n{'='*30}") 57 | print(f" Iteration {iteration+1}") 58 | print(f"{'='*30}") 59 | 60 | feedback_result = feedback_tool.evaluate( 61 | generated_series=generated_series, 62 | actual_series=actual_series, 63 | text_description=current_text 64 | ) 65 | 66 | text_quality_scores = feedback_result["text_quality_scores"] 67 | suggestions = feedback_result["suggestions"] 68 | summary_feedback = feedback_result["text_feedback_summary"] 69 | 70 | total_score = 0 71 | print(f"\n--- Text Quality Scores ---") 72 | for key, val in text_quality_scores.items(): 73 | score_value = int(val.split("/")[0]) 74 | total_score += score_value 75 | print(f"{key}: {val} ({suggestions.get(key, 'No suggestion')})") 76 | 77 | if total_score > best_score: 78 | best_score = total_score 79 | best_text = current_text 80 | 81 | print(f"\nCurrent Score: {total_score}/25") 82 | print(f"Summary Feedback: {summary_feedback}") 83 | 84 | if best_score >= score_threshold: 85 | print(f"\n Score threshold reached ({best_score}/25). Stopping refinement.") 86 | break 87 | 88 | refined_text = apply_suggestions(current_text, suggestions) 89 | 90 | print(f"\n--- Refined Text ---\n{refined_text}") 91 | 92 | history.append({ 93 | "iteration": iteration + 1, 94 | "text": current_text, 95 | "refined_text": refined_text, 96 | "total_score": total_score, 97 | "summary_feedback": summary_feedback, 98 | "suggestions": suggestions 99 | }) 100 | 101 | current_text = refined_text 102 | 103 | print(f"\n{'='*30}") 104 | print(f" Final Best Text After {iteration+1} Iterations:") 105 | print(f"{'-'*30}") 106 | print(best_text) 107 | print(f"Score: {best_score}/25") 108 | print(f"{'='*30}") 109 | 110 | return best_text, best_score, history 111 | 112 | 113 | def apply_suggestions(text, suggestions): 114 | suggestion_summary = ". ".join([f"{k}: {v}" for k, v in suggestions.items()]) 115 | prompt = f""" 116 | Original description: {text} 117 | 118 | Suggestions for improvement: 119 | {suggestion_summary} 120 | 121 | Please revise the original description to address the suggestions, ensuring clarity and completeness. 122 | """ 123 | 124 | return call_azure_openai( 125 | prompt, 126 | system_msg="You are an expert technical writer improving time series data descriptions." 127 | ) 128 | -------------------------------------------------------------------------------- /BRIDGE/self_refine/task_init.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from typing import List 5 | import pandas as pd 6 | from llm_agents import ChatLLM 7 | 8 | class TimeSeriesTaskInit: 9 | def __init__(self, prompt_examples: str, model: ChatLLM) -> None: 10 | self.model = model 11 | self.setup_prompt_from_examples_file(prompt_examples) 12 | 13 | def generate_time_series(self, text_description: str) -> List[float]: 14 | prompt = self.create_prompt(text_description) 15 | response = self.model.generate(prompt) 16 | time_series = self.extract_time_series(response) 17 | return time_series 18 | 19 | def setup_prompt_from_examples_file(self, instances_path: str) -> None: 20 | TEMPLATE = """Text Description: {text_description} 21 | Time Series: {time_series}""" 22 | 23 | instance_df = pd.read_json(instances_path, orient="records", lines=True) 24 | prompt = [] 25 | for _, row in instance_df.iterrows(): 26 | example = TEMPLATE.format( 27 | text_description=row["text_description"], 28 | time_series=", ".join(map(str, row["time_series"])) 29 | ) 30 | prompt.append(example) 31 | self.prompt_examples = "\n\n###\n\n".join(prompt) 32 | 33 | def create_prompt(self, text_description: str) -> str: 34 | return f"{self.prompt_examples}\n\n###\n\nText Description: {text_description}\nTime Series:" 35 | 36 | def extract_time_series(self, response: str) -> List[float]: 37 | time_series_str = response.split("Time Series:")[-1].strip() 38 | time_series = [float(value) for value in time_series_str.split(',') if value.strip()] 39 | return time_series 40 | 41 | def __call__(self, text_description: str) -> List[float]: 42 | return self.generate_time_series(text_description) 43 | -------------------------------------------------------------------------------- /BRIDGE/self_refine/task_iterate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from typing import List, Dict 5 | import pandas as pd 6 | import re 7 | import os 8 | import requests 9 | from sklearn.metrics import mean_squared_error 10 | 11 | class TimeSeriesTaskIterate: 12 | def __init__(self, model, prompt_examples: str, engine: str) -> None: 13 | self.model = model 14 | self.engine = engine 15 | self.prompt = self.make_prompt(prompt_examples=prompt_examples) 16 | 17 | def make_prompt(self, prompt_examples: str) -> str: 18 | header = """Concepts: {concepts}\n""" 19 | example_template = """Sentence: {sentence} 20 | 21 | What concepts from the concept list are missing from the sentence? 22 | 23 | Concept Feedback: {concept_feedback} 24 | 25 | Any feedback on commonsense? 26 | 27 | Commonsense Feedback: {commonsense_feedback}""" 28 | instr = "\n\nOkay, improve the sentence using the feedback:\n\n" 29 | 30 | examples_df = pd.read_json(prompt_examples, orient="records", lines=True) 31 | prompt = [] 32 | 33 | for example in examples_df.to_dict(orient="records"): 34 | single_example = [] 35 | for step in example["sentence_to_feedback"]: 36 | single_example.append( 37 | example_template.format( 38 | sentence=step["sentence"], 39 | concept_feedback=step["concept_feedback"], 40 | commonsense_feedback=step["commonsense_feedback"] 41 | ) 42 | ) 43 | prompt.append(header.format(concepts=example["concepts"]) + instr.join(single_example)) 44 | 45 | return "\n\n###\n\n".join(prompt) + "\n\n###\n\n" 46 | 47 | def make_one_iterate_example(self, concepts: List[str], sent_to_fb: List[Dict]) -> str: 48 | header = """Concepts: {concepts}\n""" 49 | example_template = """Sentence: {sentence} 50 | 51 | What concepts from the concept list are missing from the sentence? 52 | 53 | Concept Feedback: {concept_feedback} 54 | 55 | Any feedback on commonsense? 56 | 57 | Commonsense Feedback: {commonsense_feedback}""" 58 | instr = "\n\nOkay, improve the sentence using the feedback:\n\n" 59 | 60 | single_example = [] 61 | for example in sent_to_fb: 62 | single_example.append( 63 | example_template.format( 64 | sentence=example["sentence"], 65 | concept_feedback=example["concept_feedback"], 66 | commonsense_feedback=example["commonsense_feedback"] 67 | ) 68 | ) 69 | 70 | return header.format(concepts=concepts) + instr.join(single_example) 71 | 72 | def make_query(self, concepts: List[str], sent_to_fb: List[Dict]) -> str: 73 | query_example = self.make_one_iterate_example(concepts=concepts, sent_to_fb=sent_to_fb) 74 | return f"{self.prompt}\n\n###\n\n{query_example}\n\n###\n\nOkay, improve the sentence using the feedback:\n\n" 75 | 76 | def call_azure_openai_completion(self, prompt: str) -> str: 77 | endpoint = os.getenv("AZURE_OPENAI_ENDPOINT") 78 | key = os.getenv("AZURE_OPENAI_API_KEY") 79 | deployment = self.engine 80 | api_version = os.getenv("AZURE_OPENAI_API_VERSION", "2024-05-13") 81 | 82 | url = f"{endpoint}/openai/deployments/{deployment}/completions?api-version={api_version}" 83 | headers = { 84 | "Content-Type": "application/json", 85 | "api-key": key 86 | } 87 | 88 | payload = { 89 | "prompt": prompt, 90 | "max_tokens": 300, 91 | "temperature": 0.7, 92 | "stop": ["###"] 93 | } 94 | 95 | response = requests.post(url, headers=headers, json=payload) 96 | response.raise_for_status() 97 | return response.json()["choices"][0]["text"].strip() 98 | 99 | def __call__(self, concepts: List[str], sent_to_fb: List[Dict]) -> str: 100 | transfer_query = self.make_query(concepts=concepts, sent_to_fb=sent_to_fb) 101 | response = self.call_azure_openai_completion(transfer_query) 102 | match = re.search("Sentence: (.*)", response) 103 | improved_sentence = match.group(1).strip() if match else response.split("\n")[0].strip() 104 | return improved_sentence 105 | 106 | def refine_text(self, text_description: str, feedback: Dict) -> str: 107 | return self.model.refine(text_description, feedback) 108 | 109 | 110 | if __name__ == "__main__": 111 | obj = TimeSeriesTaskIterate( 112 | model="your-model-instance", 113 | prompt_examples="data/prompt/commongen/iterate.v1.jsonl", 114 | engine=os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME", "gpt-4o") 115 | ) 116 | 117 | concepts = ["trend", "increase", "data"] 118 | sent_to_fb = [ 119 | {"sentence": "The data shows an increasing trend over time.", "concept_feedback": "None", "commonsense_feedback": "The sentence is clear."}, 120 | {"sentence": "Data shows trend.", "concept_feedback": "Missing 'increasing'", "commonsense_feedback": "Incomplete sentence."} 121 | ] 122 | 123 | refined_sentence = obj(concepts, sent_to_fb) 124 | print(refined_sentence) 125 | -------------------------------------------------------------------------------- /BRIDGE/template_extractor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import json 5 | import os 6 | import re 7 | from typing import List, Dict 8 | import pandas as pd 9 | from tqdm import tqdm 10 | from llm_agents import ChatLLM 11 | from self_refine.feedback import TimeSeriesFeedback 12 | 13 | 14 | class TemplateExtractor: 15 | def __init__(self, llm: ChatLLM, feedback_tool: TimeSeriesFeedback, output_template_file="templates/description_template.json"): 16 | self.llm = llm 17 | self.feedback_tool = feedback_tool 18 | self.output_template_file = output_template_file 19 | self.templates = [] 20 | 21 | def load_final_text_from_refinement_result(self, refinement_json_path: str) -> str: 22 | """ 23 | Load the `final_text` field from multi_agent_refinement_result.json 24 | """ 25 | print(f"Loading refinement result from {refinement_json_path}...") 26 | with open(refinement_json_path, "r", encoding="utf-8") as f: 27 | data = json.load(f) 28 | 29 | final_text = data.get("final_text", "") 30 | if not final_text: 31 | raise ValueError("No 'final_text' found in the provided JSON.") 32 | return final_text 33 | 34 | def generate_template_from_final_text(self, final_text: str) -> str: 35 | """ 36 | Generate a generalized template from the `final_text` 37 | """ 38 | prompt = ( 39 | "You are an expert in time series data description. Given the following finalized description:\n\n" 40 | f"{final_text}\n\n" 41 | "Extract a generalized description template. Replace all specific values like numbers, dates, and category names with placeholders in curly braces, " 42 | "such as {dataset_name}, {frequency}, {start_date}, {end_date}, {prediction_length}, {min_value}, {max_value}, {mean_value}, {std_value}, {peak_steps}, {dip_steps}, {variability_summary}.\n\n" 43 | "Return ONLY the generalized template in natural language, no explanation." 44 | ) 45 | 46 | try: 47 | response = self.llm.generate(prompt) 48 | if not response or len(response.strip()) < 10: 49 | raise ValueError("LLM returned an empty or too short response.") 50 | 51 | print("\n Extracted template:\n") 52 | print(response.strip()) 53 | 54 | return response.strip() 55 | 56 | except Exception as e: 57 | print(f"[ERROR] LLM error during template generation: {e}") 58 | return "" 59 | 60 | def evaluate_template_candidate(self, candidate: str, dummy_series: List[float]) -> bool: 61 | """ 62 | Evaluate a single sentence to check if it works as a template for time series descriptions. 63 | """ 64 | print(f"Evaluating template: {candidate}") 65 | 66 | # Step 1: Replace variables in the candidate template with example values 67 | test_description = ( 68 | candidate.replace("{dataset_name}", "ETTh1") 69 | .replace("{frequency}", "hourly") 70 | .replace("{data_description}", "electricity consumption") 71 | .replace("{start_date}", "January 1, 2020") 72 | .replace("{end_date}", "December 31, 2022") 73 | .replace("{prediction_length}", "24") 74 | .replace("{min_value}", "100") 75 | .replace("{max_value}", "250") 76 | .replace("{mean_value}", "175") 77 | .replace("{std_value}", "30") 78 | .replace("{peak_steps}", "June 15, 2021") 79 | .replace("{dip_steps}", "September 10, 2021") 80 | .replace("{variability_summary}", "moderate fluctuations") 81 | ) 82 | 83 | # Step 2: Evaluate the text quality with feedback tool 84 | try: 85 | feedback_score = self.feedback_tool(dummy_series, test_description) 86 | print(f"Feedback result: {feedback_score}") 87 | except Exception as e: 88 | print(f"Failed evaluation on: {test_description}") 89 | print(f"[ERROR]: {e}") 90 | return False 91 | 92 | return True # You can add stricter rules here if necessary 93 | 94 | def extract_templates_from_refinement_result(self, refinement_json_path: str): 95 | """ 96 | Full pipeline: loads `final_text` -> generalizes -> evaluates -> saves template. 97 | """ 98 | final_text = self.load_final_text_from_refinement_result(refinement_json_path) 99 | 100 | # Step 1: Generate generalized template 101 | template = self.generate_template_from_final_text(final_text) 102 | 103 | if not template: 104 | print("No valid template generated.") 105 | return 106 | 107 | # Step 2: Evaluate its quality with a dummy series 108 | dummy_series = [100, 150, 200, 250, 300, 350, 400, 450, 500] 109 | 110 | if self.evaluate_template_candidate(template, dummy_series): 111 | self.templates.append(template) 112 | print(f"Accepted template: {template}") 113 | else: 114 | print(f"Template rejected after evaluation.") 115 | 116 | # Step 3: Save templates to JSON 117 | os.makedirs(os.path.dirname(self.output_template_file), exist_ok=True) 118 | with open(self.output_template_file, "w", encoding="utf-8") as f: 119 | json.dump(self.templates, f, indent=2) 120 | 121 | print(f"\n Saved {len(self.templates)} template(s) to {self.output_template_file}") 122 | 123 | -------------------------------------------------------------------------------- /BRIDGE/ts_metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import ast 7 | import matplotlib.pyplot as plt 8 | from scipy.stats import ks_2samp, wasserstein_distance, jensenshannon 9 | from sklearn.metrics import mean_squared_error, mean_absolute_error 10 | from sklearn.metrics.pairwise import cosine_similarity 11 | from statsmodels.tsa.seasonal import STL 12 | 13 | class TimeSeriesMetrics: 14 | def __init__(self, predictions, ground_truth): 15 | self.predictions = np.array(predictions) 16 | self.ground_truth = np.array(ground_truth) 17 | 18 | def compute_mse(self): 19 | return mean_squared_error(self.ground_truth, self.predictions) 20 | 21 | def compute_mae(self): 22 | return mean_absolute_error(self.ground_truth, self.predictions) 23 | 24 | def compute_correlation(self): 25 | if len(self.predictions) != len(self.ground_truth): 26 | raise ValueError("Predictions and ground truth must have the same length") 27 | 28 | mean_predictions = np.mean(self.predictions) 29 | mean_ground_truth = np.mean(self.ground_truth) 30 | 31 | numerator = np.sum((self.predictions - mean_predictions) * (self.ground_truth - mean_ground_truth)) 32 | denominator = np.sqrt(np.sum((self.predictions - mean_predictions)**2) * np.sum((self.ground_truth - mean_ground_truth)**2)) 33 | 34 | if denominator == 0: 35 | raise ValueError("Denominator in correlation calculation is zero, cannot divide by zero") 36 | 37 | correlation = numerator / denominator 38 | return correlation 39 | 40 | def compare_distributions(self, interval=1, bar_width=0.2, plot=True): 41 | min_value = min(np.min(self.predictions), np.min(self.ground_truth)) 42 | max_value = max(np.max(self.predictions), np.max(self.ground_truth)) 43 | 44 | bins = np.arange(min_value, max_value + interval, interval) 45 | 46 | pred_hist, _ = np.histogram(self.predictions, bins=bins, density=True) 47 | gt_hist, _ = np.histogram(self.ground_truth, bins=bins, density=True) 48 | 49 | if plot: 50 | plt.figure(figsize=(10, 5)) 51 | bin_centers = 0.5 * (bins[:-1] + bins[1:]) 52 | plt.bar(bin_centers - bar_width/2, pred_hist, width=bar_width, alpha=0.5, label='Predictions') 53 | plt.bar(bin_centers + bar_width/2, gt_hist, width=bar_width, alpha=0.5, label='Ground Truth') 54 | plt.legend(loc='upper right') 55 | plt.title('Distribution Comparison') 56 | plt.xlabel('Value') 57 | plt.ylabel('Density') 58 | plt.show() 59 | 60 | return pred_hist, gt_hist 61 | 62 | def compute_cosine_similarity(self): 63 | pred_reshaped = self.predictions.reshape(1, -1) 64 | gt_reshaped = self.ground_truth.reshape(1, -1) 65 | return cosine_similarity(pred_reshaped, gt_reshaped)[0][0] 66 | 67 | def compute_js_distance(self): 68 | pred_hist, gt_hist = self.compare_distributions(plot=False) 69 | return jensenshannon(pred_hist, gt_hist) 70 | 71 | def compute_ks_test(self): 72 | """ 73 | Kolmogorov-Smirnov test for the equality of distribution. 74 | Returns the KS statistic and p-value. 75 | """ 76 | statistic, p_value = ks_2samp(self.predictions, self.ground_truth) 77 | return {"ks_statistic": statistic, "p_value": p_value} 78 | 79 | def compute_wasserstein_distance(self): 80 | """ 81 | Compute the first Wasserstein distance (also called Earth Mover’s Distance). 82 | """ 83 | return wasserstein_distance(self.predictions, self.ground_truth) 84 | 85 | @staticmethod 86 | def perform_stl_decomposition(time_series, period): 87 | if not isinstance(time_series, pd.Series): 88 | raise ValueError("Input must be a pandas Series") 89 | if not isinstance(time_series.index, pd.DatetimeIndex): 90 | raise ValueError("Input Series must have a DatetimeIndex") 91 | 92 | stl = STL(time_series, period=period) 93 | result = stl.fit() 94 | 95 | trend = result.trend 96 | seasonal = result.seasonal 97 | residual = result.resid 98 | 99 | decomposed = pd.DataFrame({ 100 | 'original': time_series, 101 | 'trend': trend, 102 | 'seasonal': seasonal, 103 | 'residual': residual 104 | }) 105 | 106 | return decomposed 107 | 108 | @staticmethod 109 | def load_and_prepare_data(file_path, column_name='history_data', num_steps=96): 110 | df = pd.read_csv(file_path) 111 | df[column_name] = df[column_name].apply(ast.literal_eval) 112 | time_series = pd.Series(df[column_name].iloc[0][:num_steps]) 113 | 114 | date_range = pd.date_range(start='2021-01-01', periods=num_steps, freq='H') 115 | time_series.index = date_range 116 | 117 | print(time_series) 118 | return time_series 119 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /DiGA/agent/utils/chartist_state.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import List, Optional 3 | 4 | import numpy as np 5 | import pandas as pd 6 | from pandas import Timestamp 7 | 8 | from mlib.core.state import State 9 | from mlib.core.trade_info import TradeInfo 10 | 11 | 12 | class ChartistState(State): 13 | def __init__(self, delta_time: float, init_len: int, init_price: float) -> None: 14 | super().__init__() 15 | self.delta_time = delta_time 16 | self.init_len = init_len 17 | 18 | self.horizon = 0 19 | self.update_time: Optional[Timestamp] = None 20 | self.return_history = pd.DataFrame(columns=["return"], index=pd.to_datetime([])) # type: ignore 21 | self._is_ready = False 22 | 23 | self._return_vars: List[float] = [] 24 | self._time_deltas: List[float] = [] 25 | self._num_returns = 0 26 | self._num_returns_by_1s = 0 27 | 28 | self.price_scale = init_price / 300 29 | self.pre_price: float = init_price / self.price_scale 30 | self.current_price: float = init_price / self.price_scale 31 | self.last_trade_info: Optional[TradeInfo] = None 32 | 33 | def on_trading(self, trade_info: TradeInfo): 34 | super().on_trading(trade_info) 35 | self.last_trade_info = trade_info 36 | 37 | self.pre_price = self.current_price 38 | self.current_price = self.get_mid_price() 39 | ret = self.get_current_return() 40 | self.update_historical_data(ret, time=trade_info.order.time) 41 | 42 | def observe_price(self, price: float, time: Timestamp): 43 | self.current_price = price 44 | ret = self.get_current_return() 45 | self.pre_price = self.current_price 46 | self.update_historical_data(ret, time=time) 47 | 48 | def check_valid_return(self, ret: float): 49 | assert -0.5 < ret < 0.5 # if error encountered, you can temporarilly disabled this assertion 50 | pass 51 | 52 | def get_mid_price(self): 53 | if self.last_trade_info is None: 54 | return self.pre_price 55 | lob_snapshot = self.last_trade_info.lob_snapshot 56 | if lob_snapshot.ask_prices and lob_snapshot.bid_prices: 57 | mid_price = (lob_snapshot.ask_prices[0] + lob_snapshot.bid_prices[0]) / 2 58 | mid_price /= self.price_scale 59 | return mid_price 60 | return self.pre_price 61 | 62 | def get_current_return(self): 63 | if self.current_price is None or self.pre_price is None: 64 | return 0 65 | assert self.current_price > 0 66 | assert self.pre_price > 0 67 | # formula (2) in reg paper 68 | current_return: float = np.log(self.current_price / self.pre_price) 69 | assert current_return is not None 70 | return current_return 71 | 72 | def is_ready(self): 73 | return self._is_ready 74 | 75 | def estimate_return(self, horizon: int): 76 | avg_return: float = self.return_history[-int(horizon) :]["return"].mean() # type: ignore 77 | self.check_valid_return(avg_return) 78 | return avg_return 79 | 80 | def cal_return_variance(self, horizon: int): 81 | avg_return: float = self.estimate_return(horizon) 82 | var_return = sum(map(lambda x: (x - avg_return) ** 2, self.return_history[-int(horizon) :]["return"].tolist())) / horizon # type: ignore 83 | self._return_vars.append(var_return) 84 | return var_return 85 | 86 | # Return and Time 87 | def update_historical_data(self, x: float, time: Timestamp): 88 | self.check_valid_return(x) 89 | # Initialization 90 | if not self.update_time: 91 | self.return_history = pd.concat([self.return_history, pd.DataFrame([pd.Series({"return": x}, name=time)])]) # self.return_history.append(pd.Series({"return": x}, name=time)) # type: ignore 92 | self.return_history.dropna(inplace=True) # type: ignore 93 | self.update_time = time 94 | return 95 | 96 | time_delta_from_last_msg = int((time - self.update_time).total_seconds() / self.delta_time) 97 | assert time_delta_from_last_msg >= 0 98 | 99 | if time_delta_from_last_msg == 0: 100 | return 101 | elif time_delta_from_last_msg > 1: 102 | self._time_deltas.append(time_delta_from_last_msg) 103 | # Time interpolation 104 | for idx_delta in range(time_delta_from_last_msg - 1): 105 | self.return_history = pd.concat([self.return_history, pd.DataFrame([pd.Series({"return": 0}, name=self.update_time + pd.Timedelta("{}s".format(self.delta_time * idx_delta)))])]) 106 | self._num_returns_by_1s += 1 107 | self._num_returns += 1 108 | # update return 109 | self.return_history = pd.concat([self.return_history, pd.DataFrame([pd.Series({"return": x}, name=time)])]) # self.return_history.append(pd.Series({"return": x}, name=time)) # type: ignore 110 | self.return_history.dropna(inplace=True) # type: ignore 111 | self.update_time = time 112 | 113 | self._clear_queue() 114 | self._is_ready = True if len(self.return_history) > self.init_len else False # type: ignore 115 | 116 | def register_horizon(self, x: int) -> None: 117 | x = int(x) 118 | if x > self.horizon: 119 | self.horizon = x 120 | # print("Maximum chartist horizon: {}".format(self.horizon)) 121 | return 122 | 123 | # prevent the queue from growing too long 124 | def _clear_queue(self): 125 | if random.random() < 0.1: 126 | self.return_history = self.return_history[-self.horizon - 1 :] 127 | -------------------------------------------------------------------------------- /DiGA/agent/utils/ctrl_loader.py: -------------------------------------------------------------------------------- 1 | # import json 2 | import numpy as np 3 | import pandas as pd 4 | 5 | 6 | class CtrlLoader: 7 | """Generation data loader. Load fundamentals and num of orders in terms of gt.""" 8 | 9 | def __init__(self, duration, ctrls, symbol="000001"): 10 | self.symbol = symbol 11 | self.mkt_open = duration[0] 12 | self.mkt_close = duration[1] 13 | self.date = str(self.mkt_open.year) + "{:02d}".format(self.mkt_open.month) + "{:02d}".format(self.mkt_open.day) # str(self.mkt_open.day) 14 | 15 | self.ctrls = ctrls 16 | if 'fundamental' in self.ctrls.keys(): 17 | self.load_fdmtl(self.ctrls['fundamental']) 18 | if 'n_orders' in self.ctrls.keys(): 19 | self.load_n_orders(self.ctrls['n_orders']) 20 | 21 | 22 | def load_fdmtl(self, fdmtl): 23 | # fdmtl: 1-d array of MINUTE mid price as fundamental, transform into second-wise df 24 | mkt_open, mkt_close = self.mkt_open, self.mkt_close 25 | date_range = pd.date_range(mkt_open, mkt_close, inclusive="left", freq='1s', name='ts') 26 | price_df = pd.DataFrame(index=date_range).reset_index(names=['ts']) 27 | sec_trading_filter = price_df['ts'].apply(lambda x: not ((x.hour==11 and x.minute>=30) or x.hour==12)) 28 | min_date_range = pd.date_range(mkt_open, mkt_close, inclusive="left", freq='1min', name='ts') 29 | min_price_df = pd.DataFrame(index=min_date_range).reset_index(names=['ts']) 30 | min_trading_filter = min_price_df['ts'].apply(lambda x: not ((x.hour==11 and x.minute>=30) or x.hour==12)) 31 | 32 | 33 | trading_min = min_price_df[min_trading_filter].copy() 34 | trading_sec = price_df[sec_trading_filter].copy() 35 | row_prices = fdmtl // 50 * 50 36 | ffill_row = np.pad(row_prices, (0, trading_min.shape[0]-row_prices.shape[0]), 'constant', constant_values=row_prices[-1]) 37 | trading_min['price'] = ffill_row 38 | trading_sec = trading_sec.merge(trading_min, how='left', on='ts').ffill() 39 | self._midPriceTb = trading_sec 40 | 41 | def load_n_orders(self, n): 42 | mkt_open, mkt_close = self.mkt_open, self.mkt_close 43 | min_date_range = pd.date_range(mkt_open, mkt_close, inclusive="left", freq='1min', name='ts') 44 | min_price_df = pd.DataFrame(index=min_date_range).reset_index(names=['ts']) 45 | min_trading_filter = min_price_df['ts'].apply(lambda x: not ((x.hour==11 and x.minute>=30) or x.hour==12)) 46 | 47 | trading_min = min_price_df[min_trading_filter].copy() 48 | # should have correct shape or else will be padded with 0 and maybe raise error 49 | ffill_row = np.pad(n, (0, trading_min.shape[0]-n.shape[0]), 'constant', constant_values=0) 50 | ffill_row = np.maximum(ffill_row, 100) 51 | trading_min['n_orders'] = ffill_row 52 | self._nOrdersTb = trading_min 53 | 54 | def getMidPrice(self): 55 | return self._midPriceTb 56 | 57 | def getNOrders(self): 58 | return self._nOrdersTb 59 | 60 | def generateSeed(self): 61 | return int(self._midPriceTb['price'].mean() // 10) 62 | 63 | -------------------------------------------------------------------------------- /DiGA/agent/utils/rl_state.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from mlib.core.state import State 3 | from mlib.core.trade_info import TradeInfo 4 | import numpy as np 5 | 6 | def rpadton(x, v=0, n=10): 7 | # right pad to n 8 | pad_len = n - len(x) 9 | return np.pad(x, (0, pad_len), 'constant', constant_values=(v, v)) 10 | 11 | class RLState(State): 12 | 13 | def __init__(self, window=20, init_price=100000, tick_size=100) -> None: 14 | super().__init__() 15 | self.signal = 0 16 | self.price_history = pd.DataFrame(columns=['price'], index=pd.to_datetime([])) 17 | self.last_price = init_price 18 | self.prev_price = init_price 19 | self.initilized = False 20 | self.prev_update_time = None 21 | self.window = window 22 | self.tick_size = tick_size 23 | 24 | def on_open(self, trade_info: TradeInfo=None, cancel_transactions = None, lob_snapshot = None, match_trans = None): 25 | if trade_info is not None: 26 | self.update_price_history() 27 | self.update_orderbook(trade_info.lob_snapshot) 28 | 29 | def on_trading(self, trade_info: TradeInfo): 30 | super().on_trading(trade_info) 31 | self.update_price_history() 32 | self.update_orderbook(trade_info.lob_snapshot) 33 | self.initilized = True 34 | 35 | def update_price_history(self): 36 | now_second = self.time.ceil('s') 37 | if self.prev_update_time is None: 38 | self.price_history.loc[now_second] = self.last_price 39 | else: 40 | prev_second = self.prev_update_time.ceil('s') 41 | if now_second == prev_second: 42 | self.price_history.loc[now_second] = self.last_price 43 | else: 44 | delta_second = now_second - prev_second 45 | num_of_intervals = int(delta_second.total_seconds() / 60) 46 | if num_of_intervals > 1: 47 | for i in range(1, num_of_intervals): 48 | self.price_history.loc[prev_second + pd.Timedelta(seconds=i)] = self.prev_price 49 | 50 | self.price_history.loc[now_second] = self.last_price 51 | 52 | self.prev_update_time = self.time 53 | self.prev_price = self.last_price 54 | 55 | def update_orderbook(self, orderbook): 56 | self.bid_prices = orderbook.bid_prices[:10] 57 | self.ask_prices = orderbook.ask_prices[:10] 58 | self.bid_volumes = orderbook.bid_volumes[:10] 59 | self.ask_volumes = orderbook.ask_volumes[:10] 60 | 61 | def get_state(self): 62 | # right pad to n 63 | history = rpadton(np.log(self.price_history.iloc[-self.window:].values / self.last_price).squeeze(-1), n=self.window) 64 | lob_prices = np.log(np.concatenate([rpadton(self.bid_prices, self.last_price),rpadton(self.ask_prices, self.last_price)])) / np.log(self.last_price) 65 | lob_volumes = (np.array(np.concatenate([rpadton(self.bid_volumes),rpadton(self.ask_volumes)]))-1000)/1000 66 | return np.concatenate([history, lob_prices, lob_volumes], axis=0) 67 | 68 | 69 | -------------------------------------------------------------------------------- /DiGA/agent/utils/trade_info_state.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from mlib.core.state import State 4 | from mlib.core.trade_info import TradeInfo 5 | 6 | 7 | class TradeInfoState(State): 8 | """A state contains all trade infos.""" 9 | 10 | def __init__(self) -> None: 11 | super().__init__() 12 | self.trade_infos: List[TradeInfo] = [] 13 | self.call_auction_trade_infos: List[TradeInfo] = [] 14 | 15 | def on_trading(self, trade_info: TradeInfo): 16 | super().on_trading(trade_info) 17 | self.trade_infos.append(trade_info) 18 | 19 | def on_call_auction_trading(self, trade_info: TradeInfo): 20 | super().on_call_auction_trading(trade_info) 21 | self.call_auction_trade_infos.append(trade_info) 22 | -------------------------------------------------------------------------------- /DiGA/environment.yml: -------------------------------------------------------------------------------- 1 | name: diga 2 | channels: 3 | - conda-forge 4 | - pytorch 5 | - nvidia 6 | dependencies: 7 | - python=3.8 8 | - pip=20.3 9 | - numpy 10 | - pandas 11 | - scipy 12 | - matplotlib 13 | - statsmodels 14 | - seaborn 15 | - pytorch 16 | - pytorch-cuda=12.1 17 | - torchvision 18 | - wandb 19 | - tqdm 20 | - lightning 21 | - einops 22 | - rich 23 | - pip: 24 | - gymnasium 25 | - stable-baselines3 26 | # - market_simulation 27 | 28 | -------------------------------------------------------------------------------- /DiGA/rltask/envs/base_market_env.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | import gymnasium as gym 3 | import numpy as np 4 | from gymnasium import spaces 5 | from typing import Generator 6 | 7 | from agent.rl_agent import RLAgent 8 | from mlib.core.env import Env 9 | from mlib.core.observation import Observation 10 | 11 | 12 | class BaseMarketEnv(gym.Env): 13 | """Custom Environment that follows gym interface.""" 14 | market_env: Env = None 15 | rl_agent: RLAgent = None 16 | obs_generator: Generator[Observation, None, None] = None 17 | 18 | def __init__(self, config, mode='normal', show_progress=True, discrete_action=True): 19 | super().__init__() 20 | self.config = config # each time we create a env, the config is fixed. Maybe we can modify this setting later. 21 | if discrete_action: 22 | self.action_space = spaces.Discrete(2 * 5 * 10 + 1) 23 | else: 24 | self.action_space = spaces.Box(low=0, high=1, shape=(2 * 5 * 10 + 1,)) 25 | self.observation_space = spaces.Box(low=-np.inf, high=np.inf, 26 | shape=(config["window"]+40+3,)) 27 | self.reward_mode = config['reward_mode'] 28 | self.show_progress = show_progress 29 | 30 | @abstractmethod 31 | def prepare_trading_env(self, config): 32 | # must register market_env, rl_agent, obs_generator here. 33 | ... 34 | 35 | def market_loop(self): 36 | observation = None 37 | for observation in self.obs_generator: 38 | if observation.agent.agent_id == self.rl_agent.agent_id and not observation.is_market_open_wakup and self.rl_state.initilized: 39 | break 40 | else: 41 | agent_to_act = observation.agent 42 | action = agent_to_act.get_action(observation) 43 | self.market_env.step(action) 44 | if observation is None: 45 | return None 46 | elif isinstance(observation.agent, RLAgent): 47 | self.env_obs = observation 48 | return observation.agent.convert_state() 49 | elif observation.time >= self.close_time: 50 | return None 51 | elif len(self.market_env.events) == 0: 52 | return None 53 | else: 54 | raise ValueError(f"observation not handled: {observation}") 55 | 56 | def step(self, action): 57 | 58 | action = self.rl_agent.get_action(self.env_obs, action) 59 | self.market_env.step(action) 60 | observation = self.market_loop() 61 | reward = self.rl_agent.get_step_pnl(self.reward_mode) 62 | 63 | terminated = True if observation is None else False 64 | truncated = False 65 | info = {} 66 | return observation, reward, terminated, truncated, info 67 | 68 | def reset(self, seed=None, options=None): 69 | info = {} 70 | observation = self.prepare_trading_env(self.config) 71 | return observation, info 72 | 73 | def render(self): 74 | ... 75 | 76 | def close(self): 77 | ... 78 | -------------------------------------------------------------------------------- /DiGA/rltask/envs/ctrl_market_env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import pandas as pd 4 | from dateutil.parser import parse 5 | 6 | from mlib.core.env import Env 7 | from mlib.core.event import create_exchange_events 8 | from mlib.core.exchange import Exchange 9 | 10 | from agent.utils.chartist_state import ChartistState 11 | from agent.rl_agent import RLAgent 12 | from agent.utils.trade_info_state import TradeInfoState 13 | from agent.utils.rl_state import RLState 14 | from agent.meta_agent import get_oracle, parse_config, CtrlHeterogeneousAgent 15 | from utils.pkl_utils import load_pkl 16 | from rltask.envs.base_market_env import BaseMarketEnv 17 | from mlib.core.exchange_config import create_exchange_config_without_call_auction # create_a_stock_exchange_config 18 | 19 | base_ctrl_config = { 20 | "f_value": "g", 21 | "symbol": 'ma.sim', 22 | "init_price": 100000, 23 | "date": "20000101", 24 | "start_time": "09:30:00", 25 | "end_time": "15:00:00", 26 | "tick_size": 100, 27 | "num_agents": 400, 28 | "wakeup_prob": 0.01, 29 | "delta_time": 1, 30 | "freq": 1, 31 | "f": 10, 32 | "c": 1.5, 33 | "n": 1, 34 | "time_ref": 30, 35 | "init_cash": 25, 36 | "init_hold": 25, 37 | "a": 0.1, 38 | "risk_averse": 0.1, 39 | "min_num_returns": 20, 40 | "noise_var": 1e-4, 41 | "reward_mode": "step", 42 | "window": 30, 43 | "ctrl_path": "/data/container/rl_simulation/ctrl_env_train_dict.pkl", 44 | "shuffle": True 45 | } 46 | 47 | class CtrlMarketEnv(BaseMarketEnv): 48 | """Custom Environment that follows gym interface.""" 49 | 50 | def __init__(self, config, mode='normal', show_progress=True, discrete_action=True): 51 | super().__init__(config=config, mode=mode, show_progress=show_progress, discrete_action=discrete_action) 52 | self.current_ctrl_idx = 0 53 | self.all_ctrl_pairs = None 54 | self.prepare_ctrl_data() 55 | 56 | def prepare_ctrl_data(self): 57 | assert 'ctrl_path' in self.config, "Control path not specified!" 58 | self.all_ctrls = load_pkl(self.config['ctrl_path']) 59 | self.all_ctrl_names = list(self.all_ctrls.keys()) 60 | self.num_ctrls = len(self.all_ctrl_names) 61 | 62 | def get_next_ctrl_data(self): 63 | assert self.all_ctrls is not None, "Control data not prepared!" 64 | if self.config["shuffle"]: 65 | self.select_ctrl = np.random.choice(self.all_ctrl_names) 66 | else: 67 | self.select_ctrl = self.all_ctrl_names[self.current_ctrl_idx] 68 | self.current_ctrl_idx += 1 69 | if self.current_ctrl_idx >= self.num_ctrls: 70 | self.current_ctrl_idx = 0 71 | return self.all_ctrls[self.select_ctrl] 72 | 73 | def prepare_trading_env(self, config): 74 | """Run a rollout and get trade info.""" 75 | next_ctrl = self.get_next_ctrl_data() 76 | date = pd.to_datetime(config["date"]) 77 | mkt_open: pd.Timestamp = date + pd.to_timedelta(parse(config["start_time"]).strftime("%H:%M:%S")) # type: ignore 78 | mkt_close: pd.Timestamp = date + pd.to_timedelta(parse(config["end_time"]).strftime("%H:%M:%S")) 79 | self.close_time = mkt_close 80 | symbol = config['symbol'] 81 | print( 82 | f"Reset exchange and environments...Symbol {symbol}, Data {date}" 83 | ) 84 | ex_config = create_exchange_config_without_call_auction(mkt_open, mkt_close, [symbol]) 85 | self.exchange = Exchange(ex_config) 86 | self.exchange.register_state(TradeInfoState()) 87 | self.rl_state = RLState(config['window'], config['init_price'], config['tick_size']) 88 | self.exchange.register_state(self.rl_state) 89 | self.chartist_state = ChartistState(config['delta_time'], config['min_num_returns'], config['init_price']) 90 | self.exchange.register_state(self.chartist_state) 91 | 92 | self.market_env = Env(self.exchange, "Market Env For RL", show_progress=self.show_progress) 93 | agent_config = parse_config(config, symbol, mkt_open, mkt_close) 94 | oracle, data_loader = get_oracle(agent_config=config, ctrls=next_ctrl, symbol=symbol, output_dir=None, mkt_open=mkt_open, mkt_close=mkt_close) 95 | background_agents = [ 96 | CtrlHeterogeneousAgent(0, agent_config, chartist_state=self.chartist_state, oracle=oracle, symbol=config['symbol']) 97 | ] 98 | for bg_agent in background_agents: 99 | self.market_env.register_agent(bg_agent) 100 | self.rl_agent = RLAgent(start_time=mkt_open, end_time=mkt_close, obs_state=self.rl_state, symbol=config['symbol']) 101 | self.market_env.register_agent(self.rl_agent) 102 | self.market_env.push_events(create_exchange_events(ex_config)) 103 | self.obs_generator = self.market_env.env() 104 | observation = self.market_loop() 105 | 106 | return observation 107 | -------------------------------------------------------------------------------- /DiGA/rltask/train_test_rl.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from stable_baselines3 import A2C 3 | from stable_baselines3.common.evaluation import evaluate_policy 4 | from stable_baselines3.common.monitor import Monitor 5 | import wandb 6 | from wandb.integration.sb3 import WandbCallback 7 | import argparse 8 | from copy import copy 9 | from pathlib import Path 10 | from pytorch_lightning import seed_everything 11 | 12 | from rltask.envs.ctrl_market_env import CtrlMarketEnv, base_ctrl_config 13 | from rltask.envs.replay_market_env import ReplayMarketEnv, base_replay_config 14 | 15 | def parse_args(): 16 | parser = argparse.ArgumentParser(description='RL training') 17 | parser.add_argument('--market', type=str, default='DiGA', help='DiGA, Replay') 18 | parser.add_argument('--max_steps', type=int, default=288000, help='max_steps') 19 | parser.add_argument("--save_name", type=str, default="DiGA_env") 20 | parser.add_argument('--eval_eps', type=int, default=10, help='eval_eps') 21 | parser.add_argument('--data_path', type=str, default=None, help='path for diga/replay data') 22 | parser.add_argument('--test_replay_path', type=str, default=None, help='path for replay test data') 23 | parser.add_argument("--output_path", type=str, default=None) 24 | parser.add_argument('--seed', type=int, default=0, help='seed') 25 | args = parser.parse_args() 26 | return args 27 | 28 | if __name__ == "__main__": 29 | 30 | args = parse_args() 31 | seed_everything(args.seed) 32 | 33 | if args.market == 'DiGA': 34 | env_class = CtrlMarketEnv 35 | env_config = copy(base_ctrl_config) 36 | env_config['diga_path'] = args.data_path 37 | elif args.market == 'Replay': 38 | env_class = ReplayMarketEnv 39 | env_config = copy(base_replay_config) 40 | env_config['replay_path'] = args.data_path 41 | else: 42 | raise ValueError(f"Not included market type: {args.market}") 43 | 44 | model_class = A2C 45 | 46 | 47 | now = datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 48 | run_name = f"{args.market}-disc-10s-w30-{args.max_steps}" 49 | if args.market == 'Ctrl': 50 | run_name = f"{args.market}-{args.ctrl_key}-disc-10s-w30-{args.max_steps}" 51 | run_name += f"-{args.seed}" 52 | save_path = Path(f"{args.output_path}/{args.save_name}/{run_name}") 53 | save_path.mkdir(parents=True, exist_ok=True) 54 | run = wandb.init( 55 | project="Trade-Simulation-Market", 56 | config=args, 57 | sync_tensorboard=True, # auto-upload sb3's tensorboard metrics 58 | monitor_gym=False, # auto-upload the videos of agents playing the game 59 | save_code=True, # optional 60 | name=run_name 61 | ) 62 | env_config['window'] = 30 63 | env = env_class(env_config, show_progress=False) 64 | env = Monitor(env) 65 | 66 | model = model_class("MlpPolicy", env, tensorboard_log=f"runs/{run.id}") 67 | model.learn(total_timesteps=args.max_steps, callback=WandbCallback(), progress_bar=True) 68 | model.save(Path(save_path) / "model") # 69 | mean_reward, std_reward = evaluate_policy(model, model.get_env(), n_eval_episodes=args.eval_eps) 70 | print(f"Train mean_reward:{mean_reward:.2f} +/- {std_reward:.2f}") 71 | env.close() 72 | 73 | 74 | env_config = copy(base_replay_config) 75 | env_config['train'] = False 76 | env_config['replay_path'] = args.test_replay_path 77 | env_config['window'] = 30 78 | env_config['test_pnl_path'] = save_path 79 | test_env = ReplayMarketEnv(env_config, show_progress=False) 80 | test_env = Monitor(test_env) 81 | 82 | model = model_class.load(Path(save_path)/run_name, env = test_env) 83 | mean_reward, std_reward = evaluate_policy(model, model.get_env(), n_eval_episodes=args.eval_eps) 84 | print(f"{run_name} Test mean_reward:{mean_reward:.2f} +/- {std_reward:.2f}") 85 | test_env.close() 86 | 87 | # %% 88 | -------------------------------------------------------------------------------- /DiGA/train.py: -------------------------------------------------------------------------------- 1 | from diffusion.cont_ctrl_net import UnetCC, DataModuleCC 2 | from diffusion.disc_ctrl_net import UnetDC, DataModuleDC 3 | from diffusion.ddpm import GaussianDiffusion, PLModel 4 | from pytorch_lightning.callbacks import ModelCheckpoint 5 | from pytorch_lightning.loggers import WandbLogger 6 | from pytorch_lightning import Trainer, seed_everything 7 | from utils.metrics_utils import get_metrics_func 8 | from pathlib import Path 9 | import argparse 10 | 11 | def get_args(): 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--data_name", type=str, default='SZAMain') 14 | parser.add_argument("--ctrl_type", type=str, default='continuous') 15 | parser.add_argument("--ctrl_target", type=str, default='return') 16 | parser.add_argument("--n_bins", type=int, default=5) 17 | parser.add_argument("--diffsteps", type=int, default=200) 18 | parser.add_argument("--samsteps", type=int, default=20) 19 | parser.add_argument("--epochs", type=int, default=None) 20 | parser.add_argument("--maxsteps", type=int, default=None) 21 | parser.add_argument("--batch_size", type=int, default=256) 22 | parser.add_argument("--learning_rate", type=float, default=1e-5) 23 | parser.add_argument("--checkpoints", type=int, default=2) 24 | parser.add_argument("--data_path", type=str) # data/container 25 | parser.add_argument("--output_path", type=str) 26 | parser.add_argument("--seed", type=int, default=0) 27 | parser.add_argument("--num_workers", type=int, default=0) 28 | args = parser.parse_args() 29 | return args 30 | 31 | 32 | if __name__ == "__main__": 33 | args = get_args() 34 | 35 | seed_everything(args.seed) 36 | print(f"Set seed: {args.seed}") 37 | 38 | exp_name = f'DiGA_{args.data_name.split("_")[0]}_{args.ctrl_type}_{args.ctrl_target}_{args.seed}' # _{now}' 39 | 40 | results_folder = Path(args.output_path) / exp_name 41 | results_folder.mkdir(parents=True, exist_ok=True) 42 | print(f"Set output path: {results_folder}") 43 | 44 | train_path = Path(args.data_path) / f'{args.data_name}_train.npy' 45 | vali_path = Path(args.data_path) / f'{args.data_name}_vali.npy' 46 | selected_metrics = [get_metrics_func(args.ctrl_target)] 47 | 48 | if args.ctrl_type == 'continuous': 49 | pl_data = DataModuleCC(train_path, vali_path, selected_metrics=args.ctrl_target, batch_size=args.batch_size, n_bins=args.n_bins, num_workers=args.num_workers) 50 | model = UnetCC(dim = 64, n_ctrls = 1, dim_mults = (1, 4, 16), channels = 2) 51 | elif args.ctrl_type == 'discrete': 52 | pl_data = DataModuleDC(train_path, vali_path, selected_metrics=args.ctrl_target, batch_size=args.batch_size, n_bins=args.n_bins, num_workers=args.num_workers) 53 | model = UnetDC(dim = 64, num_classes=args.n_bins, dim_mults = (1, 4, 16), channels = 2) 54 | 55 | diffusion = GaussianDiffusion(model, seq_length = 236, timesteps = args.diffsteps, sampling_timesteps = args.samsteps, objective = 'pred_noise', auto_normalize = False) 56 | LitModel = PLModel(model = diffusion, train_lr = args.learning_rate, results_folder = results_folder) 57 | 58 | ckpt_callback = ModelCheckpoint(results_folder, filename='model-{epoch}', monitor='val/loss', save_top_k=args.checkpoints, save_last=True, mode='min') 59 | 60 | logger = WandbLogger(name=exp_name, project='DiGA-Meta-Controller', config=args, offline=True) 61 | 62 | trainer = Trainer(max_epochs=args.epochs, max_steps=args.maxsteps, callbacks=[ckpt_callback], logger=logger, default_root_dir=results_folder, accelerator="gpu") 63 | trainer.fit(LitModel, train_dataloaders=pl_data.train_dataloader(), val_dataloaders=pl_data.val_dataloader()) 64 | 65 | 66 | -------------------------------------------------------------------------------- /DiGA/utils/pkl_utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from pathlib import Path 3 | 4 | def load_pkl(path: Path): 5 | """Load pkl from path.""" 6 | with open(path, "rb") as infile: 7 | data = pickle.load(infile) 8 | return data 9 | 10 | 11 | def save_pkl(data: object, path: Path): 12 | """Save pkl to path.""" 13 | with open(path, "wb") as outfile: 14 | pickle.dump(data, outfile) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | # MICROSOFT SECURITY 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd). 40 | -------------------------------------------------------------------------------- /TarDiff/README.md: -------------------------------------------------------------------------------- 1 | # TarDiff: Target-Oriented Diffusion Guidance for Synthetic Electronic Health Record Time Series Generation 2 | 3 | 4 | 5 | ## Introduction 6 | ![TarDiff OverView](./images/overview.png) 7 | 8 | Synthetic Electronic Health Record (EHR) time-series generation is crucial for advancing clinical machine learning models, as it helps address data scarcity by providing more training data. However, most existing approaches focus primarily on replicating statistical distributions and temporal dependencies of real-world data. We argue that fidelity to observed data alone does not guarantee better model performance, as common patterns may dominate, limiting the representation of rare but important conditions. This highlights the need for generate synthetic samples to improve performance of specific clinical models to fulfill their target outcomes. To address this, we propose TarDiff, a novel target-oriented diffusion framework that integrates task-specific influence guidance into the synthetic data generation process. Unlike conventional approaches that mimic training data distributions, TarDiff optimizes synthetic samples by quantifying their expected contribution to improving downstream model performance through influence functions. Specifically, we measure the reduction in task-specific loss induced by synthetic samples and embed this influence gradient into the reverse diffusion process, thereby steering the generation towards utility-optimized data. Evaluated on six publicly available EHR datasets, TarDiff achieves state-of-the-art performance, outperforming existing methods by up to 20.4% in AUPRC and 18.4% in AUROC. Our results demonstrate that TarDiff not only preserves temporal fidelity but also enhances downstream model performance, offering a robust solution to data scarcity and class imbalance in healthcare analytics. 9 | 10 | 11 | ## 1 · Environment 12 | 13 | Prepare TarDiff's environment. 14 | ``` 15 | conda env create -f environment.yaml 16 | conda activate tardiff 17 | ``` 18 | 19 | Prepare TS downstream task environment depands on the repo you used for the specific task. 20 | 21 | ## 2 · Data Pre-processing 22 | 23 | You can access the raw datasets at the following links: 24 | 25 | - [eICU Collaborative Research Database](https://eicu-crd.mit.edu/) 26 | - [MIMIC-III Clinical Database](https://physionet.org/content/mimiciii/1.4/) 27 | 28 | > **Note:** Both datasets require prior approval and credentialing before download. 29 | 30 | We focus exclusively on the multivariate time-series recordings available in these datasets. 31 | To assist with preprocessing, we provide high-level extraction scripts under **`data_preprocess/`**. 32 | 33 | 34 | 35 | ## 3 · Stage 1 · Train the *Base* Diffusion Model 36 | 37 | ```bash 38 | bash train.sh # trains TarDiff on MIMIC-III ICU-stay data 39 | ``` 40 | 41 | > **Edit tip:** open the example YAML in `configs/base/` and replace any placeholder data paths with your own before running. 42 | 43 | This step produces an unconditional diffusion model checkpoint—no guidance yet. 44 | 45 | --- 46 | 47 | ## 4 · Stage 2 · Train a Downstream Task Model (Guidance Source) 48 | 49 | An example RNN classifier is supplied in **`classifier/`**. 50 | 51 | ```bash 52 | cd classifier 53 | bash train.sh # saves weights to classifier/checkpoint/ 54 | cd .. 55 | ``` 56 | 57 | Feel free to swap in any architecture that suits your task. 58 | 59 | --- 60 | 61 | ## 5 · Stage 3 · Target-Guided Generation 62 | 63 | With **both** checkpoints ready—the diffusion backbone and the task model—start guided sampling: 64 | 65 | ```bash 66 | bash generation.sh # remember to update paths to both weights 67 | ``` 68 | 69 | The script creates a synthetic dataset tailored to the guidance task. 70 | 71 | --- 72 | 73 | ## 6 · Stage 4 · Utility Evaluation — *TSTR* and *TSRTR* 74 | 75 | After generation, you can assess the utility of the synthetic data for the **target task** using two complementary protocols: 76 | 77 | | Protocol | Training Set | Test Set | Question Answered | 78 | |:---------|:-------------|:---------|:------------------| 79 | | **TSTR** (Train-Synthetic, Test-Real) | **Synthetic only** | **Real** | “If I train a model purely on synthetic EHRs, how well does it generalize to real patients?” | 80 | | **TSRTR** (Train-Synthetic-Real, Test-Real) | **Synthetic + Real** (α ∈ {0.2, …, 1.0}) | **Real** | “If I augment the real training set with α× synthetic samples, does it improve model performance?” | 81 | 82 | --- 83 | 84 | ### How to run TSTR and TSRTR evaluations 85 | 86 | You can directly reuse the training script under `classifier/` to run both evaluations: 87 | 88 | ```bash 89 | cd classifier 90 | bash train.sh # Edit the training data path to point to either synthetic-only or mixed (real + synthetic) data 91 | ``` 92 | 93 | - For **TSTR**, set the training set path to the synthetic dataset. 94 | - For **TSRTR**, combine the real training data with synthetic samples according to your desired α ratio, and update the path accordingly. 95 | 96 | The downstream model will be trained and evaluated automatically on the real validation and test sets. 97 | 98 | --- 99 | 100 | Enjoy exploring target-oriented diffusion for healthcare ML! For issues or pull requests, please open a GitHub ticket. 101 | -------------------------------------------------------------------------------- /TarDiff/classifier/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. -------------------------------------------------------------------------------- /TarDiff/classifier/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | 5 | from __future__ import annotations 6 | import torch 7 | from torch import nn 8 | 9 | 10 | class RNNClassifier(nn.Module): 11 | """Bidirectional LSTM/GRU classifier for fixed‑length sequences.""" 12 | 13 | def __init__( 14 | self, 15 | input_dim: int, 16 | hidden_dim: int = 128, 17 | num_layers: int = 2, 18 | rnn_type: str = "lstm", 19 | num_classes: int = 2, 20 | dropout: float = 0.2, 21 | ) -> None: 22 | super().__init__() 23 | rnn_cls = {"lstm": nn.LSTM, "gru": nn.GRU}[rnn_type.lower()] 24 | self.rnn = rnn_cls( 25 | input_size=input_dim, 26 | hidden_size=hidden_dim, 27 | num_layers=num_layers, 28 | batch_first=True, 29 | bidirectional=True, 30 | dropout=dropout if num_layers > 1 else 0.0, 31 | ) 32 | self.fc = nn.Linear(hidden_dim * 2, num_classes) 33 | 34 | def forward(self, x: torch.Tensor) -> torch.Tensor: # (B, T, F) 35 | rnn_out, _ = self.rnn(x) # (B, T, 2*H) 36 | last_hidden = rnn_out[:, -1, :] # final time‑step representation 37 | logits = self.fc(last_hidden) # (B, C) or (B, 1) 38 | return logits.squeeze(-1) # binary → (B,) ; multi‑class stays (B, C) 39 | -------------------------------------------------------------------------------- /TarDiff/classifier/train.sh: -------------------------------------------------------------------------------- 1 | # prepare guidance model train_tuple.pkl : (data, label) 2 | 3 | python classifier_train.py --num_classes 1 --rnn_type gru --hidden_dim 256 --train_data data/mimic_icustay/train_tuple.pkl --val_data data/mimic_icustay/val_tuple.pkl 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /TarDiff/configs/base/mimic_icustay_base.yaml: -------------------------------------------------------------------------------- 1 | seq_length: &seqlen 24 2 | model: 3 | base_learning_rate: 5.e-5 # set to target_lr by starting main.py with '--scale_lr False' 4 | target: ldm.models.diffusion.ddpm.LatentDiffusion 5 | params: 6 | linear_start: 0.0005 7 | linear_end: 0.1 8 | num_timesteps_cond: 1 9 | log_every_t: 40 10 | timesteps: 200 11 | loss_type: l1 12 | first_stage_key: "context" 13 | cond_stage_key: "context" 14 | image_size: *seqlen 15 | channels: 7 16 | cond_stage_trainable: True 17 | concat_mode: False 18 | scale_by_std: False # True 19 | monitor: 'val/loss_simple_ema' 20 | conditioning_key: crossattn 21 | cond_part_drop: False 22 | dis_loss_flag: False 23 | pair_loss_flag: False 24 | pair_loss_type: l2 25 | pair_loss_weight: 1.0 26 | scheduler_config: # 10000 warmup steps 27 | target: ldm.lr_scheduler.LambdaLinearScheduler 28 | params: 29 | warm_up_steps: [1000] 30 | cycle_lengths: [10000000000000] 31 | f_start: [1.e-6] 32 | f_max: [1.] 33 | f_min: [ 1.] 34 | 35 | unet_config: 36 | target: ldm.modules.diffusionmodules.unet1d.UNetModel 37 | params: 38 | image_size: *seqlen 39 | dims: 1 40 | in_channels: 7 41 | out_channels: 7 42 | model_channels: 64 43 | attention_resolutions: [ 1, 2, 4] # 8, 4, 2 44 | num_res_blocks: 2 45 | channel_mult: [ 1,2,4,4 ] # 8,4,2,1 46 | num_heads: 8 47 | use_scale_shift_norm: True 48 | resblock_updown: True 49 | context_dim: 32 50 | repre_emb_channels: 32 51 | latent_unit: 1 52 | use_cfg: True 53 | use_spatial_transformer: True 54 | num_classes: 2 55 | 56 | first_stage_config: # no first stage model for ts data 57 | target: ldm.models.autoencoder.IdentityFirstStage # VQModelInterface 58 | 59 | cond_stage_config: 60 | target: ldm.modules.encoders.modules.DomainUnifiedEncoder # SplitTSEqEncoder # SplitTSEqEncoder, SingleTSEncoder 61 | params: 62 | dim: 32 63 | window: *seqlen 64 | latent_dim: 32 # 32 * 3 65 | num_channels: 7 66 | use_prototype: False 67 | # use_cfg: True 68 | 69 | data: 70 | target: ldm.data.tsg_dataset.TSClassCondTrainDataModule 71 | params: 72 | data_path_dict: 73 | MIMIC_III_Readmission: data/mimic_icustay/train_tuple.pkl 74 | window: *seqlen 75 | val_portion: 0.1 76 | batch_size: 256 77 | num_workers: 8 78 | normalize: centered_pit 79 | drop_last: True 80 | reweight: False 81 | input_dim: 82 | lightning: 83 | callbacks: 84 | image_logger: 85 | target: utils.callback_utils.TSLogger 86 | params: 87 | batch_frequency: 5000 88 | max_images: 8 89 | increase_log_steps: false 90 | log_images_kwargs: 91 | inpaint: false 92 | plot_swapped_concepts: false 93 | 94 | 95 | trainer: 96 | benchmark: True 97 | max_steps: 20 98 | grad_watch: False -------------------------------------------------------------------------------- /TarDiff/data_preprocess/README.md: -------------------------------------------------------------------------------- 1 | We preprocess MIMIC-III by first querying the raw **vitals** and **admissions** tables, then isolating each ICU stay (`icustay_id`) as an independent sample. For every stay we extract seven routinely recorded signals—heart-rate, systolic/diastolic blood pressure, mean arterial pressure, respiratory rate, temperature, oxygen saturation (SpO₂), and urine output—resample them to an equal 1-hour grid, and truncate or zero-pad so every sample is a fixed **24 × 7** time-series matrix covering the first 24 hours in the unit. We attach a binary in-hospital mortality label from the admissions record, stack all samples into a single array, randomly shuffle, and split 80 % / 20 % into training and test sets while reporting the class balance. This yields a clean, length-aligned dataset ready for downstream modeling without exposing any protected health information. -------------------------------------------------------------------------------- /TarDiff/generation.sh: -------------------------------------------------------------------------------- 1 | python guidance_generation.py --base configs/base/mimic_icustay_base.yaml --gpus 0, --uncond --logdir MIMIC_ICUSTAY -sl 24 2 | -------------------------------------------------------------------------------- /TarDiff/images/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/TimeCraft/156bf4ff7c09c206b2423c39f8206c87f3a8337f/TarDiff/images/overview.png -------------------------------------------------------------------------------- /TarDiff/ldm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | import numpy as np 4 | 5 | 6 | class LambdaWarmUpCosineScheduler: 7 | """ 8 | note: use with a base_lr of 1.0 9 | """ 10 | 11 | def __init__(self, 12 | warm_up_steps, 13 | lr_min, 14 | lr_max, 15 | lr_start, 16 | max_decay_steps, 17 | verbosity_interval=0): 18 | self.lr_warm_up_steps = warm_up_steps 19 | self.lr_start = lr_start 20 | self.lr_min = lr_min 21 | self.lr_max = lr_max 22 | self.lr_max_decay_steps = max_decay_steps 23 | self.last_lr = 0. 24 | self.verbosity_interval = verbosity_interval 25 | 26 | def schedule(self, n, **kwargs): 27 | if self.verbosity_interval > 0: 28 | if n % self.verbosity_interval == 0: 29 | print( 30 | f"current step: {n}, recent lr-multiplier: {self.last_lr}") 31 | if n < self.lr_warm_up_steps: 32 | lr = (self.lr_max - 33 | self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 34 | self.last_lr = lr 35 | return lr 36 | else: 37 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - 38 | self.lr_warm_up_steps) 39 | t = min(t, 1.0) 40 | lr = self.lr_min + 0.5 * (self.lr_max - 41 | self.lr_min) * (1 + np.cos(t * np.pi)) 42 | self.last_lr = lr 43 | return lr 44 | 45 | def __call__(self, n, **kwargs): 46 | return self.schedule(n, **kwargs) 47 | 48 | 49 | class LambdaWarmUpCosineScheduler2: 50 | """ 51 | supports repeated iterations, configurable via lists 52 | note: use with a base_lr of 1.0. 53 | """ 54 | 55 | def __init__(self, 56 | warm_up_steps, 57 | f_min, 58 | f_max, 59 | f_start, 60 | cycle_lengths, 61 | verbosity_interval=0): 62 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len( 63 | f_start) == len(cycle_lengths) 64 | self.lr_warm_up_steps = warm_up_steps 65 | self.f_start = f_start 66 | self.f_min = f_min 67 | self.f_max = f_max 68 | self.cycle_lengths = cycle_lengths 69 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 70 | self.last_f = 0. 71 | self.verbosity_interval = verbosity_interval 72 | 73 | def find_in_interval(self, n): 74 | interval = 0 75 | for cl in self.cum_cycles[1:]: 76 | if n <= cl: 77 | return interval 78 | interval += 1 79 | 80 | def schedule(self, n, **kwargs): 81 | cycle = self.find_in_interval(n) 82 | n = n - self.cum_cycles[cycle] 83 | if self.verbosity_interval > 0: 84 | if n % self.verbosity_interval == 0: 85 | print( 86 | f"current step: {n}, recent lr-multiplier: {self.last_f}, " 87 | f"current cycle {cycle}") 88 | if n < self.lr_warm_up_steps[cycle]: 89 | f = (self.f_max[cycle] - self.f_start[cycle] 90 | ) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 91 | self.last_f = f 92 | return f 93 | else: 94 | t = (n - self.lr_warm_up_steps[cycle]) / ( 95 | self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 96 | t = min(t, 1.0) 97 | f = self.f_min[cycle] + 0.5 * ( 98 | self.f_max[cycle] - self.f_min[cycle]) * (1 + 99 | np.cos(t * np.pi)) 100 | self.last_f = f 101 | return f 102 | 103 | def __call__(self, n, **kwargs): 104 | return self.schedule(n, **kwargs) 105 | 106 | 107 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 108 | 109 | def schedule(self, n, **kwargs): 110 | cycle = self.find_in_interval(n) 111 | n = n - self.cum_cycles[cycle] 112 | if self.verbosity_interval > 0: 113 | if n % self.verbosity_interval == 0: 114 | print( 115 | f"current step: {n}, recent lr-multiplier: {self.last_f}, " 116 | f"current cycle {cycle}") 117 | 118 | if n < self.lr_warm_up_steps[cycle]: 119 | f = (self.f_max[cycle] - self.f_start[cycle] 120 | ) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 121 | self.last_f = f 122 | return f 123 | else: 124 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * ( 125 | self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) 126 | self.last_f = f 127 | return f 128 | -------------------------------------------------------------------------------- /TarDiff/ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | -------------------------------------------------------------------------------- /TarDiff/ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | -------------------------------------------------------------------------------- /TarDiff/ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. -------------------------------------------------------------------------------- /TarDiff/ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | import torch 4 | import numpy as np 5 | 6 | 7 | class AbstractDistribution: 8 | 9 | def sample(self): 10 | raise NotImplementedError() 11 | 12 | def mode(self): 13 | raise NotImplementedError() 14 | 15 | 16 | class DiracDistribution(AbstractDistribution): 17 | 18 | def __init__(self, value): 19 | self.value = value 20 | 21 | def sample(self): 22 | return self.value 23 | 24 | def mode(self): 25 | return self.value 26 | 27 | 28 | class DiagonalGaussianDistribution(object): 29 | 30 | def __init__(self, parameters, deterministic=False): 31 | self.parameters = parameters 32 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 33 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 34 | self.deterministic = deterministic 35 | self.std = torch.exp(0.5 * self.logvar) 36 | self.var = torch.exp(self.logvar) 37 | if self.deterministic: 38 | self.var = self.std = torch.zeros_like( 39 | self.mean).to(device=self.parameters.device) 40 | 41 | def sample(self): 42 | x = self.mean + self.std * torch.randn( 43 | self.mean.shape).to(device=self.parameters.device) 44 | return x 45 | 46 | def kl(self, other=None): 47 | if self.deterministic: 48 | return torch.Tensor([0.]) 49 | else: 50 | if other is None: 51 | return 0.5 * torch.sum( 52 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, 53 | dim=[1, 2, 3]) 54 | else: 55 | return 0.5 * torch.sum( 56 | torch.pow(self.mean - other.mean, 2) / other.var + 57 | self.var / other.var - 1.0 - self.logvar + other.logvar, 58 | dim=[1, 2, 3]) 59 | 60 | def kl_splits(self, latent_unit=6): 61 | mean_splits = self.mean.chunk(latent_unit, dim=-1) 62 | var_splits = self.var.chunk(latent_unit, dim=-1) 63 | logvar_splits = self.logvar.chunk(latent_unit, dim=-1) 64 | kl_loss = 0 65 | for mean, var, logvar in zip(mean_splits, var_splits, logvar_splits): 66 | kl_split = 0.5 * torch.sum(torch.pow(mean, 2) + var - 1.0 - logvar, 67 | dim=-1) 68 | kl_loss += torch.sum(kl_split) / kl_split.shape[0] 69 | return kl_loss / latent_unit 70 | 71 | def nll(self, sample, dims=[1, 2, 3]): 72 | if self.deterministic: 73 | return torch.Tensor([0.]) 74 | logtwopi = np.log(2.0 * np.pi) 75 | return 0.5 * torch.sum(logtwopi + self.logvar + 76 | torch.pow(sample - self.mean, 2) / self.var, 77 | dim=dims) 78 | 79 | def mode(self): 80 | return self.mean 81 | 82 | 83 | def normal_kl(mean1, logvar1, mean2, logvar2): 84 | """ 85 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 86 | Compute the KL divergence between two gaussians. 87 | Shapes are automatically broadcasted, so batches can be compared to 88 | scalars, among other use cases. 89 | """ 90 | tensor = None 91 | for obj in (mean1, logvar1, mean2, logvar2): 92 | if isinstance(obj, torch.Tensor): 93 | tensor = obj 94 | break 95 | assert tensor is not None, "at least one argument must be a Tensor" 96 | 97 | # Force variances to be Tensors. Broadcasting helps convert scalars to 98 | # Tensors, but it does not work for torch.exp(). 99 | logvar1, logvar2 = [ 100 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 101 | for x in (logvar1, logvar2) 102 | ] 103 | 104 | return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + 105 | ((mean1 - mean2)**2) * torch.exp(-logvar2)) 106 | -------------------------------------------------------------------------------- /TarDiff/ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class LitEma(nn.Module): 8 | 9 | def __init__(self, model, decay=0.9999, use_num_upates=True): 10 | super().__init__() 11 | if decay < 0.0 or decay > 1.0: 12 | raise ValueError('Decay must be between 0 and 1') 13 | 14 | self.m_name2s_name = {} 15 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 16 | self.register_buffer( 17 | 'num_updates', 18 | torch.tensor(0, dtype=torch.int) 19 | if use_num_upates else torch.tensor(-1, dtype=torch.int)) 20 | 21 | for name, p in model.named_parameters(): 22 | if p.requires_grad: 23 | #remove as '.'-character is not allowed in buffers 24 | s_name = name.replace('.', '') 25 | self.m_name2s_name.update({name: s_name}) 26 | self.register_buffer(s_name, p.clone().detach().data) 27 | 28 | self.collected_params = [] 29 | 30 | def forward(self, model): 31 | decay = self.decay 32 | 33 | if self.num_updates >= 0: 34 | self.num_updates += 1 35 | decay = min(self.decay, 36 | (1 + self.num_updates) / (10 + self.num_updates)) 37 | 38 | one_minus_decay = 1.0 - decay 39 | 40 | with torch.no_grad(): 41 | m_param = dict(model.named_parameters()) 42 | shadow_params = dict(self.named_buffers()) 43 | 44 | for key in m_param: 45 | if m_param[key].requires_grad: 46 | sname = self.m_name2s_name[key] 47 | shadow_params[sname] = shadow_params[sname].type_as( 48 | m_param[key]) 49 | shadow_params[sname].sub_( 50 | one_minus_decay * 51 | (shadow_params[sname] - m_param[key])) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def copy_to(self, model): 56 | m_param = dict(model.named_parameters()) 57 | shadow_params = dict(self.named_buffers()) 58 | for key in m_param: 59 | if m_param[key].requires_grad: 60 | m_param[key].data.copy_( 61 | shadow_params[self.m_name2s_name[key]].data) 62 | else: 63 | assert not key in self.m_name2s_name 64 | 65 | def store(self, parameters): 66 | """ 67 | Save the current parameters for restoring later. 68 | Args: 69 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 70 | temporarily stored. 71 | """ 72 | self.collected_params = [param.clone() for param in parameters] 73 | 74 | def restore(self, parameters): 75 | """ 76 | Restore the parameters stored with the `store` method. 77 | Useful to validate the model with EMA parameters without affecting the 78 | original optimization process. Store the parameters before the 79 | `copy_to` method. After validation (or model saving), use this to 80 | restore the former parameters. 81 | Args: 82 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 83 | updated with the stored parameters. 84 | """ 85 | for c_param, param in zip(self.collected_params, parameters): 86 | param.data.copy_(c_param.data) 87 | -------------------------------------------------------------------------------- /TarDiff/ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. -------------------------------------------------------------------------------- /TarDiff/ldm/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator 4 | from ldm.modules.losses.vqperceptual import VQLPIPSWithDiscriminator -------------------------------------------------------------------------------- /TarDiff/train.sh: -------------------------------------------------------------------------------- 1 | python train.py --base configs/base/mimic_icustay_base.yaml --gpus 0, --uncond --logdir ts_diff_uncond_testing/mimic_icustay_base -sl 24 --batch_size 128 --max_steps 20000 -lr 0.0001 -s 42 2 | 3 | -------------------------------------------------------------------------------- /TarDiff/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. -------------------------------------------------------------------------------- /TimeDP/.env: -------------------------------------------------------------------------------- 1 | DATA_ROOT=/mnt/storage/ts_data/newer -------------------------------------------------------------------------------- /TimeDP/.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__/ 2 | *.pyc 3 | *.pyo 4 | *.pyd 5 | *.pyw 6 | *.pyz 7 | 8 | logs/ -------------------------------------------------------------------------------- /TimeDP/configs/multi_domain_timedp.yaml: -------------------------------------------------------------------------------- 1 | seq_length: &seqlen 96 2 | model: 3 | base_learning_rate: 0.001 4 | target: ldm.models.diffusion.ddpm_time.LatentDiffusion 5 | params: 6 | linear_start: 0.0005 7 | linear_end: 0.1 8 | num_timesteps_cond: 1 9 | log_every_t: 40 10 | timesteps: 200 11 | loss_type: l1 12 | first_stage_key: "context" 13 | cond_stage_key: "context" 14 | seq_len: *seqlen 15 | channels: 1 16 | cond_stage_trainable: True 17 | concat_mode: False 18 | scale_by_std: False # True 19 | monitor: 'val/loss_simple_ema' 20 | conditioning_key: crossattn 21 | cond_drop_prob: 0.5 22 | 23 | scheduler_config: # 10000 warmup steps 24 | target: ldm.lr_scheduler.LambdaLinearScheduler 25 | params: 26 | warm_up_steps: [1000] 27 | cycle_lengths: [10000000000000] 28 | f_start: [1.e-6] 29 | f_max: [1.] 30 | f_min: [ 1.] 31 | 32 | unet_config: 33 | target: ldm.modules.diffusionmodules.ts_unet.UNetModel 34 | params: 35 | seq_len: *seqlen 36 | dims: 1 37 | in_channels: 1 38 | out_channels: 1 39 | model_channels: 64 40 | attention_resolutions: [ 1, 2, 4] 41 | num_res_blocks: 2 42 | channel_mult: [ 1,2,4,4 ] 43 | num_heads: 8 44 | use_scale_shift_norm: True 45 | resblock_updown: True 46 | context_dim: 32 47 | repre_emb_channels: 32 48 | latent_unit: 1 49 | use_spatial_transformer: true 50 | use_pam: true 51 | 52 | first_stage_config: # no first stage model for ts data 53 | target: ldm.models.autoencoder.IdentityFirstStage 54 | 55 | cond_stage_config: 56 | target: ldm.modules.encoders.modules.DomainUnifiedPrototyper 57 | params: 58 | dim: 32 59 | window: *seqlen 60 | latent_dim: 32 # 32 * 3 61 | num_latents: 16 62 | num_channels: 1 63 | 64 | data: 65 | target: ldm.data.tsg_dataset.TSGDataModule 66 | params: 67 | data_path_dict: 68 | solar: "{DATA_ROOT}/solar_{SEQ_LEN}_train.npy" 69 | electricity: "{DATA_ROOT}/electricity_{SEQ_LEN}_train.npy" 70 | traffic: "{DATA_ROOT}/traffic_{SEQ_LEN}_train.npy" 71 | kddcup: "{DATA_ROOT}/kddcup_{SEQ_LEN}_train.npy" 72 | taxi: "{DATA_ROOT}/taxi_{SEQ_LEN}_train.npy" 73 | exchange: "{DATA_ROOT}/exchange_{SEQ_LEN}_train.npy" 74 | fred_md: "{DATA_ROOT}/fred_md_{SEQ_LEN}_train.npy" 75 | nn5: "{DATA_ROOT}/nn5_{SEQ_LEN}_train.npy" 76 | temp: "{DATA_ROOT}/temp_{SEQ_LEN}_train.npy" 77 | rain: "{DATA_ROOT}/rain_{SEQ_LEN}_train.npy" 78 | pedestrian: "{DATA_ROOT}/pedestrian_{SEQ_LEN}_train.npy" 79 | wind_4_seconds: "{DATA_ROOT}/wind_4_seconds_{SEQ_LEN}_train.npy" 80 | window: *seqlen 81 | val_portion: 0.1 82 | batch_size: 256 83 | num_workers: 8 84 | normalize: centered_pit 85 | drop_last: True 86 | reweight: True 87 | 88 | lightning: 89 | callbacks: 90 | image_logger: 91 | target: utils.callback_utils.TSLogger 92 | params: 93 | batch_frequency: 2000 94 | max_images: 8 95 | increase_log_steps: false 96 | 97 | 98 | trainer: 99 | benchmark: True 100 | max_steps: 50000 -------------------------------------------------------------------------------- /TimeDP/environment.yml: -------------------------------------------------------------------------------- 1 | name: timedp 2 | channels: 3 | - conda-forge 4 | - pytorch 5 | - nvidia 6 | dependencies: 7 | - python=3.8 8 | - pip=20.3 9 | - cudatoolkit=11.7 10 | - torchvision==0.14.0 11 | - numpy=1.19.2 12 | - scikit-learn 13 | - h5py 14 | - ca-certificates 15 | - openssl 16 | - torchaudio==0.13.0 17 | - pytorch==1.13.0 18 | - certifi 19 | - pytorch-cuda=11.7 20 | - packaging=21.3 21 | - setuptools=69.5.1 22 | - statsmodels 23 | - jupyter 24 | - matplotlib 25 | - wandb 26 | - seaborn 27 | - einops 28 | - tqdm 29 | - scipy 30 | - pandas 31 | - omegaconf 32 | - mkl=2023 33 | - pytorch-lightning=1.4.2 34 | - torchmetrics=0.7.3 35 | prefix: /opt/conda/envs/tsgen 36 | -------------------------------------------------------------------------------- /TimeDP/figure/TimeDP_Overview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/TimeCraft/156bf4ff7c09c206b2423c39f8206c87f3a8337f/TimeDP/figure/TimeDP_Overview.jpg -------------------------------------------------------------------------------- /TimeDP/ldm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import numpy as np 5 | 6 | 7 | class LambdaWarmUpCosineScheduler: 8 | """ 9 | note: use with a base_lr of 1.0 10 | """ 11 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 12 | self.lr_warm_up_steps = warm_up_steps 13 | self.lr_start = lr_start 14 | self.lr_min = lr_min 15 | self.lr_max = lr_max 16 | self.lr_max_decay_steps = max_decay_steps 17 | self.last_lr = 0. 18 | self.verbosity_interval = verbosity_interval 19 | 20 | def schedule(self, n, **kwargs): 21 | if self.verbosity_interval > 0: 22 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 23 | if n < self.lr_warm_up_steps: 24 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 25 | self.last_lr = lr 26 | return lr 27 | else: 28 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 29 | t = min(t, 1.0) 30 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 31 | 1 + np.cos(t * np.pi)) 32 | self.last_lr = lr 33 | return lr 34 | 35 | def __call__(self, n, **kwargs): 36 | return self.schedule(n,**kwargs) 37 | 38 | 39 | class LambdaWarmUpCosineScheduler2: 40 | """ 41 | supports repeated iterations, configurable via lists 42 | note: use with a base_lr of 1.0. 43 | """ 44 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 45 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 46 | self.lr_warm_up_steps = warm_up_steps 47 | self.f_start = f_start 48 | self.f_min = f_min 49 | self.f_max = f_max 50 | self.cycle_lengths = cycle_lengths 51 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 52 | self.last_f = 0. 53 | self.verbosity_interval = verbosity_interval 54 | 55 | def find_in_interval(self, n): 56 | interval = 0 57 | for cl in self.cum_cycles[1:]: 58 | if n <= cl: 59 | return interval 60 | interval += 1 61 | 62 | def schedule(self, n, **kwargs): 63 | cycle = self.find_in_interval(n) 64 | n = n - self.cum_cycles[cycle] 65 | if self.verbosity_interval > 0: 66 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 67 | f"current cycle {cycle}") 68 | if n < self.lr_warm_up_steps[cycle]: 69 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 70 | self.last_f = f 71 | return f 72 | else: 73 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 74 | t = min(t, 1.0) 75 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 76 | 1 + np.cos(t * np.pi)) 77 | self.last_f = f 78 | return f 79 | 80 | def __call__(self, n, **kwargs): 81 | return self.schedule(n, **kwargs) 82 | 83 | 84 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 85 | 86 | def schedule(self, n, **kwargs): 87 | cycle = self.find_in_interval(n) 88 | n = n - self.cum_cycles[cycle] 89 | if self.verbosity_interval > 0: 90 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 91 | f"current cycle {cycle}") 92 | 93 | if n < self.lr_warm_up_steps[cycle]: 94 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 95 | self.last_f = f 96 | return f 97 | else: 98 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) 99 | self.last_f = f 100 | return f 101 | 102 | -------------------------------------------------------------------------------- /TimeDP/ldm/models/autoencoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import torch 5 | 6 | class IdentityFirstStage(torch.nn.Module): 7 | def __init__(self, *args, vq_interface=False, **kwargs): 8 | self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff 9 | super().__init__() 10 | 11 | def encode(self, x, *args, **kwargs): 12 | return x 13 | 14 | def decode(self, x, *args, **kwargs): 15 | return x 16 | 17 | def quantize(self, x, *args, **kwargs): 18 | return x 19 | 20 | def forward(self, x, *args, **kwargs): 21 | return x 22 | -------------------------------------------------------------------------------- /TimeDP/ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import torch 5 | import numpy as np 6 | 7 | 8 | class AbstractDistribution: 9 | def sample(self): 10 | raise NotImplementedError() 11 | 12 | def mode(self): 13 | raise NotImplementedError() 14 | 15 | 16 | class DiracDistribution(AbstractDistribution): 17 | def __init__(self, value): 18 | self.value = value 19 | 20 | def sample(self): 21 | return self.value 22 | 23 | def mode(self): 24 | return self.value 25 | 26 | 27 | class DiagonalGaussianDistribution(object): 28 | def __init__(self, parameters, deterministic=False): 29 | self.parameters = parameters 30 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 31 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 32 | self.deterministic = deterministic 33 | self.std = torch.exp(0.5 * self.logvar) 34 | self.var = torch.exp(self.logvar) 35 | if self.deterministic: 36 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 37 | 38 | def sample(self): 39 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 40 | return x 41 | 42 | def kl(self, other=None): 43 | if self.deterministic: 44 | return torch.Tensor([0.]) 45 | else: 46 | if other is None: 47 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 48 | + self.var - 1.0 - self.logvar, 49 | dim=[1, 2, 3]) 50 | else: 51 | return 0.5 * torch.sum( 52 | torch.pow(self.mean - other.mean, 2) / other.var 53 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 54 | dim=[1, 2, 3]) 55 | 56 | def kl_splits(self, latent_unit=6): 57 | mean_splits = self.mean.chunk(latent_unit, dim=-1) 58 | var_splits = self.var.chunk(latent_unit, dim=-1) 59 | logvar_splits = self.logvar.chunk(latent_unit, dim=-1) 60 | kl_loss = 0 61 | for mean, var, logvar in zip(mean_splits, var_splits, logvar_splits): 62 | kl_split = 0.5 * torch.sum(torch.pow(mean, 2) 63 | + var - 1.0 - logvar, 64 | dim=-1) 65 | kl_loss += torch.sum(kl_split) / kl_split.shape[0] 66 | return kl_loss/latent_unit 67 | 68 | def nll(self, sample, dims=[1,2,3]): 69 | if self.deterministic: 70 | return torch.Tensor([0.]) 71 | logtwopi = np.log(2.0 * np.pi) 72 | return 0.5 * torch.sum( 73 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 74 | dim=dims) 75 | 76 | def mode(self): 77 | return self.mean 78 | 79 | 80 | def normal_kl(mean1, logvar1, mean2, logvar2): 81 | """ 82 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 83 | Compute the KL divergence between two gaussians. 84 | Shapes are automatically broadcasted, so batches can be compared to 85 | scalars, among other use cases. 86 | """ 87 | tensor = None 88 | for obj in (mean1, logvar1, mean2, logvar2): 89 | if isinstance(obj, torch.Tensor): 90 | tensor = obj 91 | break 92 | assert tensor is not None, "at least one argument must be a Tensor" 93 | 94 | # Force variances to be Tensors. Broadcasting helps convert scalars to 95 | # Tensors, but it does not work for torch.exp(). 96 | logvar1, logvar2 = [ 97 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 98 | for x in (logvar1, logvar2) 99 | ] 100 | 101 | return 0.5 * ( 102 | -1.0 103 | + logvar2 104 | - logvar1 105 | + torch.exp(logvar1 - logvar2) 106 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 107 | ) 108 | -------------------------------------------------------------------------------- /TimeDP/ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class LitEma(nn.Module): 8 | def __init__(self, model, decay=0.9999, use_num_upates=True): 9 | super().__init__() 10 | if decay < 0.0 or decay > 1.0: 11 | raise ValueError('Decay must be between 0 and 1') 12 | 13 | self.m_name2s_name = {} 14 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 15 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 16 | else torch.tensor(-1,dtype=torch.int)) 17 | 18 | for name, p in model.named_parameters(): 19 | if p.requires_grad: 20 | #remove as '.'-character is not allowed in buffers 21 | s_name = name.replace('.','') 22 | self.m_name2s_name.update({name:s_name}) 23 | self.register_buffer(s_name,p.clone().detach().data) 24 | 25 | self.collected_params = [] 26 | 27 | def forward(self,model): 28 | decay = self.decay 29 | 30 | if self.num_updates >= 0: 31 | self.num_updates += 1 32 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 33 | 34 | one_minus_decay = 1.0 - decay 35 | 36 | with torch.no_grad(): 37 | m_param = dict(model.named_parameters()) 38 | shadow_params = dict(self.named_buffers()) 39 | 40 | for key in m_param: 41 | if m_param[key].requires_grad: 42 | sname = self.m_name2s_name[key] 43 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 44 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 45 | else: 46 | assert not key in self.m_name2s_name 47 | 48 | def copy_to(self, model): 49 | m_param = dict(model.named_parameters()) 50 | shadow_params = dict(self.named_buffers()) 51 | for key in m_param: 52 | if m_param[key].requires_grad: 53 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 54 | else: 55 | assert not key in self.m_name2s_name 56 | 57 | def store(self, parameters): 58 | """ 59 | Save the current parameters for restoring later. 60 | Args: 61 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 62 | temporarily stored. 63 | """ 64 | self.collected_params = [param.clone() for param in parameters] 65 | 66 | def restore(self, parameters): 67 | """ 68 | Restore the parameters stored with the `store` method. 69 | Useful to validate the model with EMA parameters without affecting the 70 | original optimization process. Store the parameters before the 71 | `copy_to` method. After validation (or model saving), use this to 72 | restore the former parameters. 73 | Args: 74 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 75 | updated with the stored parameters. 76 | """ 77 | for c_param, param in zip(self.collected_params, parameters): 78 | param.data.copy_(c_param.data) 79 | -------------------------------------------------------------------------------- /TimeDP/ldm/modules/encoders/modules.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import torch 5 | import torch.nn as nn 6 | from einops import repeat 7 | import copy 8 | 9 | # helpers 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | def default(val, d): 15 | return val if exists(val) else d 16 | 17 | class ResBlockTime(nn.Module): 18 | def __init__(self, in_channels, out_channels, mid_channels=None, bn=False): 19 | super(ResBlockTime, self).__init__() 20 | 21 | if mid_channels is None: 22 | mid_channels = out_channels 23 | 24 | layers = [ 25 | nn.ReLU(), 26 | nn.Conv1d(in_channels, mid_channels, 27 | kernel_size=3, stride=1, padding=1), 28 | nn.ReLU(), 29 | nn.Conv1d(mid_channels, out_channels, 30 | kernel_size=1, stride=1, padding=0) 31 | ] 32 | if bn: 33 | layers.insert(2, nn.BatchNorm1d(out_channels)) 34 | self.convs = nn.Sequential(*layers) 35 | 36 | def forward(self, x): 37 | return x + self.convs(x) 38 | 39 | class View(nn.Module): 40 | def __init__(self, size): 41 | super(View, self).__init__() 42 | self.size = size 43 | 44 | def forward(self, tensor): 45 | return tensor.view(self.size) 46 | 47 | class DomainUnifiedEncoder(nn.Module): 48 | ''' 49 | The input are encoded into two parts, invariant part and specific part. The specific part is generated attending to a random initialized latent vector pool. 50 | The length of the two part are equal in this implementation. 51 | ''' 52 | def __init__(self, dim, window, num_channels=3, latent_dim=32, bn=True, **kwargs): 53 | super().__init__() 54 | dim_out = latent_dim 55 | flatten_dim = int(dim * window / 4) 56 | self.in_encoder = nn.Sequential( 57 | nn.Conv1d(num_channels, dim, kernel_size=4, stride=2, padding=1), 58 | nn.BatchNorm1d(dim), 59 | nn.ReLU(inplace=True), 60 | nn.Conv1d(dim, dim, kernel_size=4, stride=2, padding=1), 61 | nn.BatchNorm1d(dim), 62 | nn.ReLU(inplace=True) 63 | ) 64 | 65 | self.out_encoder = nn.Sequential( 66 | ResBlockTime(dim, dim, bn=bn), 67 | nn.BatchNorm1d(dim), 68 | nn.ReLU(inplace=True), 69 | ResBlockTime(dim, dim, bn=bn), 70 | View((-1, flatten_dim)), # batch_size x 2048 71 | nn.Linear(flatten_dim, dim_out) 72 | ) 73 | 74 | def forward(self, x): 75 | h = self.in_encoder(x) 76 | mask = None 77 | 78 | out = self.out_encoder(h)[:,None] # b, 1, d 79 | return out, mask 80 | 81 | class DomainUnifiedPrototyper(nn.Module): 82 | ''' 83 | The input are encoded into two parts, invariant part and specific part. The specific part is generated attending to a random initialized latent vector pool. 84 | The length of the two part are equal in this implementation. 85 | ''' 86 | def __init__(self, dim, window, num_latents=16, num_channels=3, latent_dim=32, bn=True, **kwargs): 87 | super().__init__() 88 | self.num_latents = num_latents 89 | self.latent_dim = latent_dim 90 | flatten_dim = int(dim * window / 4) 91 | self.share_encoder = nn.Sequential( 92 | nn.Conv1d(num_channels, dim, kernel_size=4, stride=2, padding=1), 93 | nn.BatchNorm1d(dim), 94 | nn.ReLU(inplace=True), 95 | nn.Conv1d(dim, dim, kernel_size=4, stride=2, padding=1), 96 | nn.BatchNorm1d(dim), 97 | nn.ReLU(inplace=True) 98 | ) 99 | self.latents = nn.Parameter(torch.empty(num_latents, self.latent_dim), requires_grad=False) 100 | nn.init.orthogonal_(self.latents) 101 | self.init_latents = copy.deepcopy(self.latents.detach()) 102 | self.mask_ffn = nn.Sequential( 103 | ResBlockTime(dim, dim, bn=bn), 104 | View((-1, flatten_dim)), # batch_size x 2048 105 | nn.Linear(flatten_dim, self.num_latents), 106 | ) 107 | self.sigmoid = nn.Sigmoid() 108 | 109 | def forward(self, x): 110 | b = x.shape[0] 111 | h = self.share_encoder(x) 112 | mask = None 113 | 114 | latents = repeat(self.latents, 'n d -> b n d', b = b) 115 | mask_logit = self.mask_ffn(h) 116 | mask = mask_logit # soft assign 117 | 118 | out = latents # mask 119 | return out, mask 120 | 121 | -------------------------------------------------------------------------------- /TimeDP/ldm/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import importlib 5 | 6 | import torch 7 | 8 | from inspect import isfunction 9 | 10 | 11 | def ismap(x): 12 | if not isinstance(x, torch.Tensor): 13 | return False 14 | return (len(x.shape) == 4) and (x.shape[1] > 3) 15 | 16 | 17 | def isimage(x): 18 | if not isinstance(x, torch.Tensor): 19 | return False 20 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 21 | 22 | 23 | def exists(x): 24 | return x is not None 25 | 26 | 27 | def default(val, d): 28 | if exists(val): 29 | return val 30 | return d() if isfunction(d) else d 31 | 32 | 33 | def mean_flat(tensor): 34 | """ 35 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 36 | Take the mean over all non-batch dimensions. 37 | """ 38 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 39 | 40 | 41 | def count_params(model, verbose=False): 42 | total_params = sum(p.numel() for p in model.parameters()) 43 | if verbose: 44 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") 45 | return total_params 46 | 47 | 48 | def instantiate_from_config(config): 49 | if not "target" in config: 50 | if config == '__is_first_stage__': 51 | return None 52 | elif config == "__is_unconditional__": 53 | return None 54 | raise KeyError("Expected key `target` to instantiate.") 55 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 56 | 57 | 58 | def get_obj_from_str(string, reload=False): 59 | module, cls = string.rsplit(".", 1) 60 | if reload: 61 | module_imp = importlib.import_module(module) 62 | importlib.reload(module_imp) 63 | return getattr(importlib.import_module(module, package=None), cls) 64 | 65 | -------------------------------------------------------------------------------- /TimeDP/main_train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | 5 | import os, sys 6 | from pytorch_lightning.trainer import Trainer 7 | from utils.cli_utils import get_parser 8 | from utils.init_utils import init_model_data_trainer 9 | from utils.test_utils import test_model_with_dp, test_model_uncond, test_model_unseen 10 | 11 | 12 | if __name__ == "__main__": 13 | 14 | data_root = os.environ['DATA_ROOT'] 15 | 16 | parser = get_parser() 17 | parser = Trainer.add_argparse_args(parser) 18 | 19 | model, data, trainer, opt, logdir, melk = init_model_data_trainer(parser) 20 | 21 | # run 22 | if opt.train: 23 | try: 24 | trainer.logger.experiment.config.update(opt) 25 | trainer.fit(model, data) 26 | except Exception: 27 | melk() 28 | raise 29 | if not opt.no_test and not trainer.interrupted: 30 | if opt.uncond: 31 | test_model_uncond(model, data, trainer, opt, logdir) 32 | else: 33 | test_model_with_dp(model, data, trainer, opt, logdir) 34 | test_model_unseen(model, data, trainer, opt, logdir) 35 | 36 | -------------------------------------------------------------------------------- /TimeDP/metrics/metrics_sets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import os 5 | import numpy as np 6 | from utils.data_utils import test_data_loading 7 | from metrics.feature_distance_eval import get_mdd_eval, mmd_metric, get_flat_distance 8 | 9 | 10 | data_root = os.environ['DATA_ROOT'] 11 | 12 | 13 | def calculate_one(gen_data, scaled_ori, model_name, repeat, data_name, seq_len, uni_data_sub, uni_data_div, n_samples): 14 | this_metrics = {} 15 | print(model_name, gen_data.shape) 16 | scaled_gen = (gen_data - uni_data_sub) / uni_data_div 17 | scaled_gen = scaled_gen[:n_samples, :, None] 18 | this_metrics = update_metrics_dict(this_metrics, model_name, data_name, seq_len, scaled_ori, scaled_gen, repeat_id=repeat) 19 | return this_metrics 20 | 21 | def update_metrics_dict(the_dict, key, data_name, seq_len, ori_data, gen_data, repeat_id=0): 22 | if (key, data_name, seq_len, repeat_id) in the_dict: 23 | print(f'{key} {data_name} {seq_len} {repeat_id} already in the dict, skip!') 24 | return the_dict 25 | 26 | mdd = get_mdd_eval(ori_data, gen_data) 27 | the_dict[(key, data_name, seq_len, repeat_id)] = { 28 | 'mdd': mdd, 29 | } 30 | flat_sk_result = get_flat_distance(ori_data, gen_data) 31 | the_dict[(key, data_name, seq_len, repeat_id)].update(flat_sk_result) 32 | the_dict[(key, data_name, seq_len, repeat_id)].update(mmd_metric(ori_data, gen_data)) 33 | return the_dict 34 | 35 | def run_metrics(data_name, seq_len, model_name, gen_data, scale='zscore', exist_dict={}, repeat_id=0): 36 | extend_metrics = exist_dict 37 | 38 | uni_ori_data = test_data_loading(data_name, seq_len, stride=seq_len, univar=True) 39 | uni_data_min, uni_data_max = np.min(uni_ori_data), np.max(uni_ori_data) 40 | uni_data_mean, uni_data_std = np.mean(uni_ori_data), np.std(uni_ori_data) 41 | if scale == 'minmax': 42 | uni_data_sub, uni_data_div = uni_data_min, uni_data_max - uni_data_min + 1e-7 43 | elif scale == 'zscore': 44 | uni_data_sub, uni_data_div = uni_data_mean, uni_data_std + 1e-7 45 | elif scale == 'raw': 46 | uni_data_sub, uni_data_div = 0, 1 47 | elif scale == 'robust_zscore': 48 | median = np.median(uni_ori_data) 49 | mad = np.median(np.abs(uni_ori_data - median)) 50 | uni_data_sub, uni_data_div = median, 1.4826 * mad + 1e-7 51 | uni_scaled_ori = (uni_ori_data - uni_data_sub) / uni_data_div 52 | print(data_name, 'univar', uni_scaled_ori.shape) 53 | scaled_ori = uni_scaled_ori 54 | scaled_gen = (gen_data - uni_data_sub) / uni_data_div 55 | extend_metrics = update_metrics_dict(extend_metrics, model_name, data_name, seq_len, scaled_ori, scaled_gen, repeat_id=repeat_id) 56 | return extend_metrics 57 | -------------------------------------------------------------------------------- /TimeDP/train.sh: -------------------------------------------------------------------------------- 1 | python main_train.py --base configs/multi_domain_tsgen.yaml --gpus 0, --logdir ./logs/dpdiff_12new -sl 168 -up -nl 16 --batch_size 128 -lr 0.0001 -s 0 --debug 2 | -------------------------------------------------------------------------------- /TimeDP/utils/cli_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import argparse 5 | from pytorch_lightning import Trainer 6 | 7 | def get_parser(**parser_kwargs): 8 | def str2bool(v): 9 | if isinstance(v, bool): 10 | return v 11 | if v.lower() in ("yes", "true", "t", "y", "1"): 12 | return True 13 | elif v.lower() in ("no", "false", "f", "n", "0"): 14 | return False 15 | else: 16 | raise argparse.ArgumentTypeError("Boolean value expected.") 17 | 18 | parser = argparse.ArgumentParser(**parser_kwargs) 19 | parser.add_argument("-n","--name",type=str,const=True,default="",nargs="?",help="postfix for logdir") 20 | parser.add_argument("-b","--base",nargs="*",metavar="base_config.yaml",help="paths to base configs. Loaded from left-to-right.", default=list(),) 21 | parser.add_argument("-t","--train",type=str2bool,const=True,default=True,nargs="?",help="train",) 22 | parser.add_argument("-r","--resume",type=str2bool,const=True,default=False,nargs="?",help="resume and test",) 23 | parser.add_argument("--no-test",type=str2bool,const=True,default=False,nargs="?",help="disable test",) 24 | parser.add_argument("-d","--debug",type=str2bool,nargs="?",const=True,default=False,help="debug mode",) 25 | parser.add_argument("-s","--seed",type=int,default=23,help="seed for seed_everything",) 26 | parser.add_argument("-f","--postfix",type=str,default="",help="post-postfix for default name",) 27 | parser.add_argument("-l","--logdir",type=str,default="./logs",help="directory for logging dat shit",) 28 | parser.add_argument("--scale_lr",type=str2bool,nargs="?",const=True,default=False,help="scale base-lr by ngpu * batch_size * n_accumulate",) 29 | parser.add_argument("--ckpt_name",type=str,default="last",help="ckpt name to resume",) 30 | parser.add_argument("-sl","--seq_len", type=int, const=True, default=24,nargs="?", help="sequence length") 31 | parser.add_argument("-uc","--uncond", action='store_true', help="unconditional generation") 32 | parser.add_argument("-up","--use_pam", action='store_true', help="use prototype") 33 | parser.add_argument("-bs","--batch_size", type=int, const=True, default=128,nargs="?", help="batch_size") 34 | parser.add_argument("-nl","--num_latents", type=int, const=True, default=16,nargs="?", help="number of prototypes") 35 | parser.add_argument("-lr","--overwrite_learning_rate", type=float, const=True, default=None, nargs="?", help="learning rate") 36 | 37 | return parser 38 | 39 | def nondefault_trainer_args(opt): 40 | parser = argparse.ArgumentParser() 41 | parser = Trainer.add_argparse_args(parser) 42 | args = parser.parse_args([]) 43 | return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k)) 44 | -------------------------------------------------------------------------------- /TimeDP/utils/pkl_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import pickle 5 | from pathlib import Path 6 | 7 | 8 | def load_pkl(path: Path): 9 | """Load pkl from path.""" 10 | with open(path, "rb") as infile: 11 | data = pickle.load(infile) 12 | return data 13 | 14 | 15 | def save_pkl(data: object, path: Path): 16 | """Save pkl to path.""" 17 | with open(path, "wb") as outfile: 18 | pickle.dump(data, outfile) 19 | -------------------------------------------------------------------------------- /TimeDP/visualize.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import os 5 | import numpy as np 6 | import torch 7 | from pytorch_lightning.trainer import Trainer 8 | import matplotlib.pyplot as plt 9 | import seaborn as sns 10 | from sklearn.manifold import TSNE 11 | from ldm.data.tsg_dataset import TSGDataset 12 | from utils.init_utils import load_model_data 13 | from utils.cli_utils import get_parser 14 | 15 | data_root = os.environ['DATA_ROOT'] 16 | 17 | mix_dataset = [ 18 | 'electricity', 'solar', 'wind_4_seconds', 'traffic', 'taxi', 'pedestrian', 19 | 'kddcup', 'temp', 'rain', 'nn5', 'fred_md', 'exchange' 20 | ] 21 | 22 | dataset_name_map = { 23 | 'electricity': 'Electricity', 24 | 'solar': 'Solar', 25 | 'wind_4_seconds': 'Wind', 26 | 'traffic': 'Traffic', 27 | 'taxi': 'Taxi', 28 | 'pedestrian': 'Pedestrian', 29 | 'kddcup': 'Air Quality', 30 | 'temp': 'Temperature', 31 | 'rain': 'Rain', 32 | 'nn5': 'NN5', 33 | 'fred_md': 'Fred-MD', 34 | 'exchange': 'Exchange' 35 | } 36 | dataset_color_map = { 37 | 'electricity': 'tab:blue', 38 | 'solar': 'tab:blue', 39 | 'wind_4_seconds': 'tab:blue', 40 | 'traffic': 'tab:green', 41 | 'taxi': 'tab:green', 42 | 'pedestrian': 'tab:green', 43 | 'kddcup': 'tab:orange', 44 | 'temp': 'tab:orange', 45 | 'rain': 'tab:orange', 46 | 'nn5': 'tab:purple', 47 | 'fred_md': 'tab:purple', 48 | 'exchange': 'tab:purple' 49 | } 50 | dataset_domain_map = { 51 | 'electricity': 'Energy', 52 | 'solar': 'Energy', 53 | 'wind_4_seconds': 'Energy', 54 | 'traffic': 'Transport', 55 | 'taxi': 'Transport', 56 | 'pedestrian': 'Transport', 57 | 'kddcup': 'Nature', 58 | 'temp': 'Nature', 59 | 'rain': 'Nature', 60 | 'nn5': 'Econ', 61 | 'fred_md': 'Econ', 62 | 'exchange': 'Econ' 63 | } 64 | 65 | def draw_domain_tsne_on_ax(ax1: plt.Axes, ax2: plt.Axes, all_repeat, num_dp=100): 66 | all_data = [x.cpu().numpy() for x in all_repeat] 67 | concat_data = np.concatenate(all_data, axis=0) 68 | 69 | # TSNE anlaysis 70 | tsne = TSNE(n_components=2, perplexity=40, n_iter=500) 71 | transformed_data = tsne.fit_transform(concat_data) 72 | for i, data_name in enumerate(mix_dataset): 73 | tsne_results = transformed_data[i*num_dp:(i+1)*num_dp] 74 | ax1.scatter(tsne_results[:, 0], tsne_results[:, 1], alpha=0.4, label=dataset_name_map[data_name]) 75 | ax2.scatter(tsne_results[:, 0], tsne_results[:, 1], alpha=0.4, color=dataset_color_map[data_name], label=dataset_domain_map[data_name]) 76 | ax1.legend(loc='center left', bbox_to_anchor=(1, 0.5)) 77 | handles, labels = ax2.get_legend_handles_labels() 78 | by_label = dict(zip(labels, handles)) 79 | ax2.legend(by_label.values(), by_label.keys()) 80 | 81 | 82 | #%% 83 | if __name__ == "__main__": 84 | parser = get_parser() 85 | parser = Trainer.add_argparse_args(parser) 86 | model, data, opt, logdir = load_model_data(parser) 87 | 88 | seq_len = opt.seq_len 89 | seed = opt.seed 90 | 91 | model = model.to('cuda') 92 | model.eval() 93 | nu = opt.num_latents 94 | 95 | num_dp = 100 96 | 97 | all_mask = [] 98 | with torch.no_grad(): 99 | for i, dataset in enumerate(mix_dataset): 100 | dataset_data = TSGDataset({dataset: data.norm_train_dict[dataset]}) 101 | 102 | dataset_samples = [] 103 | for idx in np.random.randint(dataset_data.__len__(), size=num_dp): 104 | dataset_samples.append(dataset_data.__getitem__(idx)['context']) 105 | dataset_samples = np.vstack(dataset_samples) 106 | 107 | x = torch.tensor(dataset_samples).to('cuda').float().unsqueeze(1)[:num_dp] 108 | z = model.get_first_stage_encoding(model.encode_first_stage(x)) 109 | c, mask = model.get_learned_conditioning(x, return_mask=True) 110 | all_mask.append(mask) 111 | 112 | # tsne visulize 113 | sns.set_palette("Paired") 114 | fig1, ax1 = plt.subplots(1, 1, figsize=(7, 4), dpi=200) 115 | fig2, ax2 = plt.subplots(1, 1, figsize=(6, 4), dpi=200) 116 | 117 | draw_domain_tsne_on_ax(ax1, ax2, all_mask, num_dp=num_dp) 118 | fig1.tight_layout() 119 | fig1.savefig(logdir / "dataset_tsne_plot.pdf") 120 | fig2.tight_layout() 121 | fig2.savefig(logdir / "domain_tsne_plot.pdf") 122 | 123 | fig, axs = plt.subplots(2, 6, figsize=(18, 6), dpi=200) 124 | for i, dataset_name in enumerate(mix_dataset): 125 | sns.heatmap(all_mask[i].cpu().numpy(), cmap='coolwarm',center=0,ax=axs[i//6, i%6]) 126 | axs[i//6, i%6].set_title(dataset_name_map[dataset_name]) 127 | axs[i//6, i%6].set_yticks([]) 128 | plt.tight_layout() 129 | plt.savefig(logdir / "pam_heatmap.pdf") 130 | 131 | fig, axs = plt.subplots(4, 4, figsize=(8, 8), dpi=200) 132 | latents = model.cond_stage_model.latents.detach().repeat(10, 1, 1) 133 | for i_proto in range(16): 134 | mask = torch.zeros(latents.shape[0], latents.shape[1]).to('cuda') - 1 # 16 dims 135 | mask[:, i_proto] = 1 136 | with torch.no_grad(): 137 | samples, z_denoise_row = model.sample_log(cond=latents, batch_size=latents.shape[0], ddim=False, cfg_scale=1, mask=mask) 138 | draw_samples = samples.detach().cpu().squeeze(1).numpy().mean(0) 139 | axs[i_proto // 4, i_proto % 4].plot(draw_samples) 140 | axs[i_proto // 4, i_proto % 4].set_title(f'Prototype No.{i_proto}') 141 | 142 | plt.tight_layout() 143 | plt.subplots_adjust(top=0.9) 144 | plt.savefig(logdir / f"prototype_semantic.pdf") 145 | plt.show() 146 | 147 | # %% 148 | -------------------------------------------------------------------------------- /diffusion/classifier/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/TimeCraft/156bf4ff7c09c206b2423c39f8206c87f3a8337f/diffusion/classifier/__init__.py -------------------------------------------------------------------------------- /diffusion/classifier/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from __future__ import annotations 5 | import torch 6 | from torch import nn 7 | 8 | 9 | class RNNClassifier(nn.Module): 10 | """Bidirectional LSTM/GRU classifier for fixed‑length sequences.""" 11 | 12 | def __init__( 13 | self, 14 | input_dim: int, 15 | hidden_dim: int = 128, 16 | num_layers: int = 2, 17 | rnn_type: str = "lstm", 18 | num_classes: int = 2, 19 | dropout: float = 0.2, 20 | ) -> None: 21 | super().__init__() 22 | rnn_cls = {"lstm": nn.LSTM, "gru": nn.GRU}[rnn_type.lower()] 23 | self.rnn = rnn_cls( 24 | input_size=input_dim, 25 | hidden_size=hidden_dim, 26 | num_layers=num_layers, 27 | batch_first=True, 28 | bidirectional=True, 29 | dropout=dropout if num_layers > 1 else 0.0, 30 | ) 31 | self.fc = nn.Linear(hidden_dim * 2, num_classes) 32 | 33 | def forward(self, x: torch.Tensor) -> torch.Tensor: # (B, T, F) 34 | rnn_out, _ = self.rnn(x) # (B, T, 2*H) 35 | last_hidden = rnn_out[:, -1, :] # final time‑step representation 36 | logits = self.fc(last_hidden) # (B, C) or (B, 1) 37 | return logits.squeeze(-1) # binary → (B,) ; multi‑class stays (B, C) 38 | -------------------------------------------------------------------------------- /diffusion/classifier/train.sh: -------------------------------------------------------------------------------- 1 | # prepare guidance model train_tuple.pkl : (data, label) 2 | 3 | python classifier/classifier_train.py --num_classes 1 --rnn_type gru --hidden_dim 256 --train_data /data/train_tuple.pkl --val_data /data/val_tuple.pkl 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /diffusion/configs/mimic_icustay_base.yaml: -------------------------------------------------------------------------------- 1 | seq_length: &seqlen 24 2 | model: 3 | base_learning_rate: 5.e-5 # set to target_lr by starting main.py with '--scale_lr False' 4 | target: ldm.models.diffusion.ddpm_time.LatentDiffusion 5 | params: 6 | linear_start: 0.0005 7 | linear_end: 0.1 8 | num_timesteps_cond: 1 9 | log_every_t: 40 10 | timesteps: 1000 11 | loss_type: l1 12 | first_stage_key: "context" 13 | cond_stage_key: "context" 14 | seq_len: *seqlen 15 | channels: 7 16 | cond_stage_trainable: True 17 | concat_mode: False 18 | scale_by_std: False # True 19 | monitor: 'val/loss_simple_ema' 20 | conditioning_key: crossattn 21 | cond_drop_prob: 0.5 22 | 23 | unet_config: 24 | target: ldm.modules.diffusionmodules.ts_unet.UNetModel 25 | params: 26 | seq_len: *seqlen 27 | dims: 1 28 | in_channels: 7 29 | out_channels: 7 30 | model_channels: 64 31 | attention_resolutions: [ 1, 2, 4] 32 | num_res_blocks: 2 33 | channel_mult: [ 1,2,4,4 ] 34 | num_heads: 8 35 | use_scale_shift_norm: True 36 | resblock_updown: True 37 | context_dim: 64 38 | repre_emb_channels: 32 39 | latent_unit: 1 40 | use_spatial_transformer: True 41 | num_classes: 2 42 | 43 | first_stage_config: # no first stage model for ts data 44 | target: ldm.models.autoencoder.IdentityFirstStage # VQModelInterface 45 | 46 | cond_stage_config: 47 | target: ldm.modules.encoders.modules.DomainUnifiedEncoder # SplitTSEqEncoder # SplitTSEqEncoder, SingleTSEncoder 48 | params: 49 | dim: 32 50 | window: *seqlen 51 | latent_dim: 32 # 32 * 3 52 | num_channels: 7 53 | use_prototype: False 54 | # use_cfg: True 55 | 56 | data: 57 | target: ldm.data.tsg_dataset.TSClassCondTrainDataModule 58 | params: 59 | data_path_dict: 60 | MIMIC_III_Readmission: icustay/train_tuple.pkl 61 | window: *seqlen 62 | val_portion: 0.1 63 | batch_size: 256 64 | num_workers: 8 65 | normalize: centered_pit 66 | drop_last: True 67 | reweight: False 68 | input_dim: 69 | lightning: 70 | callbacks: 71 | image_logger: 72 | target: utils.callback_utils.TSLogger 73 | params: 74 | # batch_frequency: 10 75 | batch_frequency: 5000 76 | max_images: 8 77 | increase_log_steps: false 78 | 79 | trainer: 80 | benchmark: True 81 | max_steps: 20000 -------------------------------------------------------------------------------- /diffusion/configs/text_control.yaml: -------------------------------------------------------------------------------- 1 | seq_length: &seqlen 168 2 | model: 3 | base_learning_rate: 5.e-5 # set to target_lr by starting main.py with '--scale_lr False' 4 | target: ldm.models.diffusion.ddpm_time.LatentDiffusion 5 | params: 6 | linear_start: 0.0005 7 | linear_end: 0.1 8 | num_timesteps_cond: 1 9 | log_every_t: 40 10 | timesteps: 1000 11 | loss_type: l1 12 | first_stage_key: "context" 13 | cond_stage_key: "context" 14 | seq_len: *seqlen 15 | channels: 1 16 | cond_stage_trainable: True 17 | concat_mode: False 18 | scale_by_std: False # True 19 | monitor: 'val/loss_simple_ema' 20 | conditioning_key: crossattn 21 | cond_drop_prob: 0.5 22 | 23 | 24 | scheduler_config: # 10000 warmup steps 25 | target: ldm.lr_scheduler.LambdaLinearScheduler 26 | params: 27 | warm_up_steps: [1000] 28 | cycle_lengths: [10000000000000] 29 | f_start: [1.e-6] 30 | f_max: [1.] 31 | f_min: [ 1.] 32 | 33 | unet_config: 34 | target: ldm.modules.diffusionmodules.ts_unet.UNetModel 35 | params: 36 | seq_len: *seqlen 37 | dims: 1 38 | in_channels: 1 39 | out_channels: 1 40 | model_channels: 64 41 | attention_resolutions: [ 1, 2, 4] 42 | num_res_blocks: 2 43 | channel_mult: [ 1,2,4,4 ] 44 | num_heads: 8 45 | use_scale_shift_norm: True 46 | resblock_updown: True 47 | context_dim: 64 48 | repre_emb_channels: 32 49 | latent_unit: 1 50 | use_spatial_transformer: True 51 | 52 | first_stage_config: # no first stage model for ts data 53 | target: ldm.models.autoencoder.IdentityFirstStage # VQModelInterface 54 | 55 | cond_stage_config: 56 | target: ldm.modules.encoders.modules.DomainUnifiedPrototyper 57 | params: 58 | dim: 32 59 | window: *seqlen 60 | latent_dim: 32 # 32 * 3 61 | num_latents: 16 62 | num_channels: 1 63 | 64 | 65 | data: 66 | target: ldm.data.tsg_dataset.TSGtextDataModule 67 | params: 68 | data_path_dict: 69 | nn5: "{DATA_ROOT}/nn5_with_descriptions_168.csv" 70 | window: *seqlen 71 | val_portion: 0.1 72 | batch_size: 256 73 | num_workers: 8 74 | normalize: centered_pit 75 | drop_last: True 76 | reweight: True 77 | 78 | lightning: 79 | callbacks: 80 | image_logger: 81 | target: utils.callback_utils.TSLogger 82 | params: 83 | # batch_frequency: 10 84 | batch_frequency: 2000 85 | max_images: 8 86 | increase_log_steps: false 87 | 88 | 89 | 90 | trainer: 91 | benchmark: True 92 | max_steps: 20000 93 | -------------------------------------------------------------------------------- /diffusion/environment.yml: -------------------------------------------------------------------------------- 1 | name: timedp 2 | channels: 3 | - conda-forge 4 | - pytorch 5 | - nvidia 6 | dependencies: 7 | - python=3.8 8 | - pip=20.3 9 | - cudatoolkit=11.7 10 | - torchvision==0.14.0 11 | - numpy=1.19.2 12 | - scikit-learn 13 | - h5py 14 | - ca-certificates 15 | - openssl 16 | - torchaudio==0.13.0 17 | - pytorch==1.13.0 18 | - certifi 19 | - pytorch-cuda=11.7 20 | - packaging=21.3 21 | - setuptools=69.5.1 22 | - statsmodels 23 | - jupyter 24 | - matplotlib 25 | - wandb 26 | - seaborn 27 | - einops 28 | - tqdm 29 | - scipy 30 | - pandas 31 | - omegaconf 32 | - mkl=2023 33 | - pytorch-lightning=1.4.2 34 | - torchmetrics=0.7.3 35 | prefix: /opt/conda/envs/tsgen 36 | -------------------------------------------------------------------------------- /diffusion/inference.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import os 5 | from pytorch_lightning.trainer import Trainer 6 | from utils.cli_utils import get_parser 7 | from utils.init_utils import init_model_data_trainer 8 | from utils.test_utils import test_model_with_dp, test_model_uncond, test_model_unseen, test_model_guidance 9 | 10 | if __name__ == "__main__": 11 | data_root = os.environ.get('DATA_ROOT', './') 12 | 13 | parser = get_parser() 14 | parser = Trainer.add_argparse_args(parser) 15 | 16 | parser.set_defaults(train=False) 17 | parser.set_defaults(no_test=False) 18 | 19 | model, data, trainer, opt, logdir, melk = init_model_data_trainer(parser) 20 | 21 | if opt.ckpt_path is not None: 22 | print(f"Loading checkpoint from {opt.ckpt_path}") 23 | model.init_from_ckpt(opt.ckpt_path) 24 | elif trainer.callbacks[-1].best_model_path: 25 | best_ckpt_path = trainer.callbacks[-1].best_model_path 26 | print(f"Loading best model from {best_ckpt_path}") 27 | model.init_from_ckpt(best_ckpt_path) 28 | else: 29 | print("⚠️ No checkpoint path provided and no best_model_path found! Proceeding without loading weights...") 30 | 31 | model.cuda() 32 | model.eval() 33 | 34 | if opt.use_text: 35 | print("Inference with text input enabled.") 36 | else: 37 | print("Inference without text input.") 38 | 39 | if not opt.no_test: 40 | if opt.uncond and not opt.use_guidance: 41 | test_model_uncond(model, data, trainer, opt, logdir) 42 | elif opt.use_guidance: 43 | test_model_guidance(model, data, trainer, opt, logdir) 44 | 45 | else: 46 | test_model_with_dp(model, data, trainer, opt, logdir, use_pam=opt.use_pam, use_text=opt.use_text, text_emb_dir=opt.text_emb_dir) 47 | test_model_unseen(model, data, trainer, opt, logdir,use_text=opt.use_text, text_emb_dir=opt.text_emb_dir) 48 | -------------------------------------------------------------------------------- /diffusion/ldm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import numpy as np 5 | 6 | 7 | class LambdaWarmUpCosineScheduler: 8 | """ 9 | note: use with a base_lr of 1.0 10 | """ 11 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 12 | self.lr_warm_up_steps = warm_up_steps 13 | self.lr_start = lr_start 14 | self.lr_min = lr_min 15 | self.lr_max = lr_max 16 | self.lr_max_decay_steps = max_decay_steps 17 | self.last_lr = 0. 18 | self.verbosity_interval = verbosity_interval 19 | 20 | def schedule(self, n, **kwargs): 21 | if self.verbosity_interval > 0: 22 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 23 | if n < self.lr_warm_up_steps: 24 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 25 | self.last_lr = lr 26 | return lr 27 | else: 28 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 29 | t = min(t, 1.0) 30 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 31 | 1 + np.cos(t * np.pi)) 32 | self.last_lr = lr 33 | return lr 34 | 35 | def __call__(self, n, **kwargs): 36 | return self.schedule(n,**kwargs) 37 | 38 | 39 | class LambdaWarmUpCosineScheduler2: 40 | """ 41 | supports repeated iterations, configurable via lists 42 | note: use with a base_lr of 1.0. 43 | """ 44 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 45 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 46 | self.lr_warm_up_steps = warm_up_steps 47 | self.f_start = f_start 48 | self.f_min = f_min 49 | self.f_max = f_max 50 | self.cycle_lengths = cycle_lengths 51 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 52 | self.last_f = 0. 53 | self.verbosity_interval = verbosity_interval 54 | 55 | def find_in_interval(self, n): 56 | interval = 0 57 | for cl in self.cum_cycles[1:]: 58 | if n <= cl: 59 | return interval 60 | interval += 1 61 | 62 | def schedule(self, n, **kwargs): 63 | cycle = self.find_in_interval(n) 64 | n = n - self.cum_cycles[cycle] 65 | if self.verbosity_interval > 0: 66 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 67 | f"current cycle {cycle}") 68 | if n < self.lr_warm_up_steps[cycle]: 69 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 70 | self.last_f = f 71 | return f 72 | else: 73 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 74 | t = min(t, 1.0) 75 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 76 | 1 + np.cos(t * np.pi)) 77 | self.last_f = f 78 | return f 79 | 80 | def __call__(self, n, **kwargs): 81 | return self.schedule(n, **kwargs) 82 | 83 | 84 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 85 | 86 | def schedule(self, n, **kwargs): 87 | cycle = self.find_in_interval(n) 88 | n = n - self.cum_cycles[cycle] 89 | if self.verbosity_interval > 0: 90 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 91 | f"current cycle {cycle}") 92 | 93 | if n < self.lr_warm_up_steps[cycle]: 94 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 95 | self.last_f = f 96 | return f 97 | else: 98 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) 99 | self.last_f = f 100 | return f 101 | 102 | -------------------------------------------------------------------------------- /diffusion/ldm/models/autoencoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import torch 5 | 6 | class IdentityFirstStage(torch.nn.Module): 7 | def __init__(self, *args, vq_interface=False, **kwargs): 8 | self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff 9 | super().__init__() 10 | 11 | def encode(self, x, *args, **kwargs): 12 | return x 13 | 14 | def decode(self, x, *args, **kwargs): 15 | return x 16 | 17 | def quantize(self, x, *args, **kwargs): 18 | return x 19 | 20 | def forward(self, x, *args, **kwargs): 21 | return x 22 | -------------------------------------------------------------------------------- /diffusion/ldm/models/diffusion/conditioning_mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | class ConditioningMLP(nn.Module): 8 | """ 9 | Fusion module that integrates text embedding into the latent c representation. 10 | Supports 'add', 'concat', and 'gated_add' fusion modes. 11 | """ 12 | 13 | def __init__(self, c_dim, text_dim, fusion_type='gated_add', pool_type='mean'): 14 | """ 15 | Args: 16 | c_dim (int): The channel dimension of `c` (e.g., 64) 17 | text_dim (int): The dimension of text embeddings (e.g., 4096) 18 | fusion_type (str): One of ['add', 'concat', 'gated_add'] 19 | """ 20 | super().__init__() 21 | assert fusion_type in ['add', 'concat', 'gated_add'], \ 22 | f"fusion_type must be one of ['add', 'concat', 'gated_add'], got {fusion_type}" 23 | assert pool_type in ['mean', 'max'] 24 | 25 | self.fusion_type = fusion_type 26 | self.pool_type = pool_type 27 | self.c_dim = c_dim 28 | self.text_dim = text_dim 29 | 30 | # Text embedding MLP projection: text_dim -> c_dim 31 | self.text_mlp = nn.Sequential( 32 | nn.Linear(text_dim, c_dim), 33 | nn.ReLU(inplace=True), 34 | nn.Linear(c_dim, c_dim) 35 | ) 36 | 37 | # Optional gate for 'gated_add' 38 | if fusion_type == 'gated_add': 39 | self.gate_layer = nn.Sequential( 40 | nn.Linear(c_dim * 2, c_dim), 41 | nn.Sigmoid() 42 | ) 43 | 44 | def forward(self, c, text_embedding): 45 | """ 46 | Args: 47 | c (torch.Tensor): Conditioning input of shape [B, C, T] 48 | text_embedding (torch.Tensor): Text embeddings of shape [B, text_dim] 49 | Returns: 50 | torch.Tensor: Fused conditioning [B, C, T] (or [B, 2C, T] if 'concat') 51 | """ 52 | B, C, T = c.shape 53 | assert C == self.c_dim, f"Expected c_dim={self.c_dim}, got {C}" 54 | 55 | # Project text embeddings from [B, text_dim] -> [B, c_dim] 56 | text_proj = self.text_mlp(text_embedding) # [B, C] 57 | 58 | # Expand to match c's temporal dimension: [B, C, T] 59 | text_proj_expanded = text_proj.unsqueeze(-1).expand(-1, -1, T) 60 | 61 | # Fusion 62 | if self.fusion_type == 'add': 63 | fused = c + text_proj_expanded 64 | elif self.fusion_type == 'concat': 65 | fused = torch.cat([c, text_proj_expanded], dim=1) # [B, 2C, T] 66 | elif self.fusion_type == 'gated_add': 67 | if self.pool_type == 'mean': 68 | pooled_c = c.mean(-1) 69 | elif self.pool_type == 'max': 70 | pooled_c, _ = c.max(-1) 71 | 72 | gate_input = torch.cat([pooled_c, text_proj], dim=-1) # [B, 2C] 73 | gate = self.gate_layer(gate_input).unsqueeze(-1) # [B, C, 1] 74 | fused = gate * text_proj_expanded + (1 - gate) * c 75 | else: 76 | raise ValueError(f"Unsupported fusion_type: {self.fusion_type}") 77 | 78 | return fused 79 | -------------------------------------------------------------------------------- /diffusion/ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import torch 5 | import numpy as np 6 | 7 | 8 | class AbstractDistribution: 9 | def sample(self): 10 | raise NotImplementedError() 11 | 12 | def mode(self): 13 | raise NotImplementedError() 14 | 15 | 16 | class DiracDistribution(AbstractDistribution): 17 | def __init__(self, value): 18 | self.value = value 19 | 20 | def sample(self): 21 | return self.value 22 | 23 | def mode(self): 24 | return self.value 25 | 26 | 27 | class DiagonalGaussianDistribution(object): 28 | def __init__(self, parameters, deterministic=False): 29 | self.parameters = parameters 30 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 31 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 32 | self.deterministic = deterministic 33 | self.std = torch.exp(0.5 * self.logvar) 34 | self.var = torch.exp(self.logvar) 35 | if self.deterministic: 36 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 37 | 38 | def sample(self): 39 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 40 | return x 41 | 42 | def kl(self, other=None): 43 | if self.deterministic: 44 | return torch.Tensor([0.]) 45 | else: 46 | if other is None: 47 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 48 | + self.var - 1.0 - self.logvar, 49 | dim=[1, 2, 3]) 50 | else: 51 | return 0.5 * torch.sum( 52 | torch.pow(self.mean - other.mean, 2) / other.var 53 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 54 | dim=[1, 2, 3]) 55 | 56 | def kl_splits(self, latent_unit=6): 57 | mean_splits = self.mean.chunk(latent_unit, dim=-1) 58 | var_splits = self.var.chunk(latent_unit, dim=-1) 59 | logvar_splits = self.logvar.chunk(latent_unit, dim=-1) 60 | kl_loss = 0 61 | for mean, var, logvar in zip(mean_splits, var_splits, logvar_splits): 62 | kl_split = 0.5 * torch.sum(torch.pow(mean, 2) 63 | + var - 1.0 - logvar, 64 | dim=-1) 65 | kl_loss += torch.sum(kl_split) / kl_split.shape[0] 66 | return kl_loss/latent_unit 67 | 68 | def nll(self, sample, dims=[1,2,3]): 69 | if self.deterministic: 70 | return torch.Tensor([0.]) 71 | logtwopi = np.log(2.0 * np.pi) 72 | return 0.5 * torch.sum( 73 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 74 | dim=dims) 75 | 76 | def mode(self): 77 | return self.mean 78 | 79 | 80 | def normal_kl(mean1, logvar1, mean2, logvar2): 81 | """ 82 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 83 | Compute the KL divergence between two gaussians. 84 | Shapes are automatically broadcasted, so batches can be compared to 85 | scalars, among other use cases. 86 | """ 87 | tensor = None 88 | for obj in (mean1, logvar1, mean2, logvar2): 89 | if isinstance(obj, torch.Tensor): 90 | tensor = obj 91 | break 92 | assert tensor is not None, "at least one argument must be a Tensor" 93 | 94 | # Force variances to be Tensors. Broadcasting helps convert scalars to 95 | # Tensors, but it does not work for torch.exp(). 96 | logvar1, logvar2 = [ 97 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 98 | for x in (logvar1, logvar2) 99 | ] 100 | 101 | return 0.5 * ( 102 | -1.0 103 | + logvar2 104 | - logvar1 105 | + torch.exp(logvar1 - logvar2) 106 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 107 | ) 108 | -------------------------------------------------------------------------------- /diffusion/ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import torch 5 | from torch import nn 6 | 7 | 8 | class LitEma(nn.Module): 9 | def __init__(self, model, decay=0.9999, use_num_upates=True): 10 | super().__init__() 11 | if decay < 0.0 or decay > 1.0: 12 | raise ValueError('Decay must be between 0 and 1') 13 | 14 | self.m_name2s_name = {} 15 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 16 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 17 | else torch.tensor(-1,dtype=torch.int)) 18 | 19 | for name, p in model.named_parameters(): 20 | if p.requires_grad: 21 | #remove as '.'-character is not allowed in buffers 22 | s_name = name.replace('.','') 23 | self.m_name2s_name.update({name:s_name}) 24 | self.register_buffer(s_name,p.clone().detach().data) 25 | 26 | self.collected_params = [] 27 | 28 | def forward(self,model): 29 | decay = self.decay 30 | 31 | if self.num_updates >= 0: 32 | self.num_updates += 1 33 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 34 | 35 | one_minus_decay = 1.0 - decay 36 | 37 | with torch.no_grad(): 38 | m_param = dict(model.named_parameters()) 39 | shadow_params = dict(self.named_buffers()) 40 | 41 | for key in m_param: 42 | if m_param[key].requires_grad: 43 | sname = self.m_name2s_name[key] 44 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 45 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 46 | else: 47 | assert not key in self.m_name2s_name 48 | 49 | def copy_to(self, model): 50 | m_param = dict(model.named_parameters()) 51 | shadow_params = dict(self.named_buffers()) 52 | for key in m_param: 53 | if m_param[key].requires_grad: 54 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 55 | else: 56 | assert not key in self.m_name2s_name 57 | 58 | def store(self, parameters): 59 | """ 60 | Save the current parameters for restoring later. 61 | Args: 62 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 63 | temporarily stored. 64 | """ 65 | self.collected_params = [param.clone() for param in parameters] 66 | 67 | def restore(self, parameters): 68 | """ 69 | Restore the parameters stored with the `store` method. 70 | Useful to validate the model with EMA parameters without affecting the 71 | original optimization process. Store the parameters before the 72 | `copy_to` method. After validation (or model saving), use this to 73 | restore the former parameters. 74 | Args: 75 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 76 | updated with the stored parameters. 77 | """ 78 | for c_param, param in zip(self.collected_params, parameters): 79 | param.data.copy_(c_param.data) 80 | -------------------------------------------------------------------------------- /diffusion/ldm/modules/encoders/modules.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import torch 5 | import torch.nn as nn 6 | from einops import repeat 7 | import copy 8 | 9 | # helpers 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | def default(val, d): 15 | return val if exists(val) else d 16 | 17 | class ResBlockTime(nn.Module): 18 | def __init__(self, in_channels, out_channels, mid_channels=None, bn=False): 19 | super(ResBlockTime, self).__init__() 20 | 21 | if mid_channels is None: 22 | mid_channels = out_channels 23 | 24 | layers = [ 25 | nn.ReLU(), 26 | nn.Conv1d(in_channels, mid_channels, 27 | kernel_size=3, stride=1, padding=1), 28 | nn.ReLU(), 29 | nn.Conv1d(mid_channels, out_channels, 30 | kernel_size=1, stride=1, padding=0) 31 | ] 32 | if bn: 33 | layers.insert(2, nn.BatchNorm1d(out_channels)) 34 | self.convs = nn.Sequential(*layers) 35 | 36 | def forward(self, x): 37 | return x + self.convs(x) 38 | 39 | class View(nn.Module): 40 | def __init__(self, size): 41 | super(View, self).__init__() 42 | self.size = size 43 | 44 | def forward(self, tensor): 45 | return tensor.view(self.size) 46 | 47 | class DomainUnifiedEncoder(nn.Module): 48 | ''' 49 | The input are encoded into two parts, invariant part and specific part. The specific part is generated attending to a random initialized latent vector pool. 50 | The length of the two part are equal in this implementation. 51 | ''' 52 | def __init__(self, dim, window, num_channels=3, latent_dim=32, bn=True, **kwargs): 53 | super().__init__() 54 | dim_out = latent_dim 55 | flatten_dim = int(dim * window / 4) 56 | self.in_encoder = nn.Sequential( 57 | nn.Conv1d(num_channels, dim, kernel_size=4, stride=2, padding=1), 58 | nn.BatchNorm1d(dim), 59 | nn.ReLU(inplace=True), 60 | nn.Conv1d(dim, dim, kernel_size=4, stride=2, padding=1), 61 | nn.BatchNorm1d(dim), 62 | nn.ReLU(inplace=True) 63 | ) 64 | 65 | self.out_encoder = nn.Sequential( 66 | ResBlockTime(dim, dim, bn=bn), 67 | nn.BatchNorm1d(dim), 68 | nn.ReLU(inplace=True), 69 | ResBlockTime(dim, dim, bn=bn), 70 | View((-1, flatten_dim)), # batch_size x 2048 71 | nn.Linear(flatten_dim, dim_out) 72 | ) 73 | 74 | def forward(self, x): 75 | h = self.in_encoder(x) 76 | mask = None 77 | 78 | out = self.out_encoder(h)[:,None] # b, 1, d 79 | return out, mask 80 | 81 | class DomainUnifiedPrototyper(nn.Module): 82 | ''' 83 | The input are encoded into two parts, invariant part and specific part. The specific part is generated attending to a random initialized latent vector pool. 84 | The length of the two part are equal in this implementation. 85 | ''' 86 | def __init__(self, dim, window, num_latents=16, num_channels=3, latent_dim=32, bn=True, **kwargs): 87 | super().__init__() 88 | self.num_latents = num_latents 89 | self.latent_dim = latent_dim 90 | flatten_dim = int(dim * window / 4) 91 | self.share_encoder = nn.Sequential( 92 | nn.Conv1d(num_channels, dim, kernel_size=4, stride=2, padding=1), 93 | nn.BatchNorm1d(dim), 94 | nn.ReLU(inplace=True), 95 | nn.Conv1d(dim, dim, kernel_size=4, stride=2, padding=1), 96 | nn.BatchNorm1d(dim), 97 | nn.ReLU(inplace=True) 98 | ) 99 | self.latents = nn.Parameter(torch.empty(num_latents, self.latent_dim), requires_grad=False) 100 | nn.init.orthogonal_(self.latents) 101 | self.init_latents = copy.deepcopy(self.latents.detach()) 102 | self.mask_ffn = nn.Sequential( 103 | ResBlockTime(dim, dim, bn=bn), 104 | View((-1, flatten_dim)), # batch_size x 2048 105 | nn.Linear(flatten_dim, self.num_latents), 106 | ) 107 | self.sigmoid = nn.Sigmoid() 108 | 109 | def forward(self, x): 110 | b = x.shape[0] 111 | h = self.share_encoder(x) 112 | mask = None 113 | 114 | latents = repeat(self.latents, 'n d -> b n d', b = b) 115 | mask_logit = self.mask_ffn(h) 116 | mask = mask_logit # soft assign 117 | 118 | out = latents # mask 119 | return out, mask 120 | 121 | -------------------------------------------------------------------------------- /diffusion/ldm/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | import importlib 4 | 5 | import torch 6 | 7 | from inspect import isfunction 8 | 9 | 10 | def ismap(x): 11 | if not isinstance(x, torch.Tensor): 12 | return False 13 | return (len(x.shape) == 4) and (x.shape[1] > 3) 14 | 15 | 16 | def isimage(x): 17 | if not isinstance(x, torch.Tensor): 18 | return False 19 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 20 | 21 | 22 | def exists(x): 23 | return x is not None 24 | 25 | 26 | def default(val, d): 27 | if exists(val): 28 | return val 29 | return d() if isfunction(d) else d 30 | 31 | 32 | def mean_flat(tensor): 33 | """ 34 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 35 | Take the mean over all non-batch dimensions. 36 | """ 37 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 38 | 39 | 40 | def count_params(model, verbose=False): 41 | total_params = sum(p.numel() for p in model.parameters()) 42 | if verbose: 43 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") 44 | return total_params 45 | 46 | 47 | def instantiate_from_config(config): 48 | if not "target" in config: 49 | if config == '__is_first_stage__': 50 | return None 51 | elif config == "__is_unconditional__": 52 | return None 53 | raise KeyError("Expected key `target` to instantiate.") 54 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 55 | 56 | 57 | def get_obj_from_str(string, reload=False): 58 | module, cls = string.rsplit(".", 1) 59 | if reload: 60 | module_imp = importlib.import_module(module) 61 | importlib.reload(module_imp) 62 | return getattr(importlib.import_module(module, package=None), cls) 63 | 64 | -------------------------------------------------------------------------------- /diffusion/main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import os, sys 5 | from pytorch_lightning.trainer import Trainer 6 | from utils.cli_utils import get_parser 7 | from utils.init_utils import init_model_data_trainer 8 | from utils.test_utils import test_model_with_dp, test_model_uncond, test_model_unseen 9 | 10 | 11 | if __name__ == "__main__": 12 | 13 | data_root = os.environ['DATA_ROOT'] 14 | 15 | parser = get_parser() 16 | parser = Trainer.add_argparse_args(parser) 17 | 18 | model, data, trainer, opt, logdir, melk = init_model_data_trainer(parser) 19 | 20 | # run 21 | if opt.train: 22 | try: 23 | trainer.logger.experiment.config.update(opt) 24 | trainer.fit(model, data) 25 | except Exception: 26 | melk() 27 | raise 28 | 29 | -------------------------------------------------------------------------------- /diffusion/metrics/metrics_sets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import os 5 | import numpy as np 6 | from diffusion.utils.data_utils import test_data_loading 7 | from diffusion.metrics.feature_distance_eval import get_mdd_eval, mmd_metric, get_flat_distance 8 | 9 | 10 | data_root = os.environ['DATA_ROOT'] 11 | 12 | 13 | def calculate_one(gen_data, scaled_ori, model_name, repeat, data_name, seq_len, uni_data_sub, uni_data_div, n_samples): 14 | this_metrics = {} 15 | print(model_name, gen_data.shape) 16 | scaled_gen = (gen_data - uni_data_sub) / uni_data_div 17 | scaled_gen = scaled_gen[:n_samples, :, None] 18 | this_metrics = update_metrics_dict(this_metrics, model_name, data_name, seq_len, scaled_ori, scaled_gen, repeat_id=repeat) 19 | return this_metrics 20 | 21 | def update_metrics_dict(the_dict, key, data_name, seq_len, ori_data, gen_data, repeat_id=0): 22 | if (key, data_name, seq_len, repeat_id) in the_dict: 23 | print(f'{key} {data_name} {seq_len} {repeat_id} already in the dict, skip!') 24 | return the_dict 25 | 26 | mdd = get_mdd_eval(ori_data, gen_data) 27 | the_dict[(key, data_name, seq_len, repeat_id)] = { 28 | 'mdd': mdd, 29 | } 30 | flat_sk_result = get_flat_distance(ori_data, gen_data) 31 | the_dict[(key, data_name, seq_len, repeat_id)].update(flat_sk_result) 32 | the_dict[(key, data_name, seq_len, repeat_id)].update(mmd_metric(ori_data, gen_data)) 33 | return the_dict 34 | 35 | def run_metrics(data_name, seq_len, model_name, gen_data, scale='zscore', exist_dict={}, repeat_id=0): 36 | extend_metrics = exist_dict 37 | 38 | uni_ori_data, _ = test_data_loading(data_name, seq_len, stride=seq_len, univar=True) 39 | uni_data_min, uni_data_max = np.min(uni_ori_data), np.max(uni_ori_data) 40 | uni_data_mean, uni_data_std = np.mean(uni_ori_data), np.std(uni_ori_data) 41 | if scale == 'minmax': 42 | uni_data_sub, uni_data_div = uni_data_min, uni_data_max - uni_data_min + 1e-7 43 | elif scale == 'zscore': 44 | uni_data_sub, uni_data_div = uni_data_mean, uni_data_std + 1e-7 45 | elif scale == 'raw': 46 | uni_data_sub, uni_data_div = 0, 1 47 | elif scale == 'robust_zscore': 48 | median = np.median(uni_ori_data) 49 | mad = np.median(np.abs(uni_ori_data - median)) 50 | uni_data_sub, uni_data_div = median, 1.4826 * mad + 1e-7 51 | uni_scaled_ori = (uni_ori_data - uni_data_sub) / uni_data_div 52 | print(data_name, 'univar', uni_scaled_ori.shape) 53 | scaled_ori = uni_scaled_ori 54 | scaled_gen = (gen_data - uni_data_sub) / uni_data_div 55 | extend_metrics = update_metrics_dict(extend_metrics, model_name, data_name, seq_len, scaled_ori, scaled_gen, repeat_id=repeat_id) 56 | return extend_metrics 57 | -------------------------------------------------------------------------------- /diffusion/train.sh: -------------------------------------------------------------------------------- 1 | python main_train.py --base configs/multi_domain_tsgen.yaml --gpus 0, --logdir ./logs/dpdiff_12new -sl 168 -up -nl 16 --batch_size 128 -lr 0.0001 -s 0 --debug 2 | -------------------------------------------------------------------------------- /diffusion/utils/cli_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import argparse 5 | from pytorch_lightning import Trainer 6 | 7 | def get_parser(**parser_kwargs): 8 | def str2bool(v): 9 | if isinstance(v, bool): 10 | return v 11 | if v.lower() in ("yes", "true", "t", "y", "1"): 12 | return True 13 | elif v.lower() in ("no", "false", "f", "n", "0"): 14 | return False 15 | else: 16 | raise argparse.ArgumentTypeError("Boolean value expected.") 17 | 18 | parser = argparse.ArgumentParser(**parser_kwargs) 19 | parser.add_argument("-n","--name",type=str,const=True,default="",nargs="?",help="postfix for logdir") 20 | parser.add_argument("-b","--base",nargs="*",metavar="base_config.yaml",help="paths to base configs. Loaded from left-to-right.", default=list(),) 21 | parser.add_argument("-t","--train",type=str2bool,const=True,default=True,nargs="?",help="train",) 22 | parser.add_argument("-r","--resume",type=str2bool,const=True,default=False,nargs="?",help="resume and test",) 23 | parser.add_argument("--no-test",type=str2bool,const=True,default=False,nargs="?",help="disable test",) 24 | parser.add_argument("-d","--debug",type=str2bool,nargs="?",const=True,default=False,help="debug mode",) 25 | parser.add_argument("-s","--seed",type=int,default=23,help="seed for seed_everything",) 26 | parser.add_argument("-f","--postfix",type=str,default="",help="post-postfix for default name",) 27 | parser.add_argument("-l","--logdir",type=str,default="./logs",help="directory for logging dat shit",) 28 | parser.add_argument("--scale_lr",type=str2bool,nargs="?",const=True,default=False,help="scale base-lr by ngpu * batch_size * n_accumulate",) 29 | parser.add_argument("--ckpt_name",type=str,default="last",help="ckpt name to resume",) 30 | parser.add_argument("-sl","--seq_len", type=int, const=True, default=24,nargs="?", help="sequence length") 31 | parser.add_argument("-uc","--uncond", action='store_true', help="unconditional generation") 32 | parser.add_argument("-up","--use_pam", action='store_true', help="use prototype") 33 | parser.add_argument("-bs","--batch_size", type=int, const=True, default=128,nargs="?", help="batch_size") 34 | parser.add_argument("-nl","--num_latents", type=int, const=True, default=16,nargs="?", help="number of prototypes") 35 | parser.add_argument("-lr","--overwrite_learning_rate", type=float, const=True, default=None, nargs="?", help="learning rate") 36 | parser.add_argument("--use_text", action='store_true', help="whether to use text input") 37 | parser.add_argument("--use_guidance", action='store_true', help="whether to use influence-guided generation") 38 | parser.add_argument("--guidance_path", type=str, default=None, help="path to the guidance set (e.g., .pkl file)") 39 | parser.add_argument("--downstream_pth_path", type=str, default=None, help="path to pretrained downstream model for guidance") 40 | parser.add_argument("--guidance_scale", type=float, default=0.0001, help="scaling factor for influence-guided gradients") 41 | 42 | return parser 43 | 44 | def nondefault_trainer_args(opt): 45 | parser = argparse.ArgumentParser() 46 | parser = Trainer.add_argparse_args(parser) 47 | args = parser.parse_args([]) 48 | return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k)) 49 | -------------------------------------------------------------------------------- /diffusion/utils/pkl_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import pickle 5 | from pathlib import Path 6 | 7 | 8 | def load_pkl(path: Path): 9 | """Load pkl from path.""" 10 | with open(path, "rb") as infile: 11 | data = pickle.load(infile) 12 | return data 13 | 14 | 15 | def save_pkl(data: object, path: Path): 16 | """Save pkl to path.""" 17 | with open(path, "wb") as outfile: 18 | pickle.dump(data, outfile) 19 | -------------------------------------------------------------------------------- /diffusion/utils/text_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import torch 5 | import numpy as np 6 | from transformers import AutoTokenizer, AutoModel 7 | import os 8 | 9 | # === CONFIGURATION === 10 | LLAMA_MODEL_NAME = "meta-llama/Llama-3.1-8B" 11 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 12 | 13 | class LlamaTextEncoder: 14 | def __init__(self, model_name=LLAMA_MODEL_NAME, embedding_dim=None, hf_token=None): 15 | """ 16 | Initialize LLaMA tokenizer and model. 17 | 18 | Args: 19 | model_name (str): HF model name for LLaMA. 20 | embedding_dim (int): Embedding size (optional). Will infer if None. 21 | use_auth_token (bool): Whether to use Hugging Face auth token. 22 | """ 23 | print(f"Initializing Text encoder: {model_name} on {DEVICE}") 24 | 25 | self.device = DEVICE 26 | self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token or os.environ.get("HUGGING_FACE_HUB_TOKEN")) 27 | self.model = AutoModel.from_pretrained(model_name, token=hf_token or os.environ.get("HUGGING_FACE_HUB_TOKEN")).to(self.device) 28 | self.model.eval() 29 | 30 | if self.tokenizer.pad_token is None: 31 | print("No pad_token found. Adding pad_token as eos_token...") 32 | self.tokenizer.pad_token = self.tokenizer.eos_token 33 | 34 | # Infer embedding dimension if not provided 35 | self.embedding_dim = embedding_dim or self.model.config.hidden_size 36 | print(f"Text encoder initialized. Embedding dim: {self.embedding_dim}") 37 | 38 | @torch.no_grad() 39 | def encode(self, text_list, batch_size=16, pooling="cls"): 40 | """ 41 | Encode a list of text strings into embeddings. 42 | 43 | Args: 44 | text_list (list of str): List of texts. 45 | batch_size (int): Batch size for inference. 46 | pooling (str): 'cls' or 'mean'. 47 | 48 | Returns: 49 | np.ndarray: Embedding array of shape (num_texts, embedding_dim). 50 | """ 51 | embeddings = [] 52 | print(f"Encoding {len(text_list)} texts...") 53 | 54 | for i in range(0, len(text_list), batch_size): 55 | batch_texts = text_list[i:i + batch_size] 56 | 57 | tokens = self.tokenizer(batch_texts, padding=True, truncation=True, return_tensors="pt").to(self.device) 58 | outputs = self.model(**tokens) 59 | 60 | # Pooling 61 | if pooling == "cls": 62 | # Use CLS token from last hidden state 63 | batch_embeddings = outputs.last_hidden_state[:, 0, :] # (batch_size, hidden_size) 64 | elif pooling == "mean": 65 | attention_mask = tokens['attention_mask'].unsqueeze(-1) 66 | summed = (outputs.last_hidden_state * attention_mask).sum(1) 67 | count = attention_mask.sum(1) 68 | batch_embeddings = summed / count 69 | else: 70 | raise ValueError(f"Unsupported pooling type: {pooling}") 71 | 72 | embeddings.append(batch_embeddings.cpu().numpy()) 73 | 74 | # Stack all batches 75 | embeddings = np.vstack(embeddings) 76 | print(f"Encoding complete! Shape: {embeddings.shape}") 77 | return embeddings 78 | 79 | def encode_single(self, text, pooling="cls"): 80 | """ 81 | Encode a single text string. 82 | """ 83 | return self.encode([text], pooling=pooling)[0] 84 | 85 | -------------------------------------------------------------------------------- /env.txt: -------------------------------------------------------------------------------- 1 | DATA_ROOT=/mnt/storage/ts_data/newer -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: timecraft 2 | channels: 3 | - conda-forge 4 | - pytorch 5 | - nvidia 6 | dependencies: 7 | - python=3.8 8 | - pip=20.3 9 | - cudatoolkit=11.7 10 | - torchvision==0.14.0 11 | - numpy=1.19.2 12 | - scikit-learn 13 | - h5py 14 | - ca-certificates 15 | - openssl 16 | - torchaudio==0.13.0 17 | - pytorch==1.13.0 18 | - certifi 19 | - pytorch-cuda=11.7 20 | - packaging=21.3 21 | - setuptools=69.5.1 22 | - statsmodels 23 | - jupyter 24 | - matplotlib 25 | - wandb 26 | - seaborn 27 | - einops 28 | - tqdm 29 | - scipy 30 | - pandas 31 | - omegaconf 32 | - mkl=2023 33 | - pytorch-lightning=1.4.2 34 | - torchmetrics=0.7.3 35 | prefix: /opt/conda/envs/tsgen 36 | -------------------------------------------------------------------------------- /figures/BRIDGE.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/TimeCraft/156bf4ff7c09c206b2423c39f8206c87f3a8337f/figures/BRIDGE.jpeg -------------------------------------------------------------------------------- /figures/TarDiff_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/TimeCraft/156bf4ff7c09c206b2423c39f8206c87f3a8337f/figures/TarDiff_result.png -------------------------------------------------------------------------------- /figures/TextPreparation.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/TimeCraft/156bf4ff7c09c206b2423c39f8206c87f3a8337f/figures/TextPreparation.jpeg -------------------------------------------------------------------------------- /figures/TimeCraft.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/TimeCraft/156bf4ff7c09c206b2423c39f8206c87f3a8337f/figures/TimeCraft.png -------------------------------------------------------------------------------- /figures/TimeCraft2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/TimeCraft/156bf4ff7c09c206b2423c39f8206c87f3a8337f/figures/TimeCraft2.png -------------------------------------------------------------------------------- /figures/overview_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/TimeCraft/156bf4ff7c09c206b2423c39f8206c87f3a8337f/figures/overview_2.png -------------------------------------------------------------------------------- /figures/prototype_like_words.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/TimeCraft/156bf4ff7c09c206b2423c39f8206c87f3a8337f/figures/prototype_like_words.jpeg -------------------------------------------------------------------------------- /figures/pt_like_word_small.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/TimeCraft/156bf4ff7c09c206b2423c39f8206c87f3a8337f/figures/pt_like_word_small.png -------------------------------------------------------------------------------- /figures/timedp_indomain.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/TimeCraft/156bf4ff7c09c206b2423c39f8206c87f3a8337f/figures/timedp_indomain.png -------------------------------------------------------------------------------- /process/dataset_split.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import os 5 | import pandas as pd 6 | from sklearn.model_selection import train_test_split 7 | import argparse 8 | 9 | def split_dataset(file_path: str, train_ratio: float = 0.8, val_ratio: float = 0.1, test_ratio: float = 0.1): 10 | total = train_ratio + val_ratio + test_ratio 11 | if abs(total - 1.0) > 1e-6: 12 | raise ValueError("Train, val and test ratios must sum to 1.0") 13 | 14 | # Load the dataset 15 | data = pd.read_csv(file_path) 16 | 17 | # First split off the train set 18 | train_data, temp_data = train_test_split(data, train_size=train_ratio, random_state=42) 19 | 20 | # Then split the remaining data into val and test 21 | val_size_adjusted = val_ratio / (val_ratio + test_ratio) 22 | val_data, test_data = train_test_split(temp_data, train_size=val_size_adjusted, random_state=42) 23 | 24 | return train_data, val_data, test_data 25 | 26 | def main(input_dir: str, output_dir: str): 27 | for file_name in os.listdir(input_dir): 28 | if file_name.endswith(".csv") and "train" not in file_name and "val" not in file_name and "test" not in file_name: 29 | file_path = os.path.join(input_dir, file_name) 30 | print(f"Processing file: {file_name}") 31 | 32 | train_data, val_data, test_data = split_dataset(file_path) 33 | 34 | base_name = os.path.splitext(file_name)[0] 35 | train_file_name = f"{base_name}_train.csv" 36 | val_file_name = f"{base_name}_val.csv" 37 | test_file_name = f"{base_name}_test.csv" 38 | 39 | train_data.to_csv(os.path.join(output_dir, train_file_name), index=False) 40 | val_data.to_csv(os.path.join(output_dir, val_file_name), index=False) 41 | test_data.to_csv(os.path.join(output_dir, test_file_name), index=False) 42 | 43 | print(f"Saved {train_file_name}, {val_file_name}, and {test_file_name} to {output_dir}") 44 | 45 | if __name__ == "__main__": 46 | parser = argparse.ArgumentParser(description='Split dataset into train, val, and test sets.') 47 | parser.add_argument('--input_dir', type=str, required=True, help='Directory containing the input files.') 48 | parser.add_argument('--output_dir', type=str, required=True, help='Directory to save the split files.') 49 | 50 | args = parser.parse_args() 51 | main(args.input_dir, args.output_dir) 52 | -------------------------------------------------------------------------------- /process/prompt_bank.js: -------------------------------------------------------------------------------- 1 | function getDatasetDescription(datasetName) { 2 | const datasetDescriptions = { 3 | "weather": "Weather is recorded every 10 minutes for the 2020 whole year, which contains 21 meteorological indicators, such as air temperature, humidity, etc.", 4 | "traffic": "Traffic is a collection of hourly data from California Department of Transportation, which describes the road occupancy rates measured by different sensors on San Francisco Bay area freeways.", 5 | "solar":"The Solar Power Data for Integration Studies consist of 1 year (2006) of 5-minute solar power and hourly day-ahead forecasts for approximately 6,000 simulated PV plants.", 6 | "wind":"This dataset contains a single very long daily time series representing the wind power production in MW recorded per every 4 seconds starting from 01/08/2019. It was downloaded from the Australian Energy Market Operator (AEMO) online platform. The length of this time series is 7397147.", 7 | "electricity":"This data set contains electricity consumption of 370 points/clients. Values are in kW of each 15 min. All time labels report to Portuguese hour. However, all days present 96 measures (24*4). ", 8 | "taxi":"The Taxi Dataset includes millions of individual trip records spanning a period of several years, from 2009 to the present. Each entry includes timestamps for pickup and drop-off, coordinates of pickup/drop-off locations, fare amount, tip amount, payment type, and distance traveled. ", 9 | "pedestrian":"This dataset contains hourly pedestrian counts captured from 66 sensors in Melbourne city starting from May 2009.", 10 | "air":"The Air Dataset contains hourly measurements of air pollutant concentrations and meteorological data collected from multiple monitoring stations within a city or region. This archive contains millions of records collected continuously over a span of more than 20 years, from 2000 to the present. ", 11 | "temperature":"The Temperature Dataset contains daily temperature measurements recorded across multiple weather stations in a specific region. This archive consists of millions of observations collected over a period of more than 50 years, spanning from 1970 to the present. Each record includes the date, location of the weather station, and the respective temperature readings. ", 12 | "rain":"This dataset contains 32072 daily time series showing the temperature observations and rain forecasts, gathered by the Australian Bureau of Meteorology for 422 weather stations across Australia, between 02/05/2015 and 26/04/2017.", 13 | "nn5":"This dataset was used in the NN5 forecasting competition. It contains 111 time series from the banking domain. The goal is predicting the daily cash withdrawals from ATMs in UK.", 14 | "fred":"The FRED-MD dataset contains monthly macroeconomic indicators from the United States, spanning a period of over 60 years. A wide range of economic categories is represented, including real output, income, labor markets, housing, inflation, interest rates, and stock markets. The dataset includes 637,965 observations, collected from various government and financial institutions across the U.S., and is updated on a monthly basis.", 15 | "exchange":"This dataset contains official exchange rates of various currencies against a base currency, as reported by the respective central banks or other official sources. The dataset includes daily exchange rates for a period of 2020 - 2023.", 16 | "m4":"The M4 dataset is a collection of 100,000 time series used for the fourth edition of the Makridakis forecasting Competition. The M4 dataset consists of time series of yearly, quarterly, monthly and other (weekly, daily and hourly) data, which are divided into training and test sets. The minimum numbers of observations in the training test are 13 for yearly, 16 for quarterly, 42 for monthly, 80 for weekly, 93 for daily and 700 for hourly series. The participants were asked to produce the following numbers of forecasts beyond the available data that they had been given: six for yearly, eight for quarterly, 18 for monthly series, 13 for weekly series and 14 and 48 forecasts respectively for the daily and hourly ones.", 17 | "illiness":"The Illness Dataset contains daily records of reported illness cases across multiple healthcare facilities in a specific region. This archive comprises hundreds of thousands of entries collected over a period of 10 years, from 2010 to the present. ", 18 | }; 19 | return datasetDescriptions[datasetName] || "the prompt description for each of dataset."; 20 | } -------------------------------------------------------------------------------- /supplementary/dataset_split.md: -------------------------------------------------------------------------------- 1 | ### Train, Validation, and Test set Split for Time-Series Dataset with Text Descriptions 2 | 3 | After generating textual descriptions for the time-series data, we split the output files into training, validation, and test sets. This ensures the datasets are ready for model training and evaluation. It splits the data into **train**, **validation**, and **test** sets according to a predefined ratio (default: 80% train, 10% val, 10% test). 4 | 5 | The splitting process is implemented in the following script: [Dataset Split Code](https://github.com/chang-xu/TimeGen/blob/main/process/dataset_split.py) 6 | 7 | ### Example Command 8 | ```bash 9 | python dataset_split.py --input_dir ./output_files --output_dir ./split_files 10 | ``` 11 | -------------------------------------------------------------------------------- /supplementary/examples.md: -------------------------------------------------------------------------------- 1 | 2 | ## Example Settings and Expected Results 3 | 4 | ### Demo training and inference of TimeGen with prototypes and text 5 | 6 | ```bash 7 | python training_inference.py \ 8 | --base text_control.yaml \ 9 | --gpus 0, \ 10 | --logdir ./logs/ \ 11 | -sl 168 \ 12 | -up \ 13 | -nl 16 \ 14 | --batch_size 128 \ 15 | -lr 0.0001 \ 16 | --use_text \ 17 | ``` 18 | 19 | ### Demo training and inference of TimeGen with prototypes 20 | 21 | ```bash 22 | python training_inference.py \ 23 | --base multi_domain_timedp.yaml \ 24 | --gpus 0, \ 25 | --logdir ./logs/ \ 26 | -sl 168 \ 27 | -up \ 28 | -nl 16 \ 29 | --batch_size 128 \ 30 | -lr 0.0001 \ 31 | ``` 32 | 33 | ### Demo training and inference of TimeGen with text 34 | 35 | ```bash 36 | python train_inference.py \ 37 | --base text_control.yaml\ 38 | --gpus 0, \ 39 | --logdir ./logs/ \ 40 | -sl 168 \ 41 | -nl 16 \ 42 | --batch_size 128 \ 43 | -lr 0.0001 \ 44 | -use_text \ 45 | ``` 46 | 47 | ### Demo training and inference of TimeGen without text and prototype 48 | 49 | ```bash 50 | python train_inference.py \ 51 | --base multi_domain_timedp.yaml\ 52 | --gpus 0, \ 53 | --logdir ./logs/ \ 54 | -sl 168 \ 55 | -nl 16 \ 56 | --batch_size 128 \ 57 | -lr 0.0001 \ 58 | ``` 59 | 60 | ### Dataset 61 | 62 | The [Electricity dataset](https://archive.ics.uci.edu/ml/datasets/ElectricityLoadDiagrams20112014) is a public multivariate time series dataset widely used for forecasting, anomaly detection, and energy consumption analysis. It contains 15-minute interval electricity consumption records (in kWh) from 370 industrial and residential clients of a Portuguese energy provider, collected between 2011 and 2014. The diverse consumption patterns make it ideal for evaluating machine learning models in multivariate time series forecasting and classification. 63 | ### Example Output 64 | 65 | We evaluate **TimeGen** and baseline models on time series generation tasks. The metrics used are Maximum Mean Discrepancy (MDD) and Kullback-Leibler divergence (K-L), both measuring the similarity between the generated and real data distributions—lower values indicate better performance. **TimeGen** consistently outperforms existing baselines across both metrics. Combining prototypes and text leads to the best results, showing the advantage of integrating structured temporal patterns with semantic information. 66 | 67 | | Model | mdd | k-l | 68 | |----------------|--------|-----------| 69 | | TimeGen with prototypes and text | 0.222 | 0.012 | 70 | | TimeGen with prototypes | 0.237 | 0.016 | 71 | | TimeGen with text | 0.288 | 0.021 | 72 | | TimeGAN | 1.631 | 1.389 | 73 | | GT-GAN | 1.290 | 0.956 | 74 | | TimeVAE | 0.978 | 0.206 | 75 | 76 | Note: The BRIDGE implementation of TimeGen uses a much smaller training dataset compared to TimeDP, due to trade-offs in handling large-scale textual data. 77 | 78 | -------------------------------------------------------------------------------- /supplementary/inference_guidance.md: -------------------------------------------------------------------------------- 1 | # 5.3 Target-Aware Generation for Specific Downstream Tasks 2 | 3 | This advanced mode enables **target-aware generation**, where the model produces time-series data that is **optimized to improve performance on a specific downstream task** (e.g., classification, detection). It integrates **gradient-based guidance** from a pre-trained classifier into the generation process, steering synthetic data toward task-relevant attributes. 4 | 5 | This setup is useful when: 6 | - You want synthetic data to enhance downstream task models 7 | - You need to generate hard or rare samples for classifier robustness 8 | - Controllability is required based on task-specific feedback 9 | 10 | **Example Command:** 11 | 12 | ```bash 13 | python inference.py \ 14 | --base config.yaml \ 15 | --resume true \ 16 | --ckpt_name ./checkpoints/ \ 17 | --use_guidance \ 18 | --uncond \ 19 | --downstream_pth_path ./classifier/checkpoints/best_model.pt \ 20 | --guidance_path ./classifier/data/guidance_tuple.pkl 21 | ``` 22 | 23 | > `--use_guidance` enables classifier-informed generation 24 | > `--guidance_path` must point to a `.pkl` containing the guidance tuples 25 | > This assumes the classifier is already trained and stored at `downstream_pth_path` 26 | 27 | [🔙 Back to Main README](https://github.com/microsoft/TimeCraft) 28 | -------------------------------------------------------------------------------- /supplementary/inference_prototype.md: -------------------------------------------------------------------------------- 1 | # 5.1 Controllable Generation with Domain Prompts 2 | 3 | In this mode, TimeCraft leverages learned **semantic prototypes** (also referred to as "domain prompts") to control the generation of synthetic time-series data. These prototypes encode structural or categorical properties of specific domains (e.g., medical, financial, climate), enabling the model to generate data that conforms to domain-specific characteristics without relying on explicit textual input. 4 | 5 | **Example Command:** 6 | 7 | ```bash 8 | python inference.py \ 9 | --base config.yaml \ 10 | --resume true \ 11 | --ckpt_name ./checkpoints/ \ 12 | --use_pam 13 | ``` 14 | 15 | > `--use_pam` enables prototype-based control 16 | 17 | Inference with unconditional 18 | 19 | **Example Command:** 20 | 21 | ```bash 22 | python inference.py \ 23 | --base config.yaml \ 24 | --resume true \ 25 | --ckpt_name ./checkpoints/ \ 26 | ``` 27 | 28 | [🔙 Back to Main README](https://github.com/microsoft/TimeCraft) 29 | -------------------------------------------------------------------------------- /supplementary/inference_prototype_text.md: -------------------------------------------------------------------------------- 1 | # 5.2 Controllable Generation with Domain Prompts and Text 2 | 3 | This mode combines **textual conditioning** with semantic **prototypes** to offer more expressive and fine-grained control over the generated time series. By including natural language prompts, users can specify high-level trends (e.g., "a rising curve with seasonal dips") or domain-specific features (e.g., "heartbeat pattern after exercise"). 4 | 5 | This setting is especially powerful when: 6 | - Users want to guide generation using natural language descriptions 7 | - Additional domain knowledge needs to be injected beyond prototypes 8 | 9 | **Example Command:** 10 | 11 | ```bash 12 | python inference.py \ 13 | --base config.yaml \ 14 | --resume true \ 15 | --ckpt_name ./checkpoints/ \ 16 | --use_pam \ 17 | --use_text \ 18 | --text_emb_dir ./your_text_embedding_dir/ 19 | ``` 20 | 21 | > `--use_pam` + `--use_text` = joint prototype-text conditioning 22 | > `--text_emb_dir` points to pre-computed text embeddings 23 | 24 | [🔙 Back to Main README](https://github.com/microsoft/TimeCraft) 25 | -------------------------------------------------------------------------------- /supplementary/mimiciii_prepare.md: -------------------------------------------------------------------------------- 1 | We preprocess MIMIC-III by first querying the raw **vitals** and **admissions** tables, then isolating each ICU stay (`icustay_id`) as an independent sample. For every stay we extract seven routinely recorded signals—heart-rate, systolic/diastolic blood pressure, mean arterial pressure, respiratory rate, temperature, oxygen saturation (SpO₂), and urine output—resample them to an equal 1-hour grid, and truncate or zero-pad so every sample is a fixed **24 × 7** time-series matrix covering the first 24 hours in the unit. We attach a binary in-hospital mortality label from the admissions record, stack all samples into a single array, randomly shuffle, and split 80 % / 20 % into training and test sets while reporting the class balance. This yields a clean, length-aligned dataset ready for downstream modeling without exposing any protected health information. -------------------------------------------------------------------------------- /supplementary/training_details.md: -------------------------------------------------------------------------------- 1 | ## Train TimeGen Framework 2 | 3 | Use `main.py` for model training, `inference.py` for model inference and `visualize.py` for domain prompt visualization. 4 | 5 | The detailed descriptions about command line arguments are as follows: 6 | | Parameter Name | Description | 7 | | --------------------------------- | ------------------------------------------------------------------------------------------------------------------ | 8 | | `base` (`-b`) | Paths to base configuration files. | 9 | | `train` (`-t`) | Boolean flag to enable training. (default: true) | 10 | | `debug` (`-d`) | Boolean flag to enter debug mode. (default: false) | 11 | | `seed` (`-s`) | Seed for initializing random number generators. (default: 23) | 12 | | `logdir` (`-l`) | Directory for logging data. (default: ./logs) | 13 | | `seq_len` (`-sl`) | Sequence length for the model. (default: 24) | 14 | | `uncond` (`-uc`) | Boolean flag for unconditional generation. | 15 | | `use_pam` (`-up`) | Boolean flag to use the prototype assignment module. | 16 | | `batch_size` (`-bs`) | Batch size for training. (default: 128) | 17 | | `num_latents` (`-nl`) | Number of latent variables. (default: 16) | 18 | | `overwrite_learning_rate` (`-lr`) | Learning rate to overwrite the config file. (default: None) | 19 | | `gpus` | Comma-separated list of GPU ids to use for training. | 20 | | `ckpt_name` | Checkpoint name to resume from for test or visualization. (default: last) | 21 | | `use_text` | Use text as condition | 22 | 23 | 24 | ### Training and inference together 25 | We provide end-to-end scripts that can be used for both training and inference. 26 | 27 | ```bash 28 | python train_inference.py \ 29 | --base config.yaml \ 30 | --gpus 0, \ 31 | --logdir ./logs/Your_Logidr \ 32 | -sl 168 \ 33 | -up \ 34 | -nl 16 \ 35 | --batch_size 128 \ 36 | -lr 0.0001 \ 37 | ``` 38 | 39 | ### Training with Prototypes and Text 40 | 41 | ```bash 42 | python main.py \ 43 | --base config.yaml \ 44 | --gpus 0, \ 45 | --logdir ./logs/Your_Logidr \ 46 | -sl 168 \ 47 | -up \ 48 | -nl 16 \ 49 | --batch_size 128 \ 50 | -lr 0.0001 \ 51 | --use_text \ 52 | ``` 53 | 54 | ### Training with Prototypes 55 | ```bash 56 | python main.py \ 57 | --base config.yaml \ 58 | --gpus 0, \ 59 | --logdir ./logs/Your_Logidr \ 60 | -sl 168 \ 61 | -up \ 62 | -nl 16 \ 63 | --batch_size 128 \ 64 | -lr 0.0001 \ 65 | ``` 66 | 67 | ### Training with neither Prototypes nor Text 68 | 69 | ```bash 70 | python main.py \ 71 | --base config.yaml \ 72 | --gpus 0, \ 73 | --logdir ./logs/Your_Logidr \ 74 | -sl 168 \ 75 | -nl 16 \ 76 | --batch_size 128 \ 77 | -lr 0.0001 \ 78 | ``` 79 | -------------------------------------------------------------------------------- /train_inference.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | 5 | import os 6 | import sys 7 | import traceback 8 | 9 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 10 | sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'diffusion')) 11 | 12 | from pytorch_lightning.trainer import Trainer 13 | from diffusion.utils.cli_utils import get_parser 14 | from diffusion.utils.init_utils import init_model_data_trainer 15 | from diffusion.utils.test_utils import test_model_with_dp, test_model_uncond, test_model_unseen, test_model_guidance 16 | 17 | if __name__ == "__main__": 18 | 19 | # data_root = os.environ.get('DATA_ROOT', None) 20 | # if not data_root or not os.path.exists(data_root): 21 | # raise ValueError("DATA_ROOT is not defined or does not exist!") 22 | 23 | parser = get_parser() 24 | parser = Trainer.add_argparse_args(parser) 25 | 26 | 27 | model, data, trainer, opt, logdir, melk = init_model_data_trainer(parser) 28 | 29 | if opt.train: 30 | try: 31 | trainer.logger.experiment.config.update(opt) 32 | trainer.fit(model, data) 33 | except Exception as e: 34 | print("Exception occurred during training!") 35 | print(traceback.format_exc()) 36 | 37 | 38 | if trainer is not None and trainer.lightning_module is not None: 39 | print("Attempting to save checkpoint in exception handler via melk() ...") 40 | melk() # 41 | else: 42 | print("Skipped calling melk() because trainer.lightning_module is None") 43 | 44 | raise e # 45 | 46 | if not opt.no_test and not getattr(trainer, "interrupted", False): 47 | if opt.uncond and not opt.use_guidance: 48 | test_model_uncond(model, data, trainer, opt, logdir) 49 | if opt.use_guidance: 50 | test_model_guidance(model, data, trainer, opt, logdir) 51 | else: 52 | test_model_with_dp(model, data, trainer, opt, logdir, use_pam=opt.use_pam, use_text=opt.use_text) 53 | test_model_unseen(model, data, trainer, opt, logdir, use_pam=opt.use_pam, use_text=opt.use_text) 54 | --------------------------------------------------------------------------------