From 0fa6b399340bee3b4633fbca475dacee7d549372 Mon Sep 17 00:00:00 2001 From: "Lim, Choon Keat" Date: Wed, 11 Feb 2026 11:06:25 +0800 Subject: [PATCH] fix: Update the validator criteria to allow wifi config is empty when DHCP is false --- internal/controller/httpapi/v1/profiles.go | 1 + .../controller/httpapi/v1/profiles_test.go | 275 +++++++++++++++++- internal/entity/dto/v1/profile.go | 14 +- 3 files changed, 283 insertions(+), 7 deletions(-) diff --git a/internal/controller/httpapi/v1/profiles.go b/internal/controller/httpapi/v1/profiles.go index 958c99284..442815e49 100644 --- a/internal/controller/httpapi/v1/profiles.go +++ b/internal/controller/httpapi/v1/profiles.go @@ -27,6 +27,7 @@ func NewProfileRoutes(handler *gin.RouterGroup, t profiles.Feature, l logger.Int if v, ok := binding.Validator.Engine().(*validator.Validate); ok { _ = v.RegisterValidation("genpasswordwone", dto.ValidateAMTPassOrGenRan) _ = v.RegisterValidation("ciraortls", dto.ValidateCIRAOrTLS) + _ = v.RegisterValidation("wifidhcp", dto.ValidateWiFiDHCP) } } diff --git a/internal/controller/httpapi/v1/profiles_test.go b/internal/controller/httpapi/v1/profiles_test.go index 62be82e11..26dfd1040 100644 --- a/internal/controller/httpapi/v1/profiles_test.go +++ b/internal/controller/httpapi/v1/profiles_test.go @@ -6,9 +6,13 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "reflect" + "sync" "testing" "github.com/gin-gonic/gin" + "github.com/gin-gonic/gin/binding" + "github.com/go-playground/validator/v10" "github.com/stretchr/testify/require" gomock "go.uber.org/mock/gomock" @@ -18,9 +22,84 @@ import ( "github.com/device-management-toolkit/console/pkg/logger" ) +// defaultValidator implements the gin binding.StructValidator interface +type defaultValidator struct { + once sync.Once + validate *validator.Validate +} + +func (v *defaultValidator) ValidateStruct(obj any) error { + if obj == nil { + return nil + } + + value := reflect.ValueOf(obj) + switch value.Kind() { + case reflect.Ptr: + if value.IsNil() { + return nil + } + + return v.ValidateStruct(value.Elem().Interface()) + case reflect.Struct: + return v.validateStruct(obj) + case reflect.Slice, reflect.Array: + count := value.Len() + validateRet := make(binding.SliceValidationError, 0) + + for i := 0; i < count; i++ { + if err := v.ValidateStruct(value.Index(i).Interface()); err != nil { + validateRet = append(validateRet, err) + } + } + + if len(validateRet) == 0 { + return nil + } + + return validateRet + case reflect.Invalid, reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, + reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Uintptr, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128, + reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.String, reflect.UnsafePointer: + return nil + default: + return nil + } +} + +func (v *defaultValidator) validateStruct(obj any) error { + v.lazyinit() + + return v.validate.Struct(obj) +} + +func (v *defaultValidator) Engine() any { + v.lazyinit() + + return v.validate +} + +func (v *defaultValidator) lazyinit() { + v.once.Do(func() { + v.validate = validator.New() + v.validate.SetTagName("binding") + + // Register custom validators + _ = v.validate.RegisterValidation("genpasswordwone", dto.ValidateAMTPassOrGenRan) + _ = v.validate.RegisterValidation("ciraortls", dto.ValidateCIRAOrTLS) + _ = v.validate.RegisterValidation("wifidhcp", dto.ValidateWiFiDHCP) + }) +} + func profilesTest(t *testing.T) (*mocks.MockProfilesFeature, *gin.Engine) { t.Helper() + // Enable validation for tests + if binding.Validator == nil { + binding.Validator = &defaultValidator{} + } + mockCtl := gomock.NewController(t) defer mockCtl.Finish() @@ -73,9 +152,7 @@ var profileTest = dto.Profile{ UEFIWiFiSyncEnabled: false, } -func TestProfileRoutes(t *testing.T) { //nolint:gocognit // this is a test function - t.Parallel() - +func TestProfileRoutes(t *testing.T) { //nolint:gocognit,paralleltest // this is a test function tests := []testProfiles{ { name: "get all profiles", @@ -260,12 +337,10 @@ func TestProfileRoutes(t *testing.T) { //nolint:gocognit // this is a test funct }, } - for _, tc := range tests { + for _, tc := range tests { //nolint:paralleltest // tests run sequentially for simplicity tc := tc t.Run(tc.name, func(t *testing.T) { - t.Parallel() - profileFeature, engine := profilesTest(t) tc.mock(profileFeature) @@ -277,6 +352,7 @@ func TestProfileRoutes(t *testing.T) { //nolint:gocognit // this is a test funct if tc.requestBody.ProfileName != "" { reqBody, _ := json.Marshal(tc.requestBody) req, err = http.NewRequestWithContext(context.Background(), tc.method, tc.url, bytes.NewBuffer(reqBody)) + req.Header.Set("Content-Type", "application/json") } else { req, err = http.NewRequestWithContext(context.Background(), tc.method, tc.url, http.NoBody) } @@ -309,3 +385,190 @@ func TestProfileRoutes(t *testing.T) { //nolint:gocognit // this is a test funct }) } } + +func TestProfileValidation(t *testing.T) { //nolint:paralleltest // tests run sequentially for simplicity + tests := []struct { + name string + profile dto.Profile + expectedCode int + }{ + { + name: "valid profile - CCM with CIRA", + profile: dto.Profile{ + ProfileName: "test-profile", + Activation: "ccmactivate", + GenerateRandomPassword: true, + GenerateRandomMEBxPassword: true, + CIRAConfigName: stringPtr("cira-config"), + DHCPEnabled: true, + UserConsent: "All", + TenantID: "tenant1", + }, + expectedCode: http.StatusCreated, + }, + { + name: "valid profile - ACM with TLS", + profile: dto.Profile{ + ProfileName: "test-profile", + Activation: "acmactivate", + GenerateRandomPassword: true, + MEBXPassword: "P@ssw0rd123", + GenerateRandomMEBxPassword: false, + TLSMode: 1, + TLSSigningAuthority: "SelfSigned", + DHCPEnabled: true, + UserConsent: "KVM", + TenantID: "tenant1", + }, + expectedCode: http.StatusCreated, + }, + { + name: "invalid - both CIRA and TLS", + profile: dto.Profile{ + ProfileName: "test-profile", + Activation: "ccmactivate", + GenerateRandomPassword: true, + GenerateRandomMEBxPassword: true, + CIRAConfigName: stringPtr("cira-config"), + TLSMode: 1, + DHCPEnabled: true, + UserConsent: "All", + TenantID: "tenant1", + }, + expectedCode: http.StatusBadRequest, + }, + { + name: "invalid - wifi configs without DHCP", + profile: dto.Profile{ + ProfileName: "test-profile", + Activation: "ccmactivate", + GenerateRandomPassword: true, + GenerateRandomMEBxPassword: true, + DHCPEnabled: false, + WiFiConfigs: []dto.ProfileWiFiConfigs{ + {ProfileName: "wifi1", Priority: 1}, + }, + UserConsent: "All", + TenantID: "tenant1", + }, + expectedCode: http.StatusBadRequest, + }, + { + name: "invalid - password set with genRandom true", + profile: dto.Profile{ + ProfileName: "test-profile", + Activation: "ccmactivate", + AMTPassword: "P@ssw0rd123", + GenerateRandomPassword: true, + GenerateRandomMEBxPassword: true, + DHCPEnabled: true, + UserConsent: "All", + TenantID: "tenant1", + }, + expectedCode: http.StatusBadRequest, + }, + { + name: "invalid - invalid activation", + profile: dto.Profile{ + ProfileName: "test-profile", + Activation: "invalidactivation", + GenerateRandomPassword: true, + GenerateRandomMEBxPassword: true, + UserConsent: "All", + TenantID: "tenant1", + }, + expectedCode: http.StatusBadRequest, + }, + { + name: "invalid - invalid TLS signing authority", + profile: dto.Profile{ + ProfileName: "test-profile", + Activation: "acmactivate", + GenerateRandomPassword: true, + GenerateRandomMEBxPassword: true, + TLSMode: 1, + TLSSigningAuthority: "InvalidAuthority", + DHCPEnabled: true, + UserConsent: "All", + TenantID: "tenant1", + }, + expectedCode: http.StatusBadRequest, + }, + { + name: "invalid - TLS mode out of range", + profile: dto.Profile{ + ProfileName: "test-profile", + Activation: "acmactivate", + GenerateRandomPassword: true, + GenerateRandomMEBxPassword: true, + TLSMode: 5, + TLSSigningAuthority: "SelfSigned", + DHCPEnabled: true, + UserConsent: "All", + TenantID: "tenant1", + }, + expectedCode: http.StatusBadRequest, + }, + { + name: "invalid - password too short", + profile: dto.Profile{ + ProfileName: "test-profile", + Activation: "acmactivate", + AMTPassword: "short", + GenerateRandomPassword: false, + MEBXPassword: "P@ssw0rd123", + GenerateRandomMEBxPassword: false, + DHCPEnabled: true, + UserConsent: "All", + TenantID: "tenant1", + }, + expectedCode: http.StatusBadRequest, + }, + { + name: "invalid - password missing special character", + profile: dto.Profile{ + ProfileName: "test-profile", + Activation: "acmactivate", + AMTPassword: "Password123", + GenerateRandomPassword: false, + MEBXPassword: "P@ssw0rd123", + GenerateRandomMEBxPassword: false, + DHCPEnabled: true, + UserConsent: "All", + TenantID: "tenant1", + }, + expectedCode: http.StatusBadRequest, + }, + } + + for _, tc := range tests { //nolint:paralleltest // tests run sequentially for simplicity + tc := tc + + t.Run(tc.name, func(t *testing.T) { + profileFeature, engine := profilesTest(t) + + if tc.expectedCode == http.StatusCreated { + profileFeature.EXPECT().Insert(context.Background(), &tc.profile).Return(&tc.profile, nil) + } + + reqBody, _ := json.Marshal(tc.profile) + req, err := http.NewRequestWithContext( + context.Background(), + http.MethodPost, + "/api/v1/admin/profiles", + bytes.NewBuffer(reqBody), + ) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + engine.ServeHTTP(w, req) + + require.Equal(t, tc.expectedCode, w.Code) + }) + } +} + +func stringPtr(s string) *string { + return &s +} diff --git a/internal/entity/dto/v1/profile.go b/internal/entity/dto/v1/profile.go index 2b8376414..8c371f3b0 100644 --- a/internal/entity/dto/v1/profile.go +++ b/internal/entity/dto/v1/profile.go @@ -21,7 +21,7 @@ type Profile struct { DHCPEnabled bool `json:"dhcpEnabled" example:"true"` IPSyncEnabled bool `json:"ipSyncEnabled" example:"true"` LocalWiFiSyncEnabled bool `json:"localWifiSyncEnabled" example:"true"` - WiFiConfigs []ProfileWiFiConfigs `json:"wifiConfigs,omitempty" binding:"excluded_if=DHCPEnabled false,dive"` + WiFiConfigs []ProfileWiFiConfigs `json:"wifiConfigs,omitempty" binding:"wifidhcp,dive"` TenantID string `json:"tenantId" example:"abc123"` TLSMode int `json:"tlsMode,omitempty" binding:"omitempty,min=1,max=4,ciraortls" example:"1"` TLSCerts *TLSCerts `json:"tlsCerts,omitempty"` @@ -66,6 +66,18 @@ var ValidateUserConsent validator.Func = func(fl validator.FieldLevel) bool { return userConsent == "none" || userConsent == "kvm" || userConsent == "all" } +var ValidateWiFiDHCP validator.Func = func(fl validator.FieldLevel) bool { + dhcpEnabled := fl.Parent().FieldByName("DHCPEnabled").Bool() + wifiConfigs := fl.Field() + + // If WiFiConfigs has items and DHCP is disabled, fail validation + if wifiConfigs.Len() > 0 && !dhcpEnabled { + return false + } + + return true +} + type ProfileCountResponse struct { Count int `json:"totalCount"` Data []Profile `json:"data"`