feat: support resource owner password credentials grant
Signed-off-by: hongming <talonwan@yunify.com>
This commit is contained in:
@@ -27,6 +27,7 @@ import (
|
||||
"kubesphere.io/kubesphere/pkg/simple/client/mysql"
|
||||
"kubesphere.io/kubesphere/pkg/simple/client/redis"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ServerRunOptions struct {
|
||||
@@ -37,7 +38,7 @@ type ServerRunOptions struct {
|
||||
MySQLOptions *mysql.MySQLOptions
|
||||
AdminEmail string
|
||||
AdminPassword string
|
||||
TokenIdleTimeout string
|
||||
TokenIdleTimeout time.Duration
|
||||
JWTSecret string
|
||||
AuthRateLimit string
|
||||
EnableMultiLogin bool
|
||||
@@ -61,7 +62,7 @@ func (s *ServerRunOptions) Flags() (fss cliflag.NamedFlagSets) {
|
||||
s.GenericServerRunOptions.AddFlags(fs)
|
||||
fs.StringVar(&s.AdminEmail, "admin-email", "admin@kubesphere.io", "default administrator's email")
|
||||
fs.StringVar(&s.AdminPassword, "admin-password", "passw0rd", "default administrator's password")
|
||||
fs.StringVar(&s.TokenIdleTimeout, "token-idle-timeout", "30m", "tokens that are idle beyond that time will expire,0s means the token has no expiration time. valid time units are \"ns\",\"us\",\"ms\",\"s\",\"m\",\"h\"")
|
||||
fs.DurationVar(&s.TokenIdleTimeout, "token-idle-timeout", 30*time.Minute, "tokens that are idle beyond that time will expire,0s means the token has no expiration time. valid time units are \"ns\",\"us\",\"ms\",\"s\",\"m\",\"h\"")
|
||||
fs.StringVar(&s.JWTSecret, "jwt-secret", "", "jwt secret")
|
||||
fs.StringVar(&s.AuthRateLimit, "auth-rate-limit", "5/30m", "specifies the maximum number of authentication attempts permitted and time interval,valid time units are \"s\",\"m\",\"h\"")
|
||||
fs.BoolVar(&s.EnableMultiLogin, "enable-multi-login", false, "allow one account to have multiple sessions")
|
||||
|
||||
@@ -94,7 +94,7 @@ func Run(s *options.ServerRunOptions, stopChan <-chan struct{}) error {
|
||||
|
||||
waitForResourceSync(stopChan)
|
||||
|
||||
err := iam.Init(s.AdminEmail, s.AdminPassword, s.TokenIdleTimeout, s.AuthRateLimit, s.EnableMultiLogin)
|
||||
err := iam.Init(s.AdminEmail, s.AdminPassword, s.AuthRateLimit, s.TokenIdleTimeout, s.EnableMultiLogin)
|
||||
|
||||
jwtutil.Setup(s.JWTSecret)
|
||||
|
||||
|
||||
@@ -20,11 +20,11 @@ package authenticate
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/go-redis/redis"
|
||||
"k8s.io/apimachinery/pkg/util/sets"
|
||||
"k8s.io/apiserver/pkg/authentication/user"
|
||||
"k8s.io/apiserver/pkg/endpoints/request"
|
||||
"k8s.io/klog"
|
||||
"kubesphere.io/kubesphere/pkg/simple/client/redis"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
@@ -43,9 +43,9 @@ type Auth struct {
|
||||
type Rule struct {
|
||||
Secret []byte
|
||||
Path string
|
||||
RedisOptions *redis.Options
|
||||
RedisOptions *redis.RedisOptions
|
||||
TokenIdleTimeout time.Duration
|
||||
RedisClient *redis.Client
|
||||
RedisClient *redis.RedisClient
|
||||
ExceptedPath []string
|
||||
}
|
||||
|
||||
@@ -187,13 +187,13 @@ func (h Auth) Validate(uToken string) (*jwt.Token, error) {
|
||||
}
|
||||
|
||||
if _, ok = payload["exp"]; ok {
|
||||
// allow static token when contain expiration time
|
||||
// allow static token has expiration time
|
||||
return token, nil
|
||||
}
|
||||
|
||||
tokenKey := fmt.Sprintf("kubesphere:users:%s:token:%s", username, uToken)
|
||||
|
||||
exist, err := h.Rule.RedisClient.Exists(tokenKey).Result()
|
||||
exist, err := h.Rule.RedisClient.Redis().Exists(tokenKey).Result()
|
||||
if err != nil {
|
||||
klog.Error(err)
|
||||
return nil, err
|
||||
@@ -201,7 +201,7 @@ func (h Auth) Validate(uToken string) (*jwt.Token, error) {
|
||||
|
||||
if exist == 1 {
|
||||
// reset expiration time if token exist
|
||||
h.Rule.RedisClient.Expire(tokenKey, h.Rule.TokenIdleTimeout)
|
||||
h.Rule.RedisClient.Redis().Expire(tokenKey, h.Rule.TokenIdleTimeout)
|
||||
return token, nil
|
||||
} else {
|
||||
return nil, errors.New("illegal token")
|
||||
|
||||
@@ -19,7 +19,8 @@ package authenticate
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/go-redis/redis"
|
||||
"kubesphere.io/kubesphere/pkg/simple/client/redis"
|
||||
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -35,10 +36,9 @@ func Setup(c *caddy.Controller) error {
|
||||
return err
|
||||
}
|
||||
|
||||
rule.RedisClient = redis.NewClient(rule.RedisOptions)
|
||||
|
||||
c.OnStartup(func() error {
|
||||
if err := rule.RedisClient.Ping().Err(); err != nil {
|
||||
rule.RedisClient, err = redis.NewRedisClient(rule.RedisOptions, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println("Authenticate middleware is initiated")
|
||||
@@ -46,7 +46,7 @@ func Setup(c *caddy.Controller) error {
|
||||
})
|
||||
|
||||
c.OnShutdown(func() error {
|
||||
return rule.RedisClient.Close()
|
||||
return rule.RedisClient.Redis().Close()
|
||||
})
|
||||
|
||||
httpserver.GetConfig(c).AddMiddleware(func(next httpserver.Handler) httpserver.Handler {
|
||||
@@ -95,10 +95,12 @@ func parse(c *caddy.Controller) (Rule, error) {
|
||||
return rule, c.ArgErr()
|
||||
}
|
||||
|
||||
if redisOptions, err := redis.ParseURL(c.Val()); err != nil {
|
||||
options := &redis.RedisOptions{RedisURL: c.Val()}
|
||||
|
||||
if err := options.Validate(); len(err) > 0 {
|
||||
return rule, c.ArgErr()
|
||||
} else {
|
||||
rule.RedisOptions = redisOptions
|
||||
rule.RedisOptions = options
|
||||
}
|
||||
|
||||
if c.NextArg() {
|
||||
|
||||
@@ -130,7 +130,13 @@ func addWebService(c *restful.Container) error {
|
||||
To(iam.Login).
|
||||
Doc("KubeSphere APIs support token-based authentication via the Authtoken request header. The POST Login API is used to retrieve the authentication token. After the authentication token is obtained, it must be inserted into the Authtoken header for all requests.").
|
||||
Reads(iam.LoginRequest{}).
|
||||
Returns(http.StatusOK, ok, models.Token{}).
|
||||
Returns(http.StatusOK, ok, models.AuthGrantResponse{}).
|
||||
Metadata(restfulspec.KeyOpenAPITags, []string{constants.IdentityManagementTag}))
|
||||
ws.Route(ws.POST("/token").
|
||||
To(iam.OAuth).
|
||||
Doc("OAuth API,only support resource owner password credentials grant").
|
||||
Reads(iam.LoginRequest{}).
|
||||
Returns(http.StatusOK, ok, models.AuthGrantResponse{}).
|
||||
Metadata(restfulspec.KeyOpenAPITags, []string{constants.IdentityManagementTag}))
|
||||
ws.Route(ws.GET("/users/{user}").
|
||||
To(iam.DescribeUser).
|
||||
|
||||
@@ -18,15 +18,16 @@
|
||||
package iam
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
"github.com/emicklei/go-restful"
|
||||
"k8s.io/klog"
|
||||
"kubesphere.io/kubesphere/pkg/models"
|
||||
"kubesphere.io/kubesphere/pkg/models/iam"
|
||||
"kubesphere.io/kubesphere/pkg/server/errors"
|
||||
"kubesphere.io/kubesphere/pkg/utils/iputil"
|
||||
"kubesphere.io/kubesphere/pkg/utils/jwtutil"
|
||||
"net/http"
|
||||
|
||||
"kubesphere.io/kubesphere/pkg/models/iam"
|
||||
"kubesphere.io/kubesphere/pkg/server/errors"
|
||||
)
|
||||
|
||||
type Spec struct {
|
||||
@@ -50,8 +51,14 @@ type LoginRequest struct {
|
||||
Password string `json:"password" description:"password"`
|
||||
}
|
||||
|
||||
type OAuthRequest struct {
|
||||
GrantType string `json:"grant_type"`
|
||||
Username string `json:"username,omitempty" description:"username"`
|
||||
Password string `json:"password,omitempty" description:"password"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
}
|
||||
|
||||
const (
|
||||
APIVersion = "authentication.k8s.io/v1beta1"
|
||||
KindTokenReview = "TokenReview"
|
||||
)
|
||||
|
||||
@@ -81,6 +88,39 @@ func Login(req *restful.Request, resp *restful.Response) {
|
||||
resp.WriteAsJson(token)
|
||||
}
|
||||
|
||||
func OAuth(req *restful.Request, resp *restful.Response) {
|
||||
|
||||
authRequest := &OAuthRequest{}
|
||||
|
||||
err := req.ReadEntity(authRequest)
|
||||
|
||||
if err != nil {
|
||||
resp.WriteHeaderAndEntity(http.StatusBadRequest, errors.Wrap(err))
|
||||
return
|
||||
}
|
||||
var result *models.AuthGrantResponse
|
||||
switch authRequest.GrantType {
|
||||
case "refresh_token":
|
||||
result, err = iam.RefreshToken(authRequest.RefreshToken)
|
||||
case "password":
|
||||
ip := iputil.RemoteIp(req.Request)
|
||||
result, err = iam.PasswordCredentialGrant(authRequest.Username, authRequest.Password, ip)
|
||||
default:
|
||||
resp.Header().Set("WWW-Authenticate", "grant_type is not supported")
|
||||
resp.WriteHeaderAndEntity(http.StatusUnauthorized, errors.Wrap(fmt.Errorf("grant_type is not supported")))
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
resp.Header().Set("WWW-Authenticate", err.Error())
|
||||
resp.WriteHeaderAndEntity(http.StatusUnauthorized, errors.Wrap(err))
|
||||
return
|
||||
}
|
||||
|
||||
resp.WriteEntity(result)
|
||||
|
||||
}
|
||||
|
||||
// k8s token review
|
||||
func TokenReviewHandler(req *restful.Request, resp *restful.Response) {
|
||||
var tokenReview TokenReview
|
||||
@@ -103,7 +143,7 @@ func TokenReviewHandler(req *restful.Request, resp *restful.Response) {
|
||||
|
||||
if err != nil {
|
||||
klog.Errorln("token review failed", uToken, err)
|
||||
failed := TokenReview{APIVersion: APIVersion,
|
||||
failed := TokenReview{APIVersion: tokenReview.APIVersion,
|
||||
Kind: KindTokenReview,
|
||||
Status: &Status{
|
||||
Authenticated: false,
|
||||
@@ -138,7 +178,7 @@ func TokenReviewHandler(req *restful.Request, resp *restful.Response) {
|
||||
|
||||
user.Groups = groups
|
||||
|
||||
success := TokenReview{APIVersion: APIVersion,
|
||||
success := TokenReview{APIVersion: tokenReview.APIVersion,
|
||||
Kind: KindTokenReview,
|
||||
Status: &Status{
|
||||
Authenticated: true,
|
||||
|
||||
@@ -70,13 +70,12 @@ const (
|
||||
authRateLimitRegex = `(\d+)/(\d+[s|m|h])`
|
||||
defaultMaxAuthFailed = 5
|
||||
defaultAuthTimeInterval = 30 * time.Minute
|
||||
defaultTokenIdleTimeout = 30 * time.Minute
|
||||
)
|
||||
|
||||
func Init(email, password, idleTimeout, authRateLimit string, multiLogin bool) error {
|
||||
func Init(email, password, authRateLimit string, idleTimeout time.Duration, multiLogin bool) error {
|
||||
adminEmail = email
|
||||
adminPassword = password
|
||||
tokenIdleTimeout = parseTokenIdleTimeout(idleTimeout)
|
||||
tokenIdleTimeout = idleTimeout
|
||||
maxAuthFailed, authTimeInterval = parseAuthRateLimit(authRateLimit)
|
||||
enableMultiLogin = multiLogin
|
||||
|
||||
@@ -97,15 +96,6 @@ func Init(email, password, idleTimeout, authRateLimit string, multiLogin bool) e
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseTokenIdleTimeout(tokenExpirationTime string) time.Duration {
|
||||
duration, err := time.ParseDuration(tokenExpirationTime)
|
||||
if err != nil {
|
||||
return defaultTokenIdleTimeout
|
||||
} else {
|
||||
return duration
|
||||
}
|
||||
}
|
||||
|
||||
func parseAuthRateLimit(authRateLimit string) (int, time.Duration) {
|
||||
regex := regexp.MustCompile(authRateLimitRegex)
|
||||
groups := regex.FindStringSubmatch(authRateLimit)
|
||||
@@ -255,8 +245,151 @@ func createGroupsBaseDN() error {
|
||||
return conn.Add(groupsCreateRequest)
|
||||
}
|
||||
|
||||
func RefreshToken(refreshToken string) (*models.AuthGrantResponse, error) {
|
||||
validRefreshToken, err := jwtutil.ValidateToken(refreshToken)
|
||||
if err != nil {
|
||||
klog.Error(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
payload, ok := validRefreshToken.Claims.(jwt.MapClaims)
|
||||
|
||||
if !ok {
|
||||
err = errors.New("invalid payload")
|
||||
klog.Error(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
claims := jwt.MapClaims{}
|
||||
|
||||
// token with expiration time will not auto sliding
|
||||
claims["username"] = payload["username"]
|
||||
claims["email"] = payload["email"]
|
||||
claims["iat"] = time.Now().Unix()
|
||||
claims["exp"] = time.Now().Add(tokenIdleTimeout * 4).Unix()
|
||||
|
||||
token := jwtutil.MustSigned(claims)
|
||||
|
||||
claims = jwt.MapClaims{}
|
||||
claims["username"] = payload["username"]
|
||||
claims["email"] = payload["email"]
|
||||
claims["iat"] = time.Now().Unix()
|
||||
claims["type"] = "refresh_token"
|
||||
claims["exp"] = time.Now().Add(tokenIdleTimeout * 5).Unix()
|
||||
|
||||
refreshToken = jwtutil.MustSigned(claims)
|
||||
|
||||
return &models.AuthGrantResponse{TokenType: "jwt", Token: token, RefreshToken: refreshToken, ExpiresIn: (tokenIdleTimeout * 4).Seconds()}, nil
|
||||
}
|
||||
|
||||
func PasswordCredentialGrant(username, password, ip string) (*models.AuthGrantResponse, error) {
|
||||
redisClient, err := clientset.ClientSets().Redis()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
records, err := redisClient.Keys(fmt.Sprintf("kubesphere:authfailed:%s:*", username)).Result()
|
||||
|
||||
if err != nil {
|
||||
klog.Error(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(records) >= maxAuthFailed {
|
||||
return nil, restful.NewError(http.StatusTooManyRequests, "auth rate limit exceeded")
|
||||
}
|
||||
|
||||
client, err := clientset.ClientSets().Ldap()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn, err := client.NewConn()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
userSearchRequest := ldap.NewSearchRequest(
|
||||
client.UserSearchBase(),
|
||||
ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false,
|
||||
fmt.Sprintf("(&(objectClass=inetOrgPerson)(|(uid=%s)(mail=%s)))", username, username),
|
||||
[]string{"uid", "mail"},
|
||||
nil,
|
||||
)
|
||||
|
||||
result, err := conn.Search(userSearchRequest)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(result.Entries) != 1 {
|
||||
return nil, ldap.NewError(ldap.LDAPResultInvalidCredentials, errors.New("incorrect password"))
|
||||
}
|
||||
|
||||
uid := result.Entries[0].GetAttributeValue("uid")
|
||||
email := result.Entries[0].GetAttributeValue("mail")
|
||||
dn := result.Entries[0].DN
|
||||
|
||||
// bind as the user to verify their password
|
||||
err = conn.Bind(dn, password)
|
||||
|
||||
if err != nil {
|
||||
klog.Infoln("auth failed", username, err)
|
||||
|
||||
if ldap.IsErrorWithCode(err, ldap.LDAPResultInvalidCredentials) {
|
||||
loginFailedRecord := fmt.Sprintf("kubesphere:authfailed:%s:%d", uid, time.Now().UnixNano())
|
||||
redisClient.Set(loginFailedRecord, "", authTimeInterval)
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
claims := jwt.MapClaims{}
|
||||
|
||||
// token with expiration time will not auto sliding
|
||||
claims["username"] = uid
|
||||
claims["email"] = email
|
||||
claims["iat"] = time.Now().Unix()
|
||||
claims["exp"] = time.Now().Add(tokenIdleTimeout * 4).Unix()
|
||||
|
||||
token := jwtutil.MustSigned(claims)
|
||||
|
||||
if !enableMultiLogin {
|
||||
// multi login not allowed, remove the previous token
|
||||
sessions, err := redisClient.Keys(fmt.Sprintf("kubesphere:users:%s:token:*", uid)).Result()
|
||||
|
||||
if err != nil {
|
||||
klog.Errorln(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(sessions) > 0 {
|
||||
klog.V(4).Infoln("revoke token", sessions)
|
||||
err = redisClient.Del(sessions...).Err()
|
||||
if err != nil {
|
||||
klog.Errorln(err)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
claims = jwt.MapClaims{}
|
||||
claims["username"] = uid
|
||||
claims["email"] = email
|
||||
claims["iat"] = time.Now().Unix()
|
||||
claims["type"] = "refresh_token"
|
||||
claims["exp"] = time.Now().Add(tokenIdleTimeout * 5).Unix()
|
||||
|
||||
refreshToken := jwtutil.MustSigned(claims)
|
||||
|
||||
loginLog(uid, ip)
|
||||
|
||||
return &models.AuthGrantResponse{TokenType: "jwt", Token: token, RefreshToken: refreshToken, ExpiresIn: (tokenIdleTimeout * 4).Seconds()}, nil
|
||||
}
|
||||
|
||||
// User login
|
||||
func Login(username string, password string, ip string) (*models.Token, error) {
|
||||
func Login(username, password, ip string) (*models.AuthGrantResponse, error) {
|
||||
|
||||
redisClient, err := clientset.ClientSets().Redis()
|
||||
if err != nil {
|
||||
@@ -322,7 +455,7 @@ func Login(username string, password string, ip string) (*models.Token, error) {
|
||||
|
||||
claims := jwt.MapClaims{}
|
||||
|
||||
// do not set expiration time
|
||||
// token without expiration time will auto sliding
|
||||
claims["username"] = uid
|
||||
claims["email"] = email
|
||||
claims["iat"] = time.Now().Unix()
|
||||
@@ -356,7 +489,7 @@ func Login(username string, password string, ip string) (*models.Token, error) {
|
||||
|
||||
loginLog(uid, ip)
|
||||
|
||||
return &models.Token{Token: token}, nil
|
||||
return &models.AuthGrantResponse{Token: token}, nil
|
||||
}
|
||||
|
||||
func loginLog(uid, ip string) {
|
||||
|
||||
@@ -107,8 +107,11 @@ type PodInfo struct {
|
||||
Container string `json:"container" description:"container name"`
|
||||
}
|
||||
|
||||
type Token struct {
|
||||
Token string `json:"access_token" description:"access token"`
|
||||
type AuthGrantResponse struct {
|
||||
TokenType string `json:"token_type,omitempty"`
|
||||
Token string `json:"access_token" description:"access token"`
|
||||
ExpiresIn float64 `json:"expires_in,omitempty"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
}
|
||||
|
||||
type ResourceQuota struct {
|
||||
|
||||
@@ -229,7 +229,7 @@ func (c *Config) stripEmptyOptions() {
|
||||
c.MySQLOptions = nil
|
||||
}
|
||||
|
||||
if c.RedisOptions != nil && c.RedisOptions.Host == "" {
|
||||
if c.RedisOptions != nil && c.RedisOptions.RedisURL == "" {
|
||||
c.RedisOptions = nil
|
||||
}
|
||||
|
||||
|
||||
@@ -63,10 +63,7 @@ func newTestConfig() *Config {
|
||||
GroupSearchBase: "ou=Groups,dc=example,dc=org",
|
||||
},
|
||||
RedisOptions: &redis.RedisOptions{
|
||||
Host: "10.10.111.110",
|
||||
Port: 6379,
|
||||
Password: "",
|
||||
DB: 0,
|
||||
RedisURL: "redis://:qwerty@localhost:6379/1",
|
||||
},
|
||||
S3Options: &s2is3.S3Options{
|
||||
Endpoint: "http://minio.openpitrix-system.svc",
|
||||
|
||||
@@ -181,7 +181,7 @@ func (cs *ClientSet) MySQL() (*mysql.Database, error) {
|
||||
func (cs *ClientSet) Redis() (*goredis.Client, error) {
|
||||
var err error
|
||||
|
||||
if cs.csoptions.redisOptions == nil || cs.csoptions.redisOptions.Host == "" {
|
||||
if cs.csoptions.redisOptions == nil || cs.csoptions.redisOptions.RedisURL == "" {
|
||||
return nil, ClientSetNotEnabledError{}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,27 +1,20 @@
|
||||
package redis
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/go-redis/redis"
|
||||
"github.com/spf13/pflag"
|
||||
"kubesphere.io/kubesphere/pkg/utils/net"
|
||||
"kubesphere.io/kubesphere/pkg/utils/reflectutils"
|
||||
)
|
||||
|
||||
type RedisOptions struct {
|
||||
Host string
|
||||
Port int
|
||||
Password string
|
||||
DB int
|
||||
RedisURL string
|
||||
}
|
||||
|
||||
// NewRedisOptions returns options points to nowhere,
|
||||
// because redis is not required for some components
|
||||
func NewRedisOptions() *RedisOptions {
|
||||
return &RedisOptions{
|
||||
Host: "",
|
||||
Port: 6379,
|
||||
Password: "",
|
||||
DB: 0,
|
||||
RedisURL: "",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,14 +22,10 @@ func NewRedisOptions() *RedisOptions {
|
||||
func (r *RedisOptions) Validate() []error {
|
||||
errors := make([]error, 0)
|
||||
|
||||
if r.Host != "" {
|
||||
if !net.IsValidPort(r.Port) {
|
||||
errors = append(errors, fmt.Errorf("--redis-port is out of range"))
|
||||
}
|
||||
}
|
||||
_, err := redis.ParseURL(r.RedisURL)
|
||||
|
||||
if r.DB < 0 {
|
||||
errors = append(errors, fmt.Errorf("--redis-db is less than 0"))
|
||||
if err != nil {
|
||||
errors = append(errors, err)
|
||||
}
|
||||
|
||||
return errors
|
||||
@@ -44,7 +33,7 @@ func (r *RedisOptions) Validate() []error {
|
||||
|
||||
// ApplyTo apply to another options if it's a enabled option(non empty host)
|
||||
func (r *RedisOptions) ApplyTo(options *RedisOptions) {
|
||||
if r.Host != "" {
|
||||
if r.RedisURL != "" {
|
||||
reflectutils.Override(options, r)
|
||||
}
|
||||
}
|
||||
@@ -52,16 +41,6 @@ func (r *RedisOptions) ApplyTo(options *RedisOptions) {
|
||||
// AddFlags add option flags to command line flags,
|
||||
// if redis-host left empty, the following options will be ignored.
|
||||
func (r *RedisOptions) AddFlags(fs *pflag.FlagSet) {
|
||||
fs.StringVar(&r.Host, "redis-host", r.Host, ""+
|
||||
"Redis service host address. If left blank, means redis is unnecessary, "+
|
||||
"redis will be disabled")
|
||||
|
||||
fs.IntVar(&r.Port, "redis-port", r.Port, ""+
|
||||
"Redis service port number.")
|
||||
|
||||
fs.StringVar(&r.Password, "redis-password", r.Password, ""+
|
||||
"Redis service password if necessary, default to empty")
|
||||
|
||||
fs.IntVar(&r.DB, "redis-db", r.DB, ""+
|
||||
"Redis service database index, default to 0.")
|
||||
fs.StringVar(&r.RedisURL, "redis-url", "", "Redis connection URL. If left blank, means redis is unnecessary, "+
|
||||
"redis will be disabled. e.g. redis://:password@host:port/db")
|
||||
}
|
||||
|
||||
@@ -18,7 +18,6 @@
|
||||
package redis
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/go-redis/redis"
|
||||
"k8s.io/klog"
|
||||
)
|
||||
@@ -39,11 +38,14 @@ func NewRedisClientOrDie(options *RedisOptions, stopCh <-chan struct{}) *RedisCl
|
||||
func NewRedisClient(option *RedisOptions, stopCh <-chan struct{}) (*RedisClient, error) {
|
||||
var r RedisClient
|
||||
|
||||
r.client = redis.NewClient(&redis.Options{
|
||||
Addr: fmt.Sprintf("%s:%d", option.Host, option.Port),
|
||||
Password: option.Password,
|
||||
DB: option.DB,
|
||||
})
|
||||
options, err := redis.ParseURL(option.RedisURL)
|
||||
|
||||
if err != nil {
|
||||
klog.Error(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r.client = redis.NewClient(options)
|
||||
|
||||
if err := r.client.Ping().Err(); err != nil {
|
||||
klog.Error("unable to reach redis host", err)
|
||||
@@ -51,12 +53,14 @@ func NewRedisClient(option *RedisOptions, stopCh <-chan struct{}) (*RedisClient,
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go func() {
|
||||
<-stopCh
|
||||
if err := r.client.Close(); err != nil {
|
||||
klog.Error(err)
|
||||
}
|
||||
}()
|
||||
if stopCh != nil {
|
||||
go func() {
|
||||
<-stopCh
|
||||
if err := r.client.Close(); err != nil {
|
||||
klog.Error(err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return &r, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user