Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pkg/microservice/user/core/handler/user/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ func ListUsers(c *gin.Context) {
}

if len(args.UIDs) > 0 {
ctx.Resp, ctx.RespErr = permission.SearchUsersByUIDs(args.UIDs, ctx.Logger)
ctx.Resp, ctx.RespErr = permission.SearchUsersByUIDs(args.UIDs, args.MFAEnabled, ctx.Logger)
} else if len(args.Account) > 0 {
if len(args.IdentityType) == 0 {
args.IdentityType = config.SystemIdentityType
Expand Down Expand Up @@ -321,6 +321,7 @@ func OpenAPIListUsersBrief(c *gin.Context) {
Roles: args.Roles,
Project: args.Project,
IdentityType: args.IdentityType,
MFAEnabled: args.MFAEnabled,
}

var resp *types.UsersResp
Expand Down Expand Up @@ -383,7 +384,7 @@ func ListUsersBrief(c *gin.Context) {

var resp *types.UsersResp
if len(args.UIDs) > 0 {
resp, err = permission.SearchUsersByUIDs(args.UIDs, ctx.Logger)
resp, err = permission.SearchUsersByUIDs(args.UIDs, args.MFAEnabled, ctx.Logger)
} else if len(args.Account) > 0 {
if len(args.IdentityType) == 0 {
args.IdentityType = config.SystemIdentityType
Expand Down
1 change: 1 addition & 0 deletions pkg/microservice/user/core/repository/models/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ type User struct {
Email string `json:"email"`
Phone string `json:"phone"`
Account string `json:"account"`
MFAEnabled bool `gorm:"->;column:mfa_enabled;-:migration" json:"mfa_enabled"`
APIToken string `gorm:"api_token" json:"api_token"`
APITokenEnabled bool `gorm:"column:api_token_enabled;default:0" json:"api_token_enabled"`

Expand Down
44 changes: 44 additions & 0 deletions pkg/microservice/user/core/repository/orm/role.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,50 @@ func ListRoleByUIDAndNamespace(uid, namespace string, db *gorm.DB) ([]*models.Ne
return resp, nil
}

// ListRoleByUIDsAndNamespace lists roles for the given users in a namespace with a single query.
func ListRoleByUIDsAndNamespace(uids []string, namespace string, db *gorm.DB) (map[string][]*models.NewRole, error) {
if len(uids) == 0 {
return map[string][]*models.NewRole{}, nil
}

type uidRole struct {
UID string `gorm:"column:uid"`
ID uint `gorm:"column:id"`
Name string `gorm:"column:name"`
Description string `gorm:"column:description"`
Type int64 `gorm:"column:type"`
Namespace string `gorm:"column:namespace"`
GlobalReadOnly bool `gorm:"column:global_read_only"`
}

rows := make([]*uidRole, 0)
err := db.Table("role").
Select("role_binding.uid, role.id, role.name, role.description, role.type, role.namespace, role.global_read_only").
Joins("INNER JOIN role_binding ON role.id = role_binding.role_id").
Where("role.namespace = ?", namespace).
Where("role_binding.uid IN ?", uids).
Order("role_binding.uid ASC").
Order("role.id ASC").
Scan(&rows).Error
if err != nil {
return nil, err
}

resp := make(map[string][]*models.NewRole, len(rows))
for _, row := range rows {
resp[row.UID] = append(resp[row.UID], &models.NewRole{
ID: row.ID,
Name: row.Name,
Description: row.Description,
Type: row.Type,
Namespace: row.Namespace,
GlobalReadOnly: row.GlobalReadOnly,
})
}

return resp, nil
}

// ListRoleByUID list a set of roles that is used by specific user in ALL namespace
func ListRoleByUID(uid string, db *gorm.DB) ([]*models.NewRole, error) {
resp := make([]*models.NewRole, 0)
Expand Down
135 changes: 84 additions & 51 deletions pkg/microservice/user/core/repository/orm/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ import (
"github.com/koderover/zadig/v2/pkg/types"
)

const (
userMFAJoinClause = "LEFT JOIN user_mfa ON user_mfa.uid = user.uid"
userMFAEnabledSelectExpr = "IFNULL(user_mfa.enabled, 0) AS mfa_enabled"
)

// CreateUser create a user
func CreateUser(user *models.User, db *gorm.DB) error {
if err := db.Create(&user).Error; err != nil {
Expand All @@ -46,6 +51,22 @@ func GetUser(account string, identityType string, db *gorm.DB) (*models.User, er
return &user, nil
}

func GetUserByAccountAndMFAEnabled(account string, identityType string, mfaEnabled *bool, db *gorm.DB) (*models.User, error) {
var user models.User
query := db.Model(&models.User{}).
Select("user.*, "+userMFAEnabledSelectExpr).
Where("account = ? and identity_type = ?", account, identityType)
err := applyMFAEnabledJoinFilter(query, mfaEnabled).
First(&user).Error
if err != nil && err != gorm.ErrRecordNotFound {
return nil, err
}
if err == gorm.ErrRecordNotFound {
return nil, nil
}
return &user, nil
}

// GetUserByUid Get a user based on uid
func GetUserByUid(uid string, db *gorm.DB) (*models.User, error) {
var user models.User
Expand Down Expand Up @@ -80,13 +101,17 @@ func ListAllUsers(db *gorm.DB) ([]*models.User, error) {
}

// ListUsers gets a list of users based on paging constraints
func ListUsers(page int, perPage int, name string, db *gorm.DB) ([]models.User, error) {
func ListUsers(page int, perPage int, name string, mfaEnabled *bool, db *gorm.DB) ([]models.User, error) {
var (
users []models.User
err error
)

err = db.Where("name LIKE ?", "%"+name+"%").Order("account ASC").Offset((page - 1) * perPage).Limit(perPage).Find(&users).Error
query := db.Model(&models.User{}).
Select("user.*, "+userMFAEnabledSelectExpr).
Where("name LIKE ?", "%"+name+"%")
query = applyMFAEnabledJoinFilter(query, mfaEnabled)
err = query.Order("account ASC").Offset((page - 1) * perPage).Limit(perPage).Find(&users).Error

if err != nil && err != gorm.ErrRecordNotFound {
return nil, err
Expand All @@ -95,15 +120,18 @@ func ListUsers(page int, perPage int, name string, db *gorm.DB) ([]models.User,
return users, nil
}

func ListUsersByLoginTime(page int, perPage int, name string, order setting.ListUserOrder, db *gorm.DB) ([]models.UserWithLoginTime, error) {
func ListUsersByLoginTime(page int, perPage int, name string, order setting.ListUserOrder, mfaEnabled *bool, db *gorm.DB) ([]models.UserWithLoginTime, error) {
var (
users []models.UserWithLoginTime
err error
)

err = db.Select("user.uid, user.name, user.account, user.identity_type, user.api_token_enabled, IFNULL(user_login.last_login_time, 0) as last_login_time").
query := db.Model(&models.User{}).
Select("user.uid, user.name, user.account, user.identity_type, user.api_token_enabled, "+userMFAEnabledSelectExpr+", IFNULL(user_login.last_login_time, 0) as last_login_time").
Where("user.name LIKE ?", "%"+name+"%").
Joins("LEFT JOIN user_login on user_login.uid = user.uid").
Joins("LEFT JOIN user_login on user_login.uid = user.uid")
query = applyMFAEnabledJoinFilter(query, mfaEnabled)
err = query.
Order("IFNULL(user_login.last_login_time, 0) " + string(order)).
Offset((page - 1) * perPage).
Limit(perPage).
Expand All @@ -117,39 +145,24 @@ func ListUsersByLoginTime(page int, perPage int, name string, order setting.List
return users, nil
}

// listUIDsByRoles returns distinct user uids that have any of the given role names within the namespace.
func listUIDsByRoles(roles []string, namespace string, db *gorm.DB) ([]string, error) {
var uids []string
err := db.Table("role_binding").
Distinct("role_binding.uid").
func uidSubQueryByRoles(roles []string, namespace string, db *gorm.DB) *gorm.DB {
return db.Table("role_binding").
Select("DISTINCT role_binding.uid").
Joins("INNER JOIN role ON role.id = role_binding.role_id").
Where("role.name IN ? AND role.namespace = ?", roles, namespace).
Pluck("role_binding.uid", &uids).Error

if err != nil && err != gorm.ErrRecordNotFound {
return nil, err
}
return uids, nil
Where("role.name IN ? AND role.namespace = ?", roles, namespace)
}

// ListUsersByNameAndRoleWithLoginTime gets a list of users filtered by name and roles,
// ordered by last_login_time with pagination. It is implemented in two simple steps:
// 1. Find the uids of users that have any of the given roles (role_binding + role) within the namespace.
// 2. Query user + user_login for those uids, filter by name, order by last_login_time and paginate.
func ListUsersByNameAndRoleWithLoginTime(page int, perPage int, name string, roles []string, namespace string, order setting.ListUserOrder, db *gorm.DB) ([]models.UserWithLoginTime, error) {
uids, err := listUIDsByRoles(roles, namespace, db)
if err != nil {
return nil, err
}
if len(uids) == 0 {
return []models.UserWithLoginTime{}, nil
}

// ordered by last_login_time with pagination.
func ListUsersByNameAndRoleWithLoginTime(page int, perPage int, name string, roles []string, namespace string, order setting.ListUserOrder, mfaEnabled *bool, db *gorm.DB) ([]models.UserWithLoginTime, error) {
var users []models.UserWithLoginTime
err = db.Table("user").
Select("user.uid, user.name, user.account, user.identity_type, user.api_token_enabled, IFNULL(user_login.last_login_time, 0) AS last_login_time").
roleUIDs := uidSubQueryByRoles(roles, namespace, db)
query := db.Model(&models.User{}).
Select("user.uid, user.name, user.account, user.identity_type, user.api_token_enabled, "+userMFAEnabledSelectExpr+", IFNULL(user_login.last_login_time, 0) AS last_login_time").
Joins("LEFT JOIN user_login ON user_login.uid = user.uid").
Where("user.uid IN ? AND user.name LIKE ?", uids, "%"+name+"%").
Where("user.uid IN (?) AND user.name LIKE ?", roleUIDs, "%"+name+"%")
query = applyMFAEnabledJoinFilter(query, mfaEnabled)
err := query.
Order("last_login_time " + string(order)).
Offset((page - 1) * perPage).
Limit(perPage).
Expand All @@ -162,16 +175,18 @@ func ListUsersByNameAndRoleWithLoginTime(page int, perPage int, name string, rol
}

// ListUsersByNameAndRole gets a list of users based on paging constraints, the name of the user, the roles, and namespace
func ListUsersByNameAndRole(page int, perPage int, name string, roles []string, namespace string, db *gorm.DB) ([]models.User, error) {
func ListUsersByNameAndRole(page int, perPage int, name string, roles []string, namespace string, mfaEnabled *bool, db *gorm.DB) ([]models.User, error) {
var (
users []models.User
err error
)

err = db.Where("user.name LIKE ? AND role.name IN ? AND role.namespace = ?", "%"+name+"%", roles, namespace).
Joins("INNER JOIN role_binding on role_binding.uid = user.uid").
Joins("INNER JOIN role on role_binding.role_id = role.id").Order("account ASC").Offset((page - 1) * perPage).
Group("user.uid").
roleUIDs := uidSubQueryByRoles(roles, namespace, db)
query := db.Model(&models.User{}).
Select("user.*, "+userMFAEnabledSelectExpr).
Where("user.uid IN (?) AND user.name LIKE ?", roleUIDs, "%"+name+"%")
query = applyMFAEnabledJoinFilter(query, mfaEnabled)
err = query.Order("account ASC").Offset((page - 1) * perPage).
Limit(perPage).
Find(&users).
Error
Expand All @@ -183,6 +198,10 @@ func ListUsersByNameAndRole(page int, perPage int, name string, roles []string,
return users, nil
}

func joinUserMFA(db *gorm.DB) *gorm.DB {
return db.Joins(userMFAJoinClause)
}

func ListUsersByGroup(groupID string, db *gorm.DB) ([]*models.User, error) {
resp := make([]*models.User, 0)

Expand All @@ -199,13 +218,16 @@ func ListUsersByGroup(groupID string, db *gorm.DB) ([]*models.User, error) {
}

// ListUsersByUIDs gets a list of users based on paging constraints
func ListUsersByUIDs(uids []string, db *gorm.DB) ([]models.User, error) {
func ListUsersByUIDs(uids []string, mfaEnabled *bool, db *gorm.DB) ([]models.User, error) {
var (
users []models.User
err error
)

err = db.Find(&users, "uid in ?", uids).Error
query := db.Model(&models.User{}).
Select("user.*, "+userMFAEnabledSelectExpr).
Where("user.uid in ?", uids)
err = applyMFAEnabledJoinFilter(query, mfaEnabled).Find(&users).Error

if err != nil && err != gorm.ErrRecordNotFound {
return nil, err
Expand Down Expand Up @@ -251,14 +273,15 @@ func DeleteUserByUid(uid string, db *gorm.DB) error {
}

// GetUsersCount gets user count
func GetUsersCount(name string) (int64, error) {
func GetUsersCount(name string, mfaEnabled *bool) (int64, error) {
var (
users []models.User
err error
count int64
)

err = repository.DB.Where("name LIKE ?", "%"+name+"%").Find(&users).Count(&count).Error
query := repository.DB.Model(&models.User{}).Where("name LIKE ?", "%"+name+"%")
query = applyMFAEnabledJoinFilter(query, mfaEnabled)
err = query.Count(&count).Error

if err != nil {
return 0, err
Expand All @@ -268,20 +291,17 @@ func GetUsersCount(name string) (int64, error) {
}

// GetUsersCountByRoles gets user count filtered by roles and namespace
func GetUsersCountByRoles(name string, roles []string, namespace string) (int64, error) {
func GetUsersCountByRoles(name string, roles []string, namespace string, mfaEnabled *bool) (int64, error) {
var (
users []models.User
err error
count int64
)

err = repository.DB.Where("user.name LIKE ? AND role.name IN ? AND role.namespace = ?", "%"+name+"%", roles, namespace).
Joins("INNER JOIN role_binding on role_binding.uid = user.uid").
Joins("INNER JOIN role on role_binding.role_id = role.id").
Group("user.uid").
Find(&users).
Count(&count).
Error
roleUIDs := uidSubQueryByRoles(roles, namespace, repository.DB)
query := repository.DB.Model(&models.User{}).
Where("user.uid IN (?) AND user.name LIKE ?", roleUIDs, "%"+name+"%")
query = applyMFAEnabledJoinFilter(query, mfaEnabled)
err = query.Count(&count).Error

if err != nil {
return 0, err
Expand All @@ -290,6 +310,19 @@ func GetUsersCountByRoles(name string, roles []string, namespace string) (int64,
return count, nil
}

func applyMFAEnabledJoinFilter(db *gorm.DB, mfaEnabled *bool) *gorm.DB {
db = joinUserMFA(db)
if mfaEnabled == nil {
return db
}

if *mfaEnabled {
return db.Where("user_mfa.enabled = ?", true)
}

return db.Where("user_mfa.enabled IS NULL OR user_mfa.enabled = ?", false)
}

// UpdateUser update user info
func UpdateUser(uid string, user *models.User, db *gorm.DB) error {
if err := db.Model(&models.User{}).Where("uid = ?", uid).Updates(user).Error; err != nil {
Expand Down
11 changes: 0 additions & 11 deletions pkg/microservice/user/core/repository/orm/user_mfa.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,6 @@ func GetUserMFA(uid string, db *gorm.DB) (*models.UserMFA, error) {
return res, nil
}

func ListUserMFAsByUIDs(uids []string, db *gorm.DB) ([]*models.UserMFA, error) {
if len(uids) == 0 {
return []*models.UserMFA{}, nil
}
res := make([]*models.UserMFA, 0)
if err := db.Where("uid IN ?", uids).Find(&res).Error; err != nil {
return nil, err
}
return res, nil
}

// EnableUserMFA enables MFA for a user without allowing overwrite of an already-enabled MFA config.
func EnableUserMFA(uid, secretCipher, recoveryCodesJSON string, db *gorm.DB) error {
now := time.Now().Unix()
Expand Down
26 changes: 26 additions & 0 deletions pkg/microservice/user/core/service/permission/role.go
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,32 @@ func ListRolesByNamespaceAndUserID(projectName, uid string, log *zap.SugaredLogg
return resp, nil
}

func ListRolesByNamespaceAndUserIDs(projectName string, uids []string, log *zap.SugaredLogger) (map[string][]*types.Role, error) {
rolesByUID, err := orm.ListRoleByUIDsAndNamespace(uids, projectName, repository.DB)
if err != nil {
log.Errorf("failed to list roles in project: %s, error: %s", projectName, err)
return nil, fmt.Errorf("failed to list roles in project: %s, error: %s", projectName, err)
}

resp := make(map[string][]*types.Role, len(rolesByUID))
for uid, roles := range rolesByUID {
roleList := make([]*types.Role, 0, len(roles))
for _, role := range roles {
roleList = append(roleList, &types.Role{
ID: role.ID,
Name: role.Name,
Namespace: role.Namespace,
Description: role.Description,
Type: convertDBRoleType(role.Type),
GlobalReadOnly: role.GlobalReadOnly,
})
}
resp[uid] = roleList
}

return resp, nil
}

func GetRole(ns, name string, log *zap.SugaredLogger) (*types.DetailedRole, error) {
role, err := orm.GetRole(name, ns, repository.DB)
if err != nil {
Expand Down
Loading
Loading