diff --git a/graph/component/__init__.py b/graph/component/__init__.py index 9e723a26f..3f1d32fea 100644 --- a/graph/component/__init__.py +++ b/graph/component/__init__.py @@ -13,6 +13,7 @@ from .baidu import Baidu, BaiduParam from .duckduckgo import DuckDuckGo, DuckDuckGoParam from .wikipedia import Wikipedia, WikipediaParam from .pubmed import PubMed, PubMedParam +from .arxiv import ArXiv, ArXivParam def component_class(class_name): diff --git a/graph/component/arxiv.py b/graph/component/arxiv.py new file mode 100644 index 000000000..ad9ee3885 --- /dev/null +++ b/graph/component/arxiv.py @@ -0,0 +1,68 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import random +from abc import ABC +from functools import partial +import arxiv +import pandas as pd +from graph.settings import DEBUG +from graph.component.base import ComponentBase, ComponentParamBase + + +class ArXivParam(ComponentParamBase): + """ + Define the ArXiv component parameters. + """ + + def __init__(self): + super().__init__() + self.top_n = 6 + self.sort_by = 'submittedDate' + + def check(self): + self.check_positive_integer(self.top_n, "Top N") + self.check_valid_value(self.sort_by, "ArXiv Search Sort_by", + ['submittedDate', 'lastUpdatedDate', 'relevance']) + + +class ArXiv(ComponentBase, ABC): + component_name = "ArXiv" + + def _run(self, history, **kwargs): + ans = self.get_input() + ans = " - ".join(ans["content"]) if "content" in ans else "" + if not ans: + return ArXiv.be_output("") + + sort_choices = {"relevance": arxiv.SortCriterion.Relevance, + "lastUpdatedDate": arxiv.SortCriterion.LastUpdatedDate, + 'submittedDate': arxiv.SortCriterion.SubmittedDate} + arxiv_client = arxiv.Client() + search = arxiv.Search( + query=ans, + max_results=self._param.top_n, + sort_by=sort_choices[self._param.sort_by] + ) + arxiv_res = [ + {"content": 'Title: ' + i.title + '\nPdf_Url: \nSummary: ' + i.summary} for + i in list(arxiv_client.results(search))] + + if not arxiv_res: + return ArXiv.be_output("") + + df = pd.DataFrame(arxiv_res) + if DEBUG: print(df, ":::::::::::::::::::::::::::::::::") + return df diff --git a/requirements.txt b/requirements.txt index 8f9fa545a..9aa14c94c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +arxiv==2.1.3 Aspose.Slides==24.2.0 BCEmbedding==0.1.3 Bio==1.7.1 diff --git a/requirements_arm.txt b/requirements_arm.txt index 23895815f..20ee53e65 100644 --- a/requirements_arm.txt +++ b/requirements_arm.txt @@ -152,3 +152,4 @@ google-generativeai==0.7.2 groq==0.9.0 wikipedia==1.4.0 Bio==1.7.1 +arxiv==2.1.3 diff --git a/requirements_dev.txt b/requirements_dev.txt index 77b8bd619..9c86c702c 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -137,3 +137,4 @@ google-generativeai==0.7.2 groq==0.9.0 wikipedia==1.4.0 Bio==1.7.1 +arxiv==2.1.3