Skip to content

Commit

Permalink
fix async client
Browse files Browse the repository at this point in the history
  • Loading branch information
pdevine committed Jan 13, 2025
1 parent a0388b2 commit f5c8ee0
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 161 deletions.
89 changes: 36 additions & 53 deletions ollama/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,24 +546,6 @@ def create(
stream=stream,
)

def _parse_modelfile(self, modelfile: str, base: Optional[Path] = None) -> str:
base = Path.cwd() if base is None else base

out = io.StringIO()
for line in io.StringIO(modelfile):
command, _, args = line.partition(' ')
if command.upper() not in ['FROM', 'ADAPTER']:
print(line, end='', file=out)
continue

path = Path(args.strip()).expanduser()
path = path if path.is_absolute() else base / path
if path.exists():
args = f'@{self.create_blob(path)}\n'
print(command, args, end='', file=out)

return out.getvalue()

def create_blob(self, path: Union[str, Path]) -> str:
sha256sum = sha256()
with open(path, 'rb') as r:
Expand Down Expand Up @@ -996,76 +978,77 @@ async def push(
async def create(
self,
model: str,
path: Optional[Union[str, PathLike]] = None,
modelfile: Optional[str] = None,
*,
quantize: Optional[str] = None,
stream: Literal[False] = False,
from_: Optional[str] = None,
files: Optional[dict[str, str]] = None,
adapters: Optional[dict[str, str]] = None,
template: Optional[str] = None,
license: Optional[Union[str, list[str]]] = None,
system: Optional[str] = None,
parameters: Optional[Union[Mapping[str, Any], Options]] = None,
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
*,
stream: Literal[True] = True,
) -> ProgressResponse: ...

@overload
async def create(
self,
model: str,
path: Optional[Union[str, PathLike]] = None,
modelfile: Optional[str] = None,
*,
quantize: Optional[str] = None,
from_: Optional[str] = None,
files: Optional[dict[str, str]] = None,
adapters: Optional[dict[str, str]] = None,
template: Optional[str] = None,
license: Optional[Union[str, list[str]]] = None,
system: Optional[str] = None,
parameters: Optional[Union[Mapping[str, Any], Options]] = None,
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
*,
stream: Literal[True] = True,
) -> AsyncIterator[ProgressResponse]: ...

async def create(
self,
model: str,
path: Optional[Union[str, PathLike]] = None,
modelfile: Optional[str] = None,
*,
quantize: Optional[str] = None,
from_: Optional[str] = None,
files: Optional[dict[str, str]] = None,
adapters: Optional[dict[str, str]] = None,
template: Optional[str] = None,
license: Optional[Union[str, list[str]]] = None,
system: Optional[str] = None,
parameters: Optional[Union[Mapping[str, Any], Options]] = None,
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
*,
stream: bool = False,
) -> Union[ProgressResponse, AsyncIterator[ProgressResponse]]:
"""
Raises `ResponseError` if the request could not be fulfilled.
Returns `ProgressResponse` if `stream` is `False`, otherwise returns a `ProgressResponse` generator.
"""
if (realpath := _as_path(path)) and realpath.exists():
modelfile = await self._parse_modelfile(realpath.read_text(), base=realpath.parent)
elif modelfile:
modelfile = await self._parse_modelfile(modelfile)
else:
raise RequestError('must provide either path or modelfile')

return await self._request(
ProgressResponse,
'POST',
'/api/create',
json=CreateRequest(
model=model,
modelfile=modelfile,
stream=stream,
quantize=quantize,
from_=from_,
files=files,
adapters=adapters,
license=license,
template=template,
system=system,
parameters=parameters,
messages=messages,
).model_dump(exclude_none=True),
stream=stream,
)

async def _parse_modelfile(self, modelfile: str, base: Optional[Path] = None) -> str:
base = Path.cwd() if base is None else base

out = io.StringIO()
for line in io.StringIO(modelfile):
command, _, args = line.partition(' ')
if command.upper() not in ['FROM', 'ADAPTER']:
print(line, end='', file=out)
continue

path = Path(args.strip()).expanduser()
path = path if path.is_absolute() else base / path
if path.exists():
args = f'@{await self.create_blob(path)}\n'
print(command, args, end='', file=out)

return out.getvalue()

async def create_blob(self, path: Union[str, Path]) -> str:
sha256sum = sha256()
with open(path, 'rb') as r:
Expand Down
131 changes: 23 additions & 108 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,117 +933,39 @@ def generate():


@pytest.mark.asyncio
async def test_async_client_create_path(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200))
httpserver.expect_ordered_request(
'/api/create',
method='POST',
json={
'model': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
},
).respond_with_json({'status': 'success'})

client = AsyncClient(httpserver.url_for('/'))

with tempfile.NamedTemporaryFile() as modelfile:
with tempfile.NamedTemporaryFile() as blob:
modelfile.write(f'FROM {blob.name}'.encode('utf-8'))
modelfile.flush()

response = await client.create('dummy', path=modelfile.name)
assert response['status'] == 'success'


@pytest.mark.asyncio
async def test_async_client_create_path_relative(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200))
async def test_async_client_create_with_blob(httpserver: HTTPServer):
httpserver.expect_ordered_request(
'/api/create',
method='POST',
json={
'model': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
},
).respond_with_json({'status': 'success'})

client = AsyncClient(httpserver.url_for('/'))

with tempfile.NamedTemporaryFile() as modelfile:
with tempfile.NamedTemporaryFile(dir=Path(modelfile.name).parent) as blob:
modelfile.write(f'FROM {Path(blob.name).name}'.encode('utf-8'))
modelfile.flush()

response = await client.create('dummy', path=modelfile.name)
assert response['status'] == 'success'


@pytest.mark.asyncio
async def test_async_client_create_path_user_home(httpserver: HTTPServer, userhomedir):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200))
httpserver.expect_ordered_request(
'/api/create',
method='POST',
json={
'model': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
},
).respond_with_json({'status': 'success'})

client = AsyncClient(httpserver.url_for('/'))

with tempfile.NamedTemporaryFile() as modelfile:
with tempfile.NamedTemporaryFile(dir=userhomedir) as blob:
modelfile.write(f'FROM ~/{Path(blob.name).name}'.encode('utf-8'))
modelfile.flush()

response = await client.create('dummy', path=modelfile.name)
assert response['status'] == 'success'


@pytest.mark.asyncio
async def test_async_client_create_modelfile(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200))
httpserver.expect_ordered_request(
'/api/create',
method='POST',
json={
'model': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'files': {'test.gguf': 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'},
'stream': False,
},
).respond_with_json({'status': 'success'})

client = AsyncClient(httpserver.url_for('/'))

with tempfile.NamedTemporaryFile() as blob:

Check failure on line 949 in tests/test_client.py

View workflow job for this annotation

GitHub Actions / test (3.9)

Ruff (F841)

tests/test_client.py:949:41: F841 Local variable `blob` is assigned to but never used

Check failure on line 949 in tests/test_client.py

View workflow job for this annotation

GitHub Actions / test (3.10)

Ruff (F841)

tests/test_client.py:949:41: F841 Local variable `blob` is assigned to but never used

Check failure on line 949 in tests/test_client.py

View workflow job for this annotation

GitHub Actions / test (3.11)

Ruff (F841)

tests/test_client.py:949:41: F841 Local variable `blob` is assigned to but never used

Check failure on line 949 in tests/test_client.py

View workflow job for this annotation

GitHub Actions / test (3.12)

Ruff (F841)

tests/test_client.py:949:41: F841 Local variable `blob` is assigned to but never used
response = await client.create('dummy', modelfile=f'FROM {blob.name}')
response = await client.create('dummy', files={'test.gguf': 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'})
assert response['status'] == 'success'


@pytest.mark.asyncio
async def test_async_client_create_modelfile_roundtrip(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200))
async def test_async_client_create_with_parameters_roundtrip(httpserver: HTTPServer):
httpserver.expect_ordered_request(
'/api/create',
method='POST',
json={
'model': 'dummy',
'modelfile': '''FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855
TEMPLATE """[INST] <<SYS>>{{.System}}<</SYS>>
{{.Prompt}} [/INST]"""
SYSTEM """
Use
multiline
strings.
"""
PARAMETER stop [INST]
PARAMETER stop [/INST]
PARAMETER stop <<SYS>>
PARAMETER stop <</SYS>>''',
'quantize': 'q4_k_m',
'from': 'mymodel',
'adapters': {'someadapter.gguf': 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'},
'template': '[INST] <<SYS>>{{.System}}<</SYS>>\n{{.Prompt}} [/INST]',
'license': 'this is my license',
'system': '\nUse\nmultiline\nstrings.\n',
'parameters': {'stop': ['[INST]', '[/INST]', '<<SYS>>', '<</SYS>>'], 'pi': 3.14159},
'messages': [{'role': 'user', 'content': 'Hello there!'}, {'role': 'assistant', 'content': 'Hello there yourself!'}],
'stream': False,
},
).respond_with_json({'status': 'success'})
Expand All @@ -1053,22 +975,15 @@ async def test_async_client_create_modelfile_roundtrip(httpserver: HTTPServer):
with tempfile.NamedTemporaryFile() as blob:

Check failure on line 975 in tests/test_client.py

View workflow job for this annotation

GitHub Actions / test (3.9)

Ruff (F841)

tests/test_client.py:975:41: F841 Local variable `blob` is assigned to but never used

Check failure on line 975 in tests/test_client.py

View workflow job for this annotation

GitHub Actions / test (3.10)

Ruff (F841)

tests/test_client.py:975:41: F841 Local variable `blob` is assigned to but never used

Check failure on line 975 in tests/test_client.py

View workflow job for this annotation

GitHub Actions / test (3.11)

Ruff (F841)

tests/test_client.py:975:41: F841 Local variable `blob` is assigned to but never used

Check failure on line 975 in tests/test_client.py

View workflow job for this annotation

GitHub Actions / test (3.12)

Ruff (F841)

tests/test_client.py:975:41: F841 Local variable `blob` is assigned to but never used
response = await client.create(
'dummy',
modelfile='\n'.join(
[
f'FROM {blob.name}',
'TEMPLATE """[INST] <<SYS>>{{.System}}<</SYS>>',
'{{.Prompt}} [/INST]"""',
'SYSTEM """',
'Use',
'multiline',
'strings.',
'"""',
'PARAMETER stop [INST]',
'PARAMETER stop [/INST]',
'PARAMETER stop <<SYS>>',
'PARAMETER stop <</SYS>>',
]
),
quantize='q4_k_m',
from_='mymodel',
adapters={'someadapter.gguf': 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'},
template='[INST] <<SYS>>{{.System}}<</SYS>>\n{{.Prompt}} [/INST]',
license='this is my license',
system='\nUse\nmultiline\nstrings.\n',
parameters={'stop': ['[INST]', '[/INST]', '<<SYS>>', '<</SYS>>'], 'pi': 3.14159},
messages=[{'role': 'user', 'content': 'Hello there!'}, {'role': 'assistant', 'content': 'Hello there yourself!'}],
stream=False,
)
assert response['status'] == 'success'

Expand All @@ -1080,14 +995,14 @@ async def test_async_client_create_from_library(httpserver: HTTPServer):
method='POST',
json={
'model': 'dummy',
'modelfile': 'FROM llama2',
'from': 'llama2',
'stream': False,
},
).respond_with_json({'status': 'success'})

client = AsyncClient(httpserver.url_for('/'))

response = await client.create('dummy', modelfile='FROM llama2')
response = await client.create('dummy', from_='llama2')
assert response['status'] == 'success'


Expand Down

0 comments on commit f5c8ee0

Please sign in to comment.