| | from __future__ import annotations |
| |
|
| | import re |
| | from typing import Any, Dict, List, Optional, Sequence, Tuple |
| |
|
| | import numpy as np |
| | from langchain_core.callbacks import ( |
| | CallbackManagerForChainRun, |
| | ) |
| | from langchain_core.language_models import BaseLanguageModel |
| | from langchain_core.messages import AIMessage |
| | from langchain_core.output_parsers import StrOutputParser |
| | from langchain_core.prompts import BasePromptTemplate |
| | from langchain_core.retrievers import BaseRetriever |
| | from langchain_core.runnables import Runnable |
| | from pydantic import Field |
| |
|
| | from langchain.chains.base import Chain |
| | from langchain.chains.flare.prompts import ( |
| | PROMPT, |
| | QUESTION_GENERATOR_PROMPT, |
| | FinishedOutputParser, |
| | ) |
| | from langchain.chains.llm import LLMChain |
| |
|
| |
|
| | def _extract_tokens_and_log_probs(response: AIMessage) -> Tuple[List[str], List[float]]: |
| | """Extract tokens and log probabilities from chat model response.""" |
| | tokens = [] |
| | log_probs = [] |
| | for token in response.response_metadata["logprobs"]["content"]: |
| | tokens.append(token["token"]) |
| | log_probs.append(token["logprob"]) |
| | return tokens, log_probs |
| |
|
| |
|
| | class QuestionGeneratorChain(LLMChain): |
| | """Chain that generates questions from uncertain spans.""" |
| |
|
| | prompt: BasePromptTemplate = QUESTION_GENERATOR_PROMPT |
| | """Prompt template for the chain.""" |
| |
|
| | @classmethod |
| | def is_lc_serializable(cls) -> bool: |
| | return False |
| |
|
| | @property |
| | def input_keys(self) -> List[str]: |
| | """Input keys for the chain.""" |
| | return ["user_input", "context", "response"] |
| |
|
| |
|
| | def _low_confidence_spans( |
| | tokens: Sequence[str], |
| | log_probs: Sequence[float], |
| | min_prob: float, |
| | min_token_gap: int, |
| | num_pad_tokens: int, |
| | ) -> List[str]: |
| | _low_idx = np.where(np.exp(log_probs) < min_prob)[0] |
| | low_idx = [i for i in _low_idx if re.search(r"\w", tokens[i])] |
| | if len(low_idx) == 0: |
| | return [] |
| | spans = [[low_idx[0], low_idx[0] + num_pad_tokens + 1]] |
| | for i, idx in enumerate(low_idx[1:]): |
| | end = idx + num_pad_tokens + 1 |
| | if idx - low_idx[i] < min_token_gap: |
| | spans[-1][1] = end |
| | else: |
| | spans.append([idx, end]) |
| | return ["".join(tokens[start:end]) for start, end in spans] |
| |
|
| |
|
| | class FlareChain(Chain): |
| | """Chain that combines a retriever, a question generator, |
| | and a response generator. |
| | |
| | See [Active Retrieval Augmented Generation](https://arxiv.org/abs/2305.06983) paper. |
| | """ |
| |
|
| | question_generator_chain: Runnable |
| | """Chain that generates questions from uncertain spans.""" |
| | response_chain: Runnable |
| | """Chain that generates responses from user input and context.""" |
| | output_parser: FinishedOutputParser = Field(default_factory=FinishedOutputParser) |
| | """Parser that determines whether the chain is finished.""" |
| | retriever: BaseRetriever |
| | """Retriever that retrieves relevant documents from a user input.""" |
| | min_prob: float = 0.2 |
| | """Minimum probability for a token to be considered low confidence.""" |
| | min_token_gap: int = 5 |
| | """Minimum number of tokens between two low confidence spans.""" |
| | num_pad_tokens: int = 2 |
| | """Number of tokens to pad around a low confidence span.""" |
| | max_iter: int = 10 |
| | """Maximum number of iterations.""" |
| | start_with_retrieval: bool = True |
| | """Whether to start with retrieval.""" |
| |
|
| | @property |
| | def input_keys(self) -> List[str]: |
| | """Input keys for the chain.""" |
| | return ["user_input"] |
| |
|
| | @property |
| | def output_keys(self) -> List[str]: |
| | """Output keys for the chain.""" |
| | return ["response"] |
| |
|
| | def _do_generation( |
| | self, |
| | questions: List[str], |
| | user_input: str, |
| | response: str, |
| | _run_manager: CallbackManagerForChainRun, |
| | ) -> Tuple[str, bool]: |
| | callbacks = _run_manager.get_child() |
| | docs = [] |
| | for question in questions: |
| | docs.extend(self.retriever.invoke(question)) |
| | context = "\n\n".join(d.page_content for d in docs) |
| | result = self.response_chain.invoke( |
| | { |
| | "user_input": user_input, |
| | "context": context, |
| | "response": response, |
| | }, |
| | {"callbacks": callbacks}, |
| | ) |
| | if isinstance(result, AIMessage): |
| | result = result.content |
| | marginal, finished = self.output_parser.parse(result) |
| | return marginal, finished |
| |
|
| | def _do_retrieval( |
| | self, |
| | low_confidence_spans: List[str], |
| | _run_manager: CallbackManagerForChainRun, |
| | user_input: str, |
| | response: str, |
| | initial_response: str, |
| | ) -> Tuple[str, bool]: |
| | question_gen_inputs = [ |
| | { |
| | "user_input": user_input, |
| | "current_response": initial_response, |
| | "uncertain_span": span, |
| | } |
| | for span in low_confidence_spans |
| | ] |
| | callbacks = _run_manager.get_child() |
| | if isinstance(self.question_generator_chain, LLMChain): |
| | question_gen_outputs = self.question_generator_chain.apply( |
| | question_gen_inputs, callbacks=callbacks |
| | ) |
| | questions = [ |
| | output[self.question_generator_chain.output_keys[0]] |
| | for output in question_gen_outputs |
| | ] |
| | else: |
| | questions = self.question_generator_chain.batch( |
| | question_gen_inputs, config={"callbacks": callbacks} |
| | ) |
| | _run_manager.on_text( |
| | f"Generated Questions: {questions}", color="yellow", end="\n" |
| | ) |
| | return self._do_generation(questions, user_input, response, _run_manager) |
| |
|
| | def _call( |
| | self, |
| | inputs: Dict[str, Any], |
| | run_manager: Optional[CallbackManagerForChainRun] = None, |
| | ) -> Dict[str, Any]: |
| | _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() |
| |
|
| | user_input = inputs[self.input_keys[0]] |
| |
|
| | response = "" |
| |
|
| | for i in range(self.max_iter): |
| | _run_manager.on_text( |
| | f"Current Response: {response}", color="blue", end="\n" |
| | ) |
| | _input = {"user_input": user_input, "context": "", "response": response} |
| | tokens, log_probs = _extract_tokens_and_log_probs( |
| | self.response_chain.invoke( |
| | _input, {"callbacks": _run_manager.get_child()} |
| | ) |
| | ) |
| | low_confidence_spans = _low_confidence_spans( |
| | tokens, |
| | log_probs, |
| | self.min_prob, |
| | self.min_token_gap, |
| | self.num_pad_tokens, |
| | ) |
| | initial_response = response.strip() + " " + "".join(tokens) |
| | if not low_confidence_spans: |
| | response = initial_response |
| | final_response, finished = self.output_parser.parse(response) |
| | if finished: |
| | return {self.output_keys[0]: final_response} |
| | continue |
| |
|
| | marginal, finished = self._do_retrieval( |
| | low_confidence_spans, |
| | _run_manager, |
| | user_input, |
| | response, |
| | initial_response, |
| | ) |
| | response = response.strip() + " " + marginal |
| | if finished: |
| | break |
| | return {self.output_keys[0]: response} |
| |
|
| | @classmethod |
| | def from_llm( |
| | cls, llm: BaseLanguageModel, max_generation_len: int = 32, **kwargs: Any |
| | ) -> FlareChain: |
| | """Creates a FlareChain from a language model. |
| | |
| | Args: |
| | llm: Language model to use. |
| | max_generation_len: Maximum length of the generated response. |
| | kwargs: Additional arguments to pass to the constructor. |
| | |
| | Returns: |
| | FlareChain class with the given language model. |
| | """ |
| | try: |
| | from langchain_openai import ChatOpenAI |
| | except ImportError: |
| | raise ImportError( |
| | "OpenAI is required for FlareChain. " |
| | "Please install langchain-openai." |
| | "pip install langchain-openai" |
| | ) |
| | llm = ChatOpenAI(max_tokens=max_generation_len, logprobs=True, temperature=0) |
| | response_chain = PROMPT | llm |
| | question_gen_chain = QUESTION_GENERATOR_PROMPT | llm | StrOutputParser() |
| | return cls( |
| | question_generator_chain=question_gen_chain, |
| | response_chain=response_chain, |
| | **kwargs, |
| | ) |
| |
|