Skip to content

Authentication

Auth

Bases: ABC

Base class for Salesforce authentication.

Source code in src/aiosalesforce/auth/base.py
class Auth(ABC):
    """
    Base class for Salesforce authentication.

    """

    def __init__(self) -> None:
        self.__access_token: str | None = None
        self.__lock = asyncio.Lock()

    @final
    async def get_access_token(self, client: "Salesforce") -> str:
        """
        Get access token.

        If this is the first time this method is called, it will acquire a new
        access token from Salesforce.

        Parameters
        ----------
        client : Salesforce
            Salesforce client.

        Returns
        -------
        str
            Access token

        """
        async with self.__lock:
            if self.__access_token is None:
                logger.debug(
                    "Acquiring new access token using %s for %s",
                    self.__class__.__name__,
                    client.base_url,
                )
                self.__access_token = await self._acquire_new_access_token(client)
            elif self.expired:
                logger.debug(
                    "Token expired, refreshing access token using %s for %s",
                    self.__class__.__name__,
                    client.base_url,
                )
                self.__access_token = await self._refresh_access_token(client)
            return self.__access_token

    @final
    async def refresh_access_token(self, client: "Salesforce") -> str:
        """
        Refresh the access token.

        Parameters
        ----------
        client : Salesforce
            Salesforce client.

        Returns
        -------
        str
            Access token

        """
        if self.__access_token is None:
            raise RuntimeError("No access token to refresh")
        token_before_refresh = self.__access_token
        async with self.__lock:
            if self.__access_token == token_before_refresh:
                logger.debug(
                    "Refreshing access token using %s for %s",
                    self.__class__.__name__,
                    client.base_url,
                )
                self.__access_token = await self._refresh_access_token(client)
            return self.__access_token

    @abstractmethod
    async def _acquire_new_access_token(self, client: "Salesforce") -> str:
        """
        Acquire a new access token from Salesforce.

        Implementation is responsible for emitting RequestEvent and ResponseEvent.

        Parameters
        ----------
        client : Salesforce
            Salesforce client.

        Returns
        -------
        str
            Access token

        """

    async def _refresh_access_token(self, client: "Salesforce") -> str:
        """
        Refresh the access token.

        Implementation is responsible for emitting RequestEvent and ResponseEvent.

        Parameters
        ----------
        client : Salesforce
            Salesforce client.

        Returns
        -------
        str
            Access token

        """
        return await self._acquire_new_access_token(client)

    @property
    def expired(self) -> bool:
        """True if the access token is expired."""
        if self.__access_token is None:  # pragma: no cover
            raise RuntimeError("Cannot check expiration of a non-existent access token")
        # By default, assumes the access token never expires
        # Salesforce client automatically refreshes the token after 401 response
        return False

expired: bool property

True if the access token is expired.

get_access_token(client) async

Get access token.

If this is the first time this method is called, it will acquire a new access token from Salesforce.

Parameters:

Name Type Description Default
client Salesforce

Salesforce client.

required

Returns:

Type Description
str

Access token

Source code in src/aiosalesforce/auth/base.py
@final
async def get_access_token(self, client: "Salesforce") -> str:
    """
    Get access token.

    If this is the first time this method is called, it will acquire a new
    access token from Salesforce.

    Parameters
    ----------
    client : Salesforce
        Salesforce client.

    Returns
    -------
    str
        Access token

    """
    async with self.__lock:
        if self.__access_token is None:
            logger.debug(
                "Acquiring new access token using %s for %s",
                self.__class__.__name__,
                client.base_url,
            )
            self.__access_token = await self._acquire_new_access_token(client)
        elif self.expired:
            logger.debug(
                "Token expired, refreshing access token using %s for %s",
                self.__class__.__name__,
                client.base_url,
            )
            self.__access_token = await self._refresh_access_token(client)
        return self.__access_token

refresh_access_token(client) async

Refresh the access token.

Parameters:

Name Type Description Default
client Salesforce

Salesforce client.

required

Returns:

Type Description
str

Access token

Source code in src/aiosalesforce/auth/base.py
@final
async def refresh_access_token(self, client: "Salesforce") -> str:
    """
    Refresh the access token.

    Parameters
    ----------
    client : Salesforce
        Salesforce client.

    Returns
    -------
    str
        Access token

    """
    if self.__access_token is None:
        raise RuntimeError("No access token to refresh")
    token_before_refresh = self.__access_token
    async with self.__lock:
        if self.__access_token == token_before_refresh:
            logger.debug(
                "Refreshing access token using %s for %s",
                self.__class__.__name__,
                client.base_url,
            )
            self.__access_token = await self._refresh_access_token(client)
        return self.__access_token

SoapLogin

Bases: Auth

Authenticate using the SOAP login method.

https://developer.salesforce.com/docs/atlas.en-us.api.meta/api/sforce_api_calls_login.htm

Parameters:

Name Type Description Default
username str

Username.

required
password str

Password.

required
security_token str

Security token.

required
Source code in src/aiosalesforce/auth/soap.py
class SoapLogin(Auth):
    """
    Authenticate using the SOAP login method.

    https://developer.salesforce.com/docs/atlas.en-us.api.meta/api/sforce_api_calls_login.htm

    Parameters
    ----------
    username : str
        Username.
    password : str
        Password.
    security_token : str
        Security token.

    """

    def __init__(
        self,
        username: str,
        password: str,
        security_token: str,
    ):
        super().__init__()
        self.username = username
        self.password = password
        self.security_token = security_token

        self._expiration_time: float | None = None

    async def _acquire_new_access_token(self, client: "Salesforce") -> str:
        soap_xml_payload = f"""
        <?xml version="1.0" encoding="utf-8" ?>
        <env:Envelope
            xmlns:xsd="http://www.w3.org/2001/XMLSchema"
            xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
            xmlns:env="http://schemas.xmlsoap.org/soap/envelope/">
            <env:Body>
                <n1:login xmlns:n1="urn:partner.soap.sforce.com">
                    <n1:username>{self.username}</n1:username>
                    <n1:password>{self.password}{self.security_token}</n1:password>
                </n1:login>
            </env:Body>
        </env:Envelope>
        """
        request = client.httpx_client.build_request(
            "POST",
            f"{client.base_url}/services/Soap/u/{client.version}",
            content=textwrap.dedent(soap_xml_payload).strip(),
            headers={
                "Content-Type": "text/xml; charset=UTF-8",
                "SOAPAction": "login",
                "Accept": "text/xml",
            },
        )
        await client.event_bus.publish_event(
            RequestEvent(
                type="request",
                request=request,
            )
        )
        retry_context = client.retry_policy.create_context()
        response = await retry_context.send_request_with_retries(
            httpx_client=client.httpx_client,
            event_bus=client.event_bus,
            semaphore=client._semaphore,
            request=request,
        )
        response_text = response.text
        if not response.is_success:
            try:
                exception_code = str(
                    re.search(
                        r"<sf:exceptionCode>(.+)<\/sf:exceptionCode>",
                        response_text,
                    ).groups()[0]  # type: ignore
                )
            except AttributeError:  # pragma: no cover
                exception_code = None
            try:
                exception_message = str(
                    re.search(
                        r"<sf:exceptionMessage>(.+)<\/sf:exceptionMessage>",
                        response_text,
                    ).groups()[0]  # type: ignore
                )
            except AttributeError:  # pragma: no cover
                exception_message = response_text
            raise AuthenticationError(
                message=(
                    f"[{exception_code}] {exception_message}"
                    if exception_code
                    else exception_message
                ),
                response=response,
                error_code=exception_code,
                error_message=exception_message,
            )
        match_ = re.search(r"<sessionId>(.+)<\/sessionId>", response_text)
        if match_ is None:  # pragma: no cover
            raise AuthenticationError(
                f"Failed to parse sessionId from the SOAP response: {response_text}",
                response,
            )
        session_id = match_.groups()[0]

        # Parse expiration time
        match_ = re.search(
            r"<sessionSecondsValid>(.+)<\/sessionSecondsValid>",
            response_text,
        )
        self._expiration_time = None
        if match_ is not None:
            try:
                self._expiration_time = time.time() + int(match_.groups()[0])
            except ValueError:  # pragma: no cover
                pass

        await client.event_bus.publish_event(
            ResponseEvent(
                type="response",
                response=response,
            )
        )
        return session_id

    @property
    def expired(self) -> bool:
        super().expired
        if self._expiration_time is None:  # pragma: no cover
            return False
        return self._expiration_time <= time.time()

ClientCredentialsFlow

Bases: Auth

Authenticate using the OAuth 2.0 Client Credentials Flow.

https://help.salesforce.com/s/articleView?id=sf.remoteaccess_oauth_client_credentials_flow.htm&type=5

Parameters:

Name Type Description Default
client_id str

Client ID.

required
client_secret str

Client secret.

required
timeout float

Timeout for the access token in seconds. By default assumed to never expire.

None
Source code in src/aiosalesforce/auth/client_credentials_flow.py
class ClientCredentialsFlow(Auth):
    """
    Authenticate using the OAuth 2.0 Client Credentials Flow.

    https://help.salesforce.com/s/articleView?id=sf.remoteaccess_oauth_client_credentials_flow.htm&type=5

    Parameters
    ----------
    client_id : str
        Client ID.
    client_secret : str
        Client secret.
    timeout : float, optional
        Timeout for the access token in seconds.
        By default assumed to never expire.

    """

    def __init__(
        self,
        client_id: str,
        client_secret: str,
        timeout: float | None = None,
    ) -> None:
        super().__init__()
        self.client_id = client_id
        self.client_secret = client_secret
        self.timeout = timeout

        self._expiration_time: float | None = None

    async def _acquire_new_access_token(self, client: "Salesforce") -> str:
        request = client.httpx_client.build_request(
            "POST",
            f"{client.base_url}/services/oauth2/token",
            headers={
                "Content-Type": "application/x-www-form-urlencoded",
                "Accept": "application/json",
            },
            data={
                "grant_type": "client_credentials",
                "client_id": self.client_id,
                "client_secret": self.client_secret,
            },
        )
        await client.event_bus.publish_event(
            RequestEvent(
                type="request",
                request=request,
            )
        )
        retry_context = client.retry_policy.create_context()
        response = await retry_context.send_request_with_retries(
            httpx_client=client.httpx_client,
            event_bus=client.event_bus,
            semaphore=client._semaphore,
            request=request,
        )
        if not response.is_success:
            try:
                response_json = json_loads(response.content)
                error_code = response_json["error"]
                error_message = response_json["error_description"]
            except Exception:  # pragma: no cover
                error_code = None
                error_message = response.text
            raise AuthenticationError(
                f"[{error_code}] {error_message}" if error_code else error_message,
                response=response,
                error_code=error_code,
                error_message=error_message,
            )
        await client.event_bus.publish_event(
            ResponseEvent(
                type="response",
                response=response,
            )
        )
        if self.timeout is not None:
            self._expiration_time = time.time() + self.timeout
        return json_loads(response.content)["access_token"]

    @property
    def expired(self) -> bool:
        super().expired
        if self._expiration_time is None:  # pragma: no cover
            return False
        return self._expiration_time <= time.time()