Skip to content

Commit

Permalink
server: more support for mixed-case model names (ollama#8017)
Browse files Browse the repository at this point in the history
  • Loading branch information
bmizerany authored Dec 11, 2024
1 parent 36d111e commit b1fd7fe
Show file tree
Hide file tree
Showing 8 changed files with 123 additions and 38 deletions.
2 changes: 1 addition & 1 deletion cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,7 @@ func ListHandler(cmd *cobra.Command, args []string) error {
var data [][]string

for _, m := range models.Models {
if len(args) == 0 || strings.HasPrefix(m.Name, args[0]) {
if len(args) == 0 || strings.HasPrefix(strings.ToLower(m.Name), strings.ToLower(args[0])) {
data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), format.HumanTime(m.ModifiedAt, "Never")})
}
}
Expand Down
4 changes: 4 additions & 0 deletions server/images.go
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,10 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
switch command {
case "model", "adapter":
if name := model.ParseName(c.Args); name.IsValid() && command == "model" {
name, err := getExistingName(name)
if err != nil {
return err
}
baseLayers, err = parseFromModel(ctx, name, fn)
if err != nil {
return err
Expand Down
15 changes: 11 additions & 4 deletions server/modelpath.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ package server
import (
"errors"
"fmt"
"io/fs"
"net/url"
"os"
"path/filepath"
"regexp"
"strings"

"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/model"
)

type ModelPath struct {
Expand Down Expand Up @@ -93,11 +95,16 @@ func (mp ModelPath) GetShortTagname() string {

// GetManifestPath returns the path to the manifest file for the given model path, it is up to the caller to create the directory if it does not exist.
func (mp ModelPath) GetManifestPath() (string, error) {
if p := filepath.Join(mp.Registry, mp.Namespace, mp.Repository, mp.Tag); filepath.IsLocal(p) {
return filepath.Join(envconfig.Models(), "manifests", p), nil
name := model.Name{
Host: mp.Registry,
Namespace: mp.Namespace,
Model: mp.Repository,
Tag: mp.Tag,
}

return "", errModelPathInvalid
if !name.IsValid() {
return "", fs.ErrNotExist
}
return filepath.Join(envconfig.Models(), "manifests", name.Filepath()), nil
}

func (mp ModelPath) BaseURL() *url.URL {
Expand Down
8 changes: 0 additions & 8 deletions server/modelpath_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package server

import (
"errors"
"os"
"path/filepath"
"testing"
Expand Down Expand Up @@ -155,10 +154,3 @@ func TestParseModelPath(t *testing.T) {
})
}
}

func TestInsecureModelpath(t *testing.T) {
mp := ParseModelPath("../../..:something")
if _, err := mp.GetManifestPath(); !errors.Is(err, errModelPathInvalid) {
t.Errorf("expected error: %v", err)
}
}
112 changes: 90 additions & 22 deletions server/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"errors"
"fmt"
"io"
"io/fs"
"log/slog"
"math"
"net"
Expand Down Expand Up @@ -120,10 +121,26 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}

model, err := GetModel(req.Model)
name := model.ParseName(req.Model)
if !name.IsValid() {
// Ideally this is "invalid model name" but we're keeping with
// what the API currently returns until we can change it.
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
return
}

// We cannot currently consolidate this into GetModel because all we'll
// induce infinite recursion given the current code structure.
name, err := getExistingName(name)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
return
}

model, err := GetModel(name.String())
if err != nil {
switch {
case os.IsNotExist(err):
case errors.Is(err, fs.ErrNotExist):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
case err.Error() == "invalid model name":
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
Expand Down Expand Up @@ -157,7 +174,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
caps = append(caps, CapabilityInsert)
}

r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
if errors.Is(err, errCapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
return
Expand Down Expand Up @@ -386,7 +403,13 @@ func (s *Server) EmbedHandler(c *gin.Context) {
}
}

r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
name, err := getExistingName(model.ParseName(req.Model))
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
return
}

r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), []Capability{}, req.Options, req.KeepAlive)
if err != nil {
handleScheduleError(c, req.Model, err)
return
Expand Down Expand Up @@ -489,7 +512,13 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
return
}

r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
name := model.ParseName(req.Model)
if !name.IsValid() {
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}

r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []Capability{}, req.Options, req.KeepAlive)
if err != nil {
handleScheduleError(c, req.Model, err)
return
Expand Down Expand Up @@ -582,11 +611,11 @@ func (s *Server) PushHandler(c *gin.Context) {
return
}

var model string
var mname string
if req.Model != "" {
model = req.Model
mname = req.Model
} else if req.Name != "" {
model = req.Name
mname = req.Name
} else {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
Expand All @@ -606,7 +635,13 @@ func (s *Server) PushHandler(c *gin.Context) {
ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()

if err := PushModel(ctx, model, regOpts, fn); err != nil {
name, err := getExistingName(model.ParseName(mname))
if err != nil {
ch <- gin.H{"error": err.Error()}
return
}

if err := PushModel(ctx, name.DisplayShortest(), regOpts, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
Expand All @@ -619,17 +654,29 @@ func (s *Server) PushHandler(c *gin.Context) {
streamResponse(c, ch)
}

// getExistingName returns the original, on disk name if the input name is a
// case-insensitive match, otherwise it returns the input name.
// getExistingName searches the models directory for the longest prefix match of
// the input name and returns the input name with all existing parts replaced
// with each part found. If no parts are found, the input name is returned as
// is.
func getExistingName(n model.Name) (model.Name, error) {
var zero model.Name
existing, err := Manifests(true)
if err != nil {
return zero, err
}
var set model.Name // tracks parts already canonicalized
for e := range existing {
if n.EqualFold(e) {
return e, nil
if set.Host == "" && strings.EqualFold(e.Host, n.Host) {
n.Host = e.Host
}
if set.Namespace == "" && strings.EqualFold(e.Namespace, n.Namespace) {
n.Namespace = e.Namespace
}
if set.Model == "" && strings.EqualFold(e.Model, n.Model) {
n.Model = e.Model
}
if set.Tag == "" && strings.EqualFold(e.Tag, n.Tag) {
n.Tag = e.Tag
}
}
return n, nil
Expand Down Expand Up @@ -658,7 +705,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
}

if r.Path == "" && r.Modelfile == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"})
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or Modelfile are required"})
return
}

Expand Down Expand Up @@ -722,6 +769,12 @@ func (s *Server) DeleteHandler(c *gin.Context) {
return
}

n, err := getExistingName(n)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", cmp.Or(r.Model, r.Name))})
return
}

m, err := ParseNamedManifest(n)
if err != nil {
switch {
Expand Down Expand Up @@ -782,7 +835,16 @@ func (s *Server) ShowHandler(c *gin.Context) {
}

func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
m, err := GetModel(req.Model)
name := model.ParseName(req.Model)
if !name.IsValid() {
return nil, errModelPathInvalid
}
name, err := getExistingName(name)
if err != nil {
return nil, err
}

m, err := GetModel(name.String())
if err != nil {
return nil, err
}
Expand All @@ -805,12 +867,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
msgs[i] = api.Message{Role: msg.Role, Content: msg.Content}
}

n := model.ParseName(req.Model)
if !n.IsValid() {
return nil, errors.New("invalid model name")
}

manifest, err := ParseNamedManifest(n)
manifest, err := ParseNamedManifest(name)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1431,7 +1488,18 @@ func (s *Server) ChatHandler(c *gin.Context) {
caps = append(caps, CapabilityTools)
}

r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
name := model.ParseName(req.Model)
if !name.IsValid() {
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
name, err := getExistingName(name)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}

r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
if errors.Is(err, errCapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
return
Expand Down
2 changes: 1 addition & 1 deletion server/routes_generate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,7 @@ func TestGenerate(t *testing.T) {
t.Errorf("expected status 400, got %d", w.Code)
}

if diff := cmp.Diff(w.Body.String(), `{"error":"test does not support insert"}`); diff != "" {
if diff := cmp.Diff(w.Body.String(), `{"error":"registry.ollama.ai/library/test:latest does not support insert"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
Expand Down
14 changes: 14 additions & 0 deletions server/routes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,8 @@ func TestManifestCaseSensitivity(t *testing.T) {

wantStableName := name()

t.Logf("stable name: %s", wantStableName)

// checkManifestList tests that there is strictly one manifest in the
// models directory, and that the manifest is for the model under test.
checkManifestList := func() {
Expand Down Expand Up @@ -601,6 +603,18 @@ func TestManifestCaseSensitivity(t *testing.T) {
Destination: name(),
}))
checkManifestList()

t.Logf("pushing")
rr := createRequest(t, s.PushHandler, api.PushRequest{
Model: name(),
Insecure: true,
Username: "alice",
Password: "x",
})
checkOK(rr)
if !strings.Contains(rr.Body.String(), `"status":"success"`) {
t.Errorf("got = %q, want success", rr.Body.String())
}
}

func TestShow(t *testing.T) {
Expand Down
4 changes: 2 additions & 2 deletions types/model/name.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,12 +223,12 @@ func (n Name) String() string {
func (n Name) DisplayShortest() string {
var sb strings.Builder

if n.Host != defaultHost {
if !strings.EqualFold(n.Host, defaultHost) {
sb.WriteString(n.Host)
sb.WriteByte('/')
sb.WriteString(n.Namespace)
sb.WriteByte('/')
} else if n.Namespace != defaultNamespace {
} else if !strings.EqualFold(n.Namespace, defaultNamespace) {
sb.WriteString(n.Namespace)
sb.WriteByte('/')
}
Expand Down

0 comments on commit b1fd7fe

Please sign in to comment.