normalize embedding (#974)

Co-authored-by: jyong <jyong@dify.ai>
This commit is contained in:
Jyong 2023-08-23 19:10:11 +08:00 committed by GitHub
parent 916d8be0ae
commit 1fc57d7358
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,6 +1,7 @@
import logging import logging
from typing import List from typing import List
import numpy as np
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
@ -32,14 +33,17 @@ class CacheEmbedding(Embeddings):
embedding_results = self._embeddings.client.embed_documents(embedding_queue_texts) embedding_results = self._embeddings.client.embed_documents(embedding_queue_texts)
except Exception as ex: except Exception as ex:
raise self._embeddings.handle_exceptions(ex) raise self._embeddings.handle_exceptions(ex)
i = 0 i = 0
normalized_embedding_results = []
for text in embedding_queue_texts: for text in embedding_queue_texts:
hash = helper.generate_text_hash(text) hash = helper.generate_text_hash(text)
try: try:
embedding = Embedding(model_name=self._embeddings.name, hash=hash) embedding = Embedding(model_name=self._embeddings.name, hash=hash)
embedding.set_embedding(embedding_results[i]) vector = embedding_results[i]
normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
normalized_embedding_results.append(normalized_embedding)
embedding.set_embedding(normalized_embedding)
db.session.add(embedding) db.session.add(embedding)
db.session.commit() db.session.commit()
except IntegrityError: except IntegrityError:
@ -51,7 +55,7 @@ class CacheEmbedding(Embeddings):
finally: finally:
i += 1 i += 1
text_embeddings.extend(embedding_results) text_embeddings.extend(normalized_embedding_results)
return text_embeddings return text_embeddings
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> List[float]:
@ -64,6 +68,7 @@ class CacheEmbedding(Embeddings):
try: try:
embedding_results = self._embeddings.client.embed_query(text) embedding_results = self._embeddings.client.embed_query(text)
embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
except Exception as ex: except Exception as ex:
raise self._embeddings.handle_exceptions(ex) raise self._embeddings.handle_exceptions(ex)
@ -79,4 +84,3 @@ class CacheEmbedding(Embeddings):
return embedding_results return embedding_results