Skip to content

Commit

Permalink
allow for user passed requests.Session (#390)
Browse files Browse the repository at this point in the history
  • Loading branch information
kristapratico authored Apr 27, 2023
1 parent 96e7642 commit c556584
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 1 deletion.
7 changes: 6 additions & 1 deletion openai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import os
import sys
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Union, Callable

from contextvars import ContextVar

Expand Down Expand Up @@ -36,6 +36,7 @@
from openai.version import VERSION

if TYPE_CHECKING:
import requests
from aiohttp import ClientSession

api_key = os.environ.get("OPENAI_API_KEY")
Expand All @@ -58,6 +59,10 @@
debug = False
log = None # Set to either 'debug' or 'info', controls console logging

requestssession: Optional[
Union["requests.Session", Callable[[], "requests.Session"]]
] = None # Provide a requests.Session or Session factory.

aiosession: ContextVar[Optional["ClientSession"]] = ContextVar(
"aiohttp-session", default=None
) # Acts as a global aiohttp ClientSession that reuses connections.
Expand Down
4 changes: 4 additions & 0 deletions openai/api_requestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ def _aiohttp_proxies_arg(proxy) -> Optional[str]:


def _make_session() -> requests.Session:
if openai.requestssession:
if isinstance(openai.requestssession, requests.Session):
return openai.requestssession
return openai.requestssession()
if not openai.verify_ssl_certs:
warnings.warn("verify_ssl_certs is ignored; openai always verifies.")
s = requests.Session()
Expand Down
30 changes: 30 additions & 0 deletions openai/tests/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json

import pytest
import requests

import openai
from openai import error
Expand Down Expand Up @@ -86,3 +87,32 @@ def test_timeout_does_not_error():
model="ada",
request_timeout=10,
)


def test_user_session():
with requests.Session() as session:
openai.requestssession = session

completion = openai.Completion.create(
prompt="hello world",
model="ada",
)
assert completion


def test_user_session_factory():
def factory():
session = requests.Session()
session.mount(
"https://",
requests.adapters.HTTPAdapter(max_retries=4),
)
return session

openai.requestssession = factory

completion = openai.Completion.create(
prompt="hello world",
model="ada",
)
assert completion

1 comment on commit c556584

@tchoua08
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice

Please sign in to comment.