diff --git a/pyproject.toml b/pyproject.toml index f63a710..bd89ab1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "ragchecker" -version = "0.1.2" +version = "0.1.3" description = "RAGChecker: A Fine-grained Framework For Diagnosing Retrieval-Augmented Generation (RAG) systems." authors = [ "Xiangkun Hu ", @@ -15,7 +15,7 @@ license = "Apache-2.0" [tool.poetry.dependencies] python = "^3.9" -refchecker = "^0.2.3" +refchecker = "^0.2.4" loguru = "^0.7" dataclasses-json = "^0.6" diff --git a/ragchecker/cli.py b/ragchecker/cli.py index ab21b76..6a1e5f6 100644 --- a/ragchecker/cli.py +++ b/ragchecker/cli.py @@ -59,6 +59,10 @@ def get_args(): help="Disable joint checking of the claims." ) parser.set_defaults(joint_check=True) + parser.add_argument( + "--joint_check_num", type=int, default=5 + ) + return parser.parse_args() @@ -75,6 +79,7 @@ def main(): batch_size_checker=args.batch_size_checker, openai_api_key=args.openai_api_key, joint_check=args.joint_check, + joint_check_num=args.joint_check_num ) with open(args.input_path, "r") as f: rag_results = RAGResults.from_json(f.read()) diff --git a/ragchecker/evaluator.py b/ragchecker/evaluator.py index df2d3a9..a4e231b 100644 --- a/ragchecker/evaluator.py +++ b/ragchecker/evaluator.py @@ -50,13 +50,15 @@ def __init__( batch_size_checker=32, openai_api_key=None, joint_check=True, - joint_check_num=5 + joint_check_num=5, + **kwargs ): if openai_api_key: os.environ['OPENAI_API_KEY'] = openai_api_key self.extractor_max_new_tokens = extractor_max_new_tokens self.joint_check = joint_check self.joint_check_num = joint_check_num + self.kwargs = kwargs self.extractor = LLMExtractor( model=extractor_name, @@ -102,7 +104,8 @@ def extract_claims(self, results: List[RAGResult], extract_type="gt_answer"): extraction_results = self.extractor.extract( batch_responses=texts, batch_questions=questions, - max_new_tokens=self.extractor_max_new_tokens + max_new_tokens=self.extractor_max_new_tokens, + **self.kwargs ) claims = [[c.content for c in res.claims] for res in extraction_results] for i, result in enumerate(results): @@ -161,7 +164,8 @@ def check_claims(self, results: RAGResults, check_type="answer2response"): max_reference_segment_length=0, merge_psg=merge_psg, is_joint=self.joint_check, - joint_check_num=self.joint_check_num + joint_check_num=self.joint_check_num, + **self.kwargs ) for i, result in enumerate(results): if check_type == "answer2response":