diff --git a/cmd/ks-iam/app/options/options.go b/cmd/ks-iam/app/options/options.go index 8b3e3c261..d986db334 100644 --- a/cmd/ks-iam/app/options/options.go +++ b/cmd/ks-iam/app/options/options.go @@ -27,6 +27,7 @@ import ( "kubesphere.io/kubesphere/pkg/simple/client/mysql" "kubesphere.io/kubesphere/pkg/simple/client/redis" "strings" + "time" ) type ServerRunOptions struct { @@ -37,9 +38,10 @@ type ServerRunOptions struct { MySQLOptions *mysql.MySQLOptions AdminEmail string AdminPassword string - TokenExpireTime string + TokenIdleTimeout time.Duration JWTSecret string AuthRateLimit string + EnableMultiLogin bool } func NewServerRunOptions() *ServerRunOptions { @@ -60,9 +62,10 @@ func (s *ServerRunOptions) Flags() (fss cliflag.NamedFlagSets) { s.GenericServerRunOptions.AddFlags(fs) fs.StringVar(&s.AdminEmail, "admin-email", "admin@kubesphere.io", "default administrator's email") fs.StringVar(&s.AdminPassword, "admin-password", "passw0rd", "default administrator's password") - fs.StringVar(&s.TokenExpireTime, "token-expire-time", "2h", "token expire time,valid time units are \"ns\",\"us\",\"ms\",\"s\",\"m\",\"h\"") + fs.DurationVar(&s.TokenIdleTimeout, "token-idle-timeout", 30*time.Minute, "tokens that are idle beyond that time will expire,0s means the token has no expiration time. valid time units are \"ns\",\"us\",\"ms\",\"s\",\"m\",\"h\"") fs.StringVar(&s.JWTSecret, "jwt-secret", "", "jwt secret") fs.StringVar(&s.AuthRateLimit, "auth-rate-limit", "5/30m", "specifies the maximum number of authentication attempts permitted and time interval,valid time units are \"s\",\"m\",\"h\"") + fs.BoolVar(&s.EnableMultiLogin, "enable-multi-login", false, "allow one account to have multiple sessions") s.KubernetesOptions.AddFlags(fss.FlagSet("kubernetes")) s.LdapOptions.AddFlags(fss.FlagSet("ldap")) diff --git a/cmd/ks-iam/app/server.go b/cmd/ks-iam/app/server.go index 14370e1cb..502714860 100644 --- a/cmd/ks-iam/app/server.go +++ b/cmd/ks-iam/app/server.go @@ -24,6 +24,7 @@ import ( cliflag "k8s.io/component-base/cli/flag" "k8s.io/klog" "kubesphere.io/kubesphere/cmd/ks-iam/app/options" + "kubesphere.io/kubesphere/pkg/apis" "kubesphere.io/kubesphere/pkg/apiserver/runtime" "kubesphere.io/kubesphere/pkg/informers" "kubesphere.io/kubesphere/pkg/models/iam" @@ -35,9 +36,6 @@ import ( "kubesphere.io/kubesphere/pkg/utils/signals" "kubesphere.io/kubesphere/pkg/utils/term" "net/http" - "time" - - "kubesphere.io/kubesphere/pkg/apis" ) func NewAPIServerCommand() *cobra.Command { @@ -94,15 +92,10 @@ func Run(s *options.ServerRunOptions, stopChan <-chan struct{}) error { client.NewClientSetFactory(csop, stopChan) - expireTime, err := time.ParseDuration(s.TokenExpireTime) - - if err != nil { - return err - } - waitForResourceSync(stopChan) - err = iam.Init(s.AdminEmail, s.AdminPassword, expireTime, s.AuthRateLimit) + err := iam.Init(s.AdminEmail, s.AdminPassword, s.AuthRateLimit, s.TokenIdleTimeout, s.EnableMultiLogin) + jwtutil.Setup(s.JWTSecret) if err != nil { diff --git a/pkg/apigateway/caddy-plugin/authenticate/authenticate.go b/pkg/apigateway/caddy-plugin/authenticate/authenticate.go index 3ee4fe11a..4adb6d730 100644 --- a/pkg/apigateway/caddy-plugin/authenticate/authenticate.go +++ b/pkg/apigateway/caddy-plugin/authenticate/authenticate.go @@ -23,10 +23,13 @@ import ( "k8s.io/apimachinery/pkg/util/sets" "k8s.io/apiserver/pkg/authentication/user" "k8s.io/apiserver/pkg/endpoints/request" + "k8s.io/klog" + "kubesphere.io/kubesphere/pkg/simple/client/redis" "log" "net/http" "strconv" "strings" + "time" "github.com/dgrijalva/jwt-go" "github.com/mholt/caddy/caddyhttp/httpserver" @@ -38,9 +41,12 @@ type Auth struct { } type Rule struct { - Secret []byte - Path string - ExceptedPath []string + Secret []byte + Path string + RedisOptions *redis.RedisOptions + TokenIdleTimeout time.Duration + RedisClient *redis.RedisClient + ExceptedPath []string } type User struct { @@ -87,7 +93,7 @@ func (h Auth) ServeHTTP(resp http.ResponseWriter, req *http.Request) (int, error func (h Auth) InjectContext(req *http.Request, token *jwt.Token) (*http.Request, error) { - payLoad, ok := token.Claims.(jwt.MapClaims) + payload, ok := token.Claims.(jwt.MapClaims) if !ok { return nil, errors.New("invalid payload") @@ -101,14 +107,14 @@ func (h Auth) InjectContext(req *http.Request, token *jwt.Token) (*http.Request, usr := &user.DefaultInfo{} - username, ok := payLoad["username"].(string) + username, ok := payload["username"].(string) if ok && username != "" { req.Header.Set("X-Token-Username", username) usr.Name = username } - uid := payLoad["uid"] + uid := payload["uid"] if uid != nil { switch uid.(type) { @@ -123,7 +129,7 @@ func (h Auth) InjectContext(req *http.Request, token *jwt.Token) (*http.Request, } } - groups, ok := payLoad["groups"].([]string) + groups, ok := payload["groups"].([]string) if ok && len(groups) > 0 { req.Header.Set("X-Token-Groups", strings.Join(groups, ",")) usr.Groups = groups @@ -160,10 +166,46 @@ func (h Auth) Validate(uToken string) (*jwt.Token, error) { token, err := jwt.Parse(uToken, h.ProvideKey) if err != nil { + klog.Errorln(err) return nil, err } - return token, nil + payload, ok := token.Claims.(jwt.MapClaims) + + if !ok { + err := fmt.Errorf("invalid payload") + klog.Errorln(err) + return nil, err + } + + username, ok := payload["username"].(string) + + if !ok { + err := fmt.Errorf("invalid payload") + klog.Errorln(err) + return nil, err + } + + if _, ok = payload["exp"]; ok { + // allow static token has expiration time + return token, nil + } + + tokenKey := fmt.Sprintf("kubesphere:users:%s:token:%s", username, uToken) + + exist, err := h.Rule.RedisClient.Redis().Exists(tokenKey).Result() + if err != nil { + klog.Error(err) + return nil, err + } + + if exist == 1 { + // reset expiration time if token exist + h.Rule.RedisClient.Redis().Expire(tokenKey, h.Rule.TokenIdleTimeout) + return token, nil + } else { + return nil, errors.New("illegal token") + } } func (h Auth) HandleUnauthorized(w http.ResponseWriter, err error) int { diff --git a/pkg/apigateway/caddy-plugin/authenticate/auto_load.go b/pkg/apigateway/caddy-plugin/authenticate/auto_load.go index 53d47466b..3cb2f4c79 100644 --- a/pkg/apigateway/caddy-plugin/authenticate/auto_load.go +++ b/pkg/apigateway/caddy-plugin/authenticate/auto_load.go @@ -19,7 +19,10 @@ package authenticate import ( "fmt" + "kubesphere.io/kubesphere/pkg/simple/client/redis" + "strings" + "time" "github.com/mholt/caddy" "github.com/mholt/caddy/caddyhttp/httpserver" @@ -34,16 +37,25 @@ func Setup(c *caddy.Controller) error { } c.OnStartup(func() error { + rule.RedisClient, err = redis.NewRedisClient(rule.RedisOptions, nil) + if err != nil { + return err + } fmt.Println("Authenticate middleware is initiated") return nil }) + c.OnShutdown(func() error { + return rule.RedisClient.Redis().Close() + }) + httpserver.GetConfig(c).AddMiddleware(func(next httpserver.Handler) httpserver.Handler { return &Auth{Next: next, Rule: rule} }) return nil } + func parse(c *caddy.Controller) (Rule, error) { rule := Rule{ExceptedPath: make([]string, 0)} @@ -61,6 +73,36 @@ func parse(c *caddy.Controller) (Rule, error) { rule.Path = c.Val() + if c.NextArg() { + return rule, c.ArgErr() + } + case "token-idle-timeout": + if !c.NextArg() { + return rule, c.ArgErr() + } + + if timeout, err := time.ParseDuration(c.Val()); err != nil { + return rule, c.ArgErr() + } else { + rule.TokenIdleTimeout = timeout + } + + if c.NextArg() { + return rule, c.ArgErr() + } + case "redis-url": + if !c.NextArg() { + return rule, c.ArgErr() + } + + options := &redis.RedisOptions{RedisURL: c.Val()} + + if err := options.Validate(); len(err) > 0 { + return rule, c.ArgErr() + } else { + rule.RedisOptions = options + } + if c.NextArg() { return rule, c.ArgErr() } diff --git a/pkg/apis/iam/v1alpha2/register.go b/pkg/apis/iam/v1alpha2/register.go index 6ce498af2..3bbd13af8 100644 --- a/pkg/apis/iam/v1alpha2/register.go +++ b/pkg/apis/iam/v1alpha2/register.go @@ -130,7 +130,13 @@ func addWebService(c *restful.Container) error { To(iam.Login). Doc("KubeSphere APIs support token-based authentication via the Authtoken request header. The POST Login API is used to retrieve the authentication token. After the authentication token is obtained, it must be inserted into the Authtoken header for all requests."). Reads(iam.LoginRequest{}). - Returns(http.StatusOK, ok, models.Token{}). + Returns(http.StatusOK, ok, models.AuthGrantResponse{}). + Metadata(restfulspec.KeyOpenAPITags, []string{constants.IdentityManagementTag})) + ws.Route(ws.POST("/token"). + To(iam.OAuth). + Doc("OAuth API,only support resource owner password credentials grant"). + Reads(iam.LoginRequest{}). + Returns(http.StatusOK, ok, models.AuthGrantResponse{}). Metadata(restfulspec.KeyOpenAPITags, []string{constants.IdentityManagementTag})) ws.Route(ws.GET("/users/{user}"). To(iam.DescribeUser). diff --git a/pkg/apiserver/iam/auth.go b/pkg/apiserver/iam/auth.go index f93855a17..07892b8aa 100644 --- a/pkg/apiserver/iam/auth.go +++ b/pkg/apiserver/iam/auth.go @@ -18,15 +18,16 @@ package iam import ( + "fmt" "github.com/dgrijalva/jwt-go" "github.com/emicklei/go-restful" "k8s.io/klog" + "kubesphere.io/kubesphere/pkg/models" + "kubesphere.io/kubesphere/pkg/models/iam" + "kubesphere.io/kubesphere/pkg/server/errors" "kubesphere.io/kubesphere/pkg/utils/iputil" "kubesphere.io/kubesphere/pkg/utils/jwtutil" "net/http" - - "kubesphere.io/kubesphere/pkg/models/iam" - "kubesphere.io/kubesphere/pkg/server/errors" ) type Spec struct { @@ -50,8 +51,14 @@ type LoginRequest struct { Password string `json:"password" description:"password"` } +type OAuthRequest struct { + GrantType string `json:"grant_type"` + Username string `json:"username,omitempty" description:"username"` + Password string `json:"password,omitempty" description:"password"` + RefreshToken string `json:"refresh_token,omitempty"` +} + const ( - APIVersion = "authentication.k8s.io/v1beta1" KindTokenReview = "TokenReview" ) @@ -81,6 +88,39 @@ func Login(req *restful.Request, resp *restful.Response) { resp.WriteAsJson(token) } +func OAuth(req *restful.Request, resp *restful.Response) { + + authRequest := &OAuthRequest{} + + err := req.ReadEntity(authRequest) + + if err != nil { + resp.WriteHeaderAndEntity(http.StatusBadRequest, errors.Wrap(err)) + return + } + var result *models.AuthGrantResponse + switch authRequest.GrantType { + case "refresh_token": + result, err = iam.RefreshToken(authRequest.RefreshToken) + case "password": + ip := iputil.RemoteIp(req.Request) + result, err = iam.PasswordCredentialGrant(authRequest.Username, authRequest.Password, ip) + default: + resp.Header().Set("WWW-Authenticate", "grant_type is not supported") + resp.WriteHeaderAndEntity(http.StatusUnauthorized, errors.Wrap(fmt.Errorf("grant_type is not supported"))) + return + } + + if err != nil { + resp.Header().Set("WWW-Authenticate", err.Error()) + resp.WriteHeaderAndEntity(http.StatusUnauthorized, errors.Wrap(err)) + return + } + + resp.WriteEntity(result) + +} + // k8s token review func TokenReviewHandler(req *restful.Request, resp *restful.Response) { var tokenReview TokenReview @@ -103,7 +143,7 @@ func TokenReviewHandler(req *restful.Request, resp *restful.Response) { if err != nil { klog.Errorln("token review failed", uToken, err) - failed := TokenReview{APIVersion: APIVersion, + failed := TokenReview{APIVersion: tokenReview.APIVersion, Kind: KindTokenReview, Status: &Status{ Authenticated: false, @@ -138,7 +178,7 @@ func TokenReviewHandler(req *restful.Request, resp *restful.Response) { user.Groups = groups - success := TokenReview{APIVersion: APIVersion, + success := TokenReview{APIVersion: tokenReview.APIVersion, Kind: KindTokenReview, Status: &Status{ Authenticated: true, diff --git a/pkg/models/iam/im.go b/pkg/models/iam/im.go index 34f7a1665..ec98d7f2d 100644 --- a/pkg/models/iam/im.go +++ b/pkg/models/iam/im.go @@ -53,10 +53,11 @@ import ( var ( adminEmail string adminPassword string - tokenExpireTime time.Duration + tokenIdleTimeout time.Duration maxAuthFailed int authTimeInterval time.Duration initUsers []initUser + enableMultiLogin bool ) type initUser struct { @@ -71,11 +72,12 @@ const ( defaultAuthTimeInterval = 30 * time.Minute ) -func Init(email, password string, expireTime time.Duration, authRateLimit string) error { +func Init(email, password, authRateLimit string, idleTimeout time.Duration, multiLogin bool) error { adminEmail = email adminPassword = password - tokenExpireTime = expireTime + tokenIdleTimeout = idleTimeout maxAuthFailed, authTimeInterval = parseAuthRateLimit(authRateLimit) + enableMultiLogin = multiLogin err := checkAndCreateDefaultUser() @@ -216,6 +218,9 @@ func createUserBaseDN() error { return err } conn, err := client.NewConn() + if err != nil { + return err + } defer conn.Close() groupsCreateRequest := ldap.NewAddRequest(client.UserSearchBase(), nil) groupsCreateRequest.Attribute("objectClass", []string{"organizationalUnit", "top"}) @@ -230,6 +235,9 @@ func createGroupsBaseDN() error { return err } conn, err := client.NewConn() + if err != nil { + return err + } defer conn.Close() groupsCreateRequest := ldap.NewAddRequest(client.GroupSearchBase(), nil) groupsCreateRequest.Attribute("objectClass", []string{"organizationalUnit", "top"}) @@ -237,8 +245,151 @@ func createGroupsBaseDN() error { return conn.Add(groupsCreateRequest) } +func RefreshToken(refreshToken string) (*models.AuthGrantResponse, error) { + validRefreshToken, err := jwtutil.ValidateToken(refreshToken) + if err != nil { + klog.Error(err) + return nil, err + } + + payload, ok := validRefreshToken.Claims.(jwt.MapClaims) + + if !ok { + err = errors.New("invalid payload") + klog.Error(err) + return nil, err + } + + claims := jwt.MapClaims{} + + // token with expiration time will not auto sliding + claims["username"] = payload["username"] + claims["email"] = payload["email"] + claims["iat"] = time.Now().Unix() + claims["exp"] = time.Now().Add(tokenIdleTimeout * 4).Unix() + + token := jwtutil.MustSigned(claims) + + claims = jwt.MapClaims{} + claims["username"] = payload["username"] + claims["email"] = payload["email"] + claims["iat"] = time.Now().Unix() + claims["type"] = "refresh_token" + claims["exp"] = time.Now().Add(tokenIdleTimeout * 5).Unix() + + refreshToken = jwtutil.MustSigned(claims) + + return &models.AuthGrantResponse{TokenType: "jwt", Token: token, RefreshToken: refreshToken, ExpiresIn: (tokenIdleTimeout * 4).Seconds()}, nil +} + +func PasswordCredentialGrant(username, password, ip string) (*models.AuthGrantResponse, error) { + redisClient, err := clientset.ClientSets().Redis() + if err != nil { + return nil, err + } + + records, err := redisClient.Keys(fmt.Sprintf("kubesphere:authfailed:%s:*", username)).Result() + + if err != nil { + klog.Error(err) + return nil, err + } + + if len(records) >= maxAuthFailed { + return nil, restful.NewError(http.StatusTooManyRequests, "auth rate limit exceeded") + } + + client, err := clientset.ClientSets().Ldap() + if err != nil { + return nil, err + } + conn, err := client.NewConn() + if err != nil { + return nil, err + } + defer conn.Close() + + userSearchRequest := ldap.NewSearchRequest( + client.UserSearchBase(), + ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false, + fmt.Sprintf("(&(objectClass=inetOrgPerson)(|(uid=%s)(mail=%s)))", username, username), + []string{"uid", "mail"}, + nil, + ) + + result, err := conn.Search(userSearchRequest) + + if err != nil { + return nil, err + } + + if len(result.Entries) != 1 { + return nil, ldap.NewError(ldap.LDAPResultInvalidCredentials, errors.New("incorrect password")) + } + + uid := result.Entries[0].GetAttributeValue("uid") + email := result.Entries[0].GetAttributeValue("mail") + dn := result.Entries[0].DN + + // bind as the user to verify their password + err = conn.Bind(dn, password) + + if err != nil { + klog.Infoln("auth failed", username, err) + + if ldap.IsErrorWithCode(err, ldap.LDAPResultInvalidCredentials) { + loginFailedRecord := fmt.Sprintf("kubesphere:authfailed:%s:%d", uid, time.Now().UnixNano()) + redisClient.Set(loginFailedRecord, "", authTimeInterval) + } + + return nil, err + } + + claims := jwt.MapClaims{} + + // token with expiration time will not auto sliding + claims["username"] = uid + claims["email"] = email + claims["iat"] = time.Now().Unix() + claims["exp"] = time.Now().Add(tokenIdleTimeout * 4).Unix() + + token := jwtutil.MustSigned(claims) + + if !enableMultiLogin { + // multi login not allowed, remove the previous token + sessions, err := redisClient.Keys(fmt.Sprintf("kubesphere:users:%s:token:*", uid)).Result() + + if err != nil { + klog.Errorln(err) + return nil, err + } + + if len(sessions) > 0 { + klog.V(4).Infoln("revoke token", sessions) + err = redisClient.Del(sessions...).Err() + if err != nil { + klog.Errorln(err) + return nil, err + } + } + } + + claims = jwt.MapClaims{} + claims["username"] = uid + claims["email"] = email + claims["iat"] = time.Now().Unix() + claims["type"] = "refresh_token" + claims["exp"] = time.Now().Add(tokenIdleTimeout * 5).Unix() + + refreshToken := jwtutil.MustSigned(claims) + + loginLog(uid, ip) + + return &models.AuthGrantResponse{TokenType: "jwt", Token: token, RefreshToken: refreshToken, ExpiresIn: (tokenIdleTimeout * 4).Seconds()}, nil +} + // User login -func Login(username string, password string, ip string) (*models.Token, error) { +func Login(username, password, ip string) (*models.AuthGrantResponse, error) { redisClient, err := clientset.ClientSets().Redis() if err != nil { @@ -295,7 +446,7 @@ func Login(username string, password string, ip string) (*models.Token, error) { klog.Infoln("auth failed", username, err) if ldap.IsErrorWithCode(err, ldap.LDAPResultInvalidCredentials) { - loginFailedRecord := fmt.Sprintf("kubesphere:authfailed:%s:%d", username, time.Now().UnixNano()) + loginFailedRecord := fmt.Sprintf("kubesphere:authfailed:%s:%d", uid, time.Now().UnixNano()) redisClient.Set(loginFailedRecord, "", authTimeInterval) } @@ -304,17 +455,41 @@ func Login(username string, password string, ip string) (*models.Token, error) { claims := jwt.MapClaims{} - if tokenExpireTime > 0 { - claims["exp"] = time.Now().Add(tokenExpireTime).Unix() - } + // token without expiration time will auto sliding claims["username"] = uid claims["email"] = email + claims["iat"] = time.Now().Unix() token := jwtutil.MustSigned(claims) + if !enableMultiLogin { + // multi login not allowed, remove the previous token + sessions, err := redisClient.Keys(fmt.Sprintf("kubesphere:users:%s:token:*", uid)).Result() + + if err != nil { + klog.Errorln(err) + return nil, err + } + + if len(sessions) > 0 { + klog.V(4).Infoln("revoke token", sessions) + err = redisClient.Del(sessions...).Err() + if err != nil { + klog.Errorln(err) + return nil, err + } + } + } + + // cache token with expiration time + if err = redisClient.Set(fmt.Sprintf("kubesphere:users:%s:token:%s", uid, token), token, tokenIdleTimeout).Err(); err != nil { + klog.Errorln(err) + return nil, err + } + loginLog(uid, ip) - return &models.Token{Token: token}, nil + return &models.AuthGrantResponse{Token: token}, nil } func loginLog(uid, ip string) { @@ -443,7 +618,6 @@ func ListUsers(conditions *params.Conditions, orderBy string, reverse bool, limi if i >= offset && len(items) < limit { - user.AvatarUrl = getAvatar(user.Username) user.LastLoginTime = getLastLoginTime(user.Username) clusterRole, err := GetUserClusterRole(user.Username) if err != nil { @@ -480,8 +654,6 @@ func DescribeUser(username string) (*models.User, error) { user.Groups = groups } - user.AvatarUrl = getAvatar(username) - return user, nil } @@ -582,37 +754,6 @@ func getLastLoginTime(username string) string { return "" } -func setAvatar(username, avatar string) error { - redis, err := clientset.ClientSets().Redis() - if err != nil { - return err - } - - _, err = redis.HMSet("kubesphere:users:avatar", map[string]interface{}{"username": avatar}).Result() - return err -} - -func getAvatar(username string) string { - redis, err := clientset.ClientSets().Redis() - if err != nil { - return "" - } - - avatar, err := redis.HMGet("kubesphere:users:avatar", username).Result() - - if err != nil { - return "" - } - - if len(avatar) > 0 { - if url, ok := avatar[0].(string); ok { - return url - } - } - - return "" -} - func DeleteUser(username string) error { client, err := clientset.ClientSets().Ldap() @@ -876,10 +1017,6 @@ func CreateUser(user *models.User) (*models.User, error) { return nil, err } - if user.AvatarUrl != "" { - setAvatar(user.Username, user.AvatarUrl) - } - if user.ClusterRole != "" { err := CreateClusterRoleBinding(user.Username, user.ClusterRole) @@ -1022,15 +1159,6 @@ func UpdateUser(user *models.User) (*models.User, error) { userModifyRequest.Replace("userPassword", []string{user.Password}) } - if user.AvatarUrl != "" { - err = setAvatar(user.Username, user.AvatarUrl) - } - - if err != nil { - klog.Error(err) - return nil, err - } - err = conn.Modify(userModifyRequest) if err != nil { diff --git a/pkg/models/types.go b/pkg/models/types.go index 3916d6cb7..9144a20fb 100644 --- a/pkg/models/types.go +++ b/pkg/models/types.go @@ -107,8 +107,11 @@ type PodInfo struct { Container string `json:"container" description:"container name"` } -type Token struct { - Token string `json:"access_token" description:"access token"` +type AuthGrantResponse struct { + TokenType string `json:"token_type,omitempty"` + Token string `json:"access_token" description:"access token"` + ExpiresIn float64 `json:"expires_in,omitempty"` + RefreshToken string `json:"refresh_token,omitempty"` } type ResourceQuota struct { diff --git a/pkg/server/config/config.go b/pkg/server/config/config.go index ef6329d91..fcba02cc4 100644 --- a/pkg/server/config/config.go +++ b/pkg/server/config/config.go @@ -258,7 +258,7 @@ func (c *Config) stripEmptyOptions() { c.MySQLOptions = nil } - if c.RedisOptions != nil && c.RedisOptions.Host == "" { + if c.RedisOptions != nil && c.RedisOptions.RedisURL == "" { c.RedisOptions = nil } diff --git a/pkg/server/config/config_test.go b/pkg/server/config/config_test.go index e2c310367..f4f266c87 100644 --- a/pkg/server/config/config_test.go +++ b/pkg/server/config/config_test.go @@ -63,10 +63,7 @@ func newTestConfig() *Config { GroupSearchBase: "ou=Groups,dc=example,dc=org", }, RedisOptions: &redis.RedisOptions{ - Host: "10.10.111.110", - Port: 6379, - Password: "", - DB: 0, + RedisURL: "redis://:qwerty@localhost:6379/1", }, S3Options: &s2is3.S3Options{ Endpoint: "http://minio.openpitrix-system.svc", diff --git a/pkg/simple/client/factory.go b/pkg/simple/client/factory.go index 3d9f273a9..6fadb2c58 100644 --- a/pkg/simple/client/factory.go +++ b/pkg/simple/client/factory.go @@ -181,7 +181,7 @@ func (cs *ClientSet) MySQL() (*mysql.Database, error) { func (cs *ClientSet) Redis() (*goredis.Client, error) { var err error - if cs.csoptions.redisOptions == nil || cs.csoptions.redisOptions.Host == "" { + if cs.csoptions.redisOptions == nil || cs.csoptions.redisOptions.RedisURL == "" { return nil, ClientSetNotEnabledError{} } diff --git a/pkg/simple/client/redis/options.go b/pkg/simple/client/redis/options.go index a89b5bac9..a3e9d25de 100644 --- a/pkg/simple/client/redis/options.go +++ b/pkg/simple/client/redis/options.go @@ -1,27 +1,20 @@ package redis import ( - "fmt" + "github.com/go-redis/redis" "github.com/spf13/pflag" - "kubesphere.io/kubesphere/pkg/utils/net" "kubesphere.io/kubesphere/pkg/utils/reflectutils" ) type RedisOptions struct { - Host string - Port int - Password string - DB int + RedisURL string } // NewRedisOptions returns options points to nowhere, // because redis is not required for some components func NewRedisOptions() *RedisOptions { return &RedisOptions{ - Host: "", - Port: 6379, - Password: "", - DB: 0, + RedisURL: "", } } @@ -29,14 +22,10 @@ func NewRedisOptions() *RedisOptions { func (r *RedisOptions) Validate() []error { errors := make([]error, 0) - if r.Host != "" { - if !net.IsValidPort(r.Port) { - errors = append(errors, fmt.Errorf("--redis-port is out of range")) - } - } + _, err := redis.ParseURL(r.RedisURL) - if r.DB < 0 { - errors = append(errors, fmt.Errorf("--redis-db is less than 0")) + if err != nil { + errors = append(errors, err) } return errors @@ -44,7 +33,7 @@ func (r *RedisOptions) Validate() []error { // ApplyTo apply to another options if it's a enabled option(non empty host) func (r *RedisOptions) ApplyTo(options *RedisOptions) { - if r.Host != "" { + if r.RedisURL != "" { reflectutils.Override(options, r) } } @@ -52,16 +41,6 @@ func (r *RedisOptions) ApplyTo(options *RedisOptions) { // AddFlags add option flags to command line flags, // if redis-host left empty, the following options will be ignored. func (r *RedisOptions) AddFlags(fs *pflag.FlagSet) { - fs.StringVar(&r.Host, "redis-host", r.Host, ""+ - "Redis service host address. If left blank, means redis is unnecessary, "+ - "redis will be disabled") - - fs.IntVar(&r.Port, "redis-port", r.Port, ""+ - "Redis service port number.") - - fs.StringVar(&r.Password, "redis-password", r.Password, ""+ - "Redis service password if necessary, default to empty") - - fs.IntVar(&r.DB, "redis-db", r.DB, ""+ - "Redis service database index, default to 0.") + fs.StringVar(&r.RedisURL, "redis-url", "", "Redis connection URL. If left blank, means redis is unnecessary, "+ + "redis will be disabled. e.g. redis://:password@host:port/db") } diff --git a/pkg/simple/client/redis/redis.go b/pkg/simple/client/redis/redis.go index aafade8d5..4a1fb83b3 100644 --- a/pkg/simple/client/redis/redis.go +++ b/pkg/simple/client/redis/redis.go @@ -18,7 +18,6 @@ package redis import ( - "fmt" "github.com/go-redis/redis" "k8s.io/klog" ) @@ -39,11 +38,14 @@ func NewRedisClientOrDie(options *RedisOptions, stopCh <-chan struct{}) *RedisCl func NewRedisClient(option *RedisOptions, stopCh <-chan struct{}) (*RedisClient, error) { var r RedisClient - r.client = redis.NewClient(&redis.Options{ - Addr: fmt.Sprintf("%s:%d", option.Host, option.Port), - Password: option.Password, - DB: option.DB, - }) + options, err := redis.ParseURL(option.RedisURL) + + if err != nil { + klog.Error(err) + return nil, err + } + + r.client = redis.NewClient(options) if err := r.client.Ping().Err(); err != nil { klog.Error("unable to reach redis host", err) @@ -51,12 +53,14 @@ func NewRedisClient(option *RedisOptions, stopCh <-chan struct{}) (*RedisClient, return nil, err } - go func() { - <-stopCh - if err := r.client.Close(); err != nil { - klog.Error(err) - } - }() + if stopCh != nil { + go func() { + <-stopCh + if err := r.client.Close(); err != nil { + klog.Error(err) + } + }() + } return &r, nil }