modules
__all__ = ['BaseAfterResponseRagModule', 'BaseBeforeResponseRagModule', 'BaseQueryRagModule', 'BaseRagModule', 'BaseRerankRagModule', 'BaseResponseRagModule', 'BaseRetrievalRagModule', 'FootnotePromptResponseRagModule', 'PromptResponseRagModule', 'TextChunksRerankRagModule', 'TextChunksResponseRagModule', 'TextLoaderRetrievalRagModule', 'TranslateQueryRagModule', 'VectorStoreRetrievalRagModule']
module-attribute
Bases:
BaseRagModule
, ABC
Source Code in griptape/engines/rag/modules/response/base_after_response_rag_module.py
@define(kw_only=True) class BaseAfterResponseRagModule(BaseRagModule, ABC): @abstractmethod def run(self, context: RagContext) -> RagContext: ...
run(context)abstractmethod
Source Code in griptape/engines/rag/modules/response/base_after_response_rag_module.py
@abstractmethod def run(self, context: RagContext) -> RagContext: ...
BaseBeforeResponseRagModule
Bases:
BaseRagModule
, ABC
Source Code in griptape/engines/rag/modules/response/base_before_response_rag_module.py
@define(kw_only=True) class BaseBeforeResponseRagModule(BaseRagModule, ABC): @abstractmethod def run(self, context: RagContext) -> RagContext: ...
run(context)abstractmethod
Source Code in griptape/engines/rag/modules/response/base_before_response_rag_module.py
@abstractmethod def run(self, context: RagContext) -> RagContext: ...
BaseQueryRagModule
Bases:
BaseRagModule
, ABC
Source Code in griptape/engines/rag/modules/query/base_query_rag_module.py
@define(kw_only=True) class BaseQueryRagModule(BaseRagModule, ABC): @abstractmethod def run(self, context: RagContext) -> RagContext: ...
run(context)abstractmethod
Source Code in griptape/engines/rag/modules/query/base_query_rag_module.py
@abstractmethod def run(self, context: RagContext) -> RagContext: ...
BaseRagModule
Bases:
FuturesExecutorMixin
, ABC
Source Code in griptape/engines/rag/modules/base_rag_module.py
@define(kw_only=True) class BaseRagModule(FuturesExecutorMixin, ABC): name: str = field( default=Factory(lambda self: f"{self.__class__.__name__}-{uuid.uuid4().hex}", takes_self=True), kw_only=True ) def generate_prompt_stack(self, system_prompt: Optional[str], query: str) -> PromptStack: messages = [] if system_prompt is not None: messages.append(Message(system_prompt, role=Message.SYSTEM_ROLE)) messages.append(Message(query, role=Message.USER_ROLE)) return PromptStack(messages=messages) def get_context_param(self, context: RagContext, key: str) -> Optional[Any]: return context.module_configs.get(self.name, {}).get(key) def set_context_param(self, context: RagContext, key: str, value: Any) -> None: if not isinstance(context.module_configs.get(self.name), dict): context.module_configs[self.name] = {} context.module_configs[self.name][key] = value
name = field(default=Factory(lambda self: f'{self.__class__.__name__}-{uuid.uuid4().hex}', takes_self=True), kw_only=True)
class-attribute instance-attribute
generate_prompt_stack(system_prompt, query)
Source Code in griptape/engines/rag/modules/base_rag_module.py
def generate_prompt_stack(self, system_prompt: Optional[str], query: str) -> PromptStack: messages = [] if system_prompt is not None: messages.append(Message(system_prompt, role=Message.SYSTEM_ROLE)) messages.append(Message(query, role=Message.USER_ROLE)) return PromptStack(messages=messages)
get_context_param(context, key)
Source Code in griptape/engines/rag/modules/base_rag_module.py
def get_context_param(self, context: RagContext, key: str) -> Optional[Any]: return context.module_configs.get(self.name, {}).get(key)
set_context_param(context, key, value)
Source Code in griptape/engines/rag/modules/base_rag_module.py
def set_context_param(self, context: RagContext, key: str, value: Any) -> None: if not isinstance(context.module_configs.get(self.name), dict): context.module_configs[self.name] = {} context.module_configs[self.name][key] = value
BaseRerankRagModule
Bases:
BaseRagModule
, ABC
Source Code in griptape/engines/rag/modules/retrieval/base_rerank_rag_module.py
@define(kw_only=True) class BaseRerankRagModule(BaseRagModule, ABC): rerank_driver: BaseRerankDriver = field() @abstractmethod def run(self, context: RagContext) -> Sequence[BaseArtifact]: ...
rerank_driver = field()
class-attribute instance-attribute
run(context)abstractmethod
Source Code in griptape/engines/rag/modules/retrieval/base_rerank_rag_module.py
@abstractmethod def run(self, context: RagContext) -> Sequence[BaseArtifact]: ...
BaseResponseRagModule
Bases:
BaseRagModule
, ABC
Source Code in griptape/engines/rag/modules/response/base_response_rag_module.py
@define(kw_only=True) class BaseResponseRagModule(BaseRagModule, ABC): @abstractmethod def run(self, context: RagContext) -> BaseArtifact: ...
run(context)abstractmethod
Source Code in griptape/engines/rag/modules/response/base_response_rag_module.py
@abstractmethod def run(self, context: RagContext) -> BaseArtifact: ...
BaseRetrievalRagModule
Bases:
BaseRagModule
, ABC
Source Code in griptape/engines/rag/modules/retrieval/base_retrieval_rag_module.py
@define(kw_only=True) class BaseRetrievalRagModule(BaseRagModule, ABC): @abstractmethod def run(self, context: RagContext) -> Sequence[BaseArtifact]: ...
run(context)abstractmethod
Source Code in griptape/engines/rag/modules/retrieval/base_retrieval_rag_module.py
@abstractmethod def run(self, context: RagContext) -> Sequence[BaseArtifact]: ...
FootnotePromptResponseRagModule
Bases:
PromptResponseRagModule
Source Code in griptape/engines/rag/modules/response/footnote_prompt_response_rag_module.py
@define(kw_only=True) class FootnotePromptResponseRagModule(PromptResponseRagModule): def default_generate_system_template(self, context: RagContext, artifacts: list[TextArtifact]) -> str: return J2("engines/rag/modules/response/footnote_prompt/system.j2").render( text_chunk_artifacts=artifacts, references=utils.references_from_artifacts(artifacts), before_system_prompt="\n\n".join(context.before_query), after_system_prompt="\n\n".join(context.after_query), )
default_generate_system_template(context, artifacts)
Source Code in griptape/engines/rag/modules/response/footnote_prompt_response_rag_module.py
def default_generate_system_template(self, context: RagContext, artifacts: list[TextArtifact]) -> str: return J2("engines/rag/modules/response/footnote_prompt/system.j2").render( text_chunk_artifacts=artifacts, references=utils.references_from_artifacts(artifacts), before_system_prompt="\n\n".join(context.before_query), after_system_prompt="\n\n".join(context.after_query), )
PromptResponseRagModule
Bases:
BaseResponseRagModule
, RuleMixin
Source Code in griptape/engines/rag/modules/response/prompt_response_rag_module.py
@define(kw_only=True) class PromptResponseRagModule(BaseResponseRagModule, RuleMixin): prompt_driver: BasePromptDriver = field(default=Factory(lambda: Defaults.drivers_config.prompt_driver)) answer_token_offset: int = field(default=400) metadata: Optional[str] = field(default=None) generate_system_template: Callable[[RagContext, list[TextArtifact]], str] = field( default=Factory(lambda self: self.default_generate_system_template, takes_self=True), ) def run(self, context: RagContext) -> BaseArtifact: query = context.query tokenizer = self.prompt_driver.tokenizer included_chunks = [] system_prompt = self.generate_system_template(context, included_chunks) for artifact in context.text_chunks: included_chunks.append(artifact) system_prompt = self.generate_system_template(context, included_chunks) message_token_count = self.prompt_driver.tokenizer.count_tokens( self.prompt_driver.prompt_stack_to_string(self.generate_prompt_stack(system_prompt, query)), ) if message_token_count + self.answer_token_offset >= tokenizer.max_input_tokens: included_chunks.pop() system_prompt = self.generate_system_template(context, included_chunks) break output = self.prompt_driver.run(self.generate_prompt_stack(system_prompt, query)).to_artifact() if isinstance(output, TextArtifact): return output raise ValueError("Prompt driver did not return a TextArtifact") def default_generate_system_template(self, context: RagContext, artifacts: list[TextArtifact]) -> str: params: dict[str, Any] = {"text_chunks": [c.to_text() for c in artifacts]} if len(self.rulesets) > 0: params["rulesets"] = J2("rulesets/rulesets.j2").render(rulesets=self.rulesets) if self.metadata is not None: params["metadata"] = J2("engines/rag/modules/response/metadata/system.j2").render(metadata=self.metadata) return J2("engines/rag/modules/response/prompt/system.j2").render(**params)
answer_token_offset = field(default=400)
class-attribute instance-attributegenerate_system_template = field(default=Factory(lambda self: self.default_generate_system_template, takes_self=True))
class-attribute instance-attributemetadata = field(default=None)
class-attribute instance-attributeprompt_driver = field(default=Factory(lambda: Defaults.drivers_config.prompt_driver))
class-attribute instance-attribute
default_generate_system_template(context, artifacts)
Source Code in griptape/engines/rag/modules/response/prompt_response_rag_module.py
def default_generate_system_template(self, context: RagContext, artifacts: list[TextArtifact]) -> str: params: dict[str, Any] = {"text_chunks": [c.to_text() for c in artifacts]} if len(self.rulesets) > 0: params["rulesets"] = J2("rulesets/rulesets.j2").render(rulesets=self.rulesets) if self.metadata is not None: params["metadata"] = J2("engines/rag/modules/response/metadata/system.j2").render(metadata=self.metadata) return J2("engines/rag/modules/response/prompt/system.j2").render(**params)
run(context)
Source Code in griptape/engines/rag/modules/response/prompt_response_rag_module.py
def run(self, context: RagContext) -> BaseArtifact: query = context.query tokenizer = self.prompt_driver.tokenizer included_chunks = [] system_prompt = self.generate_system_template(context, included_chunks) for artifact in context.text_chunks: included_chunks.append(artifact) system_prompt = self.generate_system_template(context, included_chunks) message_token_count = self.prompt_driver.tokenizer.count_tokens( self.prompt_driver.prompt_stack_to_string(self.generate_prompt_stack(system_prompt, query)), ) if message_token_count + self.answer_token_offset >= tokenizer.max_input_tokens: included_chunks.pop() system_prompt = self.generate_system_template(context, included_chunks) break output = self.prompt_driver.run(self.generate_prompt_stack(system_prompt, query)).to_artifact() if isinstance(output, TextArtifact): return output raise ValueError("Prompt driver did not return a TextArtifact")
TextChunksRerankRagModule
Bases:
BaseRerankRagModule
Source Code in griptape/engines/rag/modules/retrieval/text_chunks_rerank_rag_module.py
@define(kw_only=True) class TextChunksRerankRagModule(BaseRerankRagModule): def run(self, context: RagContext) -> Sequence[BaseArtifact]: return self.rerank_driver.run(context.query, context.text_chunks)
run(context)
Source Code in griptape/engines/rag/modules/retrieval/text_chunks_rerank_rag_module.py
def run(self, context: RagContext) -> Sequence[BaseArtifact]: return self.rerank_driver.run(context.query, context.text_chunks)
TextChunksResponseRagModule
Bases:
BaseResponseRagModule
Source Code in griptape/engines/rag/modules/response/text_chunks_response_rag_module.py
@define(kw_only=True) class TextChunksResponseRagModule(BaseResponseRagModule): def run(self, context: RagContext) -> BaseArtifact: return ListArtifact(context.text_chunks)
run(context)
Source Code in griptape/engines/rag/modules/response/text_chunks_response_rag_module.py
def run(self, context: RagContext) -> BaseArtifact: return ListArtifact(context.text_chunks)
TextLoaderRetrievalRagModule
Bases:
BaseRetrievalRagModule
Source Code in griptape/engines/rag/modules/retrieval/text_loader_retrieval_rag_module.py
@define(kw_only=True) class TextLoaderRetrievalRagModule(BaseRetrievalRagModule): loader: TextLoader = field() chunker: TextChunker = field(default=Factory(lambda: TextChunker())) vector_store_driver: BaseVectorStoreDriver = field() source: Any = field() query_params: dict[str, Any] = field(factory=dict) process_query_output: Callable[[list[BaseVectorStoreDriver.Entry]], Sequence[TextArtifact]] = field( default=Factory(lambda: lambda es: [e.to_artifact() for e in es]), ) def run(self, context: RagContext) -> Sequence[TextArtifact]: namespace = uuid.uuid4().hex context_source = self.get_context_param(context, "source") source = self.source if context_source is None else context_source query_params = utils.dict_merge(self.query_params, self.get_context_param(context, "query_params")) query_params["namespace"] = namespace loader_output = self.loader.load(source) chunks = self.chunker.chunk(loader_output) self.vector_store_driver.upsert_collection({namespace: chunks}) return self.process_query_output(self.vector_store_driver.query(context.query, **query_params))
chunker = field(default=Factory(lambda: TextChunker()))
class-attribute instance-attributeloader = field()
class-attribute instance-attributeprocess_query_output = field(default=Factory(lambda: lambda es: [e.to_artifact() for e in es]))
class-attribute instance-attributequery_params = field(factory=dict)
class-attribute instance-attributesource = field()
class-attribute instance-attributevector_store_driver = field()
class-attribute instance-attribute
run(context)
Source Code in griptape/engines/rag/modules/retrieval/text_loader_retrieval_rag_module.py
def run(self, context: RagContext) -> Sequence[TextArtifact]: namespace = uuid.uuid4().hex context_source = self.get_context_param(context, "source") source = self.source if context_source is None else context_source query_params = utils.dict_merge(self.query_params, self.get_context_param(context, "query_params")) query_params["namespace"] = namespace loader_output = self.loader.load(source) chunks = self.chunker.chunk(loader_output) self.vector_store_driver.upsert_collection({namespace: chunks}) return self.process_query_output(self.vector_store_driver.query(context.query, **query_params))
TranslateQueryRagModule
Bases:
BaseQueryRagModule
Source Code in griptape/engines/rag/modules/query/translate_query_rag_module.py
@define(kw_only=True) class TranslateQueryRagModule(BaseQueryRagModule): prompt_driver: BasePromptDriver = field() language: str = field() generate_user_template: Callable[[str, str], str] = field( default=Factory(lambda self: self.default_generate_user_template, takes_self=True), ) def run(self, context: RagContext) -> RagContext: user_prompt = self.generate_user_template(context.query, self.language) output = self.prompt_driver.run(self.generate_prompt_stack(None, user_prompt)).to_artifact() context.query = output.to_text() return context def default_generate_user_template(self, query: str, language: str) -> str: return J2("engines/rag/modules/query/translate/user.j2").render(query=query, language=language)
generate_user_template = field(default=Factory(lambda self: self.default_generate_user_template, takes_self=True))
class-attribute instance-attributelanguage = field()
class-attribute instance-attributeprompt_driver = field()
class-attribute instance-attribute
default_generate_user_template(query, language)
Source Code in griptape/engines/rag/modules/query/translate_query_rag_module.py
def default_generate_user_template(self, query: str, language: str) -> str: return J2("engines/rag/modules/query/translate/user.j2").render(query=query, language=language)
run(context)
Source Code in griptape/engines/rag/modules/query/translate_query_rag_module.py
def run(self, context: RagContext) -> RagContext: user_prompt = self.generate_user_template(context.query, self.language) output = self.prompt_driver.run(self.generate_prompt_stack(None, user_prompt)).to_artifact() context.query = output.to_text() return context
VectorStoreRetrievalRagModule
Bases:
BaseRetrievalRagModule
Source Code in griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py
@define(kw_only=True) class VectorStoreRetrievalRagModule(BaseRetrievalRagModule): vector_store_driver: BaseVectorStoreDriver = field( default=Factory(lambda: Defaults.drivers_config.vector_store_driver) ) query_params: dict[str, Any] = field(factory=dict) process_query_output: Callable[[list[BaseVectorStoreDriver.Entry]], Sequence[TextArtifact]] = field( default=Factory(lambda: lambda es: [e.to_artifact() for e in es]), ) def run(self, context: RagContext) -> Sequence[TextArtifact]: query_params = utils.dict_merge(self.query_params, self.get_context_param(context, "query_params")) return self.process_query_output(self.vector_store_driver.query(context.query, **query_params))
process_query_output = field(default=Factory(lambda: lambda es: [e.to_artifact() for e in es]))
class-attribute instance-attributequery_params = field(factory=dict)
class-attribute instance-attributevector_store_driver = field(default=Factory(lambda: Defaults.drivers_config.vector_store_driver))
class-attribute instance-attribute
run(context)
Source Code in griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py
def run(self, context: RagContext) -> Sequence[TextArtifact]: query_params = utils.dict_merge(self.query_params, self.get_context_param(context, "query_params")) return self.process_query_output(self.vector_store_driver.query(context.query, **query_params))
- On this page
- BaseBeforeResponseRagModule
- BaseQueryRagModule
- BaseRagModule
- BaseRerankRagModule
- BaseResponseRagModule
- BaseRetrievalRagModule
- FootnotePromptResponseRagModule
- PromptResponseRagModule
- TextChunksRerankRagModule
- TextChunksResponseRagModule
- TextLoaderRetrievalRagModule
- TranslateQueryRagModule
- VectorStoreRetrievalRagModule
Could this page be better? Report a problem or suggest an addition!