├── README.md ├── data ├── .DS_Store ├── gsm8k │ ├── input.jsonl │ ├── shots.md │ └── test.jsonl ├── human_eval │ ├── input.jsonl │ └── test.jsonl ├── mbpp │ ├── input.jsonl │ ├── shots.md │ └── test.jsonl └── strategy_qa │ ├── input.jsonl │ ├── shots.md │ └── test.jsonl ├── decodingmethod ├── __init__.py ├── moe_utils.py └── utils.py ├── evaluation.py ├── generate.py ├── lm_eval ├── __init__.py ├── base.py ├── evaluator.py └── tasks │ ├── __init__.py │ ├── gsm8k.py │ ├── human_eval.py │ ├── mbpp.py │ └── strategy_qa.py └── modeling_models ├── Mixtral ├── config.json ├── configuration_mixtral.py └── modeling_mixtral.py └── deepseek-moe ├── config.json ├── configuration_deepseek.py └── modeling_deepseek.py /README.md: -------------------------------------------------------------------------------- 1 | Unchosen Experts Can Contribute Too: Unleashing MoE Models' Power by Self-Contrast 2 | === 3 | ## Setup 4 | ### Requirements 5 | ``` 6 | pip install datasets 7 | pip install evaluate 8 | pip install absl-py 9 | pip install nltk 10 | pip install pylint 11 | pip install antlr4-python3-runtime==4.11.1 12 | pip install transformers==4.40.0 13 | ``` 14 | ### Add modeling models 15 | 16 | 17 | Replace the configuration_mixtral.py, modeling_mixtral.py in the transformers/src/transfomers/models/mixtral with .py file in ./modeling_models/Mixtral 18 | Replace the config.json file in model's config file with config.json in ./modeling_models/Mixtral 19 | 20 | ### GPU Requirements 21 | 22 | For Mixtral, you should need 4 A100 40G or 2 A100 80G. 23 | For DeepSeekMoE, you should need 2 A100 40G or 1 A100 80G. 24 | 25 | ## Inference 26 | `--decoding_method` refers to a certain method in (greedy, dynamic, cs, dola, cd, scmoe). 27 | `--num_experts_per_tok` refers to number of activation experts for MoE Model or the strong activation of SCMoE. Default to be 2 for Mixtral\ 28 | `--student_num_experts_per_tok` refers to number of activation experts for the weak activation of SCMoE. 29 | `--routed_tok` refers to the routed expert rank id for weak activation. id are begin from 0 to 7 for Mixtral when using rank-$k$ routing 30 | `--cd_beta` refers to the parameter $\beta$ in SCMoE 31 | 32 | 33 | 34 | ``` 35 | task=gsm8k 36 | model_name=Mixtral 37 | decoding_method=scmoe 38 | python3 generate.py\ 39 | --decoding_method ${decoding_method}\ 40 | --infile ./data/${task}/input.jsonl\ 41 | --outfile ./results/${task}/${model_name}_${decoding_method}.jsonl\ 42 | --model ${model_name}\ 43 | --cd_beta ${cd3_beta}\ 44 | --gpus_per_model 4\ 45 | --world_size 4\ 46 | --batch_size 1\ 47 | --num_experts_per_tok 2\ 48 | --student_num_experts_per_tok 1\ 49 | --student_routed_tok 2\ 50 | --max_new_tokens 512 51 | ``` 52 | ## Evaluation 53 | task=gsm8k 54 | model_name=Mixtral 55 | decoding_method=scmoe 56 | ``` 57 | python3 evaluation.py\ 58 | --model ${model_name}\ 59 | --task_name ${task}\ 60 | --load_generations_path ./results/${task}/${model_name}_${decoding_method}.jsonl\ 61 | --metric_output_path ./results/${task}_results.jsonl\ 62 | --allow_code_execution 63 | ``` -------------------------------------------------------------------------------- /data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DavidFanzz/SCMoE/857fad79769c61b25ad87d0618246490bf5551a2/data/.DS_Store -------------------------------------------------------------------------------- /data/gsm8k/shots.md: -------------------------------------------------------------------------------- 1 | Question: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today? 2 | Answer: There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6. The answer is 6. 3 | 4 | Question: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot? 5 | Answer: There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. The answer is 5. 6 | 7 | Question: Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total? 8 | Answer: Originally, Leah had 32 chocolates. Her sister had 42. So in total they had 32 + 42 = 74. After eating 35, they had 74 - 35 = 39. The answer is 39. 9 | 10 | Question: Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny? 11 | Answer: Jason started with 20 lollipops. Then he had 12 after giving some to Denny. So he gave Denny 20 - 12 = 8. The answer is 8. 12 | 13 | Question: Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now? 14 | Answer: Shawn started with 5 toys. If he got 2 toys each from his mom and dad, then that is 4 more toys. 5 + 4 = 9. The answer is 9. 15 | 16 | Question: There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room? 17 | Answer: There were originally 9 computers. For each of 4 days, 5 more computers were added. So 5 * 4 = 20 computers were added. 9 + 20 is 29. The answer is 29. 18 | 19 | Question: Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday? 20 | Answer: Michael started with 58 golf balls. After losing 23 on tuesday, he had 58 - 23 = 35. After losing 2 more, he had 35 - 2 = 33 golf balls. The answer is 33. 21 | 22 | Question: Olivia has $23. She bought five bagels for $3 each. How much money does she have left? 23 | Answer: Olivia had 23 dollars. 5 bagels for 3 dollars each will be 5 x 3 = 15 dollars. So she has 23 - 15 dollars left. 23 - 15 is 8. The answer is 8. -------------------------------------------------------------------------------- /data/human_eval/input.jsonl: -------------------------------------------------------------------------------- 1 | {"instructions":"from typing import List\n\n\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\n \"\"\" Check if in given list of numbers, are any two numbers closer to each other than\n given threshold.\n >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\n False\n >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\n True\n \"\"\"\n"} 2 | {"instructions":"from typing import List\n\n\ndef separate_paren_groups(paren_string: str) -> List[str]:\n \"\"\" Input to this function is a string containing multiple groups of nested parentheses. Your goal is to\n separate those group into separate strings and return the list of those.\n Separate groups are balanced (each open brace is properly closed) and not nested within each other\n Ignore any spaces in the input string.\n >>> separate_paren_groups('( ) (( )) (( )( ))')\n ['()', '(())', '(()())']\n \"\"\"\n"} 3 | {"instructions":"\n\ndef truncate_number(number: float) -> float:\n \"\"\" Given a positive floating point number, it can be decomposed into\n and integer part (largest integer smaller than given number) and decimals\n (leftover part always smaller than 1).\n\n Return the decimal part of the number.\n >>> truncate_number(3.5)\n 0.5\n \"\"\"\n"} 4 | {"instructions":"from typing import List\n\n\ndef below_zero(operations: List[int]) -> bool:\n \"\"\" You're given a list of deposit and withdrawal operations on a bank account that starts with\n zero balance. Your task is to detect if at any point the balance of account fallls below zero, and\n at that point function should return True. Otherwise it should return False.\n >>> below_zero([1, 2, 3])\n False\n >>> below_zero([1, 2, -4, 5])\n True\n \"\"\"\n"} 5 | {"instructions":"from typing import List\n\n\ndef mean_absolute_deviation(numbers: List[float]) -> float:\n \"\"\" For a given list of input numbers, calculate Mean Absolute Deviation\n around the mean of this dataset.\n Mean Absolute Deviation is the average absolute difference between each\n element and a centerpoint (mean in this case):\n MAD = average | x - x_mean |\n >>> mean_absolute_deviation([1.0, 2.0, 3.0, 4.0])\n 1.0\n \"\"\"\n"} 6 | {"instructions":"from typing import List\n\n\ndef intersperse(numbers: List[int], delimeter: int) -> List[int]:\n \"\"\" Insert a number 'delimeter' between every two consecutive elements of input list `numbers'\n >>> intersperse([], 4)\n []\n >>> intersperse([1, 2, 3], 4)\n [1, 4, 2, 4, 3]\n \"\"\"\n"} 7 | {"instructions":"from typing import List\n\n\ndef parse_nested_parens(paren_string: str) -> List[int]:\n \"\"\" Input to this function is a string represented multiple groups for nested parentheses separated by spaces.\n For each of the group, output the deepest level of nesting of parentheses.\n E.g. (()()) has maximum two levels of nesting while ((())) has three.\n\n >>> parse_nested_parens('(()()) ((())) () ((())()())')\n [2, 3, 1, 3]\n \"\"\"\n"} 8 | {"instructions":"from typing import List\n\n\ndef filter_by_substring(strings: List[str], substring: str) -> List[str]:\n \"\"\" Filter an input list of strings only for ones that contain given substring\n >>> filter_by_substring([], 'a')\n []\n >>> filter_by_substring(['abc', 'bacd', 'cde', 'array'], 'a')\n ['abc', 'bacd', 'array']\n \"\"\"\n"} 9 | {"instructions":"from typing import List, Tuple\n\n\ndef sum_product(numbers: List[int]) -> Tuple[int, int]:\n \"\"\" For a given list of integers, return a tuple consisting of a sum and a product of all the integers in a list.\n Empty sum should be equal to 0 and empty product should be equal to 1.\n >>> sum_product([])\n (0, 1)\n >>> sum_product([1, 2, 3, 4])\n (10, 24)\n \"\"\"\n"} 10 | {"instructions":"from typing import List, Tuple\n\n\ndef rolling_max(numbers: List[int]) -> List[int]:\n \"\"\" From a given list of integers, generate a list of rolling maximum element found until given moment\n in the sequence.\n >>> rolling_max([1, 2, 3, 2, 3, 4, 2])\n [1, 2, 3, 3, 3, 4, 4]\n \"\"\"\n"} 11 | {"instructions":"\n\ndef is_palindrome(string: str) -> bool:\n \"\"\" Test if given string is a palindrome \"\"\"\n return string == string[::-1]\n\n\ndef make_palindrome(string: str) -> str:\n \"\"\" Find the shortest palindrome that begins with a supplied string.\n Algorithm idea is simple:\n - Find the longest postfix of supplied string that is a palindrome.\n - Append to the end of the string reverse of a string prefix that comes before the palindromic suffix.\n >>> make_palindrome('')\n ''\n >>> make_palindrome('cat')\n 'catac'\n >>> make_palindrome('cata')\n 'catac'\n \"\"\"\n"} 12 | {"instructions":"from typing import List\n\n\ndef string_xor(a: str, b: str) -> str:\n \"\"\" Input are two strings a and b consisting only of 1s and 0s.\n Perform binary XOR on these inputs and return result also as a string.\n >>> string_xor('010', '110')\n '100'\n \"\"\"\n"} 13 | {"instructions":"from typing import List, Optional\n\n\ndef longest(strings: List[str]) -> Optional[str]:\n \"\"\" Out of list of strings, return the longest one. Return the first one in case of multiple\n strings of the same length. Return None in case the input list is empty.\n >>> longest([])\n\n >>> longest(['a', 'b', 'c'])\n 'a'\n >>> longest(['a', 'bb', 'ccc'])\n 'ccc'\n \"\"\"\n"} 14 | {"instructions":"\n\ndef greatest_common_divisor(a: int, b: int) -> int:\n \"\"\" Return a greatest common divisor of two integers a and b\n >>> greatest_common_divisor(3, 5)\n 1\n >>> greatest_common_divisor(25, 15)\n 5\n \"\"\"\n"} 15 | {"instructions":"from typing import List\n\n\ndef all_prefixes(string: str) -> List[str]:\n \"\"\" Return list of all prefixes from shortest to longest of the input string\n >>> all_prefixes('abc')\n ['a', 'ab', 'abc']\n \"\"\"\n"} 16 | {"instructions":"\n\ndef string_sequence(n: int) -> str:\n \"\"\" Return a string containing space-delimited numbers starting from 0 upto n inclusive.\n >>> string_sequence(0)\n '0'\n >>> string_sequence(5)\n '0 1 2 3 4 5'\n \"\"\"\n"} 17 | {"instructions":"\n\ndef count_distinct_characters(string: str) -> int:\n \"\"\" Given a string, find out how many distinct characters (regardless of case) does it consist of\n >>> count_distinct_characters('xyzXYZ')\n 3\n >>> count_distinct_characters('Jerry')\n 4\n \"\"\"\n"} 18 | {"instructions":"from typing import List\n\n\ndef parse_music(music_string: str) -> List[int]:\n \"\"\" Input to this function is a string representing musical notes in a special ASCII format.\n Your task is to parse this string and return list of integers corresponding to how many beats does each\n not last.\n\n Here is a legend:\n 'o' - whole note, lasts four beats\n 'o|' - half note, lasts two beats\n '.|' - quater note, lasts one beat\n\n >>> parse_music('o o| .| o| o| .| .| .| .| o o')\n [4, 2, 1, 2, 2, 1, 1, 1, 1, 4, 4]\n \"\"\"\n"} 19 | {"instructions":"\n\ndef how_many_times(string: str, substring: str) -> int:\n \"\"\" Find how many times a given substring can be found in the original string. Count overlaping cases.\n >>> how_many_times('', 'a')\n 0\n >>> how_many_times('aaa', 'a')\n 3\n >>> how_many_times('aaaa', 'aa')\n 3\n \"\"\"\n"} 20 | {"instructions":"from typing import List\n\n\ndef sort_numbers(numbers: str) -> str:\n \"\"\" Input is a space-delimited string of numberals from 'zero' to 'nine'.\n Valid choices are 'zero', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight' and 'nine'.\n Return the string with numbers sorted from smallest to largest\n >>> sort_numbers('three one five')\n 'one three five'\n \"\"\"\n"} 21 | {"instructions":"from typing import List, Tuple\n\n\ndef find_closest_elements(numbers: List[float]) -> Tuple[float, float]:\n \"\"\" From a supplied list of numbers (of length at least two) select and return two that are the closest to each\n other and return them in order (smaller number, larger number).\n >>> find_closest_elements([1.0, 2.0, 3.0, 4.0, 5.0, 2.2])\n (2.0, 2.2)\n >>> find_closest_elements([1.0, 2.0, 3.0, 4.0, 5.0, 2.0])\n (2.0, 2.0)\n \"\"\"\n"} 22 | {"instructions":"from typing import List\n\n\ndef rescale_to_unit(numbers: List[float]) -> List[float]:\n \"\"\" Given list of numbers (of at least two elements), apply a linear transform to that list,\n such that the smallest number will become 0 and the largest will become 1\n >>> rescale_to_unit([1.0, 2.0, 3.0, 4.0, 5.0])\n [0.0, 0.25, 0.5, 0.75, 1.0]\n \"\"\"\n"} 23 | {"instructions":"from typing import List, Any\n\n\ndef filter_integers(values: List[Any]) -> List[int]:\n \"\"\" Filter given list of any python values only for integers\n >>> filter_integers(['a', 3.14, 5])\n [5]\n >>> filter_integers([1, 2, 3, 'abc', {}, []])\n [1, 2, 3]\n \"\"\"\n"} 24 | {"instructions":"\n\ndef strlen(string: str) -> int:\n \"\"\" Return length of given string\n >>> strlen('')\n 0\n >>> strlen('abc')\n 3\n \"\"\"\n"} 25 | {"instructions":"\n\ndef largest_divisor(n: int) -> int:\n \"\"\" For a given number n, find the largest number that divides n evenly, smaller than n\n >>> largest_divisor(15)\n 5\n \"\"\"\n"} 26 | {"instructions":"from typing import List\n\n\ndef factorize(n: int) -> List[int]:\n \"\"\" Return list of prime factors of given integer in the order from smallest to largest.\n Each of the factors should be listed number of times corresponding to how many times it appeares in factorization.\n Input number should be equal to the product of all factors\n >>> factorize(8)\n [2, 2, 2]\n >>> factorize(25)\n [5, 5]\n >>> factorize(70)\n [2, 5, 7]\n \"\"\"\n"} 27 | {"instructions":"from typing import List\n\n\ndef remove_duplicates(numbers: List[int]) -> List[int]:\n \"\"\" From a list of integers, remove all elements that occur more than once.\n Keep order of elements left the same as in the input.\n >>> remove_duplicates([1, 2, 3, 2, 4])\n [1, 3, 4]\n \"\"\"\n"} 28 | {"instructions":"\n\ndef flip_case(string: str) -> str:\n \"\"\" For a given string, flip lowercase characters to uppercase and uppercase to lowercase.\n >>> flip_case('Hello')\n 'hELLO'\n \"\"\"\n"} 29 | {"instructions":"from typing import List\n\n\ndef concatenate(strings: List[str]) -> str:\n \"\"\" Concatenate list of strings into a single string\n >>> concatenate([])\n ''\n >>> concatenate(['a', 'b', 'c'])\n 'abc'\n \"\"\"\n"} 30 | {"instructions":"from typing import List\n\n\ndef filter_by_prefix(strings: List[str], prefix: str) -> List[str]:\n \"\"\" Filter an input list of strings only for ones that start with a given prefix.\n >>> filter_by_prefix([], 'a')\n []\n >>> filter_by_prefix(['abc', 'bcd', 'cde', 'array'], 'a')\n ['abc', 'array']\n \"\"\"\n"} 31 | {"instructions":"\n\ndef get_positive(l: list):\n \"\"\"Return only positive numbers in the list.\n >>> get_positive([-1, 2, -4, 5, 6])\n [2, 5, 6]\n >>> get_positive([5, 3, -5, 2, -3, 3, 9, 0, 123, 1, -10])\n [5, 3, 2, 3, 9, 123, 1]\n \"\"\"\n"} 32 | {"instructions":"\n\ndef is_prime(n):\n \"\"\"Return true if a given number is prime, and false otherwise.\n >>> is_prime(6)\n False\n >>> is_prime(101)\n True\n >>> is_prime(11)\n True\n >>> is_prime(13441)\n True\n >>> is_prime(61)\n True\n >>> is_prime(4)\n False\n >>> is_prime(1)\n False\n \"\"\"\n"} 33 | {"instructions":"import math\n\n\ndef poly(xs: list, x: float):\n \"\"\"\n Evaluates polynomial with coefficients xs at point x.\n return xs[0] + xs[1] * x + xs[1] * x^2 + .... xs[n] * x^n\n \"\"\"\n return sum([coeff * math.pow(x, i) for i, coeff in enumerate(xs)])\n\n\ndef find_zero(xs: list):\n \"\"\" xs are coefficients of a polynomial.\n find_zero find x such that poly(x) = 0.\n find_zero returns only only zero point, even if there are many.\n Moreover, find_zero only takes list xs having even number of coefficients\n and largest non zero coefficient as it guarantees\n a solution.\n >>> round(find_zero([1, 2]), 2) # f(x) = 1 + 2x\n -0.5\n >>> round(find_zero([-6, 11, -6, 1]), 2) # (x - 1) * (x - 2) * (x - 3) = -6 + 11x - 6x^2 + x^3\n 1.0\n \"\"\"\n"} 34 | {"instructions":"\n\ndef sort_third(l: list):\n \"\"\"This function takes a list l and returns a list l' such that\n l' is identical to l in the indicies that are not divisible by three, while its values at the indicies that are divisible by three are equal\n to the values of the corresponding indicies of l, but sorted.\n >>> sort_third([1, 2, 3])\n [1, 2, 3]\n >>> sort_third([5, 6, 3, 4, 8, 9, 2])\n [2, 6, 3, 4, 8, 9, 5]\n \"\"\"\n"} 35 | {"instructions":"\n\ndef unique(l: list):\n \"\"\"Return sorted unique elements in a list\n >>> unique([5, 3, 5, 2, 3, 3, 9, 0, 123])\n [0, 2, 3, 5, 9, 123]\n \"\"\"\n"} 36 | {"instructions":"\n\ndef max_element(l: list):\n \"\"\"Return maximum element in the list.\n >>> max_element([1, 2, 3])\n 3\n >>> max_element([5, 3, -5, 2, -3, 3, 9, 0, 123, 1, -10])\n 123\n \"\"\"\n"} 37 | {"instructions":"\n\ndef fizz_buzz(n: int):\n \"\"\"Return the number of times the digit 7 appears in integers less than n which are divisible by 11 or 13.\n >>> fizz_buzz(50)\n 0\n >>> fizz_buzz(78)\n 2\n >>> fizz_buzz(79)\n 3\n \"\"\"\n"} 38 | {"instructions":"\n\ndef sort_even(l: list):\n \"\"\"This function takes a list l and returns a list l' such that\n l' is identical to l in the odd indicies, while its values at the even indicies are equal\n to the values of the even indicies of l, but sorted.\n >>> sort_even([1, 2, 3])\n [1, 2, 3]\n >>> sort_even([5, 6, 3, 4])\n [3, 6, 5, 4]\n \"\"\"\n"} 39 | {"instructions":"\n\ndef encode_cyclic(s: str):\n \"\"\"\n returns encoded string by cycling groups of three characters.\n \"\"\"\n # split string to groups. Each of length 3.\n groups = [s[(3 * i):min((3 * i + 3), len(s))] for i in range((len(s) + 2) \/\/ 3)]\n # cycle elements in each group. Unless group has fewer elements than 3.\n groups = [(group[1:] + group[0]) if len(group) == 3 else group for group in groups]\n return \"\".join(groups)\n\n\ndef decode_cyclic(s: str):\n \"\"\"\n takes as input string encoded with encode_cyclic function. Returns decoded string.\n \"\"\"\n"} 40 | {"instructions":"\n\ndef prime_fib(n: int):\n \"\"\"\n prime_fib returns n-th number that is a Fibonacci number and it's also prime.\n >>> prime_fib(1)\n 2\n >>> prime_fib(2)\n 3\n >>> prime_fib(3)\n 5\n >>> prime_fib(4)\n 13\n >>> prime_fib(5)\n 89\n \"\"\"\n"} 41 | {"instructions":"\n\ndef triples_sum_to_zero(l: list):\n \"\"\"\n triples_sum_to_zero takes a list of integers as an input.\n it returns True if there are three distinct elements in the list that\n sum to zero, and False otherwise.\n\n >>> triples_sum_to_zero([1, 3, 5, 0])\n False\n >>> triples_sum_to_zero([1, 3, -2, 1])\n True\n >>> triples_sum_to_zero([1, 2, 3, 7])\n False\n >>> triples_sum_to_zero([2, 4, -5, 3, 9, 7])\n True\n >>> triples_sum_to_zero([1])\n False\n \"\"\"\n"} 42 | {"instructions":"\n\ndef car_race_collision(n: int):\n \"\"\"\n Imagine a road that's a perfectly straight infinitely long line.\n n cars are driving left to right; simultaneously, a different set of n cars\n are driving right to left. The two sets of cars start out being very far from\n each other. All cars move in the same speed. Two cars are said to collide\n when a car that's moving left to right hits a car that's moving right to left.\n However, the cars are infinitely sturdy and strong; as a result, they continue moving\n in their trajectory as if they did not collide.\n\n This function outputs the number of such collisions.\n \"\"\"\n"} 43 | {"instructions":"\n\ndef incr_list(l: list):\n \"\"\"Return list with elements incremented by 1.\n >>> incr_list([1, 2, 3])\n [2, 3, 4]\n >>> incr_list([5, 3, 5, 2, 3, 3, 9, 0, 123])\n [6, 4, 6, 3, 4, 4, 10, 1, 124]\n \"\"\"\n"} 44 | {"instructions":"\n\ndef pairs_sum_to_zero(l):\n \"\"\"\n pairs_sum_to_zero takes a list of integers as an input.\n it returns True if there are two distinct elements in the list that\n sum to zero, and False otherwise.\n >>> pairs_sum_to_zero([1, 3, 5, 0])\n False\n >>> pairs_sum_to_zero([1, 3, -2, 1])\n False\n >>> pairs_sum_to_zero([1, 2, 3, 7])\n False\n >>> pairs_sum_to_zero([2, 4, -5, 3, 5, 7])\n True\n >>> pairs_sum_to_zero([1])\n False\n \"\"\"\n"} 45 | {"instructions":"\n\ndef change_base(x: int, base: int):\n \"\"\"Change numerical base of input number x to base.\n return string representation after the conversion.\n base numbers are less than 10.\n >>> change_base(8, 3)\n '22'\n >>> change_base(8, 2)\n '1000'\n >>> change_base(7, 2)\n '111'\n \"\"\"\n"} 46 | {"instructions":"\n\ndef triangle_area(a, h):\n \"\"\"Given length of a side and high return area for a triangle.\n >>> triangle_area(5, 3)\n 7.5\n \"\"\"\n"} 47 | {"instructions":"\n\ndef fib4(n: int):\n \"\"\"The Fib4 number sequence is a sequence similar to the Fibbonacci sequnece that's defined as follows:\n fib4(0) -> 0\n fib4(1) -> 0\n fib4(2) -> 2\n fib4(3) -> 0\n fib4(n) -> fib4(n-1) + fib4(n-2) + fib4(n-3) + fib4(n-4).\n Please write a function to efficiently compute the n-th element of the fib4 number sequence. Do not use recursion.\n >>> fib4(5)\n 4\n >>> fib4(6)\n 8\n >>> fib4(7)\n 14\n \"\"\"\n"} 48 | {"instructions":"\n\ndef median(l: list):\n \"\"\"Return median of elements in the list l.\n >>> median([3, 1, 2, 4, 5])\n 3\n >>> median([-10, 4, 6, 1000, 10, 20])\n 15.0\n \"\"\"\n"} 49 | {"instructions":"\n\ndef is_palindrome(text: str):\n \"\"\"\n Checks if given string is a palindrome\n >>> is_palindrome('')\n True\n >>> is_palindrome('aba')\n True\n >>> is_palindrome('aaaaa')\n True\n >>> is_palindrome('zbcd')\n False\n \"\"\"\n"} 50 | {"instructions":"\n\ndef modp(n: int, p: int):\n \"\"\"Return 2^n modulo p (be aware of numerics).\n >>> modp(3, 5)\n 3\n >>> modp(1101, 101)\n 2\n >>> modp(0, 101)\n 1\n >>> modp(3, 11)\n 8\n >>> modp(100, 101)\n 1\n \"\"\"\n"} 51 | {"instructions":"\n\ndef encode_shift(s: str):\n \"\"\"\n returns encoded string by shifting every character by 5 in the alphabet.\n \"\"\"\n return \"\".join([chr(((ord(ch) + 5 - ord(\"a\")) % 26) + ord(\"a\")) for ch in s])\n\n\ndef decode_shift(s: str):\n \"\"\"\n takes as input string encoded with encode_shift function. Returns decoded string.\n \"\"\"\n"} 52 | {"instructions":"\n\ndef remove_vowels(text):\n \"\"\"\n remove_vowels is a function that takes string and returns string without vowels.\n >>> remove_vowels('')\n ''\n >>> remove_vowels(\"abcdef\\nghijklm\")\n 'bcdf\\nghjklm'\n >>> remove_vowels('abcdef')\n 'bcdf'\n >>> remove_vowels('aaaaa')\n ''\n >>> remove_vowels('aaBAA')\n 'B'\n >>> remove_vowels('zbcd')\n 'zbcd'\n \"\"\"\n"} 53 | {"instructions":"\n\ndef below_threshold(l: list, t: int):\n \"\"\"Return True if all numbers in the list l are below threshold t.\n >>> below_threshold([1, 2, 4, 10], 100)\n True\n >>> below_threshold([1, 20, 4, 10], 5)\n False\n \"\"\"\n"} 54 | {"instructions":"\n\ndef add(x: int, y: int):\n \"\"\"Add two numbers x and y\n >>> add(2, 3)\n 5\n >>> add(5, 7)\n 12\n \"\"\"\n"} 55 | {"instructions":"\n\ndef same_chars(s0: str, s1: str):\n \"\"\"\n Check if two words have the same characters.\n >>> same_chars('eabcdzzzz', 'dddzzzzzzzddeddabc')\n True\n >>> same_chars('abcd', 'dddddddabc')\n True\n >>> same_chars('dddddddabc', 'abcd')\n True\n >>> same_chars('eabcd', 'dddddddabc')\n False\n >>> same_chars('abcd', 'dddddddabce')\n False\n >>> same_chars('eabcdzzzz', 'dddzzzzzzzddddabc')\n False\n \"\"\"\n"} 56 | {"instructions":"\n\ndef fib(n: int):\n \"\"\"Return n-th Fibonacci number.\n >>> fib(10)\n 55\n >>> fib(1)\n 1\n >>> fib(8)\n 21\n \"\"\"\n"} 57 | {"instructions":"\n\ndef correct_bracketing(brackets: str):\n \"\"\" brackets is a string of \"<\" and \">\".\n return True if every opening bracket has a corresponding closing bracket.\n\n >>> correct_bracketing(\"<\")\n False\n >>> correct_bracketing(\"<>\")\n True\n >>> correct_bracketing(\"<<><>>\")\n True\n >>> correct_bracketing(\"><<>\")\n False\n \"\"\"\n"} 58 | {"instructions":"\n\ndef monotonic(l: list):\n \"\"\"Return True is list elements are monotonically increasing or decreasing.\n >>> monotonic([1, 2, 4, 20])\n True\n >>> monotonic([1, 20, 4, 10])\n False\n >>> monotonic([4, 1, 0, -10])\n True\n \"\"\"\n"} 59 | {"instructions":"\n\ndef common(l1: list, l2: list):\n \"\"\"Return sorted unique common elements for two lists.\n >>> common([1, 4, 3, 34, 653, 2, 5], [5, 7, 1, 5, 9, 653, 121])\n [1, 5, 653]\n >>> common([5, 3, 2, 8], [3, 2])\n [2, 3]\n\n \"\"\"\n"} 60 | {"instructions":"\n\ndef largest_prime_factor(n: int):\n \"\"\"Return the largest prime factor of n. Assume n > 1 and is not a prime.\n >>> largest_prime_factor(13195)\n 29\n >>> largest_prime_factor(2048)\n 2\n \"\"\"\n"} 61 | {"instructions":"\n\ndef sum_to_n(n: int):\n \"\"\"sum_to_n is a function that sums numbers from 1 to n.\n >>> sum_to_n(30)\n 465\n >>> sum_to_n(100)\n 5050\n >>> sum_to_n(5)\n 15\n >>> sum_to_n(10)\n 55\n >>> sum_to_n(1)\n 1\n \"\"\"\n"} 62 | {"instructions":"\n\ndef correct_bracketing(brackets: str):\n \"\"\" brackets is a string of \"(\" and \")\".\n return True if every opening bracket has a corresponding closing bracket.\n\n >>> correct_bracketing(\"(\")\n False\n >>> correct_bracketing(\"()\")\n True\n >>> correct_bracketing(\"(()())\")\n True\n >>> correct_bracketing(\")(()\")\n False\n \"\"\"\n"} 63 | {"instructions":"\n\ndef derivative(xs: list):\n \"\"\" xs represent coefficients of a polynomial.\n xs[0] + xs[1] * x + xs[2] * x^2 + ....\n Return derivative of this polynomial in the same form.\n >>> derivative([3, 1, 2, 4, 5])\n [1, 4, 12, 20]\n >>> derivative([1, 2, 3])\n [2, 6]\n \"\"\"\n"} 64 | {"instructions":"\n\ndef fibfib(n: int):\n \"\"\"The FibFib number sequence is a sequence similar to the Fibbonacci sequnece that's defined as follows:\n fibfib(0) == 0\n fibfib(1) == 0\n fibfib(2) == 1\n fibfib(n) == fibfib(n-1) + fibfib(n-2) + fibfib(n-3).\n Please write a function to efficiently compute the n-th element of the fibfib number sequence.\n >>> fibfib(1)\n 0\n >>> fibfib(5)\n 4\n >>> fibfib(8)\n 24\n \"\"\"\n"} 65 | {"instructions":"\nFIX = \"\"\"\nAdd more test cases.\n\"\"\"\n\ndef vowels_count(s):\n \"\"\"Write a function vowels_count which takes a string representing\n a word as input and returns the number of vowels in the string.\n Vowels in this case are 'a', 'e', 'i', 'o', 'u'. Here, 'y' is also a\n vowel, but only when it is at the end of the given word.\n\n Example:\n >>> vowels_count(\"abcde\")\n 2\n >>> vowels_count(\"ACEDY\")\n 3\n \"\"\"\n"} 66 | {"instructions":"\ndef circular_shift(x, shift):\n \"\"\"Circular shift the digits of the integer x, shift the digits right by shift\n and return the result as a string.\n If shift > number of digits, return digits reversed.\n >>> circular_shift(12, 1)\n \"21\"\n >>> circular_shift(12, 2)\n \"12\"\n \"\"\"\n"} 67 | {"instructions":"\ndef digitSum(s):\n \"\"\"Task\n Write a function that takes a string as input and returns the sum of the upper characters only'\n ASCII codes.\n\n Examples:\n digitSum(\"\") => 0\n digitSum(\"abAB\") => 131\n digitSum(\"abcCd\") => 67\n digitSum(\"helloE\") => 69\n digitSum(\"woArBld\") => 131\n digitSum(\"aAaaaXa\") => 153\n \"\"\"\n"} 68 | {"instructions":"\ndef fruit_distribution(s,n):\n \"\"\"\n In this task, you will be given a string that represents a number of apples and oranges \n that are distributed in a basket of fruit this basket contains \n apples, oranges, and mango fruits. Given the string that represents the total number of \n the oranges and apples and an integer that represent the total number of the fruits \n in the basket return the number of the mango fruits in the basket.\n for examble:\n fruit_distribution(\"5 apples and 6 oranges\", 19) ->19 - 5 - 6 = 8\n fruit_distribution(\"0 apples and 1 oranges\",3) -> 3 - 0 - 1 = 2\n fruit_distribution(\"2 apples and 3 oranges\", 100) -> 100 - 2 - 3 = 95\n fruit_distribution(\"100 apples and 1 oranges\",120) -> 120 - 100 - 1 = 19\n \"\"\"\n"} 69 | {"instructions":"\ndef pluck(arr):\n \"\"\"\n \"Given an array representing a branch of a tree that has non-negative integer nodes\n your task is to pluck one of the nodes and return it.\n The plucked node should be the node with the smallest even value.\n If multiple nodes with the same smallest even value are found return the node that has smallest index.\n\n The plucked node should be returned in a list, [ smalest_value, its index ],\n If there are no even values or the given array is empty, return [].\n\n Example 1:\n Input: [4,2,3]\n Output: [2, 1]\n Explanation: 2 has the smallest even value, and 2 has the smallest index.\n\n Example 2:\n Input: [1,2,3]\n Output: [2, 1]\n Explanation: 2 has the smallest even value, and 2 has the smallest index. \n\n Example 3:\n Input: []\n Output: []\n \n Example 4:\n Input: [5, 0, 3, 0, 4, 2]\n Output: [0, 1]\n Explanation: 0 is the smallest value, but there are two zeros,\n so we will choose the first zero, which has the smallest index.\n\n Constraints:\n * 1 <= nodes.length <= 10000\n * 0 <= node.value\n \"\"\"\n"} 70 | {"instructions":"\ndef search(lst):\n '''\n You are given a non-empty list of positive integers. Return the greatest integer that is greater than \n zero, and has a frequency greater than or equal to the value of the integer itself. \n The frequency of an integer is the number of times it appears in the list.\n If no such a value exist, return -1.\n Examples:\n search([4, 1, 2, 2, 3, 1]) == 2\n search([1, 2, 2, 3, 3, 3, 4, 4, 4]) == 3\n search([5, 5, 4, 4, 4]) == -1\n '''\n"} 71 | {"instructions":"\ndef strange_sort_list(lst):\n '''\n Given list of integers, return list in strange order.\n Strange sorting, is when you start with the minimum value,\n then maximum of the remaining integers, then minimum and so on.\n\n Examples:\n strange_sort_list([1, 2, 3, 4]) == [1, 4, 2, 3]\n strange_sort_list([5, 5, 5, 5]) == [5, 5, 5, 5]\n strange_sort_list([]) == []\n '''\n"} 72 | {"instructions":"\ndef triangle_area(a, b, c):\n '''\n Given the lengths of the three sides of a triangle. Return the area of\n the triangle rounded to 2 decimal points if the three sides form a valid triangle. \n Otherwise return -1\n Three sides make a valid triangle when the sum of any two sides is greater \n than the third side.\n Example:\n triangle_area(3, 4, 5) == 6.00\n triangle_area(1, 2, 10) == -1\n '''\n"} 73 | {"instructions":"\ndef will_it_fly(q,w):\n '''\n Write a function that returns True if the object q will fly, and False otherwise.\n The object q will fly if it's balanced (it is a palindromic list) and the sum of its elements is less than or equal the maximum possible weight w.\n\n Example:\n will_it_fly([1, 2], 5) \u279e False \n # 1+2 is less than the maximum possible weight, but it's unbalanced.\n\n will_it_fly([3, 2, 3], 1) \u279e False\n # it's balanced, but 3+2+3 is more than the maximum possible weight.\n\n will_it_fly([3, 2, 3], 9) \u279e True\n # 3+2+3 is less than the maximum possible weight, and it's balanced.\n\n will_it_fly([3], 5) \u279e True\n # 3 is less than the maximum possible weight, and it's balanced.\n '''\n"} 74 | {"instructions":"\ndef smallest_change(arr):\n \"\"\"\n Given an array arr of integers, find the minimum number of elements that\n need to be changed to make the array palindromic. A palindromic array is an array that\n is read the same backwards and forwards. In one change, you can change one element to any other element.\n\n For example:\n smallest_change([1,2,3,5,4,7,9,6]) == 4\n smallest_change([1, 2, 3, 4, 3, 2, 2]) == 1\n smallest_change([1, 2, 3, 2, 1]) == 0\n \"\"\"\n"} 75 | {"instructions":"\ndef total_match(lst1, lst2):\n '''\n Write a function that accepts two lists of strings and returns the list that has \n total number of chars in the all strings of the list less than the other list.\n\n if the two lists have the same number of chars, return the first list.\n\n Examples\n total_match([], []) \u279e []\n total_match(['hi', 'admin'], ['hI', 'Hi']) \u279e ['hI', 'Hi']\n total_match(['hi', 'admin'], ['hi', 'hi', 'admin', 'project']) \u279e ['hi', 'admin']\n total_match(['hi', 'admin'], ['hI', 'hi', 'hi']) \u279e ['hI', 'hi', 'hi']\n total_match(['4'], ['1', '2', '3', '4', '5']) \u279e ['4']\n '''\n"} 76 | {"instructions":"\ndef is_multiply_prime(a):\n \"\"\"Write a function that returns true if the given number is the multiplication of 3 prime numbers\n and false otherwise.\n Knowing that (a) is less then 100. \n Example:\n is_multiply_prime(30) == True\n 30 = 2 * 3 * 5\n \"\"\"\n"} 77 | {"instructions":"\ndef is_simple_power(x, n):\n \"\"\"Your task is to write a function that returns true if a number x is a simple\n power of n and false in other cases.\n x is a simple power of n if n**int=x\n For example:\n is_simple_power(1, 4) => true\n is_simple_power(2, 2) => true\n is_simple_power(8, 2) => true\n is_simple_power(3, 2) => false\n is_simple_power(3, 1) => false\n is_simple_power(5, 3) => false\n \"\"\"\n"} 78 | {"instructions":"\ndef iscube(a):\n '''\n Write a function that takes an integer a and returns True \n if this ingeger is a cube of some integer number.\n Note: you may assume the input is always valid.\n Examples:\n iscube(1) ==> True\n iscube(2) ==> False\n iscube(-1) ==> True\n iscube(64) ==> True\n iscube(0) ==> True\n iscube(180) ==> False\n '''\n"} 79 | {"instructions":"\ndef hex_key(num):\n \"\"\"You have been tasked to write a function that receives \n a hexadecimal number as a string and counts the number of hexadecimal \n digits that are primes (prime number, or a prime, is a natural number \n greater than 1 that is not a product of two smaller natural numbers).\n Hexadecimal digits are 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, A, B, C, D, E, F.\n Prime numbers are 2, 3, 5, 7, 11, 13, 17,...\n So you have to determine a number of the following digits: 2, 3, 5, 7, \n B (=decimal 11), D (=decimal 13).\n Note: you may assume the input is always correct or empty string, \n and symbols A,B,C,D,E,F are always uppercase.\n Examples:\n For num = \"AB\" the output should be 1.\n For num = \"1077E\" the output should be 2.\n For num = \"ABED1A33\" the output should be 4.\n For num = \"123456789ABCDEF0\" the output should be 6.\n For num = \"2020\" the output should be 2.\n \"\"\"\n"} 80 | {"instructions":"\ndef decimal_to_binary(decimal):\n \"\"\"You will be given a number in decimal form and your task is to convert it to\n binary format. The function should return a string, with each character representing a binary\n number. Each character in the string will be '0' or '1'.\n\n There will be an extra couple of characters 'db' at the beginning and at the end of the string.\n The extra characters are there to help with the format.\n\n Examples:\n decimal_to_binary(15) # returns \"db1111db\"\n decimal_to_binary(32) # returns \"db100000db\"\n \"\"\"\n"} 81 | {"instructions":"\ndef is_happy(s):\n \"\"\"You are given a string s.\n Your task is to check if the string is happy or not.\n A string is happy if its length is at least 3 and every 3 consecutive letters are distinct\n For example:\n is_happy(a) => False\n is_happy(aa) => False\n is_happy(abcd) => True\n is_happy(aabb) => False\n is_happy(adb) => True\n is_happy(xyy) => False\n \"\"\"\n"} 82 | {"instructions":"\ndef numerical_letter_grade(grades):\n \"\"\"It is the last week of the semester and the teacher has to give the grades\n to students. The teacher has been making her own algorithm for grading.\n The only problem is, she has lost the code she used for grading.\n She has given you a list of GPAs for some students and you have to write \n a function that can output a list of letter grades using the following table:\n GPA | Letter grade\n 4.0 A+\n > 3.7 A \n > 3.3 A- \n > 3.0 B+\n > 2.7 B \n > 2.3 B-\n > 2.0 C+\n > 1.7 C\n > 1.3 C-\n > 1.0 D+ \n > 0.7 D \n > 0.0 D-\n 0.0 E\n \n\n Example:\n grade_equation([4.0, 3, 1.7, 2, 3.5]) ==> ['A+', 'B', 'C-', 'C', 'A-']\n \"\"\"\n"} 83 | {"instructions":"\ndef prime_length(string):\n \"\"\"Write a function that takes a string and returns True if the string\n length is a prime number or False otherwise\n Examples\n prime_length('Hello') == True\n prime_length('abcdcba') == True\n prime_length('kittens') == True\n prime_length('orange') == False\n \"\"\"\n"} 84 | {"instructions":"\ndef starts_one_ends(n):\n \"\"\"\n Given a positive integer n, return the count of the numbers of n-digit\n positive integers that start or end with 1.\n \"\"\"\n"} 85 | {"instructions":"\ndef solve(N):\n \"\"\"Given a positive integer N, return the total sum of its digits in binary.\n \n Example\n For N = 1000, the sum of digits will be 1 the output should be \"1\".\n For N = 150, the sum of digits will be 6 the output should be \"110\".\n For N = 147, the sum of digits will be 12 the output should be \"1100\".\n \n Variables:\n @N integer\n Constraints: 0 \u2264 N \u2264 10000.\n Output:\n a string of binary number\n \"\"\"\n"} 86 | {"instructions":"\ndef add(lst):\n \"\"\"Given a non-empty list of integers lst. add the even elements that are at odd indices..\n\n\n Examples:\n add([4, 2, 6, 7]) ==> 2 \n \"\"\"\n"} 87 | {"instructions":"\ndef anti_shuffle(s):\n \"\"\"\n Write a function that takes a string and returns an ordered version of it.\n Ordered version of string, is a string where all words (separated by space)\n are replaced by a new word where all the characters arranged in\n ascending order based on ascii value.\n Note: You should keep the order of words and blank spaces in the sentence.\n\n For example:\n anti_shuffle('Hi') returns 'Hi'\n anti_shuffle('hello') returns 'ehllo'\n anti_shuffle('Hello World!!!') returns 'Hello !!!Wdlor'\n \"\"\"\n"} 88 | {"instructions":"\ndef get_row(lst, x):\n \"\"\"\n You are given a 2 dimensional data, as a nested lists,\n which is similar to matrix, however, unlike matrices,\n each row may contain a different number of columns.\n Given lst, and integer x, find integers x in the list,\n and return list of tuples, [(x1, y1), (x2, y2) ...] such that\n each tuple is a coordinate - (row, columns), starting with 0.\n Sort coordinates initially by rows in ascending order.\n Also, sort coordinates of the row by columns in descending order.\n \n Examples:\n get_row([\n [1,2,3,4,5,6],\n [1,2,3,4,1,6],\n [1,2,3,4,5,1]\n ], 1) == [(0, 0), (1, 4), (1, 0), (2, 5), (2, 0)]\n get_row([], 1) == []\n get_row([[], [1], [1, 2, 3]], 3) == [(2, 2)]\n \"\"\"\n"} 89 | {"instructions":"\ndef sort_array(array):\n \"\"\"\n Given an array of non-negative integers, return a copy of the given array after sorting,\n you will sort the given array in ascending order if the sum( first index value, last index value) is odd,\n or sort it in descending order if the sum( first index value, last index value) is even.\n\n Note:\n * don't change the given array.\n\n Examples:\n * sort_array([]) => []\n * sort_array([5]) => [5]\n * sort_array([2, 4, 3, 0, 1, 5]) => [0, 1, 2, 3, 4, 5]\n * sort_array([2, 4, 3, 0, 1, 5, 6]) => [6, 5, 4, 3, 2, 1, 0]\n \"\"\"\n"} 90 | {"instructions":"\ndef encrypt(s):\n \"\"\"Create a function encrypt that takes a string as an argument and\n returns a string encrypted with the alphabet being rotated. \n The alphabet should be rotated in a manner such that the letters \n shift down by two multiplied to two places.\n For example:\n encrypt('hi') returns 'lm'\n encrypt('asdfghjkl') returns 'ewhjklnop'\n encrypt('gf') returns 'kj'\n encrypt('et') returns 'ix'\n \"\"\"\n"} 91 | {"instructions":"\ndef next_smallest(lst):\n \"\"\"\n You are given a list of integers.\n Write a function next_smallest() that returns the 2nd smallest element of the list.\n Return None if there is no such element.\n \n next_smallest([1, 2, 3, 4, 5]) == 2\n next_smallest([5, 1, 4, 3, 2]) == 2\n next_smallest([]) == None\n next_smallest([1, 1]) == None\n \"\"\"\n"} 92 | {"instructions":"\ndef is_bored(S):\n \"\"\"\n You'll be given a string of words, and your task is to count the number\n of boredoms. A boredom is a sentence that starts with the word \"I\".\n Sentences are delimited by '.', '?' or '!'.\n \n For example:\n >>> is_bored(\"Hello world\")\n 0\n >>> is_bored(\"The sky is blue. The sun is shining. I love this weather\")\n 1\n \"\"\"\n"} 93 | {"instructions":"\ndef any_int(x, y, z):\n '''\n Create a function that takes 3 numbers.\n Returns true if one of the numbers is equal to the sum of the other two, and all numbers are integers.\n Returns false in any other cases.\n \n Examples\n any_int(5, 2, 7) \u279e True\n \n any_int(3, 2, 2) \u279e False\n\n any_int(3, -2, 1) \u279e True\n \n any_int(3.6, -2.2, 2) \u279e False\n \n\n \n '''\n"} 94 | {"instructions":"\ndef encode(message):\n \"\"\"\n Write a function that takes a message, and encodes in such a \n way that it swaps case of all letters, replaces all vowels in \n the message with the letter that appears 2 places ahead of that \n vowel in the english alphabet. \n Assume only letters. \n \n Examples:\n >>> encode('test')\n 'TGST'\n >>> encode('This is a message')\n 'tHKS KS C MGSSCGG'\n \"\"\"\n"} 95 | {"instructions":"\n\ndef skjkasdkd(lst):\n \"\"\"You are given a list of integers.\n You need to find the largest prime value and return the sum of its digits.\n\n Examples:\n For lst = [0,3,2,1,3,5,7,4,5,5,5,2,181,32,4,32,3,2,32,324,4,3] the output should be 10\n For lst = [1,0,1,8,2,4597,2,1,3,40,1,2,1,2,4,2,5,1] the output should be 25\n For lst = [1,3,1,32,5107,34,83278,109,163,23,2323,32,30,1,9,3] the output should be 13\n For lst = [0,724,32,71,99,32,6,0,5,91,83,0,5,6] the output should be 11\n For lst = [0,81,12,3,1,21] the output should be 3\n For lst = [0,8,1,2,1,7] the output should be 7\n \"\"\"\n"} 96 | {"instructions":"\ndef check_dict_case(dict):\n \"\"\"\n Given a dictionary, return True if all keys are strings in lower \n case or all keys are strings in upper case, else return False.\n The function should return False is the given dictionary is empty.\n Examples:\n check_dict_case({\"a\":\"apple\", \"b\":\"banana\"}) should return True.\n check_dict_case({\"a\":\"apple\", \"A\":\"banana\", \"B\":\"banana\"}) should return False.\n check_dict_case({\"a\":\"apple\", 8:\"banana\", \"a\":\"apple\"}) should return False.\n check_dict_case({\"Name\":\"John\", \"Age\":\"36\", \"City\":\"Houston\"}) should return False.\n check_dict_case({\"STATE\":\"NC\", \"ZIP\":\"12345\" }) should return True.\n \"\"\"\n"} 97 | {"instructions":"\ndef count_up_to(n):\n \"\"\"Implement a function that takes an non-negative integer and returns an array of the first n\n integers that are prime numbers and less than n.\n for example:\n count_up_to(5) => [2,3]\n count_up_to(11) => [2,3,5,7]\n count_up_to(0) => []\n count_up_to(20) => [2,3,5,7,11,13,17,19]\n count_up_to(1) => []\n count_up_to(18) => [2,3,5,7,11,13,17]\n \"\"\"\n"} 98 | {"instructions":"\ndef multiply(a, b):\n \"\"\"Complete the function that takes two integers and returns \n the product of their unit digits.\n Assume the input is always valid.\n Examples:\n multiply(148, 412) should return 16.\n multiply(19, 28) should return 72.\n multiply(2020, 1851) should return 0.\n multiply(14,-15) should return 20.\n \"\"\"\n"} 99 | {"instructions":"\ndef count_upper(s):\n \"\"\"\n Given a string s, count the number of uppercase vowels in even indices.\n \n For example:\n count_upper('aBCdEf') returns 1\n count_upper('abcdefg') returns 0\n count_upper('dBBE') returns 0\n \"\"\"\n"} 100 | {"instructions":"\ndef closest_integer(value):\n '''\n Create a function that takes a value (string) representing a number\n and returns the closest integer to it. If the number is equidistant\n from two integers, round it away from zero.\n\n Examples\n >>> closest_integer(\"10\")\n 10\n >>> closest_integer(\"15.3\")\n 15\n\n Note:\n Rounding away from zero means that if the given number is equidistant\n from two integers, the one you should return is the one that is the\n farthest from zero. For example closest_integer(\"14.5\") should\n return 15 and closest_integer(\"-14.5\") should return -15.\n '''\n"} 101 | {"instructions":"\ndef make_a_pile(n):\n \"\"\"\n Given a positive integer n, you have to make a pile of n levels of stones.\n The first level has n stones.\n The number of stones in the next level is:\n - the next odd number if n is odd.\n - the next even number if n is even.\n Return the number of stones in each level in a list, where element at index\n i represents the number of stones in the level (i+1).\n\n Examples:\n >>> make_a_pile(3)\n [3, 5, 7]\n \"\"\"\n"} 102 | {"instructions":"\ndef words_string(s):\n \"\"\"\n You will be given a string of words separated by commas or spaces. Your task is\n to split the string into words and return an array of the words.\n \n For example:\n words_string(\"Hi, my name is John\") == [\"Hi\", \"my\", \"name\", \"is\", \"John\"]\n words_string(\"One, two, three, four, five, six\") == [\"One\", \"two\", \"three\", \"four\", \"five\", \"six\"]\n \"\"\"\n"} 103 | {"instructions":"\ndef choose_num(x, y):\n \"\"\"This function takes two positive numbers x and y and returns the\n biggest even integer number that is in the range [x, y] inclusive. If \n there's no such number, then the function should return -1.\n\n For example:\n choose_num(12, 15) = 14\n choose_num(13, 12) = -1\n \"\"\"\n"} 104 | {"instructions":"\ndef rounded_avg(n, m):\n \"\"\"You are given two positive integers n and m, and your task is to compute the\n average of the integers from n through m (including n and m). \n Round the answer to the nearest integer and convert that to binary.\n If n is greater than m, return -1.\n Example:\n rounded_avg(1, 5) => \"0b11\"\n rounded_avg(7, 5) => -1\n rounded_avg(10, 20) => \"0b1111\"\n rounded_avg(20, 33) => \"0b11010\"\n \"\"\"\n"} 105 | {"instructions":"\ndef unique_digits(x):\n \"\"\"Given a list of positive integers x. return a sorted list of all \n elements that hasn't any even digit.\n\n Note: Returned list should be sorted in increasing order.\n \n For example:\n >>> unique_digits([15, 33, 1422, 1])\n [1, 15, 33]\n >>> unique_digits([152, 323, 1422, 10])\n []\n \"\"\"\n"} 106 | {"instructions":"\ndef by_length(arr):\n \"\"\"\n Given an array of integers, sort the integers that are between 1 and 9 inclusive,\n reverse the resulting array, and then replace each digit by its corresponding name from\n \"One\", \"Two\", \"Three\", \"Four\", \"Five\", \"Six\", \"Seven\", \"Eight\", \"Nine\".\n\n For example:\n arr = [2, 1, 1, 4, 5, 8, 2, 3] \n -> sort arr -> [1, 1, 2, 2, 3, 4, 5, 8] \n -> reverse arr -> [8, 5, 4, 3, 2, 2, 1, 1]\n return [\"Eight\", \"Five\", \"Four\", \"Three\", \"Two\", \"Two\", \"One\", \"One\"]\n \n If the array is empty, return an empty array:\n arr = []\n return []\n \n If the array has any strange number ignore it:\n arr = [1, -1 , 55] \n -> sort arr -> [-1, 1, 55]\n -> reverse arr -> [55, 1, -1]\n return = ['One']\n \"\"\"\n"} 107 | {"instructions":"\ndef f(n):\n \"\"\" Implement the function f that takes n as a parameter,\n and returns a list of size n, such that the value of the element at index i is the factorial of i if i is even\n or the sum of numbers from 1 to i otherwise.\n i starts from 1.\n the factorial of i is the multiplication of the numbers from 1 to i (1 * 2 * ... * i).\n Example:\n f(5) == [1, 2, 6, 24, 15]\n \"\"\"\n"} 108 | {"instructions":"\ndef even_odd_palindrome(n):\n \"\"\"\n Given a positive integer n, return a tuple that has the number of even and odd\n integer palindromes that fall within the range(1, n), inclusive.\n\n Example 1:\n\n Input: 3\n Output: (1, 2)\n Explanation:\n Integer palindrome are 1, 2, 3. one of them is even, and two of them are odd.\n\n Example 2:\n\n Input: 12\n Output: (4, 6)\n Explanation:\n Integer palindrome are 1, 2, 3, 4, 5, 6, 7, 8, 9, 11. four of them are even, and 6 of them are odd.\n\n Note:\n 1. 1 <= n <= 10^3\n 2. returned tuple has the number of even and odd integer palindromes respectively.\n \"\"\"\n"} 109 | {"instructions":"\ndef count_nums(arr):\n \"\"\"\n Write a function count_nums which takes an array of integers and returns\n the number of elements which has a sum of digits > 0.\n If a number is negative, then its first signed digit will be negative:\n e.g. -123 has signed digits -1, 2, and 3.\n >>> count_nums([]) == 0\n >>> count_nums([-1, 11, -11]) == 1\n >>> count_nums([1, 1, 2]) == 3\n \"\"\"\n"} 110 | {"instructions":"\ndef move_one_ball(arr):\n \"\"\"We have an array 'arr' of N integers arr[1], arr[2], ..., arr[N].The\n numbers in the array will be randomly ordered. Your task is to determine if\n it is possible to get an array sorted in non-decreasing order by performing \n the following operation on the given array:\n You are allowed to perform right shift operation any number of times.\n \n One right shift operation means shifting all elements of the array by one\n position in the right direction. The last element of the array will be moved to\n the starting position in the array i.e. 0th index. \n\n If it is possible to obtain the sorted array by performing the above operation\n then return True else return False.\n If the given array is empty then return True.\n\n Note: The given list is guaranteed to have unique elements.\n\n For Example:\n \n move_one_ball([3, 4, 5, 1, 2])==>True\n Explanation: By performin 2 right shift operations, non-decreasing order can\n be achieved for the given array.\n move_one_ball([3, 5, 4, 1, 2])==>False\n Explanation:It is not possible to get non-decreasing order for the given\n array by performing any number of right shift operations.\n \n \"\"\"\n"} 111 | {"instructions":"\ndef exchange(lst1, lst2):\n \"\"\"In this problem, you will implement a function that takes two lists of numbers,\n and determines whether it is possible to perform an exchange of elements\n between them to make lst1 a list of only even numbers.\n There is no limit on the number of exchanged elements between lst1 and lst2.\n If it is possible to exchange elements between the lst1 and lst2 to make\n all the elements of lst1 to be even, return \"YES\".\n Otherwise, return \"NO\".\n For example:\n exchange([1, 2, 3, 4], [1, 2, 3, 4]) => \"YES\"\n exchange([1, 2, 3, 4], [1, 5, 3, 4]) => \"NO\"\n It is assumed that the input lists will be non-empty.\n \"\"\"\n"} 112 | {"instructions":"\ndef histogram(test):\n \"\"\"Given a string representing a space separated lowercase letters, return a dictionary\n of the letter with the most repetition and containing the corresponding count.\n If several letters have the same occurrence, return all of them.\n \n Example:\n histogram('a b c') == {'a': 1, 'b': 1, 'c': 1}\n histogram('a b b a') == {'a': 2, 'b': 2}\n histogram('a b c a b') == {'a': 2, 'b': 2}\n histogram('b b b b a') == {'b': 4}\n histogram('') == {}\n\n \"\"\"\n"} 113 | {"instructions":"\ndef reverse_delete(s,c):\n \"\"\"Task\n We are given two strings s and c, you have to deleted all the characters in s that are equal to any character in c\n then check if the result string is palindrome.\n A string is called palindrome if it reads the same backward as forward.\n You should return a tuple containing the result string and True\/False for the check.\n Example\n For s = \"abcde\", c = \"ae\", the result should be ('bcd',False)\n For s = \"abcdef\", c = \"b\" the result should be ('acdef',False)\n For s = \"abcdedcba\", c = \"ab\", the result should be ('cdedc',True)\n \"\"\"\n"} 114 | {"instructions":"\ndef odd_count(lst):\n \"\"\"Given a list of strings, where each string consists of only digits, return a list.\n Each element i of the output should be \"the number of odd elements in the\n string i of the input.\" where all the i's should be replaced by the number\n of odd digits in the i'th string of the input.\n\n >>> odd_count(['1234567'])\n [\"the number of odd elements 4n the str4ng 4 of the 4nput.\"]\n >>> odd_count(['3',\"11111111\"])\n [\"the number of odd elements 1n the str1ng 1 of the 1nput.\",\n \"the number of odd elements 8n the str8ng 8 of the 8nput.\"]\n \"\"\"\n"} 115 | {"instructions":"\ndef minSubArraySum(nums):\n \"\"\"\n Given an array of integers nums, find the minimum sum of any non-empty sub-array\n of nums.\n Example\n minSubArraySum([2, 3, 4, 1, 2, 4]) == 1\n minSubArraySum([-1, -2, -3]) == -6\n \"\"\"\n"} 116 | {"instructions":"\ndef max_fill(grid, capacity):\n import math\n \"\"\"\n You are given a rectangular grid of wells. Each row represents a single well,\n and each 1 in a row represents a single unit of water.\n Each well has a corresponding bucket that can be used to extract water from it, \n and all buckets have the same capacity.\n Your task is to use the buckets to empty the wells.\n Output the number of times you need to lower the buckets.\n\n Example 1:\n Input: \n grid : [[0,0,1,0], [0,1,0,0], [1,1,1,1]]\n bucket_capacity : 1\n Output: 6\n\n Example 2:\n Input: \n grid : [[0,0,1,1], [0,0,0,0], [1,1,1,1], [0,1,1,1]]\n bucket_capacity : 2\n Output: 5\n \n Example 3:\n Input: \n grid : [[0,0,0], [0,0,0]]\n bucket_capacity : 5\n Output: 0\n\n Constraints:\n * all wells have the same length\n * 1 <= grid.length <= 10^2\n * 1 <= grid[:,1].length <= 10^2\n * grid[i][j] -> 0 | 1\n * 1 <= capacity <= 10\n \"\"\"\n"} 117 | {"instructions":"\ndef sort_array(arr):\n \"\"\"\n In this Kata, you have to sort an array of non-negative integers according to\n number of ones in their binary representation in ascending order.\n For similar number of ones, sort based on decimal value.\n\n It must be implemented like this:\n >>> sort_array([1, 5, 2, 3, 4]) == [1, 2, 3, 4, 5]\n >>> sort_array([-2, -3, -4, -5, -6]) == [-6, -5, -4, -3, -2]\n >>> sort_array([1, 0, 2, 3, 4]) [0, 1, 2, 3, 4]\n \"\"\"\n"} 118 | {"instructions":"\ndef select_words(s, n):\n \"\"\"Given a string s and a natural number n, you have been tasked to implement \n a function that returns a list of all words from string s that contain exactly \n n consonants, in order these words appear in the string s.\n If the string s is empty then the function should return an empty list.\n Note: you may assume the input string contains only letters and spaces.\n Examples:\n select_words(\"Mary had a little lamb\", 4) ==> [\"little\"]\n select_words(\"Mary had a little lamb\", 3) ==> [\"Mary\", \"lamb\"]\n select_words(\"simple white space\", 2) ==> []\n select_words(\"Hello world\", 4) ==> [\"world\"]\n select_words(\"Uncle sam\", 3) ==> [\"Uncle\"]\n \"\"\"\n"} 119 | {"instructions":"\ndef get_closest_vowel(word):\n \"\"\"You are given a word. Your task is to find the closest vowel that stands between \n two consonants from the right side of the word (case sensitive).\n \n Vowels in the beginning and ending doesn't count. Return empty string if you didn't\n find any vowel met the above condition. \n\n You may assume that the given string contains English letter only.\n\n Example:\n get_closest_vowel(\"yogurt\") ==> \"u\"\n get_closest_vowel(\"FULL\") ==> \"U\"\n get_closest_vowel(\"quick\") ==> \"\"\n get_closest_vowel(\"ab\") ==> \"\"\n \"\"\"\n"} 120 | {"instructions":"\ndef match_parens(lst):\n '''\n You are given a list of two strings, both strings consist of open\n parentheses '(' or close parentheses ')' only.\n Your job is to check if it is possible to concatenate the two strings in\n some order, that the resulting string will be good.\n A string S is considered to be good if and only if all parentheses in S\n are balanced. For example: the string '(())()' is good, while the string\n '())' is not.\n Return 'Yes' if there's a way to make a good string, and return 'No' otherwise.\n\n Examples:\n match_parens(['()(', ')']) == 'Yes'\n match_parens([')', ')']) == 'No'\n '''\n"} 121 | {"instructions":"\ndef maximum(arr, k):\n \"\"\"\n Given an array arr of integers and a positive integer k, return a sorted list \n of length k with the maximum k numbers in arr.\n\n Example 1:\n\n Input: arr = [-3, -4, 5], k = 3\n Output: [-4, -3, 5]\n\n Example 2:\n\n Input: arr = [4, -4, 4], k = 2\n Output: [4, 4]\n\n Example 3:\n\n Input: arr = [-3, 2, 1, 2, -1, -2, 1], k = 1\n Output: [2]\n\n Note:\n 1. The length of the array will be in the range of [1, 1000].\n 2. The elements in the array will be in the range of [-1000, 1000].\n 3. 0 <= k <= len(arr)\n \"\"\"\n"} 122 | {"instructions":"\ndef solution(lst):\n \"\"\"Given a non-empty list of integers, return the sum of all of the odd elements that are in even positions.\n \n\n Examples\n solution([5, 8, 7, 1]) ==> 12\n solution([3, 3, 3, 3, 3]) ==> 9\n solution([30, 13, 24, 321]) ==>0\n \"\"\"\n"} 123 | {"instructions":"\ndef add_elements(arr, k):\n \"\"\"\n Given a non-empty array of integers arr and an integer k, return\n the sum of the elements with at most two digits from the first k elements of arr.\n\n Example:\n\n Input: arr = [111,21,3,4000,5,6,7,8,9], k = 4\n Output: 24 # sum of 21 + 3\n\n Constraints:\n 1. 1 <= len(arr) <= 100\n 2. 1 <= k <= len(arr)\n \"\"\"\n"} 124 | {"instructions":"\ndef get_odd_collatz(n):\n \"\"\"\n Given a positive integer n, return a sorted list that has the odd numbers in collatz sequence.\n\n The Collatz conjecture is a conjecture in mathematics that concerns a sequence defined\n as follows: start with any positive integer n. Then each term is obtained from the \n previous term as follows: if the previous term is even, the next term is one half of \n the previous term. If the previous term is odd, the next term is 3 times the previous\n term plus 1. The conjecture is that no matter what value of n, the sequence will always reach 1.\n\n Note: \n 1. Collatz(1) is [1].\n 2. returned list sorted in increasing order.\n\n For example:\n get_odd_collatz(5) returns [1, 5] # The collatz sequence for 5 is [5, 16, 8, 4, 2, 1], so the odd numbers are only 1, and 5.\n \"\"\"\n"} 125 | {"instructions":"\ndef valid_date(date):\n \"\"\"You have to write a function which validates a given date string and\n returns True if the date is valid otherwise False.\n The date is valid if all of the following rules are satisfied:\n 1. The date string is not empty.\n 2. The number of days is not less than 1 or higher than 31 days for months 1,3,5,7,8,10,12. And the number of days is not less than 1 or higher than 30 days for months 4,6,9,11. And, the number of days is not less than 1 or higher than 29 for the month 2.\n 3. The months should not be less than 1 or higher than 12.\n 4. The date should be in the format: mm-dd-yyyy\n\n for example: \n valid_date('03-11-2000') => True\n\n valid_date('15-01-2012') => False\n\n valid_date('04-0-2040') => False\n\n valid_date('06-04-2020') => True\n\n valid_date('06\/04\/2020') => False\n \"\"\"\n"} 126 | {"instructions":"\ndef split_words(txt):\n '''\n Given a string of words, return a list of words split on whitespace, if no whitespaces exists in the text you\n should split on commas ',' if no commas exists you should return the number of lower-case letters with odd order in the\n alphabet, ord('a') = 0, ord('b') = 1, ... ord('z') = 25\n Examples\n split_words(\"Hello world!\") \u279e [\"Hello\", \"world!\"]\n split_words(\"Hello,world!\") \u279e [\"Hello\", \"world!\"]\n split_words(\"abcdef\") == 3 \n '''\n"} 127 | {"instructions":"\ndef is_sorted(lst):\n '''\n Given a list of numbers, return whether or not they are sorted\n in ascending order. If list has more than 1 duplicate of the same\n number, return False. Assume no negative numbers and only integers.\n\n Examples\n is_sorted([5]) \u279e True\n is_sorted([1, 2, 3, 4, 5]) \u279e True\n is_sorted([1, 3, 2, 4, 5]) \u279e False\n is_sorted([1, 2, 3, 4, 5, 6]) \u279e True\n is_sorted([1, 2, 3, 4, 5, 6, 7]) \u279e True\n is_sorted([1, 3, 2, 4, 5, 6, 7]) \u279e False\n is_sorted([1, 2, 2, 3, 3, 4]) \u279e True\n is_sorted([1, 2, 2, 2, 3, 4]) \u279e False\n '''\n"} 128 | {"instructions":"\ndef intersection(interval1, interval2):\n \"\"\"You are given two intervals,\n where each interval is a pair of integers. For example, interval = (start, end) = (1, 2).\n The given intervals are closed which means that the interval (start, end)\n includes both start and end.\n For each given interval, it is assumed that its start is less or equal its end.\n Your task is to determine whether the length of intersection of these two \n intervals is a prime number.\n Example, the intersection of the intervals (1, 3), (2, 4) is (2, 3)\n which its length is 1, which not a prime number.\n If the length of the intersection is a prime number, return \"YES\",\n otherwise, return \"NO\".\n If the two intervals don't intersect, return \"NO\".\n\n\n [input\/output] samples:\n intersection((1, 2), (2, 3)) ==> \"NO\"\n intersection((-1, 1), (0, 4)) ==> \"NO\"\n intersection((-3, -1), (-5, 5)) ==> \"YES\"\n \"\"\"\n"} 129 | {"instructions":"\ndef prod_signs(arr):\n \"\"\"\n You are given an array arr of integers and you need to return\n sum of magnitudes of integers multiplied by product of all signs\n of each number in the array, represented by 1, -1 or 0.\n Note: return None for empty arr.\n\n Example:\n >>> prod_signs([1, 2, 2, -4]) == -9\n >>> prod_signs([0, 1]) == 0\n >>> prod_signs([]) == None\n \"\"\"\n"} 130 | {"instructions":"\ndef minPath(grid, k):\n \"\"\"\n Given a grid with N rows and N columns (N >= 2) and a positive integer k, \n each cell of the grid contains a value. Every integer in the range [1, N * N]\n inclusive appears exactly once on the cells of the grid.\n\n You have to find the minimum path of length k in the grid. You can start\n from any cell, and in each step you can move to any of the neighbor cells,\n in other words, you can go to cells which share an edge with you current\n cell.\n Please note that a path of length k means visiting exactly k cells (not\n necessarily distinct).\n You CANNOT go off the grid.\n A path A (of length k) is considered less than a path B (of length k) if\n after making the ordered lists of the values on the cells that A and B go\n through (let's call them lst_A and lst_B), lst_A is lexicographically less\n than lst_B, in other words, there exist an integer index i (1 <= i <= k)\n such that lst_A[i] < lst_B[i] and for any j (1 <= j < i) we have\n lst_A[j] = lst_B[j].\n It is guaranteed that the answer is unique.\n Return an ordered list of the values on the cells that the minimum path go through.\n\n Examples:\n\n Input: grid = [ [1,2,3], [4,5,6], [7,8,9]], k = 3\n Output: [1, 2, 1]\n\n Input: grid = [ [5,9,3], [4,1,6], [7,8,2]], k = 1\n Output: [1]\n \"\"\"\n"} 131 | {"instructions":"\ndef tri(n):\n \"\"\"Everyone knows Fibonacci sequence, it was studied deeply by mathematicians in \n the last couple centuries. However, what people don't know is Tribonacci sequence.\n Tribonacci sequence is defined by the recurrence:\n tri(1) = 3\n tri(n) = 1 + n \/ 2, if n is even.\n tri(n) = tri(n - 1) + tri(n - 2) + tri(n + 1), if n is odd.\n For example:\n tri(2) = 1 + (2 \/ 2) = 2\n tri(4) = 3\n tri(3) = tri(2) + tri(1) + tri(4)\n = 2 + 3 + 3 = 8 \n You are given a non-negative integer number n, you have to a return a list of the \n first n + 1 numbers of the Tribonacci sequence.\n Examples:\n tri(3) = [1, 3, 2, 8]\n \"\"\"\n"} 132 | {"instructions":"\ndef digits(n):\n \"\"\"Given a positive integer n, return the product of the odd digits.\n Return 0 if all digits are even.\n For example:\n digits(1) == 1\n digits(4) == 0\n digits(235) == 15\n \"\"\"\n"} 133 | {"instructions":"\ndef is_nested(string):\n '''\n Create a function that takes a string as input which contains only square brackets.\n The function should return True if and only if there is a valid subsequence of brackets \n where at least one bracket in the subsequence is nested.\n\n is_nested('[[]]') \u279e True\n is_nested('[]]]]]]][[[[[]') \u279e False\n is_nested('[][]') \u279e False\n is_nested('[]') \u279e False\n is_nested('[[][]]') \u279e True\n is_nested('[[]][[') \u279e True\n '''\n"} 134 | {"instructions":"\n\ndef sum_squares(lst):\n \"\"\"You are given a list of numbers.\n You need to return the sum of squared numbers in the given list,\n round each element in the list to the upper int(Ceiling) first.\n Examples:\n For lst = [1,2,3] the output should be 14\n For lst = [1,4,9] the output should be 98\n For lst = [1,3,5,7] the output should be 84\n For lst = [1.4,4.2,0] the output should be 29\n For lst = [-2.4,1,1] the output should be 6\n \n\n \"\"\"\n"} 135 | {"instructions":"\ndef check_if_last_char_is_a_letter(txt):\n '''\n Create a function that returns True if the last character\n of a given string is an alphabetical character and is not\n a part of a word, and False otherwise.\n Note: \"word\" is a group of characters separated by space.\n\n Examples:\n check_if_last_char_is_a_letter(\"apple pie\") \u279e False\n check_if_last_char_is_a_letter(\"apple pi e\") \u279e True\n check_if_last_char_is_a_letter(\"apple pi e \") \u279e False\n check_if_last_char_is_a_letter(\"\") \u279e False \n '''\n"} 136 | {"instructions":"\ndef can_arrange(arr):\n \"\"\"Create a function which returns the largest index of an element which\n is not greater than or equal to the element immediately preceding it. If\n no such element exists then return -1. The given array will not contain\n duplicate values.\n\n Examples:\n can_arrange([1,2,4,3,5]) = 3\n can_arrange([1,2,3]) = -1\n \"\"\"\n"} 137 | {"instructions":"\ndef largest_smallest_integers(lst):\n '''\n Create a function that returns a tuple (a, b), where 'a' is\n the largest of negative integers, and 'b' is the smallest\n of positive integers in a list.\n If there is no negative or positive integers, return them as None.\n\n Examples:\n largest_smallest_integers([2, 4, 1, 3, 5, 7]) == (None, 1)\n largest_smallest_integers([]) == (None, None)\n largest_smallest_integers([0]) == (None, None)\n '''\n"} 138 | {"instructions":"\ndef compare_one(a, b):\n \"\"\"\n Create a function that takes integers, floats, or strings representing\n real numbers, and returns the larger variable in its given variable type.\n Return None if the values are equal.\n Note: If a real number is represented as a string, the floating point might be . or ,\n\n compare_one(1, 2.5) \u279e 2.5\n compare_one(1, \"2,3\") \u279e \"2,3\"\n compare_one(\"5,1\", \"6\") \u279e \"6\"\n compare_one(\"1\", 1) \u279e None\n \"\"\"\n"} 139 | {"instructions":"\ndef is_equal_to_sum_even(n):\n \"\"\"Evaluate whether the given number n can be written as the sum of exactly 4 positive even numbers\n Example\n is_equal_to_sum_even(4) == False\n is_equal_to_sum_even(6) == False\n is_equal_to_sum_even(8) == True\n \"\"\"\n"} 140 | {"instructions":"\ndef special_factorial(n):\n \"\"\"The Brazilian factorial is defined as:\n brazilian_factorial(n) = n! * (n-1)! * (n-2)! * ... * 1!\n where n > 0\n\n For example:\n >>> special_factorial(4)\n 288\n\n The function will receive an integer as input and should return the special\n factorial of this integer.\n \"\"\"\n"} 141 | {"instructions":"\ndef fix_spaces(text):\n \"\"\"\n Given a string text, replace all spaces in it with underscores, \n and if a string has more than 2 consecutive spaces, \n then replace all consecutive spaces with - \n \n fix_spaces(\"Example\") == \"Example\"\n fix_spaces(\"Example 1\") == \"Example_1\"\n fix_spaces(\" Example 2\") == \"_Example_2\"\n fix_spaces(\" Example 3\") == \"_Example-3\"\n \"\"\"\n"} 142 | {"instructions":"\ndef file_name_check(file_name):\n \"\"\"Create a function which takes a string representing a file's name, and returns\n 'Yes' if the the file's name is valid, and returns 'No' otherwise.\n A file's name is considered to be valid if and only if all the following conditions \n are met:\n - There should not be more than three digits ('0'-'9') in the file's name.\n - The file's name contains exactly one dot '.'\n - The substring before the dot should not be empty, and it starts with a letter from \n the latin alphapet ('a'-'z' and 'A'-'Z').\n - The substring after the dot should be one of these: ['txt', 'exe', 'dll']\n Examples:\n file_name_check(\"example.txt\") # => 'Yes'\n file_name_check(\"1example.dll\") # => 'No' (the name should start with a latin alphapet letter)\n \"\"\"\n"} 143 | {"instructions":"\n\n\ndef sum_squares(lst):\n \"\"\"\"\n This function will take a list of integers. For all entries in the list, the function shall square the integer entry if its index is a \n multiple of 3 and will cube the integer entry if its index is a multiple of 4 and not a multiple of 3. The function will not \n change the entries in the list whose indexes are not a multiple of 3 or 4. The function shall then return the sum of all entries. \n \n Examples:\n For lst = [1,2,3] the output should be 6\n For lst = [] the output should be 0\n For lst = [-1,-5,2,-1,-5] the output should be -126\n \"\"\"\n"} 144 | {"instructions":"\ndef words_in_sentence(sentence):\n \"\"\"\n You are given a string representing a sentence,\n the sentence contains some words separated by a space,\n and you have to return a string that contains the words from the original sentence,\n whose lengths are prime numbers,\n the order of the words in the new string should be the same as the original one.\n\n Example 1:\n Input: sentence = \"This is a test\"\n Output: \"is\"\n\n Example 2:\n Input: sentence = \"lets go for swimming\"\n Output: \"go for\"\n\n Constraints:\n * 1 <= len(sentence) <= 100\n * sentence contains only letters\n \"\"\"\n"} 145 | {"instructions":"\ndef simplify(x, n):\n \"\"\"Your task is to implement a function that will simplify the expression\n x * n. The function returns True if x * n evaluates to a whole number and False\n otherwise. Both x and n, are string representation of a fraction, and have the following format,\n \/ where both numerator and denominator are positive whole numbers.\n\n You can assume that x, and n are valid fractions, and do not have zero as denominator.\n\n simplify(\"1\/5\", \"5\/1\") = True\n simplify(\"1\/6\", \"2\/1\") = False\n simplify(\"7\/10\", \"10\/2\") = False\n \"\"\"\n"} 146 | {"instructions":"\ndef order_by_points(nums):\n \"\"\"\n Write a function which sorts the given list of integers\n in ascending order according to the sum of their digits.\n Note: if there are several items with similar sum of their digits,\n order them based on their index in original list.\n\n For example:\n >>> order_by_points([1, 11, -1, -11, -12]) == [-1, -11, 1, -12, 11]\n >>> order_by_points([]) == []\n \"\"\"\n"} 147 | {"instructions":"\ndef specialFilter(nums):\n \"\"\"Write a function that takes an array of numbers as input and returns \n the number of elements in the array that are greater than 10 and both \n first and last digits of a number are odd (1, 3, 5, 7, 9).\n For example:\n specialFilter([15, -73, 14, -15]) => 1 \n specialFilter([33, -2, -3, 45, 21, 109]) => 2\n \"\"\"\n"} 148 | {"instructions":"\ndef get_max_triples(n):\n \"\"\"\n You are given a positive integer n. You have to create an integer array a of length n.\n For each i (1 \u2264 i \u2264 n), the value of a[i] = i * i - i + 1.\n Return the number of triples (a[i], a[j], a[k]) of a where i < j < k, \n and a[i] + a[j] + a[k] is a multiple of 3.\n\n Example :\n Input: n = 5\n Output: 1\n Explanation: \n a = [1, 3, 7, 13, 21]\n The only valid triple is (1, 7, 13).\n \"\"\"\n"} 149 | {"instructions":"\ndef bf(planet1, planet2):\n '''\n There are eight planets in our solar system: the closerst to the Sun \n is Mercury, the next one is Venus, then Earth, Mars, Jupiter, Saturn, \n Uranus, Neptune.\n Write a function that takes two planet names as strings planet1 and planet2. \n The function should return a tuple containing all planets whose orbits are \n located between the orbit of planet1 and the orbit of planet2, sorted by \n the proximity to the sun. \n The function should return an empty tuple if planet1 or planet2\n are not correct planet names. \n Examples\n bf(\"Jupiter\", \"Neptune\") ==> (\"Saturn\", \"Uranus\")\n bf(\"Earth\", \"Mercury\") ==> (\"Venus\")\n bf(\"Mercury\", \"Uranus\") ==> (\"Venus\", \"Earth\", \"Mars\", \"Jupiter\", \"Saturn\")\n '''\n"} 150 | {"instructions":"\ndef sorted_list_sum(lst):\n \"\"\"Write a function that accepts a list of strings as a parameter,\n deletes the strings that have odd lengths from it,\n and returns the resulted list with a sorted order,\n The list is always a list of strings and never an array of numbers,\n and it may contain duplicates.\n The order of the list should be ascending by length of each word, and you\n should return the list sorted by that rule.\n If two words have the same length, sort the list alphabetically.\n The function should return a list of strings in sorted order.\n You may assume that all words will have the same length.\n For example:\n assert list_sort([\"aa\", \"a\", \"aaa\"]) => [\"aa\"]\n assert list_sort([\"ab\", \"a\", \"aaa\", \"cd\"]) => [\"ab\", \"cd\"]\n \"\"\"\n"} 151 | {"instructions":"\ndef x_or_y(n, x, y):\n \"\"\"A simple program which should return the value of x if n is \n a prime number and should return the value of y otherwise.\n\n Examples:\n for x_or_y(7, 34, 12) == 34\n for x_or_y(15, 8, 5) == 5\n \n \"\"\"\n"} 152 | {"instructions":"\ndef double_the_difference(lst):\n '''\n Given a list of numbers, return the sum of squares of the numbers\n in the list that are odd. Ignore numbers that are negative or not integers.\n \n double_the_difference([1, 3, 2, 0]) == 1 + 9 + 0 + 0 = 10\n double_the_difference([-1, -2, 0]) == 0\n double_the_difference([9, -2]) == 81\n double_the_difference([0]) == 0 \n \n If the input list is empty, return 0.\n '''\n"} 153 | {"instructions":"\ndef compare(game,guess):\n \"\"\"I think we all remember that feeling when the result of some long-awaited\n event is finally known. The feelings and thoughts you have at that moment are\n definitely worth noting down and comparing.\n Your task is to determine if a person correctly guessed the results of a number of matches.\n You are given two arrays of scores and guesses of equal length, where each index shows a match. \n Return an array of the same length denoting how far off each guess was. If they have guessed correctly,\n the value is 0, and if not, the value is the absolute difference between the guess and the score.\n \n \n example:\n\n compare([1,2,3,4,5,1],[1,2,3,4,2,-2]) -> [0,0,0,0,3,3]\n compare([0,5,0,0,0,4],[4,1,1,0,0,-2]) -> [4,4,1,0,0,6]\n \"\"\"\n"} 154 | {"instructions":"\ndef Strongest_Extension(class_name, extensions):\n \"\"\"You will be given the name of a class (a string) and a list of extensions.\n The extensions are to be used to load additional classes to the class. The\n strength of the extension is as follows: Let CAP be the number of the uppercase\n letters in the extension's name, and let SM be the number of lowercase letters \n in the extension's name, the strength is given by the fraction CAP - SM. \n You should find the strongest extension and return a string in this \n format: ClassName.StrongestExtensionName.\n If there are two or more extensions with the same strength, you should\n choose the one that comes first in the list.\n For example, if you are given \"Slices\" as the class and a list of the\n extensions: ['SErviNGSliCes', 'Cheese', 'StuFfed'] then you should\n return 'Slices.SErviNGSliCes' since 'SErviNGSliCes' is the strongest extension \n (its strength is -1).\n Example:\n for Strongest_Extension('my_class', ['AA', 'Be', 'CC']) == 'my_class.AA'\n \"\"\"\n"} 155 | {"instructions":"\ndef cycpattern_check(a , b):\n \"\"\"You are given 2 words. You need to return True if the second word or any of its rotations is a substring in the first word\n cycpattern_check(\"abcd\",\"abd\") => False\n cycpattern_check(\"hello\",\"ell\") => True\n cycpattern_check(\"whassup\",\"psus\") => False\n cycpattern_check(\"abab\",\"baa\") => True\n cycpattern_check(\"efef\",\"eeff\") => False\n cycpattern_check(\"himenss\",\"simen\") => True\n\n \"\"\"\n"} 156 | {"instructions":"\ndef even_odd_count(num):\n \"\"\"Given an integer. return a tuple that has the number of even and odd digits respectively.\n\n Example:\n even_odd_count(-12) ==> (1, 1)\n even_odd_count(123) ==> (1, 2)\n \"\"\"\n"} 157 | {"instructions":"\ndef int_to_mini_roman(number):\n \"\"\"\n Given a positive integer, obtain its roman numeral equivalent as a string,\n and return it in lowercase.\n Restrictions: 1 <= num <= 1000\n\n Examples:\n >>> int_to_mini_roman(19) == 'xix'\n >>> int_to_mini_roman(152) == 'clii'\n >>> int_to_mini_roman(426) == 'cdxxvi'\n \"\"\"\n"} 158 | {"instructions":"\ndef right_angle_triangle(a, b, c):\n '''\n Given the lengths of the three sides of a triangle. Return True if the three\n sides form a right-angled triangle, False otherwise.\n A right-angled triangle is a triangle in which one angle is right angle or \n 90 degree.\n Example:\n right_angle_triangle(3, 4, 5) == True\n right_angle_triangle(1, 2, 3) == False\n '''\n"} 159 | {"instructions":"\ndef find_max(words):\n \"\"\"Write a function that accepts a list of strings.\n The list contains different words. Return the word with maximum number\n of unique characters. If multiple strings have maximum number of unique\n characters, return the one which comes first in lexicographical order.\n\n find_max([\"name\", \"of\", \"string\"]) == \"string\"\n find_max([\"name\", \"enam\", \"game\"]) == \"enam\"\n find_max([\"aaaaaaa\", \"bb\" ,\"cc\"]) == \"\"aaaaaaa\"\n \"\"\"\n"} 160 | {"instructions":"\ndef eat(number, need, remaining):\n \"\"\"\n You're a hungry rabbit, and you already have eaten a certain number of carrots,\n but now you need to eat more carrots to complete the day's meals.\n you should return an array of [ total number of eaten carrots after your meals,\n the number of carrots left after your meals ]\n if there are not enough remaining carrots, you will eat all remaining carrots, but will still be hungry.\n \n Example:\n * eat(5, 6, 10) -> [11, 4]\n * eat(4, 8, 9) -> [12, 1]\n * eat(1, 10, 10) -> [11, 0]\n * eat(2, 11, 5) -> [7, 0]\n \n Variables:\n @number : integer\n the number of carrots that you have eaten.\n @need : integer\n the number of carrots that you need to eat.\n @remaining : integer\n the number of remaining carrots thet exist in stock\n \n Constrain:\n * 0 <= number <= 1000\n * 0 <= need <= 1000\n * 0 <= remaining <= 1000\n\n Have fun :)\n \"\"\"\n"} 161 | {"instructions":"\ndef do_algebra(operator, operand):\n \"\"\"\n Given two lists operator, and operand. The first list has basic algebra operations, and \n the second list is a list of integers. Use the two given lists to build the algebric \n expression and return the evaluation of this expression.\n\n The basic algebra operations:\n Addition ( + ) \n Subtraction ( - ) \n Multiplication ( * ) \n Floor division ( \/\/ ) \n Exponentiation ( ** ) \n\n Example:\n operator['+', '*', '-']\n array = [2, 3, 4, 5]\n result = 2 + 3 * 4 - 5\n => result = 9\n\n Note:\n The length of operator list is equal to the length of operand list minus one.\n Operand is a list of of non-negative integers.\n Operator list has at least one operator, and operand list has at least two operands.\n\n \"\"\"\n"} 162 | {"instructions":"\ndef solve(s):\n \"\"\"You are given a string s.\n if s[i] is a letter, reverse its case from lower to upper or vise versa, \n otherwise keep it as it is.\n If the string contains no letters, reverse the string.\n The function should return the resulted string.\n Examples\n solve(\"1234\") = \"4321\"\n solve(\"ab\") = \"AB\"\n solve(\"#a@C\") = \"#A@c\"\n \"\"\"\n"} 163 | {"instructions":"\ndef string_to_md5(text):\n \"\"\"\n Given a string 'text', return its md5 hash equivalent string.\n If 'text' is an empty string, return None.\n\n >>> string_to_md5('Hello world') == '3e25960a79dbc69b674cd4ec67a72c62'\n \"\"\"\n"} 164 | {"instructions":"\ndef generate_integers(a, b):\n \"\"\"\n Given two positive integers a and b, return the even digits between a\n and b, in ascending order.\n\n For example:\n generate_integers(2, 8) => [2, 4, 6, 8]\n generate_integers(8, 2) => [2, 4, 6, 8]\n generate_integers(10, 14) => []\n \"\"\"\n"} 165 | -------------------------------------------------------------------------------- /data/mbpp/shots.md: -------------------------------------------------------------------------------- 1 | You are an expert Python programmer, and here is your task: Write a function to find the similar elements from the given two tuple lists. Your code should pass these tests: 2 | 3 | assert similar_elements((3, 4, 5, 6),(5, 7, 4, 10)) == (4, 5) 4 | assert similar_elements((1, 2, 3, 4),(5, 4, 3, 7)) == (3, 4) 5 | assert similar_elements((11, 12, 14, 13),(17, 15, 14, 13)) == (13, 14) 6 | [BEGIN] 7 | def similar_elements(test_tup1, test_tup2): 8 | res = tuple(set(test_tup1) & set(test_tup2)) 9 | return (res) 10 | [DONE] 11 | 12 | You are an expert Python programmer, and here is your task: Write a python function to identify non-prime numbers. Your code should pass these tests: 13 | 14 | assert is_not_prime(2) == False 15 | assert is_not_prime(10) == True 16 | assert is_not_prime(35) == True 17 | [BEGIN] 18 | import math 19 | def is_not_prime(n): 20 | result = False 21 | for i in range(2,int(math.sqrt(n)) + 1): 22 | if n % i == 0: 23 | result = True 24 | return result 25 | [DONE] 26 | 27 | You are an expert Python programmer, and here is your task: Write a function to find squares of individual elements in a list using lambda function. Your code should pass these tests: 28 | 29 | assert square_nums([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])==[1, 4, 9, 16, 25, 36, 49, 64, 81, 100] 30 | assert square_nums([10,20,30])==([100,400,900]) 31 | assert square_nums([12,15])==([144,225]) 32 | [BEGIN] 33 | def square_nums(nums): 34 | square_nums = list(map(lambda x: x ** 2, nums)) 35 | return square_nums 36 | [DONE] -------------------------------------------------------------------------------- /data/strategy_qa/shots.md: -------------------------------------------------------------------------------- 1 | Question: Do hamsters provide food for any animals? 2 | Answer: Hamsters are prey animals. Prey are food for predators. Thus, hamsters provide food for some animals. So the answer is yes. 3 | 4 | Question: Could Brooke Shields succeed at University of Pennsylvania? 5 | Answer: Brooke Shields went to Princeton University. Princeton University is about as academically rigorous as the University of Pennsylvania. Thus, Brooke Shields could also succeed at the University of Pennsylvania. So the answer is yes. 6 | 7 | Question: Yes or no: Hydrogen's atomic number squared exceeds number of Spice Girls? 8 | Answer: Hydrogen has an atomic number of 1. 1 squared is 1. There are 5 Spice Girls. Thus, Hydrogen's atomic number squared is less than 5. So the answer is no. 9 | 10 | Question: Yes or no: Is it common to see frost during some college commencements? 11 | Answer: College commencement ceremonies can happen in December, May, and June. December is in the winter, so there can be frost. Thus, there could be frost at some commencements. So the answer is yes. -------------------------------------------------------------------------------- /decodingmethod/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import ( 2 | dola, 3 | contrastive_decoding, 4 | ) 5 | 6 | from .moe_utils import ( 7 | scmoe, 8 | scmoe_with_sampling, 9 | ) -------------------------------------------------------------------------------- /decodingmethod/moe_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | from transformers.generation.stopping_criteria import ( 4 | StoppingCriteriaList, 5 | ) 6 | 7 | @torch.no_grad() 8 | def scmoe( 9 | model, 10 | teacher_t, 11 | student_t, 12 | tokenizer, 13 | input_ids, 14 | attention_mask, 15 | max_new_tokens, 16 | eos_token_id=None, 17 | early_stop=False, 18 | alpha=0.1, 19 | beta=0.5, 20 | stopping_criteria=None, 21 | teacher_routed_tok=[0, 1], 22 | teacher_num_experts_per_tok=2, 23 | student_routed_tok=[0], 24 | student_num_experts_per_tok=1, 25 | ): 26 | 27 | batch_size, prefix_len = input_ids.size() 28 | model_kwargs = {} 29 | model_kwargs_student = {} 30 | model_kwargs["attention_mask"] = attention_mask 31 | model_kwargs_student["attention_mask"] = attention_mask 32 | eos_token_id = eos_token_id if eos_token_id is not None else tokenizer.eos_token_id 33 | eos_token_id_tensor = ( 34 | torch.tensor([eos_token_id]).to(model.device) 35 | if eos_token_id is not None 36 | else None 37 | ) 38 | unfinished_sequences = input_ids.new(batch_size).fill_(1) 39 | stopping_criteria = ( 40 | stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() 41 | ) 42 | for step in range(max_new_tokens): 43 | model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs) 44 | outputs = model( 45 | **model_inputs, 46 | return_dict=True, 47 | output_hidden_states=False, 48 | routed_tok=teacher_routed_tok, 49 | num_experts_per_tok=teacher_num_experts_per_tok, 50 | ) 51 | next_token_scores = outputs.logits[:, -1, :] 52 | next_token_scores = next_token_scores / teacher_t 53 | cutoff = ( 54 | torch.log(torch.tensor(alpha, device=next_token_scores.device)) 55 | + next_token_scores.max(dim=-1, keepdim=True).values 56 | ) 57 | 58 | model_inputs_student = model.prepare_inputs_for_generation( 59 | input_ids, **model_kwargs_student 60 | ) 61 | outputs_student = model( 62 | **model_inputs_student, 63 | return_dict=True, 64 | output_hidden_states=False, 65 | routed_tok=student_routed_tok, 66 | num_experts_per_tok=student_num_experts_per_tok, 67 | ) 68 | next_token_logits_student = outputs_student.logits[:, -1, :] 69 | next_token_logits_student = next_token_logits_student / student_t 70 | diffs = (1 + beta) * next_token_scores - beta * next_token_logits_student 71 | cdlogits = diffs.masked_fill(next_token_scores < cutoff, -float("inf")) 72 | if not early_stop and eos_token_id != None: 73 | cdlogits[:, eos_token_id] = -float("inf") 74 | 75 | next_tokens = torch.argmax(cdlogits, dim=-1) 76 | next_tokens = next_tokens * unfinished_sequences + tokenizer.pad_token_id * ( 77 | 1 - unfinished_sequences 78 | ) 79 | input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) 80 | 81 | if eos_token_id_tensor is not None: 82 | unfinished_sequences = unfinished_sequences.mul( 83 | next_tokens.tile(eos_token_id_tensor.shape[0], 1) 84 | .ne(eos_token_id_tensor.unsqueeze(1)) 85 | .prod(dim=0) 86 | ) 87 | 88 | unfinished_sequences = unfinished_sequences & ~stopping_criteria( 89 | input_ids, None 90 | ) 91 | if unfinished_sequences.max() == 0 or step == max_new_tokens - 1: 92 | stopped = True 93 | else: 94 | stopped = False 95 | 96 | if stopped: 97 | break 98 | 99 | model_kwargs = model._update_model_kwargs_for_generation( 100 | outputs, 101 | model_kwargs, 102 | is_encoder_decoder=model.config.is_encoder_decoder, 103 | ) 104 | model_kwargs_student = model._update_model_kwargs_for_generation( 105 | outputs_student, 106 | model_kwargs_student, 107 | is_encoder_decoder=model.config.is_encoder_decoder, 108 | ) 109 | return input_ids 110 | 111 | @torch.no_grad() 112 | def scmoe_with_sampling( 113 | model, 114 | teacher_t, 115 | student_t, 116 | tokenizer, 117 | input_ids, 118 | attention_mask, 119 | max_new_tokens, 120 | eos_token_id=None, 121 | early_stop=False, 122 | alpha=0.1, 123 | beta=0.5, 124 | stopping_criteria=None, 125 | teacher_routed_tok=[0, 1], 126 | teacher_num_experts_per_tok=2, 127 | student_routed_tok=[0], 128 | student_num_experts_per_tok=1, 129 | ): 130 | 131 | batch_size, prefix_len = input_ids.size() 132 | model_kwargs = {} 133 | model_kwargs_student = {} 134 | model_kwargs["attention_mask"] = attention_mask 135 | model_kwargs_student["attention_mask"] = attention_mask 136 | eos_token_id = eos_token_id if eos_token_id is not None else tokenizer.eos_token_id 137 | eos_token_id_tensor = ( 138 | torch.tensor([eos_token_id]).to(model.device) 139 | if eos_token_id is not None 140 | else None 141 | ) 142 | unfinished_sequences = input_ids.new(batch_size).fill_(1) 143 | stopping_criteria = ( 144 | stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() 145 | ) 146 | for step in range(max_new_tokens): 147 | model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs) 148 | outputs = model( 149 | **model_inputs, 150 | return_dict=True, 151 | output_hidden_states=False, 152 | routed_tok=teacher_routed_tok, 153 | num_experts_per_tok=teacher_num_experts_per_tok, 154 | ) 155 | next_token_scores = outputs.logits[:, -1, :] 156 | next_token_scores = next_token_scores / teacher_t 157 | cutoff = ( 158 | torch.log(torch.tensor(alpha, device=next_token_scores.device)) 159 | + next_token_scores.max(dim=-1, keepdim=True).values 160 | ) 161 | 162 | model_inputs_student = model.prepare_inputs_for_generation( 163 | input_ids, **model_kwargs_student 164 | ) 165 | outputs_student = model( 166 | **model_inputs_student, 167 | return_dict=True, 168 | output_hidden_states=False, 169 | routed_tok=student_routed_tok, 170 | num_experts_per_tok=student_num_experts_per_tok, 171 | ) 172 | next_token_logits_student = outputs_student.logits[:, -1, :] 173 | next_token_logits_student = next_token_logits_student / student_t 174 | diffs = (1 + beta) * next_token_scores - beta * next_token_logits_student 175 | cdlogits = diffs.masked_fill(next_token_scores < cutoff, -float("inf")) 176 | if not early_stop and eos_token_id != None: 177 | cdlogits[:, eos_token_id] = -float("inf") 178 | 179 | cdscores = F.softmax(cdlogits, dim=-1) 180 | next_tokens = torch.multinomial(cdscores, num_samples=1).squeeze(-1) 181 | next_tokens = next_tokens * unfinished_sequences + tokenizer.pad_token_id * ( 182 | 1 - unfinished_sequences 183 | ) 184 | input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) 185 | 186 | if eos_token_id_tensor is not None: 187 | unfinished_sequences = unfinished_sequences.mul( 188 | next_tokens.tile(eos_token_id_tensor.shape[0], 1) 189 | .ne(eos_token_id_tensor.unsqueeze(1)) 190 | .prod(dim=0) 191 | ) 192 | 193 | unfinished_sequences = unfinished_sequences & ~stopping_criteria( 194 | input_ids, None 195 | ) 196 | if unfinished_sequences.max() == 0 or step == max_new_tokens - 1: 197 | stopped = True 198 | else: 199 | stopped = False 200 | 201 | if stopped: 202 | break 203 | 204 | model_kwargs = model._update_model_kwargs_for_generation( 205 | outputs, 206 | model_kwargs, 207 | is_encoder_decoder=model.config.is_encoder_decoder, 208 | ) 209 | model_kwargs_student = model._update_model_kwargs_for_generation( 210 | outputs_student, 211 | model_kwargs_student, 212 | is_encoder_decoder=model.config.is_encoder_decoder, 213 | ) 214 | return input_ids -------------------------------------------------------------------------------- /decodingmethod/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | import numpy as np 4 | from transformers.generation.stopping_criteria import ( 5 | StoppingCriteriaList, 6 | ) 7 | 8 | from transformers.generation.logits_process import ( 9 | RepetitionPenaltyLogitsProcessor, 10 | LogitsProcessorList, 11 | ) 12 | 13 | 14 | def relative_top_filter( 15 | scores: torch.FloatTensor, 16 | relative_top: float = 0.1, 17 | filter_value: float = -float("Inf"), 18 | min_tokens_to_keep: int = 1, 19 | ) -> torch.FloatTensor: 20 | scores_normalized = scores.log_softmax(dim=-1) 21 | sorted_logits, sorted_indices = torch.sort(scores_normalized, descending=True) 22 | min_thresh = sorted_logits[..., min_tokens_to_keep - 1] 23 | probs_max = torch.max(scores_normalized, dim=-1).values 24 | probs_thresh = probs_max + np.log(relative_top) 25 | probs_thresh = torch.min(min_thresh, probs_thresh) 26 | probs_thresh = probs_thresh.unsqueeze(-1) 27 | scores_normalized[scores_normalized < probs_thresh] = filter_value 28 | return scores_normalized 29 | 30 | 31 | @torch.no_grad() 32 | def dola( 33 | model, 34 | tokenizer, 35 | input_ids, 36 | attention_mask, 37 | max_new_tokens=512, 38 | repetition_penalty=1.2, 39 | mature_layer=None, 40 | base_layer=None, 41 | candidate_premature_layers=None, 42 | relative_top=0.1, 43 | eos_token_id=None, 44 | stopping_criteria=None, 45 | early_stop=False, 46 | ): 47 | """ 48 | - k: top-k candidate words are selected, default 3 49 | - alpha: (1-alpha)p_lm -(alpha)*penalty 50 | - max_length: decoding max_length-prompt_length steps 51 | - n: the order of n-gram models 52 | - sw_coeff: give stopwords a small penalty (<1) or larger penalty(>1), default 0. 53 | - stop_words=[]: the list of stopwords. If you use GPT-2, you at least need to add two special tokens ('Ċ' and 'ĊĊ') to avoid grammars errors. 54 | """ 55 | batch_size, prefix_len = input_ids.size() 56 | model_kwargs = {} 57 | prompt_len = torch.sum(attention_mask, dim=1) 58 | model_kwargs["attention_mask"] = attention_mask 59 | 60 | eos_token_id = eos_token_id if eos_token_id is not None else tokenizer.eos_token_id 61 | 62 | eos_token_id_tensor = ( 63 | torch.tensor([eos_token_id]).to(model.device) 64 | if eos_token_id is not None 65 | else None 66 | ) 67 | unfinished_sequences = input_ids.new(batch_size).fill_(1) 68 | 69 | early_exit_layers = candidate_premature_layers + [mature_layer] 70 | premature_layer_dist = {l: 0 for l in candidate_premature_layers} 71 | 72 | processors = LogitsProcessorList() 73 | processors.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)) 74 | for step in range(max_new_tokens): 75 | model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs) 76 | # print("model inputs:",model_inputs) 77 | outputs = model(**model_inputs, return_dict=True, output_hidden_states=True) 78 | if early_exit_layers is not None: 79 | dict_outputs = {} 80 | # loss_dict = {} 81 | for i, early_exit_layer in enumerate(early_exit_layers): 82 | # print(outputs.hidden_states.shape) 83 | # print(early_exit_layer) 84 | logits = model.lm_head(outputs.hidden_states[early_exit_layer]) 85 | dict_outputs[early_exit_layer] = logits 86 | 87 | if base_layer is not None: 88 | base_logits = dict_outputs[base_layer][:, -1, :] 89 | final_logits = dict_outputs[mature_layer][:, -1, :] 90 | if relative_top > 0.0: 91 | final_logits = relative_top_filter(final_logits, relative_top) 92 | base_logits = base_logits.log_softmax(dim=-1) 93 | mask = final_logits[0] < -1e3 94 | base_logits[0][mask] = -1e3 95 | 96 | logits = final_logits - base_logits 97 | next_token_logits = logits 98 | else: 99 | # 1. Stacking all premature_layers into a new dimension 100 | stacked_premature_layers = torch.stack( 101 | [dict_outputs[i][:, -1, :] for i in candidate_premature_layers], dim=0 102 | ) 103 | 104 | # 2. Calculate the softmax values for mature_layer and all premature_layers 105 | softmax_mature_layer = F.softmax( 106 | dict_outputs[mature_layer][:, -1, :], dim=-1 107 | ) # shape: (batch_size, num_features) 108 | softmax_premature_layers = F.softmax( 109 | stacked_premature_layers, dim=-1 110 | ) # shape: (num_premature_layers, batch_size, num_features) 111 | 112 | # 3. Calculate M, the average distribution 113 | M = 0.5 * ( 114 | softmax_mature_layer[None, :, :] + softmax_premature_layers 115 | ) # shape: (num_premature_layers, batch_size, num_features) 116 | 117 | # 4. Calculate log-softmax for the KL divergence 118 | log_softmax_mature_layer = F.log_softmax( 119 | dict_outputs[mature_layer][:, -1, :], dim=-1 120 | ) # shape: (batch_size, num_features) 121 | log_softmax_premature_layers = F.log_softmax( 122 | stacked_premature_layers, dim=-1 123 | ) # shape: (num_premature_layers, batch_size, num_features) 124 | 125 | # 5. Calculate the KL divergences and then the JS divergences 126 | kl1 = F.kl_div( 127 | log_softmax_mature_layer[None, :, :], M, reduction="none" 128 | ).mean( 129 | -1 130 | ) # shape: (num_premature_layers, batch_size) 131 | kl2 = F.kl_div(log_softmax_premature_layers, M, reduction="none").mean( 132 | -1 133 | ) # shape: (num_premature_layers, batch_size) 134 | js_divs = 0.5 * (kl1 + kl2) # shape: (num_premature_layers, batch_size) 135 | 136 | # 6. Reduce the batchmean 137 | js_divs = js_divs.mean(-1) # shape: (num_premature_layers,) 138 | premature_layer = candidate_premature_layers[ 139 | int(js_divs.argmax().cpu().item()) 140 | ] 141 | premature_layer_dist[premature_layer] += 1 142 | 143 | base_logits = dict_outputs[premature_layer][:, -1, :] 144 | final_logits = dict_outputs[mature_layer][:, -1, :] 145 | 146 | if relative_top > 0.0: 147 | final_logits = relative_top_filter(final_logits, relative_top) 148 | base_logits = base_logits.log_softmax(dim=-1) 149 | mask = final_logits[0] < -1e3 150 | base_logits[0][mask] = -1e3 151 | logits = final_logits - base_logits 152 | next_token_logits = logits 153 | # pre-process distribution 154 | import copy 155 | 156 | new_next_token_logits = copy.deepcopy(next_token_logits) 157 | new_next_token_logits = new_next_token_logits.to(input_ids.device) 158 | next_tokens_scores = processors(input_ids, new_next_token_logits) 159 | 160 | # avoid generating eos 161 | if not early_stop and eos_token_id != None: 162 | next_tokens_scores[:, eos_token_id] = -float("inf") 163 | 164 | next_tokens = torch.argmax(next_tokens_scores, dim=-1) 165 | # fsd-vec 166 | next_tokens = next_tokens * unfinished_sequences + tokenizer.pad_token_id * ( 167 | 1 - unfinished_sequences 168 | ) 169 | 170 | input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) 171 | 172 | if eos_token_id_tensor is not None: 173 | unfinished_sequences = unfinished_sequences.mul( 174 | next_tokens.tile(eos_token_id_tensor.shape[0], 1) 175 | .ne(eos_token_id_tensor.unsqueeze(1)) 176 | .prod(dim=0) 177 | ) 178 | 179 | unfinished_sequences = unfinished_sequences & ~stopping_criteria( 180 | input_ids, None 181 | ) 182 | 183 | if unfinished_sequences.max() == 0 or step == max_new_tokens: 184 | stopped = True 185 | else: 186 | stopped = False 187 | 188 | if stopped: 189 | break 190 | 191 | model_kwargs = model._update_model_kwargs_for_generation( 192 | outputs, model_kwargs, is_encoder_decoder=model.config.is_encoder_decoder 193 | ) 194 | return input_ids 195 | 196 | @torch.no_grad() 197 | def contrastive_decoding( 198 | teacher_model, 199 | student_model, 200 | teacher_t, 201 | student_t, 202 | tokenizer, 203 | input_ids, 204 | attention_mask, 205 | max_new_tokens, 206 | eos_token_id=None, 207 | early_stop=False, 208 | alpha=0.1, 209 | beta=0.5, 210 | stopping_criteria=None, 211 | ): 212 | # formulation of "CONTRASTIVE DECODING IMPROVES REASONING IN LARGE LANGUAGE MODELS" 213 | batch_size, prefix_len = input_ids.size() 214 | model_kwargs = {} 215 | model_kwargs_student = {} 216 | model_kwargs["attention_mask"] = attention_mask 217 | model_kwargs_student["attention_mask"] = attention_mask 218 | eos_token_id = eos_token_id if eos_token_id is not None else tokenizer.eos_token_id 219 | eos_token_id_tensor = ( 220 | torch.tensor([eos_token_id]).to(teacher_model.device) 221 | if eos_token_id is not None 222 | else None 223 | ) 224 | unfinished_sequences = input_ids.new(batch_size).fill_(1) 225 | stopping_criteria = ( 226 | stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() 227 | ) 228 | for step in range(max_new_tokens): 229 | model_inputs = teacher_model.prepare_inputs_for_generation( 230 | input_ids, **model_kwargs 231 | ) 232 | outputs = teacher_model( 233 | **model_inputs, return_dict=True, output_hidden_states=True 234 | ) 235 | next_token_scores = outputs.logits[:, -1, :] 236 | next_token_scores = next_token_scores / teacher_t 237 | cutoff = ( 238 | torch.log(torch.tensor(alpha, device=next_token_scores.device)) 239 | + next_token_scores.max(dim=-1, keepdim=True).values 240 | ) 241 | 242 | model_inputs_student = student_model.prepare_inputs_for_generation( 243 | input_ids, **model_kwargs_student 244 | ) 245 | outputs_student = student_model( 246 | **model_inputs_student, 247 | return_dict=True, 248 | output_attentions=True, 249 | output_hidden_states=True, 250 | ) 251 | 252 | next_token_logits_student = outputs_student.logits[:, -1, :] 253 | next_token_logits_student = next_token_logits_student / student_t 254 | diffs = (1 + beta) * next_token_scores - beta * next_token_logits_student 255 | cdlogits = diffs.masked_fill(next_token_scores < cutoff, -float("inf")) 256 | if not early_stop and eos_token_id != None: 257 | cdlogits[:, eos_token_id] = -float("inf") 258 | 259 | # next_tokens_scores = next_token_scores - alpha * next_token_logits_student 260 | next_tokens = torch.argmax(cdlogits, dim=-1) 261 | next_tokens = next_tokens * unfinished_sequences + tokenizer.pad_token_id * ( 262 | 1 - unfinished_sequences 263 | ) 264 | input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) 265 | 266 | if eos_token_id_tensor is not None: 267 | unfinished_sequences = unfinished_sequences.mul( 268 | next_tokens.tile(eos_token_id_tensor.shape[0], 1) 269 | .ne(eos_token_id_tensor.unsqueeze(1)) 270 | .prod(dim=0) 271 | ) 272 | 273 | unfinished_sequences = unfinished_sequences & ~stopping_criteria( 274 | input_ids, None 275 | ) 276 | if unfinished_sequences.max() == 0 or step == max_new_tokens - 1: 277 | stopped = True 278 | else: 279 | stopped = False 280 | 281 | if stopped: 282 | break 283 | 284 | model_kwargs = teacher_model._update_model_kwargs_for_generation( 285 | outputs, 286 | model_kwargs, 287 | is_encoder_decoder=teacher_model.config.is_encoder_decoder, 288 | ) 289 | model_kwargs_student = student_model._update_model_kwargs_for_generation( 290 | outputs_student, 291 | model_kwargs_student, 292 | is_encoder_decoder=student_model.config.is_encoder_decoder, 293 | ) 294 | return input_ids -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | import fnmatch 2 | import json 3 | import warnings 4 | 5 | import datasets 6 | import torch 7 | import transformers 8 | import argparse 9 | from lm_eval.evaluator import Evaluator 10 | from lm_eval.tasks import ALL_TASKS 11 | 12 | def parse_args(): 13 | parser = argparse.ArgumentParser() 14 | 15 | parser.add_argument( 16 | "--model", 17 | default="codeparrot/codeparrot-small", 18 | help="Model to evaluate, provide a repo name in Hugging Face hub or a local path", 19 | ) 20 | parser.add_argument( 21 | "--task_name", 22 | default=None, 23 | help="Task to evaluate on, can be a single task", 24 | ) 25 | parser.add_argument( 26 | "--allow_code_execution", 27 | action="store_true", 28 | help="Allow code evaluation to execute external/untrusted Python code on your machine", 29 | ) 30 | parser.add_argument( 31 | "--load_generations_path", 32 | type=str, 33 | default=None, 34 | help="Path of file with previously generated solutions, if provided generation is skipped and only evaluation is done", 35 | ) 36 | parser.add_argument( 37 | "--metric_output_path", 38 | type=str, 39 | default="evaluation_results.json", 40 | help="Path to save the results", 41 | ) 42 | parser.add_argument( 43 | "--check_references", 44 | action="store_true", 45 | help="Don't run generation but benchmark groundtruth (useful for debugging)", 46 | ) 47 | 48 | parser.add_argument( 49 | "--postprocessed_output_path", 50 | type=str, 51 | default=None, 52 | help="Path to save the postprocessed generations", 53 | ) 54 | 55 | return parser.parse_args() 56 | 57 | if __name__ == "__main__": 58 | args = parse_args() 59 | print(f"Selected Tasks: {args.task_name}") 60 | 61 | results = {} 62 | if args.load_generations_path: 63 | # here we don't generate code but only evaluate previously computed generations 64 | print("evaluation only mode") 65 | evaluator = Evaluator(args) 66 | results["results"] = evaluator.evaluate(args.task_name) 67 | 68 | # Save all args to config 69 | for k,v in vars(args).items(): 70 | results[k] = v 71 | # Save jsonl 72 | dumped = json.dumps(results) 73 | print(dumped) 74 | with open(args.metric_output_path, "a+") as f: 75 | f.write(dumped+"\n") 76 | 77 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import transformers 3 | print(transformers.__version__) 4 | import json 5 | import torch 6 | import argparse 7 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig 8 | import os 9 | import time 10 | from tqdm import tqdm 11 | from decodingmethod import ( 12 | dola, 13 | contrastive_decoding, 14 | scmoe, 15 | scmoe_with_sampling, 16 | ) 17 | from multiprocessing import Process 18 | import numpy as np 19 | import pandas as pd 20 | from transformers.generation.stopping_criteria import ( 21 | StoppingCriteria, 22 | StoppingCriteriaList, 23 | STOPPING_CRITERIA_INPUTS_DOCSTRING, 24 | add_start_docstrings, 25 | ) 26 | from typing import Any, Dict, List, Optional, Tuple, Union, Sequence 27 | from itertools import combinations 28 | 29 | 30 | def generate_combinations(total_length, num_to_remove): 31 | indices = range(total_length) 32 | result = list(combinations(indices, num_to_remove)) 33 | return [list(i) for i in result] 34 | 35 | 36 | class StopAtSpecificTokenCriteria(StoppingCriteria): 37 | 38 | def __init__(self, token_id_list): 39 | self.token_id_list = token_id_list 40 | self.stop_tag = None 41 | 42 | @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) 43 | def __call__( 44 | self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs 45 | ) -> bool: 46 | for _ in range(len(self.token_id_list)): 47 | stop_states = [ 48 | np.array_equal( 49 | self.token_id_list[_], 50 | input_ids[i][-len(self.token_id_list[_]) :].detach().cpu().numpy(), 51 | ) 52 | for i in range(input_ids.size(0)) 53 | ] 54 | if self.stop_tag is None: 55 | self.stop_tag = stop_states 56 | else: 57 | self.stop_tag = [ 58 | self.stop_tag[i] or stop_states[i] 59 | for i in range(len(self.stop_tag)) 60 | ] 61 | if all(self.stop_tag): 62 | self.stop_tag = None 63 | return True 64 | return False 65 | 66 | 67 | def args_parse(): 68 | parser = argparse.ArgumentParser(description="Process some integers.") 69 | parser.add_argument( 70 | "--infile", type=str, help="the data used for instructing tuning" 71 | ) 72 | parser.add_argument( 73 | "--outfile", type=str, help="the data used for instructing tuning" 74 | ) 75 | parser.add_argument("--batch_size", type=int, default=4) 76 | parser.add_argument("--world_size", type=int, default=4) 77 | parser.add_argument("--max_length", type=int, default=4096) 78 | parser.add_argument("--max_new_tokens", type=int, default=256) 79 | parser.add_argument("--early_stop", action="store_true") 80 | parser.add_argument("--decoding_method", type=str, default="topp") 81 | parser.add_argument("--gpus_per_model", type=int, default=2) 82 | parser.add_argument( 83 | "--model_name_or_path", default="", type=str 84 | ) 85 | 86 | parser.add_argument( 87 | "--student_model_name_or_path", 88 | default="", 89 | type=str, 90 | ) 91 | 92 | parser.add_argument("--cs_alpha", default=0.6, type=float) 93 | parser.add_argument("--cs_k", default=5, type=int) 94 | 95 | parser.add_argument( 96 | "--cd_alpha", default=0.1, type=float 97 | ) 98 | parser.add_argument("--cd_beta", default=0.5, type=float) 99 | parser.add_argument( 100 | "--cd_tt", default=1.0, type=float, help="teacher temperature" 101 | ) 102 | parser.add_argument("--cd_st", default=1.0, type=float, help="student temperature") 103 | 104 | parser.add_argument("--dola_early_exit_layers", default="0,2,4,6,8,10,12,14,32", type=str) 105 | parser.add_argument("--dola_mature_layer", default=32, type=int) 106 | 107 | parser.add_argument("--num_experts_per_tok", default=2, type=int) 108 | parser.add_argument("--routed_tok", default=0, type=int) 109 | parser.add_argument("--student_num_experts_per_tok", default=1, type=int) 110 | parser.add_argument("--student_routed_tok", default=0, type=int) 111 | 112 | parser.add_argument("--dynamic_expert_routing_threshold", default=0.6, type=float) 113 | 114 | args = parser.parse_args() 115 | return args 116 | 117 | 118 | def generate(rank, args): 119 | visible_devices = [ 120 | str(rank * args.gpus_per_model + i) 121 | for i in range(args.gpus_per_model) 122 | ] 123 | os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(visible_devices) 124 | tokenizer = AutoTokenizer.from_pretrained( 125 | args.model_name_or_path, trust_remote_code=True 126 | ) 127 | if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None: 128 | tokenizer.pad_token_id = tokenizer.eos_token_id 129 | print("pad_token_id is None, set it to eos_token_id") 130 | elif tokenizer.eos_token_id is None and tokenizer.pad_token_id is not None: 131 | tokenizer.eos_token_id = tokenizer.pad_token_id 132 | print("eos_token_id is None, set it to pad_token_id") 133 | elif tokenizer.eos_token_id is None and tokenizer.pad_token_id is None: 134 | print("both eos_token_id and pad_token_id are None") 135 | 136 | config = AutoConfig.from_pretrained(args.model_name_or_path, trust_remote_code=True) 137 | if "deepseek-moe" in args.model_name_or_path or "Mixtral" in args.model_name_or_path: 138 | config.num_experts_per_tok = args.num_experts_per_tok 139 | if args.routed_tok == 0: 140 | config.routed_tok = [_ for _ in range(args.num_experts_per_tok)] 141 | else: 142 | config.routed_tok = generate_combinations(8 if "Mixtral" in args.model_name_or_path else 64, args.num_experts_per_tok)[args.routed_tok] 143 | if args.decoding_method == "dynamic": 144 | config.dynamic_expert_routing_threshold = args.dynamic_expert_routing_threshold 145 | 146 | model = AutoModelForCausalLM.from_pretrained( 147 | args.model_name_or_path, 148 | torch_dtype=torch.bfloat16, 149 | device_map="auto", 150 | trust_remote_code=True, 151 | config=config, 152 | ).eval() 153 | 154 | if args.decoding_method == "cd": 155 | config = AutoConfig.from_pretrained( 156 | args.student_model_name_or_path, trust_remote_code=True 157 | ) 158 | student_model = AutoModelForCausalLM.from_pretrained( 159 | args.student_model_name_or_path, 160 | torch_dtype=torch.bfloat16, 161 | device_map="auto", 162 | config=config, 163 | trust_remote_code=True, 164 | ).eval() 165 | 166 | prompt_lst = [] 167 | 168 | with open(args.infile) as f: 169 | idx = 0 170 | for line in f.readlines(): 171 | d = json.loads(line.strip()) 172 | d["idx"] = idx 173 | prompt_lst.append(d) 174 | idx += 1 175 | 176 | print(f"the total number of prompts: {len(prompt_lst)}") 177 | prompt_lst = prompt_lst[rank :: args.num_processes] 178 | print(f"the total number of prompts for rank {rank}: {len(prompt_lst)}") 179 | 180 | s = time.time() 181 | 182 | for start in tqdm(range(0, len(prompt_lst), args.batch_size), disable=rank != 0): 183 | stopping_criteria = StoppingCriteriaList() 184 | if "deepseek-moe" in args.model_name_or_path: 185 | if "mbpp" in args.infile: 186 | stopping_criteria.append( 187 | StopAtSpecificTokenCriteria(i=i, token_id_list=[[58, 95742, 60]]) 188 | ) 189 | if "human_eval" in args.infile: 190 | stopping_criteria.append( 191 | StopAtSpecificTokenCriteria( 192 | i=i, 193 | token_id_list=[[185, 1558], [185, 351, 5589, 1531, 1442, 2318]], 194 | ) 195 | ) 196 | else: 197 | stopping_criteria.append( 198 | StopAtSpecificTokenCriteria(i=i, token_id_list=[[185, 185]]) 199 | ) 200 | elif "Mixtral" in args.model_name_or_path: 201 | if "mbpp" in args.infile: 202 | stopping_criteria.append( 203 | StopAtSpecificTokenCriteria( 204 | i=i, token_id_list=[[28792, 28757, 6349, 28793]] 205 | ) 206 | ) 207 | if "human_eval" in args.infile: 208 | stopping_criteria.append( 209 | StopAtSpecificTokenCriteria( 210 | i=i, token_id_list=[[13, 1270], [13, 335, 1848, 861, 860, 859]] 211 | ) 212 | ) 213 | else: 214 | stopping_criteria.append( 215 | StopAtSpecificTokenCriteria(i=i, token_id_list=[[13, 13]]) 216 | ) 217 | 218 | if start % 20 == 0 and rank == 0: 219 | print(f"rank {rank} has generated {start} prompts") 220 | cur_prompt_lst = prompt_lst[start : start + args.batch_size] 221 | prompt_text = [f"{x['instructions']}" for x in cur_prompt_lst] 222 | model_inputs = tokenizer( 223 | prompt_text, padding=True, add_special_tokens=True, return_tensors="pt" 224 | ) 225 | input_ids = model_inputs["input_ids"].to(model.device) 226 | attention_mask = model_inputs["attention_mask"].to(model.device) 227 | prompt_len = input_ids.size(1) 228 | if args.decoding_method == "greedy" or args.decoding_method == "dynamic": 229 | outputs = model.generate( 230 | input_ids, 231 | attention_mask=attention_mask, 232 | max_new_tokens=args.max_new_tokens, 233 | do_sample=False, 234 | stopping_criteria=stopping_criteria, 235 | ) 236 | if args.decoding_method == "cs": 237 | outputs = model.generate( 238 | input_ids, 239 | attention_mask=attention_mask, 240 | max_new_tokens=args.max_new_tokens, 241 | penalty_alpha=args.cs_alpha, 242 | top_k=args.cs_k, 243 | stopping_criteria=stopping_criteria, 244 | ) 245 | if args.decoding_method == "dola": 246 | early_exit_layers = [int(x) for x in args.dola_early_exit_layers.split(",")] 247 | outputs = dola( 248 | model, 249 | tokenizer, 250 | input_ids, 251 | attention_mask, 252 | max_new_tokens=args.max_new_tokens, 253 | repetition_penalty=1.2, 254 | mature_layer=args.dola_mature_layer, 255 | base_layer=None, 256 | candidate_premature_layers=early_exit_layers, 257 | relative_top=0.1, 258 | eos_token_id=None, 259 | stopping_criteria=stopping_criteria, 260 | early_stop=args.early_stop, 261 | ) 262 | if args.decoding_method == "cd": 263 | outputs = contrastive_decoding( 264 | model, 265 | student_model, 266 | teacher_t=args.cd_tt, 267 | student_t=args.cd_st, 268 | tokenizer=tokenizer, 269 | input_ids=input_ids, 270 | attention_mask=attention_mask, 271 | max_new_tokens=args.max_new_tokens, 272 | early_stop=True, 273 | alpha=args.cd_alpha, 274 | beta=args.cd_beta, 275 | stopping_criteria=stopping_criteria, 276 | ) 277 | if args.decoding_method == "scmoe": 278 | if args.routed_tok == 0: 279 | MoE_mapping_teacher = [[_ for _ in range(args.num_experts_per_tok)]] 280 | else: 281 | MoE_mapping_teacher = generate_combinations( 282 | 8 if "Mixtral" in args.model_name_or_path else 64, 283 | args.num_experts_per_tok, 284 | ) 285 | if (args.student_routed_tok == 8 or args.student_routed_tok == 64) and args.student_num_experts_per_tok == 1: 286 | MoE_mapping_student = [[_] for _ in range(args.student_routed_tok + 1)] 287 | elif args.student_routed_tok == 0: 288 | MoE_mapping_student = [[_ for _ in range(args.student_num_experts_per_tok)]] 289 | else: 290 | MoE_mapping_student = generate_combinations( 291 | 8 if "Mixtral" in args.student_model_name_or_path else 64, 292 | args.student_num_experts_per_tok, 293 | ) 294 | outputs = scmoe( 295 | model, 296 | teacher_t=args.cd_tt, 297 | student_t=args.cd_st, 298 | tokenizer=tokenizer, 299 | input_ids=input_ids, 300 | attention_mask=attention_mask, 301 | max_new_tokens=args.max_new_tokens, 302 | early_stop=args.early_stop, 303 | alpha=args.cd_alpha, 304 | beta=args.cd_beta, 305 | stopping_criteria=stopping_criteria, 306 | teacher_routed_tok=MoE_mapping_teacher[args.routed_tok], 307 | teacher_num_experts_per_tok=args.num_experts_per_tok, 308 | student_routed_tok=MoE_mapping_student[args.student_routed_tok], 309 | student_num_experts_per_tok=args.student_num_experts_per_tok, 310 | ) 311 | if args.decoding_method == "scmoe_with_sampling": 312 | if args.routed_tok == 0: 313 | MoE_mapping_teacher = [[_ for _ in range(args.num_experts_per_tok)]] 314 | else: 315 | MoE_mapping_teacher = generate_combinations( 316 | 8 if "Mixtral" in args.model_name_or_path else 64, 317 | args.num_experts_per_tok, 318 | ) 319 | if args.student_routed_tok == 0: 320 | MoE_mapping_student = [[_ for _ in range(args.student_num_experts_per_tok)]] 321 | else: 322 | MoE_mapping_student = generate_combinations( 323 | 8 if "Mixtral" in args.student_model_name_or_path else 64, 324 | args.student_num_experts_per_tok, 325 | ) 326 | outputs = scmoe_with_sampling( 327 | model, 328 | teacher_t=args.cd_tt, 329 | student_t=args.cd_st, 330 | tokenizer=tokenizer, 331 | input_ids=input_ids, 332 | attention_mask=attention_mask, 333 | max_new_tokens=args.max_new_tokens, 334 | early_stop=args.early_stop, 335 | alpha=args.cd_alpha, 336 | beta=args.cd_beta, 337 | stopping_criteria=stopping_criteria, 338 | teacher_routed_tok=MoE_mapping_teacher[args.routed_tok], 339 | teacher_num_experts_per_tok=args.num_experts_per_tok, 340 | student_routed_tok=MoE_mapping_student[args.student_routed_tok], 341 | student_num_experts_per_tok=args.student_num_experts_per_tok, 342 | ) 343 | 344 | generation_text = tokenizer.batch_decode( 345 | outputs[:, prompt_len:], 346 | clean_up_tokenization_spaces=True, 347 | skip_special_tokens=True, 348 | ) 349 | for prompt, generation in zip(cur_prompt_lst, generation_text): 350 | json_str = json.dumps( 351 | { 352 | "idx": prompt["idx"], 353 | "completion": generation.strip(), 354 | } 355 | ) 356 | with open(args.outfile + f"{rank}", "a", encoding="utf-8") as f: 357 | f.write(json_str + "\n") 358 | 359 | t = time.time() 360 | print("time used: ", t - s) 361 | 362 | if __name__ == "__main__": 363 | args = args_parse() 364 | print(args) 365 | assert args.world_size % args.gpus_per_model == 0 366 | args.num_processes = args.world_size // args.gpus_per_model 367 | process_list = [] 368 | for i in range(args.num_processes): 369 | p = Process(target=generate, args=(i, args)) 370 | p.start() 371 | process_list.append(p) 372 | for p in process_list: 373 | p.join() 374 | all_ret = pd.DataFrame() 375 | for rank in range(args.num_processes): 376 | with open(args.outfile + f"{rank}", "r", encoding="utf-8") as f: 377 | all_ret = pd.concat( 378 | [all_ret, pd.read_json(f, lines=True)], ignore_index=True 379 | ) 380 | all_ret.sort_values(by="idx", inplace=True) 381 | all_ret.to_json(args.outfile, orient="records", lines=True, force_ascii=False) 382 | for rank in range(args.num_processes): 383 | os.remove(args.outfile + f"{rank}") 384 | -------------------------------------------------------------------------------- /lm_eval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DavidFanzz/SCMoE/857fad79769c61b25ad87d0618246490bf5551a2/lm_eval/__init__.py -------------------------------------------------------------------------------- /lm_eval/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from warnings import warn 3 | 4 | from datasets import load_dataset 5 | 6 | 7 | class Task(ABC): 8 | """A task represents an entire benchmark including its dataset, problems, 9 | answers, generation settings and evaluation methods. 10 | """ 11 | 12 | # The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub 13 | DATASET_PATH: str = None 14 | 15 | # The name of a subset within `DATASET_PATH`. 16 | DATASET_NAME: str = None 17 | 18 | def __init__(self, stop_words=None, requires_execution=True): 19 | """ 20 | :param stop_words: list 21 | list of stop words if the generation uses a stopping criteria during generation 22 | :param requires_execution: bool 23 | wheter the task requires code execution during evaluation or not 24 | """ 25 | self.stop_words = stop_words 26 | self.requires_execution = requires_execution 27 | try: 28 | self.dataset = load_dataset(path=self.DATASET_PATH, name=self.DATASET_NAME) 29 | except Exception as e: 30 | warn( 31 | f"Loading the dataset failed with {str(e)}. This task will use a locally downloaded dataset, not from the HF hub." 32 | ) 33 | 34 | @abstractmethod 35 | def get_dataset(self): 36 | """Returns dataset for the task or an iterable of any object, that get_prompt can handle""" 37 | return [] 38 | 39 | def fewshot_examples(self): 40 | """Loads and returns the few-shot examples for the task if they exist.""" 41 | pass 42 | 43 | @abstractmethod 44 | def get_reference(self, doc): 45 | """Builds the reference solution for the doc. 46 | :param doc: dict[str: str] 47 | sample from the test dataset 48 | """ 49 | pass 50 | 51 | @abstractmethod 52 | def postprocess_generation(self, generation, idx): 53 | """Defines the postprocessing for a LM generation. 54 | :param generation: str 55 | code generation from LM 56 | :param idx: int 57 | index of doc in the dataset to which the generation belongs 58 | """ 59 | pass 60 | 61 | @abstractmethod 62 | def process_results(self, generations, references): 63 | """Takes the list of LM generations and evaluates them against ground truth references, 64 | returning the metric for the generations as in {"metric_name": result}. 65 | :param generations: list(list(str)) 66 | list of lists containing generations 67 | :param references: list(str) 68 | list of str containing refrences 69 | :return: dict[str: float] 70 | """ 71 | pass 72 | -------------------------------------------------------------------------------- /lm_eval/evaluator.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import json 3 | import os 4 | import warnings 5 | 6 | from lm_eval import tasks 7 | 8 | _WARNING = """ 9 | ################################################################################ 10 | !!!WARNING!!! 11 | ################################################################################ 12 | The "code_eval"/"apps_metric" you are about to use, execute untrusted 13 | model-generated code in Python. 14 | Although it is highly unlikely that model-generated code will do something 15 | overtly malicious in response to this test suite, model-generated code may act 16 | destructively due to a lack of model capability or alignment. 17 | Users are strongly encouraged to sandbox this evaluation suite so that it 18 | does not perform destructive actions on their host or network. For more 19 | information on how OpenAI sandboxes its code, see the paper "Evaluating Large 20 | Language Models Trained on Code" (https://arxiv.org/abs/2107.03374). 21 | Once you have read this disclaimer and taken appropriate precautions, set the argument 22 | "allow_code_execution" to True. 23 | ################################################################################\ 24 | """ 25 | 26 | 27 | class Evaluator: 28 | def __init__(self, args): 29 | self.args = args 30 | 31 | # setup arguments 32 | self.metric_output_path = args.metric_output_path 33 | 34 | # code evaluation permission 35 | self.allow_code_execution = args.allow_code_execution 36 | 37 | def get_generate_text(self, task_name): 38 | task = tasks.get_task(task_name, self.args) 39 | dataset = task.get_dataset() 40 | # if args.limit is None, use all samples 41 | n_tasks = len(dataset) 42 | references = [task.get_reference(dataset[i]) for i in range(n_tasks)] 43 | 44 | if self.args.check_references: 45 | if "get_solution" in inspect.signature(task.get_reference).parameters: 46 | solutions = [[task.get_reference(dataset[i], get_solution=True)] for i in range(n_tasks)] 47 | else: 48 | solutions = [[ref] for ref in references] 49 | return solutions, references 50 | if self.args.load_generations_path: 51 | generations = [] 52 | with open(self.args.load_generations_path) as fp: 53 | for line in fp: 54 | json_obj = json.loads(line) 55 | generations.append(json_obj) 56 | print( 57 | f"generations loaded, {n_tasks} selected from {len(generations)}." 58 | ) 59 | generations = generations[:n_tasks] 60 | generations = [[_['completion']] for _ in generations] 61 | 62 | return generations, references 63 | 64 | def evaluate(self, task_name): 65 | task = tasks.get_task(task_name, self.args) 66 | if task.requires_execution and not self.allow_code_execution: 67 | raise ValueError(_WARNING) 68 | 69 | generations, references = self.get_generate_text(task_name) 70 | 71 | # make sure tokenizer plays nice with multiprocessing 72 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 73 | if self.allow_code_execution and task.requires_execution: 74 | os.environ["HF_ALLOW_CODE_EVAL"] = "1" 75 | print("Evaluating generations...") 76 | results = task.process_results(generations, references) 77 | return results 78 | -------------------------------------------------------------------------------- /lm_eval/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from pprint import pprint 3 | from . import (mbpp, human_eval, strategy_qa, gsm8k) 4 | 5 | TASK_REGISTRY = { 6 | "mbpp": mbpp.MBPP, 7 | "strategy_qa": strategy_qa.StrategyQA, 8 | "human_eval": human_eval.HumanEval, 9 | 'gsm8k': gsm8k.GSM8K, 10 | } 11 | 12 | ALL_TASKS = sorted(list(TASK_REGISTRY)) 13 | 14 | 15 | def get_task(task_name, args=None): 16 | try: 17 | kwargs = {} 18 | if "postprocessed_output_path" in inspect.signature(TASK_REGISTRY[task_name]).parameters: 19 | kwargs["postprocessed_output_path"] = args.postprocessed_output_path 20 | return TASK_REGISTRY[task_name](**kwargs) 21 | except KeyError: 22 | print("Available tasks:") 23 | pprint(TASK_REGISTRY) 24 | raise KeyError(f"Missing task {task_name}") 25 | -------------------------------------------------------------------------------- /lm_eval/tasks/gsm8k.py: -------------------------------------------------------------------------------- 1 | import re 2 | from evaluate import load 3 | from lm_eval.base import Task 4 | import multiprocessing 5 | from math import isclose 6 | from typing import Union 7 | from sympy import simplify, N 8 | from sympy.parsing.sympy_parser import parse_expr 9 | from sympy.parsing.latex import parse_latex 10 | 11 | class GSM8K(Task): 12 | """A task represents an entire benchmark including its dataset, problems, 13 | answers, generation settings and evaluation methods. 14 | """ 15 | 16 | DATASET_PATH = "data/gsm8k" 17 | 18 | def __init__(self, postprocessed_output_path): 19 | self.postprocessed_output_path = postprocessed_output_path 20 | super().__init__( 21 | requires_execution=True, 22 | ) 23 | 24 | def get_dataset(self): 25 | """Returns dataset for the task or an iterable of any object, that get_prompt can handle""" 26 | dataset = self.dataset["test"] 27 | # the wrong split of commonsense_qa can be loaded with old datasets cache 28 | assert ( 29 | len(dataset) == 1319 30 | ), "please ensure you have the latest version of commonsense_qa dataset, try deleting its old cache" 31 | return dataset 32 | 33 | def get_reference(self, doc): 34 | """Builds the reference solution for the doc (sample from the test dataset).""" 35 | return "".join(doc["label"]) 36 | 37 | def postprocess_generation(self, generation): 38 | """Defines the postprocessing for a LM generation. 39 | :param generation: str 40 | code generation from LM 41 | :param idx: int 42 | index of doc in the dataset to which the generation belongs 43 | """ 44 | completion = generation[0].strip() 45 | if "\n\nQuestion:" in completion: 46 | completion = completion.split("\n\nQuestion:")[0] 47 | return extract_answer(completion) 48 | 49 | 50 | def process_results(self, generations, references): 51 | """Takes the list of LM generations and evaluates them against ground truth references, 52 | returning the metric for the generations. 53 | :param generations: list(list(str)) 54 | list of lists containing generations 55 | :param references: list(str) 56 | list of str containing refrences 57 | """ 58 | generations = [self.postprocess_generation(_) for _ in generations] 59 | results = [math_equal(references[_], generations[_]) for _ in range(len(references))] 60 | return len([_ for _ in results if _]) / len(results) 61 | 62 | 63 | def extract_answer(pred_str): 64 | if ('he answer is' in pred_str): 65 | pred = pred_str.split('he answer is')[-1].strip() 66 | if not is_digit(pred): 67 | pred_str = pred 68 | pattern = '-?\d*\.?\d+' 69 | pred = re.findall(pattern, pred_str.replace(",", "")) 70 | if(len(pred) >= 1): 71 | pred = pred[-1] 72 | else: 73 | pred = '' 74 | else: # use the last number 75 | pattern = '-?\d*\.?\d+' 76 | pred = re.findall(pattern, pred_str.replace(",", "")) 77 | if(len(pred) >= 1): 78 | pred = pred[-1] 79 | else: 80 | pred = '' 81 | 82 | # multiple line 83 | pred = pred.split("\n")[0] 84 | if pred != "" and pred[0] == ":": 85 | pred = pred[1:] 86 | if pred != "" and pred[-1] == ".": 87 | pred = pred[:-1] 88 | if pred != "" and pred[-1] == "/": 89 | pred = pred[:-1] 90 | pred = strip_string(pred) 91 | return pred 92 | 93 | def strip_string(string): 94 | string = str(string).strip() 95 | # linebreaks 96 | string = string.replace("\n", "") 97 | 98 | # right "." 99 | string = string.rstrip(".") 100 | 101 | # remove inverse spaces 102 | string = string.replace("\\!", "") 103 | string = string.replace("\\ ", "") 104 | 105 | # replace \\ with \ 106 | string = string.replace("\\\\", "\\") 107 | string = string.replace("\\\\", "\\") 108 | 109 | # replace tfrac and dfrac with frac 110 | string = string.replace("tfrac", "frac") 111 | string = string.replace("dfrac", "frac") 112 | 113 | # remove \left and \right 114 | string = string.replace("\\left", "") 115 | string = string.replace("\\right", "") 116 | 117 | # Remove unit: miles, dollars if after is not none 118 | _string = re.sub(r"\\text{.*?}$", "", string).strip() 119 | if _string != "" and _string != string: 120 | # print("Warning: unit not removed: '{}' -> '{}'".format(string, _string)) 121 | string = _string 122 | 123 | # Remove circ (degrees) 124 | string = string.replace("^{\\circ}", "") 125 | string = string.replace("^\\circ", "") 126 | 127 | # remove dollar signs 128 | string = string.replace("\\$", "") 129 | string = string.replace("$", "") 130 | 131 | string = string.replace("\\text", "") 132 | string = string.replace("x\\in", "") 133 | 134 | # remove percentage 135 | string = string.replace("\\%", "") 136 | string = string.replace("\%", "") 137 | string = string.replace("%", "") 138 | 139 | # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string 140 | string = string.replace(" .", " 0.") 141 | string = string.replace("{.", "{0.") 142 | 143 | # cdot 144 | string = string.replace("\\cdot", "") 145 | 146 | # inf 147 | string = string.replace("infinity", "\\infty") 148 | if "\\infty" not in string: 149 | string = string.replace("inf", "\\infty") 150 | string = string.replace("+\\inity", "\\infty") 151 | 152 | # and 153 | string = string.replace("and", "") 154 | string = string.replace("\\mathbf", "") 155 | 156 | # use regex to remove \mbox{...} 157 | string = re.sub(r"\\mbox{.*?}", "", string) 158 | 159 | # quote 160 | string.replace("'", "") 161 | string.replace("\"", "") 162 | 163 | # i, j 164 | if "j" in string and "i" not in string: 165 | string = string.replace("j", "i") 166 | 167 | # replace a.000b where b is not number or b is end, with ab, use regex 168 | string = re.sub(r"(\d+)\.0+([^\d])", r"\1\2", string) 169 | string = re.sub(r"(\d+)\.0+$", r"\1", string) 170 | 171 | # if empty, return empty string 172 | if len(string) == 0: 173 | return string 174 | if string[0] == ".": 175 | string = "0" + string 176 | 177 | # to consider: get rid of e.g. "k = " or "q = " at beginning 178 | if len(string.split("=")) == 2: 179 | if len(string.split("=")[0]) <= 2: 180 | string = string.split("=")[1] 181 | 182 | string = _fix_sqrt(string) 183 | string = string.replace(" ", "") 184 | 185 | # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} 186 | string = _fix_fracs(string) 187 | 188 | # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y 189 | string = _fix_a_slash_b(string) 190 | 191 | return string 192 | 193 | def _fix_fracs(string): 194 | substrs = string.split("\\frac") 195 | new_str = substrs[0] 196 | if len(substrs) > 1: 197 | substrs = substrs[1:] 198 | for substr in substrs: 199 | new_str += "\\frac" 200 | if len(substr) > 0 and substr[0] == "{": 201 | new_str += substr 202 | else: 203 | try: 204 | assert len(substr) >= 2 205 | except: 206 | return string 207 | a = substr[0] 208 | b = substr[1] 209 | if b != "{": 210 | if len(substr) > 2: 211 | post_substr = substr[2:] 212 | new_str += "{" + a + "}{" + b + "}" + post_substr 213 | else: 214 | new_str += "{" + a + "}{" + b + "}" 215 | else: 216 | if len(substr) > 2: 217 | post_substr = substr[2:] 218 | new_str += "{" + a + "}" + b + post_substr 219 | else: 220 | new_str += "{" + a + "}" + b 221 | string = new_str 222 | return string 223 | 224 | def _fix_a_slash_b(string): 225 | if len(string.split("/")) != 2: 226 | return string 227 | a = string.split("/")[0] 228 | b = string.split("/")[1] 229 | try: 230 | if "sqrt" not in a: 231 | a = int(a) 232 | if "sqrt" not in b: 233 | b = int(b) 234 | assert string == "{}/{}".format(a, b) 235 | new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" 236 | return new_string 237 | except: 238 | return string 239 | 240 | def _fix_sqrt(string): 241 | _string = re.sub(r"\\sqrt(\w+)", r"\\sqrt{\1}", string) 242 | return _string 243 | 244 | def is_digit(s): 245 | try: 246 | float(str(s).replace(",", "")) 247 | return True 248 | except ValueError: 249 | return False 250 | 251 | def math_equal(prediction: Union[bool, float, str], 252 | reference: Union[float, str], 253 | include_percentage: bool = True, 254 | is_close: bool = True, 255 | timeout: bool = False, 256 | ) -> bool: 257 | """ 258 | Exact match of math if and only if: 259 | 1. numerical equal: both can convert to float and are equal 260 | 2. symbolic equal: both can convert to sympy expression and are equal 261 | """ 262 | try: # 1. numerical equal 263 | if is_digit(prediction) and is_digit(reference): 264 | prediction = float(str(prediction).replace(",", "")) 265 | reference = float(str(reference).replace(",", "")) 266 | # number questions 267 | if include_percentage: 268 | gt_result = [reference / 100, reference, reference * 100] 269 | else: 270 | gt_result = [reference] 271 | for item in gt_result: 272 | try: 273 | if is_close: 274 | if isclose(item, prediction, rel_tol=1e-4): 275 | return True 276 | else: 277 | if item == prediction: 278 | return True 279 | except Exception: 280 | continue 281 | return False 282 | except: 283 | pass 284 | 285 | if not prediction and prediction not in [0, False]: 286 | return False 287 | 288 | # 2. symbolic equal 289 | reference = str(reference).strip() 290 | prediction = str(prediction).strip() 291 | 292 | ## deal with [], (), {} 293 | pred_str, ref_str = prediction, reference 294 | if (prediction.startswith("[") and prediction.endswith("]") and not reference.startswith("(")) or \ 295 | (prediction.startswith("(") and prediction.endswith(")") and not reference.startswith("[")): 296 | pred_str = pred_str.strip("[]()") 297 | ref_str = ref_str.strip("[]()") 298 | for s in ['{', "}", "(", ")"]: 299 | ref_str = ref_str.replace(s, "") 300 | pred_str = pred_str.replace(s, "") 301 | if pred_str == ref_str: 302 | return True 303 | 304 | ## [a, b] vs. [c, d], return a==c and b==d 305 | if (prediction.startswith("[") and prediction.endswith("]")) and (reference.startswith("[") and reference.endswith("]")) or \ 306 | (prediction.startswith("(") and prediction.endswith(")")) and (reference.startswith("(") and reference.endswith(")")): 307 | pred_parts = prediction[1:-1].split(",") 308 | ref_parts = reference[1:-1].split(",") 309 | if len(pred_parts) == len(ref_parts): 310 | if all([math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close) for i in range(len(pred_parts))]): 311 | return True 312 | 313 | # symbolic equal with sympy 314 | if timeout: 315 | if call_with_timeout(symbolic_equal_process, prediction, reference): 316 | return True 317 | else: 318 | if symbolic_equal(prediction, reference): 319 | return True 320 | 321 | return False 322 | 323 | 324 | def math_equal_process(param): 325 | return math_equal(param[-2], param[-1]) 326 | 327 | 328 | def symbolic_equal(a, b): 329 | def _parse(s): 330 | for f in [parse_latex, parse_expr]: 331 | try: 332 | return f(s) 333 | except: 334 | pass 335 | return s 336 | a = _parse(a) 337 | b = _parse(b) 338 | 339 | try: 340 | if simplify(a-b) == 0: 341 | return True 342 | except: 343 | pass 344 | 345 | try: 346 | if isclose(N(a), N(b), rel_tol=1e-3): 347 | return True 348 | except: 349 | pass 350 | return False 351 | 352 | 353 | def symbolic_equal_process(a, b, output_queue): 354 | result = symbolic_equal(a, b) 355 | output_queue.put(result) 356 | 357 | 358 | def call_with_timeout(func, *args, timeout=1, **kwargs): 359 | output_queue = multiprocessing.Queue() 360 | process_args = args + (output_queue,) 361 | process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs) 362 | process.start() 363 | process.join(timeout) 364 | 365 | if process.is_alive(): 366 | process.terminate() 367 | process.join() 368 | return False 369 | 370 | return output_queue.get() -------------------------------------------------------------------------------- /lm_eval/tasks/human_eval.py: -------------------------------------------------------------------------------- 1 | """Evaluating Large Language Models Trained on Code 2 | https://arxiv.org/abs/2107.03374 3 | 4 | The HumanEval dataset released by OpenAI includes 164 programming problems with a function signature, 5 | docstring, body, and several unit tests. 6 | They were handwritten to ensure not to be included in the training set of code generation models. 7 | 8 | Homepage: https://github.com/openai/human-eval 9 | """ 10 | 11 | import re 12 | from evaluate import load 13 | from lm_eval.base import Task 14 | import pandas as pd 15 | import subprocess 16 | import os 17 | 18 | _CITATION = """ 19 | @misc{chen2021evaluating, 20 | title={Evaluating Large Language Models Trained on Code}, 21 | author={Mark Chen and Jerry Tworek and Heewoo Jun and Qiming Yuan and Henrique Ponde de Oliveira Pinto and Jared Kaplan and Harri Edwards and Yuri Burda and Nicholas Joseph and Greg Brockman and Alex Ray and Raul Puri and Gretchen Krueger and Michael Petrov and Heidy Khlaaf and Girish Sastry and Pamela Mishkin and Brooke Chan and Scott Gray and Nick Ryder and Mikhail Pavlov and Alethea Power and Lukasz Kaiser and Mohammad Bavarian and Clemens Winter and Philippe Tillet and Felipe Petroski Such and Dave Cummings and Matthias Plappert and Fotios Chantzis and Elizabeth Barnes and Ariel Herbert-Voss and William Hebgen Guss and Alex Nichol and Alex Paino and Nikolas Tezak and Jie Tang and Igor Babuschkin and Suchir Balaji and Shantanu Jain and William Saunders and Christopher Hesse and Andrew N. Carr and Jan Leike and Josh Achiam and Vedant Misra and Evan Morikawa and Alec Radford and Matthew Knight and Miles Brundage and Mira Murati and Katie Mayer and Peter Welinder and Bob McGrew and Dario Amodei and Sam McCandlish and Ilya Sutskever and Wojciech Zaremba}, 22 | year={2021}, 23 | eprint={2107.03374}, 24 | archivePrefix={arXiv}, 25 | primaryClass={cs.LG} 26 | } 27 | """ 28 | 29 | def run_pylint(i, generation): 30 | file_name = f"file_{i}_tmp.py_" 31 | with open(file_name, "w", encoding="utf-8") as file: 32 | file.write(generation[i][0]) 33 | command = f"pylint {file_name} --errors-only" 34 | result = subprocess.run(command, capture_output=True, text=True, shell=True) 35 | os.remove(file_name) 36 | return result 37 | 38 | class HumanEval(Task): 39 | """A task represents an entire benchmark including its dataset, problems, 40 | answers, generation settings and evaluation methods. 41 | """ 42 | 43 | DATASET_PATH = "data/human_eval" 44 | 45 | def __init__(self, postprocessed_output_path): 46 | self.postprocessed_output_path = postprocessed_output_path 47 | 48 | super().__init__( 49 | requires_execution=True, 50 | ) 51 | 52 | def get_dataset(self): 53 | """Returns dataset for the task or an iterable of any object, that get_prompt can handle""" 54 | dataset = self.dataset["test"] 55 | # the wrong split of commonsense_qa can be loaded with old datasets cache 56 | assert ( 57 | len(dataset) == 164 58 | ), "please ensure you have the latest version of commonsense_qa dataset, try deleting its old cache" 59 | return dataset 60 | 61 | 62 | def get_reference(self, doc): 63 | """Builds the reference solution for the doc (sample from the test dataset).""" 64 | test_func = doc["test"] 65 | entry_point = f"check({doc['entry_point']})" 66 | return "\n" + test_func + "\n" + entry_point 67 | 68 | 69 | def postprocess_generation(self, generation, prefix): 70 | """Defines the postprocessing for a LM generation. 71 | :param generation: str 72 | code generation from LM 73 | :param idx: int 74 | index of doc in the dataset to which the generation belongs 75 | (not used for Humaneval-Task) 76 | """ 77 | completion = generation[0] 78 | completion = completion.replace("\r", "") 79 | if '```python' in completion: 80 | def_line = completion.index('```python') 81 | completion = completion[def_line:].strip() 82 | completion = completion.replace('```python', '') 83 | try: 84 | next_line = completion.index('```') 85 | completion = completion[:next_line].strip() 86 | except: 87 | print(generation[0]) 88 | print("=" * 50 + "\n") 89 | 90 | if "\ndef" in completion: 91 | next_line = completion.index("\ndef") 92 | completion = completion[:next_line].strip() 93 | 94 | if "__name__" in completion: 95 | next_line = completion.index('__name__') 96 | completion = completion[:next_line].strip()[:-2] 97 | 98 | if "# Example usage" in completion: 99 | next_line = completion.index('# Example usage') 100 | completion = completion[:next_line].strip() 101 | 102 | return [prefix + " " + completion.strip() + "\n"] 103 | 104 | def process_results(self, generations, references): 105 | """Takes the list of LM generations and evaluates them against ground truth references, 106 | returning the metric for the generations. 107 | :param generations: list(list(str)) 108 | list of lists containing generations 109 | :param references: list(str) 110 | list of str containing refrences 111 | """ 112 | generations = [self.postprocess_generation(generations[_], self.dataset['test']['prompt'][_]) for _ in range(len(generations))] 113 | if self.postprocessed_output_path: 114 | postprocessed_output = pd.DataFrame() 115 | postprocessed_output['results'] = generations 116 | postprocessed_output.to_json(self.postprocessed_output_path, orient='records', lines=True) 117 | 118 | code_metric = load("code_eval") 119 | results, _ = code_metric.compute( 120 | references=references, 121 | predictions=generations, 122 | ) 123 | return results['pass@1'] 124 | -------------------------------------------------------------------------------- /lm_eval/tasks/mbpp.py: -------------------------------------------------------------------------------- 1 | """Program Synthesis with Large Language Models 2 | https://arxiv.org/abs/2108.07732 3 | 4 | The benchmark consists of around 1,000 crowd-sourced Python programming problems, 5 | designed to be solvable by entry level programmers, covering programming fundamentals, 6 | standard library functionality, and so on. Each problem consists of a task description, 7 | code solution and 3 automated test cases. As described in the paper, a subset of the data 8 | has been hand-verified by the authors. 9 | 10 | Homepage:: https://github.com/google-research/google-research/tree/master/mbpp 11 | """ 12 | 13 | import re 14 | from evaluate import load 15 | from lm_eval.base import Task 16 | import pandas as pd 17 | import subprocess 18 | import os 19 | 20 | _CITATION = """ 21 | @article{austin2021program, 22 | title={Program Synthesis with Large Language Models}, 23 | author={Austin, Jacob and Odena, Augustus and Nye, Maxwell and Bosma, Maarten and Michalewski, Henryk and Dohan, David and Jiang, Ellen and Cai, Carrie and Terry, Michael and Le, Quoc and others}, 24 | journal={arXiv preprint arXiv:2108.07732}, 25 | year={2021} 26 | } 27 | """ 28 | 29 | def run_pylint(i, generation): 30 | file_name = f"file_{i}_tmp.py" 31 | with open(file_name, "w", encoding="utf-8") as file: 32 | file.write(generation[i][0]) 33 | command = f"pylint {file_name} --errors-only" 34 | result = subprocess.run(command, capture_output=True, text=True, shell=True) 35 | os.remove(file_name) 36 | return result 37 | 38 | 39 | class MBPP(Task): 40 | """A task represents an entire benchmark including its dataset, problems, 41 | answers, generation settings and evaluation methods. 42 | """ 43 | 44 | DATASET_PATH = "data/mbpp" 45 | 46 | def __init__(self, postprocessed_output_path): 47 | self.postprocessed_output_path = postprocessed_output_path 48 | super().__init__( 49 | stop_words=["\nclass", "\nassert", '\n"""', "\nprint", "\nif", "\n<|/", "\n```"], 50 | requires_execution=True, 51 | ) 52 | 53 | def get_dataset(self): 54 | """Returns dataset for the task or an iterable of any object, that get_prompt can handle""" 55 | dataset = self.dataset["test"] 56 | # the wrong split of mbpp can be loaded with old datasets cache 57 | assert ( 58 | len(dataset) == 500 59 | ), "please ensure you have the latest version of MBPP dataset, try deleting its old cache" 60 | return dataset 61 | 62 | def get_reference(self, doc): 63 | """Builds the reference solution for the doc (sample from the test dataset).""" 64 | return "\n".join(doc["test_list"]) 65 | 66 | def postprocess_generation(self, generation): 67 | """Defines the postprocessing for a LM generation. 68 | :param generation: str 69 | code generation from LM 70 | :param idx: int 71 | index of doc in the dataset to which the generation belongs 72 | (not used for Humaneval-Task) 73 | """ 74 | completion = generation[0] 75 | completion = completion.replace("\r", "") 76 | if "[DONE]" in completion: 77 | completion = completion.split("[DONE]")[0] 78 | main_func_pattern = r'def.*?\n[^\n\s#]' 79 | main_func_pattern_result = re.search(main_func_pattern, completion, re.DOTALL) 80 | if main_func_pattern_result: 81 | completion = main_func_pattern_result.group(0)[:-1] 82 | elif "def " in completion: 83 | completion = "def " +completion.split("def ")[-1] 84 | else: 85 | print(generation[0]) 86 | print("=" * 50 + "\n") 87 | 88 | if '```python' in completion: 89 | def_line = completion.index('```python') 90 | completion = completion[def_line:].strip() 91 | completion = completion.replace('```python', '') 92 | try: 93 | next_line = completion.index('```') 94 | completion = completion[:next_line].strip() 95 | except: 96 | print(completion) 97 | print("=" * 50 + "\n") 98 | if "__name__" in completion: 99 | next_line = completion.index('__name__') 100 | completion = completion[:next_line].strip()[:-2] 101 | 102 | if "# Example usage" in completion: 103 | next_line = completion.index('# Example usage') 104 | completion = completion[:next_line].strip() 105 | return [completion] 106 | 107 | def process_results(self, generations, references): 108 | """Takes the list of LM generations and evaluates them against ground truth references, 109 | returning the metric for the generations. 110 | :param generations: list(list(str)) 111 | list of lists containing generations 112 | :param references: list(str) 113 | list of str containing refrences 114 | """ 115 | postprocessed_generations = [self.postprocess_generation(generations[_]) for _ in range(len(generations))] 116 | if self.postprocessed_output_path: 117 | postprocessed_output = pd.DataFrame() 118 | postprocessed_output['results'] = generations 119 | postprocessed_output.to_json(self.postprocessed_output_path, orient='records', lines=True) 120 | 121 | code_metric = load("code_eval") 122 | results, _ = code_metric.compute( 123 | references=references, 124 | predictions=postprocessed_generations, 125 | ) 126 | return results['pass@1'] 127 | -------------------------------------------------------------------------------- /lm_eval/tasks/strategy_qa.py: -------------------------------------------------------------------------------- 1 | import re 2 | from evaluate import load 3 | from lm_eval.base import Task 4 | import pandas as pd 5 | 6 | class StrategyQA(Task): 7 | """A task represents an entire benchmark including its dataset, problems, 8 | answers, generation settings and evaluation methods. 9 | """ 10 | 11 | DATASET_PATH = "data/strategy_qa" 12 | 13 | def __init__(self, postprocessed_output_path, sft): 14 | self.postprocessed_output_path = postprocessed_output_path 15 | self.sft = sft 16 | super().__init__( 17 | requires_execution=True, 18 | ) 19 | 20 | def get_dataset(self): 21 | """Returns dataset for the task or an iterable of any object, that get_prompt can handle""" 22 | dataset = self.dataset["test"] 23 | # the wrong split of commonsense_qa can be loaded with old datasets cache 24 | assert ( 25 | len(dataset) == 2286 26 | ), "please ensure you have the latest version of commonsense_qa dataset, try deleting its old cache" 27 | return dataset 28 | 29 | def get_reference(self, doc): 30 | """Builds the reference solution for the doc (sample from the test dataset).""" 31 | return "".join(doc["label"]) 32 | 33 | def postprocess_generation(self, generation): 34 | """Defines the postprocessing for a LM generation. 35 | :param generation: str 36 | code generation from LM 37 | :param idx: int 38 | index of doc in the dataset to which the generation belongs 39 | """ 40 | answer_key = "None" 41 | answer_begin_hints = ["answer is", "answer to", "answer choice is", "i would choose", "answer would be", "answer seems to be", "the correct answer is", "answer to the question is"] 42 | answer_end_hints = ["is correct", "is the best choice", "is the correct answer", "is the correct choice", "answer choice is correct"] 43 | completion = generation[0].lower().strip() 44 | if "\n\n" in completion: 45 | completion = completion.split("\n\n")[0] 46 | 47 | matched_begin_hints = [_ for _ in answer_begin_hints if _ in completion] 48 | matched_end_hints = [_ for _ in answer_end_hints if _ in completion] 49 | if len(matched_begin_hints) > 0: 50 | completion = completion.split(matched_begin_hints[-1])[-1] 51 | elif len(matched_end_hints) > 0: 52 | completion = completion.split(matched_end_hints[-1])[0] 53 | pattern = r'( yes | no )' 54 | completion = " " + completion.replace(".", " ").replace(",", " ").replace(";", " ") + " " 55 | matches = re.findall(pattern, completion) 56 | 57 | if matches: 58 | return matches[-1].strip() 59 | else: 60 | print("=" * 25 + "No clear yes or no results" + "=" * 25 ) 61 | if "\n\n" in generation[0].strip(): 62 | print(generation[0].strip().split("\n\n")[0]) 63 | else: 64 | print(generation[0].strip()) 65 | print("=" * 50 + "=" * len("No clear yes or no results") + "\n") 66 | return "no" 67 | 68 | 69 | def process_results(self, generations, references): 70 | """Takes the list of LM generations and evaluates them against ground truth references, 71 | returning the metric for the generations. 72 | :param generations: list(list(str)) 73 | list of lists containing generations 74 | :param references: list(str) 75 | list of str containing refrences 76 | """ 77 | generations = [self.postprocess_generation(_) for _ in generations] 78 | if self.postprocessed_output_path: 79 | postprocessed_output = pd.DataFrame() 80 | postprocessed_output['results'] = generations 81 | postprocessed_output.to_json(self.postprocessed_output_path, orient='records', lines=True) 82 | cnt = 0 83 | for i in range(len(generations)): 84 | if generations[i] == "None": 85 | cnt += 1 86 | acc_metric = load("exact_match") 87 | results = acc_metric.compute( 88 | references=references, 89 | predictions=generations, 90 | ) 91 | results["match_template"] = 1 - cnt / len(generations) 92 | return results['exact_match'] 93 | -------------------------------------------------------------------------------- /modeling_models/Mixtral/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "MixtralForCausalLM" 4 | ], 5 | "attention_dropout": 0.0, 6 | "bos_token_id": 1, 7 | "eos_token_id": 2, 8 | "hidden_act": "silu", 9 | "hidden_size": 4096, 10 | "initializer_range": 0.02, 11 | "intermediate_size": 14336, 12 | "max_position_embeddings": 32768, 13 | "model_type": "mixtral", 14 | "num_attention_heads": 32, 15 | "num_experts_per_tok": 2, 16 | "num_hidden_layers": 32, 17 | "num_key_value_heads": 8, 18 | "num_local_experts": 8, 19 | "output_router_logits": false, 20 | "rms_norm_eps": 1e-05, 21 | "rope_theta": 1000000.0, 22 | "router_aux_loss_coef": 0.02, 23 | "sliding_window": null, 24 | "tie_word_embeddings": false, 25 | "torch_dtype": "bfloat16", 26 | "transformers_version": "4.36.0.dev0", 27 | "use_cache": true, 28 | "vocab_size": 32000, 29 | "dynamic_expert_routing_threshold": 1.0, 30 | "routed_tok": [ 31 | 0, 32 | 1 33 | ] 34 | } -------------------------------------------------------------------------------- /modeling_models/Mixtral/configuration_mixtral.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 Mixtral AI and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Mixtral model configuration""" 16 | 17 | from ...configuration_utils import PretrainedConfig 18 | from ...utils import logging 19 | 20 | 21 | logger = logging.get_logger(__name__) 22 | 23 | MIXTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP = { 24 | "mistral-ai/Mixtral-8x7B": "https://huggingface.co/mistral-ai/Mixtral-8x7B/resolve/main/config.json", 25 | } 26 | 27 | 28 | class MixtralConfig(PretrainedConfig): 29 | r""" 30 | This is the configuration class to store the configuration of a [`MixtralModel`]. It is used to instantiate an 31 | Mixtral model according to the specified arguments, defining the model architecture. Instantiating a configuration 32 | with the defaults will yield a similar configuration to that of the Mixtral-7B-v0.1 or Mixtral-7B-Instruct-v0.1. 33 | 34 | [mixtralai/Mixtral-8x7B](https://huggingface.co/mixtralai/Mixtral-8x7B) 35 | [mixtralai/Mixtral-7B-Instruct-v0.1](https://huggingface.co/mixtralai/Mixtral-7B-Instruct-v0.1) 36 | 37 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 38 | documentation from [`PretrainedConfig`] for more information. 39 | 40 | 41 | Args: 42 | vocab_size (`int`, *optional*, defaults to 32000): 43 | Vocabulary size of the Mixtral model. Defines the number of different tokens that can be represented by the 44 | `inputs_ids` passed when calling [`MixtralModel`] 45 | hidden_size (`int`, *optional*, defaults to 4096): 46 | Dimension of the hidden representations. 47 | intermediate_size (`int`, *optional*, defaults to 14336): 48 | Dimension of the MLP representations. 49 | num_hidden_layers (`int`, *optional*, defaults to 32): 50 | Number of hidden layers in the Transformer encoder. 51 | num_attention_heads (`int`, *optional*, defaults to 32): 52 | Number of attention heads for each attention layer in the Transformer encoder. 53 | num_key_value_heads (`int`, *optional*, defaults to 8): 54 | This is the number of key_value heads that should be used to implement Grouped Query Attention. If 55 | `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if 56 | `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When 57 | converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed 58 | by meanpooling all the original heads within that group. For more details checkout [this 59 | paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`. 60 | hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): 61 | The non-linear activation function (function or string) in the decoder. 62 | max_position_embeddings (`int`, *optional*, defaults to `4096*32`): 63 | The maximum sequence length that this model might ever be used with. Mixtral's sliding window attention 64 | allows sequence of up to 4096*32 tokens. 65 | initializer_range (`float`, *optional*, defaults to 0.02): 66 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 67 | rms_norm_eps (`float`, *optional*, defaults to 1e-05): 68 | The epsilon used by the rms normalization layers. 69 | use_cache (`bool`, *optional*, defaults to `True`): 70 | Whether or not the model should return the last key/values attentions (not used by all models). Only 71 | relevant if `config.is_decoder=True`. 72 | pad_token_id (`int`, *optional*): 73 | The id of the padding token. 74 | bos_token_id (`int`, *optional*, defaults to 1): 75 | The id of the "beginning-of-sequence" token. 76 | eos_token_id (`int`, *optional*, defaults to 2): 77 | The id of the "end-of-sequence" token. 78 | tie_word_embeddings (`bool`, *optional*, defaults to `False`): 79 | Whether the model's input and output word embeddings should be tied. 80 | rope_theta (`float`, *optional*, defaults to 1000000.0): 81 | The base period of the RoPE embeddings. 82 | sliding_window (`int`, *optional*): 83 | Sliding window attention window size. If not specified, will default to `4096`. 84 | attention_dropout (`float`, *optional*, defaults to 0.0): 85 | The dropout ratio for the attention probabilities. 86 | num_experts_per_tok (`int`, *optional*, defaults to 2): 87 | The number of experts to root per-token, can be also interpreted as the `top-p` routing 88 | parameter 89 | num_local_experts (`int`, *optional*, defaults to 8): 90 | Number of experts per Sparse MLP layer. 91 | output_router_logits (`bool`, *optional*, defaults to `False`): 92 | Whether or not the router logits should be returned by the model. Enabeling this will also 93 | allow the model to output the auxiliary loss. See [here]() for more details 94 | router_aux_loss_coef (`float`, *optional*, defaults to 0.001): 95 | The aux loss factor for the total loss. 96 | 97 | ```python 98 | >>> from transformers import MixtralModel, MixtralConfig 99 | 100 | >>> # Initializing a Mixtral 7B style configuration 101 | >>> configuration = MixtralConfig() 102 | 103 | >>> # Initializing a model from the Mixtral 7B style configuration 104 | >>> model = MixtralModel(configuration) 105 | 106 | >>> # Accessing the model configuration 107 | >>> configuration = model.config 108 | ```""" 109 | 110 | model_type = "mixtral" 111 | keys_to_ignore_at_inference = ["past_key_values"] 112 | 113 | def __init__( 114 | self, 115 | vocab_size=32000, 116 | hidden_size=4096, 117 | intermediate_size=14336, 118 | num_hidden_layers=32, 119 | num_attention_heads=32, 120 | num_key_value_heads=8, 121 | hidden_act="silu", 122 | max_position_embeddings=4096 * 32, 123 | initializer_range=0.02, 124 | rms_norm_eps=1e-5, 125 | use_cache=True, 126 | pad_token_id=None, 127 | bos_token_id=1, 128 | eos_token_id=2, 129 | tie_word_embeddings=False, 130 | rope_theta=1e6, 131 | sliding_window=None, 132 | attention_dropout=0.0, 133 | num_experts_per_tok=2, 134 | num_local_experts=8, 135 | output_router_logits=False, 136 | router_aux_loss_coef=0.001, 137 | routed_tok=[0, 1], 138 | dynamic_expert_routing_threshold=1.0, 139 | **kwargs, 140 | ): 141 | self.vocab_size = vocab_size 142 | self.max_position_embeddings = max_position_embeddings 143 | self.hidden_size = hidden_size 144 | self.intermediate_size = intermediate_size 145 | self.num_hidden_layers = num_hidden_layers 146 | self.num_attention_heads = num_attention_heads 147 | self.sliding_window = sliding_window 148 | 149 | # for backward compatibility 150 | if num_key_value_heads is None: 151 | num_key_value_heads = num_attention_heads 152 | 153 | self.num_key_value_heads = num_key_value_heads 154 | self.hidden_act = hidden_act 155 | self.initializer_range = initializer_range 156 | self.rms_norm_eps = rms_norm_eps 157 | self.use_cache = use_cache 158 | self.rope_theta = rope_theta 159 | self.attention_dropout = attention_dropout 160 | 161 | self.num_experts_per_tok = num_experts_per_tok 162 | self.num_local_experts = num_local_experts 163 | self.output_router_logits = output_router_logits 164 | self.router_aux_loss_coef = router_aux_loss_coef 165 | self.routed_tok = routed_tok 166 | self.dynamic_expert_routing_threshold = dynamic_expert_routing_threshold 167 | super().__init__( 168 | pad_token_id=pad_token_id, 169 | bos_token_id=bos_token_id, 170 | eos_token_id=eos_token_id, 171 | tie_word_embeddings=tie_word_embeddings, 172 | **kwargs, 173 | ) 174 | -------------------------------------------------------------------------------- /modeling_models/deepseek-moe/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "DeepseekForCausalLM" 4 | ], 5 | "attention_bias": false, 6 | "attention_dropout": 0.0, 7 | "auto_map": { 8 | "AutoConfig": "configuration_deepseek.DeepseekConfig", 9 | "AutoModel": "modeling_deepseek.DeepseekModel", 10 | "AutoModelForCausalLM": "modeling_deepseek.DeepseekForCausalLM" 11 | }, 12 | "aux_loss_alpha": 0.001, 13 | "bos_token_id": 100000, 14 | "eos_token_id": 100001, 15 | "first_k_dense_replace": 1, 16 | "hidden_act": "silu", 17 | "hidden_size": 2048, 18 | "initializer_range": 0.02, 19 | "intermediate_size": 10944, 20 | "max_position_embeddings": 4096, 21 | "model_type": "deepseek", 22 | "moe_intermediate_size": 1408, 23 | "moe_layer_freq": 1, 24 | "n_routed_experts": 64, 25 | "n_shared_experts": 2, 26 | "norm_topk_prob": false, 27 | "num_attention_heads": 16, 28 | "num_experts_per_tok": 6, 29 | "num_hidden_layers": 28, 30 | "num_key_value_heads": 16, 31 | "output_router_logits": false, 32 | "pretraining_tp": 1, 33 | "rms_norm_eps": 1e-06, 34 | "rope_scaling": null, 35 | "rope_theta": 10000, 36 | "routed_tok": [ 37 | 0, 38 | 1, 39 | 2, 40 | 3, 41 | 4, 42 | 5 43 | ], 44 | "scoring_func": "softmax", 45 | "seq_aux": true, 46 | "tie_word_embeddings": false, 47 | "torch_dtype": "bfloat16", 48 | "transformers_version": "4.39.3", 49 | "use_cache": true, 50 | "dynamic_expert_routing_threshold": 1.0, 51 | "vocab_size": 102400 52 | } 53 | -------------------------------------------------------------------------------- /modeling_models/deepseek-moe/configuration_deepseek.py: -------------------------------------------------------------------------------- 1 | from transformers.configuration_utils import PretrainedConfig 2 | from transformers.utils import logging 3 | 4 | logger = logging.get_logger(__name__) 5 | 6 | DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {} 7 | class DeepseekConfig(PretrainedConfig): 8 | r""" 9 | This is the configuration class to store the configuration of a [`DeepseekModel`]. It is used to instantiate an DeepSeek 10 | model according to the specified arguments, defining the model architecture. Instantiating a configuration with the 11 | defaults will yield a similar configuration to that of the DeepSeek-7B. 12 | 13 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 14 | documentation from [`PretrainedConfig`] for more information. 15 | 16 | 17 | Args: 18 | vocab_size (`int`, *optional*, defaults to 102400): 19 | Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the 20 | `inputs_ids` passed when calling [`DeepseekModel`] 21 | hidden_size (`int`, *optional*, defaults to 4096): 22 | Dimension of the hidden representations. 23 | intermediate_size (`int`, *optional*, defaults to 11008): 24 | Dimension of the MLP representations. 25 | moe_intermediate_size (`int`, *optional*, defaults to 1407): 26 | Dimension of the MoE representations. 27 | num_hidden_layers (`int`, *optional*, defaults to 32): 28 | Number of hidden layers in the Transformer decoder. 29 | num_attention_heads (`int`, *optional*, defaults to 32): 30 | Number of attention heads for each attention layer in the Transformer decoder. 31 | n_shared_experts (`int`, *optional*, defaults to None): 32 | Number of shared experts, None means dense model. 33 | n_routed_experts (`int`, *optional*, defaults to None): 34 | Number of routed experts, None means dense model. 35 | num_experts_per_tok (`int`, *optional*, defaults to None): 36 | Number of selected experts, None means dense model. 37 | moe_layer_freq (`int`, *optional*, defaults to 1): 38 | The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers. 39 | first_k_dense_replace (`int`, *optional*, defaults to 0): 40 | Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). 41 | \--k dense layers--/ 42 | norm_topk_prob (`bool`, *optional*, defaults to False): 43 | Whether to normalize the weights of the routed experts. 44 | scoring_func (`str`, *optional*, defaults to 'softmax'): 45 | Method of computing expert weights. 46 | aux_loss_alpha (`float`, *optional*, defaults to 0.001): 47 | Auxiliary loss weight coefficient. 48 | seq_aux = (`bool`, *optional*, defaults to True): 49 | Whether to compute the auxiliary loss for each individual sample. 50 | num_key_value_heads (`int`, *optional*): 51 | This is the number of key_value heads that should be used to implement Grouped Query Attention. If 52 | `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if 53 | `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When 54 | converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed 55 | by meanpooling all the original heads within that group. For more details checkout [this 56 | paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to 57 | `num_attention_heads`. 58 | hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): 59 | The non-linear activation function (function or string) in the decoder. 60 | max_position_embeddings (`int`, *optional*, defaults to 2048): 61 | The maximum sequence length that this model might ever be used with. 62 | initializer_range (`float`, *optional*, defaults to 0.02): 63 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 64 | rms_norm_eps (`float`, *optional*, defaults to 1e-06): 65 | The epsilon used by the rms normalization layers. 66 | use_cache (`bool`, *optional*, defaults to `True`): 67 | Whether or not the model should return the last key/values attentions (not used by all models). Only 68 | relevant if `config.is_decoder=True`. 69 | pad_token_id (`int`, *optional*): 70 | Padding token id. 71 | bos_token_id (`int`, *optional*, defaults to 1): 72 | Beginning of stream token id. 73 | eos_token_id (`int`, *optional*, defaults to 2): 74 | End of stream token id. 75 | pretraining_tp (`int`, *optional*, defaults to 1): 76 | Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this 77 | document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is 78 | necessary to ensure exact reproducibility of the pretraining results. Please refer to [this 79 | issue](https://github.com/pytorch/pytorch/issues/76232). 80 | tie_word_embeddings (`bool`, *optional*, defaults to `False`): 81 | Whether to tie weight embeddings 82 | rope_theta (`float`, *optional*, defaults to 10000.0): 83 | The base period of the RoPE embeddings. 84 | rope_scaling (`Dict`, *optional*): 85 | Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling 86 | strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is 87 | `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update 88 | `max_position_embeddings` to the expected new maximum. 89 | attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): 90 | Whether to use a bias in the query, key, value and output projection layers during self-attention. 91 | attention_dropout (`float`, *optional*, defaults to 0.0): 92 | The dropout ratio for the attention probabilities. 93 | 94 | ```python 95 | >>> from transformers import DeepseekModel, DeepseekConfig 96 | 97 | >>> # Initializing a Deepseek deepseek-7b style configuration 98 | >>> configuration = DeepseekConfig() 99 | 100 | >>> # Accessing the model configuration 101 | >>> configuration = model.config 102 | ```""" 103 | 104 | model_type = "deepseek" 105 | keys_to_ignore_at_inference = ["past_key_values"] 106 | 107 | def __init__( 108 | self, 109 | vocab_size=102400, 110 | hidden_size=4096, 111 | intermediate_size=11008, 112 | moe_intermediate_size = 1407, 113 | num_hidden_layers=30, 114 | num_attention_heads=32, 115 | num_key_value_heads=32, 116 | n_shared_experts = None, 117 | n_routed_experts = None, 118 | num_experts_per_tok = None, 119 | moe_layer_freq = 1, 120 | first_k_dense_replace = 0, 121 | norm_topk_prob = False, 122 | scoring_func = 'softmax', 123 | aux_loss_alpha = 0.001, 124 | seq_aux = True, 125 | hidden_act="silu", 126 | max_position_embeddings=2048, 127 | initializer_range=0.02, 128 | rms_norm_eps=1e-6, 129 | use_cache=True, 130 | pad_token_id=None, 131 | bos_token_id=100000, 132 | eos_token_id=100001, 133 | pretraining_tp=1, 134 | tie_word_embeddings=False, 135 | rope_theta=10000.0, 136 | rope_scaling=None, 137 | attention_bias=False, 138 | attention_dropout=0.0, 139 | routed_tok=[0, 1, 2, 3, 4, 5], 140 | output_router_logits=False, 141 | dynamic_expert_routing_threshold=1.0, 142 | **kwargs, 143 | ): 144 | self.vocab_size = vocab_size 145 | self.max_position_embeddings = max_position_embeddings 146 | self.hidden_size = hidden_size 147 | self.intermediate_size = intermediate_size 148 | self.moe_intermediate_size = moe_intermediate_size 149 | self.num_hidden_layers = num_hidden_layers 150 | self.num_attention_heads = num_attention_heads 151 | self.n_shared_experts = n_shared_experts 152 | self.n_routed_experts = n_routed_experts 153 | self.num_experts_per_tok = num_experts_per_tok 154 | self.moe_layer_freq = moe_layer_freq 155 | self.first_k_dense_replace = first_k_dense_replace 156 | self.norm_topk_prob = norm_topk_prob 157 | self.scoring_func = scoring_func 158 | self.aux_loss_alpha = aux_loss_alpha 159 | self.seq_aux = seq_aux 160 | # for backward compatibility 161 | if num_key_value_heads is None: 162 | num_key_value_heads = num_attention_heads 163 | 164 | self.num_key_value_heads = num_key_value_heads 165 | self.hidden_act = hidden_act 166 | self.initializer_range = initializer_range 167 | self.rms_norm_eps = rms_norm_eps 168 | self.pretraining_tp = pretraining_tp 169 | self.use_cache = use_cache 170 | self.rope_theta = rope_theta 171 | self.rope_scaling = rope_scaling 172 | self._rope_scaling_validation() 173 | self.attention_bias = attention_bias 174 | self.attention_dropout = attention_dropout 175 | self.routed_tok = routed_tok 176 | self.output_router_logits = output_router_logits 177 | self.dynamic_expert_routing_threshold = dynamic_expert_routing_threshold 178 | 179 | super().__init__( 180 | pad_token_id=pad_token_id, 181 | bos_token_id=bos_token_id, 182 | eos_token_id=eos_token_id, 183 | tie_word_embeddings=tie_word_embeddings, 184 | **kwargs, 185 | ) 186 | 187 | def _rope_scaling_validation(self): 188 | """ 189 | Validate the `rope_scaling` configuration. 190 | """ 191 | if self.rope_scaling is None: 192 | return 193 | 194 | if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: 195 | raise ValueError( 196 | "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, " 197 | f"got {self.rope_scaling}" 198 | ) 199 | rope_scaling_type = self.rope_scaling.get("type", None) 200 | rope_scaling_factor = self.rope_scaling.get("factor", None) 201 | if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: 202 | raise ValueError( 203 | f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" 204 | ) 205 | if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: 206 | raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") --------------------------------------------------------------------------------