. This is required to make intent work in code blocks.
43 | val = re.sub(div_pattern, "", val)
44 | # Remove all
. This is required to make underscores work in code blocks.
45 | val = re.sub(span_pattern, "", val)
46 | # Markdown to html
47 | val = markdownify.markdownify(val).strip()
48 | # Reformat code
49 | val = reformat_code(val)
50 |
51 | # Remove noisy "[number] / [number]" at the beginning
52 | noise = re.search(regenerate_pattern, val)
53 | if noise and noise.start() == 0:
54 | val = val[noise.end() :]
55 | # Remove noisy "Copy[number] chars / [number] words"
56 | val = re.sub(copy_chars_pattern, "", val)
57 | # Remove empty code block ```\nCopy code\n```
58 | val = re.sub(copy_code_pattern, "", val)
59 |
60 | # Strip
61 | val = val.replace("\n\n\n", "\n").strip()
62 |
63 | return val
64 |
65 |
66 | def contain_blocked_words(val: str) -> bool:
67 | blocked_words = ["openai", "chatgpt"]
68 | for w in blocked_words:
69 | if w in val.lower():
70 | return True
71 | return False
72 |
73 |
74 | def clean_html_one_sample(sample):
75 | roles = ["human", "gpt"]
76 |
77 | if len(sample["conversations"]) <= 1:
78 | return (sample, 1)
79 |
80 | # Adjust the offset for cases like https://sharegpt.com/c/VyaZlh4
81 | if sample["conversations"][0]["from"] != "human":
82 | sample["conversations"] = sample["conversations"][1:]
83 | if len(sample["conversations"]) <= 1:
84 | return (sample, 1)
85 |
86 | if sample["conversations"][-1]["from"] == "human":
87 | sample["conversations"] = sample["conversations"][:-1]
88 | if len(sample["conversations"]) <= 1:
89 | return (sample, 1)
90 |
91 | for i, c in enumerate(sample["conversations"]):
92 | if c["from"] != roles[i % 2]:
93 | return (sample, 2)
94 |
95 | if contain_blocked_words(c["value"]):
96 | return (sample, 3)
97 |
98 | try:
99 | new_val = html_to_markdown(c["value"])
100 | except (bs4.builder.ParserRejectedMarkup, AssertionError):
101 | return (sample, 4)
102 |
103 | c["value"] = new_val
104 |
105 | return (sample, 0)
106 |
107 |
108 | def clean_html_all(content, begin, end):
109 | """
110 | Clean the source html files.
111 | """
112 | cnt_skip = 0
113 | cnt_blocked_words = 0
114 | cnt_wrong_format = 0
115 | cnt_parser_error = 0
116 | cnt_too_short = 0
117 | cnt_id_duplication = 0
118 | cnt_value_duplication = 0
119 | cnt_tag = 0
120 |
121 | content = content[begin:end]
122 | processed = []
123 | with ProcessPoolExecutor() as executor:
124 | for result in tqdm(
125 | executor.map(clean_html_one_sample, content), total=len(content)
126 | ):
127 | processed.append(result)
128 |
129 | visited = {}
130 | new_content = []
131 | for sample, error_code in tqdm(processed):
132 | cid = sample["id"]
133 | skipped = True
134 |
135 | if error_code != 0:
136 | if error_code == 1:
137 | print(f"id {cid} is too short")
138 | cnt_too_short += 1
139 | elif error_code == 2:
140 | print(f"id {cid} has a wrong format")
141 | cnt_wrong_format += 1
142 | elif error_code == 3:
143 | print(f"id {cid} contains blocked words")
144 | cnt_blocked_words += 1
145 | elif error_code == 4:
146 | print(f"id {cid} contains parser errors")
147 | cnt_parser_error += 1
148 | else:
149 | raise ValueError(f"Invalid error_code: {error_code}")
150 | elif cid in visited:
151 | print(f"id {cid} is an id duplication of {visited[cid]}")
152 | cnt_id_duplication += 1
153 | elif (
154 | sample["conversations"][1]["value"],
155 | len(sample["conversations"]),
156 | ) in visited:
157 | key = (sample["conversations"][1]["value"], len(sample["conversations"]))
158 | print(f"id {cid} is a value duplication of {visited[key]}")
159 | cnt_value_duplication += 1
160 | else:
161 | key = (sample["conversations"][1]["value"], len(sample["conversations"]))
162 | visited[cid] = visited[key] = cid
163 | skipped = False
164 |
165 | if not skipped:
166 | new_content.append(sample)
167 | else:
168 | cnt_skip += 1
169 |
170 | print(
171 | f"total: {len(content)}, skip: {cnt_skip}, new: {len(new_content)}, "
172 | f"cnt_blocked_words: {cnt_blocked_words}, cnt_parser_error: {cnt_parser_error}, "
173 | f"cnt_wrong_format: {cnt_wrong_format}, "
174 | f"cnt_too_short: {cnt_too_short}, cnt_id_duplication: {cnt_id_duplication}, "
175 | f"cnt_value_duplication: {cnt_value_duplication}, "
176 | )
177 |
178 | return new_content
179 |
180 |
181 | def main(args):
182 | content = json.load(open(args["in_file"], "r"))
183 | content = clean_html_all(content, args["begin"], args["end"])
184 | json.dump(content, open(args["out_file"], "w"), indent=2)
185 |
186 |
187 | if __name__ == "__main__":
188 | parser = argparse.ArgumentParser()
189 | parser.add_argument("--in-file", type=str, required=True)
190 | parser.add_argument("--out-file", type=str, default="sharegpt_clean.json")
191 | parser.add_argument("--begin", type=int)
192 | parser.add_argument("--end", type=int)
193 | parser.add_argument("--debug", action="store_true")
194 | args = parser.parse_args()
195 | main(vars(args))
196 |
--------------------------------------------------------------------------------
/pandagpt/code/datasets/samplers.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """batch samplers that work with either random or sequential data samplers"""
16 | import math
17 | import os
18 | import sys
19 |
20 | import torch
21 | from torch.utils import data
22 | import numpy as np
23 |
24 |
25 | class RandomSampler(data.sampler.Sampler):
26 | r"""
27 | Based off of pytorch RandomSampler and DistributedSampler. Essentially a RandomSampler,
28 | but this class lets the user set an epoch like DistributedSampler
29 | Samples elements randomly. If without replacement, then sample from a shuffled dataset.
30 | If with replacement, then user can specify ``num_samples`` to draw.
31 | Arguments:
32 | data_source (Dataset): dataset to sample from
33 | num_samples (int): number of samples to draw, default=len(dataset)
34 | replacement (bool): samples are drawn with replacement if ``True``, default=False
35 | """
36 |
37 | def __init__(self, data_source, replacement=False, num_samples=None):
38 | super(RandomSampler, self).__init__(data_source)
39 | self.data_source = data_source
40 | self.replacement = replacement
41 | self._num_samples = num_samples
42 | self.epoch = -1
43 |
44 | if self._num_samples is not None and replacement is False:
45 | raise ValueError("With replacement=False, num_samples should not be specified, "
46 | "since a random permute will be performed.")
47 |
48 | if not isinstance(self.num_samples, int) or self.num_samples <= 0:
49 | raise ValueError("num_samples should be a positive integer "
50 | "value, but got num_samples={}".format(self.num_samples))
51 | if not isinstance(self.replacement, bool):
52 | raise ValueError("replacement should be a boolean value, but got "
53 | "replacement={}".format(self.replacement))
54 |
55 | @property
56 | def num_samples(self):
57 | # dataset size might change at runtime
58 | if self._num_samples is None:
59 | return len(self.data_source)
60 | return self._num_samples
61 |
62 | def __iter__(self):
63 | n = len(self.data_source)
64 | g = torch.Generator()
65 | if self.epoch >= 0:
66 | g.manual_seed(self.epoch)
67 | if self.replacement:
68 | for _ in range(self.num_samples // 32):
69 | yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=g).tolist()
70 | yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64,
71 | generator=g).tolist()
72 | else:
73 | yield from torch.randperm(n, generator=self.generator).tolist()
74 |
75 | def __len__(self):
76 | return self.num_samples
77 |
78 | def set_epoch(self, epoch):
79 | self.epoch = epoch
80 |
81 |
82 | class DistributedSequentialSampler(data.sampler.Sampler):
83 | def __init__(self, num_samples, train_iters, batch_size, rank=-1, world_size=2):
84 | super().__init__(num_samples)
85 | if rank == -1:
86 | rank = 0
87 | world_size = 1
88 | self.num_samples = num_samples
89 | self.rank = rank
90 | self.world_size = world_size
91 | self.start_iter = 0
92 | self.train_iters = train_iters
93 | self.batch_size = batch_size
94 | self.batch_bias = [i * (num_samples // batch_size) for i in range(batch_size)]
95 |
96 | def __iter__(self):
97 | for idx in range(self.start_iter, self.train_iters * 10):
98 | batch = [(idx + bias) % self.num_samples for bias in self.batch_bias]
99 | tbatch = self._batch(batch)
100 | yield tbatch
101 |
102 | def __len__(self):
103 | return self.train_iters
104 |
105 | def _batch(self, batch):
106 | """extracts samples only pertaining to this worker's batch"""
107 | start = self.rank*self.batch_size//self.world_size
108 | end = (self.rank+1)*self.batch_size//self.world_size
109 | return batch[start:end]
110 |
111 |
112 | class DistributedBatchSampler(data.sampler.BatchSampler):
113 | """
114 | similar to normal implementation of distributed sampler, except implementation is at the
115 | batch sampler level, instead of just the sampler level. This allows wrapping of arbitrary
116 | data samplers (sequential, random, WeightedRandomSampler, etc.) with this batch sampler.
117 | """
118 | def __init__(self, sampler, batch_size, drop_last, rank=-1, world_size=2, wrap_last=False, gradient_accumulation_steps=None):
119 | super(DistributedBatchSampler, self).__init__(sampler, batch_size, drop_last)
120 | if rank == -1:
121 | assert False, 'should not be here'
122 | self.rank = rank
123 | self.world_size = world_size
124 | self.sampler.wrap_around = 0
125 | self.wrap_around = 0
126 | self.wrap_last = wrap_last
127 | self.start_iter = 0
128 | self.effective_batch_size = batch_size if gradient_accumulation_steps is None else batch_size * gradient_accumulation_steps
129 |
130 | def __iter__(self):
131 | batch = []
132 | i = 0
133 | for idx in self.data_iterator(self.sampler, wrap_around=False):
134 | batch.append(idx)
135 | if len(batch) == self.batch_size:
136 | tbatch = self._batch(batch)
137 | if i >= self.start_iter * self.effective_batch_size:
138 | yield tbatch
139 | self.start_iter = 0
140 | i += len(batch)
141 | batch = []
142 | batch_len = len(batch)
143 | if batch_len > 0 and not self.drop_last:
144 | if self.wrap_last:
145 | self.sampler.wrap_around -= (self.batch_size)
146 | self.wrap_around += (len(batch))
147 | self.wrap_around %= self.batch_size
148 | yield self._batch(batch)
149 | if self.wrap_last:
150 | self.sampler.wrap_around += self.batch_size
151 |
152 | def data_iterator(self, _iter, wrap_around=False):
153 | """iterates through data and handles wrap around"""
154 | for i, idx in enumerate(_iter):
155 | if i < self.wrap_around%self.batch_size:
156 | continue
157 | if wrap_around:
158 | self.wrap_around += 1
159 | self.wrap_around %= self.batch_size
160 | yield idx
161 |
162 | def _batch(self, batch):
163 | """extracts samples only pertaining to this worker's batch"""
164 | start = self.rank*self.batch_size//self.world_size
165 | end = (self.rank+1)*self.batch_size//self.world_size
166 | return batch[start:end]
167 |
--------------------------------------------------------------------------------
/llava/serve/gradio_patch.py:
--------------------------------------------------------------------------------
1 | """
2 | Adopted from https://github.com/gradio-app/gradio/blob/main/gradio/components.py
3 | Fix a markdown render problem.
4 | """
5 | from __future__ import annotations
6 |
7 | from gradio.components import *
8 | from markdown2 import Markdown
9 |
10 |
11 | class _Keywords(Enum):
12 | NO_VALUE = "NO_VALUE" # Used as a sentinel to determine if nothing is provided as a argument for `value` in `Component.update()`
13 | FINISHED_ITERATING = "FINISHED_ITERATING" # Used to skip processing of a component's value (needed for generators + state)
14 |
15 |
16 | @document("style")
17 | class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable):
18 | """
19 | Displays a chatbot output showing both user submitted messages and responses. Supports a subset of Markdown including bold, italics, code, and images.
20 | Preprocessing: this component does *not* accept input.
21 | Postprocessing: expects function to return a {List[Tuple[str | None | Tuple, str | None | Tuple]]}, a list of tuples with user message and response messages. Messages should be strings, tuples, or Nones. If the message is a string, it can include Markdown. If it is a tuple, it should consist of (string filepath to image/video/audio, [optional string alt text]). Messages that are `None` are not displayed.
22 |
23 | Demos: chatbot_simple, chatbot_multimodal
24 | """
25 |
26 | def __init__(
27 | self,
28 | value: List[Tuple[str | None, str | None]] | Callable | None = None,
29 | color_map: Dict[str, str] | None = None, # Parameter moved to Chatbot.style()
30 | *,
31 | label: str | None = None,
32 | every: float | None = None,
33 | show_label: bool = True,
34 | visible: bool = True,
35 | elem_id: str | None = None,
36 | elem_classes: List[str] | str | None = None,
37 | **kwargs,
38 | ):
39 | """
40 | Parameters:
41 | value: Default value to show in chatbot. If callable, the function will be called whenever the app loads to set the initial value of the component.
42 | label: component name in interface.
43 | every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute.
44 | show_label: if True, will display label.
45 | visible: If False, component will be hidden.
46 | elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.
47 | elem_classes: An optional list of strings that are assigned as the classes of this component in the HTML DOM. Can be used for targeting CSS styles.
48 | """
49 | if color_map is not None:
50 | warnings.warn(
51 | "The 'color_map' parameter has been deprecated.",
52 | )
53 | #self.md = utils.get_markdown_parser()
54 | self.md = Markdown(extras=["fenced-code-blocks", "tables", "break-on-newline"])
55 | self.select: EventListenerMethod
56 | """
57 | Event listener for when the user selects message from Chatbot.
58 | Uses event data gradio.SelectData to carry `value` referring to text of selected message, and `index` tuple to refer to [message, participant] index.
59 | See EventData documentation on how to use this event data.
60 | """
61 |
62 | IOComponent.__init__(
63 | self,
64 | label=label,
65 | every=every,
66 | show_label=show_label,
67 | visible=visible,
68 | elem_id=elem_id,
69 | elem_classes=elem_classes,
70 | value=value,
71 | **kwargs,
72 | )
73 |
74 | def get_config(self):
75 | return {
76 | "value": self.value,
77 | "selectable": self.selectable,
78 | **IOComponent.get_config(self),
79 | }
80 |
81 | @staticmethod
82 | def update(
83 | value: Any | Literal[_Keywords.NO_VALUE] | None = _Keywords.NO_VALUE,
84 | label: str | None = None,
85 | show_label: bool | None = None,
86 | visible: bool | None = None,
87 | ):
88 | updated_config = {
89 | "label": label,
90 | "show_label": show_label,
91 | "visible": visible,
92 | "value": value,
93 | "__type__": "update",
94 | }
95 | return updated_config
96 |
97 | def _process_chat_messages(
98 | self, chat_message: str | Tuple | List | Dict | None
99 | ) -> str | Dict | None:
100 | if chat_message is None:
101 | return None
102 | elif isinstance(chat_message, (tuple, list)):
103 | mime_type = processing_utils.get_mimetype(chat_message[0])
104 | return {
105 | "name": chat_message[0],
106 | "mime_type": mime_type,
107 | "alt_text": chat_message[1] if len(chat_message) > 1 else None,
108 | "data": None, # These last two fields are filled in by the frontend
109 | "is_file": True,
110 | }
111 | elif isinstance(
112 | chat_message, dict
113 | ): # This happens for previously processed messages
114 | return chat_message
115 | elif isinstance(chat_message, str):
116 | #return self.md.render(chat_message)
117 | return str(self.md.convert(chat_message))
118 | else:
119 | raise ValueError(f"Invalid message for Chatbot component: {chat_message}")
120 |
121 | def postprocess(
122 | self,
123 | y: List[
124 | Tuple[str | Tuple | List | Dict | None, str | Tuple | List | Dict | None]
125 | ],
126 | ) -> List[Tuple[str | Dict | None, str | Dict | None]]:
127 | """
128 | Parameters:
129 | y: List of tuples representing the message and response pairs. Each message and response should be a string, which may be in Markdown format. It can also be a tuple whose first element is a string filepath or URL to an image/video/audio, and second (optional) element is the alt text, in which case the media file is displayed. It can also be None, in which case that message is not displayed.
130 | Returns:
131 | List of tuples representing the message and response. Each message and response will be a string of HTML, or a dictionary with media information.
132 | """
133 | if y is None:
134 | return []
135 | processed_messages = []
136 | for message_pair in y:
137 | assert isinstance(
138 | message_pair, (tuple, list)
139 | ), f"Expected a list of lists or list of tuples. Received: {message_pair}"
140 | assert (
141 | len(message_pair) == 2
142 | ), f"Expected a list of lists of length 2 or list of tuples of length 2. Received: {message_pair}"
143 | processed_messages.append(
144 | (
145 | #self._process_chat_messages(message_pair[0]),
146 | '' +
147 | message_pair[0] + "
",
148 | self._process_chat_messages(message_pair[1]),
149 | )
150 | )
151 | return processed_messages
152 |
153 | def style(self, height: int | None = None, **kwargs):
154 | """
155 | This method can be used to change the appearance of the Chatbot component.
156 | """
157 | if height is not None:
158 | self._style["height"] = height
159 | if kwargs.get("color_map") is not None:
160 | warnings.warn("The 'color_map' parameter has been deprecated.")
161 |
162 | Component.style(
163 | self,
164 | **kwargs,
165 | )
166 | return self
167 |
168 |
169 |
--------------------------------------------------------------------------------
/llava/eval/webpage/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | Who's GPT-4's favorite? Battles between State-of-the-Art Chatbots
7 |
8 |
9 |
10 |
11 |
12 |
13 |
32 |
33 |
34 |
Who's GPT-4's favorite? Battles between State-of-the-Art Chatbots
35 |
36 |
37 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
![other logo]()
63 |
64 |
65 |
66 |
67 |
68 |

69 |
70 |
71 |
72 |
73 |

74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
85 |
86 |
87 |
122 |
123 |
124 |
133 |
134 |
135 |
136 |
137 |
This website is co-authored with GPT-4.
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
160 |
161 |
162 |
163 |
--------------------------------------------------------------------------------
/pandagpt/code/web_demo.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoModel, AutoTokenizer
2 | from copy import deepcopy
3 | import os
4 | import ipdb
5 | import gradio as gr
6 | import mdtex2html
7 | from model.openllama import OpenLLAMAPEFTModel
8 | import torch
9 | import json
10 |
11 | # init the model
12 | args = {
13 | 'model': 'openllama_peft',
14 | 'imagebind_ckpt_path': '../pretrained_ckpt/imagebind_ckpt',
15 | 'vicuna_ckpt_path': '../pretrained_ckpt/vicuna_ckpt/13b_v0',
16 | 'delta_ckpt_path': '../pretrained_ckpt/pandagpt_ckpt/13b/pytorch_model.pt',
17 | 'stage': 2,
18 | 'max_tgt_len': 128,
19 | 'lora_r': 32,
20 | 'lora_alpha': 32,
21 | 'lora_dropout': 0.1,
22 | }
23 | model = OpenLLAMAPEFTModel(**args)
24 | delta_ckpt = torch.load(args['delta_ckpt_path'], map_location=torch.device('cpu'))
25 | model.load_state_dict(delta_ckpt, strict=False)
26 | model = model.eval().half().cuda()
27 | print(f'[!] init the 13b model over ...')
28 |
29 | """Override Chatbot.postprocess"""
30 |
31 |
32 | def postprocess(self, y):
33 | if y is None:
34 | return []
35 | for i, (message, response) in enumerate(y):
36 | y[i] = (
37 | None if message is None else mdtex2html.convert((message)),
38 | None if response is None else mdtex2html.convert(response),
39 | )
40 | return y
41 |
42 |
43 | gr.Chatbot.postprocess = postprocess
44 |
45 |
46 | def parse_text(text):
47 | """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
48 | lines = text.split("\n")
49 | lines = [line for line in lines if line != ""]
50 | count = 0
51 | for i, line in enumerate(lines):
52 | if "```" in line:
53 | count += 1
54 | items = line.split('`')
55 | if count % 2 == 1:
56 | lines[i] = f''
57 | else:
58 | lines[i] = f'
'
59 | else:
60 | if i > 0:
61 | if count % 2 == 1:
62 | line = line.replace("`", "\`")
63 | line = line.replace("<", "<")
64 | line = line.replace(">", ">")
65 | line = line.replace(" ", " ")
66 | line = line.replace("*", "*")
67 | line = line.replace("_", "_")
68 | line = line.replace("-", "-")
69 | line = line.replace(".", ".")
70 | line = line.replace("!", "!")
71 | line = line.replace("(", "(")
72 | line = line.replace(")", ")")
73 | line = line.replace("$", "$")
74 | lines[i] = "
"+line
75 | text = "".join(lines)
76 | return text
77 |
78 |
79 | def re_predict(
80 | input,
81 | image_path,
82 | audio_path,
83 | video_path,
84 | thermal_path,
85 | chatbot,
86 | max_length,
87 | top_p,
88 | temperature,
89 | history,
90 | modality_cache,
91 | ):
92 | # drop the latest query and answers and generate again
93 | q, a = history.pop()
94 | chatbot.pop()
95 | return predict(q, image_path, audio_path, video_path, thermal_path, chatbot, max_length, top_p, temperature, history, modality_cache)
96 |
97 |
98 | def predict(
99 | input,
100 | image_path,
101 | audio_path,
102 | video_path,
103 | thermal_path,
104 | chatbot,
105 | max_length,
106 | top_p,
107 | temperature,
108 | history,
109 | modality_cache,
110 | ):
111 | if image_path is None and audio_path is None and video_path is None and thermal_path is None:
112 | return [(input, "There is no input data provided! Please upload your data and start the conversation.")]
113 | else:
114 | print(f'[!] image path: {image_path}\n[!] audio path: {audio_path}\n[!] video path: {video_path}\n[!] thermal path: {thermal_path}')
115 |
116 | # prepare the prompt
117 | prompt_text = ''
118 | for idx, (q, a) in enumerate(history):
119 | if idx == 0:
120 | prompt_text += f'{q}\n### Assistant: {a}\n###'
121 | else:
122 | prompt_text += f' Human: {q}\n### Assistant: {a}\n###'
123 | if len(history) == 0:
124 | prompt_text += f'{input}'
125 | else:
126 | prompt_text += f' Human: {input}'
127 |
128 | response = model.generate({
129 | 'prompt': prompt_text,
130 | 'image_paths': [image_path] if image_path else [],
131 | 'audio_paths': [audio_path] if audio_path else [],
132 | 'video_paths': [video_path] if video_path else [],
133 | 'thermal_paths': [thermal_path] if thermal_path else [],
134 | 'top_p': top_p,
135 | 'temperature': temperature,
136 | 'max_tgt_len': max_length,
137 | 'modality_embeds': modality_cache
138 | })
139 | chatbot.append((parse_text(input), parse_text(response)))
140 | history.append((input, response))
141 | return chatbot, history, modality_cache
142 |
143 |
144 | def reset_user_input():
145 | return gr.update(value='')
146 |
147 | def reset_dialog():
148 | return [], []
149 |
150 | def reset_state():
151 | return None, None, None, None, [], [], []
152 |
153 |
154 | with gr.Blocks(scale=4) as demo:
155 | gr.HTML("""PandaGPT
""")
156 |
157 | with gr.Row(scale=4):
158 | with gr.Column(scale=1):
159 | image_path = gr.Image(type="filepath", label="Image", value=None)
160 | with gr.Column(scale=1):
161 | audio_path = gr.Audio(type="filepath", label="Audio", value=None)
162 | with gr.Column(scale=1):
163 | video_path = gr.Video(type='file', label="Video")
164 | with gr.Column(scale=1):
165 | thermal_path = gr.Image(type="filepath", label="Thermal Image", value=None)
166 |
167 | chatbot = gr.Chatbot().style(height=300)
168 | with gr.Row():
169 | with gr.Column(scale=4):
170 | with gr.Column(scale=12):
171 | user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(container=False)
172 | with gr.Column(min_width=32, scale=1):
173 | with gr.Row(scale=1):
174 | submitBtn = gr.Button("Submit", variant="primary")
175 | with gr.Row(scale=1):
176 | resubmitBtn = gr.Button("Resubmit", variant="primary")
177 | with gr.Column(scale=1):
178 | emptyBtn = gr.Button("Clear History")
179 | max_length = gr.Slider(0, 400, value=256, step=1.0, label="Maximum length", interactive=True)
180 | top_p = gr.Slider(0, 1, value=0.01, step=0.01, label="Top P", interactive=True)
181 | temperature = gr.Slider(0, 1, value=1.0, step=0.01, label="Temperature", interactive=True)
182 |
183 | history = gr.State([])
184 | modality_cache = gr.State([])
185 |
186 | submitBtn.click(
187 | predict, [
188 | user_input,
189 | image_path,
190 | audio_path,
191 | video_path,
192 | thermal_path,
193 | chatbot,
194 | max_length,
195 | top_p,
196 | temperature,
197 | history,
198 | modality_cache,
199 | ], [
200 | chatbot,
201 | history,
202 | modality_cache
203 | ],
204 | show_progress=True
205 | )
206 |
207 | resubmitBtn.click(
208 | re_predict, [
209 | user_input,
210 | image_path,
211 | audio_path,
212 | video_path,
213 | thermal_path,
214 | chatbot,
215 | max_length,
216 | top_p,
217 | temperature,
218 | history,
219 | modality_cache,
220 | ], [
221 | chatbot,
222 | history,
223 | modality_cache
224 | ],
225 | show_progress=True
226 | )
227 |
228 |
229 | submitBtn.click(reset_user_input, [], [user_input])
230 | emptyBtn.click(reset_state, outputs=[
231 | image_path,
232 | audio_path,
233 | video_path,
234 | thermal_path,
235 | chatbot,
236 | history,
237 | modality_cache
238 | ], show_progress=True)
239 |
240 | demo.queue().launch(share=False, inbrowser=True, server_name='0.0.0.0', server_port=24000)
241 |
--------------------------------------------------------------------------------