Skip to content

Commit

Permalink
Expose CLI flag to disable using GPU for offline chat model
Browse files Browse the repository at this point in the history
- Offline chat models outputing gibberish when loaded onto some GPU.
  GPU support with Vulkan in GPT4All seems a bit buggy

- This change mitigates the upstream issue by allowing user to
  manually disable using GPU for offline chat

Closes #516
  • Loading branch information
debanjum committed Oct 26, 2023
1 parent 5bb14a0 commit 9677eae
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 2 deletions.
1 change: 1 addition & 0 deletions src/khoj/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def set_state(args):
state.port = args.port
state.demo = args.demo
state.khoj_version = version("khoj-assistant")
state.chat_on_gpu = args.chat_on_gpu


def start_server(app, host=None, port=None, socket=None):
Expand Down
7 changes: 5 additions & 2 deletions src/khoj/processor/conversation/gpt4all/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging

from khoj.utils import state


logger = logging.getLogger(__name__)

Expand All @@ -13,8 +15,9 @@ def download_model(model_name: str):

# Use GPU for Chat Model, if available
try:
model = GPT4All(model_name=model_name, device="gpu")
logger.debug(f"Loaded {model_name} chat model to GPU.")
device = "gpu" if state.chat_on_gpu else "cpu"
model = GPT4All(model_name=model_name, device=device)
logger.debug(f"Loaded {model_name} chat model to {device.upper()}")
except ValueError:
model = GPT4All(model_name=model_name)
logger.debug(f"Loaded {model_name} chat model to CPU.")
Expand Down
6 changes: 6 additions & 0 deletions src/khoj/utils/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,16 @@ def cli(args=None):
help="Path to UNIX socket for server. Use to run server behind reverse proxy. Default: /tmp/uvicorn.sock",
)
parser.add_argument("--version", "-V", action="store_true", help="Print the installed Khoj version and exit")
parser.add_argument(
"--disable-chat-on-gpu", action="store_true", default=False, help="Disable using GPU for the offline chat model"
)
parser.add_argument("--demo", action="store_true", default=False, help="Run Khoj in demo mode")

args = parser.parse_args(args)

# Set default values for arguments
args.chat_on_gpu = not args.disable_chat_on_gpu

args.version_no = version("khoj-assistant")
if args.version:
# Show version of khoj installed and exit
Expand Down
2 changes: 2 additions & 0 deletions src/khoj/utils/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
previous_query: str = None
demo: bool = False
khoj_version: str = None
chat_on_gpu: bool = True


if torch.cuda.is_available():
# Use CUDA GPU
Expand Down

0 comments on commit 9677eae

Please sign in to comment.