Skip to content

Commit

Permalink
fix: ensure that KV values of true are rendered to output, fixes #195
Browse files Browse the repository at this point in the history
  • Loading branch information
a-h committed Oct 3, 2023
1 parent 3fa0ebd commit 8543fac
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 21 deletions.
3 changes: 2 additions & 1 deletion generator/test-css-usage/expected.html
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
<button class="green_58d2" type="button">Green</button>
<div class="a c"></div>
<div class="a"></div>
<input type="email" id="email" name="email" class="a b" placeholder="your@email.com" autocomplete="off"/>
<style type="text/css">.e_739d{font-size:14pt;}</style>
<input type="email" id="email" name="email" class="a b e_739d" placeholder="your@email.com" autocomplete="off"/>
<button class="bg-violet-500 hover:bg-violet-600">Save changes</button>

10 changes: 9 additions & 1 deletion generator/test-css-usage/template.templ
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,14 @@ css className() {
color: { red };
}

css d() {
font-size: 12pt;
}

css e() {
font-size: 14pt;
}

templ Button(text string) {
<button class={ className(), templ.Class("&&&unsafe"), "safe", templ.SafeClass("safe2") } type="button">{ text }</button>
}
Expand All @@ -23,7 +31,7 @@ templ MapCSSExample() {

templ KVExample() {
<div class={ "a", templ.KV("b", false) }></div>
<input type="email" id="email" name="email" class={ "a", "b", "c", templ.KV("c", false) } placeholder="your@email.com" autocomplete="off"/>
<input type="email" id="email" name="email" class={ "a", "b", "c", templ.KV("c", false), templ.KV(d(), false), templ.KV(e(), true) } placeholder="your@email.com" autocomplete="off"/>
}

templ PsuedoAttributes() {
Expand Down
22 changes: 21 additions & 1 deletion generator/test-css-usage/template_templ.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

62 changes: 45 additions & 17 deletions runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -382,23 +382,7 @@ func RenderCSSItems(ctx context.Context, w io.Writer, classes ...any) (err error
}
_, v := getContext(ctx)
sb := new(strings.Builder)
for _, c := range classes {
switch ccc := c.(type) {
case ComponentCSSClass:
if !v.hasClassBeenRendered(ccc.ID) {
sb.WriteString(string(ccc.Class))
v.addClass(ccc.ID)
}
case CSSClasses:
if err = RenderCSSItems(ctx, w, ccc...); err != nil {
return
}
case func() CSSClass:
if err = RenderCSSItems(ctx, w, ccc()); err != nil {
return
}
}
}
renderCSSItemsToBuilder(sb, v, classes...)
if sb.Len() > 0 {
if _, err = io.WriteString(w, `<style type="text/css">`); err != nil {
return err
Expand All @@ -413,6 +397,50 @@ func RenderCSSItems(ctx context.Context, w io.Writer, classes ...any) (err error
return nil
}

func renderCSSItemsToBuilder(sb *strings.Builder, v *contextValue, classes ...any) {
for _, c := range classes {
switch ccc := c.(type) {
case ComponentCSSClass:
if !v.hasClassBeenRendered(ccc.ID) {
sb.WriteString(string(ccc.Class))
v.addClass(ccc.ID)
}
case KeyValue[ComponentCSSClass, bool]:
if !ccc.Value {
continue
}
renderCSSItemsToBuilder(sb, v, ccc.Key)
case KeyValue[CSSClass, bool]:
if !ccc.Value {
continue
}
renderCSSItemsToBuilder(sb, v, ccc.Key)
case CSSClasses:
renderCSSItemsToBuilder(sb, v, ccc...)
case func() CSSClass:
renderCSSItemsToBuilder(sb, v, ccc())
case []string:
// Skip. These are class names, not CSS classes.
case string:
// Skip. This is a class name, not a CSS class.
case ConstantCSSClass:
// Skip. This is a class name, not a CSS class.
case CSSClass:
// Skip. This is a class name, not a CSS class.
case map[string]bool:
// Skip. These are class names, not CSS classes.
case KeyValue[string, bool]:
// Skip. These are class names, not CSS classes.
case []KeyValue[string, bool]:
// Skip. These are class names, not CSS classes.
case KeyValue[ConstantCSSClass, bool]:
// Skip. These are class names, not CSS classes.
case []KeyValue[ConstantCSSClass, bool]:
// Skip. These are class names, not CSS classes.
}
}
}

// SafeCSS is CSS that has been sanitized.
type SafeCSS string

Expand Down
54 changes: 53 additions & 1 deletion runtime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,47 @@ func TestCSSMiddleware(t *testing.T) {
}
}

var cssInputs = []any{
[]string{"a", "b"}, // []string
"c", // string
templ.ConstantCSSClass("d"), // ConstantCSSClass
templ.ComponentCSSClass{ID: "e", Class: ".e{color:red}"}, // ComponentCSSClass
map[string]bool{"f": true, "ff": false}, // map[string]bool
templ.KV[string, bool]("g", true), // KeyValue[string, bool]
templ.KV[string, bool]("gg", false), // KeyValue[string, bool]
[]templ.KeyValue[string, bool]{
templ.KV("h", true),
templ.KV("hh", false),
}, // []KeyValue[string, bool]
templ.KV[templ.CSSClass, bool](templ.ConstantCSSClass("i"), true), // KeyValue[CSSClass, bool]
templ.KV[templ.CSSClass, bool](templ.ConstantCSSClass("ii"), false), // KeyValue[CSSClass, bool]
templ.KV[templ.ComponentCSSClass, bool](templ.ComponentCSSClass{
ID: "j",
Class: ".j{color:red}",
}, true), // KeyValue[ComponentCSSClass, bool]
templ.KV[templ.ComponentCSSClass, bool](templ.ComponentCSSClass{
ID: "jj",
Class: ".jj{color:red}",
}, false), // KeyValue[ComponentCSSClass, bool]
templ.CSSClasses{templ.ConstantCSSClass("k")}, // CSSClasses
func() templ.CSSClass { return templ.ConstantCSSClass("l") }, // func() CSSClass
templ.CSSClass(templ.ConstantCSSClass("m")), // CSSClass
customClass{name: "n"}, // CSSClass
templ.KV[templ.ConstantCSSClass, bool](templ.ConstantCSSClass("o"), true), // KeyValue[ConstantCSSClass, bool]
[]templ.KeyValue[templ.ConstantCSSClass, bool]{
templ.KV(templ.ConstantCSSClass("p"), true),
templ.KV(templ.ConstantCSSClass("pp"), false),
}, // []KeyValue[ConstantCSSClass, bool]
}

type customClass struct {
name string
}

func (cc customClass) ClassName() string {
return cc.name
}

func TestRenderCSS(t *testing.T) {
c1 := templ.ComponentCSSClass{
ID: "c1",
Expand All @@ -130,11 +171,13 @@ func TestRenderCSS(t *testing.T) {
tests := []struct {
name string
toIgnore []any
toRender []any
expected string
}{
{
name: "if none are ignored, everything is rendered",
toIgnore: nil,
toRender: []any{c1, c2},
expected: `<style type="text/css">.c1{color:red}.c2{color:blue}</style>`,
},
{
Expand All @@ -145,11 +188,13 @@ func TestRenderCSS(t *testing.T) {
Class: templ.SafeCSS(".c3{color:yellow}"),
},
},
toRender: []any{c1, c2},
expected: `<style type="text/css">.c1{color:red}.c2{color:blue}</style>`,
},
{
name: "if one is ignored, it's not rendered",
toIgnore: []any{c1},
toRender: []any{c1, c2},
expected: `<style type="text/css">.c2{color:blue}</style>`,
},
{
Expand All @@ -162,8 +207,15 @@ func TestRenderCSS(t *testing.T) {
Class: templ.SafeCSS(".c3{color:yellow}"),
},
},
toRender: []any{c1, c2},
expected: ``,
},
{
name: "CSS classes are rendered",
toIgnore: nil,
toRender: cssInputs,
expected: `<style type="text/css">.e{color:red}.j{color:red}</style>`,
},
}
for _, tt := range tests {
tt := tt
Expand All @@ -180,7 +232,7 @@ func TestRenderCSS(t *testing.T) {

// Now render again to check that only the expected classes were rendered.
b.Reset()
err = templ.RenderCSSItems(ctx, b, []any{c1, c2}...)
err = templ.RenderCSSItems(ctx, b, tt.toRender...)
if err != nil {
t.Fatalf("failed to render CSS: %v", err)
}
Expand Down

0 comments on commit 8543fac

Please sign in to comment.