diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index 03cfcf0b0..5c978040b 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -57,6 +57,7 @@ class Base(ABC): class DefaultEmbedding(Base): + os.environ['CUDA_VISIBLE_DEVICES'] = '0' _model = None _model_name = "" _model_lock = threading.Lock()