├── tensorflow_chatbots
├── __init__.py
├── tsb
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ └── callback.cpython-37.pyc
│ └── callback.py
├── ttb
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ └── callback.cpython-37.pyc
│ └── callback.py
├── variableholder
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ └── variableholder.cpython-37.pyc
│ └── variableholder.py
└── __pycache__
│ └── __init__.cpython-37.pyc
├── Tensorflow_ChatBots.egg-info
├── dependency_links.txt
├── top_level.txt
├── requires.txt
├── PKG-INFO
└── SOURCES.txt
├── .idea
├── .gitignore
├── vcs.xml
├── inspectionProfiles
│ └── profiles_settings.xml
├── other.xml
├── misc.xml
├── modules.xml
└── Tensorflow-Telegram-Bot.iml
├── .DS_Store
├── examples
├── mnist
│ ├── plot.png
│ └── main.py
└── ppo
│ ├── plot.png
│ ├── save_model
│ ├── policy_net.h5
│ └── value_net.h5
│ ├── __pycache__
│ ├── ppo.cpython-37.pyc
│ ├── train.cpython-37.pyc
│ └── utils.cpython-37.pyc
│ ├── main.py
│ ├── train.py
│ ├── utils.py
│ └── ppo.py
├── images
├── telegram_example_1.png
├── telegram_example_2.png
├── telegram_example_3.png
└── telegram_example_4.png
├── documents
└── RELEASE.md
├── setup.py
└── README.md
/tensorflow_chatbots/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tensorflow_chatbots/tsb/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tensorflow_chatbots/ttb/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tensorflow_chatbots/variableholder/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/Tensorflow_ChatBots.egg-info/dependency_links.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /workspace.xml
3 |
--------------------------------------------------------------------------------
/Tensorflow_ChatBots.egg-info/top_level.txt:
--------------------------------------------------------------------------------
1 | tensorflow_chatbots
2 |
--------------------------------------------------------------------------------
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Yeachan-Heo/Tensorflow-ChatBots/HEAD/.DS_Store
--------------------------------------------------------------------------------
/examples/mnist/plot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Yeachan-Heo/Tensorflow-ChatBots/HEAD/examples/mnist/plot.png
--------------------------------------------------------------------------------
/examples/ppo/plot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Yeachan-Heo/Tensorflow-ChatBots/HEAD/examples/ppo/plot.png
--------------------------------------------------------------------------------
/images/telegram_example_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Yeachan-Heo/Tensorflow-ChatBots/HEAD/images/telegram_example_1.png
--------------------------------------------------------------------------------
/images/telegram_example_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Yeachan-Heo/Tensorflow-ChatBots/HEAD/images/telegram_example_2.png
--------------------------------------------------------------------------------
/images/telegram_example_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Yeachan-Heo/Tensorflow-ChatBots/HEAD/images/telegram_example_3.png
--------------------------------------------------------------------------------
/images/telegram_example_4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Yeachan-Heo/Tensorflow-ChatBots/HEAD/images/telegram_example_4.png
--------------------------------------------------------------------------------
/Tensorflow_ChatBots.egg-info/requires.txt:
--------------------------------------------------------------------------------
1 | tensorflow
2 | gym
3 | numpy
4 | numba
5 | matplotlib
6 | slacker
7 | python-telegram-bot
8 |
--------------------------------------------------------------------------------
/examples/ppo/save_model/policy_net.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Yeachan-Heo/Tensorflow-ChatBots/HEAD/examples/ppo/save_model/policy_net.h5
--------------------------------------------------------------------------------
/examples/ppo/save_model/value_net.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Yeachan-Heo/Tensorflow-ChatBots/HEAD/examples/ppo/save_model/value_net.h5
--------------------------------------------------------------------------------
/examples/ppo/__pycache__/ppo.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Yeachan-Heo/Tensorflow-ChatBots/HEAD/examples/ppo/__pycache__/ppo.cpython-37.pyc
--------------------------------------------------------------------------------
/examples/ppo/__pycache__/train.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Yeachan-Heo/Tensorflow-ChatBots/HEAD/examples/ppo/__pycache__/train.cpython-37.pyc
--------------------------------------------------------------------------------
/examples/ppo/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Yeachan-Heo/Tensorflow-ChatBots/HEAD/examples/ppo/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/tensorflow_chatbots/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Yeachan-Heo/Tensorflow-ChatBots/HEAD/tensorflow_chatbots/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/tensorflow_chatbots/tsb/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Yeachan-Heo/Tensorflow-ChatBots/HEAD/tensorflow_chatbots/tsb/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/tensorflow_chatbots/tsb/__pycache__/callback.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Yeachan-Heo/Tensorflow-ChatBots/HEAD/tensorflow_chatbots/tsb/__pycache__/callback.cpython-37.pyc
--------------------------------------------------------------------------------
/tensorflow_chatbots/ttb/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Yeachan-Heo/Tensorflow-ChatBots/HEAD/tensorflow_chatbots/ttb/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/tensorflow_chatbots/ttb/__pycache__/callback.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Yeachan-Heo/Tensorflow-ChatBots/HEAD/tensorflow_chatbots/ttb/__pycache__/callback.cpython-37.pyc
--------------------------------------------------------------------------------
/tensorflow_chatbots/variableholder/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Yeachan-Heo/Tensorflow-ChatBots/HEAD/tensorflow_chatbots/variableholder/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/tensorflow_chatbots/variableholder/__pycache__/variableholder.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Yeachan-Heo/Tensorflow-ChatBots/HEAD/tensorflow_chatbots/variableholder/__pycache__/variableholder.cpython-37.pyc
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/other.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/Tensorflow_ChatBots.egg-info/PKG-INFO:
--------------------------------------------------------------------------------
1 | Metadata-Version: 1.0
2 | Name: Tensorflow-ChatBots
3 | Version: 0.0.10
4 | Summary: ChatBots supporting TensorFlow
5 | Home-page: https://github.com/Yeachan-Heo/Tensorflow-ChatBots
6 | Author: Yeachan-Heo
7 | Author-email: hurrc04@gmail.com
8 | License: MIT
9 | Description: UNKNOWN
10 | Platform: UNKNOWN
11 |
--------------------------------------------------------------------------------
/documents/RELEASE.md:
--------------------------------------------------------------------------------
1 | # 0.0.0
2 | that was first release
3 | # 0.0.1
4 | debug
5 | # 0.0.2
6 | ## add command: /get.
7 | via command /get, we can get files in local pc. (tensorflow models, logs, etc..)
8 | # 0.0.3
9 | debug
10 | # 0.0.4
11 | ## add command: /bash
12 | via command /bash, we can execute bash commands and check output
13 | # 0.0.5
14 | debug
15 | #0.0.6
16 | add to_csv
17 | #0.0.7
18 | add load_csv
19 | #0.0.8
20 | debug
21 | #0.0.9
22 | add thread
23 | #0.0.10
24 | debug
--------------------------------------------------------------------------------
/examples/ppo/main.py:
--------------------------------------------------------------------------------
1 | from ppo import PPO
2 |
3 | ppo = PPO(token="'1124009259:AAGp7_MhhBBVbxRBysT7YDR5246zfqOXF0Q'",
4 | is_continous=True,
5 | state_size=3,
6 | action_size=1,
7 | lr_value_net=0.0003,
8 | lr_policy_net=0.0003,
9 | updates_n=15,
10 | sample_size=200,
11 | batch_size=16,
12 | total_episodes=2000,
13 | epsilon=0.2,
14 | lamda=0.9,
15 | gamma=0.99
16 | )
17 | if __name__ == '__main__':
18 | ppo.set_env("Pendulum-v0")
19 | ppo()
20 |
--------------------------------------------------------------------------------
/.idea/Tensorflow-Telegram-Bot.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/Tensorflow_ChatBots.egg-info/SOURCES.txt:
--------------------------------------------------------------------------------
1 | README.md
2 | setup.py
3 | Tensorflow_ChatBots.egg-info/PKG-INFO
4 | Tensorflow_ChatBots.egg-info/SOURCES.txt
5 | Tensorflow_ChatBots.egg-info/dependency_links.txt
6 | Tensorflow_ChatBots.egg-info/requires.txt
7 | Tensorflow_ChatBots.egg-info/top_level.txt
8 | tensorflow_chatbots/__init__.py
9 | tensorflow_chatbots/tsb/__init__.py
10 | tensorflow_chatbots/tsb/callback.py
11 | tensorflow_chatbots/ttb/__init__.py
12 | tensorflow_chatbots/ttb/callback.py
13 | tensorflow_chatbots/variableholder/__init__.py
14 | tensorflow_chatbots/variableholder/variableholder.py
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | setup(
4 | name='Tensorflow-ChatBots',
5 | version='0.0.10',
6 | packages=['tensorflow_chatbots', 'tensorflow_chatbots.tsb', 'tensorflow_chatbots.ttb',
7 | 'tensorflow_chatbots.variableholder'],
8 | url='https://github.com/Yeachan-Heo/Tensorflow-ChatBots',
9 | license='MIT',
10 | author='Yeachan-Heo',
11 | author_email='hurrc04@gmail.com',
12 | description='ChatBots supporting TensorFlow', install_requires=['tensorflow', 'gym', 'numpy', 'numba', 'matplotlib',
13 | 'slacker', "python-telegram-bot"]
14 | )
15 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Tensorflow-ChatBots (version 0.0.10)
2 | Telegram/Slack Chatbot class which can be used as keras custom callback.
3 |
4 | # How to Download
5 |
6 | ## from source(recommended)
7 | ### 1.git clone https://github.com/Yeachan-Heo/Tensorflow-ChatBots.git
8 | ### 2.python setup.py build
9 | ### 3.python setup.py install
10 |
11 | ## via pip
12 | ### 1.pip install Tensorflow-Chatbots
13 |
14 | # Telegram Examples
15 | 
16 | 
17 | 
18 | 
19 | # Slack Examples
20 | comming soon!
21 |
--------------------------------------------------------------------------------
/tensorflow_chatbots/variableholder/variableholder.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | # alias
3 | nan = np.nan
4 |
5 |
6 | class VariableHolder:
7 | def __init__(self, **kwargs):
8 | self.add_variables(**kwargs)
9 |
10 | def add_variables(self, **kwargs):
11 | for name, value in kwargs.items():
12 | exec(f"self.{name} = {value}")
13 |
14 | def set_value(self, variable_name: str, value: float):
15 | prev_value = None
16 | if variable_name in dir(self):
17 | exec(f"prev_value = self.{variable_name}")
18 | exec(f"self.{variable_name} = {value}")
19 | success = True
20 | else:
21 | prev_value = None
22 | success = False
23 | return success, prev_value
24 |
--------------------------------------------------------------------------------
/examples/ppo/train.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from utils import *
3 |
4 |
5 | # @tf.function
6 | def train_value_net(model, optimizer, states, rewards):
7 | with tf.GradientTape() as tape:
8 | values = model(states)
9 | loss = tf.losses.mean_squared_error(values, rewards)
10 | grads = tape.gradient(loss, model.trainable_variables)
11 | optimizer.apply_gradients(zip(grads, model.trainable_variables))
12 | return values
13 |
14 |
15 | # @tf.function
16 | def train_policy_net(model, optimizer, epsilon, states, old_actions, old_probs, advantages):
17 | with tf.GradientTape() as tape:
18 | mu, std = model(states)
19 | new_probs = get_prob(mu, std, old_actions)
20 | ratio = old_probs / new_probs
21 | loss = -tf.clip_by_value(ratio, 1 - epsilon, 1 + epsilon) * tf.cast(advantages, tf.float32)
22 | grads = tape.gradient(loss, model.trainable_variables)
23 | optimizer.apply_gradients(zip(grads, model.trainable_variables))
24 |
--------------------------------------------------------------------------------
/examples/mnist/main.py:
--------------------------------------------------------------------------------
1 | # referenced from tensorflow.org
2 |
3 | from tensorflow_chatbots.ttb.callback import TelegramBotCallback
4 | import tensorflow as tf
5 | from tensorflow import keras
6 |
7 | token = "your token"
8 | if token == "your token":
9 | token = input("please enter your telegram bot token")
10 | # 헬퍼(helper) 라이브러리를 임포트합니다
11 |
12 | print(tf.__version__)
13 |
14 | fashion_mnist = keras.datasets.fashion_mnist
15 |
16 | (train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
17 |
18 | train_images = train_images / 255.0
19 |
20 | test_images = test_images / 255.0
21 |
22 | model = keras.Sequential([
23 | keras.layers.Flatten(input_shape=(28, 28)),
24 | keras.layers.Dense(128, activation='relu'),
25 | keras.layers.Dense(10, activation='softmax')
26 | ])
27 |
28 | model.compile(optimizer='adam',
29 | loss='sparse_categorical_crossentropy',
30 | metrics=['accuracy'])
31 |
32 | if __name__ == "__main__":
33 | model.fit(train_images, train_labels, validation_data=(test_images, test_labels), epochs=1000, callbacks=[TelegramBotCallback(token)])
34 |
--------------------------------------------------------------------------------
/tensorflow_chatbots/ttb/callback.py:
--------------------------------------------------------------------------------
1 | from tensorflow_chatbots.tsb.callback import SlackBotCallback
2 | import telegram
3 | import requests
4 | import warnings
5 | import os
6 |
7 |
8 | class TelegramBotCallback(SlackBotCallback):
9 | def __init__(self, token: str, chat_id=None):
10 | warnings.filterwarnings("ignore")
11 | super().__init__(None, None)
12 | # define your bot
13 | self._bot = telegram.Bot(token=token)
14 | self._token = token
15 | self._chat_id = self._get_chat_id() if not chat_id else chat_id
16 |
17 | def _send_message(self, **kwargs):
18 | self._bot.sendMessage(chat_id=self._chat_id, text=kwargs["title"] + "\n" + kwargs["text"])
19 |
20 | def _receive_message(self):
21 | updates = self._bot.getUpdates()
22 | if not updates:
23 | return None
24 | return updates[-1].message.text
25 |
26 | # gets chat id
27 | def _get_chat_id(self) -> str:
28 | updates = self._bot.getUpdates()
29 | chat_id = updates[-1].message.chat.id
30 | return chat_id
31 |
32 | def _send_plot(self, title):
33 | if os.path.exists("./plot.png"):
34 | url = f"https://api.telegram.org/bot{self._token}/sendPhoto"
35 | files = {'photo': open('plot.png', 'rb')}
36 | data = {'chat_id': self._chat_id}
37 | requests.post(url, files=files, data=data)
38 |
39 | def _send_file(self, file_path=None, **kwargs):
40 | url = f"https://api.telegram.org/bot{self._token}/sendDocument"
41 | try:
42 | files = {'document': open(file_path, 'rb')}
43 | except:
44 | files = None
45 | pass
46 | data = {'chat_id': self._chat_id}
47 | requests.post(url, files=files, data=data)
48 |
--------------------------------------------------------------------------------
/examples/ppo/utils.py:
--------------------------------------------------------------------------------
1 | import time
2 | import numpy as np
3 | import tensorflow as tf
4 |
5 |
6 | class Transition(object):
7 | def __init__(self, state, action, reward, next_state, done, prob):
8 | self.state, self.action, self.reward, self.next_state, self.done, self.prob \
9 | = state, action, reward, next_state, done, prob
10 |
11 |
12 | class Timer(object):
13 | def __init__(self):
14 | self.s = 0
15 | self.e = 0
16 | self.hist = []
17 |
18 | def initialize(self):
19 | self.s = time.time()
20 |
21 | def time(self):
22 | self.e = time.time()
23 | time_sec = self.e - self.s
24 | self.hist.append(time_sec)
25 | return time_sec
26 |
27 | @property
28 | def mean_time(self):
29 | return np.mean(self.hist)
30 |
31 |
32 | def split_transitions(transitions: [Transition]):
33 | states = np.array(list(map(lambda x: x.state.squeeze(), transitions)))
34 | actions = np.array(list(map(lambda x: x.action, transitions)))
35 | rewards = np.array(list(map(lambda x: x.reward, transitions)))
36 | next_states = np.array(list(map(lambda x: x.next_state.squeeze(), transitions)))
37 | dones = np.array(list(map(lambda x: x.done, transitions))).astype(np.int32)
38 | probs = np.array(list(map(lambda x: x.prob, transitions))).squeeze()
39 | return states, actions, rewards, next_states, dones, probs
40 |
41 |
42 | def get_prob(mu, std, action):
43 | var = std ** 2
44 | action = tf.cast(action, tf.float32)
45 | prob = -0.5 * (action - mu) ** 2 / var - 0.5 * tf.math.log(var * 2 * 3.141592)
46 | return tf.math.exp(prob)
47 |
48 |
49 | def to_batch(transitions, batch_size):
50 | batch_indexes = list(range(0, len(transitions), batch_size))
51 | batch_indexes = batch_indexes if batch_indexes[-1] == len(transitions) else batch_indexes + [len(transitions)]
52 | batch_transitions = list(map(lambda x: transitions[x[0]:x[1]], zip(batch_indexes[:-1], batch_indexes[1:])))
53 | return batch_transitions
54 |
55 |
56 | def get_gae(vh, rewards, masks, values):
57 | advants = np.zeros_like(rewards)
58 |
59 | previous_value = 0
60 | running_advants = 0
61 |
62 | for t in reversed(range(0, len(rewards))):
63 | running_tderror = rewards[t] + vh.gamma * previous_value * masks[t] - values[t].numpy()
64 | running_advants = running_tderror + vh.gamma * vh.lamda * running_advants * masks[t]
65 |
66 | previous_value = values[t].numpy()
67 | advants[t] = running_advants
68 |
69 | advants = (advants - np.mean(advants)) / np.std(advants)
70 | return advants
71 |
--------------------------------------------------------------------------------
/examples/ppo/ppo.py:
--------------------------------------------------------------------------------
1 | from examples.ppo.utils import get_prob
2 | from tensorflow_chatbots.variableholder.variableholder import VariableHolder
3 | from tensorflow_chatbots.ttb.callback import TelegramBotCallback
4 | from tensorflow.keras import *
5 | from utils import *
6 | from train import *
7 | import random
8 | import psutil
9 | import numpy as np
10 | import gym
11 | import os
12 |
13 |
14 | class PPO(object):
15 | def __init__(self, **kwargs):
16 | self._env = None
17 | self._vh = VariableHolder(**kwargs)
18 | self._policy_net = self._build_policy_net()
19 | self._value_net = self._build_value_net()
20 | if os.path.exists("./save_model/policy_net.h5"):
21 | self._policy_net = models.load_model("./save_model/policy_net.h5")
22 | if os.path.exists("./save_model/value_net.h5"):
23 | self._value_net = models.load_model("./save_model/value_net.h5")
24 | self._bot = TelegramBotCallback(token=self._vh.token)
25 | self._bot.set_variable_holder(self._vh)
26 | self.optimizer = optimizers.Adam
27 | self.episode_timer = Timer()
28 | self.train_timer = Timer()
29 | def set_env(self, env: gym.Env or str):
30 | self._env = env \
31 | if isinstance(env, gym.Env) \
32 | else gym.make(env)
33 |
34 | def _build_policy_net(self):
35 | state = layers.Input(shape=(self._vh.state_size,))
36 | h1 = layers.Dense(32, activation="relu")(state)
37 | h2 = layers.Dense(16, activation="relu")(h1)
38 | mu = layers.Dense(self._vh.action_size, activation="softplus")(h2)
39 | std = layers.Dense(self._vh.action_size, activation="sigmoid")(h2)
40 | outputs = [mu, std]
41 | return models.Model(inputs=state, outputs=outputs)
42 |
43 | def _build_value_net(self):
44 | return models.Sequential([
45 | layers.Dense(32, input_shape=(self._vh.state_size,), activation="relu"),
46 | layers.Dense(16, activation="relu"),
47 | layers.Dense(1, activation="linear")
48 | ])
49 |
50 | def get_action(self, state):
51 | state = state.reshape(1, self._vh.state_size)
52 | mu, std = self._policy_net.predict(state)
53 | action = np.random.normal(mu, std)
54 | prob = get_prob(mu, std, action)
55 | return action, prob
56 |
57 | def train(self, transitions):
58 | if len(transitions) >= self._vh.sample_size:
59 | self.train_timer.initialize()
60 | for update in range(int(self._vh.updates_n)):
61 | random.shuffle(transitions)
62 | for transition_batch in to_batch(transitions, self._vh.batch_size):
63 | # split transitions
64 | states, old_actions, rewards, next_states, dones, old_probs = split_transitions(transition_batch)
65 | # train value net
66 | values = train_value_net(self._value_net, self.optimizer(self._vh.lr_value_net), states, rewards)
67 | # get advantages
68 | advantages = get_gae(self._vh, rewards, (1 - dones), values)
69 | # train policy net
70 | train_policy_net(self._policy_net, self.optimizer(self._vh.lr_policy_net), self._vh.epsilon,
71 | states, old_actions, old_probs, advantages)
72 | self._bot.step()
73 | self.train_timer.time()
74 | transitions.clear()
75 |
76 | def save_model(self):
77 | if not os.path.exists("./save_model"):
78 | os.mkdir("./save_model")
79 | self._policy_net.save("./save_model/policy_net.h5")
80 | self._value_net.save("./save_model/value_net.h5")
81 |
82 | def __call__(self):
83 | transitions = []
84 | for episode in range(int(self._vh.total_episodes)):
85 | self.episode_timer.initialize()
86 | state = self._env.reset()
87 | score = 0
88 | done = False
89 | probs = []
90 | timestep = 0
91 | while not done:
92 | self._bot.step()
93 | action, prob = self.get_action(state)
94 | next_state, reward, done, _ = self._env.step(action)
95 | score += reward
96 | probs.append(prob)
97 | transitions.append(Transition(state, action, reward, next_state, done, prob))
98 | state = next_state
99 | self.train(transitions)
100 | timestep += 1
101 | self.episode_timer.time()
102 | self._bot.add_status(
103 | {"episode": episode, "score": score[0], "avg_prob": np.mean(probs),
104 | "cpu_percent": psutil.cpu_percent(), "timestep": timestep,
105 | "train_time_avg": self.train_timer.mean_time, "episode_time_avg": self.episode_timer.time()})
106 | print(f"episode:{episode}")
107 | self.save_model()
108 |
--------------------------------------------------------------------------------
/tensorflow_chatbots/tsb/callback.py:
--------------------------------------------------------------------------------
1 | from tensorflow_chatbots.variableholder.variableholder import VariableHolder
2 | from tensorflow.keras import *
3 | from slacker import Slacker
4 | from functools import reduce
5 | from threading import Thread
6 | import matplotlib.pyplot as plt
7 | import subprocess as sp
8 | import pandas as pd
9 | import numpy as np
10 | import os
11 | import re
12 |
13 |
14 | class SlackBotCallback(callbacks.Callback):
15 | def __init__(self, token="", channel="#general"):
16 | super().__init__()
17 | self._bot = Slacker(token)
18 | self._channel = channel
19 | self._status_list = []
20 | self.load_csv()
21 | self._variable_holder: VariableHolder or None = None
22 | self._previous_message = None
23 |
24 | def _thread_target(self):
25 | while True:
26 | try:
27 | self.step()
28 | except:
29 | pass
30 |
31 |
32 | def get_thread(self):
33 | return Thread(target=self._thread_target)
34 |
35 | def to_csv(self, path="./log.csv"):
36 | datas = self._get_plot_datas(list(self._current_status.keys()))
37 | df_dict = dict(zip(list(self._current_status.keys()), datas))
38 | df = pd.DataFrame(df_dict)
39 | df.to_csv(path)
40 |
41 | def load_csv(self, path="./log.csv"):
42 | if os.path.exists(path):
43 | df = pd.read_csv(path)
44 | data = df.to_numpy()
45 | dicts = list(map(lambda x: dict(zip(df.columns, x)), data))
46 | else:
47 | dicts = []
48 | self._status_list = dicts
49 |
50 | def set_variable_holder(self, vh: VariableHolder):
51 | self._variable_holder = vh
52 | if self._current_status:
53 | self._variable_holder.add_variables(**self._current_status)
54 |
55 | def add_status(self, status: dict):
56 | self._status_list.append(status)
57 |
58 | def step(self):
59 | try:
60 | message = self._receive_message()
61 | if not self._is_updaten(message):
62 | return
63 | if re.match("/status ", message):
64 | self._command_status(message)
65 | elif re.match("/plot ", message):
66 | self._command_plot(message)
67 | elif re.match("/set ", message):
68 | self._command_set(message)
69 | elif re.match("/get ", message):
70 | self._command_get(message)
71 | elif re.match("/bash ", message):
72 | self._command_bash(message)
73 | elif re.match("/start", message):
74 | self._command_start()
75 | elif re.match("/help", message):
76 | self._command_help()
77 | else:
78 | self._command_invalid(message)
79 | except Exception as e:
80 | self._send_message(title="An Error Occured", text=str(e))
81 |
82 | @property
83 | def _current_status(self):
84 | try:
85 | return self._status_list[-1]
86 | except:
87 | self._send_message(title="status not prepared", text="be patient")
88 |
89 | def _is_updaten(self, message):
90 | updaten = False \
91 | if (message == self._previous_message or message is None) \
92 | else True
93 | self._previous_message = message
94 | return updaten
95 |
96 | def _receive_message(self) -> str:
97 | return self._bot.conversations.history(channel=self._channel)[-1]["text"]
98 |
99 | def _send_message(self, **attachments) -> None:
100 | self._bot.chat.post_message(channel=self._channel, text=None, attachments=[attachments], as_user=True)
101 |
102 | def _send_invalid_argument_message(self, arguments):
103 | text = self._generate_invalid_argument_text(arguments)
104 | title = self._generate_invalid_argument_title()
105 | self._send_message(text=text, title=title)
106 |
107 | def _send_variable_changed_message(self, variable_name, prev_value, current_value):
108 | text = self._generate_variable_changed_text(prev_value, current_value)
109 | title = self._generate_variable_changed_title(variable_name)
110 | self._send_message(text=text, title=title)
111 |
112 | def _send_file(self, file_path=None, **kwargs):
113 | # chaneels, file, title, filetype
114 | # reference: https://talkingaboutme.tistory.com/entry/TIP-message-sending-file-uploading-with-slackclient
115 | self._bot.files.upload(channels=self._channel, file=file_path)
116 |
117 | def _generate_invalid_argument_text(self, arguments):
118 | invalid_argument_text = f"{reduce(lambda x, y: x + ' ' + y, arguments)}"
119 | return invalid_argument_text
120 |
121 | def _generate_variable_changed_text(self, prev_value, current_value):
122 | variable_changed_text = f"({round(current_value, 4)})"
123 | return variable_changed_text
124 |
125 | def _generate_invalid_argument_title(self):
126 | invalid_argument_title = f"Invalid Arguments Occured"
127 | return invalid_argument_title
128 |
129 | def _generate_variable_changed_title(self, variable_name):
130 | variable_changed_title = f"{variable_name} has changed to"
131 | return variable_changed_title
132 |
133 | def _is_valid_arguments(self, arguments):
134 | return np.array(list(map(lambda x: x in self._current_status.keys(), arguments))).all()
135 |
136 | def _command_status(self, message):
137 | if re.match("/status all", message):
138 | arguments = list(self._current_status.keys())
139 | else:
140 | arguments = message[re.match("/status ", message).end():].split(" ")
141 | self._send_status_message(arguments) \
142 | if self._is_valid_arguments(arguments) \
143 | else self._send_invalid_argument_message(arguments)
144 |
145 | def _send_status_message(self, arguments):
146 | text = self._generate_status_text(arguments)
147 | title = self._generate_status_title(arguments)
148 | self._send_message(text=text, title=title)
149 |
150 | def _generate_status_text(self, arguments):
151 | status_string = reduce(lambda x, y: x + y, map(lambda x: f"{x}:{self._current_status[x]}\n", arguments))
152 | return status_string
153 |
154 | def _generate_status_title(self, arguments):
155 | status_title = "Status All" \
156 | if arguments == list(self._current_status.keys()) \
157 | else f"Status About {reduce(lambda x, y: x + ' ' + y, arguments)}"
158 | return status_title
159 |
160 | def _command_plot(self, message):
161 | arguments = list(self._current_status.keys()) \
162 | if re.match("/plot all", message) \
163 | else message[re.match("/plot ", message).end():].split(" ")
164 | title = self._generate_plot_title(arguments)
165 | self._draw_and_send_plot(data_list=self._get_plot_datas(arguments), labels=arguments, title=title) \
166 | if self._is_valid_arguments(arguments) \
167 | else self._send_invalid_argument_message(arguments)
168 |
169 | def _send_plot(self, title):
170 | self._send_file(file_="plot.png", filetype="image", title=title)
171 |
172 | def _generate_plot_title(self, arguments):
173 | plot_title = "All Plots" \
174 | if arguments == list(self._current_status) \
175 | else f"Plots About {reduce(lambda x, y: x + ' ' + y, arguments)}"
176 | return plot_title
177 |
178 | def _draw_and_save_plot(self, data_list, labels):
179 | for data, label in zip(data_list, labels):
180 | plt.plot(data, label=label)
181 | plt.legend()
182 | plt.savefig("plot.png")
183 | plt.clf()
184 |
185 | def _draw_and_send_plot(self, data_list, labels, title):
186 | self._draw_and_save_plot(data_list, labels)
187 | self._send_plot(title)
188 |
189 | def _get_plot_datas(self, arguments):
190 | return list(map(lambda x: list(map(lambda y: y[x], self._status_list)), arguments))
191 |
192 | def _command_set(self, message):
193 | arguments = message[re.match("/set ", message).end():].split(" ")
194 | if len(arguments) is not 2:
195 | self._send_message(text="Invalid message. use /help to see description")
196 | elif arguments[0] == "lr":
197 | self._set_lr(float(arguments[1]))
198 | else:
199 | self._set_variable(variable_name=arguments[0], value=float(arguments[1]))
200 |
201 | def _set_lr(self, value):
202 | prev_lr = self.model.optimizer.lr.numpy()
203 | backend.set_value(self.model.optimizer.lr, value)
204 | self._send_variable_changed_message("learning rate", prev_lr, value)
205 |
206 | def _set_variable(self, variable_name, value):
207 | success, prev_value = self._variable_holder.set_value(variable_name, value)
208 | if not success:
209 | self._send_invalid_argument_message([variable_name])
210 | return
211 | self._send_variable_changed_message(variable_name, prev_value, value)
212 |
213 | def _command_start(self):
214 | self._send_start_message()
215 |
216 | def _send_start_message(self):
217 | text = self._generate_start_text()
218 | title = self._generate_start_title()
219 | self._send_message(text=text, title=title)
220 |
221 | def _generate_start_text(self):
222 | start_text = "Welcome To TensorFlow Slack Bot! type /help to know about commands!"
223 | return start_text
224 |
225 | def _generate_start_title(self):
226 | start_title = "Greetings!"
227 | return start_title
228 |
229 | def _command_help(self):
230 | self._send_help_message()
231 |
232 | def _send_help_message(self):
233 | text = self._generate_help_text()
234 | title = self._generate_help_title()
235 | self._send_message(text=text, title=title)
236 |
237 | def _generate_help_text(self):
238 | help_text = """
239 | /help: shows this helpful message :D
240 | /status:
241 | usage:
242 | /status arg1 arg2 arg3....: prints last value of arg1, arg2, arg3
243 | /status all: prints last value of all arguments
244 | /plot:
245 | usage:
246 | /plot arg1 arg2 arg3....: plots all value of arg1, arg2, arg3 in one figure
247 | /plot all: plots all value of all argument in one figure
248 | /set:
249 | usage:
250 | /set lr : sets learning rate
251 | /set : sets variable in the variable holder
252 | /get:
253 | usage:
254 | /get : sends file
255 | /bash:
256 | usage:
257 | /bash : executes bash command and send output
258 | if you can't get it, why don't you just try?
259 | """
260 | return help_text
261 |
262 | def _generate_help_title(self):
263 | help_message_title = "Help Message"
264 | return help_message_title
265 |
266 | def _command_invalid(self, message):
267 | self._send_invalid_message(message)
268 |
269 | def _send_invalid_message(self, message):
270 | text = self._generate_invalid_text(message)
271 | title = self._generate_invalid_title()
272 | self._send_message(text=text, title=title)
273 |
274 | def _generate_invalid_text(self, message):
275 | invalid_text = f"command usage {message} is invalid. use command /help to see help message"
276 | return invalid_text
277 |
278 | def _generate_invalid_title(self):
279 | invalid_title = "Invalid Command Usage"
280 | return invalid_title
281 |
282 | def _command_get(self, message):
283 | arguments = message[re.match("/get ", message).end():].split(" ")
284 | for file_path in arguments:
285 | self._send_file(file_path) \
286 | if os.path.exists(file_path) \
287 | else self._send_invalid_file_message(file_path)
288 |
289 | def _send_invalid_file_message(self, file_path):
290 | text = self._generate_invalid_file_text(file_path)
291 | title = self._generate_invalid_file_title()
292 | self._send_message(text=text, title=title)
293 |
294 | def _generate_invalid_file_text(self, file_path):
295 | invalid_file_text = f"{file_path} doesn't exists"
296 | return invalid_file_text
297 |
298 | def _generate_invalid_file_title(self):
299 | invalid_file_title = "Invalid File"
300 | return invalid_file_title
301 |
302 | def _command_bash(self, message):
303 | command = message[re.match("/bash ", message).end():]
304 | self._send_bash_message(command)
305 |
306 | def _send_bash_message(self, command):
307 | text = self._generate_bash_text(command)
308 | title = self._generate_bash_title(command)
309 | self._send_message(text=text, title=title)
310 |
311 | def _generate_bash_text(self, command):
312 | try:
313 | bash_text = sp.check_output(command, shell=True).decode("cp949")
314 | except Exception as e:
315 | bash_text = str(e)
316 | return bash_text
317 |
318 | def _generate_bash_title(self, command):
319 | bash_title = f"Executed command {command}"
320 | return bash_title
321 |
322 | def on_train_begin(self, logs=None):
323 | self._command_start()
324 | self.get_thread().start()
325 |
326 | def on_epoch_end(self, epoch, logs=None):
327 | logs["epoch"] = epoch
328 | self.add_status(logs)
329 |
330 | def on_train_end(self, logs=None):
331 | for x in list(self._current_status.keys()):
332 | self._command_plot(message=f"/plot {x}")
333 | self._command_status(message="/status all")
334 |
--------------------------------------------------------------------------------