diff --git a/apidef/oas/oas_test.go b/apidef/oas/oas_test.go index 78e6e29cf54..f0ca0ee4f4b 100644 --- a/apidef/oas/oas_test.go +++ b/apidef/oas/oas_test.go @@ -240,7 +240,6 @@ func TestOAS_ExtractTo_ResetAPIDefinition(t *testing.T) { "APIDefinition.UptimeTests.Config.ExpireUptimeAnalyticsAfter", "APIDefinition.UptimeTests.Config.ServiceDiscovery.CacheDisabled", "APIDefinition.UptimeTests.Config.RecheckWait", - "APIDefinition.Proxy.PreserveHostHeader", "APIDefinition.Proxy.DisableStripSlash", "APIDefinition.Proxy.CheckHostAgainstUptimeTests", "APIDefinition.Proxy.Transport.SSLInsecureSkipVerify", diff --git a/apidef/oas/schema/x-tyk-api-gateway.json b/apidef/oas/schema/x-tyk-api-gateway.json index e85a8de8262..0450a310da3 100644 --- a/apidef/oas/schema/x-tyk-api-gateway.json +++ b/apidef/oas/schema/x-tyk-api-gateway.json @@ -1346,6 +1346,9 @@ }, "loadBalancing": { "$ref": "#/definitions/X-Tyk-LoadBalancing" + }, + "preserveHostHeader": { + "$ref": "#/definitions/X-Tyk-PreserveHostHeader" } }, "required": [ @@ -2283,6 +2286,17 @@ "enabled" ] }, + "X-Tyk-PreserveHostHeader": { + "type": "object", + "properties": { + "enabled": { + "type": "boolean" + } + }, + "required": [ + "enabled" + ] + }, "X-Tyk-LoadBalancingTarget": { "type": "object", "properties": { diff --git a/apidef/oas/upstream.go b/apidef/oas/upstream.go index fa9758c5633..2b7ecfb0c51 100644 --- a/apidef/oas/upstream.go +++ b/apidef/oas/upstream.go @@ -35,6 +35,9 @@ type Upstream struct { // LoadBalancing contains configuration for load balancing between multiple upstream targets. LoadBalancing *LoadBalancing `bson:"loadBalancing,omitempty" json:"loadBalancing,omitempty"` + + // PreserveHostHeader contains the configuration for preserving the host header. + PreserveHostHeader *PreserveHostHeader `bson:"preserveHostHeader,omitempty" json:"preserveHostHeader,omitempty"` } // Fill fills *Upstream from apidef.APIDefinition. @@ -96,6 +99,20 @@ func (u *Upstream) Fill(api apidef.APIDefinition) { } u.fillLoadBalancing(api) + + u.fillPreserveHostHeader(api) +} + +func (u *Upstream) fillPreserveHostHeader(api apidef.APIDefinition) { + if u.PreserveHostHeader == nil { + u.PreserveHostHeader = &PreserveHostHeader{} + } + + u.PreserveHostHeader.Fill(api) + + if ShouldOmit(u.PreserveHostHeader) { + u.PreserveHostHeader = nil + } } // ExtractTo extracts *Upstream into *apidef.APIDefinition. @@ -157,6 +174,19 @@ func (u *Upstream) ExtractTo(api *apidef.APIDefinition) { u.Authentication.ExtractTo(&api.UpstreamAuth) u.loadBalancingExtractTo(api) + + u.preserveHostHeaderExtractTo(api) +} + +func (u *Upstream) preserveHostHeaderExtractTo(api *apidef.APIDefinition) { + if u.PreserveHostHeader == nil { + u.PreserveHostHeader = &PreserveHostHeader{} + defer func() { + u.PreserveHostHeader = nil + }() + } + + u.PreserveHostHeader.ExtractTo(api) } func (u *Upstream) fillLoadBalancing(api apidef.APIDefinition) { @@ -905,3 +935,19 @@ func (l *LoadBalancing) ExtractTo(api *apidef.APIDefinition) { api.Proxy.Targets = proxyConfTargets } + +// PreserveHostHeader holds the configuration for preserving the host header. +type PreserveHostHeader struct { + // Enabled activates preserving the host header. + Enabled bool `json:"enabled" bson:"enabled"` +} + +// Fill fills *PreserveHostHeader from apidef.APIDefinition. +func (p *PreserveHostHeader) Fill(api apidef.APIDefinition) { + p.Enabled = api.Proxy.PreserveHostHeader +} + +// ExtractTo extracts *PreserveHostHeader into *apidef.APIDefinition. +func (p *PreserveHostHeader) ExtractTo(api *apidef.APIDefinition) { + api.Proxy.PreserveHostHeader = p.Enabled +} diff --git a/apidef/oas/upstream_test.go b/apidef/oas/upstream_test.go index 8150c39b085..46fc59be7c0 100644 --- a/apidef/oas/upstream_test.go +++ b/apidef/oas/upstream_test.go @@ -668,3 +668,84 @@ func TestLoadBalancing(t *testing.T) { } }) } + +func TestPreserveHostHeader(t *testing.T) { + t.Run("fill", func(t *testing.T) { + type testCase struct { + title string + input apidef.APIDefinition + expected *PreserveHostHeader + } + testCases := []testCase{ + { + title: "preserve host header disabled", + input: apidef.APIDefinition{ + Proxy: apidef.ProxyConfig{ + PreserveHostHeader: false, + }, + }, + expected: nil, + }, + { + title: "preserve host header enabled", + input: apidef.APIDefinition{ + Proxy: apidef.ProxyConfig{ + PreserveHostHeader: true, + }, + }, + expected: &PreserveHostHeader{ + Enabled: true, + }, + }, + } + for _, tc := range testCases { + tc := tc + t.Run(tc.title, func(t *testing.T) { + t.Parallel() + + g := new(Upstream) + g.Fill(tc.input) + + assert.Equal(t, tc.expected, g.PreserveHostHeader) + }) + } + }) + + t.Run("extractTo", func(t *testing.T) { + type testCase struct { + title string + input *PreserveHostHeader + expectedEnabled bool + } + testcases := []testCase{ + { + title: "preserve host header disabled", + input: &PreserveHostHeader{ + Enabled: false, + }, + expectedEnabled: false, + }, + { + title: "preserve host header enabled", + input: &PreserveHostHeader{ + Enabled: true, + }, + expectedEnabled: true, + }, + } + + for _, tc := range testcases { + tc := tc // Creating a new 'tc' scoped to the loop + t.Run(tc.title, func(t *testing.T) { + g := new(Upstream) + g.PreserveHostHeader = tc.input + + var apiDef apidef.APIDefinition + apiDef.Proxy.PreserveHostHeader = true + g.ExtractTo(&apiDef) + + assert.Equal(t, tc.expectedEnabled, apiDef.Proxy.PreserveHostHeader) + }) + } + }) +}