18 lines
448 B
Python
18 lines
448 B
Python
|
|
import gensim
|
||
|
|
import numpy
|
||
|
|
from numpy import dot
|
||
|
|
from numpy.linalg import norm
|
||
|
|
|
||
|
|
MODEL_INPUT = '../data/models/doc2vec.model'
|
||
|
|
|
||
|
|
def cos_dist(v0, v1):
|
||
|
|
return dot(v0, v1) / (norm(v0) * norm(v1))
|
||
|
|
|
||
|
|
class Metric:
|
||
|
|
|
||
|
|
def __init__(self, model_input=MODEL_INPUT):
|
||
|
|
self.model = gensim.models.doc2vec.Doc2Vec.load(model_input)
|
||
|
|
|
||
|
|
def vector(self, text: str):
|
||
|
|
tokens = gensim.utils.simple_preprocess(text, max_len=25)
|
||
|
|
return self.model.infer_vector(tokens)
|