Skip to content

Commit

Permalink
ensure all ollama configuration values are converted to strings befor…
Browse files Browse the repository at this point in the history
…e assignment.
  • Loading branch information
iamarunbrahma committed Jan 5, 2025
1 parent 30d9f55 commit c0f8059
Showing 1 changed file with 24 additions and 18 deletions.
42 changes: 24 additions & 18 deletions src/vision_parse/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ def _init_llm(self) -> None:
)

try:
os.environ["OLLAMA_KEEP_ALIVE"] = self.ollama_config.get(
"OLLAMA_KEEP_ALIVE", "-1"
os.environ["OLLAMA_KEEP_ALIVE"] = str(
self.ollama_config.get("OLLAMA_KEEP_ALIVE", -1)
)
if self.enable_concurrency:
self.aclient = ollama.AsyncClient(
Expand All @@ -133,31 +133,37 @@ def _init_llm(self) -> None:
timeout=self.ollama_config.get("OLLAMA_REQUEST_TIMEOUT", 240.0),
)
if self.device == "cuda":
os.environ["OLLAMA_NUM_GPU"] = self.ollama_config.get(
"OLLAMA_NUM_GPU", str(self.num_workers // 2)
os.environ["OLLAMA_NUM_GPU"] = str(
self.ollama_config.get(
"OLLAMA_NUM_GPU", self.num_workers // 2
)
)
os.environ["OLLAMA_NUM_PARALLEL"] = self.ollama_config.get(
"OLLAMA_NUM_PARALLEL", str(self.num_workers * 8)
os.environ["OLLAMA_NUM_PARALLEL"] = str(
self.ollama_config.get(
"OLLAMA_NUM_PARALLEL", self.num_workers * 8
)
)
os.environ["OLLAMA_GPU_LAYERS"] = self.ollama_config.get(
"OLLAMA_GPU_LAYERS", "all"
os.environ["OLLAMA_GPU_LAYERS"] = str(
self.ollama_config.get("OLLAMA_GPU_LAYERS", "all")
)
elif self.device == "mps":
os.environ["OLLAMA_NUM_GPU"] = self.ollama_config.get(
"OLLAMA_NUM_GPU", "1"
os.environ["OLLAMA_NUM_GPU"] = str(
self.ollama_config.get("OLLAMA_NUM_GPU", 1)
)
os.environ["OLLAMA_NUM_THREAD"] = self.ollama_config.get(
"OLLAMA_NUM_THREAD", str(self.num_workers)
os.environ["OLLAMA_NUM_THREAD"] = str(
self.ollama_config.get("OLLAMA_NUM_THREAD", self.num_workers)
)
os.environ["OLLAMA_NUM_PARALLEL"] = self.ollama_config.get(
"OLLAMA_NUM_PARALLEL", str(self.num_workers * 8)
os.environ["OLLAMA_NUM_PARALLEL"] = str(
self.ollama_config.get("OLLAMA_NUM_PARALLEL", self.num_workers * 8)
)
else:
os.environ["OLLAMA_NUM_THREAD"] = self.ollama_config.get(
"OLLAMA_NUM_THREAD", str(self.num_workers)
os.environ["OLLAMA_NUM_THREAD"] = str(
self.ollama_config.get("OLLAMA_NUM_THREAD", self.num_workers)
)
os.environ["OLLAMA_NUM_PARALLEL"] = self.ollama_config.get(
"OLLAMA_NUM_PARALLEL", str(self.num_workers * 10)
os.environ["OLLAMA_NUM_PARALLEL"] = str(
self.ollama_config.get(
"OLLAMA_NUM_PARALLEL", self.num_workers * 10
)
)
else:
self.client = ollama.Client(
Expand Down

0 comments on commit c0f8059

Please sign in to comment.