Skip to content
Snippets Groups Projects
Commit bfffdacf authored by Jim Hoekstra's avatar Jim Hoekstra :wave_tone1:
Browse files

Merge branch 'issue/MSX-51' into 'develop'

implement metric for graph expansion

See merge request !16
parents 208ed66b 3c117b84
Branches
No related tags found
2 merge requests!17Develop,!16implement metric for graph expansion
import dash_cytoscape as cyto
import pandas as pd
from math import factorial
class Graph:
......@@ -6,8 +8,10 @@ class Graph:
def __init__(self):
self.nodes = []
self.edges = []
self.COUNT_THRESHOLD = 2
self.MAX_NUM_WORDS = 10
self.N_INTERSECTING_WORDS = 5
def get_all_words(self):
all_words = [node_dict['data']['label'] for node_dict in self.nodes]
......@@ -103,21 +107,58 @@ class Graph:
]
)
def extend_graph(self, word2vec_model, base_node):
def extend_graph(self, word2vec_model, base_node, words_to_exclude=None, weight_threshold=0.3):
current_words = [node['data']['id'] for node in self.nodes]
all_associated_words = []
for current_word in current_words:
associated_words = word2vec_model.get_associated_words(current_word, top_n=100)
all_associated_words.extend(associated_words)
associated_words_filtered = [word for word in all_associated_words if word not in current_words]
associated_words_count = {word: associated_words_filtered.count(word) for word in list(set(associated_words_filtered))}
if words_to_exclude is None:
words_to_exclude = current_words
count_threshold = self.COUNT_THRESHOLD
common_associated_words = [word for word, count in associated_words_count.items() if count >= count_threshold]
while len(common_associated_words) > self.MAX_NUM_WORDS:
count_threshold += 1
common_associated_words = [word for word, count in associated_words_count.items() if count >= count_threshold]
self.add_nodes(common_associated_words)
self.add_edges(base_node, common_associated_words)
all_associated_words = {}
for current_word in current_words:
associated_words = word2vec_model.get_associated_words(current_word, top_n=100)
all_associated_words[current_word] = associated_words
intersections_df = self.construct_intersection_df(current_words, all_associated_words)
weights = intersections_df['intersection'].value_counts()
weights = weights.rename_axis(['word'], axis='index')
weights = weights.reset_index(name='weight')
weights['weight'] = self.normalize_weights(weights['weight'], len(current_words))
if len(weights['word'].values) == 0:
return
exclude_words_filter = [True if word not in words_to_exclude else False for word in weights['word'].values]
weights = weights[exclude_words_filter]
weights_after_threshold = weights.loc[weights['weight'] > weight_threshold]
words_to_add = weights_after_threshold['word'].values
self.add_nodes(words_to_add)
self.add_edges(base_node, words_to_add)
def construct_intersection_df(self, current_words, all_associated_words):
word1s = []
word2s = []
intersections = []
for i in range(len(current_words) - 1):
for j in range(i + 1, len(current_words)):
similar_words_1 = all_associated_words[current_words[i]]
similar_words_2 = all_associated_words[current_words[j]]
intersections_for_words = [word for word in similar_words_1 if word in similar_words_2][
:self.N_INTERSECTING_WORDS]
if len(intersections_for_words) > 0:
for intersection_for_words in intersections_for_words:
word1s.append(current_words[i])
word2s.append(current_words[j])
intersections.append(intersection_for_words)
return pd.DataFrame.from_dict({'word1': word1s, 'word2': word2s, 'intersection': intersections})
@staticmethod
def normalize_weights(weights, number_of_words):
if number_of_words >= 2:
number_of_combinations = factorial(number_of_words) / (2 * factorial(number_of_words - 2))
return weights / number_of_combinations
else:
return weights
......@@ -12,7 +12,17 @@ class AssociatedWords:
def get_associated_words(self, word, top_n=10):
lowercase_word = word.lower()
gensim_result = self.model.most_similar(lowercase_word, topn=top_n)
# gensim_result = [('apple', 1.0), ('banana', 1.0), ('strawberry', 1.0)]
# if word == 'fruit':
# gensim_result = [('apple', 1.0), ('banana', 1.0), ('strawberry', 1.0)]
# elif word == 'apple':
# gensim_result = [('fruit', 1.0), ('juice', 1.0), ('tree', 1.0)]
# elif word == 'banana':
# gensim_result = [('fruit', 1.0), ('smoothie', 1.0), ('tree', 1.0)]
# elif word == 'strawberry':
# gensim_result = [('fruit', 1.0), ('smoothie', 1.0), ('berry', 1.0)]
# else:
# gensim_result = []
words = self.filter_results(gensim_result, lowercase_word)
return words
......
......@@ -2,4 +2,4 @@ from dash_app.index import app
if __name__ == '__main__':
app.run_server(debug=False)
app.run_server(debug=True)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment