"""
smc
~~~
Provides an implementation of an inference algorithm for the Hierarchical
Dirichlet-Hawkes process, based on a sequential Monte-Carlo with particles.
:copyright: 2016 Charalampos Mavroforakis, <cmav@bu.edu> and contributors.
:license: ISC
"""
from __future__ import division, print_function
import tempfile
from collections import Counter, defaultdict
from copy import copy
from time import time
from numpy import log as ln, array, exp
from numpy.random import RandomState
from scipy.misc import logsumexp
from scipy.special import gammaln
from utils import copy_dict, weighted_choice
from hdhp import HDHProcess
maxint = 10000000
def memoize(f):
class memodict(dict):
__slots__ = ()
def __missing__(self, key):
self[key] = ret = f(key)
return ret
return memodict().__getitem__
@memoize
def _gammaln(x):
return gammaln(x)
@memoize
def _ln(x):
return ln(x)
class InferenceParameters:
"""This class collects all the parameters required for model inference.
Its main use is to help keep the code clean.
"""
def __init__(self, alpha_0, mu_0, omega, beta, theta_0,
threads, num_particles,
particle_weight_threshold, resample_every,
update_kernels, mu_rate,
keep_alpha_history, progress_file, seed,
vocabulary, users):
self.alpha_0 = alpha_0
self.mu_0 = mu_0
self.omega = omega
self.beta = beta
self.theta_0 = theta_0
self.threads = threads
self.num_particles = num_particles
self.particle_weight_threshold = particle_weight_threshold
self.resample_every = resample_every
self.update_kernels = update_kernels
self.mu_rate = mu_rate
self.keep_alpha_history = keep_alpha_history
self.progress_file = progress_file
self.seed = seed
self.vocabulary = vocabulary
self.users = users
class Particle(object):
def __init__(self, vocabulary_length, num_users, time_kernels=None,
alpha_0=(2, 2), mu_0 = 1, theta_0=None,
seed=None, logweight=0, update_kernels=False, uid=0,
omega=1, beta=1, keep_alpha_history=False, mu_rate=0.6):
self.vocabulary_length = vocabulary_length
self.vocabulary = None
self.seed = seed
self.prng = RandomState(self.seed)
self.first_observed_time = {}
self.first_observed_user_time = {}
self.per_topic_word_counts = {}
self.per_topic_word_count_total = {}
self.time_kernels = {}
self.alpha_0 = alpha_0
self.mu_0 = mu_0
self.theta_0 = array(theta_0)
self._lntheta = _ln(theta_0[0])
self.logweight = logweight
self.update_kernels = update_kernels
self.uid = uid
self.num_events = 0
self.topic_previous_event = None
# The following are for speed optimization purposes
# A struture to save the total intensity of a topic
# up to the most recent event t_i of that topic.
# It will be used to measure the total intensity at
# any time after t_i
self._Qn = None
self.omega = omega
self.beta = beta
self.num_users = num_users
self.keep_alpha_history = keep_alpha_history
self.user_table_cache = {}
self.dish_on_table_per_user = {}
self.dish_on_table_todelete = {}
self.dish_counters = {}
self._max_dish = -1
self.total_tables = 0
self.table_history_with_user = []
self.time_previous_user_event = []
self.total_tables_per_user = []
self.dish_cache = {}
self.time_kernel_prior = {}
self.time_history_per_user = {}
self.doc_history_per_user = {}
self.question_history_per_user = {}
self.table_history_per_user = {}
self.alpha_history = {}
self.alpha_distribution_history = {}
self.mu_rate = mu_rate
self.mu_per_user = {}
self.time_elapsed = 0
self.active_tables_per_user = {}
def reseed(self, seed=None, uid=None):
self.seed = seed
self.prng = RandomState(self.seed)
if uid is None:
self.uid = self.prng.randint(maxint)
else:
self.uid = uid
def reset_weight(self):
self.logweight = 0
def copy(self):
new_p = Particle(num_users=self.num_users,
vocabulary_length=self.vocabulary_length,
seed=self.seed, mu_rate=self.mu_rate,
theta_0=self.theta_0,
omega=self.omega,
beta=self.beta,
mu_0=self.mu_0,
uid=self.uid,
logweight=self.logweight,
update_kernels=self.update_kernels,
keep_alpha_history=self.keep_alpha_history)
new_p.alpha_0 = copy(self.alpha_0)
new_p.num_events = self.num_events
new_p.topic_previous_event = self.topic_previous_event
new_p.total_tables = self.total_tables
new_p._max_dish = self._max_dish
new_p.time_previous_user_event = copy(self.time_previous_user_event)
new_p.total_tables_per_user = copy(self.total_tables_per_user)
new_p.first_observed_time = copy(self.first_observed_time)
new_p.first_observed_user_time = copy(self.first_observed_user_time)
new_p.table_history_with_user = copy(self.table_history_with_user)
new_p.dish_cache = copy_dict(self.dish_cache)
new_p.dish_counters = copy_dict(self.dish_counters)
new_p.dish_on_table_per_user = \
copy_dict(self.dish_on_table_per_user)
new_p.dish_on_table_per_user = {}
new_p.dish_on_table_todelete = {}
for u in self.dish_on_table_per_user:
new_p.dish_on_table_per_user[u] = {}
new_p.dish_on_table_todelete[u] = {}
self.dish_on_table_todelete[u] = {}
for t in self.dish_on_table_per_user[u]:
if t in self.active_tables_per_user[u]:
new_p.dish_on_table_per_user[u][t] = \
self.dish_on_table_per_user[u][t]
else:
dish = self.dish_on_table_per_user[u][t]
self.dish_on_table_todelete[u][t] = dish
new_p.dish_on_table_todelete[u][t] = dish
if t in self.user_table_cache[u]:
del self.user_table_cache[u][t]
new_p.per_topic_word_counts = copy_dict(self.per_topic_word_counts)
new_p.per_topic_word_count_total = copy_dict(self.per_topic_word_count_total)
new_p.time_kernels = copy_dict(self.time_kernels)
new_p.time_kernel_prior = copy_dict(self.time_kernel_prior)
new_p.user_table_cache = copy_dict(self.user_table_cache)
if self.keep_alpha_history:
new_p.alpha_history = copy_dict(self.alpha_history)
new_p.alpha_distribution_history = \
copy_dict(self.alpha_distribution_history)
new_p.mu_per_user = copy_dict(self.mu_per_user)
new_p.active_tables_per_user = copy_dict(self.active_tables_per_user)
return new_p
def update(self, event):
"""Parses an event and updates the particle
Parameters
----------
event : tuple
The event is a 4-tuple of the form (user, time, content, metadata)
"""
# u_n : user of the n-th event
# t_n : time of the n-th event
# d_n : text of the n-th event
# q_n : any metadata for the n-th event, e.g. the question id
t_n, d_n, u_n, q_n = event
d_n = d_n.split()
if self.num_events == 0:
self.time_previous_user_event = [0 for i in range(self.num_users)]
self.total_tables_per_user = [0 for i in range(self.num_users)]
self.mu_per_user = {i: self.sample_mu()
for i in range(self.num_users)}
self.active_tables_per_user = {i: set()
for i in range(self.num_users)}
if self.num_events >= 1 and u_n in self.time_previous_user_event and \
self.time_previous_user_event[u_n] > 0:
log_likelihood_tn = self.time_event_log_likelihood(t_n, u_n)
else:
log_likelihood_tn = 0
tables_before = self.total_tables_per_user[u_n]
b_n, z_n, opened_table, log_likelihood_dn = \
self.sample_table(t_n, d_n, u_n)
if self.total_tables_per_user[u_n] > tables_before and tables_before > 0:
# opened a new table
old_mu = self.mu_per_user[u_n]
tables_num = tables_before + 1
user_alive_time = t_n - self.first_observed_user_time[u_n]
new_mu = (self.mu_rate * old_mu +
(1 - self.mu_rate) * tables_num / user_alive_time)
self.mu_per_user[u_n] = new_mu
if z_n not in self.time_kernels:
self.time_kernels[z_n] = self.sample_time_kernel()
self.first_observed_time[z_n] = t_n
self.dish_cache[z_n] = (t_n, 0, 1, 1, 1)
self._max_dish = z_n
else:
if self.update_kernels:
self.update_time_kernel(t_n, z_n)
if self.update_kernels and self.keep_alpha_history:
if z_n not in self.alpha_history:
self.alpha_history[z_n] = []
self.alpha_distribution_history[z_n] = []
self.alpha_history[z_n].append(self.time_kernels[z_n])
self.alpha_distribution_history[z_n].append(self.time_kernel_prior[z_n])
if self.num_events >= 1:
self.logweight += log_likelihood_tn
self.logweight += self._Qn
self.num_events += 1
self._update_word_counters(d_n, z_n)
self.time_previous_user_event[u_n] = t_n
self.topic_previous_event = z_n
self.user_previous_event = u_n
self.table_previous_event = b_n
self.active_tables_per_user[u_n].add(b_n)
if z_n not in self.dish_counters:
self.dish_counters[z_n] = 1
elif opened_table:
self.dish_counters[z_n] += 1
if u_n not in self.first_observed_user_time:
self.first_observed_user_time[u_n] = t_n
return b_n, z_n
def sample_table(self, t_n, d_n, u_n):
"""Samples table b_n and topic z_n together for the event n.
Parameters
----------
t_n : float
The time of the event.
d_n : list
The document for the event.
u_n : int
The user id.
Returns
-------
table : int
dish : int
"""
if self.total_tables_per_user[u_n] == 0:
# This is going to be the user's first table
self.dish_on_table_per_user[u_n] = {}
self.user_table_cache[u_n] = {}
self.time_previous_user_event[u_n] = 0
tables = range(self.total_tables_per_user[u_n])
num_dishes = len(self.dish_counters)
intensities = []
dn_word_counts = Counter(d_n)
count_dn = len(d_n)
# Precompute the doc_log_likelihood for each of the dishes
dish_log_likelihood = []
for dish in self.dish_counters:
dll = self.document_log_likelihood(dn_word_counts, count_dn,
dish)
dish_log_likelihood.append(dll)
table_intensity_threshold = 1e-8 # below this, the table is inactive
# Provide one option for each of the already open tables
mu = self.mu_per_user[u_n]
total_table_int = mu
dish_log_likelihood_array = []
for table in tables:
if table in self.active_tables_per_user[u_n]:
dish = self.dish_on_table_per_user[u_n][table]
alpha = self.time_kernels[dish]
t_last, sum_kernels = self.user_table_cache[u_n][table]
update_value = self.kernel(t_n, t_last)
table_intensity = alpha * sum_kernels * update_value
table_intensity += alpha * update_value
total_table_int += table_intensity
if table_intensity < table_intensity_threshold:
self.active_tables_per_user[u_n].remove(table)
dish_log_likelihood_array.append(dish_log_likelihood[dish])
intensities.append(table_intensity)
else:
dish_log_likelihood_array.append(0)
intensities.append(0)
log_intensities = [ln(inten_i / total_table_int) + dish_log_likelihood_array[i]
if inten_i > 0 else -float('inf')
for i, inten_i in enumerate(intensities)]
# Provide one option for new table with already existing dish
for dish in self.dish_counters:
dish_intensity = (mu / total_table_int) *\
self.dish_counters[dish] / (self.total_tables + self.beta)
dish_intensity = ln(dish_intensity)
dish_intensity += dish_log_likelihood[dish]
log_intensities.append(dish_intensity)
# Provide a last option for new table with new dish
new_dish_intensity = mu * self.beta /\
(total_table_int * (self.total_tables + self.beta))
new_dish_intensity = ln(new_dish_intensity)
new_dish_log_likelihood = self.document_log_likelihood(dn_word_counts,
count_dn,
num_dishes)
new_dish_intensity += new_dish_log_likelihood
log_intensities.append(new_dish_intensity)
normalizing_log_intensity = logsumexp(log_intensities)
intensities = [exp(log_intensity - normalizing_log_intensity)
for log_intensity in log_intensities]
self._Qn = normalizing_log_intensity
k = weighted_choice(intensities, self.prng)
opened_table = False
if k in tables:
# Assign to one of the already existing tables
table = k
dish = self.dish_on_table_per_user[u_n][table]
# update cache for that table
t_last, sum_kernels = self.user_table_cache[u_n][table]
update_value = self.kernel(t_n, t_last)
sum_kernels += 1
sum_kernels *= update_value
self.user_table_cache[u_n][table] = (t_n, sum_kernels)
else:
k = k - len(tables)
table = len(tables)
self.total_tables += 1
self.total_tables_per_user[u_n] += 1
dish = k
# Since this is a new table, initialize the cache accordingly
self.user_table_cache[u_n][table] = (t_n, 0)
self.dish_on_table_per_user[u_n][table] = dish
opened_table = True
if dish not in self.time_kernel_prior:
self.time_kernel_prior[dish] = self.alpha_0
dll = self.document_log_likelihood(dn_word_counts, count_dn,
dish)
dish_log_likelihood.append(dll)
self.table_history_with_user.append((u_n, table))
self.time_previous_user_event[u_n] = t_n
return table, dish, opened_table, dish_log_likelihood[dish]
def kernel(self, t_i, t_j):
"""Returns the kernel function for t_i and t_j.
Parameters
----------
t_i : float
The later timestamp
t_j : float
The earlier timestamp
Returns
-------
float
"""
return exp(-self.omega * (t_i - t_j))
def update_time_kernel(self, t_n, z_n):
"""Updates the parameter of the time kernel of the chosen pattern
"""
v_1, v_2 = self.time_kernel_prior[z_n]
t_last, sum_kernels, event_count, intensity, prod = self.dish_cache[z_n]
update_value = self.kernel(t_n, t_last)
sum_kernels += 1
sum_kernels *= update_value
prod = sum_kernels
sum_integrals = event_count - sum_kernels
sum_integrals /= self.omega
self.time_kernel_prior[z_n] = self.alpha_0[0] + event_count - self.dish_counters[z_n], \
self.alpha_0[1] + (sum_integrals)
prior = self.time_kernel_prior[z_n]
self.time_kernels[z_n] = self.sample_time_kernel(prior)
self.dish_cache[z_n] = t_n, sum_kernels, event_count + 1, intensity, prod
def sample_time_kernel(self, alpha_0=None):
if alpha_0 is None:
alpha_0 = self.alpha_0
return self.prng.gamma(alpha_0[0], 1. / alpha_0[1])
def sample_mu(self):
"""Samples a value from the prior of the base intensity mu.
Returns
-------
mu_u : float
The base intensity of a user, sampled from the prior.
"""
return self.prng.gamma(self.mu_0[0], self.mu_0[1])
def document_log_likelihood(self, dn_word_counts, count_dn, z_n):
"""Returns the log likelihood of document d_n to belong to cluster z_n.
Note: Assumes a Gamma prior on the word distribution.
"""
theta = self.theta_0[0]
V = self.vocabulary_length
if z_n not in self.per_topic_word_count_total:
count_zn_no_dn = 0
else:
count_zn_no_dn = self.per_topic_word_count_total[z_n]
# TODO: The code below works only for uniform theta_0. We should
# put the theta that corresponds to `word`. Here we assume that
# all the elements of theta_0 are equal
gamma_numerator = _gammaln(count_zn_no_dn + V * theta)
gamma_denominator = _gammaln(count_zn_no_dn + count_dn + V * theta)
is_old_topic = z_n <= self._max_dish
unique_words = len(dn_word_counts) == count_dn
topic_words = None
if is_old_topic:
topic_words = self.per_topic_word_counts[z_n]
if unique_words:
rest = [_ln(topic_words[word] + theta)
if is_old_topic and word in topic_words
else self._lntheta
for word in dn_word_counts]
else:
rest = [_gammaln(topic_words[word] + dn_word_counts[word] + theta) - _gammaln(topic_words[word] + theta)
if is_old_topic and word in topic_words
else _gammaln(dn_word_counts[word] + theta) - _gammaln(theta)
for word in dn_word_counts]
return gamma_numerator - gamma_denominator + sum(rest)
def document_history_log_likelihood(self):
"""Computes the log likelihood for the whole history of documents,
using the inferred parameters.
"""
doc_log_likelihood = 0
for user in self.doc_history_per_user:
for doc, table in zip(self.doc_history_per_user[user],
self.table_history_per_user[user]):
dish = self.dish_on_table_per_user[user][table]
doc_word_counts = Counter(doc.split())
count_doc = len(doc.split())
doc_log_likelihood += self.document_log_likelihood(doc_word_counts,
count_doc,
dish)
return doc_log_likelihood
def time_event_log_likelihood(self, t_n, u_n):
mu = self.mu_per_user[u_n]
integral = (t_n - self.time_previous_user_event[u_n]) * mu
intensity = mu
for table in self.user_table_cache[u_n]:
t_last, sum_timedeltas = self.user_table_cache[u_n][table]
update_value = self.kernel(t_n, t_last)
topic_sum = (sum_timedeltas + 1) - \
(sum_timedeltas + 1) * update_value
dish = self.dish_on_table_per_user[u_n][table]
topic_sum *= self.time_kernels[dish]
integral += topic_sum
intensity += (sum_timedeltas + 1) \
* self.time_kernels[dish] * update_value
return ln(intensity) - integral
def _update_word_counters(self, d_n, z_n):
if z_n not in self.per_topic_word_counts:
self.per_topic_word_counts[z_n] = {}
if z_n not in self.per_topic_word_count_total:
self.per_topic_word_count_total[z_n] = 0
for word in d_n:
if word not in self.per_topic_word_counts[z_n]:
self.per_topic_word_counts[z_n][word] = 0
self.per_topic_word_counts[z_n][word] += 1
self.per_topic_word_count_total[z_n] += 1
return
def to_process(self):
"""Exports the particle as a HDHProcess object.
Use the exported object to plot the user timelines.
Returns
-------
HDHProcess
"""
process = HDHProcess(num_patterns=len(self.time_kernels),
mu_0=self.mu_0,
alpha_0=self.alpha_0,
vocabulary=self.vocabulary)
process.mu_per_user = self.mu_per_user
process.table_history_per_user = self.table_history_per_user
process.time_history_per_user = self.time_history_per_user
process.dish_on_table_per_user = self.dish_on_table_per_user
process.time_kernels = self.time_kernels
process.first_observed_time = self.first_observed_time
process.omega = self.omega
process.num_users = self.num_users
process.document_history_per_user = self.doc_history_per_user
return process
def get_intensity(self, t_n, u_n, z_n):
pi_z = self.dish_counters[z_n] / self.total_tables
mu = self.mu_per_user[u_n]
alpha = self.time_kernels[z_n]
intensity = pi_z * mu
for table in self.user_table_cache[u_n]:
dish = self.dish_on_table_per_user[u_n][table]
if dish == z_n:
t_last, sum_timedeltas = self.user_table_cache[u_n][table]
update_value = self.kernel(t_n, t_last)
table_intensity = alpha * sum_timedeltas * update_value
table_intensity += alpha * update_value
intensity += table_intensity
return intensity
def _extract_words_users(history):
"""Returns the set of words and the set of users in the dataset
"""
vocabulary = set()
users = set()
for t, doc, u, q in history:
for word in doc.split():
vocabulary.add(word)
users.add(u)
return vocabulary, users
def resample_indices(weights, prng):
N = len(weights)
index = prng.randint(N)
beta = 0.0
mw = max(weights)
picked_indices = []
for i in range(N):
beta += prng.rand() * 2.0 * mw
while beta > weights[index]:
beta -= weights[index]
index = (index + 1) % N
picked_indices.append(index)
return sorted(picked_indices)
def pick_new_particles(old_particles, weights, prng):
N = len(old_particles)
index = prng.randint(N)
beta = 0.0
mw = max(weights)
picked_indices = []
for i in range(N):
beta += prng.rand() * 2.0 * mw
while beta > weights[index]:
beta -= weights[index]
index = (index + 1) % N
picked_indices.append(index)
return picked_indices
def _infer_single_thread(history, params):
prng = RandomState(seed=params.seed)
time_history_per_user = defaultdict(list)
doc_history_per_user = defaultdict(list)
question_history_per_user = defaultdict(list)
table_history_with_user = []
dish_on_table_per_user = []
# Set the accuracy
count_resamples = 0
square_norms = []
with open(params.progress_file, 'a') as out:
out.write('Starting %d particles on %d thread.\n' % (params.num_particles,
params.threads))
start_tic = time()
# Initialize the particles
epsilon = 1e-10
particles = [Particle(theta_0=params.theta_0, alpha_0=params.alpha_0,
mu_0=params.mu_0,
uid=prng.randint(maxint), seed=prng.randint(maxint),
vocabulary_length=len(params.vocabulary),
update_kernels=params.update_kernels,
omega=params.omega, beta=params.beta,
num_users=len(params.users),
keep_alpha_history=params.keep_alpha_history,
mu_rate=params.mu_rate)
for i in range(params.num_particles)]
inferred_tables = {} # for each particle, save the topic history
for p in particles:
inferred_tables[p.uid] = []
# Fit each particle to the history
square_norms = []
table_history_with_user = []
dish_on_table_per_user = []
for i, h_i in enumerate(history):
max_logweight = None
weights = []
total = 0
t_i, d_i, u_i, q_i = h_i
if u_i not in time_history_per_user:
time_history_per_user[u_i] = []
doc_history_per_user[u_i] = []
question_history_per_user[u_i] = []
time_history_per_user[u_i].append(t_i)
doc_history_per_user[u_i].append(d_i)
question_history_per_user[u_i].append(q_i)
for p_i in particles:
# Fit each particle to the next event
b_i, z_i = p_i.update(h_i)
inferred_tables[p_i.uid].append((b_i, z_i))
if i > 0 and i % params.resample_every == 0:
# Start resampling
for p_i in particles:
if max_logweight is None or max_logweight < p_i.logweight:
max_logweight = p_i.logweight
for p_i in particles:
# Normalize the weights of the particles
if p_i.logweight - max_logweight >= \
ln(epsilon) - ln(params.num_particles):
weights.append(exp(p_i.logweight - max_logweight))
else:
weights.append(exp(p_i.logweight - max_logweight))
total += weights[-1]
normalized = [w / sum(weights) for w in weights]
# Check if resampling is needed
norm2 = sum([w ** 2 for w in normalized])
square_norms.append(norm2)
if params.num_particles > 1 \
and norm2 > params.particle_weight_threshold / params.num_particles\
and i < len(history) - 1:
# Resample particles (though never for the last event)
count_resamples += 1
new_particle_indices = pick_new_particles(particles,
normalized, prng)
new_particles = []
new_table_history_with_user = []
new_dish_on_table_per_user = []
for index in new_particle_indices:
# copy table_history for that particle
if len(table_history_with_user):
old_history = copy(table_history_with_user[index])
else:
old_history = []
new_history = copy(particles[index].table_history_with_user)
old_history.extend(new_history)
new_table_history_with_user.append(old_history)
if len(dish_on_table_per_user):
dish_table_user = copy_dict(dish_on_table_per_user[index])
else:
dish_table_user = {}
dishes_toadd = copy_dict(particles[index].dish_on_table_todelete)
for user in dishes_toadd:
if user not in dish_table_user:
dish_table_user[user] = {}
for t in dishes_toadd[user]:
assert t not in dish_table_user[user]
dish_table_user[user][t] = dishes_toadd[user][t]
new_dish_on_table_per_user.append(dish_table_user)
# delete history from new particles
for index in new_particle_indices:
particles[index].table_history_with_user = []
for user in particles[index].dish_on_table_todelete:
particles[index].dish_on_table_todelete[user] = {}
for index in new_particle_indices:
particles[index].table_history_with_user = []
new_particle = particles[index].copy()
new_particle.reseed(prng.randint(maxint))
new_particle.reset_weight()
new_particles.append(new_particle)
inferred_tables[new_particle.uid] = \
copy(inferred_tables[particles[index].uid])
particles = new_particles
table_history_with_user = new_table_history_with_user
dish_on_table_per_user = new_dish_on_table_per_user
# If inferred tables dictionary grows too big, prune it
if len(inferred_tables) > 50 * params.num_particles:
new_inferred_tables = {}
for p in particles:
new_inferred_tables[p.uid] = copy(inferred_tables[p.uid])
del inferred_tables
inferred_tables = new_inferred_tables
with open(params.progress_file, mode='a') as temp:
temp.write("Time: %.2f (%d)\n" % (time() - start_tic, i))
# Finally sample a single particle according to its weight.
for p_i in particles:
if max_logweight is None or max_logweight < p_i.logweight:
max_logweight = p_i.logweight
for p_i in particles:
# Normalize the weights of the particles
if p_i.logweight - max_logweight >= \
ln(epsilon) - ln(params.num_particles):
weights.append(exp(p_i.logweight - max_logweight))
else:
weights.append(exp(p_i.logweight - max_logweight))
total += weights[-1]
normalized = [w / sum(weights) for w in weights]
final_particle_id = pick_new_particles(particles, normalized, prng)[0]
final_particle = particles[final_particle_id]
table_history_with_user = table_history_with_user[final_particle_id]
new_history = copy(final_particle.table_history_with_user)
table_history_with_user.extend(new_history)
final_particle.table_history_with_user = table_history_with_user
dish_on_table_per_user = dish_on_table_per_user[final_particle_id]
dishes_toadd = copy_dict(final_particle.dish_on_table_per_user)
for user in dishes_toadd:
if user not in dish_on_table_per_user:
dish_on_table_per_user[user] = {}
for t in dishes_toadd[user]:
assert t not in dish_on_table_per_user[user]
dish_on_table_per_user[user][t] = dishes_toadd[user][t]
for user in final_particle.dish_on_table_todelete:
if user not in dish_on_table_per_user:
dish_on_table_per_user[user] = {}
for t in final_particle.dish_on_table_todelete[user]:
assert t not in dish_on_table_per_user[user]
dish_on_table_per_user[user][t] = \
final_particle.dish_on_table_todelete[user][t]
final_particle.dish_on_table_per_user = dish_on_table_per_user
final_particle.time_history_per_user = copy(time_history_per_user)
final_particle.doc_history_per_user = copy(doc_history_per_user)
final_particle.question_history_per_user = copy(question_history_per_user)
final_particle.table_history_per_user = {}
for (u_i, table) in final_particle.table_history_with_user:
if u_i not in final_particle.table_history_per_user:
final_particle.table_history_per_user[u_i] = []
final_particle.table_history_per_user[u_i].append(table)
final_particle.vocabulary = params.vocabulary
# pool.close()
with open(params.progress_file, mode='a') as temp:
temp.write("Resampled %d times\n" % (count_resamples))
temp.write("Finished in time: %.2f\n" %
(time() - start_tic))
return final_particle, square_norms
[docs]def infer(history, alpha_0, mu_0, omega=1, beta=1, theta_0=None,
threads=1, num_particles=1,
particle_weight_threshold=1, resample_every=10,
update_kernels=True, mu_rate=0.6,
# enable_log=False, logfile='particles.log',
keep_alpha_history=False, progress_file=None, seed=None):
"""Runs the inference algorithm and returns a particle.
Parameters
----------
history : list
A list of 4-tuples (user, time, content, metadata) that represents the
event history that we want to infer our model on.
alpha_0 : tuple
The Gamma prior parameter for a pattern's time kernel.
mu_0 : tuple
The Gamma prior parameter for the user activity rate.
omega : float
The time decay parameter.
beta : float
A parameter that controls the new-task probability.
theta_0 : list, default is None
If not None, theta_0 corresponds to the Dirichlet prior used for the
word distribution. It should have as many dimensions as the number of
words. By default, this is the vector :math:`[1 / |V|, \ldots, 1 / |V|]`, where
:math:`|V|` is the size of the vocabulary.
threads : int, default is 1
The number of CPU threads that will be used during inference.
num_particles : int, default is 1
The number of particles that the SMC algorithm will use.
particle_weight_threshold : float, default is 1
A parameter that controls when the particles need to be resampled
resample_every : int, default is 10
The frequency with which we check if we need to resample or not. The
number is in inference steps (number of events)
update_kernels : bool, default is True
Controls wheter the time kernel parameter of each pattern will be
updated from the posterior, or not.
mu_rate : float, default is 0.6
The learning-rate with which we update the activity rate of
a user.
keep_alpha_history : bool, default is False
For debug reasons, we make want to keep the complete history of the value
of each pattern's time kernel parameter as we see more events in that
pattern.
progress_file : str, default is None
Since the computation might be slow, we want to save progress
information to a file instead of printing it. If None, a temporary,
randomly-named file is generated for this purpose.
"""
vocabulary, users = _extract_words_users(history)
vocabulary = list(vocabulary)
if theta_0 is None:
theta_0 = [1 / len(vocabulary)] * len(vocabulary)
if progress_file is None:
with tempfile.NamedTemporaryFile(mode='a', suffix='.log', dir='.',
delete=False) as temp:
progress_file = temp.name
print('Created temporary log file %s' % (progress_file))
params = InferenceParameters(alpha_0, mu_0, omega, beta, theta_0, threads,
num_particles, particle_weight_threshold,
resample_every, update_kernels, mu_rate,
keep_alpha_history, progress_file, seed,
vocabulary, users)
if threads == 1:
return _infer_single_thread(history, params)
else:
raise NotImplementedError("Multi-threaded versoin not yet implemented")