Skip to content

Commit de76d14

Browse files
committed
fix: add token file name validation check
refine useWorkloadIdentity logic
1 parent d7231a6 commit de76d14

File tree

4 files changed

+137
-3
lines changed

4 files changed

+137
-3
lines changed

pkg/blob/blob.go

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,11 @@ func (d *Driver) GetAuthEnv(ctx context.Context, volumeID, protocol string, attr
603603
if err != nil {
604604
return rgName, accountName, accountKey, containerName, authEnv, err
605605
}
606-
azureOAuthTokenFile := filepath.Join(defaultAzureOAuthTokenDir, clientID+accountName)
606+
tokenFileName := clientID + accountName
607+
if !isValidTokenFileName(tokenFileName) {
608+
return rgName, accountName, accountKey, containerName, authEnv, fmt.Errorf("the generated token file name %s is invalid", tokenFileName)
609+
}
610+
azureOAuthTokenFile := filepath.Join(defaultAzureOAuthTokenDir, tokenFileName)
607611
if err := os.WriteFile(azureOAuthTokenFile, []byte(workloadIdentityToken), 0600); err != nil {
608612
return rgName, accountName, accountKey, containerName, authEnv, fmt.Errorf("failed to write workload identity token file %s: %v", azureOAuthTokenFile, err)
609613
}
@@ -1251,3 +1255,20 @@ func parseServiceAccountToken(tokenStr string) (string, error) {
12511255
}
12521256
return token.APIAzureADTokenExchange.Token, nil
12531257
}
1258+
1259+
// isValidTokenFileName checks if the token file name is valid
1260+
// fileName should only contain alphanumeric characters, hyphens
1261+
func isValidTokenFileName(fileName string) bool {
1262+
if fileName == "" {
1263+
return false
1264+
}
1265+
for _, c := range fileName {
1266+
if !(('a' <= c && c <= 'z') ||
1267+
('A' <= c && c <= 'Z') ||
1268+
('0' <= c && c <= '9') ||
1269+
(c == '-')) {
1270+
return false
1271+
}
1272+
}
1273+
return true
1274+
}

pkg/blob/blob_test.go

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2132,3 +2132,65 @@ func TestIsSupportedPublicNetworkAccess(t *testing.T) {
21322132
}
21332133
}
21342134
}
2135+
2136+
func TestIsValidTokenFileName(t *testing.T) {
2137+
testCases := []struct {
2138+
name string
2139+
fileName string
2140+
expected bool
2141+
}{
2142+
{
2143+
name: "valid lowercase",
2144+
fileName: "token",
2145+
expected: true,
2146+
},
2147+
{
2148+
name: "valid uppercase",
2149+
fileName: "TOKEN",
2150+
expected: true,
2151+
},
2152+
{
2153+
name: "valid mixed alphanumeric with hyphen",
2154+
fileName: "Token-123",
2155+
expected: true,
2156+
},
2157+
{
2158+
name: "valid mixed alphanumeric with hyphen#2",
2159+
fileName: "0ab48765-efce-4799-8a9c-c3e1de2ee42eg",
2160+
expected: true,
2161+
},
2162+
{
2163+
name: "empty string",
2164+
fileName: "",
2165+
expected: false,
2166+
},
2167+
{
2168+
name: "contains underscore",
2169+
fileName: "token_file",
2170+
expected: false,
2171+
},
2172+
{
2173+
name: "contains dot",
2174+
fileName: "token.file",
2175+
expected: false,
2176+
},
2177+
{
2178+
name: "contains space",
2179+
fileName: "token file",
2180+
expected: false,
2181+
},
2182+
{
2183+
name: "contains slash",
2184+
fileName: "token/file",
2185+
expected: false,
2186+
},
2187+
}
2188+
2189+
for _, tc := range testCases {
2190+
t.Run(tc.name, func(t *testing.T) {
2191+
if got := isValidTokenFileName(tc.fileName); got != tc.expected {
2192+
t.Fatalf("isValidTokenFileName(%q) = %t, want %t", tc.fileName, got, tc.expected)
2193+
}
2194+
})
2195+
}
2196+
}

pkg/blob/nodeserver.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ func (d *Driver) NodePublishVolume(ctx context.Context, req *csi.NodePublishVolu
8080
context := req.GetVolumeContext()
8181
if context != nil {
8282
// token request
83-
if context[serviceAccountTokenField] != "" && getValueInMap(context, clientIDField) != "" {
83+
if context[serviceAccountTokenField] != "" && useWorkloadIdentity(context) {
8484
klog.V(2).Infof("NodePublishVolume: volume(%s) mount on %s with service account token, clientID: %s", volumeID, target, getValueInMap(context, clientIDField))
8585
_, err := d.NodeStageVolume(ctx, &csi.NodeStageVolumeRequest{
8686
StagingTargetPath: target,
@@ -261,7 +261,7 @@ func (d *Driver) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRe
261261
attrib := req.GetVolumeContext()
262262
secrets := req.GetSecrets()
263263

264-
if getValueInMap(attrib, clientIDField) != "" && attrib[serviceAccountTokenField] == "" {
264+
if useWorkloadIdentity(attrib) && attrib[serviceAccountTokenField] == "" {
265265
klog.V(2).Infof("Skip NodeStageVolume for volume(%s) since clientID %s is provided but service account token is empty", volumeID, getValueInMap(attrib, clientIDField))
266266
return &csi.NodeStageVolumeResponse{}, nil
267267
}
@@ -733,3 +733,11 @@ func checkGidPresentInMountFlags(mountFlags []string) bool {
733733
}
734734
return false
735735
}
736+
737+
// useWorkloadIdentity checks whether workload identity is used based on the presence of clientID or mountWithWIToken in volume attributes
738+
func useWorkloadIdentity(attrib map[string]string) bool {
739+
if getValueInMap(attrib, clientIDField) != "" || getValueInMap(attrib, mountWithWITokenField) == trueValue {
740+
return true
741+
}
742+
return false
743+
}

pkg/blob/nodeserver_test.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1192,3 +1192,46 @@ func TestCheckGidPresentInMountFlags(t *testing.T) {
11921192
}
11931193
}
11941194
}
1195+
1196+
func TestUseWorkloadIdentity(t *testing.T) {
1197+
tests := []struct {
1198+
name string
1199+
attrib map[string]string
1200+
want bool
1201+
}{
1202+
{
1203+
name: "clientID present",
1204+
attrib: map[string]string{
1205+
clientIDField: "client-id",
1206+
},
1207+
want: true,
1208+
},
1209+
{
1210+
name: "mountWithWIToken true",
1211+
attrib: map[string]string{
1212+
mountWithWITokenField: trueValue,
1213+
},
1214+
want: true,
1215+
},
1216+
{
1217+
name: "mountWithWIToken false",
1218+
attrib: map[string]string{
1219+
mountWithWITokenField: "false",
1220+
},
1221+
want: false,
1222+
},
1223+
{
1224+
name: "no workload identity fields",
1225+
attrib: map[string]string{},
1226+
want: false,
1227+
},
1228+
}
1229+
1230+
for _, tt := range tests {
1231+
t.Run(tt.name, func(t *testing.T) {
1232+
if got := useWorkloadIdentity(tt.attrib); got != tt.want {
1233+
t.Errorf("useWorkloadIdentity() = %v, want %v", got, tt.want)
1234+
}
1235+
})
1236+
}
1237+
}

0 commit comments

Comments
 (0)