openai_embedding_driver

Bases: BaseEmbeddingDriver

Attributes

NameTypeDescription
modelstrOpenAI embedding model name. Defaults to text-embedding-3-small.
base_urlOptional[str]API URL. Defaults to OpenAI's v1 API URL.
api_keyOptional[str]API key to pass directly. Defaults to OPENAI_API_KEY environment variable.
organizationOptional[str]OpenAI organization. Defaults to 'OPENAI_ORGANIZATION' environment variable.
tokenizerOpenAiTokenizerOptionally provide custom OpenAiTokenizer.
clientOpenAIOptionally provide custom openai.OpenAI client.
azure_deploymentOpenAIAn Azure OpenAi deployment id.
azure_endpointOpenAIAn Azure OpenAi endpoint.
azure_ad_tokenOpenAIAn optional Azure Active Directory token.
azure_ad_token_providerOpenAIAn optional Azure Active Directory token provider.
api_versionOpenAIAn Azure OpenAi API version.
Source Code in griptape/drivers/embedding/openai_embedding_driver.py
@define
class OpenAiEmbeddingDriver(BaseEmbeddingDriver):
    """OpenAI Embedding Driver.

    Attributes:
        model: OpenAI embedding model name. Defaults to `text-embedding-3-small`.
        base_url: API URL. Defaults to OpenAI's v1 API URL.
        api_key: API key to pass directly. Defaults to `OPENAI_API_KEY` environment variable.
        organization: OpenAI organization. Defaults to 'OPENAI_ORGANIZATION' environment variable.
        tokenizer: Optionally provide custom `OpenAiTokenizer`.
        client: Optionally provide custom `openai.OpenAI` client.
        azure_deployment: An Azure OpenAi deployment id.
        azure_endpoint: An Azure OpenAi endpoint.
        azure_ad_token: An optional Azure Active Directory token.
        azure_ad_token_provider: An optional Azure Active Directory token provider.
        api_version: An Azure OpenAi API version.
    """

    DEFAULT_MODEL = "text-embedding-3-small"

    model: str = field(default=DEFAULT_MODEL, kw_only=True, metadata={"serializable": True})
    base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False})
    organization: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    tokenizer: OpenAiTokenizer = field(
        default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True),
        kw_only=True,
    )
    _client: Optional[openai.OpenAI] = field(
        default=None, kw_only=True, alias="client", metadata={"serializable": False}
    )

    @lazy_property()
    def client(self) -> openai.OpenAI:
        return openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization)

    def try_embed_chunk(self, chunk: str, **kwargs) -> list[float]:
        # Address a performance issue in older ada models
        # https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
        if self.model.endswith("001"):
            chunk = chunk.replace("\n", " ")
        return self.client.embeddings.create(**self._params(chunk)).data[0].embedding

    def _params(self, chunk: str) -> dict:
        return {"input": chunk, "model": self.model}
  • DEFAULT_MODEL = 'text-embedding-3-small' class-attribute instance-attribute

  • _client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False}) class-attribute instance-attribute

  • api_key = field(default=None, kw_only=True, metadata={'serializable': False}) class-attribute instance-attribute

  • base_url = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • model = field(default=DEFAULT_MODEL, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • organization = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • tokenizer = field(default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True) class-attribute instance-attribute

_params(chunk)

Source Code in griptape/drivers/embedding/openai_embedding_driver.py
def _params(self, chunk: str) -> dict:
    return {"input": chunk, "model": self.model}

client()

Source Code in griptape/drivers/embedding/openai_embedding_driver.py
@lazy_property()
def client(self) -> openai.OpenAI:
    return openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization)

try_embed_chunk(chunk, **kwargs)

Source Code in griptape/drivers/embedding/openai_embedding_driver.py
def try_embed_chunk(self, chunk: str, **kwargs) -> list[float]:
    # Address a performance issue in older ada models
    # https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
    if self.model.endswith("001"):
        chunk = chunk.replace("\n", " ")
    return self.client.embeddings.create(**self._params(chunk)).data[0].embedding

Could this page be better? Report a problem or suggest an addition!