__all__ = ['RagContext', 'RagEngine']
module-attribute
Bases:
SerializableMixin
Attributes
Name | Type | Description |
---|---|---|
query | str | Query provided by the user. |
module_configs | dict[str, dict] | Dictionary of module configs. First key should be a module name and the second a dictionary of configs parameters. |
before_query | list[str] | An optional list of strings to add before the query in response modules. |
after_query | list[str] | An optional list of strings to add after the query in response modules. |
text_chunks | list[TextArtifact] | A list of text chunks to pass around from the retrieval stage to the response stage. |
outputs | list[BaseArtifact] | List of outputs from the response stage. |
Source Code in griptape/engines/rag/rag_context.py
@define(kw_only=True) class RagContext(SerializableMixin): """Used by RagEngine stages and module to pass context that individual modules are expected to update in the `run` method. Attributes: query: Query provided by the user. module_configs: Dictionary of module configs. First key should be a module name and the second a dictionary of configs parameters. before_query: An optional list of strings to add before the query in response modules. after_query: An optional list of strings to add after the query in response modules. text_chunks: A list of text chunks to pass around from the retrieval stage to the response stage. outputs: List of outputs from the response stage. """ query: str = field(metadata={"serializable": True}) module_configs: dict[str, dict] = field(factory=dict, metadata={"serializable": True}) before_query: list[str] = field(factory=list, metadata={"serializable": True}) after_query: list[str] = field(factory=list, metadata={"serializable": True}) text_chunks: list[TextArtifact] = field(factory=list, metadata={"serializable": True}) outputs: list[BaseArtifact] = field(factory=list, metadata={"serializable": True}) def get_references(self) -> list[Reference]: return utils.references_from_artifacts(self.text_chunks)
after_query = field(factory=list, metadata={'serializable': True})
class-attribute instance-attributebefore_query = field(factory=list, metadata={'serializable': True})
class-attribute instance-attributemodule_configs = field(factory=dict, metadata={'serializable': True})
class-attribute instance-attributeoutputs = field(factory=list, metadata={'serializable': True})
class-attribute instance-attributequery = field(metadata={'serializable': True})
class-attribute instance-attributetext_chunks = field(factory=list, metadata={'serializable': True})
class-attribute instance-attribute
get_references()
Source Code in griptape/engines/rag/rag_context.py
def get_references(self) -> list[Reference]: return utils.references_from_artifacts(self.text_chunks)
RagEngine
Source Code in griptape/engines/rag/rag_engine.py
@define(kw_only=True) class RagEngine: query_stage: Optional[QueryRagStage] = field(default=None) retrieval_stage: Optional[RetrievalRagStage] = field(default=None) response_stage: Optional[ResponseRagStage] = field(default=None) def __attrs_post_init__(self) -> None: modules = [] if self.query_stage is not None: modules.extend(self.query_stage.modules) if self.retrieval_stage is not None: modules.extend(self.retrieval_stage.modules) if self.response_stage is not None: modules.extend(self.response_stage.modules) module_names = [m.name for m in modules] if len(module_names) > len(set(module_names)): raise ValueError("module names have to be unique") def process_query(self, query: str) -> RagContext: return self.process(RagContext(query=query)) def process(self, context: RagContext) -> RagContext: if self.query_stage: context = self.query_stage.run(context) if self.retrieval_stage: context = self.retrieval_stage.run(context) if self.response_stage: context = self.response_stage.run(context) return context
query_stage = field(default=None)
class-attribute instance-attributeresponse_stage = field(default=None)
class-attribute instance-attributeretrieval_stage = field(default=None)
class-attribute instance-attribute
attrs_post_init()
Source Code in griptape/engines/rag/rag_engine.py
def __attrs_post_init__(self) -> None: modules = [] if self.query_stage is not None: modules.extend(self.query_stage.modules) if self.retrieval_stage is not None: modules.extend(self.retrieval_stage.modules) if self.response_stage is not None: modules.extend(self.response_stage.modules) module_names = [m.name for m in modules] if len(module_names) > len(set(module_names)): raise ValueError("module names have to be unique")
process(context)
Source Code in griptape/engines/rag/rag_engine.py
def process(self, context: RagContext) -> RagContext: if self.query_stage: context = self.query_stage.run(context) if self.retrieval_stage: context = self.retrieval_stage.run(context) if self.response_stage: context = self.response_stage.run(context) return context
process_query(query)
Source Code in griptape/engines/rag/rag_engine.py
def process_query(self, query: str) -> RagContext: return self.process(RagContext(query=query))
- On this page
- RagEngine
Could this page be better? Report a problem or suggest an addition!