Source code for autorag.nodes.passagecompressor.longllmlingua
from typing import List, Optional
import pandas as pd
from autorag.nodes.passagecompressor.base import BasePassageCompressor
from autorag.utils.util import pop_params, result_to_dataframe, empty_cuda_cache
# TODO: Parallel Processing Refactoring at #460
[docs]
class LongLLMLingua(BasePassageCompressor):
def __init__(
self, project_dir: str, model_name: str = "NousResearch/Llama-2-7b-hf", **kwargs
):
try:
from llmlingua import PromptCompressor
except ImportError:
raise ImportError(
"LongLLMLingua is not installed. Please install it by running `pip install llmlingua`."
)
super().__init__(project_dir)
model_init_params = pop_params(PromptCompressor.__init__, kwargs)
self.llm_lingua = PromptCompressor(model_name=model_name, **model_init_params)
def __del__(self):
del self.llm_lingua
empty_cuda_cache()
super().__del__()
[docs]
@result_to_dataframe(["retrieved_contents"])
def pure(self, previous_result: pd.DataFrame, *args, **kwargs):
queries, retrieved_contents = self.cast_to_run(previous_result)
results = self._pure(queries, retrieved_contents, **kwargs)
return list(map(lambda x: [x], results))
def _pure(
self,
queries: List[str],
contents: List[List[str]],
instructions: Optional[str] = None,
target_token: int = 300,
**kwargs,
) -> List[str]:
"""
Compresses the retrieved texts using LongLLMLingua.
For more information, visit https://github.com/microsoft/LLMLingua.
:param queries: The queries for retrieved passages.
:param contents: The contents of retrieved passages.
:param model_name: The model name to use for compression.
The default is "NousResearch/Llama-2-7b-hf".
:param instructions: The instructions for compression.
Default is None. When it is None, it will use default instructions.
:param target_token: The target token for compression.
Default is 300.
:param kwargs: Additional keyword arguments.
:return: The list of compressed texts.
"""
if instructions is None:
instructions = "Given the context, please answer the final question"
results = [
llmlingua_pure(
query, contents_, self.llm_lingua, instructions, target_token, **kwargs
)
for query, contents_ in zip(queries, contents)
]
return results
[docs]
def llmlingua_pure(
query: str,
contents: List[str],
llm_lingua,
instructions: str,
target_token: int = 300,
**kwargs,
) -> str:
"""
Return the compressed text.
:param query: The query for retrieved passages.
:param contents: The contents of retrieved passages.
:param llm_lingua: The llm instance, that will be used to compress.
:param instructions: The instructions for compression.
:param target_token: The target token for compression.
Default is 300.
:param kwargs: Additional keyword arguments.
:return: The compressed text.
"""
try:
from llmlingua import PromptCompressor
except ImportError:
raise ImportError(
"LongLLMLingua is not installed. Please install it by running `pip install llmlingua`."
)
# split by "\n\n" (recommended by LongLLMLingua authors)
new_context_texts = [c for context in contents for c in context.split("\n\n")]
compress_prompt_params = pop_params(PromptCompressor.compress_prompt, kwargs)
compressed_prompt = llm_lingua.compress_prompt(
new_context_texts,
question=query,
instruction=instructions,
rank_method="longllmlingua",
target_token=target_token,
**compress_prompt_params,
)
compressed_prompt_txt = compressed_prompt["compressed_prompt"]
# separate out the question and instruction
result = "\n\n".join(compressed_prompt_txt.split("\n\n")[1:-1])
return result