├── src └── qshaptools │ ├── __init__.py │ ├── __version__.py │ ├── values.py │ ├── cshap.py │ ├── postprocessing.py │ ├── qshap.py │ ├── tools.py │ ├── qvalues.py │ └── ushap.py ├── _static ├── output.png └── qshap.png ├── LICENSE ├── .gitignore └── README.rst /src/qshaptools/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/qshaptools/__version__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.2" 2 | -------------------------------------------------------------------------------- /_static/output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RaoulHeese/qshaptools/HEAD/_static/output.png -------------------------------------------------------------------------------- /_static/qshap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RaoulHeese/qshaptools/HEAD/_static/qshap.png -------------------------------------------------------------------------------- /src/qshaptools/values.py: -------------------------------------------------------------------------------- 1 | def value_dummy(S, const=1, **kwargs): 2 | return const 3 | 4 | 5 | def value_fun_batch_wrapper_base(S_list, wrapped_value_fun, **kwargs): 6 | # for shapley_iteration_batch 7 | value_list = [] 8 | for S in S_list: 9 | value_list.append(wrapped_value_fun(S, **kwargs)) 10 | return value_list 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Raoul Heese 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/qshaptools/cshap.py: -------------------------------------------------------------------------------- 1 | from ushap import ShapleyValues 2 | 3 | 4 | class ClassicalShapleyValues(ShapleyValues): 5 | def __init__(self, N, value_fun, value_kwargs_dict={}, shap_sample_frac=None, shap_sample_reps=1, 6 | evaluate_value_only_once=False, sample_in_memory=True, shap_sample_seed=None, shap_batch_size=None, 7 | memory=None, 8 | callback=None, delta_exponent=1, name=None, silent=False): 9 | # process options 10 | self._N = int(N) 11 | locked_instructions = [] 12 | unlocked_instructions = list(range(self._N)) 13 | 14 | # initialize 15 | super().__init__(unlocked_instructions, locked_instructions, value_fun, value_kwargs_dict, shap_sample_frac, 16 | shap_sample_reps, shap_batch_size, evaluate_value_only_once, sample_in_memory, 17 | shap_sample_seed, memory, 18 | callback, delta_exponent, name, silent) 19 | 20 | def get_summary_dict(self, property_list=[]): 21 | def get_attr(name): 22 | return getattr(self, name) if hasattr(self, name) else None 23 | 24 | summary = super().get_summary_dict(property_list) 25 | summary.update({'N': get_attr('_N')}) 26 | return summary 27 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # pycharm 132 | .idea/ 133 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | ****************************** 2 | Quantum Shapley Value Toolbox 3 | ****************************** 4 | 5 | .. image:: https://img.shields.io/badge/license-MIT-lightgrey 6 | :target: https://github.com/RaoulHeese/qtree/blob/main/LICENSE 7 | :alt: MIT License 8 | 9 | .. image:: https://github.com/RaoulHeese/qshaptools/blob/master/_static/qshap.png?raw=true 10 | :alt: Title 11 | 12 | Experimental Python toolbox for Shapley values with uncertain value functions in general (see `Shapley Values with Uncertain Value Functions `_) and quantum Shapley values in particular (see `Explaining Quantum Circuits with Shapley Values: Towards Explainable Quantum Machine Learning `_). Quantum Shapley values provide a method to measure the influence of gates within quantum circuit with respect to a freely customizable value function, for example expressibility or entanglement capability. 13 | 14 | 15 | **Usage** 16 | 17 | For quantum Shapley values, the toolbox presumes a representation of quantum circuits via Qiskit. 18 | 19 | Minimal working example: 20 | 21 | .. code-block:: python 22 | 23 | from qiskit import Aer 24 | from qiskit.utils import QuantumInstance 25 | from qiskit.circuit.library import QAOAAnsatz 26 | from qiskit.opflow import PauliSumOp 27 | from qshap import QuantumShapleyValues 28 | from qvalues import value_H 29 | from tools import visualize_shapleys 30 | 31 | # define circuit 32 | H = PauliSumOp.from_list([('ZZI', 1), ('ZII', 2), ('ZIZ', -3)]) 33 | qc = QAOAAnsatz(cost_operator=H, reps=1) 34 | qc = qc.decompose().decompose().decompose() 35 | qc = qc.assign_parameters([0]*len(qc.parameters)) 36 | 37 | # define quantum instance 38 | quantum_instance = QuantumInstance(backend=Aer.get_backend('statevector_simulator')) 39 | 40 | # setup quantum Shapley values 41 | qsv = QuantumShapleyValues(qc, value_fun=value_H, value_kwargs_dict=dict(H=H), quantum_instance=quantum_instance) 42 | print(qsv) 43 | 44 | # evaluate quantum Shapley values 45 | qsv() 46 | 47 | # show results 48 | print(qsv.phi_dict) 49 | visualize_shapleys(qc, phi_dict=qsv.phi_dict).draw() 50 | 51 | As a result, the quantum Shapley values assigned to each gate are plotted: 52 | 53 | .. image:: https://github.com/RaoulHeese/qshaptools/blob/master/_static/output.png?raw=true 54 | :alt: Output 55 | 56 | 57 | 📖 **Citation** 58 | 59 | If you find this code useful in your research, please consider citing `Explaining Quantum Circuits with Shapley Values: Towards Explainable Quantum Machine Learning `_: 60 | 61 | .. code-block:: tex 62 | 63 | @misc{heese2023explaining, 64 | title={Explaining Quantum Circuits with Shapley Values: Towards Explainable Quantum Machine Learning}, 65 | author={Raoul Heese and Thore Gerlach and Sascha Mücke and Sabine Müller and Matthias Jakobs and Nico Piatkowski}, 66 | year={2023}, 67 | eprint={2301.09138}, 68 | archivePrefix={arXiv}, 69 | primaryClass={quant-ph} 70 | } 71 | 72 | *This project is currently not under development and is not actively maintained.* 73 | -------------------------------------------------------------------------------- /src/qshaptools/postprocessing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | from tools import powerset 4 | from ushap import delta_phi_calculation, w_calculation, d_calculation 5 | 6 | 7 | def shapley_value_from_memory_extended(unlocked_instructions, memory, K, 8 | memory_val_fun=lambda memory, K, key, i: np.mean([x[1] for x in memory[key]]), 9 | memory_count_fun=lambda memory, K, key, i: np.sum( 10 | [1 for x in memory[key] if x[0] == i or x[0] is None]), 11 | delta_exponent=1, desc='shap_val'): 12 | total = len(unlocked_instructions) * 2 ** (len(unlocked_instructions) - 1) 13 | phi_dict = dict() 14 | with tqdm(desc=desc, total=total) as prog: 15 | for i in unlocked_instructions: 16 | F = unlocked_instructions 17 | Fi = F.copy() 18 | Fi.pop(F.index(i)) 19 | P, _ = powerset(Fi) 20 | if K is not None: 21 | n = 0 22 | else: 23 | n = 1 24 | phi = 0 25 | for k, S in enumerate(P): 26 | S = sorted(list(S)) 27 | Si = sorted(S + [i]) 28 | key_v = tuple(S) 29 | key_vi = tuple(Si) 30 | if key_v in memory and key_vi in memory: 31 | if K is not None: 32 | count_i = memory_count_fun(memory, K, key_v, i) 33 | count_ii = memory_count_fun(memory, K, key_vi, i) 34 | assert count_i == count_ii 35 | # 36 | if count_i > 0 and count_ii > 0: 37 | v = np.mean(memory_val_fun(memory, K, key_v, i)) 38 | vi = np.mean(memory_val_fun(memory, K, key_vi, i)) 39 | # 40 | n_count = count_i // K 41 | assert float(n_count) == count_i / K 42 | n += n_count 43 | # 44 | phi += n_count * delta_phi_calculation(S, F, v, vi, use_weight=False, 45 | delta_exponent=delta_exponent) 46 | else: 47 | v = memory_val_fun(memory, K, key_v, i) 48 | vi = memory_val_fun(memory, K, key_vi, i) 49 | phi += delta_phi_calculation(S, F, v, vi, use_weight=True, delta_exponent=delta_exponent) 50 | prog.update(1) 51 | phi /= n 52 | phi_dict[i] = phi 53 | return phi_dict 54 | 55 | 56 | def shapley_value_from_memory(unlocked_instructions, memory, K): 57 | return shapley_value_from_memory_extended(unlocked_instructions, memory, K) 58 | 59 | 60 | def shapley_p_from_memory(all_instructions, locked_instructions, memory, verify=False, verify_epsilon=1e-9, 61 | include_locked_instructions_in_key=True): 62 | assert all((idx in all_instructions for idx in locked_instructions)) 63 | assert type(memory) is dict 64 | unlocked_instructions = [idx for idx in all_instructions if idx not in locked_instructions] 65 | total = len(unlocked_instructions) * 2 ** (len(unlocked_instructions) - 1) 66 | # 67 | p_dict = dict() 68 | if verify: 69 | w_sum = dict() 70 | with tqdm(desc='shap_p', total=total) as prog: 71 | for i in unlocked_instructions: 72 | p_dict[i] = {} 73 | F = unlocked_instructions 74 | Fi = F.copy() 75 | Fi.pop(F.index(i)) 76 | P, P_length = powerset(Fi) 77 | for k, S in enumerate(P): 78 | S = list(S) 79 | Si = S + [i] 80 | Sl = sorted(S + locked_instructions if include_locked_instructions_in_key else []) 81 | Sil = sorted(Si + locked_instructions if include_locked_instructions_in_key else []) 82 | key_v = tuple(Sl) 83 | key_vi = tuple(Sil) 84 | if key_v in memory and key_vi in memory: 85 | v = np.mean([vi[1] for vi in memory[key_v]]) 86 | vi = np.mean([vi[1] for vi in memory[key_vi]]) 87 | w = w_calculation(S, F) 88 | d = d_calculation(v, vi, delta_exponent=1) 89 | if d not in p_dict[i]: 90 | p_dict[i][d] = 0 91 | p_dict[i][d] += w 92 | if verify: 93 | if i not in w_sum: 94 | w_sum[i] = 0 95 | w_sum[i] += w 96 | prog.update(1) 97 | if verify: 98 | assert all([np.abs(w - 1) <= np.abs(verify_epsilon) for w in w_sum.values()]), f'{w_sum}' 99 | return p_dict 100 | -------------------------------------------------------------------------------- /src/qshaptools/qshap.py: -------------------------------------------------------------------------------- 1 | from tools import extract_from_circuit 2 | from ushap import ShapleyValues 3 | 4 | 5 | class QuantumShapleyValues(ShapleyValues): 6 | def __init__(self, qc, value_fun, value_kwargs_dict, quantum_instance, shap_sample_frac=None, shap_sample_reps=1, 7 | evaluate_value_only_once=False, sample_in_memory=True, shap_sample_seed=None, shap_batch_size=None, 8 | qc_preprocessing_fun=None, 9 | locked_instructions=None, memory=None, callback=None, delta_exponent=1, name=None, silent=False): 10 | """ 11 | Parameters 12 | ---------- 13 | qc : qiskit.circuit.QuantumCircuit 14 | Quantum circuit of interest. Per default, all gates of the circuit are used as players if not specified as locked_instructions. 15 | value_fun : callable 16 | Value function, i.e., function to map circuits to floats: (qc_data, num_qubits, S, quantum_instance, **kwargs) -> (value) 17 | value_kwargs_dict : dict 18 | Dictionary to be passed to value_fun for every evaluation. 19 | quantum_instance : qiskit.utils.QuantumInstance 20 | Quantum instance to be passed to value_fun for every evaluation. 21 | shap_sample_frac : float (>0) or None or int (<0) (default: None) 22 | Fraction of coalitions that is considered for each player: 23 | a) None: no subsampling. Consider all 2**N coalitions. 24 | b) positive float: subsampling, sample shap_sample_frac*100% of all 2**(N-1) possible coalitions (can be > 1). 25 | c) negative integer: subsampling, sample abs(shap_sample_frac) of all 2**(N-1) possible coalitions (can be > 2**N). 26 | shap_sample_reps : int or none (default: 1) 27 | Number of repeated evaluations for each value function. For each considered coalition, 2*shap_sample_reps value functions are calculated. The mean of all value functions for the same coalition is used to determine the Shapley values. 28 | evaluate_value_only_once : bool (default: False) 29 | If true, evaluate every value function only once and recall from memory afterward. Otherwise, allows to evaluate each value function multiple times. Is required to be false for shap_sample_reps>1 to have any effect. 30 | sample_in_memory : bool (default: True) 31 | If true, store all permutations in memory when sampling value functions (with shap_sample_reps < 1). Otherwise, an on-demand strategy is used that works better for larger circuits. 32 | shap_sample_seed : int or None, optional (default: None) 33 | Random seed for numpy.random.RandomState, used for subsampling. 34 | shap_batch_size : int or None, optional (default: None) 35 | If shap_batch_size is not None, multiple coalitions (up to shap_batch_size) are evaluated at once. This can be useful to submit multiple circuits to a backend at once. 36 | Accordingly, value_fun is expected to be of the form: (qc_data, num_qubits, S_list, quantum_instance, **kwargs) -> (value_list) 37 | qc_preprocessing_fun : callable, optional (default: None) 38 | Preprocessing function for the quantum circuit: (qc) -> (qc), is ignored if None. 39 | locked_instructions : list, optional (default: None) 40 | Gate indices of the circuit that are always activated and do not act as players for Shapley values. None corresponds to [], i.e., all gates are players. 41 | memory : dict, optional (default: None) 42 | Dictionary to recall the memory from previous calculations from (see memory property). Is also used to store new calculations. None corresponds to a fresh/blank memory. 43 | callback : callable, optional (default: None) 44 | Function to be called before every value function evaluation: (S) -> None, is ignored if None. 45 | delta_exponent : int, optional (default: 1) 46 | Can be used to calculate higher moments, see d_calculation. For Shapley values, use 1. 47 | name : str, optional (default: None) 48 | Only for displaying purposes. Defaults to a standard name for None. 49 | silent : bool, optional (default: False) 50 | If True, hide progess bars. 51 | """ 52 | 53 | # preprocess 54 | self._qc_preprocessing_fun = qc_preprocessing_fun 55 | if self._qc_preprocessing_fun is not None: 56 | qc = self._qc_preprocessing_fun(qc) 57 | self._num_qubits, self._qc_data = extract_from_circuit(qc, locked_instructions) 58 | unlocked_instructions = [idx for idx, (instr, qargs, cargs, opts) in enumerate(self._qc_data) if 59 | not opts['lock']] 60 | 61 | # setup value function kwargs 62 | self._quantum_instance = quantum_instance 63 | effective_value_kwargs_dict = {'qc_data': self._qc_data, 'num_qubits': self._num_qubits, 64 | 'quantum_instance': self._quantum_instance} 65 | effective_value_kwargs_dict.update(value_kwargs_dict) 66 | 67 | # initialize 68 | super().__init__(unlocked_instructions, locked_instructions, value_fun, effective_value_kwargs_dict, 69 | shap_sample_frac, shap_sample_reps, shap_batch_size, evaluate_value_only_once, 70 | sample_in_memory, 71 | shap_sample_seed, memory, callback, delta_exponent, name, silent) 72 | 73 | def run(self): 74 | """ 75 | Evaluate Shapley values. 76 | 77 | Returns 78 | ------- 79 | phi_dict : dict 80 | Dictionary of Shapley values of the form {player index: value, ...}. 81 | Result is also stored in phi_dict property. 82 | 83 | """ 84 | return self() 85 | 86 | def get_values(self, S_list, recall=False): 87 | """ 88 | Evaluate value functions. 89 | 90 | Parameters 91 | ---------- 92 | S_list : list 93 | List of coalitions to evaluate (i.e., a list of lists) . 94 | recall : bool, optional (default: False) 95 | If true, recall value function from memory. 96 | 97 | Returns 98 | ------- 99 | values : list 100 | List of values, one float for every coalition. 101 | 102 | """ 103 | return self.eval_S_list(S_list, recall) 104 | 105 | def disp(self): 106 | """ 107 | Print settings. 108 | 109 | Returns 110 | ------- 111 | None. 112 | 113 | """ 114 | print(self.__str__()) 115 | 116 | def get_summary_dict(self, property_list=None): 117 | """ 118 | Return a summary of the most important properties in form of a dictionary. 119 | 120 | Parameters 121 | ---------- 122 | property_list : list, optional (default: []) 123 | List of property names to additionally include in the summary. 124 | 125 | Returns 126 | ------- 127 | summary : dict 128 | Dictionary containing selected properties. 129 | 130 | """ 131 | 132 | if property_list is None: 133 | property_list = [] 134 | 135 | def get_attr(name_): 136 | return getattr(self, name_) if hasattr(self, name_) else None 137 | 138 | summary = super().get_summary_dict(property_list) 139 | summary.update({'quantum_instance': get_attr('_quantum_instance'), 140 | # 'qc': get_attr('_qc'), 141 | 'qc_preprocessing_fun': get_attr('_qc_preprocessing_fun'), 142 | # 'qc_data': get_attr('_qc_data'), 143 | 'num_qubits': get_attr('_num_qubits')}) 144 | return summary 145 | -------------------------------------------------------------------------------- /src/qshaptools/tools.py: -------------------------------------------------------------------------------- 1 | import math 2 | from itertools import chain, combinations, product 3 | 4 | import numpy as np 5 | from qiskit.circuit import ParameterVector, QuantumCircuit 6 | 7 | 8 | def p_coalition(coalition_len, total_len): 9 | p = math.factorial(coalition_len) * math.factorial(total_len - coalition_len - 1) / math.factorial(total_len) 10 | return p 11 | 12 | 13 | def powerset_length(len_iterable): 14 | return 2 ** len_iterable 15 | 16 | 17 | def powerset(iterable): 18 | s = list(iterable) 19 | P_length = powerset_length(len(s)) 20 | P = chain.from_iterable(combinations(s, r) for r in range(len(s) + 1)) 21 | return P, P_length 22 | 23 | 24 | def get_branch_proba(condition): 25 | def p_coalition_bin(b, total_len): 26 | coalition_len = sum(b) 27 | p = math.factorial(coalition_len) * math.factorial(total_len - coalition_len - 1) / math.factorial(total_len) 28 | return p 29 | 30 | def powerset_bin_cond(N, condition={}): 31 | if N - len(condition) < 0: 32 | return (list(condition.values()),) 33 | P = product([0, 1], repeat=N - len(condition)) 34 | idx_map = {i: i if i not in condition else condition[i] for i in range(N)} 35 | 36 | def m(x): 37 | x_ = [0 for _ in range(N)] 38 | n_ = 0 39 | for n in range(N): 40 | if n not in condition: 41 | x_[n] = x[idx_map[n_]] 42 | n_ += 1 43 | else: 44 | x_[n] = condition[n] 45 | return x_ 46 | 47 | return (m(x) for x in P) 48 | 49 | N = len(condition) + 1 50 | return sum([p_coalition_bin(b, N) for b in powerset_bin_cond(N - 1, condition)]) 51 | 52 | 53 | def sample_binary(rng, N): 54 | b = [] 55 | for _ in range(N): 56 | p = rng.rand() 57 | c0 = {i: v for i, v in enumerate(b)} 58 | c0[len(b)] = 0 59 | c1 = {i: v for i, v in enumerate(b)} 60 | c1[len(b)] = 1 61 | p0 = get_branch_proba(c0) 62 | p1 = get_branch_proba(c1) 63 | p0 /= p0 + p1 64 | if p <= p0: 65 | bp = 0 66 | else: 67 | bp = 1 68 | b.append(bp) 69 | return b 70 | 71 | 72 | def build_circuit(qc_data, num_qubits, S=None, cl_bits=True): 73 | qc = QuantumCircuit(num_qubits, num_qubits if cl_bits else 0) 74 | param_def_dict = {} 75 | for idx, qc_data_iter in enumerate(qc_data): 76 | if S is None or idx in S: 77 | try: 78 | (instr, qargs, cargs, opts) = qc_data_iter 79 | except: 80 | (instr, qargs, cargs) = qc_data_iter 81 | # 82 | # name = instr.name 83 | # params = [param for param in instr.params] 84 | qubits = [qubit.index for qubit in qargs] 85 | clbits = [clbit.index for clbit in cargs] 86 | qc.append(instr, qubits, clbits) 87 | # getattr(qc, name)(*params, *qubits, *clbits) 88 | for param in qc.parameters: 89 | param_def_dict[param] = None 90 | return qc, param_def_dict 91 | 92 | 93 | def extract_from_circuit(qc, locked_instructions): 94 | num_qubits = qc.num_qubits 95 | qc_data = [] 96 | for idx, (instr, qargs, cargs) in enumerate(qc.data): 97 | opts = {} 98 | if locked_instructions is not None and idx in locked_instructions: 99 | opts['lock'] = True 100 | else: 101 | opts['lock'] = False 102 | qc_data.append((instr, qargs, cargs, opts)) 103 | return num_qubits, qc_data 104 | 105 | 106 | def evaluate_circuit(qc, param_def_dict, quantum_instance, counts, sv, add_measurement=True): 107 | qc_list = [qc] 108 | param_def_dict_list = [param_def_dict] 109 | results = evaluate_circuits(qc_list, param_def_dict_list, quantum_instance, counts, sv, add_measurement) 110 | if counts and sv: 111 | counts_list, sv_list = results 112 | return counts[0], sv[0] 113 | return results[0] 114 | 115 | 116 | def evaluate_circuits(qc_list, param_def_dict_list, quantum_instance, counts, sv, add_measurement=True): 117 | for idx in range(len(qc_list)): 118 | qc = qc_list[idx] 119 | param_def_dict = param_def_dict_list[idx] 120 | qc = qc.copy().assign_parameters(param_def_dict) 121 | if add_measurement: 122 | qc.measure(range(qc.num_qubits), range(qc.num_qubits)) 123 | qc_list[idx] = qc 124 | result = quantum_instance.execute(qc_list) 125 | if counts: 126 | counts_list = [result.get_counts(qc) for qc in qc_list] 127 | if sv: 128 | sv_list = [result.get_statevector(qc) for qc in qc_list] 129 | if counts and sv: 130 | return counts_list, sv_list 131 | elif counts: 132 | return counts_list 133 | elif sv: 134 | return sv_list 135 | else: 136 | return [None for qc in qc_list] 137 | 138 | 139 | def unbind_parameters(qc, name='theta'): 140 | qc = qc.copy() 141 | pqc = QuantumCircuit(qc.num_qubits, qc.num_qubits) 142 | pvec = ParameterVector(name, 0) 143 | for instr, qargs, cargs in qc: 144 | instr = instr.copy() 145 | if instr.params: 146 | num_params = len(instr.params) 147 | pvec.resize(len(pvec) + num_params) 148 | instr.params = pvec[-num_params:] 149 | pqc.append(instr, qargs, cargs) 150 | return pqc 151 | 152 | 153 | def merge_circuit_instructions(qc, merge_instructions_list, names_list=None): 154 | # check args 155 | l = np.array(list(chain.from_iterable(merge_instructions_list))).ravel() # flat list 156 | assert all(l[i] <= l[i + 1] for i in range(len(l) - 1)), 'instructions unsorted' 157 | assert all([i in l for i in range(max(l) + 1)]), 'instructions left out' 158 | assert all([i in l for i in range(len(qc.data))]), 'instructions missing' 159 | assert names_list is None or len(names_list) == len(merge_instructions_list), 'invalid names' 160 | 161 | # merge 162 | num_qubits = qc.num_qubits 163 | qc_merged = QuantumCircuit(num_qubits, 0) 164 | for merge_idx, merge_instructions in enumerate(merge_instructions_list): 165 | if len(merge_instructions) == 1: 166 | (instr, qubits, clbits) = qc.data[merge_instructions[0]] 167 | qc_merged.append(instr, qubits, clbits) 168 | else: 169 | merge_data = [qc.data[idx] for idx in merge_instructions] 170 | qc_sub = QuantumCircuit(num_qubits, 0) 171 | qubits_all = [] 172 | clbits_all = [] 173 | names = [] 174 | for (instr, qargs, cargs) in merge_data: 175 | qubits = [qubit.index for qubit in qargs] 176 | clbits = [clbit.index for clbit in cargs] 177 | qubits_all.extend(qubits) 178 | clbits_all.extend(clbits) 179 | names.append(instr.name) 180 | qc_sub.append(instr, qubits, clbits) 181 | qubits_all = list(set(qubits_all)) 182 | clbits_all = list(set(clbits_all)) 183 | instr = qc_sub.to_instruction() 184 | if names_list is None or names_list[merge_idx] is None: 185 | name = '(' + '@'.join(names) + ')' 186 | else: 187 | name = names_list[merge_idx] 188 | instr.name = name 189 | qc_merged.append(instr, qubits_all, clbits_all) 190 | return qc_merged 191 | 192 | 193 | def filter_instructions_by_name(qc_data, filter_fun): 194 | filtered_idx = [] 195 | for idx, (instr, qargs, cargs) in enumerate(qc_data): 196 | if filter_fun(instr.name): 197 | filtered_idx.append(idx) 198 | return filtered_idx 199 | 200 | 201 | def remove_instructions_from_circuit(qc, allowed_idx_list): 202 | qc = qc.copy() 203 | if allowed_idx_list is not None: 204 | qc.data = [g for idx, g in enumerate(qc.data) if idx in allowed_idx_list] 205 | return qc 206 | 207 | 208 | def visualize_shapleys(qc, phi_dict=None, label_fun=None, digits=2, max_param_str=0, **kwargs): 209 | if phi_dict is None: 210 | digits = 0 211 | if label_fun is None: 212 | def label_fun(phi_, name_str_, params_str_, digits_, **kwargs): 213 | return f'{name_str_}{params_str_}:{phi_:+.{digits_}f}' 214 | qc_vis = qc.copy() 215 | for i in range(len(qc_vis.data)): 216 | if phi_dict is None or i in phi_dict: 217 | if phi_dict is not None: 218 | phi = phi_dict[i] 219 | else: 220 | phi = i 221 | if len(qc_vis.data[i][0]._params) > 0 and max_param_str > 0: 222 | params_str = ','.join([str(p) for p in qc_vis.data[i][0]._params]) 223 | params_str = '(' + params_str[:max_param_str] + ('...' if len(params_str) > max_param_str else '') + ')' 224 | else: 225 | params_str = '' 226 | name_str = qc_vis.data[i][0].name 227 | qc_vis.data[i][0]._label = label_fun(phi, name_str, params_str, digits, **kwargs) 228 | qc_vis.data[i][0]._params = [] 229 | return qc_vis 230 | -------------------------------------------------------------------------------- /src/qshaptools/qvalues.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | import numpy as np 3 | from qiskit.opflow import AbelianGrouper 4 | from qiskit.quantum_info import Statevector 5 | from scipy.stats import entropy 6 | from tools import build_circuit, evaluate_circuit, evaluate_circuits 7 | from values import value_fun_batch_wrapper_base 8 | 9 | 10 | def value_callable(qc_data, num_qubits, S, quantum_instance, eval_fun, **eval_fun_kwargs): 11 | # value: from callable of the form eval_fun(quantum_instance, qc, param_def_dict, **eval_fun_kwargs) 12 | qc, param_def_dict = build_circuit(qc_data, num_qubits, S) 13 | 14 | # evaluate and return 15 | value = eval_fun(quantum_instance, qc, param_def_dict, **eval_fun_kwargs) 16 | return value 17 | 18 | 19 | def value_fun_batch_wrapper(qc_data, num_qubits, S_list, quantum_instance, wrapped_value_fun, **kwargs): 20 | # for shapley_iteration_batch 21 | return value_fun_batch_wrapper_base(S_list, wrapped_value_fun, qc_data=qc_data, num_qubits=num_qubits, 22 | quantum_instance=quantum_instance, **kwargs) 23 | 24 | 25 | def value_batch_callable(qc_data, num_qubits, S_list, quantum_instance, eval_batch_fun, **eval_fun_kwargs): 26 | # for shapley_iteration_batch 27 | # process S_list 28 | args_list = [build_circuit(qc_data, num_qubits, S) for S in S_list] # list of (qc, param_def_dict) 29 | 30 | # evaluate and return 31 | value_list = eval_batch_fun(quantum_instance, args_list, **eval_fun_kwargs) 32 | return value_list 33 | 34 | 35 | def value_H(qc_data, num_qubits, S, quantum_instance, H): 36 | # value: expectation value of given Hamlitonian H 37 | qc, param_def_dict = build_circuit(qc_data, num_qubits, S) 38 | 39 | # build measurement circuits 40 | grouper = AbelianGrouper() 41 | groups = grouper.convert(H) 42 | circuits = [] 43 | for group in groups: 44 | basis = ['I'] * group.num_qubits 45 | for pauli_string in group.primitive.paulis: 46 | for i, pauli in enumerate(pauli_string): 47 | p = str(pauli) 48 | if p != 'I': 49 | if basis[i] == 'I': 50 | basis[i] = p 51 | elif basis[i] != p: 52 | raise ValueError('PauliSumOp contains non-commuting terms!') 53 | new_qc = qc.copy() 54 | for i, pauli in enumerate(basis): 55 | if pauli == 'X': # H @ X @ H = Z 56 | new_qc.h(i) 57 | if pauli == 'Y': # S^dag @ H @ Y @ H @ S = Z 58 | new_qc.s(i) 59 | new_qc.h(i) 60 | circuits.append(new_qc) 61 | 62 | # check simulator 63 | sv_sim = quantum_instance.is_statevector 64 | 65 | # traverse measurement circuits 66 | value = 0 67 | for group, circuit in zip(groups, circuits): 68 | if sv_sim: 69 | sv = evaluate_circuit(qc, param_def_dict, quantum_instance, counts=False, sv=True, add_measurement=False) 70 | probabilities = sv.probabilities_dict() 71 | else: 72 | counts = evaluate_circuit(qc, param_def_dict, quantum_instance, counts=True, sv=False, add_measurement=True) 73 | shots = sum(counts.values()) 74 | probabilities = {b: c / shots for b, c in counts.items()} 75 | for (pauli, coeff) in zip(group.primitive.paulis, group.primitive.coeffs): 76 | val = 0 77 | p = str(pauli) 78 | for b, prob in probabilities.items(): 79 | val += prob * np.prod([(-1) ** (b[k] == '1' and p[k] != 'I') for k in range(len(b))]) 80 | value += np.real(coeff * val) 81 | return value 82 | 83 | 84 | def value_Expr(qc_data, num_qubits, S, quantum_instance, rng, num_samples, bins, p_lim_fun): 85 | # value: expressibility of parameterized circuit based on num_samples samples (following arXiv:1905.10876v1) 86 | qc, param_def_dict = build_circuit(qc_data, num_qubits, S) 87 | if p_lim_fun is None: 88 | p_lim_fun = lambda p: (-2 * np.pi, 2 * np.pi) 89 | 90 | # expressibility tools 91 | def statevector_overlap(sv1, sv2): 92 | F = np.abs(sv1.inner(sv2)) ** 2 93 | F = np.clip(F, 0, 1) 94 | return F 95 | 96 | def calculate_Fhaar_distribution(num_qubits, bin_edges): 97 | N = 2 ** num_qubits 98 | F_hist_haar = [] 99 | for idx in range(len(bin_edges) - 1): 100 | a = bin_edges[idx] 101 | b = bin_edges[idx + 1] 102 | p = (1 - a) ** (N - 1) - (1 - b) ** (N - 1) # integral over P = (N-1)*(1-F)**(N-2) from a to b 103 | F_hist_haar.append(p) 104 | return np.array(F_hist_haar) 105 | 106 | def calculate_F_distribution(F_list, bins): 107 | F_hist, F_bin_edges = np.histogram(F_list, bins=bins, range=(0, 1)) 108 | F_hist = F_hist / len(F_list) 109 | return F_hist, F_bin_edges 110 | 111 | def kl_div(hist1, hist2, epsilon=1e-14): 112 | hist1[hist1 <= epsilon] = epsilon 113 | hist2[hist2 <= epsilon] = epsilon 114 | return entropy(hist1, hist2) 115 | 116 | def estimate_expressibility(F_hist, F_hist_haar): 117 | return kl_div(F_hist, F_hist_haar) 118 | 119 | # calculate expressibility 120 | if len(param_def_dict) > 0: 121 | F_list = np.empty(num_samples) 122 | for idx in range(num_samples): 123 | param_def_dict_list = [{p: rng.uniform(p_lim_fun(p)[0], p_lim_fun(p)[1]) for p in param_def_dict.keys()} for 124 | _ in range(2)] 125 | sv1, sv2 = evaluate_circuits([qc for _ in range(2)], param_def_dict_list, quantum_instance, counts=False, 126 | sv=True, add_measurement=False) 127 | F = statevector_overlap(sv1, sv2) 128 | F_list[idx] = F 129 | else: 130 | F_list = np.ones(num_samples) 131 | F_hist, F_bin_edges = calculate_F_distribution(F_list, bins) 132 | F_hist_haar = calculate_Fhaar_distribution(num_qubits, F_bin_edges) 133 | Expr = estimate_expressibility(F_hist, F_hist_haar) 134 | 135 | # return expressibility as value 136 | value = Expr 137 | return value 138 | 139 | 140 | def value_Ent(qc_data, num_qubits, S, quantum_instance, rng, num_samples, p_lim_fun, eps=1e-8): 141 | # value: entanglement capability of parameterized circuit based on num_samples samples (following arXiv:1905.10876v1) 142 | qc, param_def_dict = build_circuit(qc_data, num_qubits, S) 143 | if p_lim_fun is None: 144 | p_lim_fun = lambda p: (-2 * np.pi, 2 * np.pi) 145 | 146 | # entanglement capability tools 147 | def get_reduced_sv(sv, j, b, eps=eps): 148 | num_qubits = int(np.log2(len(sv))) 149 | reduced_sv = np.zeros(2 ** (num_qubits - 1), dtype=np.csingle) 150 | for idx, key in enumerate(product('01', repeat=num_qubits)): 151 | key = ''.join(key) 152 | if np.abs(sv[idx]) <= eps: 153 | continue 154 | reduced_key = key[:j] + key[j + 1:] 155 | reduced_idx = int(np.argmax(Statevector.from_label(reduced_key))) 156 | index_key_int = int(key[j]) 157 | reduced_sv[reduced_idx] += sv[idx] * (index_key_int == b) 158 | return reduced_sv # not normalized? 159 | 160 | def get_D(u, v): 161 | M = np.outer(u, v) 162 | return np.sum(np.abs(M - M.transpose()) ** 2) / 2 163 | 164 | def get_Q(sv): 165 | sv = np.asarray(sv).astype(np.csingle) 166 | sv /= np.sqrt(np.sum(np.abs(sv) ** 2)) # ensure normalization 167 | num_qubits = int(np.log2(len(sv))) 168 | q = 0 169 | for j in range(num_qubits): 170 | reduced_sv_0 = get_reduced_sv(sv, j, 0) 171 | reduced_sv_1 = get_reduced_sv(sv, j, 1) 172 | d = get_D(reduced_sv_0, reduced_sv_1) 173 | q += d 174 | return q * 4 / num_qubits 175 | 176 | def estimate_entanglement(Q_list): 177 | Q_mean = np.mean(Q_list) 178 | return Q_mean 179 | 180 | # calculate entanglement capability 181 | if len(param_def_dict) == 0: 182 | num_samples = 1 183 | Q_list = np.empty(num_samples) 184 | for idx in range(num_samples): 185 | param_def_dict_rand = {p: rng.uniform(p_lim_fun(p)[0], p_lim_fun(p)[1]) for p in param_def_dict.keys()} 186 | sv = evaluate_circuit(qc, param_def_dict_rand, quantum_instance, counts=False, sv=True, add_measurement=False) 187 | Q = get_Q(sv) 188 | Q_list[idx] = Q 189 | Ent = estimate_entanglement(Q_list) 190 | 191 | # return entanglement capability as value 192 | value = Ent 193 | return value 194 | 195 | 196 | def value_bits_fun(qc_data, num_qubits, S, quantum_instance, bits_fun): 197 | # value: extraction of measured bits with a custom function 198 | qc, param_def_dict = build_circuit(qc_data, num_qubits, S) 199 | counts = evaluate_circuit(qc, param_def_dict, quantum_instance, counts=True, sv=False) 200 | shots = sum(counts.values()) 201 | value = 0 202 | for bits, count in counts.items(): 203 | bits = [int(i) for i in bits[::-1]] 204 | value += bits_fun(bits) * count 205 | value /= shots 206 | return value 207 | -------------------------------------------------------------------------------- /src/qshaptools/ushap.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | 4 | from tools import p_coalition, powerset_length, powerset, sample_binary 5 | 6 | 7 | def w_calculation(S, F): 8 | return p_coalition(len(S), len(F)) 9 | 10 | 11 | def d_calculation(v, vi, delta_exponent=1): 12 | d = (vi - v) ** delta_exponent 13 | return d 14 | 15 | 16 | def delta_phi_calculation(S, F, v, vi, use_weight=True, delta_exponent=1): 17 | if use_weight: 18 | w = w_calculation(S, F) 19 | else: 20 | w = 1 21 | d = d_calculation(v, vi, delta_exponent) 22 | delta_phi = w * d 23 | return delta_phi 24 | 25 | 26 | class ShapleyValues(): 27 | def __init__(self, unlocked_instructions, locked_instructions, value_fun, value_kwargs_dict, shap_sample_frac, 28 | shap_sample_reps, shap_batch_size, evaluate_value_only_once, sample_in_memory, shap_sample_seed, 29 | memory, callback, 30 | delta_exponent, name, silent): 31 | # setup instructions 32 | if locked_instructions is None: 33 | locked_instructions = [] 34 | self._locked_instructions = sorted(list(locked_instructions)) 35 | unlocked_instructions = list(unlocked_instructions) 36 | if len(unlocked_instructions) == 0: 37 | raise NotImplementedError 38 | self._unlocked_instructions = list(unlocked_instructions) 39 | 40 | # setup shap_sample_frac 41 | if shap_sample_frac is not None: 42 | if shap_sample_frac < 0: 43 | P_length = 2 ** (len(unlocked_instructions) - 1) 44 | shap_sample_frac = abs(int(shap_sample_frac)) / P_length 45 | elif shap_sample_frac == 0: 46 | raise NotImplementedError 47 | self._shap_sample_frac = shap_sample_frac 48 | 49 | # setup shap_sample_reps 50 | self._evaluate_value_only_once = bool(evaluate_value_only_once) 51 | if shap_sample_reps is None: 52 | shap_sample_reps = 1 53 | self._shap_sample_reps = shap_sample_reps 54 | 55 | # setup memory 56 | if memory is None: 57 | memory = dict() 58 | self._memory = memory 59 | 60 | # setup title 61 | if name is not None and len(name) == 0: 62 | name = None 63 | self._name = name 64 | 65 | # setup other properties 66 | self._sample_in_memory = sample_in_memory 67 | self._shap_sample_seed = shap_sample_seed 68 | self._value_fun = value_fun 69 | self._value_kwargs_dict = value_kwargs_dict 70 | self._shap_batch_size = shap_batch_size 71 | self._callback = callback 72 | self._delta_exponent = delta_exponent 73 | self._silent = bool(silent) 74 | 75 | @property 76 | def phi_dict(self): 77 | if hasattr(self, 'phi_dict_'): 78 | return self.phi_dict_ 79 | return None 80 | 81 | @property 82 | def memory(self): 83 | return self._memory 84 | 85 | @property 86 | def name(self): 87 | if self._name is not None: 88 | return f'{self._name}' 89 | else: 90 | return 'shap' 91 | 92 | def _evaluate_value_fun(self, S): 93 | if self._callback is not None: 94 | self._callback(S) 95 | if self._shap_batch_size is None: 96 | return self._value_fun(S=S, **self._value_kwargs_dict) 97 | else: 98 | return self._value_fun(S_list=S, **self._value_kwargs_dict) 99 | 100 | def _sample_S_list(self, Fi, total): 101 | def calculate_num_samples(sample_frac, sample_size): 102 | return int(np.ceil(sample_frac * sample_size)) 103 | 104 | if self._sample_in_memory: 105 | P, P_length = powerset(Fi) 106 | num_samples = calculate_num_samples(self._shap_sample_frac, P_length) 107 | i_array = np.arange(P_length) 108 | P_array = [S for S in P] 109 | p_array = [p_coalition(len(S), total) for S in P_array] 110 | i_samples = self.rng_.choice(i_array, size=num_samples, replace=True, p=p_array) 111 | S_list = [P_array[i] for i in i_samples] 112 | else: 113 | P_length = powerset_length(len(Fi)) 114 | num_samples = calculate_num_samples(self._shap_sample_frac, P_length) 115 | S_list = [[Fi[i] for i, b_ in enumerate(sample_binary(self.rng_, total - 1)) if b_ == 1] for _ in 116 | range(num_samples)] 117 | return num_samples, S_list 118 | 119 | def _build_S_gen(self): 120 | F = self._unlocked_instructions 121 | if self._shap_sample_frac is not None: 122 | # subsample coalition S ~ w for every player 123 | total = len(F) 124 | with tqdm(desc=f'{self.name}:samp', total=total, disable=self._silent) as prog: 125 | self.S_gen_ = dict() 126 | self.S_gen_length_ = 0 127 | self.num_samples_dict_ = dict() 128 | for idx in F: 129 | Fi = F.copy() 130 | Fi.pop(F.index(idx)) 131 | num_samples, S_list = self._sample_S_list(Fi, total) 132 | self.S_gen_[idx] = S_list 133 | self.S_gen_length_ += num_samples 134 | self.num_samples_dict_[idx] = num_samples 135 | prog.update(1) 136 | else: 137 | # use all 2^N coalitions 138 | self.S_gen_, self.S_gen_length_ = powerset(F) 139 | self.num_samples_dict_ = None 140 | 141 | def _build_Si_total_list(self): 142 | self.Si_total_list_ = [] 143 | total = self.S_gen_length_ 144 | with tqdm(desc=f'{self.name}:jobs', total=total, disable=self._silent) as prog: 145 | if self._shap_sample_frac is not None: 146 | for idx, S_list in self.S_gen_.items(): 147 | for S in S_list: 148 | S = sorted(list(S)) 149 | Si = sorted(S + [idx]) 150 | for _ in range(self._shap_sample_reps): 151 | self.Si_total_list_.append([idx, S]) 152 | for _ in range(self._shap_sample_reps): 153 | self.Si_total_list_.append([idx, Si]) 154 | prog.update(1) 155 | else: 156 | S_list = self.S_gen_ 157 | for S in S_list: 158 | S = sorted(list(S)) 159 | for _ in range(self._shap_sample_reps): 160 | self.Si_total_list_.append([None, S]) 161 | prog.update(1) 162 | 163 | def _eval_Si_total_list(self): 164 | L = self._locked_instructions 165 | if self._evaluate_value_only_once: # remove duplicates 166 | Si_xk_dict = {} 167 | for k, x in self.Si_total_list_: 168 | if tuple(x) not in Si_xk_dict: 169 | Si_xk_dict[tuple(x)] = k 170 | Si_effective_total_list = [[k, list(x)] for x, k in Si_xk_dict.items()] 171 | else: 172 | Si_effective_total_list = self.Si_total_list_ 173 | if self._shap_batch_size is not None: 174 | Si_effective_total_list = [Si_batch_list.tolist() for Si_batch_list in 175 | np.array_split(Si_effective_total_list, 176 | np.ceil(len(Si_effective_total_list) / self._shap_batch_size))] 177 | total = sum([len(Si_batch_list) for Si_batch_list in 178 | Si_effective_total_list]) if self._shap_batch_size is not None else len(Si_effective_total_list) 179 | with tqdm(desc=f'{self.name}:eval', total=total, disable=self._silent) as prog: 180 | if self._shap_batch_size is not None: 181 | for Si_batch_list in Si_effective_total_list: 182 | S_batch_list = [sorted(Si[1] + L) for Si in Si_batch_list] 183 | value_list = self._evaluate_value_fun(S_batch_list) 184 | for Si, value in zip(Si_batch_list, value_list): 185 | i = Si[0] 186 | key = tuple(sorted(Si[1])) 187 | if key not in self._memory: 188 | self._memory[key] = [] 189 | self._memory[key].append([i, value]) 190 | prog.update(1) 191 | else: 192 | for Si in Si_effective_total_list: 193 | i = Si[0] 194 | S = sorted(Si[1]) 195 | Sl = sorted(S + L) 196 | value = self._evaluate_value_fun(Sl) 197 | key = tuple(S) 198 | if key not in self._memory: 199 | self._memory[key] = [] 200 | self._memory[key].append([i, value]) 201 | prog.update(1) 202 | 203 | def _eval_shap_idx(self, idx, prog=None): 204 | F = self._unlocked_instructions 205 | if self._shap_sample_frac is not None: 206 | P = self.S_gen_[idx] 207 | use_weight = False 208 | else: 209 | Fi = F.copy() 210 | Fi.pop(F.index(idx)) 211 | P, _ = powerset(Fi) 212 | use_weight = True 213 | phi = 0 214 | for S in P: 215 | S = list(sorted(S)) 216 | key_S = tuple(S) 217 | key_Si = tuple(sorted(S + [idx])) 218 | assert key_S in self._memory and key_Si in self._memory # sanity check, should never be violated 219 | v = np.mean([x[1] for x in self._memory[key_S]]) 220 | vi = np.mean([x[1] for x in self._memory[key_Si]]) 221 | delta_phi = delta_phi_calculation(S, F, v, vi, use_weight=use_weight, delta_exponent=self._delta_exponent) 222 | phi += delta_phi 223 | if prog is not None: 224 | prog.update(1) 225 | if self._shap_sample_frac is not None: 226 | phi /= len(P) 227 | return phi 228 | 229 | def _eval_shap(self): 230 | F = self._unlocked_instructions 231 | self.phi_dict_ = {} 232 | N = len(F) 233 | if self._shap_sample_frac is not None: 234 | self._n_total_effective = sum([len(S_list) for S_list in self.S_gen_.values()]) 235 | else: 236 | self._n_total_effective = N * 2 ** (N - 1) 237 | with tqdm(desc=f'{self.name}:sums', total=self._n_total_effective, disable=self._silent) as prog: 238 | for idx in F: 239 | phi = self._eval_shap_idx(idx, prog) 240 | self.phi_dict_[idx] = phi 241 | 242 | def __call__(self): 243 | self.rng_ = np.random.RandomState(self._shap_sample_seed) 244 | self._build_S_gen() 245 | self._build_Si_total_list() 246 | self._eval_Si_total_list() 247 | self._eval_shap() 248 | return self.phi_dict 249 | 250 | def __str__(self): 251 | # print settings 252 | N = len(self._unlocked_instructions) 253 | M = len(self._locked_instructions) 254 | if self._shap_sample_frac is not None: 255 | self._n_per_phi = int(np.ceil(self._shap_sample_frac * 2 ** (N - 1))) # each evaluated twice: S and S+i 256 | else: 257 | self._n_per_phi = 2 ** (N - 1) 258 | self._n_total = N * self._n_per_phi 259 | self._n_valfun = 2 ** N 260 | rep = f'[{self.name}]\n' 261 | rep += f'value_fun: {str(self._value_fun)}\n' 262 | rep += f'unlocked_instructions [{N:3d}]: {self._unlocked_instructions}\n' 263 | rep += f'locked_instructions [{M:3d}]: {self._locked_instructions}\n' 264 | rep += f'delta_exponent: {self._delta_exponent}\n' 265 | rep += f'shap_sample_frac: {self._shap_sample_frac}\n' 266 | rep += f'shap_sample_reps: {self._shap_sample_reps}\n' 267 | rep += f'evaluate_value_only_once: {self._evaluate_value_only_once}\n' 268 | rep += f'shap_sample_seed: {self._shap_sample_seed}\n' 269 | rep += f'shap_batch_size: {self._shap_batch_size}\n' 270 | rep += f'possible value functions: {self._n_valfun}\n' 271 | rep += f'terms per phi: {self._n_per_phi}\n' 272 | rep += f'total shapley terms: {self._n_total}' 273 | return rep 274 | 275 | def clear_memory(self): 276 | self._memory = dict() 277 | 278 | def eval_S_list(self, S_list, recall): 279 | L = self._locked_instructions 280 | S_list = [list(S) for S in S_list] 281 | values = [] 282 | total = len(S_list) 283 | with tqdm(desc=f'{self.name}:vals', total=total, disable=self._silent) as prog: 284 | for S in S_list: 285 | S = sorted(list(S)) 286 | Sl = sorted(S + L) 287 | key = tuple(S) 288 | if recall and key in self._memory: 289 | value = np.mean([vi[1] for vi in self._memory[key]]) 290 | else: 291 | if self._shap_batch_size is None: 292 | value = self._evaluate_value_fun(Sl) 293 | else: 294 | value = self._evaluate_value_fun([Sl])[0] 295 | values.append(value) 296 | prog.update(1) 297 | return values 298 | 299 | def get_summary_dict(self, property_list=None): 300 | if property_list is None: 301 | property_list = [] 302 | 303 | def get_attr(name_): 304 | return getattr(self, name_) if hasattr(self, name_) else None 305 | 306 | summary = {'name': self.name, 307 | 'value_fun': get_attr('_value_fun'), 308 | 'unlocked_instructions': get_attr('_unlocked_instructions'), 309 | 'locked_instructions': get_attr('_locked_instructions'), 310 | 'delta_exponent': get_attr('_delta_exponent'), 311 | 'shap_sample_frac': get_attr('_shap_sample_frac'), 312 | 'shap_sample_reps': get_attr('_shap_sample_reps'), 313 | 'evaluate_value_only_once': get_attr('_evaluate_value_only_once'), 314 | 'shap_sample_seed': get_attr('_shap_sample_seed'), 315 | 'shap_batch_size': get_attr('_shap_batch_size'), 316 | 'n_valfun': get_attr('_n_valfun'), 317 | 'n_per_phi': get_attr('_n_per_phi'), 318 | 'n_total': get_attr('_n_total'), 319 | 'n_total_effective': get_attr('_n_total_effective'), 320 | 'S_gen_length': get_attr('S_gen_length_'), 321 | 'num_samples_dict': get_attr('num_samples_dict_'), 322 | 'phi_dict': get_attr('phi_dict_') 323 | } 324 | for name in property_list: 325 | summary.update({name: get_attr(name)}) 326 | return summary 327 | --------------------------------------------------------------------------------