Skip to content

Commit 0d05475

Browse files
authored
Fix tolowerequalfold package resolution for alias imports and shadowed identifiers (#40622)
1 parent 28adc72 commit 0d05475

3 files changed

Lines changed: 109 additions & 17 deletions

File tree

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package tolowerequalfold
2+
3+
import str "strings"
4+
5+
func aliasImportExamples() {
6+
a := "Alice"
7+
b := "alice"
8+
9+
_ = str.ToLower(a) == str.ToLower(b) // want `use strings\.EqualFold`
10+
_ = str.ToUpper(a) == str.ToUpper(b) // want `use strings\.EqualFold`
11+
}
12+
13+
func aliasImportTrackedExamples() {
14+
a := "Alice"
15+
b := "alice"
16+
17+
x := str.ToLower(a)
18+
_ = x == "alice" // want `use strings\.EqualFold`
19+
20+
y := str.ToUpper(b)
21+
_ = "ALICE" == y // want `use strings\.EqualFold`
22+
}
23+
24+
type shadowStrings struct{}
25+
26+
func (shadowStrings) ToLower(s string) string {
27+
return s
28+
}
29+
30+
func shadowedIdentifierExample() {
31+
strings := shadowStrings{}
32+
a := "Alice"
33+
b := "alice"
34+
35+
_ = strings.ToLower(a) == strings.ToLower(b)
36+
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package tolowerequalfold
2+
3+
import str "strings"
4+
5+
func aliasImportExamples() {
6+
a := "Alice"
7+
b := "alice"
8+
9+
_ = str.EqualFold(a, b) // want `use strings\.EqualFold`
10+
_ = str.EqualFold(a, b) // want `use strings\.EqualFold`
11+
}
12+
13+
func aliasImportTrackedExamples() {
14+
a := "Alice"
15+
b := "alice"
16+
17+
x := str.ToLower(a)
18+
_ = x == "alice" // want `use strings\.EqualFold`
19+
20+
y := str.ToUpper(b)
21+
_ = "ALICE" == y // want `use strings\.EqualFold`
22+
}
23+
24+
type shadowStrings struct{}
25+
26+
func (shadowStrings) ToLower(s string) string {
27+
return s
28+
}
29+
30+
func shadowedIdentifierExample() {
31+
strings := shadowStrings{}
32+
a := "Alice"
33+
b := "alice"
34+
35+
_ = strings.ToLower(a) == strings.ToLower(b)
36+
}

pkg/linters/tolowerequalfold/tolowerequalfold.go

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,10 @@ func run(pass *analysis.Pass) (any, error) {
5151
return
5252
}
5353

54-
if arg, ok := caseConvArg(expr.X); ok && sameOperand(pass, arg, expr.Y) {
54+
if arg, ok := caseConvArg(pass, expr.X); ok && sameOperand(pass, arg, expr.Y) {
5555
return
5656
}
57-
if arg, ok := caseConvArg(expr.Y); ok && sameOperand(pass, expr.X, arg) {
57+
if arg, ok := caseConvArg(pass, expr.Y); ok && sameOperand(pass, expr.X, arg) {
5858
return
5959
}
6060
if arg, ok := caseConvAliasArg(pass, expr.X, caseConvAliases); ok && sameOperand(pass, arg, expr.Y) {
@@ -64,7 +64,7 @@ func run(pass *analysis.Pass) (any, error) {
6464
return
6565
}
6666

67-
if isCaseConvCall(expr.X) || isCaseConvCall(expr.Y) ||
67+
if isCaseConvCall(pass, expr.X) || isCaseConvCall(pass, expr.Y) ||
6868
(isCaseConvAlias(pass, expr.X, caseConvAliases) && astutil.IsStringLiteral(expr.Y)) ||
6969
(isCaseConvAlias(pass, expr.Y, caseConvAliases) && astutil.IsStringLiteral(expr.X)) {
7070
if nolint.HasDirective(pass.Fset.PositionFor(expr.Pos(), false), noLintLinesByFile) {
@@ -88,8 +88,8 @@ func run(pass *analysis.Pass) (any, error) {
8888
// an alias variable), since alias variables may be defined at a different
8989
// source location.
9090
func buildEqualFoldFix(pass *analysis.Pass, expr *ast.BinaryExpr) []analysis.SuggestedFix {
91-
leftArg, leftOK := caseConvArg(expr.X)
92-
rightArg, rightOK := caseConvArg(expr.Y)
91+
leftArg, leftOK := caseConvArg(pass, expr.X)
92+
rightArg, rightOK := caseConvArg(pass, expr.Y)
9393
if !leftOK && !rightOK {
9494
return nil
9595
}
@@ -106,12 +106,20 @@ func buildEqualFoldFix(pass *analysis.Pass, expr *ast.BinaryExpr) []analysis.Sug
106106
if text1 == "" || text2 == "" {
107107
return nil
108108
}
109-
call := fmt.Sprintf("strings.EqualFold(%s, %s)", text1, text2)
109+
equalFoldPkg := "strings"
110+
if leftOK {
111+
if pkgName, ok := caseConvPkgName(pass, expr.X); ok {
112+
equalFoldPkg = pkgName
113+
}
114+
} else if pkgName, ok := caseConvPkgName(pass, expr.Y); ok {
115+
equalFoldPkg = pkgName
116+
}
117+
call := fmt.Sprintf("%s.EqualFold(%s, %s)", equalFoldPkg, text1, text2)
110118
if expr.Op == token.NEQ {
111119
call = "!" + call
112120
}
113121
return []analysis.SuggestedFix{{
114-
Message: "Replace with strings.EqualFold",
122+
Message: fmt.Sprintf("Replace with %s.EqualFold", equalFoldPkg),
115123
TextEdits: []analysis.TextEdit{{
116124
Pos: expr.Pos(),
117125
End: expr.End(),
@@ -167,7 +175,7 @@ func collectAliasesFromAssignStmt(pass *analysis.Pass, stmt *ast.AssignStmt, ali
167175
delete(aliases, obj)
168176
continue
169177
}
170-
if arg, ok := caseConvArg(rhs); ok {
178+
if arg, ok := caseConvArg(pass, rhs); ok {
171179
aliases[obj] = arg
172180
} else {
173181
delete(aliases, obj)
@@ -192,7 +200,7 @@ func collectAliasesFromValueSpec(pass *analysis.Pass, spec *ast.ValueSpec, alias
192200
delete(aliases, obj)
193201
continue
194202
}
195-
if arg, ok := caseConvArg(rhs); ok {
203+
if arg, ok := caseConvArg(pass, rhs); ok {
196204
aliases[obj] = arg
197205
} else {
198206
delete(aliases, obj)
@@ -209,8 +217,8 @@ func deleteAliasForExpr(pass *analysis.Pass, aliases map[types.Object]ast.Expr,
209217
}
210218

211219
// isCaseConvCall reports whether node is a call to strings.ToLower or strings.ToUpper.
212-
func isCaseConvCall(n ast.Node) bool {
213-
_, ok := caseConvArg(n)
220+
func isCaseConvCall(pass *analysis.Pass, n ast.Node) bool {
221+
_, ok := caseConvArg(pass, n)
214222
return ok
215223
}
216224

@@ -236,7 +244,7 @@ func caseConvAliasArg(pass *analysis.Pass, expr ast.Expr, aliases map[types.Obje
236244
}
237245

238246
// caseConvArg returns the argument when n is strings.ToLower/ToUpper(<arg>).
239-
func caseConvArg(n ast.Node) (ast.Expr, bool) {
247+
func caseConvArg(pass *analysis.Pass, n ast.Node) (ast.Expr, bool) {
240248
call, ok := n.(*ast.CallExpr)
241249
if !ok {
242250
return nil, false
@@ -248,11 +256,7 @@ func caseConvArg(n ast.Node) (ast.Expr, bool) {
248256
if !ok {
249257
return nil, false
250258
}
251-
ident, ok := sel.X.(*ast.Ident)
252-
if !ok {
253-
return nil, false
254-
}
255-
if ident.Name != "strings" {
259+
if !astutil.IsPkgSelector(pass, sel, "strings") {
256260
return nil, false
257261
}
258262
if sel.Sel.Name != "ToLower" && sel.Sel.Name != "ToUpper" {
@@ -261,6 +265,22 @@ func caseConvArg(n ast.Node) (ast.Expr, bool) {
261265
return call.Args[0], true
262266
}
263267

268+
func caseConvPkgName(pass *analysis.Pass, n ast.Node) (string, bool) {
269+
call, ok := n.(*ast.CallExpr)
270+
if !ok {
271+
return "", false
272+
}
273+
sel, ok := call.Fun.(*ast.SelectorExpr)
274+
if !ok || !astutil.IsPkgSelector(pass, sel, "strings") {
275+
return "", false
276+
}
277+
if sel.Sel.Name != "ToLower" && sel.Sel.Name != "ToUpper" {
278+
return "", false
279+
}
280+
pkgName := astutil.NodeText(pass.Fset, sel.X)
281+
return pkgName, pkgName != ""
282+
}
283+
264284
func sameOperand(pass *analysis.Pass, left ast.Expr, right ast.Expr) bool {
265285
leftIdent, leftOK := left.(*ast.Ident)
266286
rightIdent, rightOK := right.(*ast.Ident)

0 commit comments

Comments
 (0)