├── LICENSE ├── Makefile ├── Readme.md ├── cross_val.py ├── main.py ├── mrr.py ├── prme ├── __init__.py ├── dataio.py ├── mrr.pyx ├── myrandom │ ├── __init__.py │ ├── random.pxd │ ├── random.pyx │ ├── randomkit.c │ ├── randomkit.h │ └── tests │ │ └── __init__.py └── prme.pyx └── setup.py /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015, Flavio Figueiredo 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | * Neither the name of fpmc nor the names of its 15 | contributors may be used to endorse or promote products derived from 16 | this software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | 29 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Simple makefile 2 | 3 | PYTHON ?= python 4 | NOSETESTS ?= nosetests 5 | 6 | all: 7 | $(PYTHON) setup.py build_ext --inplace 8 | 9 | clean: 10 | rm -rf build/ 11 | find . -name "*.pyc" | xargs rm -f 12 | find . -name "*.c" | grep -v randomkit.c | xargs rm -f 13 | find . -name "*.so" | xargs rm -f 14 | 15 | test: 16 | $(NOSETESTS) 17 | 18 | trailing-spaces: 19 | find -name "*.py" | xargs sed 's/^M$$//' 20 | -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | PRME 2 | ---- 3 | 4 | Python/Cython implementation of the: "Personalized Ranking Metric Embedding for 5 | Next New POI Recommendation" paper. 6 | 7 | Notes 8 | ----- 9 | 10 | This the PRME model from the paper, not the PRME-G. Should be simple enough 11 | to adapt the code for PRME-G. 12 | 13 | Dependencies for library 14 | ------------------------ 15 | * Cython 16 | * Numpy 17 | * Pandas 18 | 19 | How to install 20 | -------------- 21 | 22 | Clone the repo 23 | 24 | :: 25 | 26 | $ git clone https://github.com/flaviovdf/prme.git 27 | 28 | Make sure you have cython and numpy. If not run as root (or use your distros package manager) 29 | 30 | :: 31 | 32 | $ pip install numpy 33 | 34 | :: 35 | 36 | $ pip install Cython 37 | 38 | Install 39 | 40 | :: 41 | 42 | $ python setup.py install 43 | 44 | Run the main script or the cross_val script: 45 | 46 | $ python main.py data_file num_latent_factors model.h5 47 | 48 | This will read the data_file, decompose with num_latent_factors and save 49 | the model under the filename model.h5 50 | 51 | The model is a pandas HDFStore. Just read-it with: 52 | 53 | :: 54 | 55 | >> import pandas as pd 56 | 57 | >> pd.HDFStore('model.h5') 58 | 59 | The keys of this store have the output matrices described in the paper. 60 | 61 | Input Format 62 | ------------ 63 | 64 | The input file should have this format: 65 | 66 | dt user from to 67 | 68 | That is, a tab separated file where the first column is the amount of time the user 69 | spent on `from` before going to `to`. The second column is the user id, the third 70 | is the `from` object, whereas the fourth is the destination `to` object. I used 71 | this input on other repositores, thus the main reason I kept it here. 72 | 73 | References 74 | ---------- 75 | .. [1] Shanshan Feng, Xutao Li, Yifeng Zeng, Gao Cong, Yeow Meng Chee, Quan Yuan 76 | "Personalized Ranking Metric Embedding for Next New POI Recommendation" - IJCAI 2015 77 | -------------------------------------------------------------------------------- /cross_val.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf8 2 | from __future__ import division, print_function 3 | 4 | from prme import dataio 5 | from prme import learn 6 | 7 | import argparse 8 | import numpy as np 9 | import pandas as pd 10 | import os 11 | import time 12 | 13 | def main(): 14 | parser = argparse.ArgumentParser() 15 | 16 | parser.add_argument('trace_fpath', help='The trace to learn topics from', \ 17 | type=str) 18 | parser.add_argument('num_topics', help='The number of topics to learn', \ 19 | type=int) 20 | parser.add_argument('model_fpath', \ 21 | help='The name of the model file (a h5 file)', type=str) 22 | 23 | parser.add_argument('--leaveout', \ 24 | help='The number of transitions to leave for test', type=float, \ 25 | default=0.3) 26 | 27 | args = parser.parse_args() 28 | started = time.mktime(time.localtime()) 29 | 30 | num_lines = 0 31 | with open(args.trace_fpath) as trace_file: 32 | num_lines = sum(1 for _ in trace_file) 33 | 34 | if args.leaveout > 0: 35 | leave_out = min(1, args.leaveout) 36 | if leave_out == 1: 37 | print('Leave out is 1 (100%), nothing todo') 38 | return 39 | from_ = 0 40 | to = int(num_lines - num_lines * leave_out) 41 | else: 42 | from_ = 0 43 | to = np.inf 44 | 45 | max_cost = float('-inf') 46 | best_model = None 47 | 48 | for rate in [0.0001, 0.001, 0.01]: 49 | for reg in [0.00001, 0.0001, 0.001, 0.01]: 50 | for alpha in [0.25, 0.5, 0.75]: 51 | for tau in [0, 60 * 60, 12 * 60 * 60, 24 * 60 * 60]: 52 | rv = learn(args.trace_fpath, args.num_topics, rate, \ 53 | reg, alpha, tau, from_, to) 54 | cost_val = rv['cost_val'][0] 55 | if cost_val > max_cost: 56 | max_cost = cost_val 57 | best_model = rv 58 | print(max_cost) 59 | 60 | ended = time.mktime(time.localtime()) 61 | best_model['training_time'] = np.array([ended - started]) 62 | dataio.save_model(args.model_fpath, best_model) 63 | print('Learning took', ended - started, 'seconds') 64 | 65 | if __name__ == '__main__': 66 | main() 67 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf8 2 | from __future__ import division, print_function 3 | 4 | from prme import dataio 5 | from prme import learn 6 | 7 | import argparse 8 | import numpy as np 9 | import os 10 | import time 11 | 12 | def main(): 13 | parser = argparse.ArgumentParser() 14 | 15 | parser.add_argument('trace_fpath', help='The trace to learn topics from', \ 16 | type=str) 17 | parser.add_argument('num_topics', help='The number of topics to learn', \ 18 | type=int) 19 | parser.add_argument('model_fpath', \ 20 | help='The name of the model file (a h5 file)', type=str) 21 | parser.add_argument('--leaveout', \ 22 | help='The number of transitions to leave for test', type=float, \ 23 | default=0) 24 | 25 | parser.add_argument('--learning_rate', \ 26 | help='The learning rate for the algorithm', \ 27 | type=float, default=0.005) 28 | parser.add_argument('--regularization', help='The regularization', \ 29 | type=float, default=0.03) 30 | parser.add_argument('--alpha', help='Value for the alpha parameter', \ 31 | type=float, default=0.02) 32 | parser.add_argument('--tau', help='Value for the tau parameter', \ 33 | type=float, default=3 * 60 * 60) 34 | 35 | args = parser.parse_args() 36 | started = time.mktime(time.localtime()) 37 | num_lines = 0 38 | with open(args.trace_fpath) as trace_file: 39 | num_lines = sum(1 for _ in trace_file) 40 | 41 | if args.leaveout > 0: 42 | leave_out = min(1, args.leaveout) 43 | if leave_out == 1: 44 | print('Leave out is 1 (100%), nothing todo') 45 | return 46 | from_ = 0 47 | to = int(num_lines - num_lines * leave_out) 48 | else: 49 | from_ = 0 50 | to = np.inf 51 | 52 | print('Learning') 53 | rv = learn(args.trace_fpath, args.num_topics, args.learning_rate, \ 54 | args.regularization, args.alpha, args.tau, from_, to) 55 | ended = time.mktime(time.localtime()) 56 | rv['training_time'] = np.array([ended - started]) 57 | dataio.save_model(args.model_fpath, rv) 58 | print('Learning took', ended - started, 'seconds') 59 | 60 | if __name__ == '__main__': 61 | main() 62 | -------------------------------------------------------------------------------- /mrr.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf8 2 | from __future__ import division, print_function 3 | 4 | from prme import mrr 5 | 6 | import pandas as pd 7 | import plac 8 | import numpy as np 9 | 10 | def main(model, out_fpath): 11 | store = pd.HDFStore(model) 12 | 13 | from_ = store['from_'][0][0] 14 | to = store['to'][0][0] 15 | assert from_ == 0 16 | 17 | trace_fpath = store['trace_fpath'][0][0] 18 | 19 | XP_hk = store['XP_hk'].values 20 | XP_ok = store['XP_ok'].values 21 | XG_ok = store['XG_ok'].values 22 | alpha = store['alpha'].values[0][0] 23 | tau = store['tau'].values[0][0] 24 | 25 | hyper2id = dict(store['hyper2id'].values) 26 | obj2id = dict(store['obj2id'].values) 27 | 28 | HSDs = [] 29 | dts = [] 30 | 31 | with open(trace_fpath) as trace_file: 32 | for i, l in enumerate(trace_file): 33 | if i < to: 34 | continue 35 | 36 | dt, h, s, d = l.strip().split('\t') 37 | if h in hyper2id and s in obj2id and d in obj2id: 38 | dts.append(float(dt)) 39 | HSDs.append([hyper2id[h], obj2id[s], obj2id[d]]) 40 | 41 | num_queries = min(10000, len(HSDs)) 42 | queries = np.random.choice(len(HSDs), size=num_queries) 43 | 44 | dts = np.array(dts, order='C', dtype='d') 45 | HSDs = np.array(HSDs, order='C', dtype='i4') 46 | rrs = mrr.compute(dts, HSDs, XP_hk, XP_ok, XG_ok, alpha, tau) 47 | 48 | np.savetxt(out_fpath, rrs) 49 | store.close() 50 | 51 | plac.call(main) 52 | -------------------------------------------------------------------------------- /prme/__init__.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf8 2 | from __future__ import print_function, division 3 | 4 | import dataio 5 | import numpy as np 6 | import os 7 | 8 | from prme import sgd 9 | 10 | def learn(trace_fpath, nk, rate, regularization, alpha, tau, 11 | from_=0, to=np.inf, validation=0.1): 12 | 13 | dts, Trace, seen, hyper2id, obj2id = \ 14 | dataio.initialize_trace(trace_fpath, from_, to) 15 | no = len(obj2id) 16 | nh = len(hyper2id) 17 | 18 | validation_from = int(len(dts) - len(dts) * validation) 19 | print('Using first %d of %d as train, rest is validation' \ 20 | % (validation_from, len(dts))) 21 | 22 | dts_train = dts[:validation_from] 23 | Trace_train = Trace[:validation_from] 24 | 25 | rnd_idx = np.arange(len(dts_train)) 26 | np.random.shuffle(rnd_idx) 27 | 28 | dts_train = np.asanyarray(dts_train[rnd_idx], dtype='f8', order='C') 29 | Trace_train = np.asanyarray(Trace_train[rnd_idx], dtype='i4', order='C') 30 | 31 | dts_val = np.asanyarray(dts[validation_from:], dtype='f8', order='C') 32 | Trace_val = np.asanyarray(Trace[validation_from:], dtype='i4', order='C') 33 | 34 | XG_ok = np.random.normal(0, 0.01, (no, nk)) 35 | XP_ok = np.random.normal(0, 0.01, (no, nk)) 36 | XP_hk = np.random.normal(0, 0.01, (nh, nk)) 37 | 38 | cost_train, cost_val = sgd(dts, Trace, XG_ok, XP_ok, XP_hk, seen, rate, \ 39 | regularization, alpha, tau, dts_val, Trace_val) 40 | 41 | rv = {} 42 | rv['num_topics'] = np.asarray([nk]) 43 | rv['trace_fpath'] = np.asarray([os.path.abspath(trace_fpath)]) 44 | rv['rate'] = np.asarray([rate]) 45 | rv['regularization'] = np.asarray([regularization]) 46 | rv['alpha'] = np.asarray([alpha]) 47 | rv['tau'] = np.asarray([tau]) 48 | rv['from_'] = np.asarray([from_]) 49 | rv['cost_train'] = np.asarray([cost_train]) 50 | rv['cost_val'] = np.asarray([cost_val]) 51 | rv['to'] = np.asarray([to]) 52 | rv['hyper2id'] = hyper2id 53 | rv['obj2id'] = obj2id 54 | rv['XG_ok'] = XG_ok 55 | rv['XP_ok'] = XP_ok 56 | rv['XP_hk'] = XP_hk 57 | return rv 58 | 59 | -------------------------------------------------------------------------------- /prme/dataio.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf8 2 | from __future__ import division, print_function 3 | 4 | from collections import OrderedDict 5 | 6 | import numpy as np 7 | import pandas as pd 8 | 9 | def save_model(out_fpath, model): 10 | store = pd.HDFStore(out_fpath, 'w') 11 | for model_key in model: 12 | model_val = model[model_key] 13 | 14 | if type(model_val) == np.ndarray: 15 | store[model_key] = pd.DataFrame(model_val) 16 | else: 17 | store[model_key] = pd.DataFrame(model_val.items(), \ 18 | columns=['Name', 'Id']) 19 | store.close() 20 | 21 | def initialize_trace(trace_fpath, from_=0, to=np.inf): 22 | 23 | hyper2id = OrderedDict() 24 | obj2id = OrderedDict() 25 | seen = {} 26 | 27 | Trace = [] 28 | dts = [] 29 | with open(trace_fpath, 'r') as trace_file: 30 | for i, line in enumerate(trace_file): 31 | if i < from_: 32 | continue 33 | 34 | if i >= to: 35 | break 36 | dt, hyper_str, src_str, dest_str = line.strip().split('\t') 37 | 38 | if hyper_str not in hyper2id: 39 | hyper2id[hyper_str] = len(hyper2id) 40 | 41 | if src_str not in obj2id: 42 | obj2id[src_str] = len(obj2id) 43 | 44 | if dest_str not in obj2id: 45 | obj2id[dest_str] = len(obj2id) 46 | 47 | dt = float(dt) 48 | h = hyper2id[hyper_str] 49 | s = obj2id[src_str] 50 | d = obj2id[dest_str] 51 | 52 | if (h, s) not in seen: 53 | seen[h, s] = set() 54 | 55 | seen[h, s].add(d) 56 | Trace.append([h, s, d]) 57 | dts.append(dt) 58 | 59 | dts = np.asanyarray(dts, order='C') 60 | Trace = np.asanyarray(Trace, dtype='i4', order='C') 61 | return dts, Trace, seen, hyper2id, obj2id 62 | -------------------------------------------------------------------------------- /prme/mrr.pyx: -------------------------------------------------------------------------------- 1 | #-*- coding: utf8 2 | # cython: boundscheck = False 3 | # cython: cdivision = True 4 | # cython: initializedcheck = False 5 | # cython: nonecheck = False 6 | # cython: wraparound = False 7 | from __future__ import division, print_function 8 | 9 | from cython.parallel cimport prange 10 | 11 | import numpy as np 12 | 13 | def compute(double[::1] dts, int[:, ::1] HSDs, double[:, ::1] XP_hk, \ 14 | double[:, ::1] XP_ok, double[:, ::1] XG_ok, double alpha, \ 15 | double tau): 16 | 17 | cdef double[::1] aux = np.zeros(XP_ok.shape[0], dtype='d') 18 | cdef double[::1] rrs = np.zeros(HSDs.shape[0], dtype='d') 19 | cdef int i, h, s, d, candidate_d, k 20 | cdef double dt, alpha_to_use 21 | 22 | for i in xrange(HSDs.shape[0]): 23 | dt = dts[i] 24 | 25 | if dt > tau: 26 | alpha_to_use = 1.0 27 | else: 28 | alpha_to_use = alpha 29 | 30 | h = HSDs[i, 0] 31 | s = HSDs[i, 1] 32 | d = HSDs[i, 2] 33 | for candidate_d in prange(XP_ok.shape[0], schedule='static', nogil=True): 34 | aux[candidate_d] = 0.0 35 | 36 | for k in xrange(XP_ok.shape[1]): 37 | for candidate_d in prange(XP_ok.shape[0], schedule='static', nogil=True): 38 | aux[candidate_d] += alpha_to_use * \ 39 | (XP_hk[h, k] - XP_ok[candidate_d, k]) ** 2 40 | 41 | for candidate_d in prange(XP_ok.shape[0], schedule='static', nogil=True): 42 | aux[candidate_d] += (1 - alpha_to_use) * \ 43 | (XG_ok[s, k] - XG_ok[candidate_d, k]) ** 2 44 | 45 | for candidate_d in prange(XP_ok.shape[0], schedule='static', nogil=True): 46 | if aux[candidate_d] <= aux[d]: 47 | rrs[i] += 1 48 | rrs[i] = 1.0 / rrs[i] 49 | 50 | return np.array(rrs) 51 | -------------------------------------------------------------------------------- /prme/myrandom/__init__.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf8 2 | from __future__ import division, print_function 3 | -------------------------------------------------------------------------------- /prme/myrandom/random.pxd: -------------------------------------------------------------------------------- 1 | #-*- coding: utf8 2 | # cython: boundscheck = False 3 | # cython: cdivision = True 4 | # cython: initializedcheck = False 5 | # cython: wraparound = False 6 | # cython: nonecheck = False 7 | 8 | from __future__ import division, print_function 9 | 10 | cdef extern from 'randomkit.h': 11 | ctypedef struct rk_state: 12 | unsigned long key[624] 13 | int pos 14 | int has_gauss 15 | double gauss 16 | 17 | cdef class RNG: 18 | cdef rk_state *rng_state 19 | cdef void set_seed(self, unsigned long seed) nogil 20 | cdef double rand(self) nogil 21 | 22 | cdef void set_seed(unsigned long seed) nogil 23 | cdef double rand() nogil 24 | -------------------------------------------------------------------------------- /prme/myrandom/random.pyx: -------------------------------------------------------------------------------- 1 | #-*- coding: utf8 2 | # cython: boundscheck = False 3 | # cython: cdivision = True 4 | # cython: initializedcheck = False 5 | # cython: wraparound = False 6 | # cython: nonecheck = False 7 | 8 | from __future__ import division, print_function 9 | 10 | import os 11 | 12 | cdef extern from 'randomkit.h': 13 | cdef void rk_seed(unsigned long seed, rk_state *state) nogil 14 | cdef double rk_double(rk_state *state) nogil 15 | 16 | cdef extern from 'stdlib.h': 17 | cdef void *malloc(size_t) nogil 18 | cdef void free(void *) nogil 19 | 20 | cdef class RNG: 21 | 22 | def __cinit__(self): 23 | self.rng_state = malloc(sizeof(rk_state)) 24 | if self.rng_state == NULL: 25 | raise MemoryError() 26 | 27 | cdef unsigned long *seedptr 28 | cdef object seed = os.urandom(sizeof(unsigned long)) 29 | seedptr = (( seed)) 30 | self.set_seed(seedptr[0]) 31 | 32 | def __dealloc__(self): 33 | if self.rng_state != NULL: 34 | free(self.rng_state) 35 | self.rng_state = NULL 36 | 37 | cdef void set_seed(self, unsigned long seed) nogil: 38 | rk_seed(seed, self.rng_state) 39 | 40 | cdef double rand(self) nogil: 41 | return rk_double(self.rng_state) 42 | 43 | cdef RNG _global_rng = RNG() 44 | 45 | cdef void set_seed(unsigned long seed) nogil: 46 | _global_rng.set_seed(seed) 47 | 48 | cdef double rand() nogil: 49 | return _global_rng.rand() 50 | -------------------------------------------------------------------------------- /prme/myrandom/randomkit.c: -------------------------------------------------------------------------------- 1 | /* Random kit 1.3 */ 2 | 3 | /* 4 | * Copyright (c) 2003-2005, Jean-Sebastien Roy (js@jeannot.org) 5 | * 6 | * The rk_random and rk_seed functions algorithms and the original design of 7 | * the Mersenne Twister RNG: 8 | * 9 | * Copyright (C) 1997 - 2002, Makoto Matsumoto and Takuji Nishimura, 10 | * All rights reserved. 11 | * 12 | * Redistribution and use in source and binary forms, with or without 13 | * modification, are permitted provided that the following conditions 14 | * are met: 15 | * 16 | * 1. Redistributions of source code must retain the above copyright 17 | * notice, this list of conditions and the following disclaimer. 18 | * 19 | * 2. Redistributions in binary form must reproduce the above copyright 20 | * notice, this list of conditions and the following disclaimer in the 21 | * documentation and/or other materials provided with the distribution. 22 | * 23 | * 3. The names of its contributors may not be used to endorse or promote 24 | * products derived from this software without specific prior written 25 | * permission. 26 | * 27 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 28 | * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 29 | * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 30 | * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 31 | * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 32 | * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 33 | * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 34 | * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 35 | * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 36 | * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 37 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 38 | * 39 | * Original algorithm for the implementation of rk_interval function from 40 | * Richard J. Wagner's implementation of the Mersenne Twister RNG, optimised by 41 | * Magnus Jonsson. 42 | * 43 | * Constants used in the rk_double implementation by Isaku Wada. 44 | * 45 | * Permission is hereby granted, free of charge, to any person obtaining a 46 | * copy of this software and associated documentation files (the 47 | * "Software"), to deal in the Software without restriction, including 48 | * without limitation the rights to use, copy, modify, merge, publish, 49 | * distribute, sublicense, and/or sell copies of the Software, and to 50 | * permit persons to whom the Software is furnished to do so, subject to 51 | * the following conditions: 52 | * 53 | * The above copyright notice and this permission notice shall be included 54 | * in all copies or substantial portions of the Software. 55 | * 56 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS 57 | * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 58 | * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 59 | * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 60 | * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 61 | * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE 62 | * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 63 | */ 64 | 65 | /* static char const rcsid[] = 66 | "@(#) $Jeannot: randomkit.c,v 1.28 2005/07/21 22:14:09 js Exp $"; */ 67 | #include 68 | #include 69 | #include 70 | #include 71 | #include 72 | #include 73 | 74 | #ifdef _WIN32 75 | /* 76 | * Windows 77 | * XXX: we have to use this ugly defined(__GNUC__) because it is not easy to 78 | * detect the compiler used in distutils itself 79 | */ 80 | #if (defined(__GNUC__) && defined(NPY_NEEDS_MINGW_TIME_WORKAROUND)) 81 | 82 | /* 83 | * FIXME: ideally, we should set this to the real version of MSVCRT. We need 84 | * something higher than 0x601 to enable _ftime64 and co 85 | */ 86 | #define __MSVCRT_VERSION__ 0x0700 87 | #include 88 | #include 89 | 90 | /* 91 | * mingw msvcr lib import wrongly export _ftime, which does not exist in the 92 | * actual msvc runtime for version >= 8; we make it an alias to _ftime64, which 93 | * is available in those versions of the runtime 94 | */ 95 | #define _FTIME(x) _ftime64((x)) 96 | #else 97 | #include 98 | #include 99 | #define _FTIME(x) _ftime((x)) 100 | #endif 101 | 102 | #ifndef RK_NO_WINCRYPT 103 | /* Windows crypto */ 104 | #ifndef _WIN32_WINNT 105 | #define _WIN32_WINNT 0x0400 106 | #endif 107 | #include 108 | #include 109 | #endif 110 | 111 | #else 112 | /* Unix */ 113 | #include 114 | #include 115 | #include 116 | #endif 117 | 118 | #include "randomkit.h" 119 | 120 | #ifndef RK_DEV_URANDOM 121 | #define RK_DEV_URANDOM "/dev/urandom" 122 | #endif 123 | 124 | #ifndef RK_DEV_RANDOM 125 | #define RK_DEV_RANDOM "/dev/random" 126 | #endif 127 | 128 | char *rk_strerror[RK_ERR_MAX] = 129 | { 130 | "no error", 131 | "random device unvavailable" 132 | }; 133 | 134 | /* static functions */ 135 | static unsigned long rk_hash(unsigned long key); 136 | 137 | void 138 | rk_seed(unsigned long seed, rk_state *state) 139 | { 140 | int pos; 141 | seed &= 0xffffffffUL; 142 | 143 | /* Knuth's PRNG as used in the Mersenne Twister reference implementation */ 144 | for (pos = 0; pos < RK_STATE_LEN; pos++) { 145 | state->key[pos] = seed; 146 | seed = (1812433253UL * (seed ^ (seed >> 30)) + pos + 1) & 0xffffffffUL; 147 | } 148 | state->pos = RK_STATE_LEN; 149 | state->gauss = 0; 150 | state->has_gauss = 0; 151 | state->has_binomial = 0; 152 | } 153 | 154 | /* Thomas Wang 32 bits integer hash function */ 155 | unsigned long 156 | rk_hash(unsigned long key) 157 | { 158 | key += ~(key << 15); 159 | key ^= (key >> 10); 160 | key += (key << 3); 161 | key ^= (key >> 6); 162 | key += ~(key << 11); 163 | key ^= (key >> 16); 164 | return key; 165 | } 166 | 167 | rk_error 168 | rk_randomseed(rk_state *state) 169 | { 170 | #ifndef _WIN32 171 | struct timeval tv; 172 | #else 173 | struct _timeb tv; 174 | #endif 175 | int i; 176 | 177 | if (rk_devfill(state->key, sizeof(state->key), 0) == RK_NOERR) { 178 | /* ensures non-zero key */ 179 | state->key[0] |= 0x80000000UL; 180 | state->pos = RK_STATE_LEN; 181 | state->gauss = 0; 182 | state->has_gauss = 0; 183 | state->has_binomial = 0; 184 | 185 | for (i = 0; i < 624; i++) { 186 | state->key[i] &= 0xffffffffUL; 187 | } 188 | return RK_NOERR; 189 | } 190 | 191 | #ifndef _WIN32 192 | gettimeofday(&tv, NULL); 193 | rk_seed(rk_hash(getpid()) ^ rk_hash(tv.tv_sec) ^ rk_hash(tv.tv_usec) 194 | ^ rk_hash(clock()), state); 195 | #else 196 | _FTIME(&tv); 197 | rk_seed(rk_hash(tv.time) ^ rk_hash(tv.millitm) ^ rk_hash(clock()), state); 198 | #endif 199 | 200 | return RK_ENODEV; 201 | } 202 | 203 | /* Magic Mersenne Twister constants */ 204 | #define N 624 205 | #define M 397 206 | #define MATRIX_A 0x9908b0dfUL 207 | #define UPPER_MASK 0x80000000UL 208 | #define LOWER_MASK 0x7fffffffUL 209 | 210 | /* Slightly optimised reference implementation of the Mersenne Twister */ 211 | unsigned long 212 | rk_random(rk_state *state) 213 | { 214 | unsigned long y; 215 | 216 | if (state->pos == RK_STATE_LEN) { 217 | int i; 218 | 219 | for (i = 0; i < N - M; i++) { 220 | y = (state->key[i] & UPPER_MASK) | (state->key[i+1] & LOWER_MASK); 221 | state->key[i] = state->key[i+M] ^ (y>>1) ^ (-(y & 1) & MATRIX_A); 222 | } 223 | for (; i < N - 1; i++) { 224 | y = (state->key[i] & UPPER_MASK) | (state->key[i+1] & LOWER_MASK); 225 | state->key[i] = state->key[i+(M-N)] ^ (y>>1) ^ (-(y & 1) & MATRIX_A); 226 | } 227 | y = (state->key[N - 1] & UPPER_MASK) | (state->key[0] & LOWER_MASK); 228 | state->key[N - 1] = state->key[M - 1] ^ (y >> 1) ^ (-(y & 1) & MATRIX_A); 229 | 230 | state->pos = 0; 231 | } 232 | y = state->key[state->pos++]; 233 | 234 | /* Tempering */ 235 | y ^= (y >> 11); 236 | y ^= (y << 7) & 0x9d2c5680UL; 237 | y ^= (y << 15) & 0xefc60000UL; 238 | y ^= (y >> 18); 239 | 240 | return y; 241 | } 242 | 243 | long 244 | rk_long(rk_state *state) 245 | { 246 | return rk_ulong(state) >> 1; 247 | } 248 | 249 | unsigned long 250 | rk_ulong(rk_state *state) 251 | { 252 | #if ULONG_MAX <= 0xffffffffUL 253 | return rk_random(state); 254 | #else 255 | return (rk_random(state) << 32) | (rk_random(state)); 256 | #endif 257 | } 258 | 259 | unsigned long 260 | rk_interval(unsigned long max, rk_state *state) 261 | { 262 | unsigned long mask = max, value; 263 | 264 | if (max == 0) { 265 | return 0; 266 | } 267 | /* Smallest bit mask >= max */ 268 | mask |= mask >> 1; 269 | mask |= mask >> 2; 270 | mask |= mask >> 4; 271 | mask |= mask >> 8; 272 | mask |= mask >> 16; 273 | #if ULONG_MAX > 0xffffffffUL 274 | mask |= mask >> 32; 275 | #endif 276 | 277 | /* Search a random value in [0..mask] <= max */ 278 | #if ULONG_MAX > 0xffffffffUL 279 | if (max <= 0xffffffffUL) { 280 | while ((value = (rk_random(state) & mask)) > max); 281 | } 282 | else { 283 | while ((value = (rk_ulong(state) & mask)) > max); 284 | } 285 | #else 286 | while ((value = (rk_ulong(state) & mask)) > max); 287 | #endif 288 | return value; 289 | } 290 | 291 | double 292 | rk_double(rk_state *state) 293 | { 294 | /* shifts : 67108864 = 0x4000000, 9007199254740992 = 0x20000000000000 */ 295 | long a = rk_random(state) >> 5, b = rk_random(state) >> 6; 296 | return (a * 67108864.0 + b) / 9007199254740992.0; 297 | } 298 | 299 | void 300 | rk_fill(void *buffer, size_t size, rk_state *state) 301 | { 302 | unsigned long r; 303 | unsigned char *buf = buffer; 304 | 305 | for (; size >= 4; size -= 4) { 306 | r = rk_random(state); 307 | *(buf++) = r & 0xFF; 308 | *(buf++) = (r >> 8) & 0xFF; 309 | *(buf++) = (r >> 16) & 0xFF; 310 | *(buf++) = (r >> 24) & 0xFF; 311 | } 312 | 313 | if (!size) { 314 | return; 315 | } 316 | r = rk_random(state); 317 | for (; size; r >>= 8, size --) { 318 | *(buf++) = (unsigned char)(r & 0xFF); 319 | } 320 | } 321 | 322 | rk_error 323 | rk_devfill(void *buffer, size_t size, int strong) 324 | { 325 | #ifndef _WIN32 326 | FILE *rfile; 327 | int done; 328 | 329 | if (strong) { 330 | rfile = fopen(RK_DEV_RANDOM, "rb"); 331 | } 332 | else { 333 | rfile = fopen(RK_DEV_URANDOM, "rb"); 334 | } 335 | if (rfile == NULL) { 336 | return RK_ENODEV; 337 | } 338 | done = fread(buffer, size, 1, rfile); 339 | fclose(rfile); 340 | if (done) { 341 | return RK_NOERR; 342 | } 343 | #else 344 | 345 | #ifndef RK_NO_WINCRYPT 346 | HCRYPTPROV hCryptProv; 347 | BOOL done; 348 | 349 | if (!CryptAcquireContext(&hCryptProv, NULL, NULL, PROV_RSA_FULL, 350 | CRYPT_VERIFYCONTEXT) || !hCryptProv) { 351 | return RK_ENODEV; 352 | } 353 | done = CryptGenRandom(hCryptProv, size, (unsigned char *)buffer); 354 | CryptReleaseContext(hCryptProv, 0); 355 | if (done) { 356 | return RK_NOERR; 357 | } 358 | #endif 359 | 360 | #endif 361 | return RK_ENODEV; 362 | } 363 | 364 | rk_error 365 | rk_altfill(void *buffer, size_t size, int strong, rk_state *state) 366 | { 367 | rk_error err; 368 | 369 | err = rk_devfill(buffer, size, strong); 370 | if (err) { 371 | rk_fill(buffer, size, state); 372 | } 373 | return err; 374 | } 375 | 376 | double 377 | rk_gauss(rk_state *state) 378 | { 379 | if (state->has_gauss) { 380 | const double tmp = state->gauss; 381 | state->gauss = 0; 382 | state->has_gauss = 0; 383 | return tmp; 384 | } 385 | else { 386 | double f, x1, x2, r2; 387 | 388 | do { 389 | x1 = 2.0*rk_double(state) - 1.0; 390 | x2 = 2.0*rk_double(state) - 1.0; 391 | r2 = x1*x1 + x2*x2; 392 | } 393 | while (r2 >= 1.0 || r2 == 0.0); 394 | 395 | /* Box-Muller transform */ 396 | f = sqrt(-2.0*log(r2)/r2); 397 | /* Keep for next call */ 398 | state->gauss = f*x1; 399 | state->has_gauss = 1; 400 | return f*x2; 401 | } 402 | } 403 | -------------------------------------------------------------------------------- /prme/myrandom/randomkit.h: -------------------------------------------------------------------------------- 1 | /* Random kit 1.3 */ 2 | 3 | /* 4 | * Copyright (c) 2003-2005, Jean-Sebastien Roy (js@jeannot.org) 5 | * 6 | * Permission is hereby granted, free of charge, to any person obtaining a 7 | * copy of this software and associated documentation files (the 8 | * "Software"), to deal in the Software without restriction, including 9 | * without limitation the rights to use, copy, modify, merge, publish, 10 | * distribute, sublicense, and/or sell copies of the Software, and to 11 | * permit persons to whom the Software is furnished to do so, subject to 12 | * the following conditions: 13 | * 14 | * The above copyright notice and this permission notice shall be included 15 | * in all copies or substantial portions of the Software. 16 | * 17 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS 18 | * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 19 | * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 20 | * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 21 | * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 22 | * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE 23 | * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 24 | */ 25 | 26 | /* @(#) $Jeannot: randomkit.h,v 1.24 2005/07/21 22:14:09 js Exp $ */ 27 | 28 | /* 29 | * Typical use: 30 | * 31 | * { 32 | * rk_state state; 33 | * unsigned long seed = 1, random_value; 34 | * 35 | * rk_seed(seed, &state); // Initialize the RNG 36 | * ... 37 | * random_value = rk_random(&state); // Generate random values in [0..RK_MAX] 38 | * } 39 | * 40 | * Instead of rk_seed, you can use rk_randomseed which will get a random seed 41 | * from /dev/urandom (or the clock, if /dev/urandom is unavailable): 42 | * 43 | * { 44 | * rk_state state; 45 | * unsigned long random_value; 46 | * 47 | * rk_randomseed(&state); // Initialize the RNG with a random seed 48 | * ... 49 | * random_value = rk_random(&state); // Generate random values in [0..RK_MAX] 50 | * } 51 | */ 52 | 53 | /* 54 | * Useful macro: 55 | * RK_DEV_RANDOM: the device used for random seeding. 56 | * defaults to "/dev/urandom" 57 | */ 58 | 59 | #include 60 | 61 | #ifndef _RANDOMKIT_ 62 | #define _RANDOMKIT_ 63 | 64 | #define RK_STATE_LEN 624 65 | 66 | typedef struct rk_state_ 67 | { 68 | unsigned long key[RK_STATE_LEN]; 69 | int pos; 70 | int has_gauss; /* !=0: gauss contains a gaussian deviate */ 71 | double gauss; 72 | 73 | /* The rk_state structure has been extended to store the following 74 | * information for the binomial generator. If the input values of n or p 75 | * are different than nsave and psave, then the other parameters will be 76 | * recomputed. RTK 2005-09-02 */ 77 | 78 | int has_binomial; /* !=0: following parameters initialized for 79 | binomial */ 80 | double psave; 81 | long nsave; 82 | double r; 83 | double q; 84 | double fm; 85 | long m; 86 | double p1; 87 | double xm; 88 | double xl; 89 | double xr; 90 | double c; 91 | double laml; 92 | double lamr; 93 | double p2; 94 | double p3; 95 | double p4; 96 | 97 | } 98 | rk_state; 99 | 100 | typedef enum { 101 | RK_NOERR = 0, /* no error */ 102 | RK_ENODEV = 1, /* no RK_DEV_RANDOM device */ 103 | RK_ERR_MAX = 2 104 | } rk_error; 105 | 106 | /* error strings */ 107 | extern char *rk_strerror[RK_ERR_MAX]; 108 | 109 | /* Maximum generated random value */ 110 | #define RK_MAX 0xFFFFFFFFUL 111 | 112 | #ifdef __cplusplus 113 | extern "C" { 114 | #endif 115 | 116 | /* 117 | * Initialize the RNG state using the given seed. 118 | */ 119 | extern void rk_seed(unsigned long seed, rk_state *state); 120 | 121 | /* 122 | * Initialize the RNG state using a random seed. 123 | * Uses /dev/random or, when unavailable, the clock (see randomkit.c). 124 | * Returns RK_NOERR when no errors occurs. 125 | * Returns RK_ENODEV when the use of RK_DEV_RANDOM failed (for example because 126 | * there is no such device). In this case, the RNG was initialized using the 127 | * clock. 128 | */ 129 | extern rk_error rk_randomseed(rk_state *state); 130 | 131 | /* 132 | * Returns a random unsigned long between 0 and RK_MAX inclusive 133 | */ 134 | extern unsigned long rk_random(rk_state *state); 135 | 136 | /* 137 | * Returns a random long between 0 and LONG_MAX inclusive 138 | */ 139 | extern long rk_long(rk_state *state); 140 | 141 | /* 142 | * Returns a random unsigned long between 0 and ULONG_MAX inclusive 143 | */ 144 | extern unsigned long rk_ulong(rk_state *state); 145 | 146 | /* 147 | * Returns a random unsigned long between 0 and max inclusive. 148 | */ 149 | extern unsigned long rk_interval(unsigned long max, rk_state *state); 150 | 151 | /* 152 | * Returns a random double between 0.0 and 1.0, 1.0 excluded. 153 | */ 154 | extern double rk_double(rk_state *state); 155 | 156 | /* 157 | * fill the buffer with size random bytes 158 | */ 159 | extern void rk_fill(void *buffer, size_t size, rk_state *state); 160 | 161 | /* 162 | * fill the buffer with randombytes from the random device 163 | * Returns RK_ENODEV if the device is unavailable, or RK_NOERR if it is 164 | * On Unix, if strong is defined, RK_DEV_RANDOM is used. If not, RK_DEV_URANDOM 165 | * is used instead. This parameter has no effect on Windows. 166 | * Warning: on most unixes RK_DEV_RANDOM will wait for enough entropy to answer 167 | * which can take a very long time on quiet systems. 168 | */ 169 | extern rk_error rk_devfill(void *buffer, size_t size, int strong); 170 | 171 | /* 172 | * fill the buffer using rk_devfill if the random device is available and using 173 | * rk_fill if is is not 174 | * parameters have the same meaning as rk_fill and rk_devfill 175 | * Returns RK_ENODEV if the device is unavailable, or RK_NOERR if it is 176 | */ 177 | extern rk_error rk_altfill(void *buffer, size_t size, int strong, 178 | rk_state *state); 179 | 180 | /* 181 | * return a random gaussian deviate with variance unity and zero mean. 182 | */ 183 | extern double rk_gauss(rk_state *state); 184 | 185 | #ifdef __cplusplus 186 | } 187 | #endif 188 | 189 | #endif /* _RANDOMKIT_ */ 190 | -------------------------------------------------------------------------------- /prme/myrandom/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flaviovdf/prme/fe518f6cc0648ce78671039f8af844f33bd6a9b4/prme/myrandom/tests/__init__.py -------------------------------------------------------------------------------- /prme/prme.pyx: -------------------------------------------------------------------------------- 1 | #-*- coding: utf8 2 | # cython: boundscheck = False 3 | # cython: cdivision = True 4 | # cython: initializedcheck = False 5 | # cython: nonecheck = False 6 | # cython: wraparound = False 7 | from __future__ import print_function, division 8 | 9 | import numpy as np 10 | 11 | from prme.myrandom.random cimport rand 12 | 13 | cdef extern from 'math.h': 14 | inline double exp(double) 15 | inline double log(double) 16 | 17 | cdef inline double sigma(double z): 18 | return 1.0 / (1 + exp(-z)) 19 | 20 | cdef inline double compute_dist(int h, int s, int d, double alpha, \ 21 | double[:, ::1] XG_ok, double[:, ::1] XP_ok, double[:, ::1] XP_hk): 22 | 23 | cdef double dp_ho = 0.0 24 | cdef double ds_oo = 0.0 25 | cdef int k 26 | for k in range(XG_ok.shape[1]): 27 | dp_ho += alpha * ((XP_ok[d, k] - XP_hk[h, k]) ** 2) 28 | ds_oo += (1 - alpha) * ((XG_ok[d, k] - XG_ok[s, k]) ** 2) 29 | return dp_ho + ds_oo 30 | 31 | cdef inline void update(int row, double[:, ::1] X, double[::1] update): 32 | cdef int k 33 | for k in range(X.shape[1]): 34 | X[row, k] += update[k] 35 | 36 | cdef void do_iter(double[::1] dts, int[:, ::1] Trace, double[:, ::1] XG_ok, \ 37 | double[:, ::1] XP_ok, double[:, ::1] XP_hk, dict seen, \ 38 | double rate, double regularization, double alpha, double tau): 39 | 40 | cdef int i, k, h, s, d_old, d_new 41 | cdef double dt, z, sigma_z 42 | 43 | cdef double[::1] update_XP_h = np.zeros(XP_ok.shape[1], dtype='d') 44 | cdef double[::1] update_XP_dnew = np.zeros(XP_ok.shape[1], dtype='d') 45 | cdef double[::1] update_XP_dold = np.zeros(XP_ok.shape[1], dtype='d') 46 | 47 | cdef double[::1] update_XG_s = np.zeros(XP_ok.shape[1], dtype='d') 48 | cdef double[::1] update_XG_dnew = np.zeros(XP_ok.shape[1], dtype='d') 49 | cdef double[::1] update_XG_dold = np.zeros(XP_ok.shape[1], dtype='d') 50 | 51 | cdef double[::1] deriv_XP_h = np.zeros(XP_ok.shape[1], dtype='d') 52 | cdef double[::1] deriv_XP_dnew = np.zeros(XP_ok.shape[1], dtype='d') 53 | cdef double[::1] deriv_XP_dold = np.zeros(XP_ok.shape[1], dtype='d') 54 | 55 | cdef double[::1] deriv_XG_s = np.zeros(XP_ok.shape[1], dtype='d') 56 | cdef double[::1] deriv_XG_dnew = np.zeros(XP_ok.shape[1], dtype='d') 57 | cdef double[::1] deriv_XG_dold = np.zeros(XP_ok.shape[1], dtype='d') 58 | 59 | cdef set seen_hs 60 | cdef double alpha_to_use 61 | 62 | for i in xrange(Trace.shape[0]): 63 | dt = dts[i] 64 | h = Trace[i, 0] 65 | s = Trace[i, 1] 66 | d_old = Trace[i, 2] 67 | 68 | if dt > tau: 69 | alpha_to_use = 1.0 70 | else: 71 | alpha_to_use = alpha 72 | 73 | seen_hs = seen[h, s] 74 | d_new = (XP_ok.shape[0] * rand()) 75 | while d_new in seen_hs: 76 | d_new = (XP_ok.shape[0] * rand()) 77 | 78 | z = compute_dist(h, s, d_new, alpha_to_use, \ 79 | XG_ok, XP_ok, XP_hk) 80 | z -= compute_dist(h, s, d_old, alpha_to_use, \ 81 | XG_ok, XP_ok, XP_hk) 82 | sigma_z = sigma(z) 83 | 84 | #Compute derivatives dz/dTheta and zero auxiliary 85 | for k in range(XP_ok.shape[1]): 86 | update_XP_h[k] = 0.0 87 | update_XP_dnew[k] = 0.0 88 | update_XP_dold[k] = 0.0 89 | 90 | update_XG_s[k] = 0.0 91 | update_XG_dnew[k] = 0.0 92 | update_XG_dold[k] = 0.0 93 | 94 | #1. XP_h deriv 95 | deriv_XP_h[k] = XP_ok[d_old, k] - XP_ok[d_new, k] 96 | 97 | #2. XP_o(d_new) deriv 98 | deriv_XP_dnew[k] = XP_ok[d_new, k] - XP_hk[h, k] 99 | 100 | #3. XP_o(d_old) deriv 101 | deriv_XP_dold[k] = -(XP_ok[d_old, k] - XP_hk[h, k]) 102 | 103 | #4. XG_o(s) deriv 104 | deriv_XG_s[k] = XG_ok[d_old, k] - XG_ok[d_new, k] 105 | 106 | #5. XG_o(d_new) deriv 107 | deriv_XG_dnew[k] = XG_ok[d_new, k] - XG_ok[s, k] 108 | 109 | #6. XG_o(d_old) deriv 110 | deriv_XG_dold[k] = -(XG_ok[d_old, k] - XG_ok[s, k]) 111 | 112 | for k in range(XP_ok.shape[1]): 113 | deriv_XP_h[k] *= 2 * alpha_to_use 114 | deriv_XP_dnew[k] *= 2 * alpha_to_use 115 | deriv_XP_dold[k] *= 2 * alpha_to_use 116 | 117 | deriv_XG_s[k] *= 2 * (1 - alpha_to_use) 118 | deriv_XG_dnew[k] *= 2 * (1 - alpha_to_use) 119 | deriv_XG_dold[k] *= 2 * (1 - alpha_to_use) 120 | 121 | for k in range(XP_ok.shape[1]): 122 | update_XP_h[k] = rate * ((1 - sigma_z) * deriv_XP_h[k] - \ 123 | (2 * regularization * XP_hk[h, k])) 124 | 125 | update_XP_dnew[k] = rate * ((1 - sigma_z) * deriv_XP_dnew[k] - \ 126 | (2 * regularization * XP_ok[d_new, k])) 127 | 128 | update_XP_dold[k] = rate * ((1 - sigma_z) * deriv_XP_dold[k] - \ 129 | (2 * regularization * XP_ok[d_old, k])) 130 | 131 | update(h, XP_hk, update_XP_h) 132 | update(d_new, XP_ok, update_XP_dnew) 133 | update(d_old, XP_ok, update_XP_dold) 134 | 135 | if dt <= tau: 136 | for k in range(XP_ok.shape[1]): 137 | update_XG_s[k] = rate * ((1 - sigma_z) * deriv_XG_s[k] - \ 138 | (2 * regularization * XG_ok[s, k])) 139 | 140 | update_XG_dnew[k] = rate * ((1 - sigma_z) * deriv_XG_dnew[k] - \ 141 | (2 * regularization * XG_ok[d_new, k])) 142 | 143 | update_XG_dold[k] = rate * ((1 - sigma_z) * deriv_XG_dold[k] - \ 144 | (2 * regularization * XG_ok[d_old, k])) 145 | 146 | update(s, XG_ok, update_XG_s) 147 | update(d_new, XG_ok, update_XG_dnew) 148 | update(d_old, XG_ok, update_XG_dold) 149 | 150 | def compute_cost(double[::1] dts, int[:, ::1] Trace, double[:, ::1] XG_ok, \ 151 | double[:, ::1] XP_ok, double[:, ::1] XP_hk, dict seen, double rate, \ 152 | double regularization, double alpha, double tau, int num_examples=-1, \ 153 | int num_candidates=-1): 154 | 155 | cdef int i, j, h, s, d_old, d_new 156 | cdef double dt, z 157 | 158 | cdef set seen_hs 159 | cdef double alpha_to_use 160 | cdef double cost = 0.0 161 | cdef double curr_cost = 0.0 162 | cdef dict precomputed = {} 163 | 164 | cdef int[::1] idx = np.arange(Trace.shape[0], dtype='i4') 165 | if num_examples > 0: 166 | np.random.shuffle(idx) 167 | idx = idx[:num_examples] 168 | 169 | cdef int[::1] candidates = np.arange(XG_ok.shape[0], dtype='i4') 170 | if num_candidates > 0: 171 | np.random.shuffle(candidates) 172 | candidates = candidates[:num_candidates] 173 | 174 | for i in xrange(idx.shape[0]): 175 | dt = dts[idx[i]] 176 | h = Trace[idx[i], 0] 177 | s = Trace[idx[i], 1] 178 | d_old = Trace[idx[i], 2] 179 | 180 | if (h, s, dt >= tau) in precomputed: 181 | cost += precomputed[h, s, dt >= tau] 182 | continue 183 | 184 | if dt >= tau: 185 | alpha_to_use = 1.0 186 | else: 187 | alpha_to_use = alpha 188 | 189 | curr_cost = 0.0 190 | for j in xrange(candidates.shape[0]): 191 | d_new = candidates[j] 192 | z = compute_dist(h, s, d_new, alpha_to_use, \ 193 | XG_ok, XP_ok, XP_hk) 194 | z -= compute_dist(h, s, d_old, alpha_to_use, \ 195 | XG_ok, XP_ok, XP_hk) 196 | curr_cost += log(sigma(z)) 197 | 198 | precomputed[h, s, dt >= tau] = curr_cost 199 | cost += curr_cost 200 | return cost 201 | 202 | def sgd(double[::1] dts, int[:, ::1] Trace, double[:, ::1] XG_ok, \ 203 | double[:, ::1] XP_ok, double[:, ::1] XP_hk, dict seen, 204 | double rate, double regularization, double alpha, double tau, 205 | double[::1] dts_val, int[:, ::1] Trace_val): 206 | 207 | cost_train = 0.0 208 | cost_val = 0.0 209 | i = 0 210 | while i < 1000: 211 | do_iter(dts, Trace, XG_ok, XP_ok, XP_hk, seen, rate, \ 212 | regularization, alpha, tau) 213 | i += 1 214 | 215 | cost_train = compute_cost(dts_val, Trace_val, XG_ok, XP_ok, XP_hk, \ 216 | seen, rate, regularization, alpha, tau, 1000, 1000) 217 | cost_val = compute_cost(dts_val, Trace_val, XG_ok, XP_ok, XP_hk, \ 218 | seen, rate, regularization, alpha, tau, 1000, 1000) 219 | return cost_train, cost_val 220 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 3 | from __future__ import division, print_function 4 | '''Setup script''' 5 | 6 | import glob 7 | import numpy 8 | import os 9 | import sys 10 | 11 | from distutils.core import setup 12 | from distutils.extension import Extension 13 | from Cython.Distutils import build_ext 14 | 15 | SOURCE = '.' 16 | os.chdir(SOURCE) 17 | 18 | #Uncomment on mac envs for openmp 19 | os.environ["CC"] = "gcc-5" 20 | os.environ["CXX"] = "gcc-5" 21 | 22 | if sys.version_info[:2] < (2, 7): 23 | print('Requires Python version 2.7 or later (%d.%d detected).' % 24 | sys.version_info[:2]) 25 | sys.exit(-1) 26 | 27 | def get_packages(): 28 | '''Appends all packages (based on recursive sub dirs)''' 29 | 30 | packages = ['prme'] 31 | 32 | for package in packages: 33 | base = os.path.join(package, '**/') 34 | sub_dirs = glob.glob(base) 35 | while len(sub_dirs) != 0: 36 | for sub_dir in sub_dirs: 37 | package_name = sub_dir.replace('/', '.') 38 | if package_name.endswith('.'): 39 | package_name = package_name[:-1] 40 | 41 | packages.append(package_name) 42 | 43 | base = os.path.join(base, '**/') 44 | sub_dirs = glob.glob(base) 45 | 46 | return packages 47 | 48 | def get_extensions(): 49 | '''Get's all .pyx and.pxd files''' 50 | 51 | extensions = [] 52 | packages = get_packages() 53 | 54 | for pkg in packages: 55 | pkg_folder = pkg.replace('.', '/') 56 | pyx_files = glob.glob(os.path.join(pkg_folder, '*.pyx')) 57 | include_dirs = ['prme/myrandom/', numpy.get_include()] 58 | for pyx in pyx_files: 59 | pxd = pyx.replace('pyx', 'pxd') 60 | module = pyx.replace('.pyx', '').replace('/', '.') 61 | 62 | if os.path.exists(pxd): 63 | ext_files = [pyx, pxd] 64 | else: 65 | ext_files = [pyx] 66 | 67 | if module == 'prme.myrandom.random': 68 | ext_files.append(os.path.join(pkg_folder, 'randomkit.c')) 69 | 70 | extension = Extension(module, ext_files, 71 | include_dirs=include_dirs, 72 | extra_compile_args=['-msse', '-msse2', '-mfpmath=sse', \ 73 | '-fopenmp', '-Wno-unused-function'], #cython warnings supress 74 | extra_link_args=['-fopenmp']) 75 | 76 | extensions.append(extension) 77 | 78 | return extensions 79 | 80 | if __name__ == "__main__": 81 | packages = get_packages() 82 | extensions = get_extensions() 83 | setup(cmdclass = {'build_ext': build_ext}, 84 | name = 'prme', 85 | packages = packages, 86 | ext_modules = extensions) 87 | --------------------------------------------------------------------------------