Skip to content

Commit

Permalink
Merge pull request #42 from ai-forever/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
Rai220 authored Nov 15, 2024
2 parents 09f77c5 + 8dabc3d commit b6b3952
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 33 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,4 @@ cython_debug/
.ruff_cache/

.vscode/
.aider*
17 changes: 17 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Python-библиотека для работы с GigaChat API

[![GitHub Actions Workflow Status](https://img.shields.io/github/actions/workflow/status/ai-forever/gigachat/gigachat.yml?style=flat-square)](https://github.com/ai-forever/gigachat/actions/workflows/gigachat.yml)
[![GitHub License](https://img.shields.io/github/license/ai-forever/gigachat?style=flat-square)](https://opensource.org/license/MIT)
[![GitHub Downloads (all assets, all releases)](https://img.shields.io/pypi/dm/gigachat?style=flat-square?style=flat-square)](https://pypistats.org/packages/gigachat)
[![GitHub Repo stars](https://img.shields.io/github/stars/ai-forever/gigachat?style=flat-square)](https://star-history.com/#ai-forever/gigachat)
[![GitHub Open Issues](https://img.shields.io/github/issues-raw/ai-forever/gigachat)](https://github.com/ai-forever/gigachat/issues)

Библиотека Python, позволяющая [GigaChain](https://github.com/ai-forever/gigachain) обращаться к GigaChat — нейросетевой модели, которая умеет вести диалог, писать код, создавать тексты и картинки по запросу.

Обмен данными с сервисом обеспечивается с помощью GigaChat API. О том как получить доступ к API читайте в [официальной документации](https://developers.sber.ru/docs/ru/gigachat/api/integration).
Expand Down Expand Up @@ -70,6 +76,17 @@ giga = GigaChat(
)
```

Предварительная авторизация (в случае, если необходимо получить временный токен и авторизоваться до отправки запросов; по умолчанию, библиотека автоматически получает временный токен при первом запросе к API):

```py
giga = GigaChat(
base_url="https://gigachat.devices.sberbank.ru/api/v1",
user=...,
password=...,
)
giga.get_token()
```

Взаимная аутентификация по протоколу TLS (mTLS):

```py
Expand Down
40 changes: 20 additions & 20 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "gigachat"
version = "0.1.34"
version = "0.1.35"
description = "GigaChat. Python-library for GigaChain and LangChain"
authors = ["Konstantin Krestnikov <rai220@gmail.com>", "Sergey Malyshev <in1t@ya.ru>"]
license = "MIT"
Expand All @@ -15,7 +15,6 @@ httpx = "<1"

[tool.poetry.group.dev.dependencies]
black = "^23.12.1"
ruff = "^0.0.285"
mypy = "^1.8.0"
pytest = "^7.4.3"
pytest-httpx = [
Expand All @@ -25,6 +24,7 @@ pytest-httpx = [
pytest-asyncio = "^0.21.1"
coverage = "<=7.3.0"
pytest-mock = "^3.12.0"
ruff = "^0.0.291"

[build-system]
requires = ["poetry-core"]
Expand Down
21 changes: 15 additions & 6 deletions src/gigachat/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Optional,
TypeVar,
Union,
cast,
)

import httpx
Expand Down Expand Up @@ -215,7 +216,7 @@ def _update_token(self) -> None:
credentials=self._settings.credentials,
scope=self._settings.scope,
)
_logger.info("OAUTH UPDATE TOKEN")
_logger.debug("OAUTH UPDATE TOKEN")
elif self._settings.user and self._settings.password:
self._access_token = _build_access_token(
post_token.sync(
Expand All @@ -226,13 +227,17 @@ def _update_token(self) -> None:
)
_logger.info("UPDATE TOKEN")

def get_token(self) -> AccessToken:
self._update_token()
return cast(AccessToken, self._access_token)

def _decorator(self, call: Callable[..., T]) -> T:
if self._use_auth:
if self._check_validity_token():
try:
return call()
except AuthenticationError:
_logger.warning("AUTHENTICATION ERROR")
_logger.debug("AUTHENTICATION ERROR")
self._reset_token()
self._update_token()
return call()
Expand Down Expand Up @@ -289,7 +294,7 @@ def stream(self, payload: Union[Chat, Dict[str, Any], str]) -> Iterator[ChatComp
yield chunk
return
except AuthenticationError:
_logger.warning("AUTHENTICATION ERROR")
_logger.debug("AUTHENTICATION ERROR")
self._reset_token()
self._update_token()

Expand Down Expand Up @@ -333,7 +338,7 @@ async def _aupdate_token(self) -> None:
credentials=self._settings.credentials,
scope=self._settings.scope,
)
_logger.info("OAUTH UPDATE TOKEN")
_logger.debug("OAUTH UPDATE TOKEN")
elif self._settings.user and self._settings.password:
self._access_token = _build_access_token(
await post_token.asyncio(
Expand All @@ -344,13 +349,17 @@ async def _aupdate_token(self) -> None:
)
_logger.info("UPDATE TOKEN")

async def aget_token(self) -> AccessToken:
await self._aupdate_token()
return cast(AccessToken, self._access_token)

async def _adecorator(self, acall: Callable[..., Awaitable[T]]) -> T:
if self._use_auth:
if self._check_validity_token():
try:
return await acall()
except AuthenticationError:
_logger.warning("AUTHENTICATION ERROR")
_logger.debug("AUTHENTICATION ERROR")
self._reset_token()
await self._aupdate_token()
return await acall()
Expand Down Expand Up @@ -429,7 +438,7 @@ async def astream(self, payload: Union[Chat, Dict[str, Any], str]) -> AsyncItera
yield chunk
return
except AuthenticationError:
_logger.warning("AUTHENTICATION ERROR")
_logger.debug("AUTHENTICATION ERROR")
self._reset_token()
await self._aupdate_token()

Expand Down
8 changes: 4 additions & 4 deletions src/gigachat/threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def run_messages_stream(
yield chunk
return
except AuthenticationError:
_logger.warning("AUTHENTICATION ERROR")
_logger.debug("AUTHENTICATION ERROR")
self.base_client._reset_token()
self.base_client._update_token()

Expand Down Expand Up @@ -271,7 +271,7 @@ def rerun_messages_stream(
yield chunk
return
except AuthenticationError:
_logger.warning("AUTHENTICATION ERROR")
_logger.debug("AUTHENTICATION ERROR")
self.base_client._reset_token()
self.base_client._update_token()

Expand Down Expand Up @@ -479,7 +479,7 @@ async def run_messages_stream(
yield chunk
return
except AuthenticationError:
_logger.warning("AUTHENTICATION ERROR")
_logger.debug("AUTHENTICATION ERROR")
self.base_client._reset_token()
await self.base_client._aupdate_token()

Expand Down Expand Up @@ -518,7 +518,7 @@ async def rerun_messages_stream(
yield chunk
return
except AuthenticationError:
_logger.warning("AUTHENTICATION ERROR")
_logger.debug("AUTHENTICATION ERROR")
self.base_client._reset_token()
await self.base_client._aupdate_token()

Expand Down
2 changes: 1 addition & 1 deletion tests/data/tokens_count.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
"tokens": 7,
"characters": 36
}
]
]
30 changes: 30 additions & 0 deletions tests/unit_tests/gigachat/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pytest_httpx import HTTPXMock
from pytest_mock import MockerFixture

from gigachat import GigaChat
from gigachat.client import (
GIGACHAT_MODEL,
GigaChatAsyncClient,
Expand Down Expand Up @@ -364,6 +365,20 @@ def test_stream_update_token_error(httpx_mock: HTTPXMock) -> None:
assert client.token != access_token


def test_get_token_credentials(httpx_mock: HTTPXMock) -> None:
httpx_mock.add_response(url=AUTH_URL, json=ACCESS_TOKEN)

model = GigaChat(
base_url=BASE_URL,
auth_url=AUTH_URL,
credentials=CREDENTIALS,
)
access_token = model.get_token()

assert model._access_token == ACCESS_TOKEN
assert access_token == ACCESS_TOKEN


@pytest.mark.asyncio()
async def test_aget_models(httpx_mock: HTTPXMock) -> None:
httpx_mock.add_response(url=MODELS_URL, json=MODELS)
Expand Down Expand Up @@ -579,3 +594,18 @@ async def test_aupload_file(httpx_mock: HTTPXMock) -> None:
response = await client.aupload_file(file=FILE)

assert isinstance(response, UploadedFile)


@pytest.mark.asyncio()
async def test_aget_token_credentials(httpx_mock: HTTPXMock) -> None:
httpx_mock.add_response(url=AUTH_URL, json=ACCESS_TOKEN)

model = GigaChat(
base_url=BASE_URL,
auth_url=AUTH_URL,
credentials=CREDENTIALS,
)
access_token = await model.aget_token()

assert model._access_token == ACCESS_TOKEN
assert access_token == ACCESS_TOKEN

0 comments on commit b6b3952

Please sign in to comment.