openai_embedding_driver
Bases:
BaseEmbeddingDriver
Attributes
Name | Type | Description |
---|---|---|
model | str | OpenAI embedding model name. Defaults to text-embedding-3-small . |
base_url | Optional[str] | API URL. Defaults to OpenAI's v1 API URL. |
api_key | Optional[str] | API key to pass directly. Defaults to OPENAI_API_KEY environment variable. |
organization | Optional[str] | OpenAI organization. Defaults to 'OPENAI_ORGANIZATION' environment variable. |
tokenizer | OpenAiTokenizer | Optionally provide custom OpenAiTokenizer . |
client | OpenAI | Optionally provide custom openai.OpenAI client. |
azure_deployment | OpenAI | An Azure OpenAi deployment id. |
azure_endpoint | OpenAI | An Azure OpenAi endpoint. |
azure_ad_token | OpenAI | An optional Azure Active Directory token. |
azure_ad_token_provider | OpenAI | An optional Azure Active Directory token provider. |
api_version | OpenAI | An Azure OpenAi API version. |
Source Code in griptape/drivers/embedding/openai_embedding_driver.py
@define class OpenAiEmbeddingDriver(BaseEmbeddingDriver): """OpenAI Embedding Driver. Attributes: model: OpenAI embedding model name. Defaults to `text-embedding-3-small`. base_url: API URL. Defaults to OpenAI's v1 API URL. api_key: API key to pass directly. Defaults to `OPENAI_API_KEY` environment variable. organization: OpenAI organization. Defaults to 'OPENAI_ORGANIZATION' environment variable. tokenizer: Optionally provide custom `OpenAiTokenizer`. client: Optionally provide custom `openai.OpenAI` client. azure_deployment: An Azure OpenAi deployment id. azure_endpoint: An Azure OpenAi endpoint. azure_ad_token: An optional Azure Active Directory token. azure_ad_token_provider: An optional Azure Active Directory token provider. api_version: An Azure OpenAi API version. """ DEFAULT_MODEL = "text-embedding-3-small" model: str = field(default=DEFAULT_MODEL, kw_only=True, metadata={"serializable": True}) base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False}) organization: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) tokenizer: OpenAiTokenizer = field( default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True, ) _client: Optional[openai.OpenAI] = field( default=None, kw_only=True, alias="client", metadata={"serializable": False} ) @lazy_property() def client(self) -> openai.OpenAI: return openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization) def try_embed_chunk(self, chunk: str, **kwargs) -> list[float]: # Address a performance issue in older ada models # https://github.com/openai/openai-python/issues/418#issuecomment-1525939500 if self.model.endswith("001"): chunk = chunk.replace("\n", " ") return self.client.embeddings.create(**self._params(chunk)).data[0].embedding def _params(self, chunk: str) -> dict: return {"input": chunk, "model": self.model}
DEFAULT_MODEL = 'text-embedding-3-small'
class-attribute instance-attribute_client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False})
class-attribute instance-attributeapi_key = field(default=None, kw_only=True, metadata={'serializable': False})
class-attribute instance-attributebase_url = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributemodel = field(default=DEFAULT_MODEL, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeorganization = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributetokenizer = field(default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True)
class-attribute instance-attribute
_params(chunk)
Source Code in griptape/drivers/embedding/openai_embedding_driver.py
def _params(self, chunk: str) -> dict: return {"input": chunk, "model": self.model}
client()
Source Code in griptape/drivers/embedding/openai_embedding_driver.py
@lazy_property() def client(self) -> openai.OpenAI: return openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization)
try_embed_chunk(chunk, **kwargs)
Source Code in griptape/drivers/embedding/openai_embedding_driver.py
def try_embed_chunk(self, chunk: str, **kwargs) -> list[float]: # Address a performance issue in older ada models # https://github.com/openai/openai-python/issues/418#issuecomment-1525939500 if self.model.endswith("001"): chunk = chunk.replace("\n", " ") return self.client.embeddings.create(**self._params(chunk)).data[0].embedding
Could this page be better? Report a problem or suggest an addition!