Merge pull request #362 from runzexia/devops2

add devops project & members & pipeline api
This commit is contained in:
zryfish
2019-04-28 13:53:27 +08:00
committed by GitHub
71 changed files with 10187 additions and 6932 deletions

44
Gopkg.lock generated
View File

@@ -17,6 +17,14 @@
revision = "1a8911d1ed007260465c3bfbbc785ac6915a0bb8"
version = "v0.4.12"
[[projects]]
digest = "1:d043e5ae59276188d876090c261c311e146792c0908e08e67e766b1020a72d00"
name = "github.com/PuerkitoBio/goquery"
packages = ["."]
pruneopts = "NUT"
revision = "2d2796f41742ece03e8086188fa4db16a3a0b458"
version = "v1.5.0"
[[projects]]
digest = "1:0a111edd8693fd977f42a0c4f199a0efb13c20aec9da99ad8830c7bb6a87e8d6"
name = "github.com/PuerkitoBio/purell"
@@ -52,6 +60,14 @@
pruneopts = "NUT"
revision = "8b13a72661dae6e9e5dea04f344f0dc95ea29547"
[[projects]]
digest = "1:fc86904a62ac4bfff8cfbe94f42231ce3d8cea8fe2506d5293eaef468f8eaecf"
name = "github.com/andybalholm/cascadia"
packages = ["."]
pruneopts = "NUT"
revision = "901648c87902174f774fac311d7f176f8647bdaa"
version = "v1.0.0"
[[projects]]
digest = "1:680b63a131506e668818d630d3ca36123ff290afa0afc9f4be21940adca3f27d"
name = "github.com/appscode/jsonpatch"
@@ -68,6 +84,14 @@
revision = "ccb8e960c48f04d6935e72476ae4a51028f9e22f"
version = "v9"
[[projects]]
digest = "1:15e3271f463f2f40d98bf426aabb86941fc66b10272ccfdfebe548683e37acb1"
name = "github.com/beevik/etree"
packages = ["."]
pruneopts = "NUT"
revision = "8aee6516be3b1163bb6450c35c50e4969e3a3aa8"
version = "v1.1.0"
[[projects]]
branch = "master"
digest = "1:707ebe952a8b3d00b343c01536c79c73771d100f63ec6babeaed5c79e2b8a8dd"
@@ -562,22 +586,6 @@
pruneopts = "NUT"
revision = "3eca13d6893afd7ecabe15f4445f5d2872a1b012"
[[projects]]
digest = "1:2ddfc1382a659966038282873c9e33e7694fa503130d445e97c4fdc3b8c5db66"
name = "github.com/jinzhu/gorm"
packages = ["."]
pruneopts = "NUT"
revision = "472c70caa40267cb89fd8facb07fe6454b578626"
version = "v1.9.2"
[[projects]]
branch = "master"
digest = "1:802f75230c29108e787d40679f9bf5da1a5673eaf5c10eb89afd993e18972909"
name = "github.com/jinzhu/inflection"
packages = ["."]
pruneopts = "NUT"
revision = "04140366298a54a039076d798123ffa108fff46c"
[[projects]]
digest = "1:da62aa6632d04e080b8a8b85a59ed9ed1550842a0099a55f3ae3a20d02a3745a"
name = "github.com/joho/godotenv"
@@ -2010,7 +2018,9 @@
analyzer-name = "dep"
analyzer-version = 1
input-imports = [
"github.com/PuerkitoBio/goquery",
"github.com/asaskevich/govalidator",
"github.com/beevik/etree",
"github.com/dgrijalva/jwt-go",
"github.com/docker/docker/api/types",
"github.com/docker/docker/client",
@@ -2026,7 +2036,6 @@
"github.com/golang/example/stringutil",
"github.com/golang/glog",
"github.com/google/uuid",
"github.com/jinzhu/gorm",
"github.com/json-iterator/go",
"github.com/kiali/kiali/config",
"github.com/kiali/kiali/handlers",
@@ -2059,7 +2068,6 @@
"gopkg.in/src-d/go-git.v4/storage/memory",
"gopkg.in/yaml.v2",
"k8s.io/api/apps/v1",
"k8s.io/api/apps/v1beta2",
"k8s.io/api/batch/v1",
"k8s.io/api/batch/v1beta1",
"k8s.io/api/core/v1",

View File

@@ -15,6 +15,7 @@
limitations under the License.
*/
package install
import (
@@ -28,6 +29,6 @@ func init() {
Install(runtime.Container)
}
func Install(c *restful.Container) {
urlruntime.Must(devopsv1alpha2.AddToContainer(c))
func Install(container *restful.Container) {
urlruntime.Must(devopsv1alpha2.AddToContainer(container))
}

View File

@@ -15,6 +15,7 @@
limitations under the License.
*/
package v1alpha2
import (
@@ -24,12 +25,14 @@ import (
devopsapi "kubesphere.io/kubesphere/pkg/apiserver/devops"
"kubesphere.io/kubesphere/pkg/apiserver/runtime"
"kubesphere.io/kubesphere/pkg/models/devops"
"kubesphere.io/kubesphere/pkg/params"
"net/http"
)
const (
GroupName = "devops.kubesphere.io"
RespMessage = "ok"
RespOK = "ok"
)
var GroupVersion = schema.GroupVersion{Group: GroupName, Version: "v1alpha2"}
@@ -40,9 +43,187 @@ var (
)
func addWebService(c *restful.Container) error {
webservice := runtime.NewWebService(GroupVersion)
tags := []string{"devops"}
tags := []string{"DevOps"}
webservice.Route(webservice.GET("/devops/{devops}").
To(devopsapi.GetDevOpsProjectHandler).
Doc("get devops project").
Metadata(restfulspec.KeyOpenAPITags, tags).
Param(webservice.PathParameter("devops", "devops project's Id")).
Returns(http.StatusOK, RespOK, devops.DevOpsProject{}).
Writes(devops.DevOpsProject{}))
webservice.Route(webservice.PATCH("/devops/{devops}").
To(devopsapi.UpdateProjectHandler).
Doc("get devops project").
Metadata(restfulspec.KeyOpenAPITags, tags).
Param(webservice.PathParameter("devops", "devops project's Id")).
Returns(http.StatusOK, RespOK, devops.DevOpsProject{}).
Writes(devops.DevOpsProject{}))
webservice.Route(webservice.GET("/devops/{devops}/defaultroles").
To(devopsapi.GetDevOpsProjectDefaultRoles).
Doc("get devops project defaultroles").
Metadata(restfulspec.KeyOpenAPITags, tags).
Param(webservice.PathParameter("devops", "devops project's Id")).
Returns(http.StatusOK, RespOK, []devops.Role{}).
Writes([]devops.Role{}))
webservice.Route(webservice.GET("/devops/{devops}/members").
To(devopsapi.GetDevOpsProjectMembersHandler).
Doc("get devops project members").
Metadata(restfulspec.KeyOpenAPITags, tags).
Param(webservice.PathParameter("devops", "devops project's Id")).
Param(webservice.QueryParameter(params.PagingParam, "page").
Required(false).
DataFormat("limit=%d,page=%d").
DefaultValue("limit=10,page=1")).
Param(webservice.QueryParameter(params.ConditionsParam, "query conditions").
Required(false).
DataFormat("key=%s,key~%s")).
Returns(http.StatusOK, RespOK, []devops.DevOpsProjectMembership{}).
Writes([]devops.DevOpsProjectMembership{}))
webservice.Route(webservice.GET("/devops/{devops}/members/{members}").
To(devopsapi.GetDevOpsProjectMemberHandler).
Doc("get devops project member").
Metadata(restfulspec.KeyOpenAPITags, tags).
Param(webservice.PathParameter("devops", "devops project's Id")).
Param(webservice.PathParameter("members", "member's username")).
Returns(http.StatusOK, RespOK, devops.DevOpsProjectMembership{}).
Writes(devops.DevOpsProjectMembership{}))
webservice.Route(webservice.POST("/devops/{devops}/members").
To(devopsapi.AddDevOpsProjectMemberHandler).
Doc("add devops project members").
Metadata(restfulspec.KeyOpenAPITags, tags).
Param(webservice.PathParameter("devops", "devops project's Id")).
Returns(http.StatusOK, RespOK, devops.DevOpsProjectMembership{}).
Writes(devops.DevOpsProjectMembership{}))
webservice.Route(webservice.PATCH("/devops/{devops}/members/{members}").
To(devopsapi.UpdateDevOpsProjectMemberHandler).
Doc("update devops project members").
Metadata(restfulspec.KeyOpenAPITags, tags).
Param(webservice.PathParameter("devops", "devops project's Id")).
Param(webservice.PathParameter("members", "member's username")).
Reads(devops.DevOpsProjectMembership{}).
Writes(devops.DevOpsProjectMembership{}))
webservice.Route(webservice.DELETE("/devops/{devops}/members/{members}").
To(devopsapi.DeleteDevOpsProjectMemberHandler).
Doc("delete devops project members").
Metadata(restfulspec.KeyOpenAPITags, tags).
Param(webservice.PathParameter("devops", "devops project's Id")).
Param(webservice.PathParameter("members", "member's username")).
Writes(devops.DevOpsProjectMembership{}))
webservice.Route(webservice.POST("/devops/{devops}/pipelines").
To(devopsapi.CreateDevOpsProjectPipelineHandler).
Doc("add devops project pipeline").
Param(webservice.PathParameter("devops", "devops project's Id")).
Metadata(restfulspec.KeyOpenAPITags, tags).
Returns(http.StatusOK, RespOK, devops.ProjectPipeline{}).
Writes(devops.ProjectPipeline{}).
Reads(devops.ProjectPipeline{}))
webservice.Route(webservice.PUT("/devops/{devops}/pipelines/{pipelines}").
To(devopsapi.UpdateDevOpsProjectPipelineHandler).
Doc("update devops project pipeline").
Param(webservice.PathParameter("devops", "devops project's Id")).
Param(webservice.PathParameter("pipelines", "pipeline name")).
Metadata(restfulspec.KeyOpenAPITags, tags).
Returns(http.StatusOK, RespOK, devops.ProjectPipeline{}).
Writes(devops.ProjectPipeline{}).
Reads(devops.ProjectPipeline{}))
webservice.Route(webservice.GET("/devops/{devops}/pipelines/{pipelines}/config").
To(devopsapi.GetDevOpsProjectPipelineHandler).
Doc("get devops project pipeline config").
Metadata(restfulspec.KeyOpenAPITags, tags).
Param(webservice.PathParameter("devops", "devops project's Id")).
Param(webservice.PathParameter("pipelines", "pipeline name")).
Returns(http.StatusOK, RespOK, devops.ProjectPipeline{}).
Writes(devops.ProjectPipeline{}))
webservice.Route(webservice.GET("/devops/{devops}/pipelines/{pipelines}/sonarStatus").
To(devopsapi.GetPipelineSonarStatusHandler).
Doc("get devops project pipeline sonarStatus").
Metadata(restfulspec.KeyOpenAPITags, tags).
Param(webservice.PathParameter("devops", "devops project's Id")).
Param(webservice.PathParameter("pipelines", "pipeline name")).
Returns(http.StatusOK, RespOK, []devops.SonarStatus{}).
Writes([]devops.SonarStatus{}))
webservice.Route(webservice.GET("/devops/{devops}/pipelines/{pipelines}/branches/{branches}/sonarStatus").
To(devopsapi.GetMultiBranchesPipelineSonarStatusHandler).
Doc("get devops project pipeline sonarStatus").
Metadata(restfulspec.KeyOpenAPITags, tags).
Param(webservice.PathParameter("devops", "devops project's Id")).
Param(webservice.PathParameter("pipelines", "pipeline name")).
Param(webservice.PathParameter("branches", "branch name")).
Returns(http.StatusOK, RespOK, []devops.SonarStatus{}).
Writes([]devops.SonarStatus{}))
webservice.Route(webservice.DELETE("/devops/{devops}/pipelines/{pipelines}").
To(devopsapi.DeleteDevOpsProjectPipelineHandler).
Doc("delete devops project pipeline").
Metadata(restfulspec.KeyOpenAPITags, tags).
Param(webservice.PathParameter("devops", "devops project's Id")).
Param(webservice.PathParameter("pipelines", "pipeline name")))
webservice.Route(webservice.PUT("/devops/{devops}/pipelines").
To(devopsapi.CreateDevOpsProjectPipelineHandler).
Doc("update devops project pipeline").
Param(webservice.PathParameter("devops", "devops project's Id")).
Metadata(restfulspec.KeyOpenAPITags, tags).
Reads(devops.ProjectPipeline{}))
webservice.Route(webservice.POST("/devops/{devops}/credentials").
To(devopsapi.CreateDevOpsProjectCredentialHandler).
Doc("add project credential pipeline").
Metadata(restfulspec.KeyOpenAPITags, tags).
Param(webservice.PathParameter("devops", "devops project's Id")).
Reads(devops.JenkinsCredential{}))
webservice.Route(webservice.PUT("/devops/{devops}/credentials/{credentials}").
To(devopsapi.UpdateDevOpsProjectCredentialHandler).
Doc("update project credential pipeline").
Metadata(restfulspec.KeyOpenAPITags, tags).
Param(webservice.PathParameter("devops", "devops project's Id")).
Param(webservice.PathParameter("credentials", "credential's Id")).
Reads(devops.JenkinsCredential{}))
webservice.Route(webservice.DELETE("/devops/{devops}/credentials/{credentials}").
To(devopsapi.DeleteDevOpsProjectCredentialHandler).
Doc("delete project credential pipeline").
Metadata(restfulspec.KeyOpenAPITags, tags).
Param(webservice.PathParameter("devops", "devops project's Id")).
Param(webservice.PathParameter("credentials", "credential's Id")))
webservice.Route(webservice.GET("/devops/{devops}/credentials/{credentials}").
To(devopsapi.GetDevOpsProjectCredentialHandler).
Doc("get project credential pipeline").
Metadata(restfulspec.KeyOpenAPITags, tags).
Param(webservice.PathParameter("devops", "devops project's Id")).
Param(webservice.PathParameter("credentials", "credential's Id")).
Param(webservice.QueryParameter("domain", "credential's domain")).
Param(webservice.QueryParameter("content", "get additional content")).
Returns(http.StatusOK, RespOK, devops.JenkinsCredential{}).
Reads(devops.JenkinsCredential{}))
webservice.Route(webservice.GET("/devops/{devops}/credentials").
To(devopsapi.GetDevOpsProjectCredentialsHandler).
Doc("get project credential pipeline").
Metadata(restfulspec.KeyOpenAPITags, tags).
Param(webservice.PathParameter("devops", "devops project's Id")).
Param(webservice.PathParameter("credentials", "credential's Id")).
Param(webservice.QueryParameter("domain", "credential's domain")).
Returns(http.StatusOK, RespOK, []devops.JenkinsCredential{}).
Reads([]devops.JenkinsCredential{}))
// match Jenkisn api "/blue/rest/organizations/jenkins/pipelines/{projectName}/{pipelineName}"
webservice.Route(webservice.GET("/devops/{projectName}/pipelines/{pipelineName}").
@@ -51,7 +232,7 @@ func addWebService(c *restful.Container) error {
Doc("Get DevOps Pipelines.").
Param(webservice.PathParameter("pipelineName", "pipeline name")).
Param(webservice.PathParameter("projectName", "devops project name")).
Returns(http.StatusOK, RespMessage, devops.Pipeline{}).
Returns(http.StatusOK, RespOK, devops.Pipeline{}).
Writes(devops.Pipeline{}))
// match Jenkisn api: "jenkins_api/blue/rest/search"
@@ -71,7 +252,7 @@ func addWebService(c *restful.Container) error {
Param(webservice.QueryParameter("limit", "limit count").
Required(false).
DataFormat("limit=%d")).
Returns(http.StatusOK, RespMessage, []devops.Pipeline{}).
Returns(http.StatusOK, RespOK, []devops.Pipeline{}).
Writes([]devops.Pipeline{}))
// match Jenkisn api "/blue/rest/organizations/jenkins/pipelines/{projectName}/{pipelineName}/runs/"
@@ -90,7 +271,7 @@ func addWebService(c *restful.Container) error {
Param(webservice.QueryParameter("branch", "branch ").
Required(false).
DataFormat("branch=%s")).
Returns(http.StatusOK, RespMessage, []devops.PipelineRun{}).
Returns(http.StatusOK, RespOK, []devops.PipelineRun{}).
Writes([]devops.PipelineRun{}))
// match Jenkins api "/blue/rest/organizations/jenkins/pipelines/{projectName}/{pipelineName}/branches/{branchName}/runs/{runId}/"
@@ -105,7 +286,7 @@ func addWebService(c *restful.Container) error {
Param(webservice.QueryParameter("start", "start").
Required(false).
DataFormat("start=%d")).
Returns(http.StatusOK, RespMessage, devops.PipelineRun{}).
Returns(http.StatusOK, RespOK, devops.PipelineRun{}).
Writes(devops.PipelineRun{}))
// match Jenkins api "/blue/rest/organizations/jenkins/pipelines/{projectName}/{pipelineName}/branches/{branchName}/runs/{runId}/nodes"
@@ -121,7 +302,7 @@ func addWebService(c *restful.Container) error {
Required(false).
DataFormat("limit=%d").
DefaultValue("limit=10000")).
Returns(http.StatusOK, RespMessage, []devops.Nodes{}).
Returns(http.StatusOK, RespOK, []devops.Nodes{}).
Writes([]devops.Nodes{}))
// match "/blue/rest/organizations/jenkins/pipelines/{projectName}/{pipelineName}/branches/{branchName}/runs/{runId}/nodes/{nodeId}/steps/{stepId}/log/?start=0"
@@ -147,7 +328,7 @@ func addWebService(c *restful.Container) error {
Metadata(restfulspec.KeyOpenAPITags, tags).
Doc("Validate Github personal access token.").
Param(webservice.PathParameter("scmId", "SCM id")).
Returns(http.StatusOK, RespMessage, devops.Validates{}).
Returns(http.StatusOK, RespOK, devops.Validates{}).
Writes(devops.Validates{}))
// match "/blue/rest/organizations/jenkins/scm/{scmId}/organizations/?credentialId=github"
@@ -159,7 +340,7 @@ func addWebService(c *restful.Container) error {
Param(webservice.QueryParameter("credentialId", "credential id for SCM").
Required(true).
DataFormat("credentialId=%s")).
Returns(http.StatusOK, RespMessage, []devops.SCMOrg{}).
Returns(http.StatusOK, RespOK, []devops.SCMOrg{}).
Writes([]devops.SCMOrg{}))
// match "/blue/rest/organizations/jenkins/scm/{scmId}/organizations/{organizationId}/repositories/?credentialId=&pageNumber&pageSize="
@@ -178,7 +359,7 @@ func addWebService(c *restful.Container) error {
Param(webservice.QueryParameter("pageSize", "page size").
Required(true).
DataFormat("pageSize=%d")).
Returns(http.StatusOK, RespMessage, []devops.OrgRepo{}).
Returns(http.StatusOK, RespOK, []devops.OrgRepo{}).
Writes([]devops.OrgRepo{}))
// match /blue/rest/organizations/jenkins/pipelines/{projectName}/pipelines/{pipelineName}/branches/{branchName}/runs/{runId}/stop/
@@ -198,7 +379,7 @@ func addWebService(c *restful.Container) error {
Required(false).
DataFormat("timeOutInSecs=%d").
DefaultValue("timeOutInSecs=10")).
Returns(http.StatusOK, RespMessage, devops.StopPipe{}).
Returns(http.StatusOK, RespOK, devops.StopPipe{}).
Writes(devops.StopPipe{}))
// match /blue/rest/organizations/jenkins/pipelines/{projectName}/pipelines/{pipelineName}/branches/{branchName}/runs/{runId}/replay/
@@ -210,7 +391,7 @@ func addWebService(c *restful.Container) error {
Param(webservice.PathParameter("pipelineName", "pipeline name")).
Param(webservice.PathParameter("branchName", "pipeline branch name")).
Param(webservice.PathParameter("runId", "pipeline runs id")).
Returns(http.StatusOK, RespMessage, devops.ReplayPipe{}).
Returns(http.StatusOK, RespOK, devops.ReplayPipe{}).
Writes(devops.ReplayPipe{}))
// match /blue/rest/organizations/jenkins/pipelines/{projectName}/{pipelineName}/branches/{branchName}/runs/{runId}/log/?start=0
@@ -262,7 +443,7 @@ func addWebService(c *restful.Container) error {
Param(webservice.QueryParameter("limit", "limit count").
Required(true).
DataFormat("limit=%d")).
Returns(http.StatusOK, RespMessage, []devops.PipeBranch{}).
Returns(http.StatusOK, RespOK, []devops.PipeBranch{}).
Writes([]devops.PipeBranch{}))
// /blue/rest/organizations/jenkins/pipelines/{projectName}/pipelines/{pipelineName}/branches/{branchName}/runs/{runId}/nodes/{nodeId}/steps/{stepId}
@@ -309,7 +490,7 @@ func addWebService(c *restful.Container) error {
Param(webservice.PathParameter("projectName", "devops project name")).
Param(webservice.PathParameter("pipelineName", "pipeline name")).
Param(webservice.PathParameter("branchName", "pipeline branch name")).
Returns(http.StatusOK, RespMessage, devops.QueuedBlueRun{}).
Returns(http.StatusOK, RespOK, devops.QueuedBlueRun{}).
Writes(devops.QueuedBlueRun{}))
// match /pipeline_status/blue/rest/organizations/jenkins/pipelines/{projectName}/{pipelineName}/branches/{branchName}/runs/{runId}/nodes/?limit=
@@ -325,7 +506,7 @@ func addWebService(c *restful.Container) error {
Param(webservice.QueryParameter("limit", "limit count").
Required(true).
DataFormat("limit=%d")).
Returns(http.StatusOK, RespMessage, []devops.QueuedBlueRun{}).
Returns(http.StatusOK, RespOK, []devops.QueuedBlueRun{}).
Writes([]devops.QueuedBlueRun{}))
// match /crumbIssuer/api/json/
@@ -333,7 +514,7 @@ func addWebService(c *restful.Container) error {
To(devopsapi.GetCrumb).
Metadata(restfulspec.KeyOpenAPITags, tags).
Doc("Get crumb").
Returns(http.StatusOK, RespMessage, devops.Crumb{}).
Returns(http.StatusOK, RespOK, devops.Crumb{}).
Writes(devops.Crumb{}))
c.Add(webservice)

View File

@@ -0,0 +1,196 @@
/*
Copyright 2019 The KubeSphere Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package devops
import (
"fmt"
"github.com/asaskevich/govalidator"
"github.com/emicklei/go-restful"
"github.com/golang/glog"
"kubesphere.io/kubesphere/pkg/constants"
"kubesphere.io/kubesphere/pkg/errors"
"kubesphere.io/kubesphere/pkg/models/devops"
"kubesphere.io/kubesphere/pkg/params"
"kubesphere.io/kubesphere/pkg/utils/reflectutils"
"net/http"
)
func GetDevOpsProjectMembersHandler(request *restful.Request, resp *restful.Response) {
projectId := request.PathParameter("devops")
username := request.HeaderParameter(constants.UserNameHeader)
err := devops.CheckProjectUserInRole(username, projectId, devops.AllRoleSlice)
if err != nil {
glog.Errorf("%+v", err)
errors.ParseSvcErr(restful.NewError(http.StatusForbidden, err.Error()), resp)
return
}
orderBy := request.QueryParameter(params.OrderByParam)
reverse := params.ParseReverse(request)
limit, offset := params.ParsePaging(request.QueryParameter(params.PagingParam))
conditions, err := params.ParseConditions(request.QueryParameter(params.ConditionsParam))
project, err := devops.GetProjectMembers(projectId, conditions, orderBy, reverse, limit, offset)
if err != nil {
glog.Errorf("%+v", err)
errors.ParseSvcErr(err, resp)
return
}
resp.WriteAsJson(project)
return
}
func GetDevOpsProjectMemberHandler(request *restful.Request, resp *restful.Response) {
projectId := request.PathParameter("devops")
username := request.HeaderParameter(constants.UserNameHeader)
member := request.PathParameter("members")
err := devops.CheckProjectUserInRole(username, projectId, devops.AllRoleSlice)
if err != nil {
glog.Errorf("%+v", err)
errors.ParseSvcErr(restful.NewError(http.StatusForbidden, err.Error()), resp)
return
}
project, err := devops.GetProjectMember(projectId, member)
if err != nil {
glog.Errorf("%+v", err)
errors.ParseSvcErr(err, resp)
return
}
resp.WriteAsJson(project)
return
}
func AddDevOpsProjectMemberHandler(request *restful.Request, resp *restful.Response) {
projectId := request.PathParameter("devops")
username := request.HeaderParameter(constants.UserNameHeader)
member := &devops.DevOpsProjectMembership{}
err := request.ReadEntity(&member)
if err != nil {
glog.Errorf("%+v", err)
errors.ParseSvcErr(restful.NewError(http.StatusBadRequest, err.Error()), resp)
return
}
if govalidator.IsNull(member.Username) {
err := fmt.Errorf("error need username")
glog.Errorf("%+v", err)
errors.ParseSvcErr(restful.NewError(http.StatusBadRequest, err.Error()), resp)
return
}
if !reflectutils.In(member.Role, devops.AllRoleSlice) {
err := fmt.Errorf("err role [%s] not in [%s]", member.Role,
devops.AllRoleSlice)
glog.Errorf("%+v", err)
errors.ParseSvcErr(restful.NewError(http.StatusBadRequest, err.Error()), resp)
return
}
err = devops.CheckProjectUserInRole(username, projectId, []string{devops.ProjectOwner})
if err != nil {
glog.Errorf("%+v", err)
errors.ParseSvcErr(restful.NewError(http.StatusForbidden, err.Error()), resp)
return
}
project, err := devops.AddProjectMember(projectId, username, member)
if err != nil {
glog.Errorf("%+v", err)
errors.ParseSvcErr(err, resp)
return
}
resp.WriteAsJson(project)
return
}
func UpdateDevOpsProjectMemberHandler(request *restful.Request, resp *restful.Response) {
projectId := request.PathParameter("devops")
username := request.HeaderParameter(constants.UserNameHeader)
member := &devops.DevOpsProjectMembership{}
err := request.ReadEntity(&member)
if err != nil {
glog.Errorf("%+v", err)
errors.ParseSvcErr(restful.NewError(http.StatusBadRequest, err.Error()), resp)
return
}
member.Username = request.PathParameter("members")
if govalidator.IsNull(member.Username) {
err := fmt.Errorf("error need username")
glog.Errorf("%+v", err)
errors.ParseSvcErr(restful.NewError(http.StatusBadRequest, err.Error()), resp)
return
}
if username == member.Username {
err := fmt.Errorf("you can not change your role")
glog.Errorf("%+v", err)
errors.ParseSvcErr(restful.NewError(http.StatusBadRequest, err.Error()), resp)
return
}
if !reflectutils.In(member.Role, devops.AllRoleSlice) {
err := fmt.Errorf("err role [%s] not in [%s]", member.Role,
devops.AllRoleSlice)
glog.Errorf("%+v", err)
errors.ParseSvcErr(restful.NewError(http.StatusBadRequest, err.Error()), resp)
return
}
err = devops.CheckProjectUserInRole(username, projectId, []string{devops.ProjectOwner})
if err != nil {
glog.Errorf("%+v", err)
errors.ParseSvcErr(restful.NewError(http.StatusForbidden, err.Error()), resp)
return
}
project, err := devops.UpdateProjectMember(projectId, username, member)
if err != nil {
glog.Errorf("%+v", err)
errors.ParseSvcErr(err, resp)
return
}
resp.WriteAsJson(project)
return
}
func DeleteDevOpsProjectMemberHandler(request *restful.Request, resp *restful.Response) {
projectId := request.PathParameter("devops")
username := request.HeaderParameter(constants.UserNameHeader)
member := request.PathParameter("members")
err := devops.CheckProjectUserInRole(username, projectId, []string{devops.ProjectOwner})
if err != nil {
glog.Errorf("%+v", err)
errors.ParseSvcErr(restful.NewError(http.StatusForbidden, err.Error()), resp)
return
}
username, err = devops.DeleteProjectMember(projectId, member)
if err != nil {
glog.Errorf("%+v", err)
errors.ParseSvcErr(err, resp)
return
}
resp.WriteAsJson(struct {
Username string `json:"username"`
}{Username: username})
return
}

View File

@@ -0,0 +1,82 @@
/*
Copyright 2019 The KubeSphere Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package devops
import (
"github.com/emicklei/go-restful"
"github.com/golang/glog"
"kubesphere.io/kubesphere/pkg/constants"
"kubesphere.io/kubesphere/pkg/errors"
"kubesphere.io/kubesphere/pkg/models/devops"
"net/http"
)
func GetDevOpsProjectHandler(request *restful.Request, resp *restful.Response) {
projectId := request.PathParameter("devops")
username := request.HeaderParameter(constants.UserNameHeader)
err := devops.CheckProjectUserInRole(username, projectId, devops.AllRoleSlice)
if err != nil {
glog.Errorf("%+v", err)
errors.ParseSvcErr(restful.NewError(http.StatusForbidden, err.Error()), resp)
return
}
project, err := devops.GetProject(projectId)
if err != nil {
glog.Errorf("%+v", err)
errors.ParseSvcErr(err, resp)
return
}
resp.WriteAsJson(project)
return
}
func UpdateProjectHandler(request *restful.Request, resp *restful.Response) {
projectId := request.PathParameter("devops")
username := request.HeaderParameter(constants.UserNameHeader)
var project *devops.DevOpsProject
err := request.ReadEntity(&project)
if err != nil {
glog.Errorf("%+v", err)
errors.ParseSvcErr(restful.NewError(http.StatusBadRequest, err.Error()), resp)
return
}
project.ProjectId = projectId
err = devops.CheckProjectUserInRole(username, projectId, []string{devops.ProjectOwner})
if err != nil {
glog.Errorf("%+v", err)
errors.ParseSvcErr(restful.NewError(http.StatusForbidden, err.Error()), resp)
return
}
project, err = devops.UpdateProject(project)
if err != nil {
glog.Errorf("%+v", err)
errors.ParseSvcErr(err, resp)
return
}
resp.WriteAsJson(project)
return
}
func GetDevOpsProjectDefaultRoles(request *restful.Request, resp *restful.Response) {
resp.WriteAsJson(devops.DefaultRoles)
return
}

View File

@@ -0,0 +1,165 @@
/*
Copyright 2019 The KubeSphere Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package devops
import (
"github.com/emicklei/go-restful"
"github.com/golang/glog"
"kubesphere.io/kubesphere/pkg/constants"
"kubesphere.io/kubesphere/pkg/errors"
"kubesphere.io/kubesphere/pkg/models/devops"
"net/http"
)
func CreateDevOpsProjectCredentialHandler(request *restful.Request, resp *restful.Response) {
projectId := request.PathParameter("devops")
username := request.HeaderParameter(constants.UserNameHeader)
var credential *devops.JenkinsCredential
err := request.ReadEntity(&credential)
if err != nil {
glog.Errorf("%+v", err)
errors.ParseSvcErr(restful.NewError(http.StatusBadRequest, err.Error()), resp)
return
}
err = devops.CheckProjectUserInRole(username, projectId, []string{devops.ProjectOwner, devops.ProjectMaintainer})
if err != nil {
glog.Errorf("%+v", err)
errors.ParseSvcErr(restful.NewError(http.StatusForbidden, err.Error()), resp)
return
}
credentialId, err := devops.CreateProjectCredential(projectId, username, credential)
if err != nil {
glog.Errorf("%+v", err)
errors.ParseSvcErr(err, resp)
return
}
resp.WriteAsJson(struct {
Name string `json:"name"`
}{Name: credentialId})
return
}
func UpdateDevOpsProjectCredentialHandler(request *restful.Request, resp *restful.Response) {
projectId := request.PathParameter("devops")
username := request.HeaderParameter(constants.UserNameHeader)
credentialId := request.PathParameter("credentials")
var credential *devops.JenkinsCredential
err := request.ReadEntity(&credential)
if err != nil {
glog.Errorf("%+v", err)
errors.ParseSvcErr(restful.NewError(http.StatusBadRequest, err.Error()), resp)
return
}
err = devops.CheckProjectUserInRole(username, projectId, []string{devops.ProjectOwner, devops.ProjectMaintainer})
if err != nil {
glog.Errorf("%+v", err)
errors.ParseSvcErr(restful.NewError(http.StatusForbidden, err.Error()), resp)
return
}
credentialId, err = devops.UpdateProjectCredential(projectId, credentialId, credential)
if err != nil {
glog.Errorf("%+v", err)
errors.ParseSvcErr(err, resp)
return
}
resp.WriteAsJson(struct {
Name string `json:"name"`
}{Name: credentialId})
return
}
func DeleteDevOpsProjectCredentialHandler(request *restful.Request, resp *restful.Response) {
projectId := request.PathParameter("devops")
username := request.HeaderParameter(constants.UserNameHeader)
credentialId := request.PathParameter("credentials")
var credential *devops.JenkinsCredential
err := request.ReadEntity(&credential)
if err != nil {
glog.Errorf("%+v", err)
errors.ParseSvcErr(restful.NewError(http.StatusBadRequest, err.Error()), resp)
return
}
err = devops.CheckProjectUserInRole(username, projectId, []string{devops.ProjectOwner, devops.ProjectMaintainer})
if err != nil {
glog.Errorf("%+v", err)
errors.ParseSvcErr(restful.NewError(http.StatusForbidden, err.Error()), resp)
return
}
credentialId, err = devops.DeleteProjectCredential(projectId, credentialId, credential)
if err != nil {
glog.Errorf("%+v", err)
errors.ParseSvcErr(err, resp)
return
}
resp.WriteAsJson(struct {
Name string `json:"name"`
}{Name: credentialId})
return
}
func GetDevOpsProjectCredentialHandler(request *restful.Request, resp *restful.Response) {
projectId := request.PathParameter("devops")
username := request.HeaderParameter(constants.UserNameHeader)
credentialId := request.PathParameter("credentials")
getContent := request.QueryParameter("content")
domain := request.QueryParameter("domain")
err := devops.CheckProjectUserInRole(username, projectId, []string{devops.ProjectOwner, devops.ProjectMaintainer})
if err != nil {
glog.Errorf("%+v", err)
errors.ParseSvcErr(restful.NewError(http.StatusForbidden, err.Error()), resp)
return
}
response, err := devops.GetProjectCredential(projectId, credentialId, domain, getContent)
if err != nil {
glog.Errorf("%+v", err)
errors.ParseSvcErr(err, resp)
return
}
resp.WriteAsJson(response)
return
}
func GetDevOpsProjectCredentialsHandler(request *restful.Request, resp *restful.Response) {
projectId := request.PathParameter("devops")
username := request.HeaderParameter(constants.UserNameHeader)
domain := request.QueryParameter("domain")
err := devops.CheckProjectUserInRole(username, projectId, []string{devops.ProjectOwner, devops.ProjectMaintainer})
if err != nil {
glog.Errorf("%+v", err)
errors.ParseSvcErr(restful.NewError(http.StatusForbidden, err.Error()), resp)
return
}
jenkinsCredentials, err := devops.GetProjectCredentials(projectId, domain)
if err != nil {
glog.Errorf("%+v", err)
errors.ParseSvcErr(err, resp)
return
}
resp.WriteAsJson(jenkinsCredentials)
return
}

View File

@@ -0,0 +1,174 @@
/*
Copyright 2019 The KubeSphere Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package devops
import (
"github.com/emicklei/go-restful"
"github.com/golang/glog"
"kubesphere.io/kubesphere/pkg/constants"
"kubesphere.io/kubesphere/pkg/errors"
"kubesphere.io/kubesphere/pkg/models/devops"
"net/http"
)
func CreateDevOpsProjectPipelineHandler(request *restful.Request, resp *restful.Response) {
projectId := request.PathParameter("devops")
username := request.HeaderParameter(constants.UserNameHeader)
var pipeline *devops.ProjectPipeline
err := request.ReadEntity(&pipeline)
if err != nil {
glog.Errorf("%+v", err)
errors.ParseSvcErr(restful.NewError(http.StatusBadRequest, err.Error()), resp)
return
}
err = devops.CheckProjectUserInRole(username, projectId, []string{devops.ProjectOwner, devops.ProjectMaintainer})
if err != nil {
glog.Errorf("%+v", err)
errors.ParseSvcErr(restful.NewError(http.StatusForbidden, err.Error()), resp)
return
}
pipelineName, err := devops.CreateProjectPipeline(projectId, pipeline)
if err != nil {
glog.Errorf("%+v", err)
errors.ParseSvcErr(err, resp)
return
}
resp.WriteAsJson(struct {
Name string `json:"name"`
}{Name: pipelineName})
return
}
func DeleteDevOpsProjectPipelineHandler(request *restful.Request, resp *restful.Response) {
projectId := request.PathParameter("devops")
username := request.HeaderParameter(constants.UserNameHeader)
pipelineId := request.PathParameter("pipelines")
err := devops.CheckProjectUserInRole(username, projectId, []string{devops.ProjectOwner, devops.ProjectMaintainer})
if err != nil {
glog.Errorf("%+v", err)
errors.ParseSvcErr(restful.NewError(http.StatusForbidden, err.Error()), resp)
return
}
pipelineName, err := devops.DeleteProjectPipeline(projectId, pipelineId)
if err != nil {
glog.Errorf("%+v", err)
errors.ParseSvcErr(err, resp)
return
}
resp.WriteAsJson(struct {
Name string `json:"name"`
}{Name: pipelineName})
return
}
func UpdateDevOpsProjectPipelineHandler(request *restful.Request, resp *restful.Response) {
projectId := request.PathParameter("devops")
username := request.HeaderParameter(constants.UserNameHeader)
pipelineId := request.PathParameter("pipelines")
var pipeline *devops.ProjectPipeline
err := request.ReadEntity(&pipeline)
if err != nil {
glog.Errorf("%+v", err)
errors.ParseSvcErr(restful.NewError(http.StatusBadRequest, err.Error()), resp)
return
}
err = devops.CheckProjectUserInRole(username, projectId, []string{devops.ProjectOwner, devops.ProjectMaintainer})
if err != nil {
glog.Errorf("%+v", err)
errors.ParseSvcErr(restful.NewError(http.StatusForbidden, err.Error()), resp)
return
}
pipelineName, err := devops.UpdateProjectPipeline(projectId, pipelineId, pipeline)
if err != nil {
glog.Errorf("%+v", err)
errors.ParseSvcErr(err, resp)
return
}
resp.WriteAsJson(struct {
Name string `json:"name"`
}{Name: pipelineName})
return
}
func GetDevOpsProjectPipelineHandler(request *restful.Request, resp *restful.Response) {
projectId := request.PathParameter("devops")
username := request.HeaderParameter(constants.UserNameHeader)
pipelineId := request.PathParameter("pipelines")
err := devops.CheckProjectUserInRole(username, projectId, []string{devops.ProjectOwner, devops.ProjectMaintainer})
if err != nil {
glog.Errorf("%+v", err)
errors.ParseSvcErr(restful.NewError(http.StatusForbidden, err.Error()), resp)
return
}
pipeline, err := devops.GetProjectPipeline(projectId, pipelineId)
if err != nil {
glog.Errorf("%+v", err)
errors.ParseSvcErr(err, resp)
return
}
resp.WriteAsJson(pipeline)
return
}
func GetPipelineSonarStatusHandler(request *restful.Request, resp *restful.Response) {
projectId := request.PathParameter("devops")
username := request.HeaderParameter(constants.UserNameHeader)
pipelineId := request.PathParameter("pipelines")
err := devops.CheckProjectUserInRole(username, projectId, devops.AllRoleSlice)
if err != nil {
glog.Errorf("%+v", err)
errors.ParseSvcErr(restful.NewError(http.StatusForbidden, err.Error()), resp)
return
}
sonarStatus, err := devops.GetPipelineSonar(projectId, pipelineId)
if err != nil {
glog.Errorf("%+v", err)
errors.ParseSvcErr(err, resp)
return
}
resp.WriteAsJson(sonarStatus)
}
func GetMultiBranchesPipelineSonarStatusHandler(request *restful.Request, resp *restful.Response) {
projectId := request.PathParameter("devops")
username := request.HeaderParameter(constants.UserNameHeader)
pipelineId := request.PathParameter("pipelines")
branchId := request.PathParameter("branches")
err := devops.CheckProjectUserInRole(username, projectId, devops.AllRoleSlice)
if err != nil {
glog.Errorf("%+v", err)
errors.ParseSvcErr(restful.NewError(http.StatusForbidden, err.Error()), resp)
return
}
sonarStatus, err := devops.GetMultiBranchPipelineSonar(projectId, pipelineId, branchId)
if err != nil {
glog.Errorf("%+v", err)
errors.ParseSvcErr(err, resp)
return
}
resp.WriteAsJson(sonarStatus)
}

View File

@@ -217,7 +217,7 @@ func ListDevopsProjects(req *restful.Request, resp *restful.Response) {
if err != nil {
glog.Errorf("%+v", err)
resp.WriteHeaderAndEntity(http.StatusBadRequest, errors.Wrap(err))
errors.ParseSvcErr(restful.NewError(http.StatusBadRequest, err.Error()), resp)
return
}
@@ -225,7 +225,7 @@ func ListDevopsProjects(req *restful.Request, resp *restful.Response) {
if err != nil {
glog.Errorf("%+v", err)
resp.WriteHeaderAndEntity(http.StatusInternalServerError, errors.Wrap(err))
errors.ParseSvcErr(err, resp)
return
}
@@ -233,7 +233,7 @@ func ListDevopsProjects(req *restful.Request, resp *restful.Response) {
}
func DeleteDevopsProject(req *restful.Request, resp *restful.Response) {
devops := req.PathParameter("id")
projectId := req.PathParameter("id")
workspaceName := req.PathParameter("workspace")
username := req.HeaderParameter(constants.UserNameHeader)
@@ -241,15 +241,15 @@ func DeleteDevopsProject(req *restful.Request, resp *restful.Response) {
if err != nil {
glog.Errorf("%+v", err)
resp.WriteHeaderAndEntity(http.StatusBadRequest, errors.Wrap(err))
errors.ParseSvcErr(restful.NewError(http.StatusBadRequest, err.Error()), resp)
return
}
err, code := tenant.DeleteDevOpsProject(devops, username)
err = tenant.DeleteDevOpsProject(projectId, username)
if err != nil {
glog.Errorf("%+v", err)
resp.WriteHeaderAndEntity(code, errors.Wrap(err))
errors.ParseSvcErr(err, resp)
return
}
@@ -267,16 +267,16 @@ func CreateDevopsProject(req *restful.Request, resp *restful.Response) {
if err != nil {
glog.Infof("%+v", err)
resp.WriteHeaderAndEntity(http.StatusBadRequest, errors.Wrap(err))
errors.ParseSvcErr(restful.NewError(http.StatusBadRequest, err.Error()), resp)
return
}
glog.Infoln("create workspace", username, workspaceName, devops)
project, err, code := tenant.CreateDevopsProject(username, workspaceName, &devops)
project, err := tenant.CreateDevopsProject(username, workspaceName, &devops)
if err != nil {
glog.Errorf("%+v", err)
resp.WriteHeaderAndEntity(code, errors.Wrap(err))
errors.ParseSvcErr(err, resp)
return
}
@@ -302,11 +302,11 @@ func ListDevopsRules(req *restful.Request, resp *restful.Response) {
devops := req.PathParameter("devops")
username := req.HeaderParameter(constants.UserNameHeader)
rules, err, code := tenant.GetUserDevopsSimpleRules(username, devops)
rules, err := tenant.GetUserDevopsSimpleRules(username, devops)
if err != nil {
glog.Errorf("%+v", err)
resp.WriteError(code, err)
errors.ParseSvcErr(err, resp)
return
}

View File

@@ -20,6 +20,8 @@ package errors
import (
"encoding/json"
"errors"
"github.com/emicklei/go-restful"
"net/http"
)
type Error struct {
@@ -53,3 +55,11 @@ func Parse(data []byte) error {
return errors.New(string(data))
}
}
func ParseSvcErr(err error, resp *restful.Response) {
if svcErr, ok := err.(restful.ServiceError); ok {
resp.WriteServiceError(svcErr.Code, svcErr)
} else {
resp.WriteHeaderAndEntity(http.StatusInternalServerError, Wrap(err))
}
}

View File

@@ -93,13 +93,13 @@ type CredentialResponse struct {
Name string `json:"name,omitempty"`
Ranges struct {
Ranges []*struct {
Start int `json:"start"`
End int `json:"end"`
} `json:"ranges"`
} `json:"ranges"`
Start int `json:"start,omitempty"`
End int `json:"end,omitempty"`
} `json:"ranges,omitempty"`
} `json:"ranges,omitempty"`
} `json:"usage,omitempty"`
} `json:"fingerprint,omitempty"`
Description string `json:"description"`
Description string `json:"description,omitempty"`
Domain string `json:"domain"`
}

View File

@@ -1,7 +1,25 @@
/*
Copyright 2019 The KubeSphere Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package devops
import (
"fmt"
"github.com/fatih/structs"
"kubesphere.io/kubesphere/pkg/db"
"kubesphere.io/kubesphere/pkg/gojenkins"
"kubesphere.io/kubesphere/pkg/simple/client/devops_mysql"
"kubesphere.io/kubesphere/pkg/utils/reflectutils"
"kubesphere.io/kubesphere/pkg/utils/stringutils"
)
@@ -59,3 +77,270 @@ const (
const (
JenkinsAllUserRoleName = "kubesphere-user"
)
type Role struct {
Name string `json:"name"`
Description string `json:"description"`
}
var DefaultRoles = []*Role{
{
Name: ProjectOwner,
Description: "Owner have access to do all the operations of a DevOps project and own the highest permissions as well.",
},
{
Name: ProjectMaintainer,
Description: "Maintainer have access to manage pipeline and credential configuration in a DevOps project.",
},
{
Name: ProjectDeveloper,
Description: "Developer is able to view and trigger the pipeline.",
},
{
Name: ProjectReporter,
Description: "Reporter is only allowed to view the status of the pipeline.",
},
}
var AllRoleSlice = []string{ProjectDeveloper, ProjectReporter, ProjectMaintainer, ProjectOwner}
var JenkinsOwnerProjectPermissionIds = &gojenkins.ProjectPermissionIds{
CredentialCreate: true,
CredentialDelete: true,
CredentialManageDomains: true,
CredentialUpdate: true,
CredentialView: true,
ItemBuild: true,
ItemCancel: true,
ItemConfigure: true,
ItemCreate: true,
ItemDelete: true,
ItemDiscover: true,
ItemMove: true,
ItemRead: true,
ItemWorkspace: true,
RunDelete: true,
RunReplay: true,
RunUpdate: true,
SCMTag: true,
}
var JenkinsProjectPermissionMap = map[string]gojenkins.ProjectPermissionIds{
ProjectOwner: gojenkins.ProjectPermissionIds{
CredentialCreate: true,
CredentialDelete: true,
CredentialManageDomains: true,
CredentialUpdate: true,
CredentialView: true,
ItemBuild: true,
ItemCancel: true,
ItemConfigure: true,
ItemCreate: true,
ItemDelete: true,
ItemDiscover: true,
ItemMove: true,
ItemRead: true,
ItemWorkspace: true,
RunDelete: true,
RunReplay: true,
RunUpdate: true,
SCMTag: true,
},
ProjectMaintainer: gojenkins.ProjectPermissionIds{
CredentialCreate: true,
CredentialDelete: true,
CredentialManageDomains: true,
CredentialUpdate: true,
CredentialView: true,
ItemBuild: true,
ItemCancel: true,
ItemConfigure: false,
ItemCreate: true,
ItemDelete: false,
ItemDiscover: true,
ItemMove: false,
ItemRead: true,
ItemWorkspace: true,
RunDelete: true,
RunReplay: true,
RunUpdate: true,
SCMTag: true,
},
ProjectDeveloper: gojenkins.ProjectPermissionIds{
CredentialCreate: false,
CredentialDelete: false,
CredentialManageDomains: false,
CredentialUpdate: false,
CredentialView: false,
ItemBuild: true,
ItemCancel: true,
ItemConfigure: false,
ItemCreate: false,
ItemDelete: false,
ItemDiscover: true,
ItemMove: false,
ItemRead: true,
ItemWorkspace: true,
RunDelete: true,
RunReplay: true,
RunUpdate: true,
SCMTag: false,
},
ProjectReporter: gojenkins.ProjectPermissionIds{
CredentialCreate: false,
CredentialDelete: false,
CredentialManageDomains: false,
CredentialUpdate: false,
CredentialView: false,
ItemBuild: false,
ItemCancel: false,
ItemConfigure: false,
ItemCreate: false,
ItemDelete: false,
ItemDiscover: true,
ItemMove: false,
ItemRead: true,
ItemWorkspace: false,
RunDelete: false,
RunReplay: false,
RunUpdate: false,
SCMTag: false,
},
}
var JenkinsPipelinePermissionMap = map[string]gojenkins.ProjectPermissionIds{
ProjectOwner: gojenkins.ProjectPermissionIds{
CredentialCreate: true,
CredentialDelete: true,
CredentialManageDomains: true,
CredentialUpdate: true,
CredentialView: true,
ItemBuild: true,
ItemCancel: true,
ItemConfigure: true,
ItemCreate: true,
ItemDelete: true,
ItemDiscover: true,
ItemMove: true,
ItemRead: true,
ItemWorkspace: true,
RunDelete: true,
RunReplay: true,
RunUpdate: true,
SCMTag: true,
},
ProjectMaintainer: gojenkins.ProjectPermissionIds{
CredentialCreate: true,
CredentialDelete: true,
CredentialManageDomains: true,
CredentialUpdate: true,
CredentialView: true,
ItemBuild: true,
ItemCancel: true,
ItemConfigure: true,
ItemCreate: true,
ItemDelete: true,
ItemDiscover: true,
ItemMove: true,
ItemRead: true,
ItemWorkspace: true,
RunDelete: true,
RunReplay: true,
RunUpdate: true,
SCMTag: true,
},
ProjectDeveloper: gojenkins.ProjectPermissionIds{
CredentialCreate: false,
CredentialDelete: false,
CredentialManageDomains: false,
CredentialUpdate: false,
CredentialView: false,
ItemBuild: true,
ItemCancel: true,
ItemConfigure: false,
ItemCreate: false,
ItemDelete: false,
ItemDiscover: true,
ItemMove: false,
ItemRead: true,
ItemWorkspace: true,
RunDelete: true,
RunReplay: true,
RunUpdate: true,
SCMTag: false,
},
ProjectReporter: gojenkins.ProjectPermissionIds{
CredentialCreate: false,
CredentialDelete: false,
CredentialManageDomains: false,
CredentialUpdate: false,
CredentialView: false,
ItemBuild: false,
ItemCancel: false,
ItemConfigure: false,
ItemCreate: false,
ItemDelete: false,
ItemDiscover: true,
ItemMove: false,
ItemRead: true,
ItemWorkspace: false,
RunDelete: false,
RunReplay: false,
RunUpdate: false,
SCMTag: false,
},
}
func GetProjectRoleName(projectId, role string) string {
return fmt.Sprintf("%s-%s-project", projectId, role)
}
func GetPipelineRoleName(projectId, role string) string {
return fmt.Sprintf("%s-%s-pipeline", projectId, role)
}
func GetProjectRolePattern(projectId string) string {
return fmt.Sprintf("^%s$", projectId)
}
func GetPipelineRolePattern(projectId string) string {
return fmt.Sprintf("^%s/.*", projectId)
}
func CheckProjectUserInRole(username, projectId string, roles []string) error {
if username == KS_ADMIN {
return nil
}
dbconn := devops_mysql.OpenDatabase()
membership := &DevOpsProjectMembership{}
err := dbconn.Select(DevOpsProjectMembershipColumns...).
From(DevOpsProjectMembershipTableName).
Where(db.And(
db.Eq(DevOpsProjectMembershipUsernameColumn, username),
db.Eq(DevOpsProjectMembershipProjectIdColumn, projectId))).LoadOne(membership)
if err != nil {
return err
}
if !reflectutils.In(membership.Role, roles) {
return fmt.Errorf("user [%s] in project [%s] role is not in %s", username, projectId, roles)
}
return nil
}
func GetProjectUserRole(username, projectId string) (string, error) {
if username == KS_ADMIN {
return ProjectOwner, nil
}
dbconn := devops_mysql.OpenDatabase()
membership := &DevOpsProjectMembership{}
err := dbconn.Select(DevOpsProjectMembershipColumns...).
From(DevOpsProjectMembershipTableName).
Where(db.And(
db.Eq(DevOpsProjectMembershipUsernameColumn, username),
db.Eq(DevOpsProjectMembershipProjectIdColumn, projectId))).LoadOne(membership)
if err != nil {
return "", err
}
return membership.Role, nil
}

View File

@@ -1,3 +1,16 @@
/*
Copyright 2019 The KubeSphere Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package devops
const (

View File

@@ -1,3 +1,16 @@
/*
Copyright 2019 The KubeSphere Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package devops
import (
@@ -21,12 +34,12 @@ const (
type DevOpsProject struct {
ProjectId string `json:"project_id" db:"project_id"`
Name string `json:"name"`
Description string `json:"description"`
Description string `json:"description,omitempty"`
Creator string `json:"creator"`
CreateTime time.Time `json:"create_time"`
Status string `json:"status"`
Visibility string `json:"visibility"`
Extra string `json:"extra"`
Visibility string `json:"visibility,omitempty"`
Extra string `json:"extra,omitempty"`
Workspace string `json:"workspace"`
}

View File

@@ -0,0 +1,117 @@
/*
Copyright 2019 The KubeSphere Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package devops
import (
"github.com/asaskevich/govalidator"
"time"
)
const (
CredentialTypeUsernamePassword = "username_password"
CredentialTypeSsh = "ssh"
CredentialTypeSecretText = "secret_text"
CredentialTypeKubeConfig = "kubeconfig"
)
type JenkinsCredential struct {
Id string `json:"id"`
Type string `json:"type"`
DisplayName string `json:"display_name,omitempty"`
Fingerprint *struct {
FileName string `json:"file_name,omitempty"`
Hash string `json:"hash,omitempty"`
Usage []*struct {
Name string `json:"name,omitempty"`
Ranges struct {
Ranges []*struct {
Start int `json:"start,omitempty"`
End int `json:"end,omitempty"`
} `json:"ranges,omitempty"`
} `json:"ranges,omitempty"`
} `json:"usage,omitempty"`
} `json:"fingerprint,omitempty"`
Description string `json:"description,omitempty"`
Domain string `json:"domain,omitempty"`
CreateTime *time.Time `json:"create_time,omitempty"`
Creator string `json:"creator,omitempty"`
UsernamePasswordCredential *UsernamePasswordCredential `json:"username_password,omitempty"`
SshCredential *SshCredential `json:"ssh,omitempty"`
SecretTextCredential *SecretTextCredential `json:"secret_text,omitempty"`
KubeconfigCredential *KubeconfigCredential `json:"kubeconfig,omitempty"`
}
type UsernamePasswordCredential struct {
Id string `json:"id"`
Username string `json:"username,omitempty"`
Password string `json:"password,omitempty"`
Description string `json:"description,omitempty"`
}
type SshCredential struct {
Id string `json:"id"`
Username string `json:"username,omitempty"`
Passphrase string `json:"passphrase,omitempty"`
PrivateKey string `json:"private_key,omitempty" mapstructure:"private_key"`
Description string `json:"description,omitempty"`
}
type SecretTextCredential struct {
Id string `json:"id"`
Secret string `json:"secret,omitempty"`
Description string `json:"description,omitempty"`
}
type KubeconfigCredential struct {
Id string `json:"id"`
Content string `json:"content,omitempty"`
Description string `json:"description,omitempty"`
}
const (
ProjectCredentialTableName = "project_credential"
ProjectCredentialIdColumn = "credential_id"
ProjectCredentialDomainColumn = "domain"
ProjectCredentialProjectIdColumn = "project_id"
)
var CredentialTypeMap = map[string]string{
"SSH Username with private key": CredentialTypeSsh,
"Username with password": CredentialTypeUsernamePassword,
"Secret text": CredentialTypeSecretText,
"Kubernetes configuration (kubeconfig)": CredentialTypeKubeConfig,
}
type ProjectCredential struct {
ProjectId string `json:"project_id"`
CredentialId string `json:"credential_id"`
Domain string `json:"domain"`
Creator string `json:"creator"`
CreateTime time.Time `json:"create_time"`
}
var ProjectCredentialColumns = GetColumnsFromStruct(&ProjectCredential{})
func NewProjectCredential(projectId, credentialId, domain, creator string) *ProjectCredential {
if govalidator.IsNull(domain) {
domain = "_"
}
return &ProjectCredential{
ProjectId: projectId,
CredentialId: credentialId,
Domain: domain,
Creator: creator,
CreateTime: time.Now(),
}
}

View File

@@ -0,0 +1,481 @@
/*
Copyright 2019 The KubeSphere Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package devops
import (
"fmt"
"github.com/PuerkitoBio/goquery"
"github.com/asaskevich/govalidator"
"github.com/emicklei/go-restful"
"github.com/gocraft/dbr"
"github.com/golang/glog"
"kubesphere.io/kubesphere/pkg/db"
"kubesphere.io/kubesphere/pkg/gojenkins"
"kubesphere.io/kubesphere/pkg/gojenkins/utils"
"kubesphere.io/kubesphere/pkg/simple/client/admin_jenkins"
"kubesphere.io/kubesphere/pkg/simple/client/devops_mysql"
"net/http"
"strings"
)
func CreateProjectCredential(projectId, username string, credentialRequest *JenkinsCredential) (string, error) {
jenkinsClient := admin_jenkins.Client()
switch credentialRequest.Type {
case CredentialTypeUsernamePassword:
err := checkJenkinsCredentialExists(projectId, credentialRequest.Domain, credentialRequest.UsernamePasswordCredential.Id)
if err != nil {
glog.Errorf("%+v", err)
return "", err
}
if credentialRequest.UsernamePasswordCredential == nil {
err := fmt.Errorf("usename_password should not be nil")
glog.Error(err)
return "", restful.NewError(http.StatusBadRequest, err.Error())
}
credentialId, err := jenkinsClient.CreateUsernamePasswordCredentialInFolder(credentialRequest.Domain,
credentialRequest.UsernamePasswordCredential.Id,
credentialRequest.UsernamePasswordCredential.Username,
credentialRequest.UsernamePasswordCredential.Password,
credentialRequest.UsernamePasswordCredential.Description,
projectId)
if err != nil {
glog.Errorf("%+v", err)
return "", restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
err = insertCredentialToDb(projectId, *credentialId, credentialRequest.Domain, username)
if err != nil {
glog.Errorf("%+v", err)
return "", err
}
return *credentialId, nil
case CredentialTypeSsh:
err := checkJenkinsCredentialExists(projectId, credentialRequest.Domain, credentialRequest.SshCredential.Id)
if err != nil {
glog.Errorf("%+v", err)
return "", err
}
if credentialRequest.SshCredential == nil {
err := fmt.Errorf("ssh should not be nil")
glog.Error(err)
return "", restful.NewError(http.StatusBadRequest, err.Error())
}
credentialId, err := jenkinsClient.CreateSshCredentialInFolder(credentialRequest.Domain,
credentialRequest.SshCredential.Id,
credentialRequest.SshCredential.Username,
credentialRequest.SshCredential.Passphrase,
credentialRequest.SshCredential.PrivateKey,
credentialRequest.SshCredential.Description,
projectId)
if err != nil {
glog.Errorf("%+v", err)
return "", restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
err = insertCredentialToDb(projectId, *credentialId, credentialRequest.Domain, username)
if err != nil {
glog.Errorf("%+v", err)
return "", restful.NewError(http.StatusInternalServerError, err.Error())
}
return *credentialId, nil
case CredentialTypeSecretText:
err := checkJenkinsCredentialExists(projectId, credentialRequest.Domain, credentialRequest.SecretTextCredential.Id)
if err != nil {
glog.Errorf("%+v", err)
return "", err
}
if credentialRequest.SecretTextCredential == nil {
err := fmt.Errorf("secret_text should not be nil")
glog.Error(err)
return "", restful.NewError(http.StatusBadRequest, err.Error())
}
credentialId, err := jenkinsClient.CreateSecretTextCredentialInFolder(credentialRequest.Domain,
credentialRequest.SecretTextCredential.Id,
credentialRequest.SecretTextCredential.Secret,
credentialRequest.SecretTextCredential.Description,
projectId)
if err != nil {
glog.Errorf("%+v", err)
return "", restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
err = insertCredentialToDb(projectId, *credentialId, credentialRequest.Domain, username)
if err != nil {
glog.Errorf("%+v", err)
return "", restful.NewError(http.StatusInternalServerError, err.Error())
}
return *credentialId, nil
case CredentialTypeKubeConfig:
err := checkJenkinsCredentialExists(projectId, credentialRequest.Domain, credentialRequest.KubeconfigCredential.Id)
if err != nil {
glog.Errorf("%+v", err)
return "", err
}
if credentialRequest.KubeconfigCredential == nil {
err := fmt.Errorf("kubeconfig should not be nil")
glog.Error(err)
return "", restful.NewError(http.StatusBadRequest, err.Error())
}
credentialId, err := jenkinsClient.CreateKubeconfigCredentialInFolder(credentialRequest.Domain,
credentialRequest.KubeconfigCredential.Id,
credentialRequest.KubeconfigCredential.Content,
credentialRequest.KubeconfigCredential.Description,
projectId)
if err != nil {
glog.Errorf("%+v", err)
return "", restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
err = insertCredentialToDb(projectId, *credentialId, credentialRequest.Domain, username)
if err != nil {
glog.Errorf("%+v", err)
return "", restful.NewError(http.StatusInternalServerError, err.Error())
}
return *credentialId, nil
default:
err := fmt.Errorf("error unsupport credential type")
glog.Errorf("%+v", err)
return "", restful.NewError(http.StatusBadRequest, err.Error())
}
}
func UpdateProjectCredential(projectId, credentialId string, credentialRequest *JenkinsCredential) (string, error) {
jenkinsClient := admin_jenkins.Client()
jenkinsCredential, err := jenkinsClient.GetCredentialInFolder(credentialRequest.Domain,
credentialId,
projectId)
if err != nil {
glog.Errorf("%+v", err)
return "", restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
credentialType := CredentialTypeMap[jenkinsCredential.TypeName]
switch credentialType {
case CredentialTypeUsernamePassword:
if credentialRequest.UsernamePasswordCredential == nil {
err := fmt.Errorf("usename_password should not be nil")
glog.Error(err)
return "", restful.NewError(http.StatusBadRequest, err.Error())
}
credentialId, err := jenkinsClient.UpdateUsernamePasswordCredentialInFolder(credentialRequest.Domain,
credentialId,
credentialRequest.UsernamePasswordCredential.Username,
credentialRequest.UsernamePasswordCredential.Password,
credentialRequest.UsernamePasswordCredential.Description,
projectId)
if err != nil {
glog.Errorf("%+v", err)
return "", restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
return *credentialId, nil
case CredentialTypeSsh:
if credentialRequest.SshCredential == nil {
err := fmt.Errorf("ssh should not be nil")
glog.Error(err)
return "", restful.NewError(http.StatusBadRequest, err.Error())
}
credentialId, err := jenkinsClient.UpdateSshCredentialInFolder(credentialRequest.Domain,
credentialId,
credentialRequest.SshCredential.Username,
credentialRequest.SshCredential.Passphrase,
credentialRequest.SshCredential.PrivateKey,
credentialRequest.SshCredential.Description,
projectId)
if err != nil {
glog.Errorf("%+v", err)
return "", restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
return *credentialId, nil
case CredentialTypeSecretText:
if credentialRequest.SecretTextCredential == nil {
err := fmt.Errorf("secret_text should not be nil")
glog.Error(err)
return "", restful.NewError(http.StatusBadRequest, err.Error())
}
credentialId, err := jenkinsClient.UpdateSecretTextCredentialInFolder(credentialRequest.Domain,
credentialId,
credentialRequest.SecretTextCredential.Secret,
credentialRequest.SecretTextCredential.Description,
projectId)
if err != nil {
glog.Errorf("%+v", err)
return "", restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
return *credentialId, nil
case CredentialTypeKubeConfig:
if credentialRequest.KubeconfigCredential == nil {
err := fmt.Errorf("kubeconfig should not be nil")
glog.Error(err)
return "", restful.NewError(http.StatusBadRequest, err.Error())
}
credentialId, err := jenkinsClient.UpdateKubeconfigCredentialInFolder(credentialRequest.Domain,
credentialId,
credentialRequest.KubeconfigCredential.Content,
credentialRequest.KubeconfigCredential.Description,
projectId)
if err != nil {
glog.Errorf("%+v", err)
return "", restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
return *credentialId, nil
default:
err := fmt.Errorf("error unsupport credential type")
glog.Errorf("%+v", err)
return "", restful.NewError(http.StatusBadRequest, err.Error())
}
}
func DeleteProjectCredential(projectId, credentialId string, credentialRequest *JenkinsCredential) (string, error) {
jenkinsClient := admin_jenkins.Client()
dbClient := devops_mysql.OpenDatabase()
_, err := jenkinsClient.GetCredentialInFolder(credentialRequest.Domain,
credentialId,
projectId)
if err != nil {
glog.Errorf("%+v", err)
return "", restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
id, err := jenkinsClient.DeleteCredentialInFolder(credentialRequest.Domain, credentialId, projectId)
if err != nil {
glog.Errorf("%+v", err)
return "", restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
deleteConditions := append(make([]dbr.Builder, 0), db.Eq(ProjectCredentialProjectIdColumn, projectId))
deleteConditions = append(deleteConditions, db.Eq(ProjectCredentialIdColumn, credentialId))
if !govalidator.IsNull(credentialRequest.Domain) {
deleteConditions = append(deleteConditions, db.Eq(ProjectCredentialDomainColumn, credentialRequest.Domain))
} else {
deleteConditions = append(deleteConditions, db.Eq(ProjectCredentialDomainColumn, "_"))
}
_, err = dbClient.DeleteFrom(ProjectCredentialTableName).
Where(db.And(deleteConditions...)).Exec()
if err != nil && err != db.ErrNotFound {
glog.Errorf("%+v", err)
return "", restful.NewError(http.StatusInternalServerError, err.Error())
}
return *id, nil
}
func GetProjectCredential(projectId, credentialId, domain, getContent string) (*JenkinsCredential, error) {
jenkinsClient := admin_jenkins.Client()
dbClient := devops_mysql.OpenDatabase()
jenkinsResponse, err := jenkinsClient.GetCredentialInFolder(domain,
credentialId,
projectId)
if err != nil {
glog.Errorf("%+v", err)
return nil, restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
projectCredential := &ProjectCredential{}
err = dbClient.Select(ProjectCredentialColumns...).
From(ProjectCredentialTableName).Where(
db.And(db.Eq(ProjectCredentialProjectIdColumn, projectId),
db.Eq(ProjectCredentialIdColumn, credentialId),
db.Eq(ProjectCredentialDomainColumn, jenkinsResponse.Domain))).LoadOne(projectCredential)
if err != nil && err != db.ErrNotFound {
glog.Errorf("%+v", err)
return nil, restful.NewError(http.StatusInternalServerError, err.Error())
}
response := formatCredentialResponse(jenkinsResponse, projectCredential)
if getContent != "" {
stringBody, err := jenkinsClient.GetCredentialContentInFolder(jenkinsResponse.Domain, credentialId, projectId)
if err != nil {
glog.Errorf("%+v", err)
return nil, restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
stringReader := strings.NewReader(stringBody)
doc, err := goquery.NewDocumentFromReader(stringReader)
if err != nil {
glog.Errorf("%+v", err)
return nil, restful.NewError(http.StatusInternalServerError, err.Error())
}
switch response.Type {
case CredentialTypeKubeConfig:
content := &KubeconfigCredential{}
doc.Find("textarea[name*=content]").Each(func(i int, selection *goquery.Selection) {
value := selection.Text()
content.Content = value
})
doc.Find("input[name*=id][type=text]").Each(func(i int, selection *goquery.Selection) {
value, _ := selection.Attr("value")
content.Id = value
})
doc.Find("input[name*=description]").Each(func(i int, selection *goquery.Selection) {
value, _ := selection.Attr("value")
content.Description = value
})
response.KubeconfigCredential = content
case CredentialTypeUsernamePassword:
content := &UsernamePasswordCredential{}
doc.Find("input[name*=username]").Each(func(i int, selection *goquery.Selection) {
value, _ := selection.Attr("value")
content.Username = value
})
doc.Find("input[name*=id][type=text]").Each(func(i int, selection *goquery.Selection) {
value, _ := selection.Attr("value")
content.Id = value
})
doc.Find("input[name*=description]").Each(func(i int, selection *goquery.Selection) {
value, _ := selection.Attr("value")
content.Description = value
})
response.UsernamePasswordCredential = content
case CredentialTypeSsh:
content := &SshCredential{}
doc.Find("input[name*=username]").Each(func(i int, selection *goquery.Selection) {
value, _ := selection.Attr("value")
content.Username = value
})
doc.Find("input[name*=id][type=text]").Each(func(i int, selection *goquery.Selection) {
value, _ := selection.Attr("value")
content.Id = value
})
doc.Find("input[name*=description]").Each(func(i int, selection *goquery.Selection) {
value, _ := selection.Attr("value")
content.Description = value
})
doc.Find("textarea[name*=privateKey]").Each(func(i int, selection *goquery.Selection) {
value := selection.Text()
content.PrivateKey = value
})
response.SshCredential = content
}
}
return response, nil
}
func GetProjectCredentials(projectId, domain string) ([]*JenkinsCredential, error) {
jenkinsClient := admin_jenkins.Client()
dbClient := devops_mysql.OpenDatabase()
jenkinsCredentialResponses, err := jenkinsClient.GetCredentialsInFolder(domain, projectId)
if err != nil {
glog.Errorf("%+v", err)
return nil, restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
selectCondition := db.Eq(ProjectCredentialProjectIdColumn, projectId)
if !govalidator.IsNull(domain) {
selectCondition = db.And(selectCondition, db.Eq(ProjectCredentialDomainColumn, domain))
}
projectCredentials := make([]*ProjectCredential, 0)
_, err = dbClient.Select(ProjectCredentialColumns...).
From(ProjectCredentialTableName).Where(selectCondition).Load(&projectCredentials)
if err != nil {
glog.Errorf("%+v", err)
return nil, restful.NewError(http.StatusInternalServerError, err.Error())
}
response := formatCredentialsResponse(jenkinsCredentialResponses, projectCredentials)
return response, nil
}
func insertCredentialToDb(projectId, credentialId, domain, username string) error {
dbClient := devops_mysql.OpenDatabase()
projectCredential := NewProjectCredential(projectId, credentialId, domain, username)
_, err := dbClient.InsertInto(ProjectCredentialTableName).Columns(ProjectCredentialColumns...).
Record(projectCredential).Exec()
if err != nil {
glog.Errorf("%+v", err)
return restful.NewError(http.StatusInternalServerError, err.Error())
}
return nil
}
func checkJenkinsCredentialExists(projectId, domain, credentialId string) error {
jenkinsClient := admin_jenkins.Client()
credential, err := jenkinsClient.GetCredentialInFolder(domain, credentialId, projectId)
if credential != nil {
err := fmt.Errorf("credential id [%s] has been used", credential.Id)
glog.Warning(err.Error())
return restful.NewError(http.StatusConflict, err.Error())
}
if err != nil && utils.GetJenkinsStatusCode(err) != http.StatusNotFound {
glog.Errorf("%+v", err)
return restful.NewError(http.StatusBadRequest, err.Error())
}
return nil
}
func formatCredentialResponse(
jenkinsCredentialResponse *gojenkins.CredentialResponse,
dbCredentialResponse *ProjectCredential) *JenkinsCredential {
response := &JenkinsCredential{}
response.Id = jenkinsCredentialResponse.Id
response.Description = jenkinsCredentialResponse.Description
response.DisplayName = jenkinsCredentialResponse.DisplayName
if jenkinsCredentialResponse.Fingerprint != nil && jenkinsCredentialResponse.Fingerprint.Hash != "" {
response.Fingerprint = &struct {
FileName string `json:"file_name,omitempty"`
Hash string `json:"hash,omitempty"`
Usage []*struct {
Name string `json:"name,omitempty"`
Ranges struct {
Ranges []*struct {
Start int `json:"start,omitempty"`
End int `json:"end,omitempty"`
} `json:"ranges,omitempty"`
} `json:"ranges,omitempty"`
} `json:"usage,omitempty"`
}{}
response.Fingerprint.FileName = jenkinsCredentialResponse.Fingerprint.FileName
response.Fingerprint.Hash = jenkinsCredentialResponse.Fingerprint.Hash
for _, usage := range jenkinsCredentialResponse.Fingerprint.Usage {
response.Fingerprint.Usage = append(response.Fingerprint.Usage, usage)
}
}
response.Domain = jenkinsCredentialResponse.Domain
if dbCredentialResponse != nil {
response.CreateTime = &dbCredentialResponse.CreateTime
response.Creator = dbCredentialResponse.Creator
}
credentialType, ok := CredentialTypeMap[jenkinsCredentialResponse.TypeName]
if ok {
response.Type = credentialType
return response
}
response.Type = jenkinsCredentialResponse.TypeName
return response
}
func formatCredentialsResponse(jenkinsCredentialsResponse []*gojenkins.CredentialResponse,
projectCredentials []*ProjectCredential) []*JenkinsCredential {
responseSlice := make([]*JenkinsCredential, 0)
for _, jenkinsCredential := range jenkinsCredentialsResponse {
var dbCredential *ProjectCredential = nil
for _, projectCredential := range projectCredentials {
if projectCredential.CredentialId == jenkinsCredential.Id &&
projectCredential.Domain == jenkinsCredential.Domain {
dbCredential = projectCredential
}
}
responseSlice = append(responseSlice, formatCredentialResponse(jenkinsCredential, dbCredential))
}
return responseSlice
}

View File

@@ -0,0 +1,77 @@
/*
Copyright 2019 The KubeSphere Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package devops
import (
"github.com/asaskevich/govalidator"
"github.com/emicklei/go-restful"
"github.com/gocraft/dbr"
"github.com/golang/glog"
"kubesphere.io/kubesphere/pkg/db"
"kubesphere.io/kubesphere/pkg/simple/client/devops_mysql"
"net/http"
)
func GetProject(projectId string) (*DevOpsProject, error) {
dbconn := devops_mysql.OpenDatabase()
project := &DevOpsProject{}
err := dbconn.Select(DevOpsProjectColumns...).
From(DevOpsProjectTableName).
Where(db.Eq(DevOpsProjectIdColumn, projectId)).
LoadOne(project)
if err != nil && err != dbr.ErrNotFound {
glog.Errorf("%+v", err)
return nil, restful.NewError(http.StatusInternalServerError, err.Error())
}
if err == dbr.ErrNotFound {
glog.Errorf("%+v", err)
return nil, restful.NewError(http.StatusNotFound, err.Error())
}
return project, nil
}
func UpdateProject(project *DevOpsProject) (*DevOpsProject, error) {
dbconn := devops_mysql.OpenDatabase()
query := dbconn.Update(DevOpsProjectTableName)
if !govalidator.IsNull(project.Description) {
query.Set(DevOpsProjectDescriptionColumn, project.Description)
}
if !govalidator.IsNull(project.Extra) {
query.Set(DevOpsProjectExtraColumn, project.Extra)
}
if !govalidator.IsNull(project.Name) {
query.Set(DevOpsProjectNameColumn, project.Extra)
}
if len(query.UpdateStmt.Value) > 0 {
_, err := query.
Where(db.Eq(DevOpsProjectIdColumn, project.ProjectId)).Exec()
if err != nil {
glog.Errorf("%+v", err)
return nil, restful.NewError(http.StatusInternalServerError, err.Error())
}
}
newProject := &DevOpsProject{}
err := dbconn.Select(DevOpsProjectColumns...).
From(DevOpsProjectTableName).
Where(db.Eq(DevOpsProjectIdColumn, project.ProjectId)).
LoadOne(newProject)
if err != nil {
glog.Errorf("%+v", err)
return nil, restful.NewError(http.StatusInternalServerError, err.Error())
}
return newProject, nil
}

View File

@@ -0,0 +1,328 @@
/*
Copyright 2018 The KubeSphere Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package devops
import (
"fmt"
"net/http"
"github.com/emicklei/go-restful"
"github.com/gocraft/dbr"
"github.com/golang/glog"
"kubesphere.io/kubesphere/pkg/db"
"kubesphere.io/kubesphere/pkg/gojenkins"
"kubesphere.io/kubesphere/pkg/gojenkins/utils"
"kubesphere.io/kubesphere/pkg/models"
"kubesphere.io/kubesphere/pkg/params"
"kubesphere.io/kubesphere/pkg/simple/client/admin_jenkins"
"kubesphere.io/kubesphere/pkg/simple/client/devops_mysql"
)
func GetProjectMembers(projectId string, conditions *params.Conditions, orderBy string, reverse bool, limit int, offset int) (*models.PageableResponse, error) {
dbconn := devops_mysql.OpenDatabase()
memberships := make([]*DevOpsProjectMembership, 0)
var sqconditions []dbr.Builder
sqconditions = append(sqconditions, db.Eq(DevOpsProjectMembershipProjectIdColumn, projectId))
if keyword := conditions.Match["keyword"]; keyword != "" {
sqconditions = append(sqconditions, db.Like(DevOpsProjectMembershipUsernameColumn, keyword))
}
query := dbconn.Select(DevOpsProjectMembershipColumns...).
From(DevOpsProjectMembershipTableName)
switch orderBy {
case "name":
if reverse {
query.OrderDesc(DevOpsProjectMembershipUsernameColumn)
} else {
query.OrderAsc(DevOpsProjectMembershipUsernameColumn)
}
default:
if reverse {
query.OrderDesc(DevOpsProjectMembershipRoleColumn)
} else {
query.OrderAsc(DevOpsProjectMembershipRoleColumn)
}
}
query.Limit(uint64(limit))
query.Offset(uint64(offset))
if len(sqconditions) > 1 {
query.Where(db.And(sqconditions...))
} else {
query.Where(sqconditions[0])
}
_, err := query.Load(&memberships)
if err != nil && err != dbr.ErrNotFound {
glog.Errorf("%+v", err)
return nil, restful.NewError(http.StatusInternalServerError, err.Error())
}
count, err := query.Count()
if err != nil {
glog.Errorf("%+v", err)
return nil, restful.NewError(http.StatusInternalServerError, err.Error())
}
result := make([]interface{}, 0)
for _, v := range memberships {
result = append(result, v)
}
return &models.PageableResponse{Items: result, TotalCount: int(count)}, nil
}
func GetProjectMember(projectId, username string) (*DevOpsProjectMembership, error) {
dbconn := devops_mysql.OpenDatabase()
member := &DevOpsProjectMembership{}
err := dbconn.Select(DevOpsProjectMembershipColumns...).
From(DevOpsProjectMembershipTableName).
Where(db.And(db.Eq(DevOpsProjectMembershipProjectIdColumn, projectId),
db.Eq(DevOpsProjectMembershipUsernameColumn, username))).
LoadOne(&member)
if err != nil && err != dbr.ErrNotFound {
glog.Errorf("%+v", err)
return nil, restful.NewError(http.StatusInternalServerError, err.Error())
}
if err == dbr.ErrNotFound {
glog.Errorf("%+v", err)
return nil, restful.NewError(http.StatusNotFound, err.Error())
}
return member, nil
}
func AddProjectMember(projectId, operator string, member *DevOpsProjectMembership) (*DevOpsProjectMembership, error) {
dbconn := devops_mysql.OpenDatabase()
jenkinsClinet := admin_jenkins.Client()
membership := &DevOpsProjectMembership{}
err := dbconn.Select(DevOpsProjectMembershipColumns...).
From(DevOpsProjectMembershipTableName).
Where(db.And(
db.Eq(DevOpsProjectMembershipUsernameColumn, member.Username),
db.Eq(DevOpsProjectMembershipProjectIdColumn, projectId))).LoadOne(membership)
// if user could be founded in db, user have been added to project
if err == nil {
err = fmt.Errorf("user [%s] have been added to project", member.Username)
glog.Errorf("%+v", err)
return nil, restful.NewError(http.StatusBadRequest, err.Error())
}
if err != nil && err != db.ErrNotFound {
glog.Errorf("%+v", err)
return nil, restful.NewError(http.StatusInternalServerError, err.Error())
}
globalRole, err := jenkinsClinet.GetGlobalRole(JenkinsAllUserRoleName)
if err != nil {
glog.Errorf("%+v", err)
return nil, restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
if globalRole == nil {
_, err := jenkinsClinet.AddGlobalRole(JenkinsAllUserRoleName, gojenkins.GlobalPermissionIds{
GlobalRead: true,
}, true)
if err != nil {
glog.Errorf("failed to create jenkins global role %+v", err)
return nil, restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
}
err = globalRole.AssignRole(member.Username)
if err != nil {
glog.Errorf("%+v", err)
return nil, restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
projectRole, err := jenkinsClinet.GetProjectRole(GetProjectRoleName(projectId, member.Role))
if err != nil {
glog.Errorf("%+v", err)
return nil, restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
err = projectRole.AssignRole(member.Username)
if err != nil {
glog.Errorf("%+v", err)
return nil, restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
pipelineRole, err := jenkinsClinet.GetProjectRole(GetPipelineRoleName(projectId, member.Role))
if err != nil {
glog.Errorf("%+v", err)
return nil, restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
err = pipelineRole.AssignRole(member.Username)
if err != nil {
glog.Errorf("%+v", err)
return nil, restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
projectMembership := NewDevOpsProjectMemberShip(member.Username, projectId, member.Role, operator)
_, err = dbconn.
InsertInto(DevOpsProjectMembershipTableName).
Columns(DevOpsProjectMembershipColumns...).
Record(projectMembership).Exec()
if err != nil {
glog.Errorf("%+v", err)
err = projectRole.UnAssignRole(member.Username)
if err != nil {
glog.Errorf("%+v", err)
return nil, restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
err = pipelineRole.UnAssignRole(member.Username)
if err != nil {
glog.Errorf("%+v", err)
return nil, restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
return nil, restful.NewError(http.StatusInternalServerError, err.Error())
}
return projectMembership, nil
}
func UpdateProjectMember(projectId, operator string, member *DevOpsProjectMembership) (*DevOpsProjectMembership, error) {
dbconn := devops_mysql.OpenDatabase()
jenkinsClinet := admin_jenkins.Client()
oldMembership := &DevOpsProjectMembership{}
err := dbconn.Select(DevOpsProjectMembershipColumns...).
From(DevOpsProjectMembershipTableName).
Where(db.And(
db.Eq(DevOpsProjectMembershipUsernameColumn, member.Username),
db.Eq(DevOpsProjectMembershipProjectIdColumn, projectId),
)).LoadOne(oldMembership)
if err != nil {
glog.Errorf("%+v", err)
return nil, restful.NewError(http.StatusBadRequest, err.Error())
}
oldProjectRole, err := jenkinsClinet.GetProjectRole(GetProjectRoleName(projectId, oldMembership.Role))
if err != nil {
glog.Errorf("%+v", err)
return nil, restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
err = oldProjectRole.UnAssignRole(member.Username)
if err != nil {
glog.Errorf("%+v", err)
return nil, restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
oldPipelineRole, err := jenkinsClinet.GetProjectRole(GetPipelineRoleName(projectId, oldMembership.Role))
if err != nil {
glog.Errorf("%+v", err)
return nil, restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
err = oldPipelineRole.UnAssignRole(member.Username)
if err != nil {
glog.Errorf("%+v", err)
return nil, restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
projectRole, err := jenkinsClinet.GetProjectRole(GetProjectRoleName(projectId, member.Role))
if err != nil {
glog.Errorf("%+v", err)
return nil, restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
err = projectRole.AssignRole(member.Username)
if err != nil {
glog.Errorf("%+v", err)
return nil, restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
pipelineRole, err := jenkinsClinet.GetProjectRole(GetPipelineRoleName(projectId, member.Role))
if err != nil {
glog.Errorf("%+v", err)
return nil, restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
err = pipelineRole.AssignRole(member.Username)
if err != nil {
glog.Errorf("%+v", err)
return nil, restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
_, err = dbconn.Update(DevOpsProjectMembershipTableName).
Set(DevOpsProjectMembershipRoleColumn, member.Role).
Where(db.And(
db.Eq(DevOpsProjectMembershipProjectIdColumn, projectId),
db.Eq(DevOpsProjectMembershipUsernameColumn, member.Username),
)).Exec()
if err != nil {
glog.Errorf("%+v", err)
return nil, restful.NewError(http.StatusInternalServerError, err.Error())
}
responseMembership := &DevOpsProjectMembership{}
err = dbconn.Select(DevOpsProjectMembershipColumns...).
From(DevOpsProjectMembershipTableName).
Where(db.And(
db.Eq(DevOpsProjectMembershipUsernameColumn, member.Username),
db.Eq(DevOpsProjectMembershipProjectIdColumn, projectId),
)).LoadOne(responseMembership)
if err != nil {
glog.Errorf("%+v", err)
return nil, restful.NewError(http.StatusInternalServerError, err.Error())
}
return responseMembership, nil
}
func DeleteProjectMember(projectId, username string) (string, error) {
dbconn := devops_mysql.OpenDatabase()
jenkinsClient := admin_jenkins.Client()
oldMembership := &DevOpsProjectMembership{}
err := dbconn.Select(DevOpsProjectMembershipColumns...).
From(DevOpsProjectMembershipTableName).
Where(db.And(
db.Eq(DevOpsProjectMembershipUsernameColumn, username),
db.Eq(DevOpsProjectMembershipProjectIdColumn, projectId),
)).LoadOne(oldMembership)
if err != nil && err != db.ErrNotFound {
glog.Errorf("%+v", err)
return "", restful.NewError(http.StatusInternalServerError, err.Error())
}
if err == db.ErrNotFound {
glog.Warningf("user [%s] not found in project", username)
return username, nil
}
if oldMembership.Role == ProjectOwner {
count, err := dbconn.Select(DevOpsProjectMembershipProjectIdColumn).
From(DevOpsProjectMembershipTableName).
Where(db.And(
db.Eq(DevOpsProjectMembershipProjectIdColumn, projectId),
db.Eq(DevOpsProjectMembershipRoleColumn, ProjectOwner))).Count()
if err != nil {
glog.Errorf("%+v", err)
return "", restful.NewError(http.StatusInternalServerError, err.Error())
}
if count == 1 {
err = fmt.Errorf("project must has at least one admin")
glog.Errorf("%+v", err)
return "", restful.NewError(http.StatusBadRequest, err.Error())
}
}
oldProjectRole, err := jenkinsClient.GetProjectRole(GetProjectRoleName(projectId, oldMembership.Role))
if err != nil {
glog.Errorf("%+v", err)
return "", restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
err = oldProjectRole.UnAssignRole(username)
if err != nil {
glog.Errorf("%+v", err)
return "", restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
oldPipelineRole, err := jenkinsClient.GetProjectRole(GetPipelineRoleName(projectId, oldMembership.Role))
if err != nil {
glog.Errorf("%+v", err)
return "", restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
err = oldPipelineRole.UnAssignRole(username)
if err != nil {
glog.Errorf("%+v", err)
return "", restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
_, err = dbconn.DeleteFrom(DevOpsProjectMembershipTableName).
Where(db.And(
db.Eq(DevOpsProjectMembershipProjectIdColumn, projectId),
db.Eq(DevOpsProjectMembershipUsernameColumn, username),
)).Exec()
if err != nil {
glog.Errorf("%+v", err)
return "", restful.NewError(http.StatusInternalServerError, err.Error())
}
return username, nil
}

View File

@@ -0,0 +1,888 @@
/*
Copyright 2019 The KubeSphere Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package devops
import (
"fmt"
"github.com/beevik/etree"
"github.com/golang/glog"
"github.com/kubesphere/sonargo/sonar"
"kubesphere.io/kubesphere/pkg/gojenkins"
"kubesphere.io/kubesphere/pkg/simple/client/sonarqube"
"strconv"
"strings"
"time"
)
const (
NoScmPipelineType = "pipeline"
MultiBranchPipelineType = "multi-branch-pipeline"
)
var ParameterTypeMap = map[string]string{
"hudson.model.StringParameterDefinition": "string",
"hudson.model.ChoiceParameterDefinition": "choice",
"hudson.model.TextParameterDefinition": "text",
"hudson.model.BooleanParameterDefinition": "boolean",
"hudson.model.FileParameterDefinition": "file",
"hudson.model.PasswordParameterDefinition": "password",
}
const (
SonarAnalysisActionClass = "hudson.plugins.sonar.action.SonarAnalysisAction"
SonarMetricKeys = "alert_status,quality_gate_details,bugs,new_bugs,reliability_rating,new_reliability_rating,vulnerabilities,new_vulnerabilities,security_rating,new_security_rating,code_smells,new_code_smells,sqale_rating,new_maintainability_rating,sqale_index,new_technical_debt,coverage,new_coverage,new_lines_to_cover,tests,duplicated_lines_density,new_duplicated_lines_density,duplicated_blocks,ncloc,ncloc_language_distribution,projects,new_lines"
SonarAdditionalFields = "metrics,periods"
)
type SonarStatus struct {
Measures *sonargo.MeasuresComponentObject `json:"measures,omitempty"`
Issues *sonargo.IssuesSearchObject `json:"issues,omitempty"`
JenkinsAction *gojenkins.GeneralObj `json:"jenkinsAction,omitempty"`
Task *sonargo.CeTaskObject `json:"task,omitempty"`
}
type ProjectPipeline struct {
Type string `json:"type"`
Pipeline *NoScmPipeline `json:"pipeline,omitempty"`
MultiBranchPipeline *MultiBranchPipeline `json:"multi_branch_pipeline,omitempty"`
}
type NoScmPipeline struct {
Name string `json:"name"`
Description string `json:"descriptio,omitempty"`
Discarder *DiscarderProperty `json:"discarder,omitempty"`
Parameters []*Parameter `json:"parameters,omitempty"`
DisableConcurrent bool `json:"disable_concurrent,omitempty" mapstructure:"disable_concurrent"`
TimerTrigger *TimerTrigger `json:"timer_trigger,omitempty" mapstructure:"timer_trigger"`
RemoteTrigger *RemoteTrigger `json:"remote_trigger,omitempty" mapstructure:"remote_trigger"`
Jenkinsfile string `json:"jenkinsfile,omitempty"`
}
type MultiBranchPipeline struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Discarder *DiscarderProperty `json:"discarder,omitempty"`
TimerTrigger *TimerTrigger `json:"timer_trigger,omitempty" mapstructure:"timer_trigger"`
SourceType string `json:"source_type"`
GitSource *GitSource `json:"git_source,omitempty"`
GitHubSource *GithubSource `json:"github_source,omitempty"`
SvnSource *SvnSource `json:"svn_source,omitempty"`
SingleSvnSource *SingleSvnSource `json:"single_svn_source,omitempty"`
ScriptPath string `json:"script_path" mapstructure:"script_path"`
}
type GitSource struct {
Url string `json:"url,omitempty" mapstructure:"url"`
CredentialId string `json:"credential_id,omitempty" mapstructure:"credential_id"`
DiscoverBranches bool `json:"discover_branches,omitempty" mapstructure:"discover_branches"`
CloneOption *GitCloneOption `json:"git_clone_option,omitempty" mapstructure:"git_clone_option"`
RegexFilter string `json:"regex_filter,omitempty" mapstructure:"regex_filter"`
}
type GithubSource struct {
Owner string `json:"owner,omitempty" mapstructure:"owner"`
Repo string `json:"repo,omitempty" mapstructure:"repo"`
CredentialId string `json:"credential_id,omitempty" mapstructure:"credential_id"`
ApiUri string `json:"api_uri,omitempty" mapstructure:"api_uri"`
DiscoverBranches int `json:"discover_branches,omitempty" mapstructure:"discover_branches"`
DiscoverPRFromOrigin int `json:"discover_pr_from_origin,omitempty" mapstructure:"discover_pr_from_origin"`
DiscoverPRFromForks *GithubDiscoverPRFromForks `json:"discover_pr_from_forks,omitempty" mapstructure:"discover_pr_from_forks"`
CloneOption *GitCloneOption `json:"git_clone_option,omitempty" mapstructure:"git_clone_option"`
RegexFilter string `json:"regex_filter,omitempty" mapstructure:"regex_filter"`
}
type GitCloneOption struct {
Shallow bool `json:"shallow,omitempty" mapstructure:"shallow"`
Timeout int `json:"timeout,omitempty" mapstructure:"timeout"`
Depth int `json:"depth,omitempty" mapstructure:"depth"`
}
type SvnSource struct {
Remote string `json:"remote,omitempty"`
CredentialId string `json:"credential_id,omitempty" mapstructure:"credential_id"`
Includes string `json:"includes,omitempty"`
Excludes string `json:"excludes,omitempty"`
}
type SingleSvnSource struct {
Remote string `json:"remote,omitempty"`
CredentialId string `json:"credential_id,omitempty" mapstructure:"credential_id"`
}
type GithubDiscoverPRFromForks struct {
Strategy int `json:"strategy,omitempty" mapstructure:"strategy"`
Trust int `json:"trust,omitempty" mapstructure:"trust"`
}
type DiscarderProperty struct {
DaysToKeep string `json:"days_to_keep,omitempty" mapstructure:"days_to_keep"`
NumToKeep string `json:"num_to_keep,omitempty" mapstructure:"num_to_keep"`
}
type Parameter struct {
Name string `json:"name"`
DefaultValue string `json:"default_value,omitempty" mapstructure:"default_value"`
Type string `json:"type"`
Description string `json:"description,omitempty"`
}
type TimerTrigger struct {
// user in no scm job
Cron string `json:"cron,omitempty"`
// use in multi-branch job
Interval string `json:"interval,omitempty"`
}
type RemoteTrigger struct {
Token string `json:"token,omitempty"`
}
func replaceXmlVersion(config, oldVersion, targetVersion string) string {
lines := strings.Split(string(config), "\n")
lines[0] = strings.Replace(lines[0], oldVersion, targetVersion, -1)
output := strings.Join(lines, "\n")
return output
}
func createPipelineConfigXml(pipeline *NoScmPipeline) (string, error) {
doc := etree.NewDocument()
xmlString := `<?xml version='1.0' encoding='UTF-8'?>
<flow-definition plugin="workflow-job">
<actions>
<org.jenkinsci.plugins.pipeline.modeldefinition.actions.DeclarativeJobAction plugin="pipeline-model-definition"/>
<org.jenkinsci.plugins.pipeline.modeldefinition.actions.DeclarativeJobPropertyTrackerAction plugin="pipeline-model-definition">
<jobProperties/>
<triggers/>
<parameters/>
<options/>
</org.jenkinsci.plugins.pipeline.modeldefinition.actions.DeclarativeJobPropertyTrackerAction>
</actions>
</flow-definition>
`
doc.ReadFromString(xmlString)
flow := doc.SelectElement("flow-definition")
flow.CreateElement("description").SetText(pipeline.Description)
properties := flow.CreateElement("properties")
if pipeline.DisableConcurrent {
properties.CreateElement("org.jenkinsci.plugins.workflow.job.properties.DisableConcurrentBuildsJobProperty")
}
if pipeline.Discarder != nil {
discarder := properties.CreateElement("jenkins.model.BuildDiscarderProperty")
strategy := discarder.CreateElement("strategy")
strategy.CreateAttr("class", "hudson.tasks.LogRotator")
strategy.CreateElement("daysToKeep").SetText(pipeline.Discarder.DaysToKeep)
strategy.CreateElement("numToKeep").SetText(pipeline.Discarder.NumToKeep)
strategy.CreateElement("artifactDaysToKeep").SetText("-1")
strategy.CreateElement("artifactNumToKeep").SetText("-1")
}
if pipeline.Parameters != nil {
parameterDefinitions := properties.CreateElement("hudson.model.ParametersDefinitionProperty").
CreateElement("parameterDefinitions")
for _, parameter := range pipeline.Parameters {
for className, typeName := range ParameterTypeMap {
if typeName == parameter.Type {
paramDefine := parameterDefinitions.CreateElement(className)
paramDefine.CreateElement("name").SetText(parameter.Name)
paramDefine.CreateElement("description").SetText(parameter.Description)
switch parameter.Type {
case "choice":
choices := paramDefine.CreateElement("choices")
choices.CreateAttr("class", "java.util.Arrays$ArrayList")
a := choices.CreateElement("a")
a.CreateAttr("class", "string-array")
choiceValues := strings.Split(parameter.DefaultValue, "\n")
for _, choiceValue := range choiceValues {
a.CreateElement("string").SetText(choiceValue)
}
case "file":
break
default:
paramDefine.CreateElement("defaultValue").SetText(parameter.DefaultValue)
}
}
}
}
}
if pipeline.TimerTrigger != nil {
triggers := properties.
CreateElement("org.jenkinsci.plugins.workflow.job.properties.PipelineTriggersJobProperty").
CreateElement("triggers")
triggers.CreateElement("hudson.triggers.TimerTrigger").CreateElement("spec").SetText(pipeline.TimerTrigger.Cron)
}
pipelineDefine := flow.CreateElement("definition")
pipelineDefine.CreateAttr("class", "org.jenkinsci.plugins.workflow.cps.CpsFlowDefinition")
pipelineDefine.CreateAttr("plugin", "workflow-cps")
pipelineDefine.CreateElement("script").SetText(pipeline.Jenkinsfile)
pipelineDefine.CreateElement("sandbox").SetText("true")
flow.CreateElement("triggers")
if pipeline.RemoteTrigger != nil {
flow.CreateElement("authToken").SetText(pipeline.RemoteTrigger.Token)
}
flow.CreateElement("disabled").SetText("false")
doc.Indent(2)
stringXml, err := doc.WriteToString()
if err != nil {
return "", err
}
return replaceXmlVersion(stringXml, "1.0", "1.1"), err
}
func parsePipelineConfigXml(config string) (*NoScmPipeline, error) {
pipeline := &NoScmPipeline{}
config = replaceXmlVersion(config, "1.1", "1.0")
doc := etree.NewDocument()
err := doc.ReadFromString(config)
if err != nil {
return nil, err
}
flow := doc.SelectElement("flow-definition")
if flow == nil {
return nil, fmt.Errorf("can not find pipeline definition")
}
pipeline.Description = flow.SelectElement("description").Text()
properties := flow.SelectElement("properties")
if properties.
SelectElement(
"org.jenkinsci.plugins.workflow.job.properties.DisableConcurrentBuildsJobProperty") != nil {
pipeline.DisableConcurrent = true
}
if properties.SelectElement("jenkins.model.BuildDiscarderProperty") != nil {
strategy := properties.
SelectElement("jenkins.model.BuildDiscarderProperty").
SelectElement("strategy")
pipeline.Discarder = &DiscarderProperty{
DaysToKeep: strategy.SelectElement("daysToKeep").Text(),
NumToKeep: strategy.SelectElement("numToKeep").Text(),
}
}
if parametersProperty := properties.SelectElement("hudson.model.ParametersDefinitionProperty"); parametersProperty != nil {
params := parametersProperty.SelectElement("parameterDefinitions").ChildElements()
for _, param := range params {
switch param.Tag {
case "hudson.model.StringParameterDefinition":
pipeline.Parameters = append(pipeline.Parameters, &Parameter{
Name: param.SelectElement("name").Text(),
Description: param.SelectElement("description").Text(),
DefaultValue: param.SelectElement("defaultValue").Text(),
Type: ParameterTypeMap["hudson.model.StringParameterDefinition"],
})
case "hudson.model.BooleanParameterDefinition":
pipeline.Parameters = append(pipeline.Parameters, &Parameter{
Name: param.SelectElement("name").Text(),
Description: param.SelectElement("description").Text(),
DefaultValue: param.SelectElement("defaultValue").Text(),
Type: ParameterTypeMap["hudson.model.BooleanParameterDefinition"],
})
case "hudson.model.TextParameterDefinition":
pipeline.Parameters = append(pipeline.Parameters, &Parameter{
Name: param.SelectElement("name").Text(),
Description: param.SelectElement("description").Text(),
DefaultValue: param.SelectElement("defaultValue").Text(),
Type: ParameterTypeMap["hudson.model.TextParameterDefinition"],
})
case "hudson.model.FileParameterDefinition":
pipeline.Parameters = append(pipeline.Parameters, &Parameter{
Name: param.SelectElement("name").Text(),
Description: param.SelectElement("description").Text(),
Type: ParameterTypeMap["hudson.model.FileParameterDefinition"],
})
case "hudson.model.PasswordParameterDefinition":
pipeline.Parameters = append(pipeline.Parameters, &Parameter{
Name: param.SelectElement("name").Text(),
Description: param.SelectElement("description").Text(),
DefaultValue: param.SelectElement("name").Text(),
Type: ParameterTypeMap["hudson.model.PasswordParameterDefinition"],
})
case "hudson.model.ChoiceParameterDefinition":
choiceParameter := &Parameter{
Name: param.SelectElement("name").Text(),
Description: param.SelectElement("description").Text(),
Type: ParameterTypeMap["hudson.model.ChoiceParameterDefinition"],
}
choices := param.SelectElement("choices").SelectElement("a").SelectElements("string")
for _, choice := range choices {
choiceParameter.DefaultValue += fmt.Sprintf("%s\n", choice.Text())
}
choiceParameter.DefaultValue = strings.TrimSpace(choiceParameter.DefaultValue)
pipeline.Parameters = append(pipeline.Parameters, choiceParameter)
default:
pipeline.Parameters = append(pipeline.Parameters, &Parameter{
Name: param.SelectElement("name").Text(),
Description: param.SelectElement("description").Text(),
DefaultValue: "unknown",
Type: param.Tag,
})
}
}
}
if triggerProperty := properties.
SelectElement(
"org.jenkinsci.plugins.workflow.job.properties.PipelineTriggersJobProperty"); triggerProperty != nil {
triggers := triggerProperty.SelectElement("triggers")
if timerTrigger := triggers.SelectElement("hudson.triggers.TimerTrigger"); timerTrigger != nil {
pipeline.TimerTrigger = &TimerTrigger{
Cron: timerTrigger.SelectElement("spec").Text(),
}
}
}
if authToken := flow.SelectElement("authToken"); authToken != nil {
pipeline.RemoteTrigger = &RemoteTrigger{
Token: authToken.Text(),
}
}
if definition := flow.SelectElement("definition"); definition != nil {
if script := definition.SelectElement("script"); script != nil {
pipeline.Jenkinsfile = script.Text()
}
}
return pipeline, nil
}
func createMultiBranchPipelineConfigXml(projectName string, pipeline *MultiBranchPipeline) (string, error) {
doc := etree.NewDocument()
xmlString := `
<?xml version='1.0' encoding='UTF-8'?>
<org.jenkinsci.plugins.workflow.multibranch.WorkflowMultiBranchProject plugin="workflow-multibranch">
<actions/>
<properties>
<org.jenkinsci.plugins.pipeline.modeldefinition.config.FolderConfig plugin="pipeline-model-definition">
<dockerLabel></dockerLabel>
<registry plugin="docker-commons"/>
</org.jenkinsci.plugins.pipeline.modeldefinition.config.FolderConfig>
</properties>
<folderViews class="jenkins.branch.MultiBranchProjectViewHolder" plugin="branch-api">
<owner class="org.jenkinsci.plugins.workflow.multibranch.WorkflowMultiBranchProject" reference="../.."/>
</folderViews>
<healthMetrics>
<com.cloudbees.hudson.plugins.folder.health.WorstChildHealthMetric plugin="cloudbees-folder">
<nonRecursive>false</nonRecursive>
</com.cloudbees.hudson.plugins.folder.health.WorstChildHealthMetric>
</healthMetrics>
<icon class="jenkins.branch.MetadataActionFolderIcon" plugin="branch-api">
<owner class="org.jenkinsci.plugins.workflow.multibranch.WorkflowMultiBranchProject" reference="../.."/>
</icon>
</org.jenkinsci.plugins.workflow.multibranch.WorkflowMultiBranchProject>`
err := doc.ReadFromString(xmlString)
if err != nil {
return "", err
}
project := doc.SelectElement("org.jenkinsci.plugins.workflow.multibranch.WorkflowMultiBranchProject")
project.CreateElement("description").SetText(pipeline.Description)
if pipeline.Discarder != nil {
discarder := project.CreateElement("orphanedItemStrategy")
discarder.CreateAttr("class", "com.cloudbees.hudson.plugins.folder.computed.DefaultOrphanedItemStrategy")
discarder.CreateAttr("plugin", "cloudbees-folder")
discarder.CreateElement("pruneDeadBranches").SetText("true")
discarder.CreateElement("daysToKeep").SetText(pipeline.Discarder.DaysToKeep)
discarder.CreateElement("numToKeep").SetText(pipeline.Discarder.NumToKeep)
}
triggers := project.CreateElement("triggers")
if pipeline.TimerTrigger != nil {
timeTrigger := triggers.CreateElement(
"com.cloudbees.hudson.plugins.folder.computed.PeriodicFolderTrigger")
timeTrigger.CreateAttr("plugin", "cloudbees-folder")
millis, err := strconv.ParseInt(pipeline.TimerTrigger.Interval, 10, 64)
if err != nil {
return "", err
}
timeTrigger.CreateElement("spec").SetText(toCrontab(millis))
timeTrigger.CreateElement("interval").SetText(pipeline.TimerTrigger.Interval)
triggers.CreateElement("disabled").SetText("false")
}
sources := project.CreateElement("sources")
sources.CreateAttr("class", "jenkins.branch.MultiBranchProject$BranchSourceList")
sources.CreateAttr("plugin", "branch-api")
sourcesOwner := sources.CreateElement("owner")
sourcesOwner.CreateAttr("class", "org.jenkinsci.plugins.workflow.multibranch.WorkflowMultiBranchProject")
sourcesOwner.CreateAttr("reference", "../..")
branchSource := sources.CreateElement("data").CreateElement("jenkins.branch.BranchSource")
branchSourceStrategy := branchSource.CreateElement("strategy")
branchSourceStrategy.CreateAttr("class", "jenkins.branch.NamedExceptionsBranchPropertyStrategy")
branchSourceStrategy.CreateElement("defaultProperties").CreateAttr("class", "empty-list")
branchSourceStrategy.CreateElement("namedExceptions").CreateAttr("class", "empty-list")
switch pipeline.SourceType {
case "git":
gitDefine := pipeline.GitSource
gitSource := branchSource.CreateElement("source")
gitSource.CreateAttr("class", "jenkins.plugins.git.GitSCMSource")
gitSource.CreateAttr("plugin", "git")
gitSource.CreateElement("id").SetText(projectName + pipeline.Name)
gitSource.CreateElement("remote").SetText(gitDefine.Url)
if gitDefine.CredentialId != "" {
gitSource.CreateElement("credentialsId").SetText(gitDefine.CredentialId)
}
traits := gitSource.CreateElement("traits")
if gitDefine.DiscoverBranches {
traits.CreateElement("jenkins.plugins.git.traits.BranchDiscoveryTrait")
}
if gitDefine.CloneOption != nil {
cloneExtension := traits.CreateElement("jenkins.plugins.git.traits.CloneOptionTrait").CreateElement("extension")
cloneExtension.CreateAttr("class", "hudson.plugins.git.extensions.impl.CloneOption")
cloneExtension.CreateElement("shallow").SetText(strconv.FormatBool(gitDefine.CloneOption.Shallow))
cloneExtension.CreateElement("noTags").SetText(strconv.FormatBool(false))
cloneExtension.CreateElement("reference")
if gitDefine.CloneOption.Timeout >= 0 {
cloneExtension.CreateElement("timeout").SetText(strconv.Itoa(gitDefine.CloneOption.Timeout))
} else {
cloneExtension.CreateElement("timeout").SetText(strconv.Itoa(10))
}
if gitDefine.CloneOption.Depth >= 0 {
cloneExtension.CreateElement("depth").SetText(strconv.Itoa(gitDefine.CloneOption.Depth))
} else {
cloneExtension.CreateElement("depth").SetText(strconv.Itoa(1))
}
}
if gitDefine.RegexFilter != "" {
regexTraits := traits.CreateElement("jenkins.scm.impl.trait.RegexSCMHeadFilterTrait")
regexTraits.CreateAttr("plugin", "scm-api@2.4.0")
regexTraits.CreateElement("regex").SetText(gitDefine.RegexFilter)
}
case "github":
githubDefine := pipeline.GitHubSource
githubSource := branchSource.CreateElement("source")
githubSource.CreateAttr("class", "org.jenkinsci.plugins.github_branch_source.GitHubSCMSource")
githubSource.CreateAttr("plugin", "github-branch-source")
githubSource.CreateElement("id").SetText(projectName + pipeline.Name)
githubSource.CreateElement("credentialsId").SetText(githubDefine.CredentialId)
githubSource.CreateElement("repoOwner").SetText(githubDefine.Owner)
githubSource.CreateElement("repository").SetText(githubDefine.Repo)
if githubDefine.ApiUri != "" {
githubSource.CreateElement("apiUri").SetText(githubDefine.ApiUri)
}
traits := githubSource.CreateElement("traits")
if githubDefine.DiscoverBranches != 0 {
traits.CreateElement("org.jenkinsci.plugins.github__branch__source.BranchDiscoveryTrait").
CreateElement("strategyId").SetText(strconv.Itoa(githubDefine.DiscoverBranches))
}
if githubDefine.DiscoverPRFromOrigin != 0 {
traits.CreateElement("org.jenkinsci.plugins.github__branch__source.OriginPullRequestDiscoveryTrait").
CreateElement("strategyId").SetText(strconv.Itoa(githubDefine.DiscoverPRFromOrigin))
}
if githubDefine.DiscoverPRFromForks != nil {
forkTrait := traits.CreateElement("org.jenkinsci.plugins.github__branch__source.ForkPullRequestDiscoveryTrait")
forkTrait.CreateElement("strategyId").SetText(strconv.Itoa(githubDefine.DiscoverPRFromForks.Strategy))
trustClass := "org.jenkinsci.plugins.github_branch_source.ForkPullRequestDiscoveryTrait$"
switch githubDefine.DiscoverPRFromForks.Trust {
case 1:
trustClass += "TrustContributors"
case 2:
trustClass += "TrustEveryone"
case 3:
trustClass += "TrustPermission"
case 4:
trustClass += "TrustNobody"
default:
return "", fmt.Errorf("unsupport trust choice")
}
forkTrait.CreateElement("trust").CreateAttr("class", trustClass)
}
if githubDefine.CloneOption != nil {
cloneExtension := traits.CreateElement("jenkins.plugins.git.traits.CloneOptionTrait").CreateElement("extension")
cloneExtension.CreateAttr("class", "hudson.plugins.git.extensions.impl.CloneOption")
cloneExtension.CreateElement("shallow").SetText(strconv.FormatBool(githubDefine.CloneOption.Shallow))
cloneExtension.CreateElement("noTags").SetText(strconv.FormatBool(false))
cloneExtension.CreateElement("reference")
if githubDefine.CloneOption.Timeout >= 0 {
cloneExtension.CreateElement("timeout").SetText(strconv.Itoa(githubDefine.CloneOption.Timeout))
} else {
cloneExtension.CreateElement("timeout").SetText(strconv.Itoa(10))
}
if githubDefine.CloneOption.Depth >= 0 {
cloneExtension.CreateElement("depth").SetText(strconv.Itoa(githubDefine.CloneOption.Depth))
} else {
cloneExtension.CreateElement("depth").SetText(strconv.Itoa(1))
}
}
if githubDefine.RegexFilter != "" {
regexTraits := traits.CreateElement("jenkins.scm.impl.trait.RegexSCMHeadFilterTrait")
regexTraits.CreateAttr("plugin", "scm-api@2.4.0")
regexTraits.CreateElement("regex").SetText(githubDefine.RegexFilter)
}
case "svn":
svnDefine := pipeline.SvnSource
svnSource := branchSource.CreateElement("source")
svnSource.CreateAttr("class", "jenkins.scm.impl.subversion.SubversionSCMSource")
svnSource.CreateAttr("plugin", "subversion")
svnSource.CreateElement("id").SetText(projectName + pipeline.Name)
if svnDefine.CredentialId != "" {
svnSource.CreateElement("credentialsId").SetText(svnDefine.CredentialId)
}
if svnDefine.Remote != "" {
svnSource.CreateElement("remoteBase").SetText(svnDefine.Remote)
}
if svnDefine.Includes != "" {
svnSource.CreateElement("includes").SetText(svnDefine.Includes)
}
if svnDefine.Excludes != "" {
svnSource.CreateElement("excludes").SetText(svnDefine.Excludes)
}
case "single_svn":
singleSvnDefine := pipeline.SingleSvnSource
if err != nil {
return "", err
}
svnSource := branchSource.CreateElement("source")
svnSource.CreateAttr("class", "jenkins.scm.impl.SingleSCMSource")
svnSource.CreateAttr("plugin", "scm-api")
svnSource.CreateElement("id").SetText(projectName + pipeline.Name)
svnSource.CreateElement("name").SetText("master")
scm := svnSource.CreateElement("scm")
scm.CreateAttr("class", "hudson.scm.SubversionSCM")
scm.CreateAttr("plugin", "subversion")
location := scm.CreateElement("locations").CreateElement("hudson.scm.SubversionSCM_-ModuleLocation")
if singleSvnDefine.Remote != "" {
location.CreateElement("remote").SetText(singleSvnDefine.Remote)
}
if singleSvnDefine.CredentialId != "" {
location.CreateElement("credentialsId").SetText(singleSvnDefine.CredentialId)
}
location.CreateElement("local").SetText(".")
location.CreateElement("depthOption").SetText("infinity")
location.CreateElement("ignoreExternalsOption").SetText("true")
location.CreateElement("cancelProcessOnExternalsFail").SetText("true")
svnSource.CreateElement("excludedRegions")
svnSource.CreateElement("includedRegions")
svnSource.CreateElement("excludedUsers")
svnSource.CreateElement("excludedRevprop")
svnSource.CreateElement("excludedCommitMessages")
svnSource.CreateElement("workspaceUpdater").CreateAttr("class", "hudson.scm.subversion.UpdateUpdater")
svnSource.CreateElement("ignoreDirPropChanges").SetText("false")
svnSource.CreateElement("filterChangelog").SetText("false")
svnSource.CreateElement("quietOperation").SetText("true")
default:
return "", fmt.Errorf("unsupport source type")
}
factory := project.CreateElement("factory")
factory.CreateAttr("class", "org.jenkinsci.plugins.workflow.multibranch.WorkflowBranchProjectFactory")
factoryOwner := factory.CreateElement("owner")
factoryOwner.CreateAttr("class", "org.jenkinsci.plugins.workflow.multibranch.WorkflowMultiBranchProject")
factoryOwner.CreateAttr("reference", "../..")
factory.CreateElement("scriptPath").SetText(pipeline.ScriptPath)
doc.Indent(2)
stringXml, err := doc.WriteToString()
return replaceXmlVersion(stringXml, "1.0", "1.1"), err
}
func parseMultiBranchPipelineConfigXml(config string) (*MultiBranchPipeline, error) {
pipeline := &MultiBranchPipeline{}
config = replaceXmlVersion(config, "1.1", "1.0")
doc := etree.NewDocument()
err := doc.ReadFromString(config)
if err != nil {
return nil, err
}
project := doc.SelectElement("org.jenkinsci.plugins.workflow.multibranch.WorkflowMultiBranchProject")
if project == nil {
return nil, fmt.Errorf("can not parse mutibranch pipeline config")
}
pipeline.Description = project.SelectElement("description").Text()
if discarder := project.SelectElement("orphanedItemStrategy"); discarder != nil {
pipeline.Discarder = &DiscarderProperty{
DaysToKeep: discarder.SelectElement("daysToKeep").Text(),
NumToKeep: discarder.SelectElement("numToKeep").Text(),
}
}
if triggers := project.SelectElement("triggers"); triggers != nil {
if timerTrigger := triggers.SelectElement(
"com.cloudbees.hudson.plugins.folder.computed.PeriodicFolderTrigger"); timerTrigger != nil {
pipeline.TimerTrigger = &TimerTrigger{
Interval: timerTrigger.SelectElement("interval").Text(),
}
}
}
if sources := project.SelectElement("sources"); sources != nil {
if sourcesData := sources.SelectElement("data"); sourcesData != nil {
if branchSource := sourcesData.SelectElement("jenkins.branch.BranchSource"); branchSource != nil {
source := branchSource.SelectElement("source")
switch source.SelectAttr("class").Value {
case "org.jenkinsci.plugins.github_branch_source.GitHubSCMSource":
githubSource := &GithubSource{}
if credential := source.SelectElement("credentialsId"); credential != nil {
githubSource.CredentialId = credential.Text()
}
if repoOwner := source.SelectElement("repoOwner"); repoOwner != nil {
githubSource.Owner = repoOwner.Text()
}
if repository := source.SelectElement("repository"); repository != nil {
githubSource.Repo = repository.Text()
}
if apiUri := source.SelectElement("apiUri"); apiUri != nil {
githubSource.ApiUri = apiUri.Text()
}
traits := source.SelectElement("traits")
if branchDiscoverTrait := traits.SelectElement(
"org.jenkinsci.plugins.github__branch__source.BranchDiscoveryTrait"); branchDiscoverTrait != nil {
strategyId, err := strconv.Atoi(branchDiscoverTrait.SelectElement("strategyId").Text())
if err != nil {
return nil, err
}
githubSource.DiscoverBranches = strategyId
}
if originPRDiscoverTrait := traits.SelectElement(
"org.jenkinsci.plugins.github__branch__source.OriginPullRequestDiscoveryTrait"); originPRDiscoverTrait != nil {
strategyId, err := strconv.Atoi(originPRDiscoverTrait.SelectElement("strategyId").Text())
if err != nil {
return nil, err
}
githubSource.DiscoverPRFromOrigin = strategyId
}
if forkPRDiscoverTrait := traits.SelectElement(
"org.jenkinsci.plugins.github__branch__source.ForkPullRequestDiscoveryTrait"); forkPRDiscoverTrait != nil {
strategyId, err := strconv.Atoi(forkPRDiscoverTrait.SelectElement("strategyId").Text())
if err != nil {
return nil, err
}
trustClass := forkPRDiscoverTrait.SelectElement("trust").SelectAttr("class").Value
trust := strings.Split(trustClass, "$")
switch trust[1] {
case "TrustContributors":
githubSource.DiscoverPRFromForks = &GithubDiscoverPRFromForks{
Strategy: strategyId,
Trust: 1,
}
case "TrustEveryone":
githubSource.DiscoverPRFromForks = &GithubDiscoverPRFromForks{
Strategy: strategyId,
Trust: 2,
}
case "TrustPermission":
githubSource.DiscoverPRFromForks = &GithubDiscoverPRFromForks{
Strategy: strategyId,
Trust: 3,
}
case "TrustNobody":
githubSource.DiscoverPRFromForks = &GithubDiscoverPRFromForks{
Strategy: strategyId,
Trust: 4,
}
}
if cloneTrait := traits.SelectElement(
"jenkins.plugins.git.traits.CloneOptionTrait"); cloneTrait != nil {
if cloneExtension := cloneTrait.SelectElement(
"extension"); cloneExtension != nil {
githubSource.CloneOption = &GitCloneOption{}
if value, err := strconv.ParseBool(cloneExtension.SelectElement("shallow").Text()); err == nil {
githubSource.CloneOption.Shallow = value
}
if value, err := strconv.ParseInt(cloneExtension.SelectElement("timeout").Text(), 10, 32); err == nil {
githubSource.CloneOption.Timeout = int(value)
}
if value, err := strconv.ParseInt(cloneExtension.SelectElement("depth").Text(), 10, 32); err == nil {
githubSource.CloneOption.Depth = int(value)
}
}
}
if regexTrait := traits.SelectElement(
"jenkins.scm.impl.trait.RegexSCMHeadFilterTrait"); regexTrait != nil {
if regex := regexTrait.SelectElement("regex"); regex != nil {
githubSource.RegexFilter = regex.Text()
}
}
}
pipeline.GitHubSource = githubSource
pipeline.SourceType = "github"
case "jenkins.plugins.git.GitSCMSource":
gitSource := &GitSource{}
if credential := source.SelectElement("credentialsId"); credential != nil {
gitSource.CredentialId = credential.Text()
}
if remote := source.SelectElement("remote"); remote != nil {
gitSource.Url = remote.Text()
}
traits := source.SelectElement("traits")
if branchDiscoverTrait := traits.SelectElement(
"jenkins.plugins.git.traits.BranchDiscoveryTrait"); branchDiscoverTrait != nil {
gitSource.DiscoverBranches = true
}
if cloneTrait := traits.SelectElement(
"jenkins.plugins.git.traits.CloneOptionTrait"); cloneTrait != nil {
if cloneExtension := cloneTrait.SelectElement(
"extension"); cloneExtension != nil {
gitSource.CloneOption = &GitCloneOption{}
if value, err := strconv.ParseBool(cloneExtension.SelectElement("shallow").Text()); err == nil {
gitSource.CloneOption.Shallow = value
}
if value, err := strconv.ParseInt(cloneExtension.SelectElement("timeout").Text(), 10, 32); err == nil {
gitSource.CloneOption.Timeout = int(value)
}
if value, err := strconv.ParseInt(cloneExtension.SelectElement("depth").Text(), 10, 32); err == nil {
gitSource.CloneOption.Depth = int(value)
}
}
}
if regexTrait := traits.SelectElement(
"jenkins.scm.impl.trait.RegexSCMHeadFilterTrait"); regexTrait != nil {
if regex := regexTrait.SelectElement("regex"); regex != nil {
gitSource.RegexFilter = regex.Text()
}
}
pipeline.SourceType = "git"
pipeline.GitSource = gitSource
case "jenkins.scm.impl.SingleSCMSource":
singleSvnSource := &SingleSvnSource{}
if scm := source.SelectElement("scm"); scm != nil {
if locations := scm.SelectElement("locations"); locations != nil {
if moduleLocations := locations.SelectElement("hudson.scm.SubversionSCM_-ModuleLocation"); moduleLocations != nil {
if remote := moduleLocations.SelectElement("remote"); remote != nil {
singleSvnSource.Remote = remote.Text()
}
if credentialId := moduleLocations.SelectElement("credentialsId"); credentialId != nil {
singleSvnSource.CredentialId = credentialId.Text()
}
}
}
}
pipeline.SourceType = "single_svn"
pipeline.SingleSvnSource = singleSvnSource
case "jenkins.scm.impl.subversion.SubversionSCMSource":
svnSource := &SvnSource{}
if remote := source.SelectElement("remoteBase"); remote != nil {
svnSource.Remote = remote.Text()
}
if credentialsId := source.SelectElement("credentialsId"); credentialsId != nil {
svnSource.CredentialId = credentialsId.Text()
}
if includes := source.SelectElement("includes"); includes != nil {
svnSource.Includes = includes.Text()
}
if excludes := source.SelectElement("excludes"); excludes != nil {
svnSource.Excludes = excludes.Text()
}
pipeline.SourceType = "svn"
pipeline.SvnSource = svnSource
}
}
}
}
pipeline.ScriptPath = project.SelectElement("factory").SelectElement("scriptPath").Text()
return pipeline, nil
}
func toCrontab(millis int64) string {
if millis*time.Millisecond.Nanoseconds() <= 5*time.Minute.Nanoseconds() {
return "* * * * *"
}
if millis*time.Millisecond.Nanoseconds() <= 30*time.Minute.Nanoseconds() {
return "H/5 * * * *"
}
if millis*time.Millisecond.Nanoseconds() <= 1*time.Hour.Nanoseconds() {
return "H/15 * * * *"
}
if millis*time.Millisecond.Nanoseconds() <= 8*time.Hour.Nanoseconds() {
return "H/30 * * * *"
}
if millis*time.Millisecond.Nanoseconds() <= 24*time.Hour.Nanoseconds() {
return "H H/4 * * *"
}
if millis*time.Millisecond.Nanoseconds() <= 48*time.Hour.Nanoseconds() {
return "H H/12 * * *"
}
return "H H * * *"
}
func getBuildSonarResults(build *gojenkins.Build) ([]*SonarStatus, error) {
sonarClient := sonarqube.Client()
actions := build.GetActions()
sonarStatuses := make([]*SonarStatus, 0)
for _, action := range actions {
if action.ClassName == SonarAnalysisActionClass {
sonarStatus := &SonarStatus{}
taskOptions := &sonargo.CeTaskOption{
Id: action.SonarTaskId,
}
ceTask, _, err := sonarClient.Ce.Task(taskOptions)
if err != nil {
glog.Errorf("get sonar task error [%+v]", err)
continue
}
sonarStatus.Task = ceTask
measuresComponentOption := &sonargo.MeasuresComponentOption{
Component: ceTask.Task.ComponentKey,
AdditionalFields: SonarAdditionalFields,
MetricKeys: SonarMetricKeys,
}
measures, _, err := sonarClient.Measures.Component(measuresComponentOption)
if err != nil {
glog.Errorf("get sonar task error [%+v]", err)
continue
}
sonarStatus.Measures = measures
issuesSearchOption := &sonargo.IssuesSearchOption{
AdditionalFields: "_all",
ComponentKeys: ceTask.Task.ComponentKey,
Resolved: "false",
Ps: "10",
S: "FILE_LINE",
Facets: "severities,types",
}
issuesSearch, _, err := sonarClient.Issues.Search(issuesSearchOption)
sonarStatus.Issues = issuesSearch
jenkinsAction := action
sonarStatus.JenkinsAction = &jenkinsAction
sonarStatuses = append(sonarStatuses, sonarStatus)
}
}
return sonarStatuses, nil
}

View File

@@ -0,0 +1,270 @@
/*
Copyright 2019 The KubeSphere Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package devops
import (
"fmt"
"github.com/emicklei/go-restful"
"github.com/golang/glog"
"kubesphere.io/kubesphere/pkg/gojenkins/utils"
"kubesphere.io/kubesphere/pkg/simple/client/admin_jenkins"
"net/http"
)
func CreateProjectPipeline(projectId string, pipeline *ProjectPipeline) (string, error) {
jenkinsClient := admin_jenkins.Client()
switch pipeline.Type {
case NoScmPipelineType:
config, err := createPipelineConfigXml(pipeline.Pipeline)
if err != nil {
glog.Errorf("%+v", err)
return "", restful.NewError(http.StatusInternalServerError, err.Error())
}
job, err := jenkinsClient.GetJob(pipeline.Pipeline.Name, projectId)
if job != nil {
err := fmt.Errorf("job name [%s] has been used", job.GetName())
glog.Warning(err.Error())
return "", restful.NewError(http.StatusConflict, err.Error())
}
if err != nil && utils.GetJenkinsStatusCode(err) != http.StatusNotFound {
glog.Errorf("%+v", err)
return "", restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
_, err = jenkinsClient.CreateJobInFolder(config, pipeline.Pipeline.Name, projectId)
if err != nil {
glog.Errorf("%+v", err)
return "", restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
return pipeline.Pipeline.Name, nil
case MultiBranchPipelineType:
config, err := createMultiBranchPipelineConfigXml(projectId, pipeline.MultiBranchPipeline)
if err != nil {
glog.Errorf("%+v", err)
return "", restful.NewError(http.StatusInternalServerError, err.Error())
}
job, err := jenkinsClient.GetJob(pipeline.MultiBranchPipeline.Name, projectId)
if job != nil {
err := fmt.Errorf("job name [%s] has been used", job.GetName())
glog.Warning(err.Error())
return "", restful.NewError(http.StatusConflict, err.Error())
}
if err != nil && utils.GetJenkinsStatusCode(err) != http.StatusNotFound {
glog.Errorf("%+v", err)
return "", restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
_, err = jenkinsClient.CreateJobInFolder(config, pipeline.MultiBranchPipeline.Name, projectId)
if err != nil {
glog.Errorf("%+v", err)
return "", restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
return pipeline.MultiBranchPipeline.Name, nil
default:
err := fmt.Errorf("error unsupport job type")
glog.Errorf("%+v", err)
return "", restful.NewError(http.StatusBadRequest, err.Error())
}
}
func DeleteProjectPipeline(projectId string, pipelineId string) (string, error) {
jenkinsClient := admin_jenkins.Client()
_, err := jenkinsClient.DeleteJob(pipelineId, projectId)
if err != nil {
glog.Errorf("%+v", err)
return "", restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
return pipelineId, nil
}
func UpdateProjectPipeline(projectId, pipelineId string, pipeline *ProjectPipeline) (string, error) {
jenkinsClient := admin_jenkins.Client()
switch pipeline.Type {
case NoScmPipelineType:
config, err := createPipelineConfigXml(pipeline.Pipeline)
if err != nil {
glog.Errorf("%+v", err)
return "", restful.NewError(http.StatusInternalServerError, err.Error())
}
job, err := jenkinsClient.GetJob(pipelineId, projectId)
if err != nil {
glog.Errorf("%+v", err)
return "", restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
err = job.UpdateConfig(config)
if err != nil {
glog.Errorf("%+v", err)
return "", restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
return pipeline.Pipeline.Name, nil
case MultiBranchPipelineType:
config, err := createMultiBranchPipelineConfigXml(projectId, pipeline.MultiBranchPipeline)
if err != nil {
glog.Errorf("%+v", err)
return "", restful.NewError(http.StatusInternalServerError, err.Error())
}
job, err := jenkinsClient.GetJob(pipelineId, projectId)
if err != nil {
glog.Errorf("%+v", err)
return "", restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
err = job.UpdateConfig(config)
if err != nil {
glog.Errorf("%+v", err)
return "", restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
return pipeline.MultiBranchPipeline.Name, nil
default:
err := fmt.Errorf("error unsupport job type")
glog.Errorf("%+v", err)
return "", restful.NewError(http.StatusBadRequest, err.Error())
}
}
func GetProjectPipeline(projectId, pipelineId string) (*ProjectPipeline, error) {
jenkinsClient := admin_jenkins.Client()
job, err := jenkinsClient.GetJob(pipelineId, projectId)
if err != nil {
glog.Errorf("%+v", err)
return nil, restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
switch job.Raw.Class {
case "org.jenkinsci.plugins.workflow.job.WorkflowJob":
config, err := job.GetConfig()
if err != nil {
glog.Errorf("%+v", err)
return nil, restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
pipeline, err := parsePipelineConfigXml(config)
if err != nil {
glog.Errorf("%+v", err)
return nil, restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
pipeline.Name = pipelineId
return &ProjectPipeline{
Type: NoScmPipelineType,
Pipeline: pipeline,
}, nil
case "org.jenkinsci.plugins.workflow.multibranch.WorkflowMultiBranchProject":
config, err := job.GetConfig()
if err != nil {
glog.Errorf("%+v", err)
return nil, restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
pipeline, err := parseMultiBranchPipelineConfigXml(config)
if err != nil {
glog.Errorf("%+v", err)
return nil, restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
pipeline.Name = pipelineId
return &ProjectPipeline{
Type: MultiBranchPipelineType,
MultiBranchPipeline: pipeline,
}, nil
default:
err := fmt.Errorf("error unsupport job type")
glog.Errorf("%+v", err)
return nil, restful.NewError(http.StatusBadRequest, err.Error())
}
}
func GetPipelineSonar(projectId, pipelineId string) ([]*SonarStatus, error) {
jenkinsClient := admin_jenkins.Client()
job, err := jenkinsClient.GetJob(pipelineId, projectId)
if err != nil {
glog.Errorf("%+v", err)
return nil, restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
build, err := job.GetLastBuild()
if err != nil {
glog.Errorf("%+v", err)
return nil, restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
sonarStatus, err := getBuildSonarResults(build)
if err != nil {
glog.Errorf("%+v", err)
return nil, restful.NewError(http.StatusBadRequest, err.Error())
}
if len(sonarStatus) == 0 {
build, err := job.GetLastCompletedBuild()
if err != nil && utils.GetJenkinsStatusCode(err) != http.StatusNotFound {
glog.Errorf("%+v", err)
return nil, restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
sonarStatus, err = getBuildSonarResults(build)
if err != nil {
glog.Errorf("%+v", err)
return nil, restful.NewError(http.StatusBadRequest, err.Error())
}
}
return sonarStatus, nil
}
func GetMultiBranchPipelineSonar(projectId, pipelineId, branchId string) ([]*SonarStatus, error) {
jenkinsClient := admin_jenkins.Client()
job, err := jenkinsClient.GetJob(branchId, projectId, pipelineId)
if err != nil {
glog.Errorf("%+v", err)
return nil, restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
build, err := job.GetLastBuild()
if err != nil {
glog.Errorf("%+v", err)
return nil, restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
sonarStatus, err := getBuildSonarResults(build)
if err != nil {
glog.Errorf("%+v", err)
return nil, restful.NewError(http.StatusBadRequest, err.Error())
}
if len(sonarStatus) == 0 {
build, err := job.GetLastCompletedBuild()
if err != nil && utils.GetJenkinsStatusCode(err) != http.StatusNotFound {
glog.Errorf("%+v", err)
return nil, restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
sonarStatus, err = getBuildSonarResults(build)
if err != nil {
glog.Errorf("%+v", err)
return nil, restful.NewError(http.StatusBadRequest, err.Error())
}
}
return sonarStatus, nil
}

View File

@@ -0,0 +1,523 @@
/*
Copyright 2019 The KubeSphere Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package devops
import (
"reflect"
"testing"
)
func Test_NoScmPipelineConfig(t *testing.T) {
inputs := []*NoScmPipeline{
&NoScmPipeline{
Name: "",
Description: "for test",
Jenkinsfile: "node{echo 'hello'}",
},
&NoScmPipeline{
Name: "",
Description: "",
Jenkinsfile: "node{echo 'hello'}",
},
&NoScmPipeline{
Name: "",
Description: "",
Jenkinsfile: "node{echo 'hello'}",
DisableConcurrent: true,
},
}
for _, input := range inputs {
outputString, err := createPipelineConfigXml(input)
if err != nil {
t.Fatalf("should not get error %+v", err)
}
output, err := parsePipelineConfigXml(outputString)
if err != nil {
t.Fatalf("should not get error %+v", err)
}
if !reflect.DeepEqual(input, output) {
t.Fatalf("input [%+v] output [%+v] should equal ", input, output)
}
}
}
func Test_NoScmPipelineConfig_Discarder(t *testing.T) {
inputs := []*NoScmPipeline{
&NoScmPipeline{
Name: "",
Description: "for test",
Jenkinsfile: "node{echo 'hello'}",
Discarder: &DiscarderProperty{
"3", "5",
},
},
&NoScmPipeline{
Name: "",
Description: "for test",
Jenkinsfile: "node{echo 'hello'}",
Discarder: &DiscarderProperty{
"3", "",
},
},
&NoScmPipeline{
Name: "",
Description: "for test",
Jenkinsfile: "node{echo 'hello'}",
Discarder: &DiscarderProperty{
"", "21321",
},
},
&NoScmPipeline{
Name: "",
Description: "for test",
Jenkinsfile: "node{echo 'hello'}",
Discarder: &DiscarderProperty{
"", "",
},
},
}
for _, input := range inputs {
outputString, err := createPipelineConfigXml(input)
if err != nil {
t.Fatalf("should not get error %+v", err)
}
output, err := parsePipelineConfigXml(outputString)
if err != nil {
t.Fatalf("should not get error %+v", err)
}
if !reflect.DeepEqual(input, output) {
t.Fatalf("input [%+v] output [%+v] should equal ", input, output)
}
}
}
func Test_NoScmPipelineConfig_Param(t *testing.T) {
inputs := []*NoScmPipeline{
&NoScmPipeline{
Name: "",
Description: "for test",
Jenkinsfile: "node{echo 'hello'}",
Parameters: []*Parameter{
&Parameter{
Name: "d",
DefaultValue: "a\nb",
Type: "choice",
Description: "fortest",
},
},
},
&NoScmPipeline{
Name: "",
Description: "for test",
Jenkinsfile: "node{echo 'hello'}",
Parameters: []*Parameter{
&Parameter{
Name: "a",
DefaultValue: "abc",
Type: "string",
Description: "fortest",
},
&Parameter{
Name: "b",
DefaultValue: "false",
Type: "boolean",
Description: "fortest",
},
&Parameter{
Name: "c",
DefaultValue: "password \n aaa",
Type: "text",
Description: "fortest",
},
&Parameter{
Name: "d",
DefaultValue: "a\nb",
Type: "choice",
Description: "fortest",
},
},
},
}
for _, input := range inputs {
outputString, err := createPipelineConfigXml(input)
if err != nil {
t.Fatalf("should not get error %+v", err)
}
output, err := parsePipelineConfigXml(outputString)
if err != nil {
t.Fatalf("should not get error %+v", err)
}
if !reflect.DeepEqual(input, output) {
t.Fatalf("input [%+v] output [%+v] should equal ", input, output)
}
}
}
func Test_NoScmPipelineConfig_Trigger(t *testing.T) {
inputs := []*NoScmPipeline{
&NoScmPipeline{
Name: "",
Description: "for test",
Jenkinsfile: "node{echo 'hello'}",
TimerTrigger: &TimerTrigger{
Cron: "1 1 1 * * *",
},
},
&NoScmPipeline{
Name: "",
Description: "for test",
Jenkinsfile: "node{echo 'hello'}",
RemoteTrigger: &RemoteTrigger{
Token: "abc",
},
},
&NoScmPipeline{
Name: "",
Description: "for test",
Jenkinsfile: "node{echo 'hello'}",
TimerTrigger: &TimerTrigger{
Cron: "1 1 1 * * *",
},
RemoteTrigger: &RemoteTrigger{
Token: "abc",
},
},
}
for _, input := range inputs {
outputString, err := createPipelineConfigXml(input)
if err != nil {
t.Fatalf("should not get error %+v", err)
}
output, err := parsePipelineConfigXml(outputString)
if err != nil {
t.Fatalf("should not get error %+v", err)
}
if !reflect.DeepEqual(input, output) {
t.Fatalf("input [%+v] output [%+v] should equal ", input, output)
}
}
}
func Test_MultiBranchPipelineConfig(t *testing.T) {
inputs := []*MultiBranchPipeline{
&MultiBranchPipeline{
Name: "",
Description: "for test",
ScriptPath: "Jenkinsfile",
SourceType: "git",
GitSource: &GitSource{},
},
&MultiBranchPipeline{
Name: "",
Description: "for test",
ScriptPath: "Jenkinsfile",
SourceType: "github",
GitHubSource: &GithubSource{},
},
&MultiBranchPipeline{
Name: "",
Description: "for test",
ScriptPath: "Jenkinsfile",
SourceType: "single_svn",
SingleSvnSource: &SingleSvnSource{},
},
&MultiBranchPipeline{
Name: "",
Description: "for test",
ScriptPath: "Jenkinsfile",
SourceType: "svn",
SvnSource: &SvnSource{},
},
}
for _, input := range inputs {
outputString, err := createMultiBranchPipelineConfigXml("", input)
if err != nil {
t.Fatalf("should not get error %+v", err)
}
output, err := parseMultiBranchPipelineConfigXml(outputString)
if err != nil {
t.Fatalf("should not get error %+v", err)
}
if !reflect.DeepEqual(input, output) {
t.Fatalf("input [%+v] output [%+v] should equal ", input, output)
}
}
}
func Test_MultiBranchPipelineConfig_Discarder(t *testing.T) {
inputs := []*MultiBranchPipeline{
&MultiBranchPipeline{
Name: "",
Description: "for test",
ScriptPath: "Jenkinsfile",
SourceType: "git",
Discarder: &DiscarderProperty{
DaysToKeep: "1",
NumToKeep: "2",
},
GitSource: &GitSource{},
},
}
for _, input := range inputs {
outputString, err := createMultiBranchPipelineConfigXml("", input)
if err != nil {
t.Fatalf("should not get error %+v", err)
}
output, err := parseMultiBranchPipelineConfigXml(outputString)
if err != nil {
t.Fatalf("should not get error %+v", err)
}
if !reflect.DeepEqual(input, output) {
t.Fatalf("input [%+v] output [%+v] should equal ", input, output)
}
}
}
func Test_MultiBranchPipelineConfig_TimerTrigger(t *testing.T) {
inputs := []*MultiBranchPipeline{
&MultiBranchPipeline{
Name: "",
Description: "for test",
ScriptPath: "Jenkinsfile",
SourceType: "git",
TimerTrigger: &TimerTrigger{
Interval: "12345566",
},
GitSource: &GitSource{},
},
}
for _, input := range inputs {
outputString, err := createMultiBranchPipelineConfigXml("", input)
if err != nil {
t.Fatalf("should not get error %+v", err)
}
output, err := parseMultiBranchPipelineConfigXml(outputString)
if err != nil {
t.Fatalf("should not get error %+v", err)
}
if !reflect.DeepEqual(input, output) {
t.Fatalf("input [%+v] output [%+v] should equal ", input, output)
}
}
}
func Test_MultiBranchPipelineConfig_Source(t *testing.T) {
inputs := []*MultiBranchPipeline{
&MultiBranchPipeline{
Name: "",
Description: "for test",
ScriptPath: "Jenkinsfile",
SourceType: "git",
TimerTrigger: &TimerTrigger{
Interval: "12345566",
},
GitSource: &GitSource{
Url: "https://github.com/kubesphere/devops",
CredentialId: "git",
DiscoverBranches: true,
},
},
&MultiBranchPipeline{
Name: "",
Description: "for test",
ScriptPath: "Jenkinsfile",
SourceType: "github",
TimerTrigger: &TimerTrigger{
Interval: "12345566",
},
GitHubSource: &GithubSource{
Owner: "kubesphere",
Repo: "devops",
CredentialId: "github",
ApiUri: "https://api.github.com",
DiscoverBranches: 1,
DiscoverPRFromOrigin: 2,
DiscoverPRFromForks: &GithubDiscoverPRFromForks{
Strategy: 1,
Trust: 1,
},
},
},
&MultiBranchPipeline{
Name: "",
Description: "for test",
ScriptPath: "Jenkinsfile",
SourceType: "svn",
TimerTrigger: &TimerTrigger{
Interval: "12345566",
},
SvnSource: &SvnSource{
Remote: "https://api.svn.com/bcd",
CredentialId: "svn",
Excludes: "truck",
Includes: "tag/*",
},
},
&MultiBranchPipeline{
Name: "",
Description: "for test",
ScriptPath: "Jenkinsfile",
SourceType: "single_svn",
TimerTrigger: &TimerTrigger{
Interval: "12345566",
},
SingleSvnSource: &SingleSvnSource{
Remote: "https://api.svn.com/bcd",
CredentialId: "svn",
},
},
}
for _, input := range inputs {
outputString, err := createMultiBranchPipelineConfigXml("", input)
if err != nil {
t.Fatalf("should not get error %+v", err)
}
output, err := parseMultiBranchPipelineConfigXml(outputString)
if err != nil {
t.Fatalf("should not get error %+v", err)
}
if !reflect.DeepEqual(input, output) {
t.Fatalf("input [%+v] output [%+v] should equal ", input, output)
}
}
}
func Test_MultiBranchPipelineCloneConfig(t *testing.T) {
inputs := []*MultiBranchPipeline{
&MultiBranchPipeline{
Name: "",
Description: "for test",
ScriptPath: "Jenkinsfile",
SourceType: "git",
GitSource: &GitSource{
Url: "https://github.com/kubesphere/devops",
CredentialId: "git",
DiscoverBranches: true,
CloneOption: &GitCloneOption{
Shallow: false,
Depth: 3,
Timeout: 20,
},
},
},
&MultiBranchPipeline{
Name: "",
Description: "for test",
ScriptPath: "Jenkinsfile",
SourceType: "github",
GitHubSource: &GithubSource{
Owner: "kubesphere",
Repo: "devops",
CredentialId: "github",
ApiUri: "https://api.github.com",
DiscoverBranches: 1,
DiscoverPRFromOrigin: 2,
DiscoverPRFromForks: &GithubDiscoverPRFromForks{
Strategy: 1,
Trust: 1,
},
CloneOption: &GitCloneOption{
Shallow: false,
Depth: 3,
Timeout: 20,
},
},
},
}
for _, input := range inputs {
outputString, err := createMultiBranchPipelineConfigXml("", input)
if err != nil {
t.Fatalf("should not get error %+v", err)
}
output, err := parseMultiBranchPipelineConfigXml(outputString)
if err != nil {
t.Fatalf("should not get error %+v", err)
}
if !reflect.DeepEqual(input, output) {
t.Fatalf("input [%+v] output [%+v] should equal ", input, output)
}
}
}
func Test_MultiBranchPipelineRegexFilter(t *testing.T) {
inputs := []*MultiBranchPipeline{
&MultiBranchPipeline{
Name: "",
Description: "for test",
ScriptPath: "Jenkinsfile",
SourceType: "git",
GitSource: &GitSource{
Url: "https://github.com/kubesphere/devops",
CredentialId: "git",
DiscoverBranches: true,
RegexFilter: ".*",
},
},
&MultiBranchPipeline{
Name: "",
Description: "for test",
ScriptPath: "Jenkinsfile",
SourceType: "github",
GitHubSource: &GithubSource{
Owner: "kubesphere",
Repo: "devops",
CredentialId: "github",
ApiUri: "https://api.github.com",
DiscoverBranches: 1,
DiscoverPRFromOrigin: 2,
DiscoverPRFromForks: &GithubDiscoverPRFromForks{
Strategy: 1,
Trust: 1,
},
RegexFilter: ".*",
},
},
}
for _, input := range inputs {
outputString, err := createMultiBranchPipelineConfigXml("", input)
if err != nil {
t.Fatalf("should not get error %+v", err)
}
output, err := parseMultiBranchPipelineConfigXml(outputString)
if err != nil {
t.Fatalf("should not get error %+v", err)
}
if !reflect.DeepEqual(input, output) {
t.Fatalf("input [%+v] output [%+v] should equal ", input, output)
}
}
}

View File

@@ -19,6 +19,7 @@ package tenant
import (
"fmt"
"github.com/emicklei/go-restful"
"github.com/gocraft/dbr"
"github.com/golang/glog"
"kubesphere.io/kubesphere/pkg/db"
@@ -29,248 +30,15 @@ import (
"kubesphere.io/kubesphere/pkg/params"
"kubesphere.io/kubesphere/pkg/simple/client/admin_jenkins"
"kubesphere.io/kubesphere/pkg/simple/client/devops_mysql"
"kubesphere.io/kubesphere/pkg/utils/reflectutils"
"net/http"
"sync"
)
const (
ProjectOwner = "owner"
ProjectMaintainer = "maintainer"
ProjectDeveloper = "developer"
ProjectReporter = "reporter"
)
var AllRoleSlice = []string{ProjectDeveloper, ProjectReporter, ProjectMaintainer, ProjectOwner}
var JenkinsOwnerProjectPermissionIds = &gojenkins.ProjectPermissionIds{
CredentialCreate: true,
CredentialDelete: true,
CredentialManageDomains: true,
CredentialUpdate: true,
CredentialView: true,
ItemBuild: true,
ItemCancel: true,
ItemConfigure: true,
ItemCreate: true,
ItemDelete: true,
ItemDiscover: true,
ItemMove: true,
ItemRead: true,
ItemWorkspace: true,
RunDelete: true,
RunReplay: true,
RunUpdate: true,
SCMTag: true,
}
var JenkinsProjectPermissionMap = map[string]gojenkins.ProjectPermissionIds{
ProjectOwner: gojenkins.ProjectPermissionIds{
CredentialCreate: true,
CredentialDelete: true,
CredentialManageDomains: true,
CredentialUpdate: true,
CredentialView: true,
ItemBuild: true,
ItemCancel: true,
ItemConfigure: true,
ItemCreate: true,
ItemDelete: true,
ItemDiscover: true,
ItemMove: true,
ItemRead: true,
ItemWorkspace: true,
RunDelete: true,
RunReplay: true,
RunUpdate: true,
SCMTag: true,
},
ProjectMaintainer: gojenkins.ProjectPermissionIds{
CredentialCreate: true,
CredentialDelete: true,
CredentialManageDomains: true,
CredentialUpdate: true,
CredentialView: true,
ItemBuild: true,
ItemCancel: true,
ItemConfigure: false,
ItemCreate: true,
ItemDelete: false,
ItemDiscover: true,
ItemMove: false,
ItemRead: true,
ItemWorkspace: true,
RunDelete: true,
RunReplay: true,
RunUpdate: true,
SCMTag: true,
},
ProjectDeveloper: gojenkins.ProjectPermissionIds{
CredentialCreate: false,
CredentialDelete: false,
CredentialManageDomains: false,
CredentialUpdate: false,
CredentialView: false,
ItemBuild: true,
ItemCancel: true,
ItemConfigure: false,
ItemCreate: false,
ItemDelete: false,
ItemDiscover: true,
ItemMove: false,
ItemRead: true,
ItemWorkspace: true,
RunDelete: true,
RunReplay: true,
RunUpdate: true,
SCMTag: false,
},
ProjectReporter: gojenkins.ProjectPermissionIds{
CredentialCreate: false,
CredentialDelete: false,
CredentialManageDomains: false,
CredentialUpdate: false,
CredentialView: false,
ItemBuild: false,
ItemCancel: false,
ItemConfigure: false,
ItemCreate: false,
ItemDelete: false,
ItemDiscover: true,
ItemMove: false,
ItemRead: true,
ItemWorkspace: false,
RunDelete: false,
RunReplay: false,
RunUpdate: false,
SCMTag: false,
},
}
var JenkinsPipelinePermissionMap = map[string]gojenkins.ProjectPermissionIds{
ProjectOwner: gojenkins.ProjectPermissionIds{
CredentialCreate: true,
CredentialDelete: true,
CredentialManageDomains: true,
CredentialUpdate: true,
CredentialView: true,
ItemBuild: true,
ItemCancel: true,
ItemConfigure: true,
ItemCreate: true,
ItemDelete: true,
ItemDiscover: true,
ItemMove: true,
ItemRead: true,
ItemWorkspace: true,
RunDelete: true,
RunReplay: true,
RunUpdate: true,
SCMTag: true,
},
ProjectMaintainer: gojenkins.ProjectPermissionIds{
CredentialCreate: true,
CredentialDelete: true,
CredentialManageDomains: true,
CredentialUpdate: true,
CredentialView: true,
ItemBuild: true,
ItemCancel: true,
ItemConfigure: true,
ItemCreate: true,
ItemDelete: true,
ItemDiscover: true,
ItemMove: true,
ItemRead: true,
ItemWorkspace: true,
RunDelete: true,
RunReplay: true,
RunUpdate: true,
SCMTag: true,
},
ProjectDeveloper: gojenkins.ProjectPermissionIds{
CredentialCreate: false,
CredentialDelete: false,
CredentialManageDomains: false,
CredentialUpdate: false,
CredentialView: false,
ItemBuild: true,
ItemCancel: true,
ItemConfigure: false,
ItemCreate: false,
ItemDelete: false,
ItemDiscover: true,
ItemMove: false,
ItemRead: true,
ItemWorkspace: true,
RunDelete: true,
RunReplay: true,
RunUpdate: true,
SCMTag: false,
},
ProjectReporter: gojenkins.ProjectPermissionIds{
CredentialCreate: false,
CredentialDelete: false,
CredentialManageDomains: false,
CredentialUpdate: false,
CredentialView: false,
ItemBuild: false,
ItemCancel: false,
ItemConfigure: false,
ItemCreate: false,
ItemDelete: false,
ItemDiscover: true,
ItemMove: false,
ItemRead: true,
ItemWorkspace: false,
RunDelete: false,
RunReplay: false,
RunUpdate: false,
SCMTag: false,
},
}
func GetProjectRoleName(projectId, role string) string {
return fmt.Sprintf("%s-%s-project", projectId, role)
}
func GetPipelineRoleName(projectId, role string) string {
return fmt.Sprintf("%s-%s-pipeline", projectId, role)
}
func GetProjectRolePattern(projectId string) string {
return fmt.Sprintf("^%s$", projectId)
}
func GetPipelineRolePattern(projectId string) string {
return fmt.Sprintf("^%s/.*", projectId)
}
type DevOpsProjectRoleResponse struct {
ProjectRole *gojenkins.ProjectRole
Err error
}
func CheckProjectUserInRole(username, projectId string, roles []string) error {
if username == devops.KS_ADMIN {
return nil
}
dbconn := devops_mysql.OpenDatabase()
membership := &devops.DevOpsProjectMembership{}
err := dbconn.Select(devops.DevOpsProjectMembershipColumns...).
From(devops.DevOpsProjectMembershipTableName).
Where(db.And(
db.Eq(devops.DevOpsProjectMembershipUsernameColumn, username),
db.Eq(devops.DevOpsProjectMembershipProjectIdColumn, projectId))).LoadOne(membership)
if err != nil {
return err
}
if !reflectutils.In(membership.Role, roles) {
return fmt.Errorf("user [%s] in project [%s] role is not in %s", username, projectId, roles)
}
return nil
}
func ListDevopsProjects(workspace, username string, conditions *params.Conditions, orderBy string, reverse bool, limit int, offset int) (*models.PageableResponse, error) {
dbconn := devops_mysql.OpenDatabase()
@@ -321,12 +89,12 @@ func ListDevopsProjects(workspace, username string, conditions *params.Condition
_, err := query.Load(&projects)
if err != nil {
glog.Errorf("%+v", err)
return nil, err
return nil, restful.NewError(http.StatusInternalServerError, err.Error())
}
count, err := query.Count()
if err != nil {
glog.Errorf("%+v", err)
return nil, err
return nil, restful.NewError(http.StatusInternalServerError, err.Error())
}
result := make([]interface{}, 0)
@@ -337,11 +105,11 @@ func ListDevopsProjects(workspace, username string, conditions *params.Condition
return &models.PageableResponse{Items: result, TotalCount: int(count)}, nil
}
func DeleteDevOpsProject(projectId, username string) (error, int) {
err := CheckProjectUserInRole(username, projectId, []string{ProjectOwner})
func DeleteDevOpsProject(projectId, username string) error {
err := devops.CheckProjectUserInRole(username, projectId, []string{devops.ProjectOwner})
if err != nil {
glog.Errorf("%+v", err)
return err, http.StatusForbidden
return restful.NewError(http.StatusForbidden, err.Error())
}
gojenkins := admin_jenkins.Client()
devopsdb := devops_mysql.OpenDatabase()
@@ -349,31 +117,31 @@ func DeleteDevOpsProject(projectId, username string) (error, int) {
if err != nil && utils.GetJenkinsStatusCode(err) != http.StatusNotFound {
glog.Errorf("%+v", err)
return err, utils.GetJenkinsStatusCode(err)
return restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
roleNames := make([]string, 0)
for role := range JenkinsProjectPermissionMap {
roleNames = append(roleNames, GetProjectRoleName(projectId, role))
roleNames = append(roleNames, GetPipelineRoleName(projectId, role))
for role := range devops.JenkinsProjectPermissionMap {
roleNames = append(roleNames, devops.GetProjectRoleName(projectId, role))
roleNames = append(roleNames, devops.GetPipelineRoleName(projectId, role))
}
err = gojenkins.DeleteProjectRoles(roleNames...)
if err != nil {
glog.Errorf("%+v", err)
return err, utils.GetJenkinsStatusCode(err)
return restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
_, err = devopsdb.DeleteFrom(devops.DevOpsProjectMembershipTableName).
Where(db.Eq(devops.DevOpsProjectMembershipProjectIdColumn, projectId)).Exec()
if err != nil {
glog.Errorf("%+v", err)
return err, http.StatusInternalServerError
return restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
_, err = devopsdb.Update(devops.DevOpsProjectTableName).
Set(devops.StatusColumn, devops.StatusDeleted).
Where(db.Eq(devops.DevOpsProjectIdColumn, projectId)).Exec()
if err != nil {
glog.Errorf("%+v", err)
return err, http.StatusInternalServerError
return restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
project := &devops.DevOpsProject{}
err = devopsdb.Select(devops.DevOpsProjectColumns...).
@@ -382,12 +150,12 @@ func DeleteDevOpsProject(projectId, username string) (error, int) {
LoadOne(project)
if err != nil {
glog.Errorf("%+v", err)
return err, http.StatusInternalServerError
return restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
return nil, http.StatusOK
return nil
}
func CreateDevopsProject(username string, workspace string, req *devops.DevOpsProject) (*devops.DevOpsProject, error, int) {
func CreateDevopsProject(username string, workspace string, req *devops.DevOpsProject) (*devops.DevOpsProject, error) {
jenkinsClient := admin_jenkins.Client()
devopsdb := devops_mysql.OpenDatabase()
@@ -395,25 +163,25 @@ func CreateDevopsProject(username string, workspace string, req *devops.DevOpsPr
_, err := jenkinsClient.CreateFolder(project.ProjectId, project.Description)
if err != nil {
glog.Errorf("%+v", err)
return nil, err, utils.GetJenkinsStatusCode(err)
return nil, restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
var addRoleCh = make(chan *DevOpsProjectRoleResponse, 8)
var addRoleWg sync.WaitGroup
for role, permission := range JenkinsProjectPermissionMap {
for role, permission := range devops.JenkinsProjectPermissionMap {
addRoleWg.Add(1)
go func(role string, permission gojenkins.ProjectPermissionIds) {
_, err := jenkinsClient.AddProjectRole(GetProjectRoleName(project.ProjectId, role),
GetProjectRolePattern(project.ProjectId), permission, true)
_, err := jenkinsClient.AddProjectRole(devops.GetProjectRoleName(project.ProjectId, role),
devops.GetProjectRolePattern(project.ProjectId), permission, true)
addRoleCh <- &DevOpsProjectRoleResponse{nil, err}
addRoleWg.Done()
}(role, permission)
}
for role, permission := range JenkinsPipelinePermissionMap {
for role, permission := range devops.JenkinsPipelinePermissionMap {
addRoleWg.Add(1)
go func(role string, permission gojenkins.ProjectPermissionIds) {
_, err := jenkinsClient.AddProjectRole(GetPipelineRoleName(project.ProjectId, role),
GetPipelineRolePattern(project.ProjectId), permission, true)
_, err := jenkinsClient.AddProjectRole(devops.GetPipelineRoleName(project.ProjectId, role),
devops.GetPipelineRolePattern(project.ProjectId), permission, true)
addRoleCh <- &DevOpsProjectRoleResponse{nil, err}
addRoleWg.Done()
}(role, permission)
@@ -423,14 +191,14 @@ func CreateDevopsProject(username string, workspace string, req *devops.DevOpsPr
for addRoleResponse := range addRoleCh {
if addRoleResponse.Err != nil {
glog.Errorf("%+v", addRoleResponse.Err)
return nil, addRoleResponse.Err, utils.GetJenkinsStatusCode(addRoleResponse.Err)
return nil, restful.NewError(utils.GetJenkinsStatusCode(addRoleResponse.Err), addRoleResponse.Err.Error())
}
}
globalRole, err := jenkinsClient.GetGlobalRole(devops.JenkinsAllUserRoleName)
if err != nil {
glog.Errorf("%+v", err)
return nil, err, utils.GetJenkinsStatusCode(err)
return nil, restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
if globalRole == nil {
_, err := jenkinsClient.AddGlobalRole(devops.JenkinsAllUserRoleName, gojenkins.GlobalPermissionIds{
@@ -438,74 +206,60 @@ func CreateDevopsProject(username string, workspace string, req *devops.DevOpsPr
}, true)
if err != nil {
glog.Error("failed to create jenkins global role")
return nil, err, utils.GetJenkinsStatusCode(err)
return nil, restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
}
err = globalRole.AssignRole(username)
if err != nil {
glog.Errorf("%+v", err)
return nil, err, utils.GetJenkinsStatusCode(err)
return nil, restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
projectRole, err := jenkinsClient.GetProjectRole(GetProjectRoleName(project.ProjectId, ProjectOwner))
projectRole, err := jenkinsClient.GetProjectRole(devops.GetProjectRoleName(project.ProjectId, devops.ProjectOwner))
if err != nil {
glog.Errorf("%+v", err)
return nil, err, utils.GetJenkinsStatusCode(err)
return nil, restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
err = projectRole.AssignRole(username)
if err != nil {
glog.Errorf("%+v", err)
return nil, err, utils.GetJenkinsStatusCode(err)
return nil, restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
pipelineRole, err := jenkinsClient.GetProjectRole(GetPipelineRoleName(project.ProjectId, ProjectOwner))
pipelineRole, err := jenkinsClient.GetProjectRole(devops.GetPipelineRoleName(project.ProjectId, devops.ProjectOwner))
if err != nil {
glog.Errorf("%+v", err)
return nil, err, utils.GetJenkinsStatusCode(err)
return nil, restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
err = pipelineRole.AssignRole(username)
if err != nil {
glog.Errorf("%+v", err)
return nil, err, utils.GetJenkinsStatusCode(err)
return nil, restful.NewError(utils.GetJenkinsStatusCode(err), err.Error())
}
_, err = devopsdb.InsertInto(devops.DevOpsProjectTableName).
Columns(devops.DevOpsProjectColumns...).Record(project).Exec()
if err != nil {
glog.Errorf("%+v", err)
return nil, err, http.StatusInternalServerError
return nil, restful.NewError(http.StatusInternalServerError, err.Error())
}
projectMembership := devops.NewDevOpsProjectMemberShip(username, project.ProjectId, ProjectOwner, username)
projectMembership := devops.NewDevOpsProjectMemberShip(username, project.ProjectId, devops.ProjectOwner, username)
_, err = devopsdb.InsertInto(devops.DevOpsProjectMembershipTableName).
Columns(devops.DevOpsProjectMembershipColumns...).Record(projectMembership).Exec()
if err != nil {
glog.Errorf("%+v", err)
return nil, err, http.StatusInternalServerError
return nil, restful.NewError(http.StatusInternalServerError, err.Error())
}
return project, nil, http.StatusOK
return project, nil
}
func GetUserDevopsSimpleRules(username, projectId string) ([]models.SimpleRule, error, int) {
err := CheckProjectUserInRole(username, projectId, AllRoleSlice)
func GetUserDevopsSimpleRules(username, projectId string) ([]models.SimpleRule, error) {
role, err := devops.GetProjectUserRole(username, projectId)
if err != nil {
glog.Errorf("%+v", err)
return nil, err, http.StatusForbidden
return nil, restful.NewError(http.StatusForbidden, err.Error())
}
dbconn := devops_mysql.OpenDatabase()
memberships := &devops.DevOpsProjectMembership{}
err = dbconn.Select(devops.DevOpsProjectMembershipColumns...).
From(devops.DevOpsProjectMembershipTableName).
Where(db.And(
db.Eq(devops.DevOpsProjectMembershipProjectIdColumn, projectId),
db.Eq(devops.DevOpsProjectMembershipUsernameColumn, username))).
LoadOne(&memberships)
if err != nil {
glog.Errorf("%+v", err)
return nil, err, http.StatusInternalServerError
}
return GetDevopsRoleSimpleRules(memberships.Role), nil, http.StatusOK
return GetDevopsRoleSimpleRules(role), nil
}
func GetDevopsRoleSimpleRules(role string) []models.SimpleRule {

View File

@@ -1,3 +1,16 @@
/*
Copyright 2018 The KubeSphere Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package admin_jenkins
import (

12
vendor/github.com/PuerkitoBio/goquery/LICENSE generated vendored Normal file
View File

@@ -0,0 +1,12 @@
Copyright (c) 2012-2016, Martin Angers & Contributors
All rights reserved.
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
* Neither the name of the author nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

124
vendor/github.com/PuerkitoBio/goquery/array.go generated vendored Normal file
View File

@@ -0,0 +1,124 @@
package goquery
import (
"golang.org/x/net/html"
)
const (
maxUint = ^uint(0)
maxInt = int(maxUint >> 1)
// ToEnd is a special index value that can be used as end index in a call
// to Slice so that all elements are selected until the end of the Selection.
// It is equivalent to passing (*Selection).Length().
ToEnd = maxInt
)
// First reduces the set of matched elements to the first in the set.
// It returns a new Selection object, and an empty Selection object if the
// the selection is empty.
func (s *Selection) First() *Selection {
return s.Eq(0)
}
// Last reduces the set of matched elements to the last in the set.
// It returns a new Selection object, and an empty Selection object if
// the selection is empty.
func (s *Selection) Last() *Selection {
return s.Eq(-1)
}
// Eq reduces the set of matched elements to the one at the specified index.
// If a negative index is given, it counts backwards starting at the end of the
// set. It returns a new Selection object, and an empty Selection object if the
// index is invalid.
func (s *Selection) Eq(index int) *Selection {
if index < 0 {
index += len(s.Nodes)
}
if index >= len(s.Nodes) || index < 0 {
return newEmptySelection(s.document)
}
return s.Slice(index, index+1)
}
// Slice reduces the set of matched elements to a subset specified by a range
// of indices. The start index is 0-based and indicates the index of the first
// element to select. The end index is 0-based and indicates the index at which
// the elements stop being selected (the end index is not selected).
//
// The indices may be negative, in which case they represent an offset from the
// end of the selection.
//
// The special value ToEnd may be specified as end index, in which case all elements
// until the end are selected. This works both for a positive and negative start
// index.
func (s *Selection) Slice(start, end int) *Selection {
if start < 0 {
start += len(s.Nodes)
}
if end == ToEnd {
end = len(s.Nodes)
} else if end < 0 {
end += len(s.Nodes)
}
return pushStack(s, s.Nodes[start:end])
}
// Get retrieves the underlying node at the specified index.
// Get without parameter is not implemented, since the node array is available
// on the Selection object.
func (s *Selection) Get(index int) *html.Node {
if index < 0 {
index += len(s.Nodes) // Negative index gets from the end
}
return s.Nodes[index]
}
// Index returns the position of the first element within the Selection object
// relative to its sibling elements.
func (s *Selection) Index() int {
if len(s.Nodes) > 0 {
return newSingleSelection(s.Nodes[0], s.document).PrevAll().Length()
}
return -1
}
// IndexSelector returns the position of the first element within the
// Selection object relative to the elements matched by the selector, or -1 if
// not found.
func (s *Selection) IndexSelector(selector string) int {
if len(s.Nodes) > 0 {
sel := s.document.Find(selector)
return indexInSlice(sel.Nodes, s.Nodes[0])
}
return -1
}
// IndexMatcher returns the position of the first element within the
// Selection object relative to the elements matched by the matcher, or -1 if
// not found.
func (s *Selection) IndexMatcher(m Matcher) int {
if len(s.Nodes) > 0 {
sel := s.document.FindMatcher(m)
return indexInSlice(sel.Nodes, s.Nodes[0])
}
return -1
}
// IndexOfNode returns the position of the specified node within the Selection
// object, or -1 if not found.
func (s *Selection) IndexOfNode(node *html.Node) int {
return indexInSlice(s.Nodes, node)
}
// IndexOfSelection returns the position of the first node in the specified
// Selection object within this Selection object, or -1 if not found.
func (s *Selection) IndexOfSelection(sel *Selection) int {
if sel != nil && len(sel.Nodes) > 0 {
return indexInSlice(s.Nodes, sel.Nodes[0])
}
return -1
}

123
vendor/github.com/PuerkitoBio/goquery/doc.go generated vendored Normal file
View File

@@ -0,0 +1,123 @@
// Copyright (c) 2012-2016, Martin Angers & Contributors
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without modification,
// are permitted provided that the following conditions are met:
//
// * Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation and/or
// other materials provided with the distribution.
// * Neither the name of the author nor the names of its contributors may be used to
// endorse or promote products derived from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS
// OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY
// AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY
// WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
/*
Package goquery implements features similar to jQuery, including the chainable
syntax, to manipulate and query an HTML document.
It brings a syntax and a set of features similar to jQuery to the Go language.
It is based on Go's net/html package and the CSS Selector library cascadia.
Since the net/html parser returns nodes, and not a full-featured DOM
tree, jQuery's stateful manipulation functions (like height(), css(), detach())
have been left off.
Also, because the net/html parser requires UTF-8 encoding, so does goquery: it is
the caller's responsibility to ensure that the source document provides UTF-8 encoded HTML.
See the repository's wiki for various options on how to do this.
Syntax-wise, it is as close as possible to jQuery, with the same method names when
possible, and that warm and fuzzy chainable interface. jQuery being the
ultra-popular library that it is, writing a similar HTML-manipulating
library was better to follow its API than to start anew (in the same spirit as
Go's fmt package), even though some of its methods are less than intuitive (looking
at you, index()...).
It is hosted on GitHub, along with additional documentation in the README.md
file: https://github.com/puerkitobio/goquery
Please note that because of the net/html dependency, goquery requires Go1.1+.
The various methods are split into files based on the category of behavior.
The three dots (...) indicate that various "overloads" are available.
* array.go : array-like positional manipulation of the selection.
- Eq()
- First()
- Get()
- Index...()
- Last()
- Slice()
* expand.go : methods that expand or augment the selection's set.
- Add...()
- AndSelf()
- Union(), which is an alias for AddSelection()
* filter.go : filtering methods, that reduce the selection's set.
- End()
- Filter...()
- Has...()
- Intersection(), which is an alias of FilterSelection()
- Not...()
* iteration.go : methods to loop over the selection's nodes.
- Each()
- EachWithBreak()
- Map()
* manipulation.go : methods for modifying the document
- After...()
- Append...()
- Before...()
- Clone()
- Empty()
- Prepend...()
- Remove...()
- ReplaceWith...()
- Unwrap()
- Wrap...()
- WrapAll...()
- WrapInner...()
* property.go : methods that inspect and get the node's properties values.
- Attr*(), RemoveAttr(), SetAttr()
- AddClass(), HasClass(), RemoveClass(), ToggleClass()
- Html()
- Length()
- Size(), which is an alias for Length()
- Text()
* query.go : methods that query, or reflect, a node's identity.
- Contains()
- Is...()
* traversal.go : methods to traverse the HTML document tree.
- Children...()
- Contents()
- Find...()
- Next...()
- Parent[s]...()
- Prev...()
- Siblings...()
* type.go : definition of the types exposed by goquery.
- Document
- Selection
- Matcher
* utilities.go : definition of helper functions (and not methods on a *Selection)
that are not part of jQuery, but are useful to goquery.
- NodeName
- OuterHtml
*/
package goquery

70
vendor/github.com/PuerkitoBio/goquery/expand.go generated vendored Normal file
View File

@@ -0,0 +1,70 @@
package goquery
import "golang.org/x/net/html"
// Add adds the selector string's matching nodes to those in the current
// selection and returns a new Selection object.
// The selector string is run in the context of the document of the current
// Selection object.
func (s *Selection) Add(selector string) *Selection {
return s.AddNodes(findWithMatcher([]*html.Node{s.document.rootNode}, compileMatcher(selector))...)
}
// AddMatcher adds the matcher's matching nodes to those in the current
// selection and returns a new Selection object.
// The matcher is run in the context of the document of the current
// Selection object.
func (s *Selection) AddMatcher(m Matcher) *Selection {
return s.AddNodes(findWithMatcher([]*html.Node{s.document.rootNode}, m)...)
}
// AddSelection adds the specified Selection object's nodes to those in the
// current selection and returns a new Selection object.
func (s *Selection) AddSelection(sel *Selection) *Selection {
if sel == nil {
return s.AddNodes()
}
return s.AddNodes(sel.Nodes...)
}
// Union is an alias for AddSelection.
func (s *Selection) Union(sel *Selection) *Selection {
return s.AddSelection(sel)
}
// AddNodes adds the specified nodes to those in the
// current selection and returns a new Selection object.
func (s *Selection) AddNodes(nodes ...*html.Node) *Selection {
return pushStack(s, appendWithoutDuplicates(s.Nodes, nodes, nil))
}
// AndSelf adds the previous set of elements on the stack to the current set.
// It returns a new Selection object containing the current Selection combined
// with the previous one.
// Deprecated: This function has been deprecated and is now an alias for AddBack().
func (s *Selection) AndSelf() *Selection {
return s.AddBack()
}
// AddBack adds the previous set of elements on the stack to the current set.
// It returns a new Selection object containing the current Selection combined
// with the previous one.
func (s *Selection) AddBack() *Selection {
return s.AddSelection(s.prevSel)
}
// AddBackFiltered reduces the previous set of elements on the stack to those that
// match the selector string, and adds them to the current set.
// It returns a new Selection object containing the current Selection combined
// with the filtered previous one
func (s *Selection) AddBackFiltered(selector string) *Selection {
return s.AddSelection(s.prevSel.Filter(selector))
}
// AddBackMatcher reduces the previous set of elements on the stack to those that match
// the mateher, and adds them to the curernt set.
// It returns a new Selection object containing the current Selection combined
// with the filtered previous one
func (s *Selection) AddBackMatcher(m Matcher) *Selection {
return s.AddSelection(s.prevSel.FilterMatcher(m))
}

163
vendor/github.com/PuerkitoBio/goquery/filter.go generated vendored Normal file
View File

@@ -0,0 +1,163 @@
package goquery
import "golang.org/x/net/html"
// Filter reduces the set of matched elements to those that match the selector string.
// It returns a new Selection object for this subset of matching elements.
func (s *Selection) Filter(selector string) *Selection {
return s.FilterMatcher(compileMatcher(selector))
}
// FilterMatcher reduces the set of matched elements to those that match
// the given matcher. It returns a new Selection object for this subset
// of matching elements.
func (s *Selection) FilterMatcher(m Matcher) *Selection {
return pushStack(s, winnow(s, m, true))
}
// Not removes elements from the Selection that match the selector string.
// It returns a new Selection object with the matching elements removed.
func (s *Selection) Not(selector string) *Selection {
return s.NotMatcher(compileMatcher(selector))
}
// NotMatcher removes elements from the Selection that match the given matcher.
// It returns a new Selection object with the matching elements removed.
func (s *Selection) NotMatcher(m Matcher) *Selection {
return pushStack(s, winnow(s, m, false))
}
// FilterFunction reduces the set of matched elements to those that pass the function's test.
// It returns a new Selection object for this subset of elements.
func (s *Selection) FilterFunction(f func(int, *Selection) bool) *Selection {
return pushStack(s, winnowFunction(s, f, true))
}
// NotFunction removes elements from the Selection that pass the function's test.
// It returns a new Selection object with the matching elements removed.
func (s *Selection) NotFunction(f func(int, *Selection) bool) *Selection {
return pushStack(s, winnowFunction(s, f, false))
}
// FilterNodes reduces the set of matched elements to those that match the specified nodes.
// It returns a new Selection object for this subset of elements.
func (s *Selection) FilterNodes(nodes ...*html.Node) *Selection {
return pushStack(s, winnowNodes(s, nodes, true))
}
// NotNodes removes elements from the Selection that match the specified nodes.
// It returns a new Selection object with the matching elements removed.
func (s *Selection) NotNodes(nodes ...*html.Node) *Selection {
return pushStack(s, winnowNodes(s, nodes, false))
}
// FilterSelection reduces the set of matched elements to those that match a
// node in the specified Selection object.
// It returns a new Selection object for this subset of elements.
func (s *Selection) FilterSelection(sel *Selection) *Selection {
if sel == nil {
return pushStack(s, winnowNodes(s, nil, true))
}
return pushStack(s, winnowNodes(s, sel.Nodes, true))
}
// NotSelection removes elements from the Selection that match a node in the specified
// Selection object. It returns a new Selection object with the matching elements removed.
func (s *Selection) NotSelection(sel *Selection) *Selection {
if sel == nil {
return pushStack(s, winnowNodes(s, nil, false))
}
return pushStack(s, winnowNodes(s, sel.Nodes, false))
}
// Intersection is an alias for FilterSelection.
func (s *Selection) Intersection(sel *Selection) *Selection {
return s.FilterSelection(sel)
}
// Has reduces the set of matched elements to those that have a descendant
// that matches the selector.
// It returns a new Selection object with the matching elements.
func (s *Selection) Has(selector string) *Selection {
return s.HasSelection(s.document.Find(selector))
}
// HasMatcher reduces the set of matched elements to those that have a descendant
// that matches the matcher.
// It returns a new Selection object with the matching elements.
func (s *Selection) HasMatcher(m Matcher) *Selection {
return s.HasSelection(s.document.FindMatcher(m))
}
// HasNodes reduces the set of matched elements to those that have a
// descendant that matches one of the nodes.
// It returns a new Selection object with the matching elements.
func (s *Selection) HasNodes(nodes ...*html.Node) *Selection {
return s.FilterFunction(func(_ int, sel *Selection) bool {
// Add all nodes that contain one of the specified nodes
for _, n := range nodes {
if sel.Contains(n) {
return true
}
}
return false
})
}
// HasSelection reduces the set of matched elements to those that have a
// descendant that matches one of the nodes of the specified Selection object.
// It returns a new Selection object with the matching elements.
func (s *Selection) HasSelection(sel *Selection) *Selection {
if sel == nil {
return s.HasNodes()
}
return s.HasNodes(sel.Nodes...)
}
// End ends the most recent filtering operation in the current chain and
// returns the set of matched elements to its previous state.
func (s *Selection) End() *Selection {
if s.prevSel != nil {
return s.prevSel
}
return newEmptySelection(s.document)
}
// Filter based on the matcher, and the indicator to keep (Filter) or
// to get rid of (Not) the matching elements.
func winnow(sel *Selection, m Matcher, keep bool) []*html.Node {
// Optimize if keep is requested
if keep {
return m.Filter(sel.Nodes)
}
// Use grep
return grep(sel, func(i int, s *Selection) bool {
return !m.Match(s.Get(0))
})
}
// Filter based on an array of nodes, and the indicator to keep (Filter) or
// to get rid of (Not) the matching elements.
func winnowNodes(sel *Selection, nodes []*html.Node, keep bool) []*html.Node {
if len(nodes)+len(sel.Nodes) < minNodesForSet {
return grep(sel, func(i int, s *Selection) bool {
return isInSlice(nodes, s.Get(0)) == keep
})
}
set := make(map[*html.Node]bool)
for _, n := range nodes {
set[n] = true
}
return grep(sel, func(i int, s *Selection) bool {
return set[s.Get(0)] == keep
})
}
// Filter based on a function test, and the indicator to keep (Filter) or
// to get rid of (Not) the matching elements.
func winnowFunction(sel *Selection, f func(int, *Selection) bool, keep bool) []*html.Node {
return grep(sel, func(i int, s *Selection) bool {
return f(i, s) == keep
})
}

39
vendor/github.com/PuerkitoBio/goquery/iteration.go generated vendored Normal file
View File

@@ -0,0 +1,39 @@
package goquery
// Each iterates over a Selection object, executing a function for each
// matched element. It returns the current Selection object. The function
// f is called for each element in the selection with the index of the
// element in that selection starting at 0, and a *Selection that contains
// only that element.
func (s *Selection) Each(f func(int, *Selection)) *Selection {
for i, n := range s.Nodes {
f(i, newSingleSelection(n, s.document))
}
return s
}
// EachWithBreak iterates over a Selection object, executing a function for each
// matched element. It is identical to Each except that it is possible to break
// out of the loop by returning false in the callback function. It returns the
// current Selection object.
func (s *Selection) EachWithBreak(f func(int, *Selection) bool) *Selection {
for i, n := range s.Nodes {
if !f(i, newSingleSelection(n, s.document)) {
return s
}
}
return s
}
// Map passes each element in the current matched set through a function,
// producing a slice of string holding the returned values. The function
// f is called for each element in the selection with the index of the
// element in that selection starting at 0, and a *Selection that contains
// only that element.
func (s *Selection) Map(f func(int, *Selection) string) (result []string) {
for i, n := range s.Nodes {
result = append(result, f(i, newSingleSelection(n, s.document)))
}
return result
}

574
vendor/github.com/PuerkitoBio/goquery/manipulation.go generated vendored Normal file
View File

@@ -0,0 +1,574 @@
package goquery
import (
"strings"
"golang.org/x/net/html"
)
// After applies the selector from the root document and inserts the matched elements
// after the elements in the set of matched elements.
//
// If one of the matched elements in the selection is not currently in the
// document, it's impossible to insert nodes after it, so it will be ignored.
//
// This follows the same rules as Selection.Append.
func (s *Selection) After(selector string) *Selection {
return s.AfterMatcher(compileMatcher(selector))
}
// AfterMatcher applies the matcher from the root document and inserts the matched elements
// after the elements in the set of matched elements.
//
// If one of the matched elements in the selection is not currently in the
// document, it's impossible to insert nodes after it, so it will be ignored.
//
// This follows the same rules as Selection.Append.
func (s *Selection) AfterMatcher(m Matcher) *Selection {
return s.AfterNodes(m.MatchAll(s.document.rootNode)...)
}
// AfterSelection inserts the elements in the selection after each element in the set of matched
// elements.
//
// This follows the same rules as Selection.Append.
func (s *Selection) AfterSelection(sel *Selection) *Selection {
return s.AfterNodes(sel.Nodes...)
}
// AfterHtml parses the html and inserts it after the set of matched elements.
//
// This follows the same rules as Selection.Append.
func (s *Selection) AfterHtml(html string) *Selection {
return s.AfterNodes(parseHtml(html)...)
}
// AfterNodes inserts the nodes after each element in the set of matched elements.
//
// This follows the same rules as Selection.Append.
func (s *Selection) AfterNodes(ns ...*html.Node) *Selection {
return s.manipulateNodes(ns, true, func(sn *html.Node, n *html.Node) {
if sn.Parent != nil {
sn.Parent.InsertBefore(n, sn.NextSibling)
}
})
}
// Append appends the elements specified by the selector to the end of each element
// in the set of matched elements, following those rules:
//
// 1) The selector is applied to the root document.
//
// 2) Elements that are part of the document will be moved to the new location.
//
// 3) If there are multiple locations to append to, cloned nodes will be
// appended to all target locations except the last one, which will be moved
// as noted in (2).
func (s *Selection) Append(selector string) *Selection {
return s.AppendMatcher(compileMatcher(selector))
}
// AppendMatcher appends the elements specified by the matcher to the end of each element
// in the set of matched elements.
//
// This follows the same rules as Selection.Append.
func (s *Selection) AppendMatcher(m Matcher) *Selection {
return s.AppendNodes(m.MatchAll(s.document.rootNode)...)
}
// AppendSelection appends the elements in the selection to the end of each element
// in the set of matched elements.
//
// This follows the same rules as Selection.Append.
func (s *Selection) AppendSelection(sel *Selection) *Selection {
return s.AppendNodes(sel.Nodes...)
}
// AppendHtml parses the html and appends it to the set of matched elements.
func (s *Selection) AppendHtml(html string) *Selection {
return s.AppendNodes(parseHtml(html)...)
}
// AppendNodes appends the specified nodes to each node in the set of matched elements.
//
// This follows the same rules as Selection.Append.
func (s *Selection) AppendNodes(ns ...*html.Node) *Selection {
return s.manipulateNodes(ns, false, func(sn *html.Node, n *html.Node) {
sn.AppendChild(n)
})
}
// Before inserts the matched elements before each element in the set of matched elements.
//
// This follows the same rules as Selection.Append.
func (s *Selection) Before(selector string) *Selection {
return s.BeforeMatcher(compileMatcher(selector))
}
// BeforeMatcher inserts the matched elements before each element in the set of matched elements.
//
// This follows the same rules as Selection.Append.
func (s *Selection) BeforeMatcher(m Matcher) *Selection {
return s.BeforeNodes(m.MatchAll(s.document.rootNode)...)
}
// BeforeSelection inserts the elements in the selection before each element in the set of matched
// elements.
//
// This follows the same rules as Selection.Append.
func (s *Selection) BeforeSelection(sel *Selection) *Selection {
return s.BeforeNodes(sel.Nodes...)
}
// BeforeHtml parses the html and inserts it before the set of matched elements.
//
// This follows the same rules as Selection.Append.
func (s *Selection) BeforeHtml(html string) *Selection {
return s.BeforeNodes(parseHtml(html)...)
}
// BeforeNodes inserts the nodes before each element in the set of matched elements.
//
// This follows the same rules as Selection.Append.
func (s *Selection) BeforeNodes(ns ...*html.Node) *Selection {
return s.manipulateNodes(ns, false, func(sn *html.Node, n *html.Node) {
if sn.Parent != nil {
sn.Parent.InsertBefore(n, sn)
}
})
}
// Clone creates a deep copy of the set of matched nodes. The new nodes will not be
// attached to the document.
func (s *Selection) Clone() *Selection {
ns := newEmptySelection(s.document)
ns.Nodes = cloneNodes(s.Nodes)
return ns
}
// Empty removes all children nodes from the set of matched elements.
// It returns the children nodes in a new Selection.
func (s *Selection) Empty() *Selection {
var nodes []*html.Node
for _, n := range s.Nodes {
for c := n.FirstChild; c != nil; c = n.FirstChild {
n.RemoveChild(c)
nodes = append(nodes, c)
}
}
return pushStack(s, nodes)
}
// Prepend prepends the elements specified by the selector to each element in
// the set of matched elements, following the same rules as Append.
func (s *Selection) Prepend(selector string) *Selection {
return s.PrependMatcher(compileMatcher(selector))
}
// PrependMatcher prepends the elements specified by the matcher to each
// element in the set of matched elements.
//
// This follows the same rules as Selection.Append.
func (s *Selection) PrependMatcher(m Matcher) *Selection {
return s.PrependNodes(m.MatchAll(s.document.rootNode)...)
}
// PrependSelection prepends the elements in the selection to each element in
// the set of matched elements.
//
// This follows the same rules as Selection.Append.
func (s *Selection) PrependSelection(sel *Selection) *Selection {
return s.PrependNodes(sel.Nodes...)
}
// PrependHtml parses the html and prepends it to the set of matched elements.
func (s *Selection) PrependHtml(html string) *Selection {
return s.PrependNodes(parseHtml(html)...)
}
// PrependNodes prepends the specified nodes to each node in the set of
// matched elements.
//
// This follows the same rules as Selection.Append.
func (s *Selection) PrependNodes(ns ...*html.Node) *Selection {
return s.manipulateNodes(ns, true, func(sn *html.Node, n *html.Node) {
// sn.FirstChild may be nil, in which case this functions like
// sn.AppendChild()
sn.InsertBefore(n, sn.FirstChild)
})
}
// Remove removes the set of matched elements from the document.
// It returns the same selection, now consisting of nodes not in the document.
func (s *Selection) Remove() *Selection {
for _, n := range s.Nodes {
if n.Parent != nil {
n.Parent.RemoveChild(n)
}
}
return s
}
// RemoveFiltered removes the set of matched elements by selector.
// It returns the Selection of removed nodes.
func (s *Selection) RemoveFiltered(selector string) *Selection {
return s.RemoveMatcher(compileMatcher(selector))
}
// RemoveMatcher removes the set of matched elements.
// It returns the Selection of removed nodes.
func (s *Selection) RemoveMatcher(m Matcher) *Selection {
return s.FilterMatcher(m).Remove()
}
// ReplaceWith replaces each element in the set of matched elements with the
// nodes matched by the given selector.
// It returns the removed elements.
//
// This follows the same rules as Selection.Append.
func (s *Selection) ReplaceWith(selector string) *Selection {
return s.ReplaceWithMatcher(compileMatcher(selector))
}
// ReplaceWithMatcher replaces each element in the set of matched elements with
// the nodes matched by the given Matcher.
// It returns the removed elements.
//
// This follows the same rules as Selection.Append.
func (s *Selection) ReplaceWithMatcher(m Matcher) *Selection {
return s.ReplaceWithNodes(m.MatchAll(s.document.rootNode)...)
}
// ReplaceWithSelection replaces each element in the set of matched elements with
// the nodes from the given Selection.
// It returns the removed elements.
//
// This follows the same rules as Selection.Append.
func (s *Selection) ReplaceWithSelection(sel *Selection) *Selection {
return s.ReplaceWithNodes(sel.Nodes...)
}
// ReplaceWithHtml replaces each element in the set of matched elements with
// the parsed HTML.
// It returns the removed elements.
//
// This follows the same rules as Selection.Append.
func (s *Selection) ReplaceWithHtml(html string) *Selection {
return s.ReplaceWithNodes(parseHtml(html)...)
}
// ReplaceWithNodes replaces each element in the set of matched elements with
// the given nodes.
// It returns the removed elements.
//
// This follows the same rules as Selection.Append.
func (s *Selection) ReplaceWithNodes(ns ...*html.Node) *Selection {
s.AfterNodes(ns...)
return s.Remove()
}
// SetHtml sets the html content of each element in the selection to
// specified html string.
func (s *Selection) SetHtml(html string) *Selection {
return setHtmlNodes(s, parseHtml(html)...)
}
// SetText sets the content of each element in the selection to specified content.
// The provided text string is escaped.
func (s *Selection) SetText(text string) *Selection {
return s.SetHtml(html.EscapeString(text))
}
// Unwrap removes the parents of the set of matched elements, leaving the matched
// elements (and their siblings, if any) in their place.
// It returns the original selection.
func (s *Selection) Unwrap() *Selection {
s.Parent().Each(func(i int, ss *Selection) {
// For some reason, jquery allows unwrap to remove the <head> element, so
// allowing it here too. Same for <html>. Why it allows those elements to
// be unwrapped while not allowing body is a mystery to me.
if ss.Nodes[0].Data != "body" {
ss.ReplaceWithSelection(ss.Contents())
}
})
return s
}
// Wrap wraps each element in the set of matched elements inside the first
// element matched by the given selector. The matched child is cloned before
// being inserted into the document.
//
// It returns the original set of elements.
func (s *Selection) Wrap(selector string) *Selection {
return s.WrapMatcher(compileMatcher(selector))
}
// WrapMatcher wraps each element in the set of matched elements inside the
// first element matched by the given matcher. The matched child is cloned
// before being inserted into the document.
//
// It returns the original set of elements.
func (s *Selection) WrapMatcher(m Matcher) *Selection {
return s.wrapNodes(m.MatchAll(s.document.rootNode)...)
}
// WrapSelection wraps each element in the set of matched elements inside the
// first element in the given Selection. The element is cloned before being
// inserted into the document.
//
// It returns the original set of elements.
func (s *Selection) WrapSelection(sel *Selection) *Selection {
return s.wrapNodes(sel.Nodes...)
}
// WrapHtml wraps each element in the set of matched elements inside the inner-
// most child of the given HTML.
//
// It returns the original set of elements.
func (s *Selection) WrapHtml(html string) *Selection {
return s.wrapNodes(parseHtml(html)...)
}
// WrapNode wraps each element in the set of matched elements inside the inner-
// most child of the given node. The given node is copied before being inserted
// into the document.
//
// It returns the original set of elements.
func (s *Selection) WrapNode(n *html.Node) *Selection {
return s.wrapNodes(n)
}
func (s *Selection) wrapNodes(ns ...*html.Node) *Selection {
s.Each(func(i int, ss *Selection) {
ss.wrapAllNodes(ns...)
})
return s
}
// WrapAll wraps a single HTML structure, matched by the given selector, around
// all elements in the set of matched elements. The matched child is cloned
// before being inserted into the document.
//
// It returns the original set of elements.
func (s *Selection) WrapAll(selector string) *Selection {
return s.WrapAllMatcher(compileMatcher(selector))
}
// WrapAllMatcher wraps a single HTML structure, matched by the given Matcher,
// around all elements in the set of matched elements. The matched child is
// cloned before being inserted into the document.
//
// It returns the original set of elements.
func (s *Selection) WrapAllMatcher(m Matcher) *Selection {
return s.wrapAllNodes(m.MatchAll(s.document.rootNode)...)
}
// WrapAllSelection wraps a single HTML structure, the first node of the given
// Selection, around all elements in the set of matched elements. The matched
// child is cloned before being inserted into the document.
//
// It returns the original set of elements.
func (s *Selection) WrapAllSelection(sel *Selection) *Selection {
return s.wrapAllNodes(sel.Nodes...)
}
// WrapAllHtml wraps the given HTML structure around all elements in the set of
// matched elements. The matched child is cloned before being inserted into the
// document.
//
// It returns the original set of elements.
func (s *Selection) WrapAllHtml(html string) *Selection {
return s.wrapAllNodes(parseHtml(html)...)
}
func (s *Selection) wrapAllNodes(ns ...*html.Node) *Selection {
if len(ns) > 0 {
return s.WrapAllNode(ns[0])
}
return s
}
// WrapAllNode wraps the given node around the first element in the Selection,
// making all other nodes in the Selection children of the given node. The node
// is cloned before being inserted into the document.
//
// It returns the original set of elements.
func (s *Selection) WrapAllNode(n *html.Node) *Selection {
if s.Size() == 0 {
return s
}
wrap := cloneNode(n)
first := s.Nodes[0]
if first.Parent != nil {
first.Parent.InsertBefore(wrap, first)
first.Parent.RemoveChild(first)
}
for c := getFirstChildEl(wrap); c != nil; c = getFirstChildEl(wrap) {
wrap = c
}
newSingleSelection(wrap, s.document).AppendSelection(s)
return s
}
// WrapInner wraps an HTML structure, matched by the given selector, around the
// content of element in the set of matched elements. The matched child is
// cloned before being inserted into the document.
//
// It returns the original set of elements.
func (s *Selection) WrapInner(selector string) *Selection {
return s.WrapInnerMatcher(compileMatcher(selector))
}
// WrapInnerMatcher wraps an HTML structure, matched by the given selector,
// around the content of element in the set of matched elements. The matched
// child is cloned before being inserted into the document.
//
// It returns the original set of elements.
func (s *Selection) WrapInnerMatcher(m Matcher) *Selection {
return s.wrapInnerNodes(m.MatchAll(s.document.rootNode)...)
}
// WrapInnerSelection wraps an HTML structure, matched by the given selector,
// around the content of element in the set of matched elements. The matched
// child is cloned before being inserted into the document.
//
// It returns the original set of elements.
func (s *Selection) WrapInnerSelection(sel *Selection) *Selection {
return s.wrapInnerNodes(sel.Nodes...)
}
// WrapInnerHtml wraps an HTML structure, matched by the given selector, around
// the content of element in the set of matched elements. The matched child is
// cloned before being inserted into the document.
//
// It returns the original set of elements.
func (s *Selection) WrapInnerHtml(html string) *Selection {
return s.wrapInnerNodes(parseHtml(html)...)
}
// WrapInnerNode wraps an HTML structure, matched by the given selector, around
// the content of element in the set of matched elements. The matched child is
// cloned before being inserted into the document.
//
// It returns the original set of elements.
func (s *Selection) WrapInnerNode(n *html.Node) *Selection {
return s.wrapInnerNodes(n)
}
func (s *Selection) wrapInnerNodes(ns ...*html.Node) *Selection {
if len(ns) == 0 {
return s
}
s.Each(func(i int, s *Selection) {
contents := s.Contents()
if contents.Size() > 0 {
contents.wrapAllNodes(ns...)
} else {
s.AppendNodes(cloneNode(ns[0]))
}
})
return s
}
func parseHtml(h string) []*html.Node {
// Errors are only returned when the io.Reader returns any error besides
// EOF, but strings.Reader never will
nodes, err := html.ParseFragment(strings.NewReader(h), &html.Node{Type: html.ElementNode})
if err != nil {
panic("goquery: failed to parse HTML: " + err.Error())
}
return nodes
}
func setHtmlNodes(s *Selection, ns ...*html.Node) *Selection {
for _, n := range s.Nodes {
for c := n.FirstChild; c != nil; c = n.FirstChild {
n.RemoveChild(c)
}
for _, c := range ns {
n.AppendChild(cloneNode(c))
}
}
return s
}
// Get the first child that is an ElementNode
func getFirstChildEl(n *html.Node) *html.Node {
c := n.FirstChild
for c != nil && c.Type != html.ElementNode {
c = c.NextSibling
}
return c
}
// Deep copy a slice of nodes.
func cloneNodes(ns []*html.Node) []*html.Node {
cns := make([]*html.Node, 0, len(ns))
for _, n := range ns {
cns = append(cns, cloneNode(n))
}
return cns
}
// Deep copy a node. The new node has clones of all the original node's
// children but none of its parents or siblings.
func cloneNode(n *html.Node) *html.Node {
nn := &html.Node{
Type: n.Type,
DataAtom: n.DataAtom,
Data: n.Data,
Attr: make([]html.Attribute, len(n.Attr)),
}
copy(nn.Attr, n.Attr)
for c := n.FirstChild; c != nil; c = c.NextSibling {
nn.AppendChild(cloneNode(c))
}
return nn
}
func (s *Selection) manipulateNodes(ns []*html.Node, reverse bool,
f func(sn *html.Node, n *html.Node)) *Selection {
lasti := s.Size() - 1
// net.Html doesn't provide document fragments for insertion, so to get
// things in the correct order with After() and Prepend(), the callback
// needs to be called on the reverse of the nodes.
if reverse {
for i, j := 0, len(ns)-1; i < j; i, j = i+1, j-1 {
ns[i], ns[j] = ns[j], ns[i]
}
}
for i, sn := range s.Nodes {
for _, n := range ns {
if i != lasti {
f(sn, cloneNode(n))
} else {
if n.Parent != nil {
n.Parent.RemoveChild(n)
}
f(sn, n)
}
}
}
return s
}

275
vendor/github.com/PuerkitoBio/goquery/property.go generated vendored Normal file
View File

@@ -0,0 +1,275 @@
package goquery
import (
"bytes"
"regexp"
"strings"
"golang.org/x/net/html"
)
var rxClassTrim = regexp.MustCompile("[\t\r\n]")
// Attr gets the specified attribute's value for the first element in the
// Selection. To get the value for each element individually, use a looping
// construct such as Each or Map method.
func (s *Selection) Attr(attrName string) (val string, exists bool) {
if len(s.Nodes) == 0 {
return
}
return getAttributeValue(attrName, s.Nodes[0])
}
// AttrOr works like Attr but returns default value if attribute is not present.
func (s *Selection) AttrOr(attrName, defaultValue string) string {
if len(s.Nodes) == 0 {
return defaultValue
}
val, exists := getAttributeValue(attrName, s.Nodes[0])
if !exists {
return defaultValue
}
return val
}
// RemoveAttr removes the named attribute from each element in the set of matched elements.
func (s *Selection) RemoveAttr(attrName string) *Selection {
for _, n := range s.Nodes {
removeAttr(n, attrName)
}
return s
}
// SetAttr sets the given attribute on each element in the set of matched elements.
func (s *Selection) SetAttr(attrName, val string) *Selection {
for _, n := range s.Nodes {
attr := getAttributePtr(attrName, n)
if attr == nil {
n.Attr = append(n.Attr, html.Attribute{Key: attrName, Val: val})
} else {
attr.Val = val
}
}
return s
}
// Text gets the combined text contents of each element in the set of matched
// elements, including their descendants.
func (s *Selection) Text() string {
var buf bytes.Buffer
// Slightly optimized vs calling Each: no single selection object created
var f func(*html.Node)
f = func(n *html.Node) {
if n.Type == html.TextNode {
// Keep newlines and spaces, like jQuery
buf.WriteString(n.Data)
}
if n.FirstChild != nil {
for c := n.FirstChild; c != nil; c = c.NextSibling {
f(c)
}
}
}
for _, n := range s.Nodes {
f(n)
}
return buf.String()
}
// Size is an alias for Length.
func (s *Selection) Size() int {
return s.Length()
}
// Length returns the number of elements in the Selection object.
func (s *Selection) Length() int {
return len(s.Nodes)
}
// Html gets the HTML contents of the first element in the set of matched
// elements. It includes text and comment nodes.
func (s *Selection) Html() (ret string, e error) {
// Since there is no .innerHtml, the HTML content must be re-created from
// the nodes using html.Render.
var buf bytes.Buffer
if len(s.Nodes) > 0 {
for c := s.Nodes[0].FirstChild; c != nil; c = c.NextSibling {
e = html.Render(&buf, c)
if e != nil {
return
}
}
ret = buf.String()
}
return
}
// AddClass adds the given class(es) to each element in the set of matched elements.
// Multiple class names can be specified, separated by a space or via multiple arguments.
func (s *Selection) AddClass(class ...string) *Selection {
classStr := strings.TrimSpace(strings.Join(class, " "))
if classStr == "" {
return s
}
tcls := getClassesSlice(classStr)
for _, n := range s.Nodes {
curClasses, attr := getClassesAndAttr(n, true)
for _, newClass := range tcls {
if !strings.Contains(curClasses, " "+newClass+" ") {
curClasses += newClass + " "
}
}
setClasses(n, attr, curClasses)
}
return s
}
// HasClass determines whether any of the matched elements are assigned the
// given class.
func (s *Selection) HasClass(class string) bool {
class = " " + class + " "
for _, n := range s.Nodes {
classes, _ := getClassesAndAttr(n, false)
if strings.Contains(classes, class) {
return true
}
}
return false
}
// RemoveClass removes the given class(es) from each element in the set of matched elements.
// Multiple class names can be specified, separated by a space or via multiple arguments.
// If no class name is provided, all classes are removed.
func (s *Selection) RemoveClass(class ...string) *Selection {
var rclasses []string
classStr := strings.TrimSpace(strings.Join(class, " "))
remove := classStr == ""
if !remove {
rclasses = getClassesSlice(classStr)
}
for _, n := range s.Nodes {
if remove {
removeAttr(n, "class")
} else {
classes, attr := getClassesAndAttr(n, true)
for _, rcl := range rclasses {
classes = strings.Replace(classes, " "+rcl+" ", " ", -1)
}
setClasses(n, attr, classes)
}
}
return s
}
// ToggleClass adds or removes the given class(es) for each element in the set of matched elements.
// Multiple class names can be specified, separated by a space or via multiple arguments.
func (s *Selection) ToggleClass(class ...string) *Selection {
classStr := strings.TrimSpace(strings.Join(class, " "))
if classStr == "" {
return s
}
tcls := getClassesSlice(classStr)
for _, n := range s.Nodes {
classes, attr := getClassesAndAttr(n, true)
for _, tcl := range tcls {
if strings.Contains(classes, " "+tcl+" ") {
classes = strings.Replace(classes, " "+tcl+" ", " ", -1)
} else {
classes += tcl + " "
}
}
setClasses(n, attr, classes)
}
return s
}
func getAttributePtr(attrName string, n *html.Node) *html.Attribute {
if n == nil {
return nil
}
for i, a := range n.Attr {
if a.Key == attrName {
return &n.Attr[i]
}
}
return nil
}
// Private function to get the specified attribute's value from a node.
func getAttributeValue(attrName string, n *html.Node) (val string, exists bool) {
if a := getAttributePtr(attrName, n); a != nil {
val = a.Val
exists = true
}
return
}
// Get and normalize the "class" attribute from the node.
func getClassesAndAttr(n *html.Node, create bool) (classes string, attr *html.Attribute) {
// Applies only to element nodes
if n.Type == html.ElementNode {
attr = getAttributePtr("class", n)
if attr == nil && create {
n.Attr = append(n.Attr, html.Attribute{
Key: "class",
Val: "",
})
attr = &n.Attr[len(n.Attr)-1]
}
}
if attr == nil {
classes = " "
} else {
classes = rxClassTrim.ReplaceAllString(" "+attr.Val+" ", " ")
}
return
}
func getClassesSlice(classes string) []string {
return strings.Split(rxClassTrim.ReplaceAllString(" "+classes+" ", " "), " ")
}
func removeAttr(n *html.Node, attrName string) {
for i, a := range n.Attr {
if a.Key == attrName {
n.Attr[i], n.Attr[len(n.Attr)-1], n.Attr =
n.Attr[len(n.Attr)-1], html.Attribute{}, n.Attr[:len(n.Attr)-1]
return
}
}
}
func setClasses(n *html.Node, attr *html.Attribute, classes string) {
classes = strings.TrimSpace(classes)
if classes == "" {
removeAttr(n, "class")
return
}
attr.Val = classes
}

49
vendor/github.com/PuerkitoBio/goquery/query.go generated vendored Normal file
View File

@@ -0,0 +1,49 @@
package goquery
import "golang.org/x/net/html"
// Is checks the current matched set of elements against a selector and
// returns true if at least one of these elements matches.
func (s *Selection) Is(selector string) bool {
return s.IsMatcher(compileMatcher(selector))
}
// IsMatcher checks the current matched set of elements against a matcher and
// returns true if at least one of these elements matches.
func (s *Selection) IsMatcher(m Matcher) bool {
if len(s.Nodes) > 0 {
if len(s.Nodes) == 1 {
return m.Match(s.Nodes[0])
}
return len(m.Filter(s.Nodes)) > 0
}
return false
}
// IsFunction checks the current matched set of elements against a predicate and
// returns true if at least one of these elements matches.
func (s *Selection) IsFunction(f func(int, *Selection) bool) bool {
return s.FilterFunction(f).Length() > 0
}
// IsSelection checks the current matched set of elements against a Selection object
// and returns true if at least one of these elements matches.
func (s *Selection) IsSelection(sel *Selection) bool {
return s.FilterSelection(sel).Length() > 0
}
// IsNodes checks the current matched set of elements against the specified nodes
// and returns true if at least one of these elements matches.
func (s *Selection) IsNodes(nodes ...*html.Node) bool {
return s.FilterNodes(nodes...).Length() > 0
}
// Contains returns true if the specified Node is within,
// at any depth, one of the nodes in the Selection object.
// It is NOT inclusive, to behave like jQuery's implementation, and
// unlike Javascript's .contains, so if the contained
// node is itself in the selection, it returns false.
func (s *Selection) Contains(n *html.Node) bool {
return sliceContains(s.Nodes, n)
}

698
vendor/github.com/PuerkitoBio/goquery/traversal.go generated vendored Normal file
View File

@@ -0,0 +1,698 @@
package goquery
import "golang.org/x/net/html"
type siblingType int
// Sibling type, used internally when iterating over children at the same
// level (siblings) to specify which nodes are requested.
const (
siblingPrevUntil siblingType = iota - 3
siblingPrevAll
siblingPrev
siblingAll
siblingNext
siblingNextAll
siblingNextUntil
siblingAllIncludingNonElements
)
// Find gets the descendants of each element in the current set of matched
// elements, filtered by a selector. It returns a new Selection object
// containing these matched elements.
func (s *Selection) Find(selector string) *Selection {
return pushStack(s, findWithMatcher(s.Nodes, compileMatcher(selector)))
}
// FindMatcher gets the descendants of each element in the current set of matched
// elements, filtered by the matcher. It returns a new Selection object
// containing these matched elements.
func (s *Selection) FindMatcher(m Matcher) *Selection {
return pushStack(s, findWithMatcher(s.Nodes, m))
}
// FindSelection gets the descendants of each element in the current
// Selection, filtered by a Selection. It returns a new Selection object
// containing these matched elements.
func (s *Selection) FindSelection(sel *Selection) *Selection {
if sel == nil {
return pushStack(s, nil)
}
return s.FindNodes(sel.Nodes...)
}
// FindNodes gets the descendants of each element in the current
// Selection, filtered by some nodes. It returns a new Selection object
// containing these matched elements.
func (s *Selection) FindNodes(nodes ...*html.Node) *Selection {
return pushStack(s, mapNodes(nodes, func(i int, n *html.Node) []*html.Node {
if sliceContains(s.Nodes, n) {
return []*html.Node{n}
}
return nil
}))
}
// Contents gets the children of each element in the Selection,
// including text and comment nodes. It returns a new Selection object
// containing these elements.
func (s *Selection) Contents() *Selection {
return pushStack(s, getChildrenNodes(s.Nodes, siblingAllIncludingNonElements))
}
// ContentsFiltered gets the children of each element in the Selection,
// filtered by the specified selector. It returns a new Selection
// object containing these elements. Since selectors only act on Element nodes,
// this function is an alias to ChildrenFiltered unless the selector is empty,
// in which case it is an alias to Contents.
func (s *Selection) ContentsFiltered(selector string) *Selection {
if selector != "" {
return s.ChildrenFiltered(selector)
}
return s.Contents()
}
// ContentsMatcher gets the children of each element in the Selection,
// filtered by the specified matcher. It returns a new Selection
// object containing these elements. Since matchers only act on Element nodes,
// this function is an alias to ChildrenMatcher.
func (s *Selection) ContentsMatcher(m Matcher) *Selection {
return s.ChildrenMatcher(m)
}
// Children gets the child elements of each element in the Selection.
// It returns a new Selection object containing these elements.
func (s *Selection) Children() *Selection {
return pushStack(s, getChildrenNodes(s.Nodes, siblingAll))
}
// ChildrenFiltered gets the child elements of each element in the Selection,
// filtered by the specified selector. It returns a new
// Selection object containing these elements.
func (s *Selection) ChildrenFiltered(selector string) *Selection {
return filterAndPush(s, getChildrenNodes(s.Nodes, siblingAll), compileMatcher(selector))
}
// ChildrenMatcher gets the child elements of each element in the Selection,
// filtered by the specified matcher. It returns a new
// Selection object containing these elements.
func (s *Selection) ChildrenMatcher(m Matcher) *Selection {
return filterAndPush(s, getChildrenNodes(s.Nodes, siblingAll), m)
}
// Parent gets the parent of each element in the Selection. It returns a
// new Selection object containing the matched elements.
func (s *Selection) Parent() *Selection {
return pushStack(s, getParentNodes(s.Nodes))
}
// ParentFiltered gets the parent of each element in the Selection filtered by a
// selector. It returns a new Selection object containing the matched elements.
func (s *Selection) ParentFiltered(selector string) *Selection {
return filterAndPush(s, getParentNodes(s.Nodes), compileMatcher(selector))
}
// ParentMatcher gets the parent of each element in the Selection filtered by a
// matcher. It returns a new Selection object containing the matched elements.
func (s *Selection) ParentMatcher(m Matcher) *Selection {
return filterAndPush(s, getParentNodes(s.Nodes), m)
}
// Closest gets the first element that matches the selector by testing the
// element itself and traversing up through its ancestors in the DOM tree.
func (s *Selection) Closest(selector string) *Selection {
cs := compileMatcher(selector)
return s.ClosestMatcher(cs)
}
// ClosestMatcher gets the first element that matches the matcher by testing the
// element itself and traversing up through its ancestors in the DOM tree.
func (s *Selection) ClosestMatcher(m Matcher) *Selection {
return pushStack(s, mapNodes(s.Nodes, func(i int, n *html.Node) []*html.Node {
// For each node in the selection, test the node itself, then each parent
// until a match is found.
for ; n != nil; n = n.Parent {
if m.Match(n) {
return []*html.Node{n}
}
}
return nil
}))
}
// ClosestNodes gets the first element that matches one of the nodes by testing the
// element itself and traversing up through its ancestors in the DOM tree.
func (s *Selection) ClosestNodes(nodes ...*html.Node) *Selection {
set := make(map[*html.Node]bool)
for _, n := range nodes {
set[n] = true
}
return pushStack(s, mapNodes(s.Nodes, func(i int, n *html.Node) []*html.Node {
// For each node in the selection, test the node itself, then each parent
// until a match is found.
for ; n != nil; n = n.Parent {
if set[n] {
return []*html.Node{n}
}
}
return nil
}))
}
// ClosestSelection gets the first element that matches one of the nodes in the
// Selection by testing the element itself and traversing up through its ancestors
// in the DOM tree.
func (s *Selection) ClosestSelection(sel *Selection) *Selection {
if sel == nil {
return pushStack(s, nil)
}
return s.ClosestNodes(sel.Nodes...)
}
// Parents gets the ancestors of each element in the current Selection. It
// returns a new Selection object with the matched elements.
func (s *Selection) Parents() *Selection {
return pushStack(s, getParentsNodes(s.Nodes, nil, nil))
}
// ParentsFiltered gets the ancestors of each element in the current
// Selection. It returns a new Selection object with the matched elements.
func (s *Selection) ParentsFiltered(selector string) *Selection {
return filterAndPush(s, getParentsNodes(s.Nodes, nil, nil), compileMatcher(selector))
}
// ParentsMatcher gets the ancestors of each element in the current
// Selection. It returns a new Selection object with the matched elements.
func (s *Selection) ParentsMatcher(m Matcher) *Selection {
return filterAndPush(s, getParentsNodes(s.Nodes, nil, nil), m)
}
// ParentsUntil gets the ancestors of each element in the Selection, up to but
// not including the element matched by the selector. It returns a new Selection
// object containing the matched elements.
func (s *Selection) ParentsUntil(selector string) *Selection {
return pushStack(s, getParentsNodes(s.Nodes, compileMatcher(selector), nil))
}
// ParentsUntilMatcher gets the ancestors of each element in the Selection, up to but
// not including the element matched by the matcher. It returns a new Selection
// object containing the matched elements.
func (s *Selection) ParentsUntilMatcher(m Matcher) *Selection {
return pushStack(s, getParentsNodes(s.Nodes, m, nil))
}
// ParentsUntilSelection gets the ancestors of each element in the Selection,
// up to but not including the elements in the specified Selection. It returns a
// new Selection object containing the matched elements.
func (s *Selection) ParentsUntilSelection(sel *Selection) *Selection {
if sel == nil {
return s.Parents()
}
return s.ParentsUntilNodes(sel.Nodes...)
}
// ParentsUntilNodes gets the ancestors of each element in the Selection,
// up to but not including the specified nodes. It returns a
// new Selection object containing the matched elements.
func (s *Selection) ParentsUntilNodes(nodes ...*html.Node) *Selection {
return pushStack(s, getParentsNodes(s.Nodes, nil, nodes))
}
// ParentsFilteredUntil is like ParentsUntil, with the option to filter the
// results based on a selector string. It returns a new Selection
// object containing the matched elements.
func (s *Selection) ParentsFilteredUntil(filterSelector, untilSelector string) *Selection {
return filterAndPush(s, getParentsNodes(s.Nodes, compileMatcher(untilSelector), nil), compileMatcher(filterSelector))
}
// ParentsFilteredUntilMatcher is like ParentsUntilMatcher, with the option to filter the
// results based on a matcher. It returns a new Selection object containing the matched elements.
func (s *Selection) ParentsFilteredUntilMatcher(filter, until Matcher) *Selection {
return filterAndPush(s, getParentsNodes(s.Nodes, until, nil), filter)
}
// ParentsFilteredUntilSelection is like ParentsUntilSelection, with the
// option to filter the results based on a selector string. It returns a new
// Selection object containing the matched elements.
func (s *Selection) ParentsFilteredUntilSelection(filterSelector string, sel *Selection) *Selection {
return s.ParentsMatcherUntilSelection(compileMatcher(filterSelector), sel)
}
// ParentsMatcherUntilSelection is like ParentsUntilSelection, with the
// option to filter the results based on a matcher. It returns a new
// Selection object containing the matched elements.
func (s *Selection) ParentsMatcherUntilSelection(filter Matcher, sel *Selection) *Selection {
if sel == nil {
return s.ParentsMatcher(filter)
}
return s.ParentsMatcherUntilNodes(filter, sel.Nodes...)
}
// ParentsFilteredUntilNodes is like ParentsUntilNodes, with the
// option to filter the results based on a selector string. It returns a new
// Selection object containing the matched elements.
func (s *Selection) ParentsFilteredUntilNodes(filterSelector string, nodes ...*html.Node) *Selection {
return filterAndPush(s, getParentsNodes(s.Nodes, nil, nodes), compileMatcher(filterSelector))
}
// ParentsMatcherUntilNodes is like ParentsUntilNodes, with the
// option to filter the results based on a matcher. It returns a new
// Selection object containing the matched elements.
func (s *Selection) ParentsMatcherUntilNodes(filter Matcher, nodes ...*html.Node) *Selection {
return filterAndPush(s, getParentsNodes(s.Nodes, nil, nodes), filter)
}
// Siblings gets the siblings of each element in the Selection. It returns
// a new Selection object containing the matched elements.
func (s *Selection) Siblings() *Selection {
return pushStack(s, getSiblingNodes(s.Nodes, siblingAll, nil, nil))
}
// SiblingsFiltered gets the siblings of each element in the Selection
// filtered by a selector. It returns a new Selection object containing the
// matched elements.
func (s *Selection) SiblingsFiltered(selector string) *Selection {
return filterAndPush(s, getSiblingNodes(s.Nodes, siblingAll, nil, nil), compileMatcher(selector))
}
// SiblingsMatcher gets the siblings of each element in the Selection
// filtered by a matcher. It returns a new Selection object containing the
// matched elements.
func (s *Selection) SiblingsMatcher(m Matcher) *Selection {
return filterAndPush(s, getSiblingNodes(s.Nodes, siblingAll, nil, nil), m)
}
// Next gets the immediately following sibling of each element in the
// Selection. It returns a new Selection object containing the matched elements.
func (s *Selection) Next() *Selection {
return pushStack(s, getSiblingNodes(s.Nodes, siblingNext, nil, nil))
}
// NextFiltered gets the immediately following sibling of each element in the
// Selection filtered by a selector. It returns a new Selection object
// containing the matched elements.
func (s *Selection) NextFiltered(selector string) *Selection {
return filterAndPush(s, getSiblingNodes(s.Nodes, siblingNext, nil, nil), compileMatcher(selector))
}
// NextMatcher gets the immediately following sibling of each element in the
// Selection filtered by a matcher. It returns a new Selection object
// containing the matched elements.
func (s *Selection) NextMatcher(m Matcher) *Selection {
return filterAndPush(s, getSiblingNodes(s.Nodes, siblingNext, nil, nil), m)
}
// NextAll gets all the following siblings of each element in the
// Selection. It returns a new Selection object containing the matched elements.
func (s *Selection) NextAll() *Selection {
return pushStack(s, getSiblingNodes(s.Nodes, siblingNextAll, nil, nil))
}
// NextAllFiltered gets all the following siblings of each element in the
// Selection filtered by a selector. It returns a new Selection object
// containing the matched elements.
func (s *Selection) NextAllFiltered(selector string) *Selection {
return filterAndPush(s, getSiblingNodes(s.Nodes, siblingNextAll, nil, nil), compileMatcher(selector))
}
// NextAllMatcher gets all the following siblings of each element in the
// Selection filtered by a matcher. It returns a new Selection object
// containing the matched elements.
func (s *Selection) NextAllMatcher(m Matcher) *Selection {
return filterAndPush(s, getSiblingNodes(s.Nodes, siblingNextAll, nil, nil), m)
}
// Prev gets the immediately preceding sibling of each element in the
// Selection. It returns a new Selection object containing the matched elements.
func (s *Selection) Prev() *Selection {
return pushStack(s, getSiblingNodes(s.Nodes, siblingPrev, nil, nil))
}
// PrevFiltered gets the immediately preceding sibling of each element in the
// Selection filtered by a selector. It returns a new Selection object
// containing the matched elements.
func (s *Selection) PrevFiltered(selector string) *Selection {
return filterAndPush(s, getSiblingNodes(s.Nodes, siblingPrev, nil, nil), compileMatcher(selector))
}
// PrevMatcher gets the immediately preceding sibling of each element in the
// Selection filtered by a matcher. It returns a new Selection object
// containing the matched elements.
func (s *Selection) PrevMatcher(m Matcher) *Selection {
return filterAndPush(s, getSiblingNodes(s.Nodes, siblingPrev, nil, nil), m)
}
// PrevAll gets all the preceding siblings of each element in the
// Selection. It returns a new Selection object containing the matched elements.
func (s *Selection) PrevAll() *Selection {
return pushStack(s, getSiblingNodes(s.Nodes, siblingPrevAll, nil, nil))
}
// PrevAllFiltered gets all the preceding siblings of each element in the
// Selection filtered by a selector. It returns a new Selection object
// containing the matched elements.
func (s *Selection) PrevAllFiltered(selector string) *Selection {
return filterAndPush(s, getSiblingNodes(s.Nodes, siblingPrevAll, nil, nil), compileMatcher(selector))
}
// PrevAllMatcher gets all the preceding siblings of each element in the
// Selection filtered by a matcher. It returns a new Selection object
// containing the matched elements.
func (s *Selection) PrevAllMatcher(m Matcher) *Selection {
return filterAndPush(s, getSiblingNodes(s.Nodes, siblingPrevAll, nil, nil), m)
}
// NextUntil gets all following siblings of each element up to but not
// including the element matched by the selector. It returns a new Selection
// object containing the matched elements.
func (s *Selection) NextUntil(selector string) *Selection {
return pushStack(s, getSiblingNodes(s.Nodes, siblingNextUntil,
compileMatcher(selector), nil))
}
// NextUntilMatcher gets all following siblings of each element up to but not
// including the element matched by the matcher. It returns a new Selection
// object containing the matched elements.
func (s *Selection) NextUntilMatcher(m Matcher) *Selection {
return pushStack(s, getSiblingNodes(s.Nodes, siblingNextUntil,
m, nil))
}
// NextUntilSelection gets all following siblings of each element up to but not
// including the element matched by the Selection. It returns a new Selection
// object containing the matched elements.
func (s *Selection) NextUntilSelection(sel *Selection) *Selection {
if sel == nil {
return s.NextAll()
}
return s.NextUntilNodes(sel.Nodes...)
}
// NextUntilNodes gets all following siblings of each element up to but not
// including the element matched by the nodes. It returns a new Selection
// object containing the matched elements.
func (s *Selection) NextUntilNodes(nodes ...*html.Node) *Selection {
return pushStack(s, getSiblingNodes(s.Nodes, siblingNextUntil,
nil, nodes))
}
// PrevUntil gets all preceding siblings of each element up to but not
// including the element matched by the selector. It returns a new Selection
// object containing the matched elements.
func (s *Selection) PrevUntil(selector string) *Selection {
return pushStack(s, getSiblingNodes(s.Nodes, siblingPrevUntil,
compileMatcher(selector), nil))
}
// PrevUntilMatcher gets all preceding siblings of each element up to but not
// including the element matched by the matcher. It returns a new Selection
// object containing the matched elements.
func (s *Selection) PrevUntilMatcher(m Matcher) *Selection {
return pushStack(s, getSiblingNodes(s.Nodes, siblingPrevUntil,
m, nil))
}
// PrevUntilSelection gets all preceding siblings of each element up to but not
// including the element matched by the Selection. It returns a new Selection
// object containing the matched elements.
func (s *Selection) PrevUntilSelection(sel *Selection) *Selection {
if sel == nil {
return s.PrevAll()
}
return s.PrevUntilNodes(sel.Nodes...)
}
// PrevUntilNodes gets all preceding siblings of each element up to but not
// including the element matched by the nodes. It returns a new Selection
// object containing the matched elements.
func (s *Selection) PrevUntilNodes(nodes ...*html.Node) *Selection {
return pushStack(s, getSiblingNodes(s.Nodes, siblingPrevUntil,
nil, nodes))
}
// NextFilteredUntil is like NextUntil, with the option to filter
// the results based on a selector string.
// It returns a new Selection object containing the matched elements.
func (s *Selection) NextFilteredUntil(filterSelector, untilSelector string) *Selection {
return filterAndPush(s, getSiblingNodes(s.Nodes, siblingNextUntil,
compileMatcher(untilSelector), nil), compileMatcher(filterSelector))
}
// NextFilteredUntilMatcher is like NextUntilMatcher, with the option to filter
// the results based on a matcher.
// It returns a new Selection object containing the matched elements.
func (s *Selection) NextFilteredUntilMatcher(filter, until Matcher) *Selection {
return filterAndPush(s, getSiblingNodes(s.Nodes, siblingNextUntil,
until, nil), filter)
}
// NextFilteredUntilSelection is like NextUntilSelection, with the
// option to filter the results based on a selector string. It returns a new
// Selection object containing the matched elements.
func (s *Selection) NextFilteredUntilSelection(filterSelector string, sel *Selection) *Selection {
return s.NextMatcherUntilSelection(compileMatcher(filterSelector), sel)
}
// NextMatcherUntilSelection is like NextUntilSelection, with the
// option to filter the results based on a matcher. It returns a new
// Selection object containing the matched elements.
func (s *Selection) NextMatcherUntilSelection(filter Matcher, sel *Selection) *Selection {
if sel == nil {
return s.NextMatcher(filter)
}
return s.NextMatcherUntilNodes(filter, sel.Nodes...)
}
// NextFilteredUntilNodes is like NextUntilNodes, with the
// option to filter the results based on a selector string. It returns a new
// Selection object containing the matched elements.
func (s *Selection) NextFilteredUntilNodes(filterSelector string, nodes ...*html.Node) *Selection {
return filterAndPush(s, getSiblingNodes(s.Nodes, siblingNextUntil,
nil, nodes), compileMatcher(filterSelector))
}
// NextMatcherUntilNodes is like NextUntilNodes, with the
// option to filter the results based on a matcher. It returns a new
// Selection object containing the matched elements.
func (s *Selection) NextMatcherUntilNodes(filter Matcher, nodes ...*html.Node) *Selection {
return filterAndPush(s, getSiblingNodes(s.Nodes, siblingNextUntil,
nil, nodes), filter)
}
// PrevFilteredUntil is like PrevUntil, with the option to filter
// the results based on a selector string.
// It returns a new Selection object containing the matched elements.
func (s *Selection) PrevFilteredUntil(filterSelector, untilSelector string) *Selection {
return filterAndPush(s, getSiblingNodes(s.Nodes, siblingPrevUntil,
compileMatcher(untilSelector), nil), compileMatcher(filterSelector))
}
// PrevFilteredUntilMatcher is like PrevUntilMatcher, with the option to filter
// the results based on a matcher.
// It returns a new Selection object containing the matched elements.
func (s *Selection) PrevFilteredUntilMatcher(filter, until Matcher) *Selection {
return filterAndPush(s, getSiblingNodes(s.Nodes, siblingPrevUntil,
until, nil), filter)
}
// PrevFilteredUntilSelection is like PrevUntilSelection, with the
// option to filter the results based on a selector string. It returns a new
// Selection object containing the matched elements.
func (s *Selection) PrevFilteredUntilSelection(filterSelector string, sel *Selection) *Selection {
return s.PrevMatcherUntilSelection(compileMatcher(filterSelector), sel)
}
// PrevMatcherUntilSelection is like PrevUntilSelection, with the
// option to filter the results based on a matcher. It returns a new
// Selection object containing the matched elements.
func (s *Selection) PrevMatcherUntilSelection(filter Matcher, sel *Selection) *Selection {
if sel == nil {
return s.PrevMatcher(filter)
}
return s.PrevMatcherUntilNodes(filter, sel.Nodes...)
}
// PrevFilteredUntilNodes is like PrevUntilNodes, with the
// option to filter the results based on a selector string. It returns a new
// Selection object containing the matched elements.
func (s *Selection) PrevFilteredUntilNodes(filterSelector string, nodes ...*html.Node) *Selection {
return filterAndPush(s, getSiblingNodes(s.Nodes, siblingPrevUntil,
nil, nodes), compileMatcher(filterSelector))
}
// PrevMatcherUntilNodes is like PrevUntilNodes, with the
// option to filter the results based on a matcher. It returns a new
// Selection object containing the matched elements.
func (s *Selection) PrevMatcherUntilNodes(filter Matcher, nodes ...*html.Node) *Selection {
return filterAndPush(s, getSiblingNodes(s.Nodes, siblingPrevUntil,
nil, nodes), filter)
}
// Filter and push filters the nodes based on a matcher, and pushes the results
// on the stack, with the srcSel as previous selection.
func filterAndPush(srcSel *Selection, nodes []*html.Node, m Matcher) *Selection {
// Create a temporary Selection with the specified nodes to filter using winnow
sel := &Selection{nodes, srcSel.document, nil}
// Filter based on matcher and push on stack
return pushStack(srcSel, winnow(sel, m, true))
}
// Internal implementation of Find that return raw nodes.
func findWithMatcher(nodes []*html.Node, m Matcher) []*html.Node {
// Map nodes to find the matches within the children of each node
return mapNodes(nodes, func(i int, n *html.Node) (result []*html.Node) {
// Go down one level, becausejQuery's Find selects only within descendants
for c := n.FirstChild; c != nil; c = c.NextSibling {
if c.Type == html.ElementNode {
result = append(result, m.MatchAll(c)...)
}
}
return
})
}
// Internal implementation to get all parent nodes, stopping at the specified
// node (or nil if no stop).
func getParentsNodes(nodes []*html.Node, stopm Matcher, stopNodes []*html.Node) []*html.Node {
return mapNodes(nodes, func(i int, n *html.Node) (result []*html.Node) {
for p := n.Parent; p != nil; p = p.Parent {
sel := newSingleSelection(p, nil)
if stopm != nil {
if sel.IsMatcher(stopm) {
break
}
} else if len(stopNodes) > 0 {
if sel.IsNodes(stopNodes...) {
break
}
}
if p.Type == html.ElementNode {
result = append(result, p)
}
}
return
})
}
// Internal implementation of sibling nodes that return a raw slice of matches.
func getSiblingNodes(nodes []*html.Node, st siblingType, untilm Matcher, untilNodes []*html.Node) []*html.Node {
var f func(*html.Node) bool
// If the requested siblings are ...Until, create the test function to
// determine if the until condition is reached (returns true if it is)
if st == siblingNextUntil || st == siblingPrevUntil {
f = func(n *html.Node) bool {
if untilm != nil {
// Matcher-based condition
sel := newSingleSelection(n, nil)
return sel.IsMatcher(untilm)
} else if len(untilNodes) > 0 {
// Nodes-based condition
sel := newSingleSelection(n, nil)
return sel.IsNodes(untilNodes...)
}
return false
}
}
return mapNodes(nodes, func(i int, n *html.Node) []*html.Node {
return getChildrenWithSiblingType(n.Parent, st, n, f)
})
}
// Gets the children nodes of each node in the specified slice of nodes,
// based on the sibling type request.
func getChildrenNodes(nodes []*html.Node, st siblingType) []*html.Node {
return mapNodes(nodes, func(i int, n *html.Node) []*html.Node {
return getChildrenWithSiblingType(n, st, nil, nil)
})
}
// Gets the children of the specified parent, based on the requested sibling
// type, skipping a specified node if required.
func getChildrenWithSiblingType(parent *html.Node, st siblingType, skipNode *html.Node,
untilFunc func(*html.Node) bool) (result []*html.Node) {
// Create the iterator function
var iter = func(cur *html.Node) (ret *html.Node) {
// Based on the sibling type requested, iterate the right way
for {
switch st {
case siblingAll, siblingAllIncludingNonElements:
if cur == nil {
// First iteration, start with first child of parent
// Skip node if required
if ret = parent.FirstChild; ret == skipNode && skipNode != nil {
ret = skipNode.NextSibling
}
} else {
// Skip node if required
if ret = cur.NextSibling; ret == skipNode && skipNode != nil {
ret = skipNode.NextSibling
}
}
case siblingPrev, siblingPrevAll, siblingPrevUntil:
if cur == nil {
// Start with previous sibling of the skip node
ret = skipNode.PrevSibling
} else {
ret = cur.PrevSibling
}
case siblingNext, siblingNextAll, siblingNextUntil:
if cur == nil {
// Start with next sibling of the skip node
ret = skipNode.NextSibling
} else {
ret = cur.NextSibling
}
default:
panic("Invalid sibling type.")
}
if ret == nil || ret.Type == html.ElementNode || st == siblingAllIncludingNonElements {
return
}
// Not a valid node, try again from this one
cur = ret
}
}
for c := iter(nil); c != nil; c = iter(c) {
// If this is an ...Until case, test before append (returns true
// if the until condition is reached)
if st == siblingNextUntil || st == siblingPrevUntil {
if untilFunc(c) {
return
}
}
result = append(result, c)
if st == siblingNext || st == siblingPrev {
// Only one node was requested (immediate next or previous), so exit
return
}
}
return
}
// Internal implementation of parent nodes that return a raw slice of Nodes.
func getParentNodes(nodes []*html.Node) []*html.Node {
return mapNodes(nodes, func(i int, n *html.Node) []*html.Node {
if n.Parent != nil && n.Parent.Type == html.ElementNode {
return []*html.Node{n.Parent}
}
return nil
})
}
// Internal map function used by many traversing methods. Takes the source nodes
// to iterate on and the mapping function that returns an array of nodes.
// Returns an array of nodes mapped by calling the callback function once for
// each node in the source nodes.
func mapNodes(nodes []*html.Node, f func(int, *html.Node) []*html.Node) (result []*html.Node) {
set := make(map[*html.Node]bool)
for i, n := range nodes {
if vals := f(i, n); len(vals) > 0 {
result = appendWithoutDuplicates(result, vals, set)
}
}
return result
}

141
vendor/github.com/PuerkitoBio/goquery/type.go generated vendored Normal file
View File

@@ -0,0 +1,141 @@
package goquery
import (
"errors"
"io"
"net/http"
"net/url"
"github.com/andybalholm/cascadia"
"golang.org/x/net/html"
)
// Document represents an HTML document to be manipulated. Unlike jQuery, which
// is loaded as part of a DOM document, and thus acts upon its containing
// document, GoQuery doesn't know which HTML document to act upon. So it needs
// to be told, and that's what the Document class is for. It holds the root
// document node to manipulate, and can make selections on this document.
type Document struct {
*Selection
Url *url.URL
rootNode *html.Node
}
// NewDocumentFromNode is a Document constructor that takes a root html Node
// as argument.
func NewDocumentFromNode(root *html.Node) *Document {
return newDocument(root, nil)
}
// NewDocument is a Document constructor that takes a string URL as argument.
// It loads the specified document, parses it, and stores the root Document
// node, ready to be manipulated.
//
// Deprecated: Use the net/http standard library package to make the request
// and validate the response before calling goquery.NewDocumentFromReader
// with the response's body.
func NewDocument(url string) (*Document, error) {
// Load the URL
res, e := http.Get(url)
if e != nil {
return nil, e
}
return NewDocumentFromResponse(res)
}
// NewDocumentFromReader returns a Document from an io.Reader.
// It returns an error as second value if the reader's data cannot be parsed
// as html. It does not check if the reader is also an io.Closer, the
// provided reader is never closed by this call. It is the responsibility
// of the caller to close it if required.
func NewDocumentFromReader(r io.Reader) (*Document, error) {
root, e := html.Parse(r)
if e != nil {
return nil, e
}
return newDocument(root, nil), nil
}
// NewDocumentFromResponse is another Document constructor that takes an http response as argument.
// It loads the specified response's document, parses it, and stores the root Document
// node, ready to be manipulated. The response's body is closed on return.
//
// Deprecated: Use goquery.NewDocumentFromReader with the response's body.
func NewDocumentFromResponse(res *http.Response) (*Document, error) {
if res == nil {
return nil, errors.New("Response is nil")
}
defer res.Body.Close()
if res.Request == nil {
return nil, errors.New("Response.Request is nil")
}
// Parse the HTML into nodes
root, e := html.Parse(res.Body)
if e != nil {
return nil, e
}
// Create and fill the document
return newDocument(root, res.Request.URL), nil
}
// CloneDocument creates a deep-clone of a document.
func CloneDocument(doc *Document) *Document {
return newDocument(cloneNode(doc.rootNode), doc.Url)
}
// Private constructor, make sure all fields are correctly filled.
func newDocument(root *html.Node, url *url.URL) *Document {
// Create and fill the document
d := &Document{nil, url, root}
d.Selection = newSingleSelection(root, d)
return d
}
// Selection represents a collection of nodes matching some criteria. The
// initial Selection can be created by using Document.Find, and then
// manipulated using the jQuery-like chainable syntax and methods.
type Selection struct {
Nodes []*html.Node
document *Document
prevSel *Selection
}
// Helper constructor to create an empty selection
func newEmptySelection(doc *Document) *Selection {
return &Selection{nil, doc, nil}
}
// Helper constructor to create a selection of only one node
func newSingleSelection(node *html.Node, doc *Document) *Selection {
return &Selection{[]*html.Node{node}, doc, nil}
}
// Matcher is an interface that defines the methods to match
// HTML nodes against a compiled selector string. Cascadia's
// Selector implements this interface.
type Matcher interface {
Match(*html.Node) bool
MatchAll(*html.Node) []*html.Node
Filter([]*html.Node) []*html.Node
}
// compileMatcher compiles the selector string s and returns
// the corresponding Matcher. If s is an invalid selector string,
// it returns a Matcher that fails all matches.
func compileMatcher(s string) Matcher {
cs, err := cascadia.Compile(s)
if err != nil {
return invalidMatcher{}
}
return cs
}
// invalidMatcher is a Matcher that always fails to match.
type invalidMatcher struct{}
func (invalidMatcher) Match(n *html.Node) bool { return false }
func (invalidMatcher) MatchAll(n *html.Node) []*html.Node { return nil }
func (invalidMatcher) Filter(ns []*html.Node) []*html.Node { return nil }

161
vendor/github.com/PuerkitoBio/goquery/utilities.go generated vendored Normal file
View File

@@ -0,0 +1,161 @@
package goquery
import (
"bytes"
"golang.org/x/net/html"
)
// used to determine if a set (map[*html.Node]bool) should be used
// instead of iterating over a slice. The set uses more memory and
// is slower than slice iteration for small N.
const minNodesForSet = 1000
var nodeNames = []string{
html.ErrorNode: "#error",
html.TextNode: "#text",
html.DocumentNode: "#document",
html.CommentNode: "#comment",
}
// NodeName returns the node name of the first element in the selection.
// It tries to behave in a similar way as the DOM's nodeName property
// (https://developer.mozilla.org/en-US/docs/Web/API/Node/nodeName).
//
// Go's net/html package defines the following node types, listed with
// the corresponding returned value from this function:
//
// ErrorNode : #error
// TextNode : #text
// DocumentNode : #document
// ElementNode : the element's tag name
// CommentNode : #comment
// DoctypeNode : the name of the document type
//
func NodeName(s *Selection) string {
if s.Length() == 0 {
return ""
}
switch n := s.Get(0); n.Type {
case html.ElementNode, html.DoctypeNode:
return n.Data
default:
if n.Type >= 0 && int(n.Type) < len(nodeNames) {
return nodeNames[n.Type]
}
return ""
}
}
// OuterHtml returns the outer HTML rendering of the first item in
// the selection - that is, the HTML including the first element's
// tag and attributes.
//
// Unlike InnerHtml, this is a function and not a method on the Selection,
// because this is not a jQuery method (in javascript-land, this is
// a property provided by the DOM).
func OuterHtml(s *Selection) (string, error) {
var buf bytes.Buffer
if s.Length() == 0 {
return "", nil
}
n := s.Get(0)
if err := html.Render(&buf, n); err != nil {
return "", err
}
return buf.String(), nil
}
// Loop through all container nodes to search for the target node.
func sliceContains(container []*html.Node, contained *html.Node) bool {
for _, n := range container {
if nodeContains(n, contained) {
return true
}
}
return false
}
// Checks if the contained node is within the container node.
func nodeContains(container *html.Node, contained *html.Node) bool {
// Check if the parent of the contained node is the container node, traversing
// upward until the top is reached, or the container is found.
for contained = contained.Parent; contained != nil; contained = contained.Parent {
if container == contained {
return true
}
}
return false
}
// Checks if the target node is in the slice of nodes.
func isInSlice(slice []*html.Node, node *html.Node) bool {
return indexInSlice(slice, node) > -1
}
// Returns the index of the target node in the slice, or -1.
func indexInSlice(slice []*html.Node, node *html.Node) int {
if node != nil {
for i, n := range slice {
if n == node {
return i
}
}
}
return -1
}
// Appends the new nodes to the target slice, making sure no duplicate is added.
// There is no check to the original state of the target slice, so it may still
// contain duplicates. The target slice is returned because append() may create
// a new underlying array. If targetSet is nil, a local set is created with the
// target if len(target) + len(nodes) is greater than minNodesForSet.
func appendWithoutDuplicates(target []*html.Node, nodes []*html.Node, targetSet map[*html.Node]bool) []*html.Node {
// if there are not that many nodes, don't use the map, faster to just use nested loops
// (unless a non-nil targetSet is passed, in which case the caller knows better).
if targetSet == nil && len(target)+len(nodes) < minNodesForSet {
for _, n := range nodes {
if !isInSlice(target, n) {
target = append(target, n)
}
}
return target
}
// if a targetSet is passed, then assume it is reliable, otherwise create one
// and initialize it with the current target contents.
if targetSet == nil {
targetSet = make(map[*html.Node]bool, len(target))
for _, n := range target {
targetSet[n] = true
}
}
for _, n := range nodes {
if !targetSet[n] {
target = append(target, n)
targetSet[n] = true
}
}
return target
}
// Loop through a selection, returning only those nodes that pass the predicate
// function.
func grep(sel *Selection, predicate func(i int, s *Selection) bool) (result []*html.Node) {
for i, n := range sel.Nodes {
if predicate(i, newSingleSelection(n, sel.document)) {
result = append(result, n)
}
}
return result
}
// Creates a new Selection object based on the specified nodes, and keeps the
// source Selection object on the stack (linked list).
func pushStack(fromSel *Selection, nodes []*html.Node) *Selection {
result := &Selection{nodes, fromSel.document, fromSel}
return result
}

24
vendor/github.com/andybalholm/cascadia/LICENSE generated vendored Executable file
View File

@@ -0,0 +1,24 @@
Copyright (c) 2011 Andy Balholm. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

835
vendor/github.com/andybalholm/cascadia/parser.go generated vendored Normal file
View File

@@ -0,0 +1,835 @@
// Package cascadia is an implementation of CSS selectors.
package cascadia
import (
"errors"
"fmt"
"regexp"
"strconv"
"strings"
"golang.org/x/net/html"
)
// a parser for CSS selectors
type parser struct {
s string // the source text
i int // the current position
}
// parseEscape parses a backslash escape.
func (p *parser) parseEscape() (result string, err error) {
if len(p.s) < p.i+2 || p.s[p.i] != '\\' {
return "", errors.New("invalid escape sequence")
}
start := p.i + 1
c := p.s[start]
switch {
case c == '\r' || c == '\n' || c == '\f':
return "", errors.New("escaped line ending outside string")
case hexDigit(c):
// unicode escape (hex)
var i int
for i = start; i < p.i+6 && i < len(p.s) && hexDigit(p.s[i]); i++ {
// empty
}
v, _ := strconv.ParseUint(p.s[start:i], 16, 21)
if len(p.s) > i {
switch p.s[i] {
case '\r':
i++
if len(p.s) > i && p.s[i] == '\n' {
i++
}
case ' ', '\t', '\n', '\f':
i++
}
}
p.i = i
return string(rune(v)), nil
}
// Return the literal character after the backslash.
result = p.s[start : start+1]
p.i += 2
return result, nil
}
func hexDigit(c byte) bool {
return '0' <= c && c <= '9' || 'a' <= c && c <= 'f' || 'A' <= c && c <= 'F'
}
// nameStart returns whether c can be the first character of an identifier
// (not counting an initial hyphen, or an escape sequence).
func nameStart(c byte) bool {
return 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_' || c > 127
}
// nameChar returns whether c can be a character within an identifier
// (not counting an escape sequence).
func nameChar(c byte) bool {
return 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_' || c > 127 ||
c == '-' || '0' <= c && c <= '9'
}
// parseIdentifier parses an identifier.
func (p *parser) parseIdentifier() (result string, err error) {
startingDash := false
if len(p.s) > p.i && p.s[p.i] == '-' {
startingDash = true
p.i++
}
if len(p.s) <= p.i {
return "", errors.New("expected identifier, found EOF instead")
}
if c := p.s[p.i]; !(nameStart(c) || c == '\\') {
return "", fmt.Errorf("expected identifier, found %c instead", c)
}
result, err = p.parseName()
if startingDash && err == nil {
result = "-" + result
}
return
}
// parseName parses a name (which is like an identifier, but doesn't have
// extra restrictions on the first character).
func (p *parser) parseName() (result string, err error) {
i := p.i
loop:
for i < len(p.s) {
c := p.s[i]
switch {
case nameChar(c):
start := i
for i < len(p.s) && nameChar(p.s[i]) {
i++
}
result += p.s[start:i]
case c == '\\':
p.i = i
val, err := p.parseEscape()
if err != nil {
return "", err
}
i = p.i
result += val
default:
break loop
}
}
if result == "" {
return "", errors.New("expected name, found EOF instead")
}
p.i = i
return result, nil
}
// parseString parses a single- or double-quoted string.
func (p *parser) parseString() (result string, err error) {
i := p.i
if len(p.s) < i+2 {
return "", errors.New("expected string, found EOF instead")
}
quote := p.s[i]
i++
loop:
for i < len(p.s) {
switch p.s[i] {
case '\\':
if len(p.s) > i+1 {
switch c := p.s[i+1]; c {
case '\r':
if len(p.s) > i+2 && p.s[i+2] == '\n' {
i += 3
continue loop
}
fallthrough
case '\n', '\f':
i += 2
continue loop
}
}
p.i = i
val, err := p.parseEscape()
if err != nil {
return "", err
}
i = p.i
result += val
case quote:
break loop
case '\r', '\n', '\f':
return "", errors.New("unexpected end of line in string")
default:
start := i
for i < len(p.s) {
if c := p.s[i]; c == quote || c == '\\' || c == '\r' || c == '\n' || c == '\f' {
break
}
i++
}
result += p.s[start:i]
}
}
if i >= len(p.s) {
return "", errors.New("EOF in string")
}
// Consume the final quote.
i++
p.i = i
return result, nil
}
// parseRegex parses a regular expression; the end is defined by encountering an
// unmatched closing ')' or ']' which is not consumed
func (p *parser) parseRegex() (rx *regexp.Regexp, err error) {
i := p.i
if len(p.s) < i+2 {
return nil, errors.New("expected regular expression, found EOF instead")
}
// number of open parens or brackets;
// when it becomes negative, finished parsing regex
open := 0
loop:
for i < len(p.s) {
switch p.s[i] {
case '(', '[':
open++
case ')', ']':
open--
if open < 0 {
break loop
}
}
i++
}
if i >= len(p.s) {
return nil, errors.New("EOF in regular expression")
}
rx, err = regexp.Compile(p.s[p.i:i])
p.i = i
return rx, err
}
// skipWhitespace consumes whitespace characters and comments.
// It returns true if there was actually anything to skip.
func (p *parser) skipWhitespace() bool {
i := p.i
for i < len(p.s) {
switch p.s[i] {
case ' ', '\t', '\r', '\n', '\f':
i++
continue
case '/':
if strings.HasPrefix(p.s[i:], "/*") {
end := strings.Index(p.s[i+len("/*"):], "*/")
if end != -1 {
i += end + len("/**/")
continue
}
}
}
break
}
if i > p.i {
p.i = i
return true
}
return false
}
// consumeParenthesis consumes an opening parenthesis and any following
// whitespace. It returns true if there was actually a parenthesis to skip.
func (p *parser) consumeParenthesis() bool {
if p.i < len(p.s) && p.s[p.i] == '(' {
p.i++
p.skipWhitespace()
return true
}
return false
}
// consumeClosingParenthesis consumes a closing parenthesis and any preceding
// whitespace. It returns true if there was actually a parenthesis to skip.
func (p *parser) consumeClosingParenthesis() bool {
i := p.i
p.skipWhitespace()
if p.i < len(p.s) && p.s[p.i] == ')' {
p.i++
return true
}
p.i = i
return false
}
// parseTypeSelector parses a type selector (one that matches by tag name).
func (p *parser) parseTypeSelector() (result Selector, err error) {
tag, err := p.parseIdentifier()
if err != nil {
return nil, err
}
return typeSelector(tag), nil
}
// parseIDSelector parses a selector that matches by id attribute.
func (p *parser) parseIDSelector() (Selector, error) {
if p.i >= len(p.s) {
return nil, fmt.Errorf("expected id selector (#id), found EOF instead")
}
if p.s[p.i] != '#' {
return nil, fmt.Errorf("expected id selector (#id), found '%c' instead", p.s[p.i])
}
p.i++
id, err := p.parseName()
if err != nil {
return nil, err
}
return attributeEqualsSelector("id", id), nil
}
// parseClassSelector parses a selector that matches by class attribute.
func (p *parser) parseClassSelector() (Selector, error) {
if p.i >= len(p.s) {
return nil, fmt.Errorf("expected class selector (.class), found EOF instead")
}
if p.s[p.i] != '.' {
return nil, fmt.Errorf("expected class selector (.class), found '%c' instead", p.s[p.i])
}
p.i++
class, err := p.parseIdentifier()
if err != nil {
return nil, err
}
return attributeIncludesSelector("class", class), nil
}
// parseAttributeSelector parses a selector that matches by attribute value.
func (p *parser) parseAttributeSelector() (Selector, error) {
if p.i >= len(p.s) {
return nil, fmt.Errorf("expected attribute selector ([attribute]), found EOF instead")
}
if p.s[p.i] != '[' {
return nil, fmt.Errorf("expected attribute selector ([attribute]), found '%c' instead", p.s[p.i])
}
p.i++
p.skipWhitespace()
key, err := p.parseIdentifier()
if err != nil {
return nil, err
}
p.skipWhitespace()
if p.i >= len(p.s) {
return nil, errors.New("unexpected EOF in attribute selector")
}
if p.s[p.i] == ']' {
p.i++
return attributeExistsSelector(key), nil
}
if p.i+2 >= len(p.s) {
return nil, errors.New("unexpected EOF in attribute selector")
}
op := p.s[p.i : p.i+2]
if op[0] == '=' {
op = "="
} else if op[1] != '=' {
return nil, fmt.Errorf(`expected equality operator, found "%s" instead`, op)
}
p.i += len(op)
p.skipWhitespace()
if p.i >= len(p.s) {
return nil, errors.New("unexpected EOF in attribute selector")
}
var val string
var rx *regexp.Regexp
if op == "#=" {
rx, err = p.parseRegex()
} else {
switch p.s[p.i] {
case '\'', '"':
val, err = p.parseString()
default:
val, err = p.parseIdentifier()
}
}
if err != nil {
return nil, err
}
p.skipWhitespace()
if p.i >= len(p.s) {
return nil, errors.New("unexpected EOF in attribute selector")
}
if p.s[p.i] != ']' {
return nil, fmt.Errorf("expected ']', found '%c' instead", p.s[p.i])
}
p.i++
switch op {
case "=":
return attributeEqualsSelector(key, val), nil
case "!=":
return attributeNotEqualSelector(key, val), nil
case "~=":
return attributeIncludesSelector(key, val), nil
case "|=":
return attributeDashmatchSelector(key, val), nil
case "^=":
return attributePrefixSelector(key, val), nil
case "$=":
return attributeSuffixSelector(key, val), nil
case "*=":
return attributeSubstringSelector(key, val), nil
case "#=":
return attributeRegexSelector(key, rx), nil
}
return nil, fmt.Errorf("attribute operator %q is not supported", op)
}
var errExpectedParenthesis = errors.New("expected '(' but didn't find it")
var errExpectedClosingParenthesis = errors.New("expected ')' but didn't find it")
var errUnmatchedParenthesis = errors.New("unmatched '('")
// parsePseudoclassSelector parses a pseudoclass selector like :not(p).
func (p *parser) parsePseudoclassSelector() (Selector, error) {
if p.i >= len(p.s) {
return nil, fmt.Errorf("expected pseudoclass selector (:pseudoclass), found EOF instead")
}
if p.s[p.i] != ':' {
return nil, fmt.Errorf("expected attribute selector (:pseudoclass), found '%c' instead", p.s[p.i])
}
p.i++
name, err := p.parseIdentifier()
if err != nil {
return nil, err
}
name = toLowerASCII(name)
switch name {
case "not", "has", "haschild":
if !p.consumeParenthesis() {
return nil, errExpectedParenthesis
}
sel, parseErr := p.parseSelectorGroup()
if parseErr != nil {
return nil, parseErr
}
if !p.consumeClosingParenthesis() {
return nil, errExpectedClosingParenthesis
}
switch name {
case "not":
return negatedSelector(sel), nil
case "has":
return hasDescendantSelector(sel), nil
case "haschild":
return hasChildSelector(sel), nil
}
case "contains", "containsown":
if !p.consumeParenthesis() {
return nil, errExpectedParenthesis
}
if p.i == len(p.s) {
return nil, errUnmatchedParenthesis
}
var val string
switch p.s[p.i] {
case '\'', '"':
val, err = p.parseString()
default:
val, err = p.parseIdentifier()
}
if err != nil {
return nil, err
}
val = strings.ToLower(val)
p.skipWhitespace()
if p.i >= len(p.s) {
return nil, errors.New("unexpected EOF in pseudo selector")
}
if !p.consumeClosingParenthesis() {
return nil, errExpectedClosingParenthesis
}
switch name {
case "contains":
return textSubstrSelector(val), nil
case "containsown":
return ownTextSubstrSelector(val), nil
}
case "matches", "matchesown":
if !p.consumeParenthesis() {
return nil, errExpectedParenthesis
}
rx, err := p.parseRegex()
if err != nil {
return nil, err
}
if p.i >= len(p.s) {
return nil, errors.New("unexpected EOF in pseudo selector")
}
if !p.consumeClosingParenthesis() {
return nil, errExpectedClosingParenthesis
}
switch name {
case "matches":
return textRegexSelector(rx), nil
case "matchesown":
return ownTextRegexSelector(rx), nil
}
case "nth-child", "nth-last-child", "nth-of-type", "nth-last-of-type":
if !p.consumeParenthesis() {
return nil, errExpectedParenthesis
}
a, b, err := p.parseNth()
if err != nil {
return nil, err
}
if !p.consumeClosingParenthesis() {
return nil, errExpectedClosingParenthesis
}
if a == 0 {
switch name {
case "nth-child":
return simpleNthChildSelector(b, false), nil
case "nth-of-type":
return simpleNthChildSelector(b, true), nil
case "nth-last-child":
return simpleNthLastChildSelector(b, false), nil
case "nth-last-of-type":
return simpleNthLastChildSelector(b, true), nil
}
}
return nthChildSelector(a, b,
name == "nth-last-child" || name == "nth-last-of-type",
name == "nth-of-type" || name == "nth-last-of-type"),
nil
case "first-child":
return simpleNthChildSelector(1, false), nil
case "last-child":
return simpleNthLastChildSelector(1, false), nil
case "first-of-type":
return simpleNthChildSelector(1, true), nil
case "last-of-type":
return simpleNthLastChildSelector(1, true), nil
case "only-child":
return onlyChildSelector(false), nil
case "only-of-type":
return onlyChildSelector(true), nil
case "input":
return inputSelector, nil
case "empty":
return emptyElementSelector, nil
case "root":
return rootSelector, nil
}
return nil, fmt.Errorf("unknown pseudoclass :%s", name)
}
// parseInteger parses a decimal integer.
func (p *parser) parseInteger() (int, error) {
i := p.i
start := i
for i < len(p.s) && '0' <= p.s[i] && p.s[i] <= '9' {
i++
}
if i == start {
return 0, errors.New("expected integer, but didn't find it")
}
p.i = i
val, err := strconv.Atoi(p.s[start:i])
if err != nil {
return 0, err
}
return val, nil
}
// parseNth parses the argument for :nth-child (normally of the form an+b).
func (p *parser) parseNth() (a, b int, err error) {
// initial state
if p.i >= len(p.s) {
goto eof
}
switch p.s[p.i] {
case '-':
p.i++
goto negativeA
case '+':
p.i++
goto positiveA
case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
goto positiveA
case 'n', 'N':
a = 1
p.i++
goto readN
case 'o', 'O', 'e', 'E':
id, nameErr := p.parseName()
if nameErr != nil {
return 0, 0, nameErr
}
id = toLowerASCII(id)
if id == "odd" {
return 2, 1, nil
}
if id == "even" {
return 2, 0, nil
}
return 0, 0, fmt.Errorf("expected 'odd' or 'even', but found '%s' instead", id)
default:
goto invalid
}
positiveA:
if p.i >= len(p.s) {
goto eof
}
switch p.s[p.i] {
case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
a, err = p.parseInteger()
if err != nil {
return 0, 0, err
}
goto readA
case 'n', 'N':
a = 1
p.i++
goto readN
default:
goto invalid
}
negativeA:
if p.i >= len(p.s) {
goto eof
}
switch p.s[p.i] {
case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
a, err = p.parseInteger()
if err != nil {
return 0, 0, err
}
a = -a
goto readA
case 'n', 'N':
a = -1
p.i++
goto readN
default:
goto invalid
}
readA:
if p.i >= len(p.s) {
goto eof
}
switch p.s[p.i] {
case 'n', 'N':
p.i++
goto readN
default:
// The number we read as a is actually b.
return 0, a, nil
}
readN:
p.skipWhitespace()
if p.i >= len(p.s) {
goto eof
}
switch p.s[p.i] {
case '+':
p.i++
p.skipWhitespace()
b, err = p.parseInteger()
if err != nil {
return 0, 0, err
}
return a, b, nil
case '-':
p.i++
p.skipWhitespace()
b, err = p.parseInteger()
if err != nil {
return 0, 0, err
}
return a, -b, nil
default:
return a, 0, nil
}
eof:
return 0, 0, errors.New("unexpected EOF while attempting to parse expression of form an+b")
invalid:
return 0, 0, errors.New("unexpected character while attempting to parse expression of form an+b")
}
// parseSimpleSelectorSequence parses a selector sequence that applies to
// a single element.
func (p *parser) parseSimpleSelectorSequence() (Selector, error) {
var result Selector
if p.i >= len(p.s) {
return nil, errors.New("expected selector, found EOF instead")
}
switch p.s[p.i] {
case '*':
// It's the universal selector. Just skip over it, since it doesn't affect the meaning.
p.i++
case '#', '.', '[', ':':
// There's no type selector. Wait to process the other till the main loop.
default:
r, err := p.parseTypeSelector()
if err != nil {
return nil, err
}
result = r
}
loop:
for p.i < len(p.s) {
var ns Selector
var err error
switch p.s[p.i] {
case '#':
ns, err = p.parseIDSelector()
case '.':
ns, err = p.parseClassSelector()
case '[':
ns, err = p.parseAttributeSelector()
case ':':
ns, err = p.parsePseudoclassSelector()
default:
break loop
}
if err != nil {
return nil, err
}
if result == nil {
result = ns
} else {
result = intersectionSelector(result, ns)
}
}
if result == nil {
result = func(n *html.Node) bool {
return n.Type == html.ElementNode
}
}
return result, nil
}
// parseSelector parses a selector that may include combinators.
func (p *parser) parseSelector() (result Selector, err error) {
p.skipWhitespace()
result, err = p.parseSimpleSelectorSequence()
if err != nil {
return
}
for {
var combinator byte
if p.skipWhitespace() {
combinator = ' '
}
if p.i >= len(p.s) {
return
}
switch p.s[p.i] {
case '+', '>', '~':
combinator = p.s[p.i]
p.i++
p.skipWhitespace()
case ',', ')':
// These characters can't begin a selector, but they can legally occur after one.
return
}
if combinator == 0 {
return
}
c, err := p.parseSimpleSelectorSequence()
if err != nil {
return nil, err
}
switch combinator {
case ' ':
result = descendantSelector(result, c)
case '>':
result = childSelector(result, c)
case '+':
result = siblingSelector(result, c, true)
case '~':
result = siblingSelector(result, c, false)
}
}
panic("unreachable")
}
// parseSelectorGroup parses a group of selectors, separated by commas.
func (p *parser) parseSelectorGroup() (result Selector, err error) {
result, err = p.parseSelector()
if err != nil {
return
}
for p.i < len(p.s) {
if p.s[p.i] != ',' {
return result, nil
}
p.i++
c, err := p.parseSelector()
if err != nil {
return nil, err
}
result = unionSelector(result, c)
}
return
}

622
vendor/github.com/andybalholm/cascadia/selector.go generated vendored Normal file
View File

@@ -0,0 +1,622 @@
package cascadia
import (
"bytes"
"fmt"
"regexp"
"strings"
"golang.org/x/net/html"
)
// the Selector type, and functions for creating them
// A Selector is a function which tells whether a node matches or not.
type Selector func(*html.Node) bool
// hasChildMatch returns whether n has any child that matches a.
func hasChildMatch(n *html.Node, a Selector) bool {
for c := n.FirstChild; c != nil; c = c.NextSibling {
if a(c) {
return true
}
}
return false
}
// hasDescendantMatch performs a depth-first search of n's descendants,
// testing whether any of them match a. It returns true as soon as a match is
// found, or false if no match is found.
func hasDescendantMatch(n *html.Node, a Selector) bool {
for c := n.FirstChild; c != nil; c = c.NextSibling {
if a(c) || (c.Type == html.ElementNode && hasDescendantMatch(c, a)) {
return true
}
}
return false
}
// Compile parses a selector and returns, if successful, a Selector object
// that can be used to match against html.Node objects.
func Compile(sel string) (Selector, error) {
p := &parser{s: sel}
compiled, err := p.parseSelectorGroup()
if err != nil {
return nil, err
}
if p.i < len(sel) {
return nil, fmt.Errorf("parsing %q: %d bytes left over", sel, len(sel)-p.i)
}
return compiled, nil
}
// MustCompile is like Compile, but panics instead of returning an error.
func MustCompile(sel string) Selector {
compiled, err := Compile(sel)
if err != nil {
panic(err)
}
return compiled
}
// MatchAll returns a slice of the nodes that match the selector,
// from n and its children.
func (s Selector) MatchAll(n *html.Node) []*html.Node {
return s.matchAllInto(n, nil)
}
func (s Selector) matchAllInto(n *html.Node, storage []*html.Node) []*html.Node {
if s(n) {
storage = append(storage, n)
}
for child := n.FirstChild; child != nil; child = child.NextSibling {
storage = s.matchAllInto(child, storage)
}
return storage
}
// Match returns true if the node matches the selector.
func (s Selector) Match(n *html.Node) bool {
return s(n)
}
// MatchFirst returns the first node that matches s, from n and its children.
func (s Selector) MatchFirst(n *html.Node) *html.Node {
if s.Match(n) {
return n
}
for c := n.FirstChild; c != nil; c = c.NextSibling {
m := s.MatchFirst(c)
if m != nil {
return m
}
}
return nil
}
// Filter returns the nodes in nodes that match the selector.
func (s Selector) Filter(nodes []*html.Node) (result []*html.Node) {
for _, n := range nodes {
if s(n) {
result = append(result, n)
}
}
return result
}
// typeSelector returns a Selector that matches elements with a given tag name.
func typeSelector(tag string) Selector {
tag = toLowerASCII(tag)
return func(n *html.Node) bool {
return n.Type == html.ElementNode && n.Data == tag
}
}
// toLowerASCII returns s with all ASCII capital letters lowercased.
func toLowerASCII(s string) string {
var b []byte
for i := 0; i < len(s); i++ {
if c := s[i]; 'A' <= c && c <= 'Z' {
if b == nil {
b = make([]byte, len(s))
copy(b, s)
}
b[i] = s[i] + ('a' - 'A')
}
}
if b == nil {
return s
}
return string(b)
}
// attributeSelector returns a Selector that matches elements
// where the attribute named key satisifes the function f.
func attributeSelector(key string, f func(string) bool) Selector {
key = toLowerASCII(key)
return func(n *html.Node) bool {
if n.Type != html.ElementNode {
return false
}
for _, a := range n.Attr {
if a.Key == key && f(a.Val) {
return true
}
}
return false
}
}
// attributeExistsSelector returns a Selector that matches elements that have
// an attribute named key.
func attributeExistsSelector(key string) Selector {
return attributeSelector(key, func(string) bool { return true })
}
// attributeEqualsSelector returns a Selector that matches elements where
// the attribute named key has the value val.
func attributeEqualsSelector(key, val string) Selector {
return attributeSelector(key,
func(s string) bool {
return s == val
})
}
// attributeNotEqualSelector returns a Selector that matches elements where
// the attribute named key does not have the value val.
func attributeNotEqualSelector(key, val string) Selector {
key = toLowerASCII(key)
return func(n *html.Node) bool {
if n.Type != html.ElementNode {
return false
}
for _, a := range n.Attr {
if a.Key == key && a.Val == val {
return false
}
}
return true
}
}
// attributeIncludesSelector returns a Selector that matches elements where
// the attribute named key is a whitespace-separated list that includes val.
func attributeIncludesSelector(key, val string) Selector {
return attributeSelector(key,
func(s string) bool {
for s != "" {
i := strings.IndexAny(s, " \t\r\n\f")
if i == -1 {
return s == val
}
if s[:i] == val {
return true
}
s = s[i+1:]
}
return false
})
}
// attributeDashmatchSelector returns a Selector that matches elements where
// the attribute named key equals val or starts with val plus a hyphen.
func attributeDashmatchSelector(key, val string) Selector {
return attributeSelector(key,
func(s string) bool {
if s == val {
return true
}
if len(s) <= len(val) {
return false
}
if s[:len(val)] == val && s[len(val)] == '-' {
return true
}
return false
})
}
// attributePrefixSelector returns a Selector that matches elements where
// the attribute named key starts with val.
func attributePrefixSelector(key, val string) Selector {
return attributeSelector(key,
func(s string) bool {
if strings.TrimSpace(s) == "" {
return false
}
return strings.HasPrefix(s, val)
})
}
// attributeSuffixSelector returns a Selector that matches elements where
// the attribute named key ends with val.
func attributeSuffixSelector(key, val string) Selector {
return attributeSelector(key,
func(s string) bool {
if strings.TrimSpace(s) == "" {
return false
}
return strings.HasSuffix(s, val)
})
}
// attributeSubstringSelector returns a Selector that matches nodes where
// the attribute named key contains val.
func attributeSubstringSelector(key, val string) Selector {
return attributeSelector(key,
func(s string) bool {
if strings.TrimSpace(s) == "" {
return false
}
return strings.Contains(s, val)
})
}
// attributeRegexSelector returns a Selector that matches nodes where
// the attribute named key matches the regular expression rx
func attributeRegexSelector(key string, rx *regexp.Regexp) Selector {
return attributeSelector(key,
func(s string) bool {
return rx.MatchString(s)
})
}
// intersectionSelector returns a selector that matches nodes that match
// both a and b.
func intersectionSelector(a, b Selector) Selector {
return func(n *html.Node) bool {
return a(n) && b(n)
}
}
// unionSelector returns a selector that matches elements that match
// either a or b.
func unionSelector(a, b Selector) Selector {
return func(n *html.Node) bool {
return a(n) || b(n)
}
}
// negatedSelector returns a selector that matches elements that do not match a.
func negatedSelector(a Selector) Selector {
return func(n *html.Node) bool {
if n.Type != html.ElementNode {
return false
}
return !a(n)
}
}
// writeNodeText writes the text contained in n and its descendants to b.
func writeNodeText(n *html.Node, b *bytes.Buffer) {
switch n.Type {
case html.TextNode:
b.WriteString(n.Data)
case html.ElementNode:
for c := n.FirstChild; c != nil; c = c.NextSibling {
writeNodeText(c, b)
}
}
}
// nodeText returns the text contained in n and its descendants.
func nodeText(n *html.Node) string {
var b bytes.Buffer
writeNodeText(n, &b)
return b.String()
}
// nodeOwnText returns the contents of the text nodes that are direct
// children of n.
func nodeOwnText(n *html.Node) string {
var b bytes.Buffer
for c := n.FirstChild; c != nil; c = c.NextSibling {
if c.Type == html.TextNode {
b.WriteString(c.Data)
}
}
return b.String()
}
// textSubstrSelector returns a selector that matches nodes that
// contain the given text.
func textSubstrSelector(val string) Selector {
return func(n *html.Node) bool {
text := strings.ToLower(nodeText(n))
return strings.Contains(text, val)
}
}
// ownTextSubstrSelector returns a selector that matches nodes that
// directly contain the given text
func ownTextSubstrSelector(val string) Selector {
return func(n *html.Node) bool {
text := strings.ToLower(nodeOwnText(n))
return strings.Contains(text, val)
}
}
// textRegexSelector returns a selector that matches nodes whose text matches
// the specified regular expression
func textRegexSelector(rx *regexp.Regexp) Selector {
return func(n *html.Node) bool {
return rx.MatchString(nodeText(n))
}
}
// ownTextRegexSelector returns a selector that matches nodes whose text
// directly matches the specified regular expression
func ownTextRegexSelector(rx *regexp.Regexp) Selector {
return func(n *html.Node) bool {
return rx.MatchString(nodeOwnText(n))
}
}
// hasChildSelector returns a selector that matches elements
// with a child that matches a.
func hasChildSelector(a Selector) Selector {
return func(n *html.Node) bool {
if n.Type != html.ElementNode {
return false
}
return hasChildMatch(n, a)
}
}
// hasDescendantSelector returns a selector that matches elements
// with any descendant that matches a.
func hasDescendantSelector(a Selector) Selector {
return func(n *html.Node) bool {
if n.Type != html.ElementNode {
return false
}
return hasDescendantMatch(n, a)
}
}
// nthChildSelector returns a selector that implements :nth-child(an+b).
// If last is true, implements :nth-last-child instead.
// If ofType is true, implements :nth-of-type instead.
func nthChildSelector(a, b int, last, ofType bool) Selector {
return func(n *html.Node) bool {
if n.Type != html.ElementNode {
return false
}
parent := n.Parent
if parent == nil {
return false
}
if parent.Type == html.DocumentNode {
return false
}
i := -1
count := 0
for c := parent.FirstChild; c != nil; c = c.NextSibling {
if (c.Type != html.ElementNode) || (ofType && c.Data != n.Data) {
continue
}
count++
if c == n {
i = count
if !last {
break
}
}
}
if i == -1 {
// This shouldn't happen, since n should always be one of its parent's children.
return false
}
if last {
i = count - i + 1
}
i -= b
if a == 0 {
return i == 0
}
return i%a == 0 && i/a >= 0
}
}
// simpleNthChildSelector returns a selector that implements :nth-child(b).
// If ofType is true, implements :nth-of-type instead.
func simpleNthChildSelector(b int, ofType bool) Selector {
return func(n *html.Node) bool {
if n.Type != html.ElementNode {
return false
}
parent := n.Parent
if parent == nil {
return false
}
if parent.Type == html.DocumentNode {
return false
}
count := 0
for c := parent.FirstChild; c != nil; c = c.NextSibling {
if c.Type != html.ElementNode || (ofType && c.Data != n.Data) {
continue
}
count++
if c == n {
return count == b
}
if count >= b {
return false
}
}
return false
}
}
// simpleNthLastChildSelector returns a selector that implements
// :nth-last-child(b). If ofType is true, implements :nth-last-of-type
// instead.
func simpleNthLastChildSelector(b int, ofType bool) Selector {
return func(n *html.Node) bool {
if n.Type != html.ElementNode {
return false
}
parent := n.Parent
if parent == nil {
return false
}
if parent.Type == html.DocumentNode {
return false
}
count := 0
for c := parent.LastChild; c != nil; c = c.PrevSibling {
if c.Type != html.ElementNode || (ofType && c.Data != n.Data) {
continue
}
count++
if c == n {
return count == b
}
if count >= b {
return false
}
}
return false
}
}
// onlyChildSelector returns a selector that implements :only-child.
// If ofType is true, it implements :only-of-type instead.
func onlyChildSelector(ofType bool) Selector {
return func(n *html.Node) bool {
if n.Type != html.ElementNode {
return false
}
parent := n.Parent
if parent == nil {
return false
}
if parent.Type == html.DocumentNode {
return false
}
count := 0
for c := parent.FirstChild; c != nil; c = c.NextSibling {
if (c.Type != html.ElementNode) || (ofType && c.Data != n.Data) {
continue
}
count++
if count > 1 {
return false
}
}
return count == 1
}
}
// inputSelector is a Selector that matches input, select, textarea and button elements.
func inputSelector(n *html.Node) bool {
return n.Type == html.ElementNode && (n.Data == "input" || n.Data == "select" || n.Data == "textarea" || n.Data == "button")
}
// emptyElementSelector is a Selector that matches empty elements.
func emptyElementSelector(n *html.Node) bool {
if n.Type != html.ElementNode {
return false
}
for c := n.FirstChild; c != nil; c = c.NextSibling {
switch c.Type {
case html.ElementNode, html.TextNode:
return false
}
}
return true
}
// descendantSelector returns a Selector that matches an element if
// it matches d and has an ancestor that matches a.
func descendantSelector(a, d Selector) Selector {
return func(n *html.Node) bool {
if !d(n) {
return false
}
for p := n.Parent; p != nil; p = p.Parent {
if a(p) {
return true
}
}
return false
}
}
// childSelector returns a Selector that matches an element if
// it matches d and its parent matches a.
func childSelector(a, d Selector) Selector {
return func(n *html.Node) bool {
return d(n) && n.Parent != nil && a(n.Parent)
}
}
// siblingSelector returns a Selector that matches an element
// if it matches s2 and in is preceded by an element that matches s1.
// If adjacent is true, the sibling must be immediately before the element.
func siblingSelector(s1, s2 Selector, adjacent bool) Selector {
return func(n *html.Node) bool {
if !s2(n) {
return false
}
if adjacent {
for n = n.PrevSibling; n != nil; n = n.PrevSibling {
if n.Type == html.TextNode || n.Type == html.CommentNode {
continue
}
return s1(n)
}
return false
}
// Walk backwards looking for element that matches s1
for c := n.PrevSibling; c != nil; c = c.PrevSibling {
if s1(c) {
return true
}
}
return false
}
}
// rootSelector implements :root
func rootSelector(n *html.Node) bool {
if n.Type != html.ElementNode {
return false
}
if n.Parent == nil {
return false
}
return n.Parent.Type == html.DocumentNode
}

10
vendor/github.com/beevik/etree/CONTRIBUTORS generated vendored Normal file
View File

@@ -0,0 +1,10 @@
Brett Vickers (beevik)
Felix Geisendörfer (felixge)
Kamil Kisiel (kisielk)
Graham King (grahamking)
Matt Smith (ma314smith)
Michal Jemala (michaljemala)
Nicolas Piganeau (npiganeau)
Chris Brown (ccbrown)
Earncef Sequeira (earncef)
Gabriel de Labachelerie (wuzuf)

24
vendor/github.com/beevik/etree/LICENSE generated vendored Normal file
View File

@@ -0,0 +1,24 @@
Copyright 2015-2019 Brett Vickers. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY COPYRIGHT HOLDER ``AS IS'' AND ANY
EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL COPYRIGHT HOLDER OR
CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

1453
vendor/github.com/beevik/etree/etree.go generated vendored Normal file

File diff suppressed because it is too large Load Diff

276
vendor/github.com/beevik/etree/helpers.go generated vendored Normal file
View File

@@ -0,0 +1,276 @@
// Copyright 2015-2019 Brett Vickers.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package etree
import (
"bufio"
"io"
"strings"
"unicode/utf8"
)
// A simple stack
type stack struct {
data []interface{}
}
func (s *stack) empty() bool {
return len(s.data) == 0
}
func (s *stack) push(value interface{}) {
s.data = append(s.data, value)
}
func (s *stack) pop() interface{} {
value := s.data[len(s.data)-1]
s.data[len(s.data)-1] = nil
s.data = s.data[:len(s.data)-1]
return value
}
func (s *stack) peek() interface{} {
return s.data[len(s.data)-1]
}
// A fifo is a simple first-in-first-out queue.
type fifo struct {
data []interface{}
head, tail int
}
func (f *fifo) add(value interface{}) {
if f.len()+1 >= len(f.data) {
f.grow()
}
f.data[f.tail] = value
if f.tail++; f.tail == len(f.data) {
f.tail = 0
}
}
func (f *fifo) remove() interface{} {
value := f.data[f.head]
f.data[f.head] = nil
if f.head++; f.head == len(f.data) {
f.head = 0
}
return value
}
func (f *fifo) len() int {
if f.tail >= f.head {
return f.tail - f.head
}
return len(f.data) - f.head + f.tail
}
func (f *fifo) grow() {
c := len(f.data) * 2
if c == 0 {
c = 4
}
buf, count := make([]interface{}, c), f.len()
if f.tail >= f.head {
copy(buf[0:count], f.data[f.head:f.tail])
} else {
hindex := len(f.data) - f.head
copy(buf[0:hindex], f.data[f.head:])
copy(buf[hindex:count], f.data[:f.tail])
}
f.data, f.head, f.tail = buf, 0, count
}
// countReader implements a proxy reader that counts the number of
// bytes read from its encapsulated reader.
type countReader struct {
r io.Reader
bytes int64
}
func newCountReader(r io.Reader) *countReader {
return &countReader{r: r}
}
func (cr *countReader) Read(p []byte) (n int, err error) {
b, err := cr.r.Read(p)
cr.bytes += int64(b)
return b, err
}
// countWriter implements a proxy writer that counts the number of
// bytes written by its encapsulated writer.
type countWriter struct {
w io.Writer
bytes int64
}
func newCountWriter(w io.Writer) *countWriter {
return &countWriter{w: w}
}
func (cw *countWriter) Write(p []byte) (n int, err error) {
b, err := cw.w.Write(p)
cw.bytes += int64(b)
return b, err
}
// isWhitespace returns true if the byte slice contains only
// whitespace characters.
func isWhitespace(s string) bool {
for i := 0; i < len(s); i++ {
if c := s[i]; c != ' ' && c != '\t' && c != '\n' && c != '\r' {
return false
}
}
return true
}
// spaceMatch returns true if namespace a is the empty string
// or if namespace a equals namespace b.
func spaceMatch(a, b string) bool {
switch {
case a == "":
return true
default:
return a == b
}
}
// spaceDecompose breaks a namespace:tag identifier at the ':'
// and returns the two parts.
func spaceDecompose(str string) (space, key string) {
colon := strings.IndexByte(str, ':')
if colon == -1 {
return "", str
}
return str[:colon], str[colon+1:]
}
// Strings used by indentCRLF and indentLF
const (
indentSpaces = "\r\n "
indentTabs = "\r\n\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t"
)
// indentCRLF returns a CRLF newline followed by n copies of the first
// non-CRLF character in the source string.
func indentCRLF(n int, source string) string {
switch {
case n < 0:
return source[:2]
case n < len(source)-1:
return source[:n+2]
default:
return source + strings.Repeat(source[2:3], n-len(source)+2)
}
}
// indentLF returns a LF newline followed by n copies of the first non-LF
// character in the source string.
func indentLF(n int, source string) string {
switch {
case n < 0:
return source[1:2]
case n < len(source)-1:
return source[1 : n+2]
default:
return source[1:] + strings.Repeat(source[2:3], n-len(source)+2)
}
}
// nextIndex returns the index of the next occurrence of sep in s,
// starting from offset. It returns -1 if the sep string is not found.
func nextIndex(s, sep string, offset int) int {
switch i := strings.Index(s[offset:], sep); i {
case -1:
return -1
default:
return offset + i
}
}
// isInteger returns true if the string s contains an integer.
func isInteger(s string) bool {
for i := 0; i < len(s); i++ {
if (s[i] < '0' || s[i] > '9') && !(i == 0 && s[i] == '-') {
return false
}
}
return true
}
type escapeMode byte
const (
escapeNormal escapeMode = iota
escapeCanonicalText
escapeCanonicalAttr
)
// escapeString writes an escaped version of a string to the writer.
func escapeString(w *bufio.Writer, s string, m escapeMode) {
var esc []byte
last := 0
for i := 0; i < len(s); {
r, width := utf8.DecodeRuneInString(s[i:])
i += width
switch r {
case '&':
esc = []byte("&amp;")
case '<':
esc = []byte("&lt;")
case '>':
if m == escapeCanonicalAttr {
continue
}
esc = []byte("&gt;")
case '\'':
if m != escapeNormal {
continue
}
esc = []byte("&apos;")
case '"':
if m == escapeCanonicalText {
continue
}
esc = []byte("&quot;")
case '\t':
if m != escapeCanonicalAttr {
continue
}
esc = []byte("&#x9;")
case '\n':
if m != escapeCanonicalAttr {
continue
}
esc = []byte("&#xA;")
case '\r':
if m == escapeNormal {
continue
}
esc = []byte("&#xD;")
default:
if !isInCharacterRange(r) || (r == 0xFFFD && width == 1) {
esc = []byte("\uFFFD")
break
}
continue
}
w.WriteString(s[last : i-width])
w.Write(esc)
last = i
}
w.WriteString(s[last:])
}
func isInCharacterRange(r rune) bool {
return r == 0x09 ||
r == 0x0A ||
r == 0x0D ||
r >= 0x20 && r <= 0xD7FF ||
r >= 0xE000 && r <= 0xFFFD ||
r >= 0x10000 && r <= 0x10FFFF
}

582
vendor/github.com/beevik/etree/path.go generated vendored Normal file
View File

@@ -0,0 +1,582 @@
// Copyright 2015-2019 Brett Vickers.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package etree
import (
"strconv"
"strings"
)
/*
A Path is a string that represents a search path through an etree starting
from the document root or an arbitrary element. Paths are used with the
Element object's Find* methods to locate and return desired elements.
A Path consists of a series of slash-separated "selectors", each of which may
be modified by one or more bracket-enclosed "filters". Selectors are used to
traverse the etree from element to element, while filters are used to narrow
the list of candidate elements at each node.
Although etree Path strings are similar to XPath strings
(https://www.w3.org/TR/1999/REC-xpath-19991116/), they have a more limited set
of selectors and filtering options.
The following selectors are supported by etree Path strings:
. Select the current element.
.. Select the parent of the current element.
* Select all child elements of the current element.
/ Select the root element when used at the start of a path.
// Select all descendants of the current element.
tag Select all child elements with a name matching the tag.
The following basic filters are supported by etree Path strings:
[@attrib] Keep elements with an attribute named attrib.
[@attrib='val'] Keep elements with an attribute named attrib and value matching val.
[tag] Keep elements with a child element named tag.
[tag='val'] Keep elements with a child element named tag and text matching val.
[n] Keep the n-th element, where n is a numeric index starting from 1.
The following function filters are also supported:
[text()] Keep elements with non-empty text.
[text()='val'] Keep elements whose text matches val.
[local-name()='val'] Keep elements whose un-prefixed tag matches val.
[name()='val'] Keep elements whose full tag exactly matches val.
[namespace-prefix()='val'] Keep elements whose namespace prefix matches val.
[namespace-uri()='val'] Keep elements whose namespace URI matches val.
Here are some examples of Path strings:
- Select the bookstore child element of the root element:
/bookstore
- Beginning from the root element, select the title elements of all
descendant book elements having a 'category' attribute of 'WEB':
//book[@category='WEB']/title
- Beginning from the current element, select the first descendant
book element with a title child element containing the text 'Great
Expectations':
.//book[title='Great Expectations'][1]
- Beginning from the current element, select all child elements of
book elements with an attribute 'language' set to 'english':
./book/*[@language='english']
- Beginning from the current element, select all child elements of
book elements containing the text 'special':
./book/*[text()='special']
- Beginning from the current element, select all descendant book
elements whose title child element has a 'language' attribute of 'french':
.//book/title[@language='french']/..
- Beginning from the current element, select all book elements
belonging to the http://www.w3.org/TR/html4/ namespace:
.//book[namespace-uri()='http://www.w3.org/TR/html4/']
*/
type Path struct {
segments []segment
}
// ErrPath is returned by path functions when an invalid etree path is provided.
type ErrPath string
// Error returns the string describing a path error.
func (err ErrPath) Error() string {
return "etree: " + string(err)
}
// CompilePath creates an optimized version of an XPath-like string that
// can be used to query elements in an element tree.
func CompilePath(path string) (Path, error) {
var comp compiler
segments := comp.parsePath(path)
if comp.err != ErrPath("") {
return Path{nil}, comp.err
}
return Path{segments}, nil
}
// MustCompilePath creates an optimized version of an XPath-like string that
// can be used to query elements in an element tree. Panics if an error
// occurs. Use this function to create Paths when you know the path is
// valid (i.e., if it's hard-coded).
func MustCompilePath(path string) Path {
p, err := CompilePath(path)
if err != nil {
panic(err)
}
return p
}
// A segment is a portion of a path between "/" characters.
// It contains one selector and zero or more [filters].
type segment struct {
sel selector
filters []filter
}
func (seg *segment) apply(e *Element, p *pather) {
seg.sel.apply(e, p)
for _, f := range seg.filters {
f.apply(p)
}
}
// A selector selects XML elements for consideration by the
// path traversal.
type selector interface {
apply(e *Element, p *pather)
}
// A filter pares down a list of candidate XML elements based
// on a path filter in [brackets].
type filter interface {
apply(p *pather)
}
// A pather is helper object that traverses an element tree using
// a Path object. It collects and deduplicates all elements matching
// the path query.
type pather struct {
queue fifo
results []*Element
inResults map[*Element]bool
candidates []*Element
scratch []*Element // used by filters
}
// A node represents an element and the remaining path segments that
// should be applied against it by the pather.
type node struct {
e *Element
segments []segment
}
func newPather() *pather {
return &pather{
results: make([]*Element, 0),
inResults: make(map[*Element]bool),
candidates: make([]*Element, 0),
scratch: make([]*Element, 0),
}
}
// traverse follows the path from the element e, collecting
// and then returning all elements that match the path's selectors
// and filters.
func (p *pather) traverse(e *Element, path Path) []*Element {
for p.queue.add(node{e, path.segments}); p.queue.len() > 0; {
p.eval(p.queue.remove().(node))
}
return p.results
}
// eval evalutes the current path node by applying the remaining
// path's selector rules against the node's element.
func (p *pather) eval(n node) {
p.candidates = p.candidates[0:0]
seg, remain := n.segments[0], n.segments[1:]
seg.apply(n.e, p)
if len(remain) == 0 {
for _, c := range p.candidates {
if in := p.inResults[c]; !in {
p.inResults[c] = true
p.results = append(p.results, c)
}
}
} else {
for _, c := range p.candidates {
p.queue.add(node{c, remain})
}
}
}
// A compiler generates a compiled path from a path string.
type compiler struct {
err ErrPath
}
// parsePath parses an XPath-like string describing a path
// through an element tree and returns a slice of segment
// descriptors.
func (c *compiler) parsePath(path string) []segment {
// If path ends with //, fix it
if strings.HasSuffix(path, "//") {
path = path + "*"
}
var segments []segment
// Check for an absolute path
if strings.HasPrefix(path, "/") {
segments = append(segments, segment{new(selectRoot), []filter{}})
path = path[1:]
}
// Split path into segments
for _, s := range splitPath(path) {
segments = append(segments, c.parseSegment(s))
if c.err != ErrPath("") {
break
}
}
return segments
}
func splitPath(path string) []string {
pieces := make([]string, 0)
start := 0
inquote := false
for i := 0; i+1 <= len(path); i++ {
if path[i] == '\'' {
inquote = !inquote
} else if path[i] == '/' && !inquote {
pieces = append(pieces, path[start:i])
start = i + 1
}
}
return append(pieces, path[start:])
}
// parseSegment parses a path segment between / characters.
func (c *compiler) parseSegment(path string) segment {
pieces := strings.Split(path, "[")
seg := segment{
sel: c.parseSelector(pieces[0]),
filters: []filter{},
}
for i := 1; i < len(pieces); i++ {
fpath := pieces[i]
if fpath[len(fpath)-1] != ']' {
c.err = ErrPath("path has invalid filter [brackets].")
break
}
seg.filters = append(seg.filters, c.parseFilter(fpath[:len(fpath)-1]))
}
return seg
}
// parseSelector parses a selector at the start of a path segment.
func (c *compiler) parseSelector(path string) selector {
switch path {
case ".":
return new(selectSelf)
case "..":
return new(selectParent)
case "*":
return new(selectChildren)
case "":
return new(selectDescendants)
default:
return newSelectChildrenByTag(path)
}
}
var fnTable = map[string]struct {
hasFn func(e *Element) bool
getValFn func(e *Element) string
}{
"local-name": {nil, (*Element).name},
"name": {nil, (*Element).FullTag},
"namespace-prefix": {nil, (*Element).namespacePrefix},
"namespace-uri": {nil, (*Element).NamespaceURI},
"text": {(*Element).hasText, (*Element).Text},
}
// parseFilter parses a path filter contained within [brackets].
func (c *compiler) parseFilter(path string) filter {
if len(path) == 0 {
c.err = ErrPath("path contains an empty filter expression.")
return nil
}
// Filter contains [@attr='val'], [fn()='val'], or [tag='val']?
eqindex := strings.Index(path, "='")
if eqindex >= 0 {
rindex := nextIndex(path, "'", eqindex+2)
if rindex != len(path)-1 {
c.err = ErrPath("path has mismatched filter quotes.")
return nil
}
key := path[:eqindex]
value := path[eqindex+2 : rindex]
switch {
case key[0] == '@':
return newFilterAttrVal(key[1:], value)
case strings.HasSuffix(key, "()"):
fn := key[:len(key)-2]
if t, ok := fnTable[fn]; ok && t.getValFn != nil {
return newFilterFuncVal(t.getValFn, value)
}
c.err = ErrPath("path has unknown function " + fn)
return nil
default:
return newFilterChildText(key, value)
}
}
// Filter contains [@attr], [N], [tag] or [fn()]
switch {
case path[0] == '@':
return newFilterAttr(path[1:])
case strings.HasSuffix(path, "()"):
fn := path[:len(path)-2]
if t, ok := fnTable[fn]; ok && t.hasFn != nil {
return newFilterFunc(t.hasFn)
}
c.err = ErrPath("path has unknown function " + fn)
return nil
case isInteger(path):
pos, _ := strconv.Atoi(path)
switch {
case pos > 0:
return newFilterPos(pos - 1)
default:
return newFilterPos(pos)
}
default:
return newFilterChild(path)
}
}
// selectSelf selects the current element into the candidate list.
type selectSelf struct{}
func (s *selectSelf) apply(e *Element, p *pather) {
p.candidates = append(p.candidates, e)
}
// selectRoot selects the element's root node.
type selectRoot struct{}
func (s *selectRoot) apply(e *Element, p *pather) {
root := e
for root.parent != nil {
root = root.parent
}
p.candidates = append(p.candidates, root)
}
// selectParent selects the element's parent into the candidate list.
type selectParent struct{}
func (s *selectParent) apply(e *Element, p *pather) {
if e.parent != nil {
p.candidates = append(p.candidates, e.parent)
}
}
// selectChildren selects the element's child elements into the
// candidate list.
type selectChildren struct{}
func (s *selectChildren) apply(e *Element, p *pather) {
for _, c := range e.Child {
if c, ok := c.(*Element); ok {
p.candidates = append(p.candidates, c)
}
}
}
// selectDescendants selects all descendant child elements
// of the element into the candidate list.
type selectDescendants struct{}
func (s *selectDescendants) apply(e *Element, p *pather) {
var queue fifo
for queue.add(e); queue.len() > 0; {
e := queue.remove().(*Element)
p.candidates = append(p.candidates, e)
for _, c := range e.Child {
if c, ok := c.(*Element); ok {
queue.add(c)
}
}
}
}
// selectChildrenByTag selects into the candidate list all child
// elements of the element having the specified tag.
type selectChildrenByTag struct {
space, tag string
}
func newSelectChildrenByTag(path string) *selectChildrenByTag {
s, l := spaceDecompose(path)
return &selectChildrenByTag{s, l}
}
func (s *selectChildrenByTag) apply(e *Element, p *pather) {
for _, c := range e.Child {
if c, ok := c.(*Element); ok && spaceMatch(s.space, c.Space) && s.tag == c.Tag {
p.candidates = append(p.candidates, c)
}
}
}
// filterPos filters the candidate list, keeping only the
// candidate at the specified index.
type filterPos struct {
index int
}
func newFilterPos(pos int) *filterPos {
return &filterPos{pos}
}
func (f *filterPos) apply(p *pather) {
if f.index >= 0 {
if f.index < len(p.candidates) {
p.scratch = append(p.scratch, p.candidates[f.index])
}
} else {
if -f.index <= len(p.candidates) {
p.scratch = append(p.scratch, p.candidates[len(p.candidates)+f.index])
}
}
p.candidates, p.scratch = p.scratch, p.candidates[0:0]
}
// filterAttr filters the candidate list for elements having
// the specified attribute.
type filterAttr struct {
space, key string
}
func newFilterAttr(str string) *filterAttr {
s, l := spaceDecompose(str)
return &filterAttr{s, l}
}
func (f *filterAttr) apply(p *pather) {
for _, c := range p.candidates {
for _, a := range c.Attr {
if spaceMatch(f.space, a.Space) && f.key == a.Key {
p.scratch = append(p.scratch, c)
break
}
}
}
p.candidates, p.scratch = p.scratch, p.candidates[0:0]
}
// filterAttrVal filters the candidate list for elements having
// the specified attribute with the specified value.
type filterAttrVal struct {
space, key, val string
}
func newFilterAttrVal(str, value string) *filterAttrVal {
s, l := spaceDecompose(str)
return &filterAttrVal{s, l, value}
}
func (f *filterAttrVal) apply(p *pather) {
for _, c := range p.candidates {
for _, a := range c.Attr {
if spaceMatch(f.space, a.Space) && f.key == a.Key && f.val == a.Value {
p.scratch = append(p.scratch, c)
break
}
}
}
p.candidates, p.scratch = p.scratch, p.candidates[0:0]
}
// filterFunc filters the candidate list for elements satisfying a custom
// boolean function.
type filterFunc struct {
fn func(e *Element) bool
}
func newFilterFunc(fn func(e *Element) bool) *filterFunc {
return &filterFunc{fn}
}
func (f *filterFunc) apply(p *pather) {
for _, c := range p.candidates {
if f.fn(c) {
p.scratch = append(p.scratch, c)
}
}
p.candidates, p.scratch = p.scratch, p.candidates[0:0]
}
// filterFuncVal filters the candidate list for elements containing a value
// matching the result of a custom function.
type filterFuncVal struct {
fn func(e *Element) string
val string
}
func newFilterFuncVal(fn func(e *Element) string, value string) *filterFuncVal {
return &filterFuncVal{fn, value}
}
func (f *filterFuncVal) apply(p *pather) {
for _, c := range p.candidates {
if f.fn(c) == f.val {
p.scratch = append(p.scratch, c)
}
}
p.candidates, p.scratch = p.scratch, p.candidates[0:0]
}
// filterChild filters the candidate list for elements having
// a child element with the specified tag.
type filterChild struct {
space, tag string
}
func newFilterChild(str string) *filterChild {
s, l := spaceDecompose(str)
return &filterChild{s, l}
}
func (f *filterChild) apply(p *pather) {
for _, c := range p.candidates {
for _, cc := range c.Child {
if cc, ok := cc.(*Element); ok &&
spaceMatch(f.space, cc.Space) &&
f.tag == cc.Tag {
p.scratch = append(p.scratch, c)
}
}
}
p.candidates, p.scratch = p.scratch, p.candidates[0:0]
}
// filterChildText filters the candidate list for elements having
// a child element with the specified tag and text.
type filterChildText struct {
space, tag, text string
}
func newFilterChildText(str, text string) *filterChildText {
s, l := spaceDecompose(str)
return &filterChildText{s, l, text}
}
func (f *filterChildText) apply(p *pather) {
for _, c := range p.candidates {
for _, cc := range c.Child {
if cc, ok := cc.(*Element); ok &&
spaceMatch(f.space, cc.Space) &&
f.tag == cc.Tag &&
f.text == cc.Text() {
p.scratch = append(p.scratch, c)
}
}
}
p.candidates, p.scratch = p.scratch, p.candidates[0:0]
}

View File

@@ -1,21 +0,0 @@
The MIT License (MIT)
Copyright (c) 2013-NOW Jinzhu <wosmvp@gmail.com>
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.

View File

@@ -1,377 +0,0 @@
package gorm
import (
"errors"
"fmt"
"reflect"
)
// Association Mode contains some helper methods to handle relationship things easily.
type Association struct {
Error error
scope *Scope
column string
field *Field
}
// Find find out all related associations
func (association *Association) Find(value interface{}) *Association {
association.scope.related(value, association.column)
return association.setErr(association.scope.db.Error)
}
// Append append new associations for many2many, has_many, replace current association for has_one, belongs_to
func (association *Association) Append(values ...interface{}) *Association {
if association.Error != nil {
return association
}
if relationship := association.field.Relationship; relationship.Kind == "has_one" {
return association.Replace(values...)
}
return association.saveAssociations(values...)
}
// Replace replace current associations with new one
func (association *Association) Replace(values ...interface{}) *Association {
if association.Error != nil {
return association
}
var (
relationship = association.field.Relationship
scope = association.scope
field = association.field.Field
newDB = scope.NewDB()
)
// Append new values
association.field.Set(reflect.Zero(association.field.Field.Type()))
association.saveAssociations(values...)
// Belongs To
if relationship.Kind == "belongs_to" {
// Set foreign key to be null when clearing value (length equals 0)
if len(values) == 0 {
// Set foreign key to be nil
var foreignKeyMap = map[string]interface{}{}
for _, foreignKey := range relationship.ForeignDBNames {
foreignKeyMap[foreignKey] = nil
}
association.setErr(newDB.Model(scope.Value).UpdateColumn(foreignKeyMap).Error)
}
} else {
// Polymorphic Relations
if relationship.PolymorphicDBName != "" {
newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), relationship.PolymorphicValue)
}
// Delete Relations except new created
if len(values) > 0 {
var associationForeignFieldNames, associationForeignDBNames []string
if relationship.Kind == "many_to_many" {
// if many to many relations, get association fields name from association foreign keys
associationScope := scope.New(reflect.New(field.Type()).Interface())
for idx, dbName := range relationship.AssociationForeignFieldNames {
if field, ok := associationScope.FieldByName(dbName); ok {
associationForeignFieldNames = append(associationForeignFieldNames, field.Name)
associationForeignDBNames = append(associationForeignDBNames, relationship.AssociationForeignDBNames[idx])
}
}
} else {
// If has one/many relations, use primary keys
for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() {
associationForeignFieldNames = append(associationForeignFieldNames, field.Name)
associationForeignDBNames = append(associationForeignDBNames, field.DBName)
}
}
newPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, field.Interface())
if len(newPrimaryKeys) > 0 {
sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, associationForeignDBNames), toQueryMarks(newPrimaryKeys))
newDB = newDB.Where(sql, toQueryValues(newPrimaryKeys)...)
}
}
if relationship.Kind == "many_to_many" {
// if many to many relations, delete related relations from join table
var sourceForeignFieldNames []string
for _, dbName := range relationship.ForeignFieldNames {
if field, ok := scope.FieldByName(dbName); ok {
sourceForeignFieldNames = append(sourceForeignFieldNames, field.Name)
}
}
if sourcePrimaryKeys := scope.getColumnAsArray(sourceForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 {
newDB = newDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(sourcePrimaryKeys)), toQueryValues(sourcePrimaryKeys)...)
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB))
}
} else if relationship.Kind == "has_one" || relationship.Kind == "has_many" {
// has_one or has_many relations, set foreign key to be nil (TODO or delete them?)
var foreignKeyMap = map[string]interface{}{}
for idx, foreignKey := range relationship.ForeignDBNames {
foreignKeyMap[foreignKey] = nil
if field, ok := scope.FieldByName(relationship.AssociationForeignFieldNames[idx]); ok {
newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
}
}
fieldValue := reflect.New(association.field.Field.Type()).Interface()
association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error)
}
}
return association
}
// Delete remove relationship between source & passed arguments, but won't delete those arguments
func (association *Association) Delete(values ...interface{}) *Association {
if association.Error != nil {
return association
}
var (
relationship = association.field.Relationship
scope = association.scope
field = association.field.Field
newDB = scope.NewDB()
)
if len(values) == 0 {
return association
}
var deletingResourcePrimaryFieldNames, deletingResourcePrimaryDBNames []string
for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() {
deletingResourcePrimaryFieldNames = append(deletingResourcePrimaryFieldNames, field.Name)
deletingResourcePrimaryDBNames = append(deletingResourcePrimaryDBNames, field.DBName)
}
deletingPrimaryKeys := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, values...)
if relationship.Kind == "many_to_many" {
// source value's foreign keys
for idx, foreignKey := range relationship.ForeignDBNames {
if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok {
newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
}
}
// get association's foreign fields name
var associationScope = scope.New(reflect.New(field.Type()).Interface())
var associationForeignFieldNames []string
for _, associationDBName := range relationship.AssociationForeignFieldNames {
if field, ok := associationScope.FieldByName(associationDBName); ok {
associationForeignFieldNames = append(associationForeignFieldNames, field.Name)
}
}
// association value's foreign keys
deletingPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, values...)
sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(deletingPrimaryKeys))
newDB = newDB.Where(sql, toQueryValues(deletingPrimaryKeys)...)
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB))
} else {
var foreignKeyMap = map[string]interface{}{}
for _, foreignKey := range relationship.ForeignDBNames {
foreignKeyMap[foreignKey] = nil
}
if relationship.Kind == "belongs_to" {
// find with deleting relation's foreign keys
primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, values...)
newDB = newDB.Where(
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
toQueryValues(primaryKeys)...,
)
// set foreign key to be null if there are some records affected
modelValue := reflect.New(scope.GetModelStruct().ModelType).Interface()
if results := newDB.Model(modelValue).UpdateColumn(foreignKeyMap); results.Error == nil {
if results.RowsAffected > 0 {
scope.updatedAttrsWithValues(foreignKeyMap)
}
} else {
association.setErr(results.Error)
}
} else if relationship.Kind == "has_one" || relationship.Kind == "has_many" {
// find all relations
primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value)
newDB = newDB.Where(
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
toQueryValues(primaryKeys)...,
)
// only include those deleting relations
newDB = newDB.Where(
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, deletingResourcePrimaryDBNames), toQueryMarks(deletingPrimaryKeys)),
toQueryValues(deletingPrimaryKeys)...,
)
// set matched relation's foreign key to be null
fieldValue := reflect.New(association.field.Field.Type()).Interface()
association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error)
}
}
// Remove deleted records from source's field
if association.Error == nil {
if field.Kind() == reflect.Slice {
leftValues := reflect.Zero(field.Type())
for i := 0; i < field.Len(); i++ {
reflectValue := field.Index(i)
primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, reflectValue.Interface())[0]
var isDeleted = false
for _, pk := range deletingPrimaryKeys {
if equalAsString(primaryKey, pk) {
isDeleted = true
break
}
}
if !isDeleted {
leftValues = reflect.Append(leftValues, reflectValue)
}
}
association.field.Set(leftValues)
} else if field.Kind() == reflect.Struct {
primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, field.Interface())[0]
for _, pk := range deletingPrimaryKeys {
if equalAsString(primaryKey, pk) {
association.field.Set(reflect.Zero(field.Type()))
break
}
}
}
}
return association
}
// Clear remove relationship between source & current associations, won't delete those associations
func (association *Association) Clear() *Association {
return association.Replace()
}
// Count return the count of current associations
func (association *Association) Count() int {
var (
count = 0
relationship = association.field.Relationship
scope = association.scope
fieldValue = association.field.Field.Interface()
query = scope.DB()
)
switch relationship.Kind {
case "many_to_many":
query = relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, query, scope.Value)
case "has_many", "has_one":
primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value)
query = query.Where(
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
toQueryValues(primaryKeys)...,
)
case "belongs_to":
primaryKeys := scope.getColumnAsArray(relationship.ForeignFieldNames, scope.Value)
query = query.Where(
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys)),
toQueryValues(primaryKeys)...,
)
}
if relationship.PolymorphicType != "" {
query = query.Where(
fmt.Sprintf("%v.%v = ?", scope.New(fieldValue).QuotedTableName(), scope.Quote(relationship.PolymorphicDBName)),
relationship.PolymorphicValue,
)
}
if err := query.Model(fieldValue).Count(&count).Error; err != nil {
association.Error = err
}
return count
}
// saveAssociations save passed values as associations
func (association *Association) saveAssociations(values ...interface{}) *Association {
var (
scope = association.scope
field = association.field
relationship = field.Relationship
)
saveAssociation := func(reflectValue reflect.Value) {
// value has to been pointer
if reflectValue.Kind() != reflect.Ptr {
reflectPtr := reflect.New(reflectValue.Type())
reflectPtr.Elem().Set(reflectValue)
reflectValue = reflectPtr
}
// value has to been saved for many2many
if relationship.Kind == "many_to_many" {
if scope.New(reflectValue.Interface()).PrimaryKeyZero() {
association.setErr(scope.NewDB().Save(reflectValue.Interface()).Error)
}
}
// Assign Fields
var fieldType = field.Field.Type()
var setFieldBackToValue, setSliceFieldBackToValue bool
if reflectValue.Type().AssignableTo(fieldType) {
field.Set(reflectValue)
} else if reflectValue.Type().Elem().AssignableTo(fieldType) {
// if field's type is struct, then need to set value back to argument after save
setFieldBackToValue = true
field.Set(reflectValue.Elem())
} else if fieldType.Kind() == reflect.Slice {
if reflectValue.Type().AssignableTo(fieldType.Elem()) {
field.Set(reflect.Append(field.Field, reflectValue))
} else if reflectValue.Type().Elem().AssignableTo(fieldType.Elem()) {
// if field's type is slice of struct, then need to set value back to argument after save
setSliceFieldBackToValue = true
field.Set(reflect.Append(field.Field, reflectValue.Elem()))
}
}
if relationship.Kind == "many_to_many" {
association.setErr(relationship.JoinTableHandler.Add(relationship.JoinTableHandler, scope.NewDB(), scope.Value, reflectValue.Interface()))
} else {
association.setErr(scope.NewDB().Select(field.Name).Save(scope.Value).Error)
if setFieldBackToValue {
reflectValue.Elem().Set(field.Field)
} else if setSliceFieldBackToValue {
reflectValue.Elem().Set(field.Field.Index(field.Field.Len() - 1))
}
}
}
for _, value := range values {
reflectValue := reflect.ValueOf(value)
indirectReflectValue := reflect.Indirect(reflectValue)
if indirectReflectValue.Kind() == reflect.Struct {
saveAssociation(reflectValue)
} else if indirectReflectValue.Kind() == reflect.Slice {
for i := 0; i < indirectReflectValue.Len(); i++ {
saveAssociation(indirectReflectValue.Index(i))
}
} else {
association.setErr(errors.New("invalid value type"))
}
}
return association
}
// setErr set error when the error is not nil. And return Association.
func (association *Association) setErr(err error) *Association {
if err != nil {
association.Error = err
}
return association
}

View File

@@ -1,242 +0,0 @@
package gorm
import "log"
// DefaultCallback default callbacks defined by gorm
var DefaultCallback = &Callback{}
// Callback is a struct that contains all CRUD callbacks
// Field `creates` contains callbacks will be call when creating object
// Field `updates` contains callbacks will be call when updating object
// Field `deletes` contains callbacks will be call when deleting object
// Field `queries` contains callbacks will be call when querying object with query methods like Find, First, Related, Association...
// Field `rowQueries` contains callbacks will be call when querying object with Row, Rows...
// Field `processors` contains all callback processors, will be used to generate above callbacks in order
type Callback struct {
creates []*func(scope *Scope)
updates []*func(scope *Scope)
deletes []*func(scope *Scope)
queries []*func(scope *Scope)
rowQueries []*func(scope *Scope)
processors []*CallbackProcessor
}
// CallbackProcessor contains callback informations
type CallbackProcessor struct {
name string // current callback's name
before string // register current callback before a callback
after string // register current callback after a callback
replace bool // replace callbacks with same name
remove bool // delete callbacks with same name
kind string // callback type: create, update, delete, query, row_query
processor *func(scope *Scope) // callback handler
parent *Callback
}
func (c *Callback) clone() *Callback {
return &Callback{
creates: c.creates,
updates: c.updates,
deletes: c.deletes,
queries: c.queries,
rowQueries: c.rowQueries,
processors: c.processors,
}
}
// Create could be used to register callbacks for creating object
// db.Callback().Create().After("gorm:create").Register("plugin:run_after_create", func(*Scope) {
// // business logic
// ...
//
// // set error if some thing wrong happened, will rollback the creating
// scope.Err(errors.New("error"))
// })
func (c *Callback) Create() *CallbackProcessor {
return &CallbackProcessor{kind: "create", parent: c}
}
// Update could be used to register callbacks for updating object, refer `Create` for usage
func (c *Callback) Update() *CallbackProcessor {
return &CallbackProcessor{kind: "update", parent: c}
}
// Delete could be used to register callbacks for deleting object, refer `Create` for usage
func (c *Callback) Delete() *CallbackProcessor {
return &CallbackProcessor{kind: "delete", parent: c}
}
// Query could be used to register callbacks for querying objects with query methods like `Find`, `First`, `Related`, `Association`...
// Refer `Create` for usage
func (c *Callback) Query() *CallbackProcessor {
return &CallbackProcessor{kind: "query", parent: c}
}
// RowQuery could be used to register callbacks for querying objects with `Row`, `Rows`, refer `Create` for usage
func (c *Callback) RowQuery() *CallbackProcessor {
return &CallbackProcessor{kind: "row_query", parent: c}
}
// After insert a new callback after callback `callbackName`, refer `Callbacks.Create`
func (cp *CallbackProcessor) After(callbackName string) *CallbackProcessor {
cp.after = callbackName
return cp
}
// Before insert a new callback before callback `callbackName`, refer `Callbacks.Create`
func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor {
cp.before = callbackName
return cp
}
// Register a new callback, refer `Callbacks.Create`
func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) {
if cp.kind == "row_query" {
if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" {
log.Printf("Registing RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...\n", callbackName)
cp.before = "gorm:row_query"
}
}
cp.name = callbackName
cp.processor = &callback
cp.parent.processors = append(cp.parent.processors, cp)
cp.parent.reorder()
}
// Remove a registered callback
// db.Callback().Create().Remove("gorm:update_time_stamp_when_create")
func (cp *CallbackProcessor) Remove(callbackName string) {
log.Printf("[info] removing callback `%v` from %v\n", callbackName, fileWithLineNum())
cp.name = callbackName
cp.remove = true
cp.parent.processors = append(cp.parent.processors, cp)
cp.parent.reorder()
}
// Replace a registered callback with new callback
// db.Callback().Create().Replace("gorm:update_time_stamp_when_create", func(*Scope) {
// scope.SetColumn("Created", now)
// scope.SetColumn("Updated", now)
// })
func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) {
log.Printf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum())
cp.name = callbackName
cp.processor = &callback
cp.replace = true
cp.parent.processors = append(cp.parent.processors, cp)
cp.parent.reorder()
}
// Get registered callback
// db.Callback().Create().Get("gorm:create")
func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) {
for _, p := range cp.parent.processors {
if p.name == callbackName && p.kind == cp.kind && !cp.remove {
return *p.processor
}
}
return nil
}
// getRIndex get right index from string slice
func getRIndex(strs []string, str string) int {
for i := len(strs) - 1; i >= 0; i-- {
if strs[i] == str {
return i
}
}
return -1
}
// sortProcessors sort callback processors based on its before, after, remove, replace
func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) {
var (
allNames, sortedNames []string
sortCallbackProcessor func(c *CallbackProcessor)
)
for _, cp := range cps {
// show warning message the callback name already exists
if index := getRIndex(allNames, cp.name); index > -1 && !cp.replace && !cp.remove {
log.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum())
}
allNames = append(allNames, cp.name)
}
sortCallbackProcessor = func(c *CallbackProcessor) {
if getRIndex(sortedNames, c.name) == -1 { // if not sorted
if c.before != "" { // if defined before callback
if index := getRIndex(sortedNames, c.before); index != -1 {
// if before callback already sorted, append current callback just after it
sortedNames = append(sortedNames[:index], append([]string{c.name}, sortedNames[index:]...)...)
} else if index := getRIndex(allNames, c.before); index != -1 {
// if before callback exists but haven't sorted, append current callback to last
sortedNames = append(sortedNames, c.name)
sortCallbackProcessor(cps[index])
}
}
if c.after != "" { // if defined after callback
if index := getRIndex(sortedNames, c.after); index != -1 {
// if after callback already sorted, append current callback just before it
sortedNames = append(sortedNames[:index+1], append([]string{c.name}, sortedNames[index+1:]...)...)
} else if index := getRIndex(allNames, c.after); index != -1 {
// if after callback exists but haven't sorted
cp := cps[index]
// set after callback's before callback to current callback
if cp.before == "" {
cp.before = c.name
}
sortCallbackProcessor(cp)
}
}
// if current callback haven't been sorted, append it to last
if getRIndex(sortedNames, c.name) == -1 {
sortedNames = append(sortedNames, c.name)
}
}
}
for _, cp := range cps {
sortCallbackProcessor(cp)
}
var sortedFuncs []*func(scope *Scope)
for _, name := range sortedNames {
if index := getRIndex(allNames, name); !cps[index].remove {
sortedFuncs = append(sortedFuncs, cps[index].processor)
}
}
return sortedFuncs
}
// reorder all registered processors, and reset CRUD callbacks
func (c *Callback) reorder() {
var creates, updates, deletes, queries, rowQueries []*CallbackProcessor
for _, processor := range c.processors {
if processor.name != "" {
switch processor.kind {
case "create":
creates = append(creates, processor)
case "update":
updates = append(updates, processor)
case "delete":
deletes = append(deletes, processor)
case "query":
queries = append(queries, processor)
case "row_query":
rowQueries = append(rowQueries, processor)
}
}
}
c.creates = sortProcessors(creates)
c.updates = sortProcessors(updates)
c.deletes = sortProcessors(deletes)
c.queries = sortProcessors(queries)
c.rowQueries = sortProcessors(rowQueries)
}

View File

@@ -1,164 +0,0 @@
package gorm
import (
"fmt"
"strings"
)
// Define callbacks for creating
func init() {
DefaultCallback.Create().Register("gorm:begin_transaction", beginTransactionCallback)
DefaultCallback.Create().Register("gorm:before_create", beforeCreateCallback)
DefaultCallback.Create().Register("gorm:save_before_associations", saveBeforeAssociationsCallback)
DefaultCallback.Create().Register("gorm:update_time_stamp", updateTimeStampForCreateCallback)
DefaultCallback.Create().Register("gorm:create", createCallback)
DefaultCallback.Create().Register("gorm:force_reload_after_create", forceReloadAfterCreateCallback)
DefaultCallback.Create().Register("gorm:save_after_associations", saveAfterAssociationsCallback)
DefaultCallback.Create().Register("gorm:after_create", afterCreateCallback)
DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
}
// beforeCreateCallback will invoke `BeforeSave`, `BeforeCreate` method before creating
func beforeCreateCallback(scope *Scope) {
if !scope.HasError() {
scope.CallMethod("BeforeSave")
}
if !scope.HasError() {
scope.CallMethod("BeforeCreate")
}
}
// updateTimeStampForCreateCallback will set `CreatedAt`, `UpdatedAt` when creating
func updateTimeStampForCreateCallback(scope *Scope) {
if !scope.HasError() {
now := NowFunc()
if createdAtField, ok := scope.FieldByName("CreatedAt"); ok {
if createdAtField.IsBlank {
createdAtField.Set(now)
}
}
if updatedAtField, ok := scope.FieldByName("UpdatedAt"); ok {
if updatedAtField.IsBlank {
updatedAtField.Set(now)
}
}
}
}
// createCallback the callback used to insert data into database
func createCallback(scope *Scope) {
if !scope.HasError() {
defer scope.trace(NowFunc())
var (
columns, placeholders []string
blankColumnsWithDefaultValue []string
)
for _, field := range scope.Fields() {
if scope.changeableField(field) {
if field.IsNormal && !field.IsIgnored {
if field.IsBlank && field.HasDefaultValue {
blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, scope.Quote(field.DBName))
scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue)
} else if !field.IsPrimaryKey || !field.IsBlank {
columns = append(columns, scope.Quote(field.DBName))
placeholders = append(placeholders, scope.AddToVars(field.Field.Interface()))
}
} else if field.Relationship != nil && field.Relationship.Kind == "belongs_to" {
for _, foreignKey := range field.Relationship.ForeignDBNames {
if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {
columns = append(columns, scope.Quote(foreignField.DBName))
placeholders = append(placeholders, scope.AddToVars(foreignField.Field.Interface()))
}
}
}
}
}
var (
returningColumn = "*"
quotedTableName = scope.QuotedTableName()
primaryField = scope.PrimaryField()
extraOption string
)
if str, ok := scope.Get("gorm:insert_option"); ok {
extraOption = fmt.Sprint(str)
}
if primaryField != nil {
returningColumn = scope.Quote(primaryField.DBName)
}
lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn)
if len(columns) == 0 {
scope.Raw(fmt.Sprintf(
"INSERT INTO %v %v%v%v",
quotedTableName,
scope.Dialect().DefaultValueStr(),
addExtraSpaceIfExist(extraOption),
addExtraSpaceIfExist(lastInsertIDReturningSuffix),
))
} else {
scope.Raw(fmt.Sprintf(
"INSERT INTO %v (%v) VALUES (%v)%v%v",
scope.QuotedTableName(),
strings.Join(columns, ","),
strings.Join(placeholders, ","),
addExtraSpaceIfExist(extraOption),
addExtraSpaceIfExist(lastInsertIDReturningSuffix),
))
}
// execute create sql
if lastInsertIDReturningSuffix == "" || primaryField == nil {
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
// set rows affected count
scope.db.RowsAffected, _ = result.RowsAffected()
// set primary value to primary field
if primaryField != nil && primaryField.IsBlank {
if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil {
scope.Err(primaryField.Set(primaryValue))
}
}
}
} else {
if primaryField.Field.CanAddr() {
if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
primaryField.IsBlank = false
scope.db.RowsAffected = 1
}
} else {
scope.Err(ErrUnaddressable)
}
}
}
}
// forceReloadAfterCreateCallback will reload columns that having default value, and set it back to current object
func forceReloadAfterCreateCallback(scope *Scope) {
if blankColumnsWithDefaultValue, ok := scope.InstanceGet("gorm:blank_columns_with_default_value"); ok {
db := scope.DB().New().Table(scope.TableName()).Select(blankColumnsWithDefaultValue.([]string))
for _, field := range scope.Fields() {
if field.IsPrimaryKey && !field.IsBlank {
db = db.Where(fmt.Sprintf("%v = ?", field.DBName), field.Field.Interface())
}
}
db.Scan(scope.Value)
}
}
// afterCreateCallback will invoke `AfterCreate`, `AfterSave` method after creating
func afterCreateCallback(scope *Scope) {
if !scope.HasError() {
scope.CallMethod("AfterCreate")
}
if !scope.HasError() {
scope.CallMethod("AfterSave")
}
}

View File

@@ -1,63 +0,0 @@
package gorm
import (
"errors"
"fmt"
)
// Define callbacks for deleting
func init() {
DefaultCallback.Delete().Register("gorm:begin_transaction", beginTransactionCallback)
DefaultCallback.Delete().Register("gorm:before_delete", beforeDeleteCallback)
DefaultCallback.Delete().Register("gorm:delete", deleteCallback)
DefaultCallback.Delete().Register("gorm:after_delete", afterDeleteCallback)
DefaultCallback.Delete().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
}
// beforeDeleteCallback will invoke `BeforeDelete` method before deleting
func beforeDeleteCallback(scope *Scope) {
if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() {
scope.Err(errors.New("Missing WHERE clause while deleting"))
return
}
if !scope.HasError() {
scope.CallMethod("BeforeDelete")
}
}
// deleteCallback used to delete data from database or set deleted_at to current time (when using with soft delete)
func deleteCallback(scope *Scope) {
if !scope.HasError() {
var extraOption string
if str, ok := scope.Get("gorm:delete_option"); ok {
extraOption = fmt.Sprint(str)
}
deletedAtField, hasDeletedAtField := scope.FieldByName("DeletedAt")
if !scope.Search.Unscoped && hasDeletedAtField {
scope.Raw(fmt.Sprintf(
"UPDATE %v SET %v=%v%v%v",
scope.QuotedTableName(),
scope.Quote(deletedAtField.DBName),
scope.AddToVars(NowFunc()),
addExtraSpaceIfExist(scope.CombinedConditionSql()),
addExtraSpaceIfExist(extraOption),
)).Exec()
} else {
scope.Raw(fmt.Sprintf(
"DELETE FROM %v%v%v",
scope.QuotedTableName(),
addExtraSpaceIfExist(scope.CombinedConditionSql()),
addExtraSpaceIfExist(extraOption),
)).Exec()
}
}
}
// afterDeleteCallback will invoke `AfterDelete` method after deleting
func afterDeleteCallback(scope *Scope) {
if !scope.HasError() {
scope.CallMethod("AfterDelete")
}
}

View File

@@ -1,104 +0,0 @@
package gorm
import (
"errors"
"fmt"
"reflect"
)
// Define callbacks for querying
func init() {
DefaultCallback.Query().Register("gorm:query", queryCallback)
DefaultCallback.Query().Register("gorm:preload", preloadCallback)
DefaultCallback.Query().Register("gorm:after_query", afterQueryCallback)
}
// queryCallback used to query data from database
func queryCallback(scope *Scope) {
if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip {
return
}
//we are only preloading relations, dont touch base model
if _, skip := scope.InstanceGet("gorm:only_preload"); skip {
return
}
defer scope.trace(NowFunc())
var (
isSlice, isPtr bool
resultType reflect.Type
results = scope.IndirectValue()
)
if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok {
if primaryField := scope.PrimaryField(); primaryField != nil {
scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), scope.Quote(primaryField.DBName), orderBy))
}
}
if value, ok := scope.Get("gorm:query_destination"); ok {
results = indirect(reflect.ValueOf(value))
}
if kind := results.Kind(); kind == reflect.Slice {
isSlice = true
resultType = results.Type().Elem()
results.Set(reflect.MakeSlice(results.Type(), 0, 0))
if resultType.Kind() == reflect.Ptr {
isPtr = true
resultType = resultType.Elem()
}
} else if kind != reflect.Struct {
scope.Err(errors.New("unsupported destination, should be slice or struct"))
return
}
scope.prepareQuerySQL()
if !scope.HasError() {
scope.db.RowsAffected = 0
if str, ok := scope.Get("gorm:query_option"); ok {
scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str))
}
if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
defer rows.Close()
columns, _ := rows.Columns()
for rows.Next() {
scope.db.RowsAffected++
elem := results
if isSlice {
elem = reflect.New(resultType).Elem()
}
scope.scan(rows, columns, scope.New(elem.Addr().Interface()).Fields())
if isSlice {
if isPtr {
results.Set(reflect.Append(results, elem.Addr()))
} else {
results.Set(reflect.Append(results, elem))
}
}
}
if err := rows.Err(); err != nil {
scope.Err(err)
} else if scope.db.RowsAffected == 0 && !isSlice {
scope.Err(ErrRecordNotFound)
}
}
}
}
// afterQueryCallback will invoke `AfterFind` method after querying
func afterQueryCallback(scope *Scope) {
if !scope.HasError() {
scope.CallMethod("AfterFind")
}
}

View File

@@ -1,404 +0,0 @@
package gorm
import (
"errors"
"fmt"
"reflect"
"strconv"
"strings"
)
// preloadCallback used to preload associations
func preloadCallback(scope *Scope) {
if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip {
return
}
if ap, ok := scope.Get("gorm:auto_preload"); ok {
// If gorm:auto_preload IS NOT a bool then auto preload.
// Else if it IS a bool, use the value
if apb, ok := ap.(bool); !ok {
autoPreload(scope)
} else if apb {
autoPreload(scope)
}
}
if scope.Search.preload == nil || scope.HasError() {
return
}
var (
preloadedMap = map[string]bool{}
fields = scope.Fields()
)
for _, preload := range scope.Search.preload {
var (
preloadFields = strings.Split(preload.schema, ".")
currentScope = scope
currentFields = fields
)
for idx, preloadField := range preloadFields {
var currentPreloadConditions []interface{}
if currentScope == nil {
continue
}
// if not preloaded
if preloadKey := strings.Join(preloadFields[:idx+1], "."); !preloadedMap[preloadKey] {
// assign search conditions to last preload
if idx == len(preloadFields)-1 {
currentPreloadConditions = preload.conditions
}
for _, field := range currentFields {
if field.Name != preloadField || field.Relationship == nil {
continue
}
switch field.Relationship.Kind {
case "has_one":
currentScope.handleHasOnePreload(field, currentPreloadConditions)
case "has_many":
currentScope.handleHasManyPreload(field, currentPreloadConditions)
case "belongs_to":
currentScope.handleBelongsToPreload(field, currentPreloadConditions)
case "many_to_many":
currentScope.handleManyToManyPreload(field, currentPreloadConditions)
default:
scope.Err(errors.New("unsupported relation"))
}
preloadedMap[preloadKey] = true
break
}
if !preloadedMap[preloadKey] {
scope.Err(fmt.Errorf("can't preload field %s for %s", preloadField, currentScope.GetModelStruct().ModelType))
return
}
}
// preload next level
if idx < len(preloadFields)-1 {
currentScope = currentScope.getColumnAsScope(preloadField)
if currentScope != nil {
currentFields = currentScope.Fields()
}
}
}
}
}
func autoPreload(scope *Scope) {
for _, field := range scope.Fields() {
if field.Relationship == nil {
continue
}
if val, ok := field.TagSettingsGet("PRELOAD"); ok {
if preload, err := strconv.ParseBool(val); err != nil {
scope.Err(errors.New("invalid preload option"))
return
} else if !preload {
continue
}
}
scope.Search.Preload(field.Name)
}
}
func (scope *Scope) generatePreloadDBWithConditions(conditions []interface{}) (*DB, []interface{}) {
var (
preloadDB = scope.NewDB()
preloadConditions []interface{}
)
for _, condition := range conditions {
if scopes, ok := condition.(func(*DB) *DB); ok {
preloadDB = scopes(preloadDB)
} else {
preloadConditions = append(preloadConditions, condition)
}
}
return preloadDB, preloadConditions
}
// handleHasOnePreload used to preload has one associations
func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) {
relation := field.Relationship
// get relations's primary keys
primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
if len(primaryKeys) == 0 {
return
}
// preload conditions
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
// find relations
query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys))
values := toQueryValues(primaryKeys)
if relation.PolymorphicType != "" {
query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName))
values = append(values, relation.PolymorphicValue)
}
results := makeSlice(field.Struct.Type)
scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error)
// assign find results
var (
resultsValue = indirect(reflect.ValueOf(results))
indirectScopeValue = scope.IndirectValue()
)
if indirectScopeValue.Kind() == reflect.Slice {
foreignValuesToResults := make(map[string]reflect.Value)
for i := 0; i < resultsValue.Len(); i++ {
result := resultsValue.Index(i)
foreignValues := toString(getValueFromFields(result, relation.ForeignFieldNames))
foreignValuesToResults[foreignValues] = result
}
for j := 0; j < indirectScopeValue.Len(); j++ {
indirectValue := indirect(indirectScopeValue.Index(j))
valueString := toString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames))
if result, found := foreignValuesToResults[valueString]; found {
indirectValue.FieldByName(field.Name).Set(result)
}
}
} else {
for i := 0; i < resultsValue.Len(); i++ {
result := resultsValue.Index(i)
scope.Err(field.Set(result))
}
}
}
// handleHasManyPreload used to preload has many associations
func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) {
relation := field.Relationship
// get relations's primary keys
primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
if len(primaryKeys) == 0 {
return
}
// preload conditions
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
// find relations
query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys))
values := toQueryValues(primaryKeys)
if relation.PolymorphicType != "" {
query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName))
values = append(values, relation.PolymorphicValue)
}
results := makeSlice(field.Struct.Type)
scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error)
// assign find results
var (
resultsValue = indirect(reflect.ValueOf(results))
indirectScopeValue = scope.IndirectValue()
)
if indirectScopeValue.Kind() == reflect.Slice {
preloadMap := make(map[string][]reflect.Value)
for i := 0; i < resultsValue.Len(); i++ {
result := resultsValue.Index(i)
foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
preloadMap[toString(foreignValues)] = append(preloadMap[toString(foreignValues)], result)
}
for j := 0; j < indirectScopeValue.Len(); j++ {
object := indirect(indirectScopeValue.Index(j))
objectRealValue := getValueFromFields(object, relation.AssociationForeignFieldNames)
f := object.FieldByName(field.Name)
if results, ok := preloadMap[toString(objectRealValue)]; ok {
f.Set(reflect.Append(f, results...))
} else {
f.Set(reflect.MakeSlice(f.Type(), 0, 0))
}
}
} else {
scope.Err(field.Set(resultsValue))
}
}
// handleBelongsToPreload used to preload belongs to associations
func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) {
relation := field.Relationship
// preload conditions
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
// get relations's primary keys
primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames, scope.Value)
if len(primaryKeys) == 0 {
return
}
// find relations
results := makeSlice(field.Struct.Type)
scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error)
// assign find results
var (
resultsValue = indirect(reflect.ValueOf(results))
indirectScopeValue = scope.IndirectValue()
)
foreignFieldToObjects := make(map[string][]*reflect.Value)
if indirectScopeValue.Kind() == reflect.Slice {
for j := 0; j < indirectScopeValue.Len(); j++ {
object := indirect(indirectScopeValue.Index(j))
valueString := toString(getValueFromFields(object, relation.ForeignFieldNames))
foreignFieldToObjects[valueString] = append(foreignFieldToObjects[valueString], &object)
}
}
for i := 0; i < resultsValue.Len(); i++ {
result := resultsValue.Index(i)
if indirectScopeValue.Kind() == reflect.Slice {
valueString := toString(getValueFromFields(result, relation.AssociationForeignFieldNames))
if objects, found := foreignFieldToObjects[valueString]; found {
for _, object := range objects {
object.FieldByName(field.Name).Set(result)
}
}
} else {
scope.Err(field.Set(result))
}
}
}
// handleManyToManyPreload used to preload many to many associations
func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) {
var (
relation = field.Relationship
joinTableHandler = relation.JoinTableHandler
fieldType = field.Struct.Type.Elem()
foreignKeyValue interface{}
foreignKeyType = reflect.ValueOf(&foreignKeyValue).Type()
linkHash = map[string][]reflect.Value{}
isPtr bool
)
if fieldType.Kind() == reflect.Ptr {
isPtr = true
fieldType = fieldType.Elem()
}
var sourceKeys = []string{}
for _, key := range joinTableHandler.SourceForeignKeys() {
sourceKeys = append(sourceKeys, key.DBName)
}
// preload conditions
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
// generate query with join table
newScope := scope.New(reflect.New(fieldType).Interface())
preloadDB = preloadDB.Table(newScope.TableName()).Model(newScope.Value)
if len(preloadDB.search.selects) == 0 {
preloadDB = preloadDB.Select("*")
}
preloadDB = joinTableHandler.JoinWith(joinTableHandler, preloadDB, scope.Value)
// preload inline conditions
if len(preloadConditions) > 0 {
preloadDB = preloadDB.Where(preloadConditions[0], preloadConditions[1:]...)
}
rows, err := preloadDB.Rows()
if scope.Err(err) != nil {
return
}
defer rows.Close()
columns, _ := rows.Columns()
for rows.Next() {
var (
elem = reflect.New(fieldType).Elem()
fields = scope.New(elem.Addr().Interface()).Fields()
)
// register foreign keys in join tables
var joinTableFields []*Field
for _, sourceKey := range sourceKeys {
joinTableFields = append(joinTableFields, &Field{StructField: &StructField{DBName: sourceKey, IsNormal: true}, Field: reflect.New(foreignKeyType).Elem()})
}
scope.scan(rows, columns, append(fields, joinTableFields...))
scope.New(elem.Addr().Interface()).
InstanceSet("gorm:skip_query_callback", true).
callCallbacks(scope.db.parent.callbacks.queries)
var foreignKeys = make([]interface{}, len(sourceKeys))
// generate hashed forkey keys in join table
for idx, joinTableField := range joinTableFields {
if !joinTableField.Field.IsNil() {
foreignKeys[idx] = joinTableField.Field.Elem().Interface()
}
}
hashedSourceKeys := toString(foreignKeys)
if isPtr {
linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem.Addr())
} else {
linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem)
}
}
if err := rows.Err(); err != nil {
scope.Err(err)
}
// assign find results
var (
indirectScopeValue = scope.IndirectValue()
fieldsSourceMap = map[string][]reflect.Value{}
foreignFieldNames = []string{}
)
for _, dbName := range relation.ForeignFieldNames {
if field, ok := scope.FieldByName(dbName); ok {
foreignFieldNames = append(foreignFieldNames, field.Name)
}
}
if indirectScopeValue.Kind() == reflect.Slice {
for j := 0; j < indirectScopeValue.Len(); j++ {
object := indirect(indirectScopeValue.Index(j))
key := toString(getValueFromFields(object, foreignFieldNames))
fieldsSourceMap[key] = append(fieldsSourceMap[key], object.FieldByName(field.Name))
}
} else if indirectScopeValue.IsValid() {
key := toString(getValueFromFields(indirectScopeValue, foreignFieldNames))
fieldsSourceMap[key] = append(fieldsSourceMap[key], indirectScopeValue.FieldByName(field.Name))
}
for source, link := range linkHash {
for i, field := range fieldsSourceMap[source] {
//If not 0 this means Value is a pointer and we already added preloaded models to it
if fieldsSourceMap[source][i].Len() != 0 {
continue
}
field.Set(reflect.Append(fieldsSourceMap[source][i], link...))
}
}
}

View File

@@ -1,30 +0,0 @@
package gorm
import "database/sql"
// Define callbacks for row query
func init() {
DefaultCallback.RowQuery().Register("gorm:row_query", rowQueryCallback)
}
type RowQueryResult struct {
Row *sql.Row
}
type RowsQueryResult struct {
Rows *sql.Rows
Error error
}
// queryCallback used to query data from database
func rowQueryCallback(scope *Scope) {
if result, ok := scope.InstanceGet("row_query_result"); ok {
scope.prepareQuerySQL()
if rowResult, ok := result.(*RowQueryResult); ok {
rowResult.Row = scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...)
} else if rowsResult, ok := result.(*RowsQueryResult); ok {
rowsResult.Rows, rowsResult.Error = scope.SQLDB().Query(scope.SQL, scope.SQLVars...)
}
}
}

View File

@@ -1,170 +0,0 @@
package gorm
import (
"reflect"
"strings"
)
func beginTransactionCallback(scope *Scope) {
scope.Begin()
}
func commitOrRollbackTransactionCallback(scope *Scope) {
scope.CommitOrRollback()
}
func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCreate bool, saveReference bool, r *Relationship) {
checkTruth := func(value interface{}) bool {
if v, ok := value.(bool); ok && !v {
return false
}
if v, ok := value.(string); ok {
v = strings.ToLower(v)
return v == "true"
}
return true
}
if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored {
if r = field.Relationship; r != nil {
autoUpdate, autoCreate, saveReference = true, true, true
if value, ok := scope.Get("gorm:save_associations"); ok {
autoUpdate = checkTruth(value)
autoCreate = autoUpdate
saveReference = autoUpdate
} else if value, ok := field.TagSettingsGet("SAVE_ASSOCIATIONS"); ok {
autoUpdate = checkTruth(value)
autoCreate = autoUpdate
saveReference = autoUpdate
}
if value, ok := scope.Get("gorm:association_autoupdate"); ok {
autoUpdate = checkTruth(value)
} else if value, ok := field.TagSettingsGet("ASSOCIATION_AUTOUPDATE"); ok {
autoUpdate = checkTruth(value)
}
if value, ok := scope.Get("gorm:association_autocreate"); ok {
autoCreate = checkTruth(value)
} else if value, ok := field.TagSettingsGet("ASSOCIATION_AUTOCREATE"); ok {
autoCreate = checkTruth(value)
}
if value, ok := scope.Get("gorm:association_save_reference"); ok {
saveReference = checkTruth(value)
} else if value, ok := field.TagSettingsGet("ASSOCIATION_SAVE_REFERENCE"); ok {
saveReference = checkTruth(value)
}
}
}
return
}
func saveBeforeAssociationsCallback(scope *Scope) {
for _, field := range scope.Fields() {
autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field)
if relationship != nil && relationship.Kind == "belongs_to" {
fieldValue := field.Field.Addr().Interface()
newScope := scope.New(fieldValue)
if newScope.PrimaryKeyZero() {
if autoCreate {
scope.Err(scope.NewDB().Save(fieldValue).Error)
}
} else if autoUpdate {
scope.Err(scope.NewDB().Save(fieldValue).Error)
}
if saveReference {
if len(relationship.ForeignFieldNames) != 0 {
// set value's foreign key
for idx, fieldName := range relationship.ForeignFieldNames {
associationForeignName := relationship.AssociationForeignDBNames[idx]
if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok {
scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface()))
}
}
}
}
}
}
}
func saveAfterAssociationsCallback(scope *Scope) {
for _, field := range scope.Fields() {
autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field)
if relationship != nil && (relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") {
value := field.Field
switch value.Kind() {
case reflect.Slice:
for i := 0; i < value.Len(); i++ {
newDB := scope.NewDB()
elem := value.Index(i).Addr().Interface()
newScope := newDB.NewScope(elem)
if saveReference {
if relationship.JoinTableHandler == nil && len(relationship.ForeignFieldNames) != 0 {
for idx, fieldName := range relationship.ForeignFieldNames {
associationForeignName := relationship.AssociationForeignDBNames[idx]
if f, ok := scope.FieldByName(associationForeignName); ok {
scope.Err(newScope.SetColumn(fieldName, f.Field.Interface()))
}
}
}
if relationship.PolymorphicType != "" {
scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue))
}
}
if newScope.PrimaryKeyZero() {
if autoCreate {
scope.Err(newDB.Save(elem).Error)
}
} else if autoUpdate {
scope.Err(newDB.Save(elem).Error)
}
if !scope.New(newScope.Value).PrimaryKeyZero() && saveReference {
if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil {
scope.Err(joinTableHandler.Add(joinTableHandler, newDB, scope.Value, newScope.Value))
}
}
}
default:
elem := value.Addr().Interface()
newScope := scope.New(elem)
if saveReference {
if len(relationship.ForeignFieldNames) != 0 {
for idx, fieldName := range relationship.ForeignFieldNames {
associationForeignName := relationship.AssociationForeignDBNames[idx]
if f, ok := scope.FieldByName(associationForeignName); ok {
scope.Err(newScope.SetColumn(fieldName, f.Field.Interface()))
}
}
}
if relationship.PolymorphicType != "" {
scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue))
}
}
if newScope.PrimaryKeyZero() {
if autoCreate {
scope.Err(scope.NewDB().Save(elem).Error)
}
} else if autoUpdate {
scope.Err(scope.NewDB().Save(elem).Error)
}
}
}
}
}

View File

@@ -1,121 +0,0 @@
package gorm
import (
"errors"
"fmt"
"sort"
"strings"
)
// Define callbacks for updating
func init() {
DefaultCallback.Update().Register("gorm:assign_updating_attributes", assignUpdatingAttributesCallback)
DefaultCallback.Update().Register("gorm:begin_transaction", beginTransactionCallback)
DefaultCallback.Update().Register("gorm:before_update", beforeUpdateCallback)
DefaultCallback.Update().Register("gorm:save_before_associations", saveBeforeAssociationsCallback)
DefaultCallback.Update().Register("gorm:update_time_stamp", updateTimeStampForUpdateCallback)
DefaultCallback.Update().Register("gorm:update", updateCallback)
DefaultCallback.Update().Register("gorm:save_after_associations", saveAfterAssociationsCallback)
DefaultCallback.Update().Register("gorm:after_update", afterUpdateCallback)
DefaultCallback.Update().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
}
// assignUpdatingAttributesCallback assign updating attributes to model
func assignUpdatingAttributesCallback(scope *Scope) {
if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok {
if updateMaps, hasUpdate := scope.updatedAttrsWithValues(attrs); hasUpdate {
scope.InstanceSet("gorm:update_attrs", updateMaps)
} else {
scope.SkipLeft()
}
}
}
// beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating
func beforeUpdateCallback(scope *Scope) {
if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() {
scope.Err(errors.New("Missing WHERE clause while updating"))
return
}
if _, ok := scope.Get("gorm:update_column"); !ok {
if !scope.HasError() {
scope.CallMethod("BeforeSave")
}
if !scope.HasError() {
scope.CallMethod("BeforeUpdate")
}
}
}
// updateTimeStampForUpdateCallback will set `UpdatedAt` when updating
func updateTimeStampForUpdateCallback(scope *Scope) {
if _, ok := scope.Get("gorm:update_column"); !ok {
scope.SetColumn("UpdatedAt", NowFunc())
}
}
// updateCallback the callback used to update data to database
func updateCallback(scope *Scope) {
if !scope.HasError() {
var sqls []string
if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok {
// Sort the column names so that the generated SQL is the same every time.
updateMap := updateAttrs.(map[string]interface{})
var columns []string
for c := range updateMap {
columns = append(columns, c)
}
sort.Strings(columns)
for _, column := range columns {
value := updateMap[column]
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(column), scope.AddToVars(value)))
}
} else {
for _, field := range scope.Fields() {
if scope.changeableField(field) {
if !field.IsPrimaryKey && field.IsNormal {
if !field.IsForeignKey || !field.IsBlank || !field.HasDefaultValue {
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
}
} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
for _, foreignKey := range relationship.ForeignDBNames {
if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {
sqls = append(sqls,
fmt.Sprintf("%v = %v", scope.Quote(foreignField.DBName), scope.AddToVars(foreignField.Field.Interface())))
}
}
}
}
}
}
var extraOption string
if str, ok := scope.Get("gorm:update_option"); ok {
extraOption = fmt.Sprint(str)
}
if len(sqls) > 0 {
scope.Raw(fmt.Sprintf(
"UPDATE %v SET %v%v%v",
scope.QuotedTableName(),
strings.Join(sqls, ", "),
addExtraSpaceIfExist(scope.CombinedConditionSql()),
addExtraSpaceIfExist(extraOption),
)).Exec()
}
}
}
// afterUpdateCallback will invoke `AfterUpdate`, `AfterSave` method after updating
func afterUpdateCallback(scope *Scope) {
if _, ok := scope.Get("gorm:update_column"); !ok {
if !scope.HasError() {
scope.CallMethod("AfterUpdate")
}
if !scope.HasError() {
scope.CallMethod("AfterSave")
}
}
}

View File

@@ -1,138 +0,0 @@
package gorm
import (
"database/sql"
"fmt"
"reflect"
"strconv"
"strings"
)
// Dialect interface contains behaviors that differ across SQL database
type Dialect interface {
// GetName get dialect's name
GetName() string
// SetDB set db for dialect
SetDB(db SQLCommon)
// BindVar return the placeholder for actual values in SQL statements, in many dbs it is "?", Postgres using $1
BindVar(i int) string
// Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name
Quote(key string) string
// DataTypeOf return data's sql type
DataTypeOf(field *StructField) string
// HasIndex check has index or not
HasIndex(tableName string, indexName string) bool
// HasForeignKey check has foreign key or not
HasForeignKey(tableName string, foreignKeyName string) bool
// RemoveIndex remove index
RemoveIndex(tableName string, indexName string) error
// HasTable check has table or not
HasTable(tableName string) bool
// HasColumn check has column or not
HasColumn(tableName string, columnName string) bool
// ModifyColumn modify column's type
ModifyColumn(tableName string, columnName string, typ string) error
// LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case
LimitAndOffsetSQL(limit, offset interface{}) string
// SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL`
SelectFromDummyTable() string
// LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING`
LastInsertIDReturningSuffix(tableName, columnName string) string
// DefaultValueStr
DefaultValueStr() string
// BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference
BuildKeyName(kind, tableName string, fields ...string) string
// CurrentDatabase return current database name
CurrentDatabase() string
}
var dialectsMap = map[string]Dialect{}
func newDialect(name string, db SQLCommon) Dialect {
if value, ok := dialectsMap[name]; ok {
dialect := reflect.New(reflect.TypeOf(value).Elem()).Interface().(Dialect)
dialect.SetDB(db)
return dialect
}
fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", name)
commontDialect := &commonDialect{}
commontDialect.SetDB(db)
return commontDialect
}
// RegisterDialect register new dialect
func RegisterDialect(name string, dialect Dialect) {
dialectsMap[name] = dialect
}
// GetDialect gets the dialect for the specified dialect name
func GetDialect(name string) (dialect Dialect, ok bool) {
dialect, ok = dialectsMap[name]
return
}
// ParseFieldStructForDialect get field's sql data type
var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fieldValue reflect.Value, sqlType string, size int, additionalType string) {
// Get redirected field type
var (
reflectType = field.Struct.Type
dataType, _ = field.TagSettingsGet("TYPE")
)
for reflectType.Kind() == reflect.Ptr {
reflectType = reflectType.Elem()
}
// Get redirected field value
fieldValue = reflect.Indirect(reflect.New(reflectType))
if gormDataType, ok := fieldValue.Interface().(interface {
GormDataType(Dialect) string
}); ok {
dataType = gormDataType.GormDataType(dialect)
}
// Get scanner's real value
if dataType == "" {
var getScannerValue func(reflect.Value)
getScannerValue = func(value reflect.Value) {
fieldValue = value
if _, isScanner := reflect.New(fieldValue.Type()).Interface().(sql.Scanner); isScanner && fieldValue.Kind() == reflect.Struct {
getScannerValue(fieldValue.Field(0))
}
}
getScannerValue(fieldValue)
}
// Default Size
if num, ok := field.TagSettingsGet("SIZE"); ok {
size, _ = strconv.Atoi(num)
} else {
size = 255
}
// Default type from tag setting
notNull, _ := field.TagSettingsGet("NOT NULL")
unique, _ := field.TagSettingsGet("UNIQUE")
additionalType = notNull + " " + unique
if value, ok := field.TagSettingsGet("DEFAULT"); ok {
additionalType = additionalType + " DEFAULT " + value
}
return fieldValue, dataType, size, strings.TrimSpace(additionalType)
}
func currentDatabaseAndTable(dialect Dialect, tableName string) (string, string) {
if strings.Contains(tableName, ".") {
splitStrings := strings.SplitN(tableName, ".", 2)
return splitStrings[0], splitStrings[1]
}
return dialect.CurrentDatabase(), tableName
}

View File

@@ -1,176 +0,0 @@
package gorm
import (
"fmt"
"reflect"
"regexp"
"strconv"
"strings"
"time"
)
// DefaultForeignKeyNamer contains the default foreign key name generator method
type DefaultForeignKeyNamer struct {
}
type commonDialect struct {
db SQLCommon
DefaultForeignKeyNamer
}
func init() {
RegisterDialect("common", &commonDialect{})
}
func (commonDialect) GetName() string {
return "common"
}
func (s *commonDialect) SetDB(db SQLCommon) {
s.db = db
}
func (commonDialect) BindVar(i int) string {
return "$$$" // ?
}
func (commonDialect) Quote(key string) string {
return fmt.Sprintf(`"%s"`, key)
}
func (s *commonDialect) fieldCanAutoIncrement(field *StructField) bool {
if value, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok {
return strings.ToLower(value) != "false"
}
return field.IsPrimaryKey
}
func (s *commonDialect) DataTypeOf(field *StructField) string {
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s)
if sqlType == "" {
switch dataValue.Kind() {
case reflect.Bool:
sqlType = "BOOLEAN"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if s.fieldCanAutoIncrement(field) {
sqlType = "INTEGER AUTO_INCREMENT"
} else {
sqlType = "INTEGER"
}
case reflect.Int64, reflect.Uint64:
if s.fieldCanAutoIncrement(field) {
sqlType = "BIGINT AUTO_INCREMENT"
} else {
sqlType = "BIGINT"
}
case reflect.Float32, reflect.Float64:
sqlType = "FLOAT"
case reflect.String:
if size > 0 && size < 65532 {
sqlType = fmt.Sprintf("VARCHAR(%d)", size)
} else {
sqlType = "VARCHAR(65532)"
}
case reflect.Struct:
if _, ok := dataValue.Interface().(time.Time); ok {
sqlType = "TIMESTAMP"
}
default:
if _, ok := dataValue.Interface().([]byte); ok {
if size > 0 && size < 65532 {
sqlType = fmt.Sprintf("BINARY(%d)", size)
} else {
sqlType = "BINARY(65532)"
}
}
}
}
if sqlType == "" {
panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", dataValue.Type().Name(), dataValue.Kind().String()))
}
if strings.TrimSpace(additionalType) == "" {
return sqlType
}
return fmt.Sprintf("%v %v", sqlType, additionalType)
}
func (s commonDialect) HasIndex(tableName string, indexName string) bool {
var count int
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, tableName, indexName).Scan(&count)
return count > 0
}
func (s commonDialect) RemoveIndex(tableName string, indexName string) error {
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v", indexName))
return err
}
func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bool {
return false
}
func (s commonDialect) HasTable(tableName string) bool {
var count int
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", currentDatabase, tableName).Scan(&count)
return count > 0
}
func (s commonDialect) HasColumn(tableName string, columnName string) bool {
var count int
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count)
return count > 0
}
func (s commonDialect) ModifyColumn(tableName string, columnName string, typ string) error {
_, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", tableName, columnName, typ))
return err
}
func (s commonDialect) CurrentDatabase() (name string) {
s.db.QueryRow("SELECT DATABASE()").Scan(&name)
return
}
func (commonDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
if limit != nil {
if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
sql += fmt.Sprintf(" LIMIT %d", parsedLimit)
}
}
if offset != nil {
if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 {
sql += fmt.Sprintf(" OFFSET %d", parsedOffset)
}
}
return
}
func (commonDialect) SelectFromDummyTable() string {
return ""
}
func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string {
return ""
}
func (commonDialect) DefaultValueStr() string {
return "DEFAULT VALUES"
}
// BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference
func (DefaultForeignKeyNamer) BuildKeyName(kind, tableName string, fields ...string) string {
keyName := fmt.Sprintf("%s_%s_%s", kind, tableName, strings.Join(fields, "_"))
keyName = regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(keyName, "_")
return keyName
}
// IsByteArrayOrSlice returns true of the reflected value is an array or slice
func IsByteArrayOrSlice(value reflect.Value) bool {
return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0))
}

View File

@@ -1,191 +0,0 @@
package gorm
import (
"crypto/sha1"
"fmt"
"reflect"
"regexp"
"strconv"
"strings"
"time"
"unicode/utf8"
)
type mysql struct {
commonDialect
}
func init() {
RegisterDialect("mysql", &mysql{})
}
func (mysql) GetName() string {
return "mysql"
}
func (mysql) Quote(key string) string {
return fmt.Sprintf("`%s`", key)
}
// Get Data Type for MySQL Dialect
func (s *mysql) DataTypeOf(field *StructField) string {
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s)
// MySQL allows only one auto increment column per table, and it must
// be a KEY column.
if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok {
if _, ok = field.TagSettingsGet("INDEX"); !ok && !field.IsPrimaryKey {
field.TagSettingsDelete("AUTO_INCREMENT")
}
}
if sqlType == "" {
switch dataValue.Kind() {
case reflect.Bool:
sqlType = "boolean"
case reflect.Int8:
if s.fieldCanAutoIncrement(field) {
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
sqlType = "tinyint AUTO_INCREMENT"
} else {
sqlType = "tinyint"
}
case reflect.Int, reflect.Int16, reflect.Int32:
if s.fieldCanAutoIncrement(field) {
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
sqlType = "int AUTO_INCREMENT"
} else {
sqlType = "int"
}
case reflect.Uint8:
if s.fieldCanAutoIncrement(field) {
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
sqlType = "tinyint unsigned AUTO_INCREMENT"
} else {
sqlType = "tinyint unsigned"
}
case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if s.fieldCanAutoIncrement(field) {
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
sqlType = "int unsigned AUTO_INCREMENT"
} else {
sqlType = "int unsigned"
}
case reflect.Int64:
if s.fieldCanAutoIncrement(field) {
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
sqlType = "bigint AUTO_INCREMENT"
} else {
sqlType = "bigint"
}
case reflect.Uint64:
if s.fieldCanAutoIncrement(field) {
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
sqlType = "bigint unsigned AUTO_INCREMENT"
} else {
sqlType = "bigint unsigned"
}
case reflect.Float32, reflect.Float64:
sqlType = "double"
case reflect.String:
if size > 0 && size < 65532 {
sqlType = fmt.Sprintf("varchar(%d)", size)
} else {
sqlType = "longtext"
}
case reflect.Struct:
if _, ok := dataValue.Interface().(time.Time); ok {
precision := ""
if p, ok := field.TagSettingsGet("PRECISION"); ok {
precision = fmt.Sprintf("(%s)", p)
}
if _, ok := field.TagSettingsGet("NOT NULL"); ok {
sqlType = fmt.Sprintf("timestamp%v", precision)
} else {
sqlType = fmt.Sprintf("timestamp%v NULL", precision)
}
}
default:
if IsByteArrayOrSlice(dataValue) {
if size > 0 && size < 65532 {
sqlType = fmt.Sprintf("varbinary(%d)", size)
} else {
sqlType = "longblob"
}
}
}
}
if sqlType == "" {
panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", dataValue.Type().Name(), dataValue.Kind().String()))
}
if strings.TrimSpace(additionalType) == "" {
return sqlType
}
return fmt.Sprintf("%v %v", sqlType, additionalType)
}
func (s mysql) RemoveIndex(tableName string, indexName string) error {
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
return err
}
func (s mysql) ModifyColumn(tableName string, columnName string, typ string) error {
_, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v MODIFY COLUMN %v %v", tableName, columnName, typ))
return err
}
func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
if limit != nil {
if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
sql += fmt.Sprintf(" LIMIT %d", parsedLimit)
if offset != nil {
if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 {
sql += fmt.Sprintf(" OFFSET %d", parsedOffset)
}
}
}
}
return
}
func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool {
var count int
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", currentDatabase, tableName, foreignKeyName).Scan(&count)
return count > 0
}
func (s mysql) CurrentDatabase() (name string) {
s.db.QueryRow("SELECT DATABASE()").Scan(&name)
return
}
func (mysql) SelectFromDummyTable() string {
return "FROM DUAL"
}
func (s mysql) BuildKeyName(kind, tableName string, fields ...string) string {
keyName := s.commonDialect.BuildKeyName(kind, tableName, fields...)
if utf8.RuneCountInString(keyName) <= 64 {
return keyName
}
h := sha1.New()
h.Write([]byte(keyName))
bs := h.Sum(nil)
// sha1 is 40 characters, keep first 24 characters of destination
destRunes := []rune(regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(fields[0], "_"))
if len(destRunes) > 24 {
destRunes = destRunes[:24]
}
return fmt.Sprintf("%s%x", string(destRunes), bs)
}
func (mysql) DefaultValueStr() string {
return "VALUES()"
}

View File

@@ -1,143 +0,0 @@
package gorm
import (
"encoding/json"
"fmt"
"reflect"
"strings"
"time"
)
type postgres struct {
commonDialect
}
func init() {
RegisterDialect("postgres", &postgres{})
RegisterDialect("cloudsqlpostgres", &postgres{})
}
func (postgres) GetName() string {
return "postgres"
}
func (postgres) BindVar(i int) string {
return fmt.Sprintf("$%v", i)
}
func (s *postgres) DataTypeOf(field *StructField) string {
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s)
if sqlType == "" {
switch dataValue.Kind() {
case reflect.Bool:
sqlType = "boolean"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uintptr:
if s.fieldCanAutoIncrement(field) {
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
sqlType = "serial"
} else {
sqlType = "integer"
}
case reflect.Int64, reflect.Uint32, reflect.Uint64:
if s.fieldCanAutoIncrement(field) {
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
sqlType = "bigserial"
} else {
sqlType = "bigint"
}
case reflect.Float32, reflect.Float64:
sqlType = "numeric"
case reflect.String:
if _, ok := field.TagSettingsGet("SIZE"); !ok {
size = 0 // if SIZE haven't been set, use `text` as the default type, as there are no performance different
}
if size > 0 && size < 65532 {
sqlType = fmt.Sprintf("varchar(%d)", size)
} else {
sqlType = "text"
}
case reflect.Struct:
if _, ok := dataValue.Interface().(time.Time); ok {
sqlType = "timestamp with time zone"
}
case reflect.Map:
if dataValue.Type().Name() == "Hstore" {
sqlType = "hstore"
}
default:
if IsByteArrayOrSlice(dataValue) {
sqlType = "bytea"
if isUUID(dataValue) {
sqlType = "uuid"
}
if isJSON(dataValue) {
sqlType = "jsonb"
}
}
}
}
if sqlType == "" {
panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", dataValue.Type().Name(), dataValue.Kind().String()))
}
if strings.TrimSpace(additionalType) == "" {
return sqlType
}
return fmt.Sprintf("%v %v", sqlType, additionalType)
}
func (s postgres) HasIndex(tableName string, indexName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2 AND schemaname = CURRENT_SCHEMA()", tableName, indexName).Scan(&count)
return count > 0
}
func (s postgres) HasForeignKey(tableName string, foreignKeyName string) bool {
var count int
s.db.QueryRow("SELECT count(con.conname) FROM pg_constraint con WHERE $1::regclass::oid = con.conrelid AND con.conname = $2 AND con.contype='f'", tableName, foreignKeyName).Scan(&count)
return count > 0
}
func (s postgres) HasTable(tableName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE' AND table_schema = CURRENT_SCHEMA()", tableName).Scan(&count)
return count > 0
}
func (s postgres) HasColumn(tableName string, columnName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2 AND table_schema = CURRENT_SCHEMA()", tableName, columnName).Scan(&count)
return count > 0
}
func (s postgres) CurrentDatabase() (name string) {
s.db.QueryRow("SELECT CURRENT_DATABASE()").Scan(&name)
return
}
func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string {
return fmt.Sprintf("RETURNING %v.%v", tableName, key)
}
func (postgres) SupportLastInsertID() bool {
return false
}
func isUUID(value reflect.Value) bool {
if value.Kind() != reflect.Array || value.Type().Len() != 16 {
return false
}
typename := value.Type().Name()
lower := strings.ToLower(typename)
return "uuid" == lower || "guid" == lower
}
func isJSON(value reflect.Value) bool {
_, ok := value.Interface().(json.RawMessage)
return ok
}

View File

@@ -1,107 +0,0 @@
package gorm
import (
"fmt"
"reflect"
"strings"
"time"
)
type sqlite3 struct {
commonDialect
}
func init() {
RegisterDialect("sqlite3", &sqlite3{})
}
func (sqlite3) GetName() string {
return "sqlite3"
}
// Get Data Type for Sqlite Dialect
func (s *sqlite3) DataTypeOf(field *StructField) string {
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s)
if sqlType == "" {
switch dataValue.Kind() {
case reflect.Bool:
sqlType = "bool"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if s.fieldCanAutoIncrement(field) {
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
sqlType = "integer primary key autoincrement"
} else {
sqlType = "integer"
}
case reflect.Int64, reflect.Uint64:
if s.fieldCanAutoIncrement(field) {
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
sqlType = "integer primary key autoincrement"
} else {
sqlType = "bigint"
}
case reflect.Float32, reflect.Float64:
sqlType = "real"
case reflect.String:
if size > 0 && size < 65532 {
sqlType = fmt.Sprintf("varchar(%d)", size)
} else {
sqlType = "text"
}
case reflect.Struct:
if _, ok := dataValue.Interface().(time.Time); ok {
sqlType = "datetime"
}
default:
if IsByteArrayOrSlice(dataValue) {
sqlType = "blob"
}
}
}
if sqlType == "" {
panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", dataValue.Type().Name(), dataValue.Kind().String()))
}
if strings.TrimSpace(additionalType) == "" {
return sqlType
}
return fmt.Sprintf("%v %v", sqlType, additionalType)
}
func (s sqlite3) HasIndex(tableName string, indexName string) bool {
var count int
s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Scan(&count)
return count > 0
}
func (s sqlite3) HasTable(tableName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Scan(&count)
return count > 0
}
func (s sqlite3) HasColumn(tableName string, columnName string) bool {
var count int
s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%\"%v\" %%' OR sql LIKE '%%%v %%');\n", columnName, columnName), tableName).Scan(&count)
return count > 0
}
func (s sqlite3) CurrentDatabase() (name string) {
var (
ifaces = make([]interface{}, 3)
pointers = make([]*string, 3)
i int
)
for i = 0; i < 3; i++ {
ifaces[i] = &pointers[i]
}
if err := s.db.QueryRow("PRAGMA database_list").Scan(ifaces...); err != nil {
return
}
if pointers[1] != nil {
name = *pointers[1]
}
return
}

View File

@@ -1,72 +0,0 @@
package gorm
import (
"errors"
"strings"
)
var (
// ErrRecordNotFound record not found error, happens when only haven't find any matched data when looking up with a struct, finding a slice won't return this error
ErrRecordNotFound = errors.New("record not found")
// ErrInvalidSQL invalid SQL error, happens when you passed invalid SQL
ErrInvalidSQL = errors.New("invalid SQL")
// ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback`
ErrInvalidTransaction = errors.New("no valid transaction")
// ErrCantStartTransaction can't start transaction when you are trying to start one with `Begin`
ErrCantStartTransaction = errors.New("can't start transaction")
// ErrUnaddressable unaddressable value
ErrUnaddressable = errors.New("using unaddressable value")
)
// Errors contains all happened errors
type Errors []error
// IsRecordNotFoundError returns current error has record not found error or not
func IsRecordNotFoundError(err error) bool {
if errs, ok := err.(Errors); ok {
for _, err := range errs {
if err == ErrRecordNotFound {
return true
}
}
}
return err == ErrRecordNotFound
}
// GetErrors gets all happened errors
func (errs Errors) GetErrors() []error {
return errs
}
// Add adds an error
func (errs Errors) Add(newErrors ...error) Errors {
for _, err := range newErrors {
if err == nil {
continue
}
if errors, ok := err.(Errors); ok {
errs = errs.Add(errors...)
} else {
ok = true
for _, e := range errs {
if err == e {
ok = false
}
}
if ok {
errs = append(errs, err)
}
}
}
return errs
}
// Error format happened errors
func (errs Errors) Error() string {
var errors = []string{}
for _, e := range errs {
errors = append(errors, e.Error())
}
return strings.Join(errors, "; ")
}

View File

@@ -1,66 +0,0 @@
package gorm
import (
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"reflect"
)
// Field model field definition
type Field struct {
*StructField
IsBlank bool
Field reflect.Value
}
// Set set a value to the field
func (field *Field) Set(value interface{}) (err error) {
if !field.Field.IsValid() {
return errors.New("field value not valid")
}
if !field.Field.CanAddr() {
return ErrUnaddressable
}
reflectValue, ok := value.(reflect.Value)
if !ok {
reflectValue = reflect.ValueOf(value)
}
fieldValue := field.Field
if reflectValue.IsValid() {
if reflectValue.Type().ConvertibleTo(fieldValue.Type()) {
fieldValue.Set(reflectValue.Convert(fieldValue.Type()))
} else {
if fieldValue.Kind() == reflect.Ptr {
if fieldValue.IsNil() {
fieldValue.Set(reflect.New(field.Struct.Type.Elem()))
}
fieldValue = fieldValue.Elem()
}
if reflectValue.Type().ConvertibleTo(fieldValue.Type()) {
fieldValue.Set(reflectValue.Convert(fieldValue.Type()))
} else if scanner, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {
v := reflectValue.Interface()
if valuer, ok := v.(driver.Valuer); ok {
if v, err = valuer.Value(); err == nil {
err = scanner.Scan(v)
}
} else {
err = scanner.Scan(v)
}
} else {
err = fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), fieldValue.Type())
}
}
} else {
field.Field.Set(reflect.Zero(field.Field.Type()))
}
field.IsBlank = isBlank(field.Field)
return err
}

View File

@@ -1,20 +0,0 @@
package gorm
import "database/sql"
// SQLCommon is the minimal database connection functionality gorm requires. Implemented by *sql.DB.
type SQLCommon interface {
Exec(query string, args ...interface{}) (sql.Result, error)
Prepare(query string) (*sql.Stmt, error)
Query(query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
}
type sqlDb interface {
Begin() (*sql.Tx, error)
}
type sqlTx interface {
Commit() error
Rollback() error
}

View File

@@ -1,211 +0,0 @@
package gorm
import (
"errors"
"fmt"
"reflect"
"strings"
)
// JoinTableHandlerInterface is an interface for how to handle many2many relations
type JoinTableHandlerInterface interface {
// initialize join table handler
Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type)
// Table return join table's table name
Table(db *DB) string
// Add create relationship in join table for source and destination
Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error
// Delete delete relationship in join table for sources
Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error
// JoinWith query with `Join` conditions
JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB
// SourceForeignKeys return source foreign keys
SourceForeignKeys() []JoinTableForeignKey
// DestinationForeignKeys return destination foreign keys
DestinationForeignKeys() []JoinTableForeignKey
}
// JoinTableForeignKey join table foreign key struct
type JoinTableForeignKey struct {
DBName string
AssociationDBName string
}
// JoinTableSource is a struct that contains model type and foreign keys
type JoinTableSource struct {
ModelType reflect.Type
ForeignKeys []JoinTableForeignKey
}
// JoinTableHandler default join table handler
type JoinTableHandler struct {
TableName string `sql:"-"`
Source JoinTableSource `sql:"-"`
Destination JoinTableSource `sql:"-"`
}
// SourceForeignKeys return source foreign keys
func (s *JoinTableHandler) SourceForeignKeys() []JoinTableForeignKey {
return s.Source.ForeignKeys
}
// DestinationForeignKeys return destination foreign keys
func (s *JoinTableHandler) DestinationForeignKeys() []JoinTableForeignKey {
return s.Destination.ForeignKeys
}
// Setup initialize a default join table handler
func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) {
s.TableName = tableName
s.Source = JoinTableSource{ModelType: source}
s.Source.ForeignKeys = []JoinTableForeignKey{}
for idx, dbName := range relationship.ForeignFieldNames {
s.Source.ForeignKeys = append(s.Source.ForeignKeys, JoinTableForeignKey{
DBName: relationship.ForeignDBNames[idx],
AssociationDBName: dbName,
})
}
s.Destination = JoinTableSource{ModelType: destination}
s.Destination.ForeignKeys = []JoinTableForeignKey{}
for idx, dbName := range relationship.AssociationForeignFieldNames {
s.Destination.ForeignKeys = append(s.Destination.ForeignKeys, JoinTableForeignKey{
DBName: relationship.AssociationForeignDBNames[idx],
AssociationDBName: dbName,
})
}
}
// Table return join table's table name
func (s JoinTableHandler) Table(db *DB) string {
return DefaultTableNameHandler(db, s.TableName)
}
func (s JoinTableHandler) updateConditionMap(conditionMap map[string]interface{}, db *DB, joinTableSources []JoinTableSource, sources ...interface{}) {
for _, source := range sources {
scope := db.NewScope(source)
modelType := scope.GetModelStruct().ModelType
for _, joinTableSource := range joinTableSources {
if joinTableSource.ModelType == modelType {
for _, foreignKey := range joinTableSource.ForeignKeys {
if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
conditionMap[foreignKey.DBName] = field.Field.Interface()
}
}
break
}
}
}
}
// Add create relationship in join table for source and destination
func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error {
var (
scope = db.NewScope("")
conditionMap = map[string]interface{}{}
)
// Update condition map for source
s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source}, source)
// Update condition map for destination
s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Destination}, destination)
var assignColumns, binVars, conditions []string
var values []interface{}
for key, value := range conditionMap {
assignColumns = append(assignColumns, scope.Quote(key))
binVars = append(binVars, `?`)
conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
values = append(values, value)
}
for _, value := range values {
values = append(values, value)
}
quotedTable := scope.Quote(handler.Table(db))
sql := fmt.Sprintf(
"INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v)",
quotedTable,
strings.Join(assignColumns, ","),
strings.Join(binVars, ","),
scope.Dialect().SelectFromDummyTable(),
quotedTable,
strings.Join(conditions, " AND "),
)
return db.Exec(sql, values...).Error
}
// Delete delete relationship in join table for sources
func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error {
var (
scope = db.NewScope(nil)
conditions []string
values []interface{}
conditionMap = map[string]interface{}{}
)
s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source, s.Destination}, sources...)
for key, value := range conditionMap {
conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
values = append(values, value)
}
return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error
}
// JoinWith query with `Join` conditions
func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB {
var (
scope = db.NewScope(source)
tableName = handler.Table(db)
quotedTableName = scope.Quote(tableName)
joinConditions []string
values []interface{}
)
if s.Source.ModelType == scope.GetModelStruct().ModelType {
destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName()
for _, foreignKey := range s.Destination.ForeignKeys {
joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTableName, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName)))
}
var foreignDBNames []string
var foreignFieldNames []string
for _, foreignKey := range s.Source.ForeignKeys {
foreignDBNames = append(foreignDBNames, foreignKey.DBName)
if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
foreignFieldNames = append(foreignFieldNames, field.Name)
}
}
foreignFieldValues := scope.getColumnAsArray(foreignFieldNames, scope.Value)
var condString string
if len(foreignFieldValues) > 0 {
var quotedForeignDBNames []string
for _, dbName := range foreignDBNames {
quotedForeignDBNames = append(quotedForeignDBNames, tableName+"."+dbName)
}
condString = fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, quotedForeignDBNames), toQueryMarks(foreignFieldValues))
keys := scope.getColumnAsArray(foreignFieldNames, scope.Value)
values = append(values, toQueryValues(keys))
} else {
condString = fmt.Sprintf("1 <> 1")
}
return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTableName, strings.Join(joinConditions, " AND "))).
Where(condString, toQueryValues(foreignFieldValues)...)
}
db.Error = errors.New("wrong source type for join table handler")
return db
}

View File

@@ -1,119 +0,0 @@
package gorm
import (
"database/sql/driver"
"fmt"
"log"
"os"
"reflect"
"regexp"
"strconv"
"time"
"unicode"
)
var (
defaultLogger = Logger{log.New(os.Stdout, "\r\n", 0)}
sqlRegexp = regexp.MustCompile(`\?`)
numericPlaceHolderRegexp = regexp.MustCompile(`\$\d+`)
)
func isPrintable(s string) bool {
for _, r := range s {
if !unicode.IsPrint(r) {
return false
}
}
return true
}
var LogFormatter = func(values ...interface{}) (messages []interface{}) {
if len(values) > 1 {
var (
sql string
formattedValues []string
level = values[0]
currentTime = "\n\033[33m[" + NowFunc().Format("2006-01-02 15:04:05") + "]\033[0m"
source = fmt.Sprintf("\033[35m(%v)\033[0m", values[1])
)
messages = []interface{}{source, currentTime}
if level == "sql" {
// duration
messages = append(messages, fmt.Sprintf(" \033[36;1m[%.2fms]\033[0m ", float64(values[2].(time.Duration).Nanoseconds()/1e4)/100.0))
// sql
for _, value := range values[4].([]interface{}) {
indirectValue := reflect.Indirect(reflect.ValueOf(value))
if indirectValue.IsValid() {
value = indirectValue.Interface()
if t, ok := value.(time.Time); ok {
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", t.Format("2006-01-02 15:04:05")))
} else if b, ok := value.([]byte); ok {
if str := string(b); isPrintable(str) {
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", str))
} else {
formattedValues = append(formattedValues, "'<binary>'")
}
} else if r, ok := value.(driver.Valuer); ok {
if value, err := r.Value(); err == nil && value != nil {
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value))
} else {
formattedValues = append(formattedValues, "NULL")
}
} else {
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value))
}
} else {
formattedValues = append(formattedValues, "NULL")
}
}
// differentiate between $n placeholders or else treat like ?
if numericPlaceHolderRegexp.MatchString(values[3].(string)) {
sql = values[3].(string)
for index, value := range formattedValues {
placeholder := fmt.Sprintf(`\$%d([^\d]|$)`, index+1)
sql = regexp.MustCompile(placeholder).ReplaceAllString(sql, value+"$1")
}
} else {
formattedValuesLength := len(formattedValues)
for index, value := range sqlRegexp.Split(values[3].(string), -1) {
sql += value
if index < formattedValuesLength {
sql += formattedValues[index]
}
}
}
messages = append(messages, sql)
messages = append(messages, fmt.Sprintf(" \n\033[36;31m[%v]\033[0m ", strconv.FormatInt(values[5].(int64), 10)+" rows affected or returned "))
} else {
messages = append(messages, "\033[31;1m")
messages = append(messages, values[2:]...)
messages = append(messages, "\033[0m")
}
}
return
}
type logger interface {
Print(v ...interface{})
}
// LogWriter log writer interface
type LogWriter interface {
Println(v ...interface{})
}
// Logger default logger
type Logger struct {
LogWriter
}
// Print format & print log
func (logger Logger) Print(values ...interface{}) {
logger.Println(LogFormatter(values...)...)
}

792
vendor/github.com/jinzhu/gorm/main.go generated vendored
View File

@@ -1,792 +0,0 @@
package gorm
import (
"database/sql"
"errors"
"fmt"
"reflect"
"strings"
"sync"
"time"
)
// DB contains information for current db connection
type DB struct {
Value interface{}
Error error
RowsAffected int64
// single db
db SQLCommon
blockGlobalUpdate bool
logMode int
logger logger
search *search
values sync.Map
// global db
parent *DB
callbacks *Callback
dialect Dialect
singularTable bool
}
// Open initialize a new db connection, need to import driver first, e.g:
//
// import _ "github.com/go-sql-driver/mysql"
// func main() {
// db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True&loc=Local")
// }
// GORM has wrapped some drivers, for easier to remember driver's import path, so you could import the mysql driver with
// import _ "github.com/jinzhu/gorm/dialects/mysql"
// // import _ "github.com/jinzhu/gorm/dialects/postgres"
// // import _ "github.com/jinzhu/gorm/dialects/sqlite"
// // import _ "github.com/jinzhu/gorm/dialects/mssql"
func Open(dialect string, args ...interface{}) (db *DB, err error) {
if len(args) == 0 {
err = errors.New("invalid database source")
return nil, err
}
var source string
var dbSQL SQLCommon
var ownDbSQL bool
switch value := args[0].(type) {
case string:
var driver = dialect
if len(args) == 1 {
source = value
} else if len(args) >= 2 {
driver = value
source = args[1].(string)
}
dbSQL, err = sql.Open(driver, source)
ownDbSQL = true
case SQLCommon:
dbSQL = value
ownDbSQL = false
default:
return nil, fmt.Errorf("invalid database source: %v is not a valid type", value)
}
db = &DB{
db: dbSQL,
logger: defaultLogger,
callbacks: DefaultCallback,
dialect: newDialect(dialect, dbSQL),
}
db.parent = db
if err != nil {
return
}
// Send a ping to make sure the database connection is alive.
if d, ok := dbSQL.(*sql.DB); ok {
if err = d.Ping(); err != nil && ownDbSQL {
d.Close()
}
}
return
}
// New clone a new db connection without search conditions
func (s *DB) New() *DB {
clone := s.clone()
clone.search = nil
clone.Value = nil
return clone
}
type closer interface {
Close() error
}
// Close close current db connection. If database connection is not an io.Closer, returns an error.
func (s *DB) Close() error {
if db, ok := s.parent.db.(closer); ok {
return db.Close()
}
return errors.New("can't close current db")
}
// DB get `*sql.DB` from current connection
// If the underlying database connection is not a *sql.DB, returns nil
func (s *DB) DB() *sql.DB {
db, _ := s.db.(*sql.DB)
return db
}
// CommonDB return the underlying `*sql.DB` or `*sql.Tx` instance, mainly intended to allow coexistence with legacy non-GORM code.
func (s *DB) CommonDB() SQLCommon {
return s.db
}
// Dialect get dialect
func (s *DB) Dialect() Dialect {
return s.dialect
}
// Callback return `Callbacks` container, you could add/change/delete callbacks with it
// db.Callback().Create().Register("update_created_at", updateCreated)
// Refer https://jinzhu.github.io/gorm/development.html#callbacks
func (s *DB) Callback() *Callback {
s.parent.callbacks = s.parent.callbacks.clone()
return s.parent.callbacks
}
// SetLogger replace default logger
func (s *DB) SetLogger(log logger) {
s.logger = log
}
// LogMode set log mode, `true` for detailed logs, `false` for no log, default, will only print error logs
func (s *DB) LogMode(enable bool) *DB {
if enable {
s.logMode = 2
} else {
s.logMode = 1
}
return s
}
// BlockGlobalUpdate if true, generates an error on update/delete without where clause.
// This is to prevent eventual error with empty objects updates/deletions
func (s *DB) BlockGlobalUpdate(enable bool) *DB {
s.blockGlobalUpdate = enable
return s
}
// HasBlockGlobalUpdate return state of block
func (s *DB) HasBlockGlobalUpdate() bool {
return s.blockGlobalUpdate
}
// SingularTable use singular table by default
func (s *DB) SingularTable(enable bool) {
modelStructsMap = sync.Map{}
s.parent.singularTable = enable
}
// NewScope create a scope for current operation
func (s *DB) NewScope(value interface{}) *Scope {
dbClone := s.clone()
dbClone.Value = value
return &Scope{db: dbClone, Search: dbClone.search.clone(), Value: value}
}
// QueryExpr returns the query as expr object
func (s *DB) QueryExpr() *expr {
scope := s.NewScope(s.Value)
scope.InstanceSet("skip_bindvar", true)
scope.prepareQuerySQL()
return Expr(scope.SQL, scope.SQLVars...)
}
// SubQuery returns the query as sub query
func (s *DB) SubQuery() *expr {
scope := s.NewScope(s.Value)
scope.InstanceSet("skip_bindvar", true)
scope.prepareQuerySQL()
return Expr(fmt.Sprintf("(%v)", scope.SQL), scope.SQLVars...)
}
// Where return a new relation, filter records with given conditions, accepts `map`, `struct` or `string` as conditions, refer http://jinzhu.github.io/gorm/crud.html#query
func (s *DB) Where(query interface{}, args ...interface{}) *DB {
return s.clone().search.Where(query, args...).db
}
// Or filter records that match before conditions or this one, similar to `Where`
func (s *DB) Or(query interface{}, args ...interface{}) *DB {
return s.clone().search.Or(query, args...).db
}
// Not filter records that don't match current conditions, similar to `Where`
func (s *DB) Not(query interface{}, args ...interface{}) *DB {
return s.clone().search.Not(query, args...).db
}
// Limit specify the number of records to be retrieved
func (s *DB) Limit(limit interface{}) *DB {
return s.clone().search.Limit(limit).db
}
// Offset specify the number of records to skip before starting to return the records
func (s *DB) Offset(offset interface{}) *DB {
return s.clone().search.Offset(offset).db
}
// Order specify order when retrieve records from database, set reorder to `true` to overwrite defined conditions
// db.Order("name DESC")
// db.Order("name DESC", true) // reorder
// db.Order(gorm.Expr("name = ? DESC", "first")) // sql expression
func (s *DB) Order(value interface{}, reorder ...bool) *DB {
return s.clone().search.Order(value, reorder...).db
}
// Select specify fields that you want to retrieve from database when querying, by default, will select all fields;
// When creating/updating, specify fields that you want to save to database
func (s *DB) Select(query interface{}, args ...interface{}) *DB {
return s.clone().search.Select(query, args...).db
}
// Omit specify fields that you want to ignore when saving to database for creating, updating
func (s *DB) Omit(columns ...string) *DB {
return s.clone().search.Omit(columns...).db
}
// Group specify the group method on the find
func (s *DB) Group(query string) *DB {
return s.clone().search.Group(query).db
}
// Having specify HAVING conditions for GROUP BY
func (s *DB) Having(query interface{}, values ...interface{}) *DB {
return s.clone().search.Having(query, values...).db
}
// Joins specify Joins conditions
// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
func (s *DB) Joins(query string, args ...interface{}) *DB {
return s.clone().search.Joins(query, args...).db
}
// Scopes pass current database connection to arguments `func(*DB) *DB`, which could be used to add conditions dynamically
// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB {
// return db.Where("amount > ?", 1000)
// }
//
// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB {
// return func (db *gorm.DB) *gorm.DB {
// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status)
// }
// }
//
// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders)
// Refer https://jinzhu.github.io/gorm/crud.html#scopes
func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB {
for _, f := range funcs {
s = f(s)
}
return s
}
// Unscoped return all record including deleted record, refer Soft Delete https://jinzhu.github.io/gorm/crud.html#soft-delete
func (s *DB) Unscoped() *DB {
return s.clone().search.unscoped().db
}
// Attrs initialize struct with argument if record not found with `FirstOrInit` https://jinzhu.github.io/gorm/crud.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/crud.html#firstorcreate
func (s *DB) Attrs(attrs ...interface{}) *DB {
return s.clone().search.Attrs(attrs...).db
}
// Assign assign result with argument regardless it is found or not with `FirstOrInit` https://jinzhu.github.io/gorm/crud.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/crud.html#firstorcreate
func (s *DB) Assign(attrs ...interface{}) *DB {
return s.clone().search.Assign(attrs...).db
}
// First find first record that match given conditions, order by primary key
func (s *DB) First(out interface{}, where ...interface{}) *DB {
newScope := s.NewScope(out)
newScope.Search.Limit(1)
return newScope.Set("gorm:order_by_primary_key", "ASC").
inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
}
// Take return a record that match given conditions, the order will depend on the database implementation
func (s *DB) Take(out interface{}, where ...interface{}) *DB {
newScope := s.NewScope(out)
newScope.Search.Limit(1)
return newScope.inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
}
// Last find last record that match given conditions, order by primary key
func (s *DB) Last(out interface{}, where ...interface{}) *DB {
newScope := s.NewScope(out)
newScope.Search.Limit(1)
return newScope.Set("gorm:order_by_primary_key", "DESC").
inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
}
// Find find records that match given conditions
func (s *DB) Find(out interface{}, where ...interface{}) *DB {
return s.NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
}
//Preloads preloads relations, don`t touch out
func (s *DB) Preloads(out interface{}) *DB {
return s.NewScope(out).InstanceSet("gorm:only_preload", 1).callCallbacks(s.parent.callbacks.queries).db
}
// Scan scan value to a struct
func (s *DB) Scan(dest interface{}) *DB {
return s.NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db
}
// Row return `*sql.Row` with given conditions
func (s *DB) Row() *sql.Row {
return s.NewScope(s.Value).row()
}
// Rows return `*sql.Rows` with given conditions
func (s *DB) Rows() (*sql.Rows, error) {
return s.NewScope(s.Value).rows()
}
// ScanRows scan `*sql.Rows` to give struct
func (s *DB) ScanRows(rows *sql.Rows, result interface{}) error {
var (
scope = s.NewScope(result)
clone = scope.db
columns, err = rows.Columns()
)
if clone.AddError(err) == nil {
scope.scan(rows, columns, scope.Fields())
}
return clone.Error
}
// Pluck used to query single column from a model as a map
// var ages []int64
// db.Find(&users).Pluck("age", &ages)
func (s *DB) Pluck(column string, value interface{}) *DB {
return s.NewScope(s.Value).pluck(column, value).db
}
// Count get how many records for a model
func (s *DB) Count(value interface{}) *DB {
return s.NewScope(s.Value).count(value).db
}
// Related get related associations
func (s *DB) Related(value interface{}, foreignKeys ...string) *DB {
return s.NewScope(s.Value).related(value, foreignKeys...).db
}
// FirstOrInit find first matched record or initialize a new one with given conditions (only works with struct, map conditions)
// https://jinzhu.github.io/gorm/crud.html#firstorinit
func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
c := s.clone()
if result := c.First(out, where...); result.Error != nil {
if !result.RecordNotFound() {
return result
}
c.NewScope(out).inlineCondition(where...).initialize()
} else {
c.NewScope(out).updatedAttrsWithValues(c.search.assignAttrs)
}
return c
}
// FirstOrCreate find first matched record or create a new one with given conditions (only works with struct, map conditions)
// https://jinzhu.github.io/gorm/crud.html#firstorcreate
func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB {
c := s.clone()
if result := s.First(out, where...); result.Error != nil {
if !result.RecordNotFound() {
return result
}
return c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(c.parent.callbacks.creates).db
} else if len(c.search.assignAttrs) > 0 {
return c.NewScope(out).InstanceSet("gorm:update_interface", c.search.assignAttrs).callCallbacks(c.parent.callbacks.updates).db
}
return c
}
// Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
func (s *DB) Update(attrs ...interface{}) *DB {
return s.Updates(toSearchableMap(attrs...), true)
}
// Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB {
return s.NewScope(s.Value).
Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0).
InstanceSet("gorm:update_interface", values).
callCallbacks(s.parent.callbacks.updates).db
}
// UpdateColumn update attributes without callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
func (s *DB) UpdateColumn(attrs ...interface{}) *DB {
return s.UpdateColumns(toSearchableMap(attrs...))
}
// UpdateColumns update attributes without callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
func (s *DB) UpdateColumns(values interface{}) *DB {
return s.NewScope(s.Value).
Set("gorm:update_column", true).
Set("gorm:save_associations", false).
InstanceSet("gorm:update_interface", values).
callCallbacks(s.parent.callbacks.updates).db
}
// Save update value in database, if the value doesn't have primary key, will insert it
func (s *DB) Save(value interface{}) *DB {
scope := s.NewScope(value)
if !scope.PrimaryKeyZero() {
newDB := scope.callCallbacks(s.parent.callbacks.updates).db
if newDB.Error == nil && newDB.RowsAffected == 0 {
return s.New().FirstOrCreate(value)
}
return newDB
}
return scope.callCallbacks(s.parent.callbacks.creates).db
}
// Create insert the value into database
func (s *DB) Create(value interface{}) *DB {
scope := s.NewScope(value)
return scope.callCallbacks(s.parent.callbacks.creates).db
}
// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition
func (s *DB) Delete(value interface{}, where ...interface{}) *DB {
return s.NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callbacks.deletes).db
}
// Raw use raw sql as conditions, won't run it unless invoked by other methods
// db.Raw("SELECT name, age FROM users WHERE name = ?", 3).Scan(&result)
func (s *DB) Raw(sql string, values ...interface{}) *DB {
return s.clone().search.Raw(true).Where(sql, values...).db
}
// Exec execute raw sql
func (s *DB) Exec(sql string, values ...interface{}) *DB {
scope := s.NewScope(nil)
generatedSQL := scope.buildCondition(map[string]interface{}{"query": sql, "args": values}, true)
generatedSQL = strings.TrimSuffix(strings.TrimPrefix(generatedSQL, "("), ")")
scope.Raw(generatedSQL)
return scope.Exec().db
}
// Model specify the model you would like to run db operations
// // update all users's name to `hello`
// db.Model(&User{}).Update("name", "hello")
// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello`
// db.Model(&user).Update("name", "hello")
func (s *DB) Model(value interface{}) *DB {
c := s.clone()
c.Value = value
return c
}
// Table specify the table you would like to run db operations
func (s *DB) Table(name string) *DB {
clone := s.clone()
clone.search.Table(name)
clone.Value = nil
return clone
}
// Debug start debug mode
func (s *DB) Debug() *DB {
return s.clone().LogMode(true)
}
// Begin begin a transaction
func (s *DB) Begin() *DB {
c := s.clone()
if db, ok := c.db.(sqlDb); ok && db != nil {
tx, err := db.Begin()
c.db = interface{}(tx).(SQLCommon)
c.dialect.SetDB(c.db)
c.AddError(err)
} else {
c.AddError(ErrCantStartTransaction)
}
return c
}
// Commit commit a transaction
func (s *DB) Commit() *DB {
var emptySQLTx *sql.Tx
if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx {
s.AddError(db.Commit())
} else {
s.AddError(ErrInvalidTransaction)
}
return s
}
// Rollback rollback a transaction
func (s *DB) Rollback() *DB {
var emptySQLTx *sql.Tx
if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx {
s.AddError(db.Rollback())
} else {
s.AddError(ErrInvalidTransaction)
}
return s
}
// NewRecord check if value's primary key is blank
func (s *DB) NewRecord(value interface{}) bool {
return s.NewScope(value).PrimaryKeyZero()
}
// RecordNotFound check if returning ErrRecordNotFound error
func (s *DB) RecordNotFound() bool {
for _, err := range s.GetErrors() {
if err == ErrRecordNotFound {
return true
}
}
return false
}
// CreateTable create table for models
func (s *DB) CreateTable(models ...interface{}) *DB {
db := s.Unscoped()
for _, model := range models {
db = db.NewScope(model).createTable().db
}
return db
}
// DropTable drop table for models
func (s *DB) DropTable(values ...interface{}) *DB {
db := s.clone()
for _, value := range values {
if tableName, ok := value.(string); ok {
db = db.Table(tableName)
}
db = db.NewScope(value).dropTable().db
}
return db
}
// DropTableIfExists drop table if it is exist
func (s *DB) DropTableIfExists(values ...interface{}) *DB {
db := s.clone()
for _, value := range values {
if s.HasTable(value) {
db.AddError(s.DropTable(value).Error)
}
}
return db
}
// HasTable check has table or not
func (s *DB) HasTable(value interface{}) bool {
var (
scope = s.NewScope(value)
tableName string
)
if name, ok := value.(string); ok {
tableName = name
} else {
tableName = scope.TableName()
}
has := scope.Dialect().HasTable(tableName)
s.AddError(scope.db.Error)
return has
}
// AutoMigrate run auto migration for given models, will only add missing fields, won't delete/change current data
func (s *DB) AutoMigrate(values ...interface{}) *DB {
db := s.Unscoped()
for _, value := range values {
db = db.NewScope(value).autoMigrate().db
}
return db
}
// ModifyColumn modify column to type
func (s *DB) ModifyColumn(column string, typ string) *DB {
scope := s.NewScope(s.Value)
scope.modifyColumn(column, typ)
return scope.db
}
// DropColumn drop a column
func (s *DB) DropColumn(column string) *DB {
scope := s.NewScope(s.Value)
scope.dropColumn(column)
return scope.db
}
// AddIndex add index for columns with given name
func (s *DB) AddIndex(indexName string, columns ...string) *DB {
scope := s.Unscoped().NewScope(s.Value)
scope.addIndex(false, indexName, columns...)
return scope.db
}
// AddUniqueIndex add unique index for columns with given name
func (s *DB) AddUniqueIndex(indexName string, columns ...string) *DB {
scope := s.Unscoped().NewScope(s.Value)
scope.addIndex(true, indexName, columns...)
return scope.db
}
// RemoveIndex remove index with name
func (s *DB) RemoveIndex(indexName string) *DB {
scope := s.NewScope(s.Value)
scope.removeIndex(indexName)
return scope.db
}
// AddForeignKey Add foreign key to the given scope, e.g:
// db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT")
func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB {
scope := s.NewScope(s.Value)
scope.addForeignKey(field, dest, onDelete, onUpdate)
return scope.db
}
// RemoveForeignKey Remove foreign key from the given scope, e.g:
// db.Model(&User{}).RemoveForeignKey("city_id", "cities(id)")
func (s *DB) RemoveForeignKey(field string, dest string) *DB {
scope := s.clone().NewScope(s.Value)
scope.removeForeignKey(field, dest)
return scope.db
}
// Association start `Association Mode` to handler relations things easir in that mode, refer: https://jinzhu.github.io/gorm/associations.html#association-mode
func (s *DB) Association(column string) *Association {
var err error
var scope = s.Set("gorm:association:source", s.Value).NewScope(s.Value)
if primaryField := scope.PrimaryField(); primaryField.IsBlank {
err = errors.New("primary key can't be nil")
} else {
if field, ok := scope.FieldByName(column); ok {
if field.Relationship == nil || len(field.Relationship.ForeignFieldNames) == 0 {
err = fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type())
} else {
return &Association{scope: scope, column: column, field: field}
}
} else {
err = fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column)
}
}
return &Association{Error: err}
}
// Preload preload associations with given conditions
// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
func (s *DB) Preload(column string, conditions ...interface{}) *DB {
return s.clone().search.Preload(column, conditions...).db
}
// Set set setting by name, which could be used in callbacks, will clone a new db, and update its setting
func (s *DB) Set(name string, value interface{}) *DB {
return s.clone().InstantSet(name, value)
}
// InstantSet instant set setting, will affect current db
func (s *DB) InstantSet(name string, value interface{}) *DB {
s.values.Store(name, value)
return s
}
// Get get setting by name
func (s *DB) Get(name string) (value interface{}, ok bool) {
value, ok = s.values.Load(name)
return
}
// SetJoinTableHandler set a model's join table handler for a relation
func (s *DB) SetJoinTableHandler(source interface{}, column string, handler JoinTableHandlerInterface) {
scope := s.NewScope(source)
for _, field := range scope.GetModelStruct().StructFields {
if field.Name == column || field.DBName == column {
if many2many, _ := field.TagSettingsGet("MANY2MANY"); many2many != "" {
source := (&Scope{Value: source}).GetModelStruct().ModelType
destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType
handler.Setup(field.Relationship, many2many, source, destination)
field.Relationship.JoinTableHandler = handler
if table := handler.Table(s); scope.Dialect().HasTable(table) {
s.Table(table).AutoMigrate(handler)
}
}
}
}
}
// AddError add error to the db
func (s *DB) AddError(err error) error {
if err != nil {
if err != ErrRecordNotFound {
if s.logMode == 0 {
go s.print(fileWithLineNum(), err)
} else {
s.log(err)
}
errors := Errors(s.GetErrors())
errors = errors.Add(err)
if len(errors) > 1 {
err = errors
}
}
s.Error = err
}
return err
}
// GetErrors get happened errors from the db
func (s *DB) GetErrors() []error {
if errs, ok := s.Error.(Errors); ok {
return errs
} else if s.Error != nil {
return []error{s.Error}
}
return []error{}
}
////////////////////////////////////////////////////////////////////////////////
// Private Methods For DB
////////////////////////////////////////////////////////////////////////////////
func (s *DB) clone() *DB {
db := &DB{
db: s.db,
parent: s.parent,
logger: s.logger,
logMode: s.logMode,
Value: s.Value,
Error: s.Error,
blockGlobalUpdate: s.blockGlobalUpdate,
dialect: newDialect(s.dialect.GetName(), s.db),
}
s.values.Range(func(k, v interface{}) bool {
db.values.Store(k, v)
return true
})
if s.search == nil {
db.search = &search{limit: -1, offset: -1}
} else {
db.search = s.search.clone()
}
db.search.db = db
return db
}
func (s *DB) print(v ...interface{}) {
s.logger.Print(v...)
}
func (s *DB) log(v ...interface{}) {
if s != nil && s.logMode == 2 {
s.print(append([]interface{}{"log", fileWithLineNum()}, v...)...)
}
}
func (s *DB) slog(sql string, t time.Time, vars ...interface{}) {
if s.logMode == 2 {
s.print("sql", fileWithLineNum(), NowFunc().Sub(t), sql, vars, s.RowsAffected)
}
}

View File

@@ -1,14 +0,0 @@
package gorm
import "time"
// Model base model definition, including fields `ID`, `CreatedAt`, `UpdatedAt`, `DeletedAt`, which could be embedded in your models
// type User struct {
// gorm.Model
// }
type Model struct {
ID uint `gorm:"primary_key"`
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt *time.Time `sql:"index"`
}

View File

@@ -1,640 +0,0 @@
package gorm
import (
"database/sql"
"errors"
"go/ast"
"reflect"
"strings"
"sync"
"time"
"github.com/jinzhu/inflection"
)
// DefaultTableNameHandler default table name handler
var DefaultTableNameHandler = func(db *DB, defaultTableName string) string {
return defaultTableName
}
var modelStructsMap sync.Map
// ModelStruct model definition
type ModelStruct struct {
PrimaryFields []*StructField
StructFields []*StructField
ModelType reflect.Type
defaultTableName string
l sync.Mutex
}
// TableName returns model's table name
func (s *ModelStruct) TableName(db *DB) string {
s.l.Lock()
defer s.l.Unlock()
if s.defaultTableName == "" && db != nil && s.ModelType != nil {
// Set default table name
if tabler, ok := reflect.New(s.ModelType).Interface().(tabler); ok {
s.defaultTableName = tabler.TableName()
} else {
tableName := ToTableName(s.ModelType.Name())
if db == nil || !db.parent.singularTable {
tableName = inflection.Plural(tableName)
}
s.defaultTableName = tableName
}
}
return DefaultTableNameHandler(db, s.defaultTableName)
}
// StructField model field's struct definition
type StructField struct {
DBName string
Name string
Names []string
IsPrimaryKey bool
IsNormal bool
IsIgnored bool
IsScanner bool
HasDefaultValue bool
Tag reflect.StructTag
TagSettings map[string]string
Struct reflect.StructField
IsForeignKey bool
Relationship *Relationship
tagSettingsLock sync.RWMutex
}
// TagSettingsSet Sets a tag in the tag settings map
func (s *StructField) TagSettingsSet(key, val string) {
s.tagSettingsLock.Lock()
defer s.tagSettingsLock.Unlock()
s.TagSettings[key] = val
}
// TagSettingsGet returns a tag from the tag settings
func (s *StructField) TagSettingsGet(key string) (string, bool) {
s.tagSettingsLock.RLock()
defer s.tagSettingsLock.RUnlock()
val, ok := s.TagSettings[key]
return val, ok
}
// TagSettingsDelete deletes a tag
func (s *StructField) TagSettingsDelete(key string) {
s.tagSettingsLock.Lock()
defer s.tagSettingsLock.Unlock()
delete(s.TagSettings, key)
}
func (structField *StructField) clone() *StructField {
clone := &StructField{
DBName: structField.DBName,
Name: structField.Name,
Names: structField.Names,
IsPrimaryKey: structField.IsPrimaryKey,
IsNormal: structField.IsNormal,
IsIgnored: structField.IsIgnored,
IsScanner: structField.IsScanner,
HasDefaultValue: structField.HasDefaultValue,
Tag: structField.Tag,
TagSettings: map[string]string{},
Struct: structField.Struct,
IsForeignKey: structField.IsForeignKey,
}
if structField.Relationship != nil {
relationship := *structField.Relationship
clone.Relationship = &relationship
}
// copy the struct field tagSettings, they should be read-locked while they are copied
structField.tagSettingsLock.Lock()
defer structField.tagSettingsLock.Unlock()
for key, value := range structField.TagSettings {
clone.TagSettings[key] = value
}
return clone
}
// Relationship described the relationship between models
type Relationship struct {
Kind string
PolymorphicType string
PolymorphicDBName string
PolymorphicValue string
ForeignFieldNames []string
ForeignDBNames []string
AssociationForeignFieldNames []string
AssociationForeignDBNames []string
JoinTableHandler JoinTableHandlerInterface
}
func getForeignField(column string, fields []*StructField) *StructField {
for _, field := range fields {
if field.Name == column || field.DBName == column || field.DBName == ToColumnName(column) {
return field
}
}
return nil
}
// GetModelStruct get value's model struct, relationships based on struct and tag definition
func (scope *Scope) GetModelStruct() *ModelStruct {
var modelStruct ModelStruct
// Scope value can't be nil
if scope.Value == nil {
return &modelStruct
}
reflectType := reflect.ValueOf(scope.Value).Type()
for reflectType.Kind() == reflect.Slice || reflectType.Kind() == reflect.Ptr {
reflectType = reflectType.Elem()
}
// Scope value need to be a struct
if reflectType.Kind() != reflect.Struct {
return &modelStruct
}
// Get Cached model struct
if value, ok := modelStructsMap.Load(reflectType); ok && value != nil {
return value.(*ModelStruct)
}
modelStruct.ModelType = reflectType
// Get all fields
for i := 0; i < reflectType.NumField(); i++ {
if fieldStruct := reflectType.Field(i); ast.IsExported(fieldStruct.Name) {
field := &StructField{
Struct: fieldStruct,
Name: fieldStruct.Name,
Names: []string{fieldStruct.Name},
Tag: fieldStruct.Tag,
TagSettings: parseTagSetting(fieldStruct.Tag),
}
// is ignored field
if _, ok := field.TagSettingsGet("-"); ok {
field.IsIgnored = true
} else {
if _, ok := field.TagSettingsGet("PRIMARY_KEY"); ok {
field.IsPrimaryKey = true
modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field)
}
if _, ok := field.TagSettingsGet("DEFAULT"); ok {
field.HasDefaultValue = true
}
if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok && !field.IsPrimaryKey {
field.HasDefaultValue = true
}
indirectType := fieldStruct.Type
for indirectType.Kind() == reflect.Ptr {
indirectType = indirectType.Elem()
}
fieldValue := reflect.New(indirectType).Interface()
if _, isScanner := fieldValue.(sql.Scanner); isScanner {
// is scanner
field.IsScanner, field.IsNormal = true, true
if indirectType.Kind() == reflect.Struct {
for i := 0; i < indirectType.NumField(); i++ {
for key, value := range parseTagSetting(indirectType.Field(i).Tag) {
if _, ok := field.TagSettingsGet(key); !ok {
field.TagSettingsSet(key, value)
}
}
}
}
} else if _, isTime := fieldValue.(*time.Time); isTime {
// is time
field.IsNormal = true
} else if _, ok := field.TagSettingsGet("EMBEDDED"); ok || fieldStruct.Anonymous {
// is embedded struct
for _, subField := range scope.New(fieldValue).GetModelStruct().StructFields {
subField = subField.clone()
subField.Names = append([]string{fieldStruct.Name}, subField.Names...)
if prefix, ok := field.TagSettingsGet("EMBEDDED_PREFIX"); ok {
subField.DBName = prefix + subField.DBName
}
if subField.IsPrimaryKey {
if _, ok := subField.TagSettingsGet("PRIMARY_KEY"); ok {
modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, subField)
} else {
subField.IsPrimaryKey = false
}
}
if subField.Relationship != nil && subField.Relationship.JoinTableHandler != nil {
if joinTableHandler, ok := subField.Relationship.JoinTableHandler.(*JoinTableHandler); ok {
newJoinTableHandler := &JoinTableHandler{}
newJoinTableHandler.Setup(subField.Relationship, joinTableHandler.TableName, reflectType, joinTableHandler.Destination.ModelType)
subField.Relationship.JoinTableHandler = newJoinTableHandler
}
}
modelStruct.StructFields = append(modelStruct.StructFields, subField)
}
continue
} else {
// build relationships
switch indirectType.Kind() {
case reflect.Slice:
defer func(field *StructField) {
var (
relationship = &Relationship{}
toScope = scope.New(reflect.New(field.Struct.Type).Interface())
foreignKeys []string
associationForeignKeys []string
elemType = field.Struct.Type
)
if foreignKey, _ := field.TagSettingsGet("FOREIGNKEY"); foreignKey != "" {
foreignKeys = strings.Split(foreignKey, ",")
}
if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_FOREIGNKEY"); foreignKey != "" {
associationForeignKeys = strings.Split(foreignKey, ",")
} else if foreignKey, _ := field.TagSettingsGet("ASSOCIATIONFOREIGNKEY"); foreignKey != "" {
associationForeignKeys = strings.Split(foreignKey, ",")
}
for elemType.Kind() == reflect.Slice || elemType.Kind() == reflect.Ptr {
elemType = elemType.Elem()
}
if elemType.Kind() == reflect.Struct {
if many2many, _ := field.TagSettingsGet("MANY2MANY"); many2many != "" {
relationship.Kind = "many_to_many"
{ // Foreign Keys for Source
joinTableDBNames := []string{}
if foreignKey, _ := field.TagSettingsGet("JOINTABLE_FOREIGNKEY"); foreignKey != "" {
joinTableDBNames = strings.Split(foreignKey, ",")
}
// if no foreign keys defined with tag
if len(foreignKeys) == 0 {
for _, field := range modelStruct.PrimaryFields {
foreignKeys = append(foreignKeys, field.DBName)
}
}
for idx, foreignKey := range foreignKeys {
if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil {
// source foreign keys (db names)
relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.DBName)
// setup join table foreign keys for source
if len(joinTableDBNames) > idx {
// if defined join table's foreign key
relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBNames[idx])
} else {
defaultJointableForeignKey := ToColumnName(reflectType.Name()) + "_" + foreignField.DBName
relationship.ForeignDBNames = append(relationship.ForeignDBNames, defaultJointableForeignKey)
}
}
}
}
{ // Foreign Keys for Association (Destination)
associationJoinTableDBNames := []string{}
if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_JOINTABLE_FOREIGNKEY"); foreignKey != "" {
associationJoinTableDBNames = strings.Split(foreignKey, ",")
}
// if no association foreign keys defined with tag
if len(associationForeignKeys) == 0 {
for _, field := range toScope.PrimaryFields() {
associationForeignKeys = append(associationForeignKeys, field.DBName)
}
}
for idx, name := range associationForeignKeys {
if field, ok := toScope.FieldByName(name); ok {
// association foreign keys (db names)
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName)
// setup join table foreign keys for association
if len(associationJoinTableDBNames) > idx {
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationJoinTableDBNames[idx])
} else {
// join table foreign keys for association
joinTableDBName := ToColumnName(elemType.Name()) + "_" + field.DBName
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName)
}
}
}
}
joinTableHandler := JoinTableHandler{}
joinTableHandler.Setup(relationship, ToTableName(many2many), reflectType, elemType)
relationship.JoinTableHandler = &joinTableHandler
field.Relationship = relationship
} else {
// User has many comments, associationType is User, comment use UserID as foreign key
var associationType = reflectType.Name()
var toFields = toScope.GetStructFields()
relationship.Kind = "has_many"
if polymorphic, _ := field.TagSettingsGet("POLYMORPHIC"); polymorphic != "" {
// Dog has many toys, tag polymorphic is Owner, then associationType is Owner
// Toy use OwnerID, OwnerType ('dogs') as foreign key
if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil {
associationType = polymorphic
relationship.PolymorphicType = polymorphicType.Name
relationship.PolymorphicDBName = polymorphicType.DBName
// if Dog has multiple set of toys set name of the set (instead of default 'dogs')
if value, ok := field.TagSettingsGet("POLYMORPHIC_VALUE"); ok {
relationship.PolymorphicValue = value
} else {
relationship.PolymorphicValue = scope.TableName()
}
polymorphicType.IsForeignKey = true
}
}
// if no foreign keys defined with tag
if len(foreignKeys) == 0 {
// if no association foreign keys defined with tag
if len(associationForeignKeys) == 0 {
for _, field := range modelStruct.PrimaryFields {
foreignKeys = append(foreignKeys, associationType+field.Name)
associationForeignKeys = append(associationForeignKeys, field.Name)
}
} else {
// generate foreign keys from defined association foreign keys
for _, scopeFieldName := range associationForeignKeys {
if foreignField := getForeignField(scopeFieldName, modelStruct.StructFields); foreignField != nil {
foreignKeys = append(foreignKeys, associationType+foreignField.Name)
associationForeignKeys = append(associationForeignKeys, foreignField.Name)
}
}
}
} else {
// generate association foreign keys from foreign keys
if len(associationForeignKeys) == 0 {
for _, foreignKey := range foreignKeys {
if strings.HasPrefix(foreignKey, associationType) {
associationForeignKey := strings.TrimPrefix(foreignKey, associationType)
if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil {
associationForeignKeys = append(associationForeignKeys, associationForeignKey)
}
}
}
if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 {
associationForeignKeys = []string{scope.PrimaryKey()}
}
} else if len(foreignKeys) != len(associationForeignKeys) {
scope.Err(errors.New("invalid foreign keys, should have same length"))
return
}
}
for idx, foreignKey := range foreignKeys {
if foreignField := getForeignField(foreignKey, toFields); foreignField != nil {
if associationField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); associationField != nil {
// source foreign keys
foreignField.IsForeignKey = true
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name)
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName)
// association foreign keys
relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name)
relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName)
}
}
}
if len(relationship.ForeignFieldNames) != 0 {
field.Relationship = relationship
}
}
} else {
field.IsNormal = true
}
}(field)
case reflect.Struct:
defer func(field *StructField) {
var (
// user has one profile, associationType is User, profile use UserID as foreign key
// user belongs to profile, associationType is Profile, user use ProfileID as foreign key
associationType = reflectType.Name()
relationship = &Relationship{}
toScope = scope.New(reflect.New(field.Struct.Type).Interface())
toFields = toScope.GetStructFields()
tagForeignKeys []string
tagAssociationForeignKeys []string
)
if foreignKey, _ := field.TagSettingsGet("FOREIGNKEY"); foreignKey != "" {
tagForeignKeys = strings.Split(foreignKey, ",")
}
if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_FOREIGNKEY"); foreignKey != "" {
tagAssociationForeignKeys = strings.Split(foreignKey, ",")
} else if foreignKey, _ := field.TagSettingsGet("ASSOCIATIONFOREIGNKEY"); foreignKey != "" {
tagAssociationForeignKeys = strings.Split(foreignKey, ",")
}
if polymorphic, _ := field.TagSettingsGet("POLYMORPHIC"); polymorphic != "" {
// Cat has one toy, tag polymorphic is Owner, then associationType is Owner
// Toy use OwnerID, OwnerType ('cats') as foreign key
if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil {
associationType = polymorphic
relationship.PolymorphicType = polymorphicType.Name
relationship.PolymorphicDBName = polymorphicType.DBName
// if Cat has several different types of toys set name for each (instead of default 'cats')
if value, ok := field.TagSettingsGet("POLYMORPHIC_VALUE"); ok {
relationship.PolymorphicValue = value
} else {
relationship.PolymorphicValue = scope.TableName()
}
polymorphicType.IsForeignKey = true
}
}
// Has One
{
var foreignKeys = tagForeignKeys
var associationForeignKeys = tagAssociationForeignKeys
// if no foreign keys defined with tag
if len(foreignKeys) == 0 {
// if no association foreign keys defined with tag
if len(associationForeignKeys) == 0 {
for _, primaryField := range modelStruct.PrimaryFields {
foreignKeys = append(foreignKeys, associationType+primaryField.Name)
associationForeignKeys = append(associationForeignKeys, primaryField.Name)
}
} else {
// generate foreign keys form association foreign keys
for _, associationForeignKey := range tagAssociationForeignKeys {
if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil {
foreignKeys = append(foreignKeys, associationType+foreignField.Name)
associationForeignKeys = append(associationForeignKeys, foreignField.Name)
}
}
}
} else {
// generate association foreign keys from foreign keys
if len(associationForeignKeys) == 0 {
for _, foreignKey := range foreignKeys {
if strings.HasPrefix(foreignKey, associationType) {
associationForeignKey := strings.TrimPrefix(foreignKey, associationType)
if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil {
associationForeignKeys = append(associationForeignKeys, associationForeignKey)
}
}
}
if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 {
associationForeignKeys = []string{scope.PrimaryKey()}
}
} else if len(foreignKeys) != len(associationForeignKeys) {
scope.Err(errors.New("invalid foreign keys, should have same length"))
return
}
}
for idx, foreignKey := range foreignKeys {
if foreignField := getForeignField(foreignKey, toFields); foreignField != nil {
if scopeField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); scopeField != nil {
foreignField.IsForeignKey = true
// source foreign keys
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scopeField.Name)
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scopeField.DBName)
// association foreign keys
relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name)
relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName)
}
}
}
}
if len(relationship.ForeignFieldNames) != 0 {
relationship.Kind = "has_one"
field.Relationship = relationship
} else {
var foreignKeys = tagForeignKeys
var associationForeignKeys = tagAssociationForeignKeys
if len(foreignKeys) == 0 {
// generate foreign keys & association foreign keys
if len(associationForeignKeys) == 0 {
for _, primaryField := range toScope.PrimaryFields() {
foreignKeys = append(foreignKeys, field.Name+primaryField.Name)
associationForeignKeys = append(associationForeignKeys, primaryField.Name)
}
} else {
// generate foreign keys with association foreign keys
for _, associationForeignKey := range associationForeignKeys {
if foreignField := getForeignField(associationForeignKey, toFields); foreignField != nil {
foreignKeys = append(foreignKeys, field.Name+foreignField.Name)
associationForeignKeys = append(associationForeignKeys, foreignField.Name)
}
}
}
} else {
// generate foreign keys & association foreign keys
if len(associationForeignKeys) == 0 {
for _, foreignKey := range foreignKeys {
if strings.HasPrefix(foreignKey, field.Name) {
associationForeignKey := strings.TrimPrefix(foreignKey, field.Name)
if foreignField := getForeignField(associationForeignKey, toFields); foreignField != nil {
associationForeignKeys = append(associationForeignKeys, associationForeignKey)
}
}
}
if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 {
associationForeignKeys = []string{toScope.PrimaryKey()}
}
} else if len(foreignKeys) != len(associationForeignKeys) {
scope.Err(errors.New("invalid foreign keys, should have same length"))
return
}
}
for idx, foreignKey := range foreignKeys {
if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil {
if associationField := getForeignField(associationForeignKeys[idx], toFields); associationField != nil {
foreignField.IsForeignKey = true
// association foreign keys
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name)
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName)
// source foreign keys
relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name)
relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName)
}
}
}
if len(relationship.ForeignFieldNames) != 0 {
relationship.Kind = "belongs_to"
field.Relationship = relationship
}
}
}(field)
default:
field.IsNormal = true
}
}
}
// Even it is ignored, also possible to decode db value into the field
if value, ok := field.TagSettingsGet("COLUMN"); ok {
field.DBName = value
} else {
field.DBName = ToColumnName(fieldStruct.Name)
}
modelStruct.StructFields = append(modelStruct.StructFields, field)
}
}
if len(modelStruct.PrimaryFields) == 0 {
if field := getForeignField("id", modelStruct.StructFields); field != nil {
field.IsPrimaryKey = true
modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field)
}
}
modelStructsMap.Store(reflectType, &modelStruct)
return &modelStruct
}
// GetStructFields get model's field structs
func (scope *Scope) GetStructFields() (fields []*StructField) {
return scope.GetModelStruct().StructFields
}
func parseTagSetting(tags reflect.StructTag) map[string]string {
setting := map[string]string{}
for _, str := range []string{tags.Get("sql"), tags.Get("gorm")} {
tags := strings.Split(str, ";")
for _, value := range tags {
v := strings.Split(value, ":")
k := strings.TrimSpace(strings.ToUpper(v[0]))
if len(v) >= 2 {
setting[k] = strings.Join(v[1:], ":")
} else {
setting[k] = k
}
}
}
return setting
}

View File

@@ -1,124 +0,0 @@
package gorm
import (
"bytes"
"strings"
)
// Namer is a function type which is given a string and return a string
type Namer func(string) string
// NamingStrategy represents naming strategies
type NamingStrategy struct {
DB Namer
Table Namer
Column Namer
}
// TheNamingStrategy is being initialized with defaultNamingStrategy
var TheNamingStrategy = &NamingStrategy{
DB: defaultNamer,
Table: defaultNamer,
Column: defaultNamer,
}
// AddNamingStrategy sets the naming strategy
func AddNamingStrategy(ns *NamingStrategy) {
if ns.DB == nil {
ns.DB = defaultNamer
}
if ns.Table == nil {
ns.Table = defaultNamer
}
if ns.Column == nil {
ns.Column = defaultNamer
}
TheNamingStrategy = ns
}
// DBName alters the given name by DB
func (ns *NamingStrategy) DBName(name string) string {
return ns.DB(name)
}
// TableName alters the given name by Table
func (ns *NamingStrategy) TableName(name string) string {
return ns.Table(name)
}
// ColumnName alters the given name by Column
func (ns *NamingStrategy) ColumnName(name string) string {
return ns.Column(name)
}
// ToDBName convert string to db name
func ToDBName(name string) string {
return TheNamingStrategy.DBName(name)
}
// ToTableName convert string to table name
func ToTableName(name string) string {
return TheNamingStrategy.TableName(name)
}
// ToColumnName convert string to db name
func ToColumnName(name string) string {
return TheNamingStrategy.ColumnName(name)
}
var smap = newSafeMap()
func defaultNamer(name string) string {
const (
lower = false
upper = true
)
if v := smap.Get(name); v != "" {
return v
}
if name == "" {
return ""
}
var (
value = commonInitialismsReplacer.Replace(name)
buf = bytes.NewBufferString("")
lastCase, currCase, nextCase, nextNumber bool
)
for i, v := range value[:len(value)-1] {
nextCase = bool(value[i+1] >= 'A' && value[i+1] <= 'Z')
nextNumber = bool(value[i+1] >= '0' && value[i+1] <= '9')
if i > 0 {
if currCase == upper {
if lastCase == upper && (nextCase == upper || nextNumber == upper) {
buf.WriteRune(v)
} else {
if value[i-1] != '_' && value[i+1] != '_' {
buf.WriteRune('_')
}
buf.WriteRune(v)
}
} else {
buf.WriteRune(v)
if i == len(value)-2 && (nextCase == upper && nextNumber == lower) {
buf.WriteRune('_')
}
}
} else {
currCase = upper
buf.WriteRune(v)
}
lastCase = currCase
currCase = nextCase
}
buf.WriteByte(value[len(value)-1])
s := strings.ToLower(buf.String())
smap.Set(name, s)
return s
}

1397
vendor/github.com/jinzhu/gorm/scope.go generated vendored

File diff suppressed because it is too large Load Diff

View File

@@ -1,153 +0,0 @@
package gorm
import (
"fmt"
)
type search struct {
db *DB
whereConditions []map[string]interface{}
orConditions []map[string]interface{}
notConditions []map[string]interface{}
havingConditions []map[string]interface{}
joinConditions []map[string]interface{}
initAttrs []interface{}
assignAttrs []interface{}
selects map[string]interface{}
omits []string
orders []interface{}
preload []searchPreload
offset interface{}
limit interface{}
group string
tableName string
raw bool
Unscoped bool
ignoreOrderQuery bool
}
type searchPreload struct {
schema string
conditions []interface{}
}
func (s *search) clone() *search {
clone := *s
return &clone
}
func (s *search) Where(query interface{}, values ...interface{}) *search {
s.whereConditions = append(s.whereConditions, map[string]interface{}{"query": query, "args": values})
return s
}
func (s *search) Not(query interface{}, values ...interface{}) *search {
s.notConditions = append(s.notConditions, map[string]interface{}{"query": query, "args": values})
return s
}
func (s *search) Or(query interface{}, values ...interface{}) *search {
s.orConditions = append(s.orConditions, map[string]interface{}{"query": query, "args": values})
return s
}
func (s *search) Attrs(attrs ...interface{}) *search {
s.initAttrs = append(s.initAttrs, toSearchableMap(attrs...))
return s
}
func (s *search) Assign(attrs ...interface{}) *search {
s.assignAttrs = append(s.assignAttrs, toSearchableMap(attrs...))
return s
}
func (s *search) Order(value interface{}, reorder ...bool) *search {
if len(reorder) > 0 && reorder[0] {
s.orders = []interface{}{}
}
if value != nil && value != "" {
s.orders = append(s.orders, value)
}
return s
}
func (s *search) Select(query interface{}, args ...interface{}) *search {
s.selects = map[string]interface{}{"query": query, "args": args}
return s
}
func (s *search) Omit(columns ...string) *search {
s.omits = columns
return s
}
func (s *search) Limit(limit interface{}) *search {
s.limit = limit
return s
}
func (s *search) Offset(offset interface{}) *search {
s.offset = offset
return s
}
func (s *search) Group(query string) *search {
s.group = s.getInterfaceAsSQL(query)
return s
}
func (s *search) Having(query interface{}, values ...interface{}) *search {
if val, ok := query.(*expr); ok {
s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": val.expr, "args": val.args})
} else {
s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": query, "args": values})
}
return s
}
func (s *search) Joins(query string, values ...interface{}) *search {
s.joinConditions = append(s.joinConditions, map[string]interface{}{"query": query, "args": values})
return s
}
func (s *search) Preload(schema string, values ...interface{}) *search {
var preloads []searchPreload
for _, preload := range s.preload {
if preload.schema != schema {
preloads = append(preloads, preload)
}
}
preloads = append(preloads, searchPreload{schema, values})
s.preload = preloads
return s
}
func (s *search) Raw(b bool) *search {
s.raw = b
return s
}
func (s *search) unscoped() *search {
s.Unscoped = true
return s
}
func (s *search) Table(name string) *search {
s.tableName = name
return s
}
func (s *search) getInterfaceAsSQL(value interface{}) (str string) {
switch value.(type) {
case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
str = fmt.Sprintf("%v", value)
default:
s.db.AddError(ErrInvalidSQL)
}
if str == "-1" {
return ""
}
return
}

View File

@@ -1,226 +0,0 @@
package gorm
import (
"database/sql/driver"
"fmt"
"reflect"
"regexp"
"runtime"
"strings"
"sync"
"time"
)
// NowFunc returns current time, this function is exported in order to be able
// to give the flexibility to the developer to customize it according to their
// needs, e.g:
// gorm.NowFunc = func() time.Time {
// return time.Now().UTC()
// }
var NowFunc = func() time.Time {
return time.Now()
}
// Copied from golint
var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
var commonInitialismsReplacer *strings.Replacer
var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*.go`)
var goTestRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*test.go`)
func init() {
var commonInitialismsForReplacer []string
for _, initialism := range commonInitialisms {
commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism)))
}
commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...)
}
type safeMap struct {
m map[string]string
l *sync.RWMutex
}
func (s *safeMap) Set(key string, value string) {
s.l.Lock()
defer s.l.Unlock()
s.m[key] = value
}
func (s *safeMap) Get(key string) string {
s.l.RLock()
defer s.l.RUnlock()
return s.m[key]
}
func newSafeMap() *safeMap {
return &safeMap{l: new(sync.RWMutex), m: make(map[string]string)}
}
// SQL expression
type expr struct {
expr string
args []interface{}
}
// Expr generate raw SQL expression, for example:
// DB.Model(&product).Update("price", gorm.Expr("price * ? + ?", 2, 100))
func Expr(expression string, args ...interface{}) *expr {
return &expr{expr: expression, args: args}
}
func indirect(reflectValue reflect.Value) reflect.Value {
for reflectValue.Kind() == reflect.Ptr {
reflectValue = reflectValue.Elem()
}
return reflectValue
}
func toQueryMarks(primaryValues [][]interface{}) string {
var results []string
for _, primaryValue := range primaryValues {
var marks []string
for range primaryValue {
marks = append(marks, "?")
}
if len(marks) > 1 {
results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ",")))
} else {
results = append(results, strings.Join(marks, ""))
}
}
return strings.Join(results, ",")
}
func toQueryCondition(scope *Scope, columns []string) string {
var newColumns []string
for _, column := range columns {
newColumns = append(newColumns, scope.Quote(column))
}
if len(columns) > 1 {
return fmt.Sprintf("(%v)", strings.Join(newColumns, ","))
}
return strings.Join(newColumns, ",")
}
func toQueryValues(values [][]interface{}) (results []interface{}) {
for _, value := range values {
for _, v := range value {
results = append(results, v)
}
}
return
}
func fileWithLineNum() string {
for i := 2; i < 15; i++ {
_, file, line, ok := runtime.Caller(i)
if ok && (!goSrcRegexp.MatchString(file) || goTestRegexp.MatchString(file)) {
return fmt.Sprintf("%v:%v", file, line)
}
}
return ""
}
func isBlank(value reflect.Value) bool {
switch value.Kind() {
case reflect.String:
return value.Len() == 0
case reflect.Bool:
return !value.Bool()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return value.Int() == 0
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return value.Uint() == 0
case reflect.Float32, reflect.Float64:
return value.Float() == 0
case reflect.Interface, reflect.Ptr:
return value.IsNil()
}
return reflect.DeepEqual(value.Interface(), reflect.Zero(value.Type()).Interface())
}
func toSearchableMap(attrs ...interface{}) (result interface{}) {
if len(attrs) > 1 {
if str, ok := attrs[0].(string); ok {
result = map[string]interface{}{str: attrs[1]}
}
} else if len(attrs) == 1 {
if attr, ok := attrs[0].(map[string]interface{}); ok {
result = attr
}
if attr, ok := attrs[0].(interface{}); ok {
result = attr
}
}
return
}
func equalAsString(a interface{}, b interface{}) bool {
return toString(a) == toString(b)
}
func toString(str interface{}) string {
if values, ok := str.([]interface{}); ok {
var results []string
for _, value := range values {
results = append(results, toString(value))
}
return strings.Join(results, "_")
} else if bytes, ok := str.([]byte); ok {
return string(bytes)
} else if reflectValue := reflect.Indirect(reflect.ValueOf(str)); reflectValue.IsValid() {
return fmt.Sprintf("%v", reflectValue.Interface())
}
return ""
}
func makeSlice(elemType reflect.Type) interface{} {
if elemType.Kind() == reflect.Slice {
elemType = elemType.Elem()
}
sliceType := reflect.SliceOf(elemType)
slice := reflect.New(sliceType)
slice.Elem().Set(reflect.MakeSlice(sliceType, 0, 0))
return slice.Interface()
}
func strInSlice(a string, list []string) bool {
for _, b := range list {
if b == a {
return true
}
}
return false
}
// getValueFromFields return given fields's value
func getValueFromFields(value reflect.Value, fieldNames []string) (results []interface{}) {
// If value is a nil pointer, Indirect returns a zero Value!
// Therefor we need to check for a zero value,
// as FieldByName could panic
if indirectValue := reflect.Indirect(value); indirectValue.IsValid() {
for _, fieldName := range fieldNames {
if fieldValue := reflect.Indirect(indirectValue.FieldByName(fieldName)); fieldValue.IsValid() {
result := fieldValue.Interface()
if r, ok := result.(driver.Valuer); ok {
result, _ = r.Value()
}
results = append(results, result)
}
}
}
return
}
func addExtraSpaceIfExist(str string) string {
if str != "" {
return " " + str
}
return ""
}

View File

@@ -1,21 +0,0 @@
The MIT License (MIT)
Copyright (c) 2015 - Jinzhu
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -1,273 +0,0 @@
/*
Package inflection pluralizes and singularizes English nouns.
inflection.Plural("person") => "people"
inflection.Plural("Person") => "People"
inflection.Plural("PERSON") => "PEOPLE"
inflection.Singular("people") => "person"
inflection.Singular("People") => "Person"
inflection.Singular("PEOPLE") => "PERSON"
inflection.Plural("FancyPerson") => "FancydPeople"
inflection.Singular("FancyPeople") => "FancydPerson"
Standard rules are from Rails's ActiveSupport (https://github.com/rails/rails/blob/master/activesupport/lib/active_support/inflections.rb)
If you want to register more rules, follow:
inflection.AddUncountable("fish")
inflection.AddIrregular("person", "people")
inflection.AddPlural("(bu)s$", "${1}ses") # "bus" => "buses" / "BUS" => "BUSES" / "Bus" => "Buses"
inflection.AddSingular("(bus)(es)?$", "${1}") # "buses" => "bus" / "Buses" => "Bus" / "BUSES" => "BUS"
*/
package inflection
import (
"regexp"
"strings"
)
type inflection struct {
regexp *regexp.Regexp
replace string
}
// Regular is a regexp find replace inflection
type Regular struct {
find string
replace string
}
// Irregular is a hard replace inflection,
// containing both singular and plural forms
type Irregular struct {
singular string
plural string
}
// RegularSlice is a slice of Regular inflections
type RegularSlice []Regular
// IrregularSlice is a slice of Irregular inflections
type IrregularSlice []Irregular
var pluralInflections = RegularSlice{
{"([a-z])$", "${1}s"},
{"s$", "s"},
{"^(ax|test)is$", "${1}es"},
{"(octop|vir)us$", "${1}i"},
{"(octop|vir)i$", "${1}i"},
{"(alias|status)$", "${1}es"},
{"(bu)s$", "${1}ses"},
{"(buffal|tomat)o$", "${1}oes"},
{"([ti])um$", "${1}a"},
{"([ti])a$", "${1}a"},
{"sis$", "ses"},
{"(?:([^f])fe|([lr])f)$", "${1}${2}ves"},
{"(hive)$", "${1}s"},
{"([^aeiouy]|qu)y$", "${1}ies"},
{"(x|ch|ss|sh)$", "${1}es"},
{"(matr|vert|ind)(?:ix|ex)$", "${1}ices"},
{"^(m|l)ouse$", "${1}ice"},
{"^(m|l)ice$", "${1}ice"},
{"^(ox)$", "${1}en"},
{"^(oxen)$", "${1}"},
{"(quiz)$", "${1}zes"},
}
var singularInflections = RegularSlice{
{"s$", ""},
{"(ss)$", "${1}"},
{"(n)ews$", "${1}ews"},
{"([ti])a$", "${1}um"},
{"((a)naly|(b)a|(d)iagno|(p)arenthe|(p)rogno|(s)ynop|(t)he)(sis|ses)$", "${1}sis"},
{"(^analy)(sis|ses)$", "${1}sis"},
{"([^f])ves$", "${1}fe"},
{"(hive)s$", "${1}"},
{"(tive)s$", "${1}"},
{"([lr])ves$", "${1}f"},
{"([^aeiouy]|qu)ies$", "${1}y"},
{"(s)eries$", "${1}eries"},
{"(m)ovies$", "${1}ovie"},
{"(c)ookies$", "${1}ookie"},
{"(x|ch|ss|sh)es$", "${1}"},
{"^(m|l)ice$", "${1}ouse"},
{"(bus)(es)?$", "${1}"},
{"(o)es$", "${1}"},
{"(shoe)s$", "${1}"},
{"(cris|test)(is|es)$", "${1}is"},
{"^(a)x[ie]s$", "${1}xis"},
{"(octop|vir)(us|i)$", "${1}us"},
{"(alias|status)(es)?$", "${1}"},
{"^(ox)en", "${1}"},
{"(vert|ind)ices$", "${1}ex"},
{"(matr)ices$", "${1}ix"},
{"(quiz)zes$", "${1}"},
{"(database)s$", "${1}"},
}
var irregularInflections = IrregularSlice{
{"person", "people"},
{"man", "men"},
{"child", "children"},
{"sex", "sexes"},
{"move", "moves"},
{"mombie", "mombies"},
}
var uncountableInflections = []string{"equipment", "information", "rice", "money", "species", "series", "fish", "sheep", "jeans", "police"}
var compiledPluralMaps []inflection
var compiledSingularMaps []inflection
func compile() {
compiledPluralMaps = []inflection{}
compiledSingularMaps = []inflection{}
for _, uncountable := range uncountableInflections {
inf := inflection{
regexp: regexp.MustCompile("^(?i)(" + uncountable + ")$"),
replace: "${1}",
}
compiledPluralMaps = append(compiledPluralMaps, inf)
compiledSingularMaps = append(compiledSingularMaps, inf)
}
for _, value := range irregularInflections {
infs := []inflection{
inflection{regexp: regexp.MustCompile(strings.ToUpper(value.singular) + "$"), replace: strings.ToUpper(value.plural)},
inflection{regexp: regexp.MustCompile(strings.Title(value.singular) + "$"), replace: strings.Title(value.plural)},
inflection{regexp: regexp.MustCompile(value.singular + "$"), replace: value.plural},
}
compiledPluralMaps = append(compiledPluralMaps, infs...)
}
for _, value := range irregularInflections {
infs := []inflection{
inflection{regexp: regexp.MustCompile(strings.ToUpper(value.plural) + "$"), replace: strings.ToUpper(value.singular)},
inflection{regexp: regexp.MustCompile(strings.Title(value.plural) + "$"), replace: strings.Title(value.singular)},
inflection{regexp: regexp.MustCompile(value.plural + "$"), replace: value.singular},
}
compiledSingularMaps = append(compiledSingularMaps, infs...)
}
for i := len(pluralInflections) - 1; i >= 0; i-- {
value := pluralInflections[i]
infs := []inflection{
inflection{regexp: regexp.MustCompile(strings.ToUpper(value.find)), replace: strings.ToUpper(value.replace)},
inflection{regexp: regexp.MustCompile(value.find), replace: value.replace},
inflection{regexp: regexp.MustCompile("(?i)" + value.find), replace: value.replace},
}
compiledPluralMaps = append(compiledPluralMaps, infs...)
}
for i := len(singularInflections) - 1; i >= 0; i-- {
value := singularInflections[i]
infs := []inflection{
inflection{regexp: regexp.MustCompile(strings.ToUpper(value.find)), replace: strings.ToUpper(value.replace)},
inflection{regexp: regexp.MustCompile(value.find), replace: value.replace},
inflection{regexp: regexp.MustCompile("(?i)" + value.find), replace: value.replace},
}
compiledSingularMaps = append(compiledSingularMaps, infs...)
}
}
func init() {
compile()
}
// AddPlural adds a plural inflection
func AddPlural(find, replace string) {
pluralInflections = append(pluralInflections, Regular{find, replace})
compile()
}
// AddSingular adds a singular inflection
func AddSingular(find, replace string) {
singularInflections = append(singularInflections, Regular{find, replace})
compile()
}
// AddIrregular adds an irregular inflection
func AddIrregular(singular, plural string) {
irregularInflections = append(irregularInflections, Irregular{singular, plural})
compile()
}
// AddUncountable adds an uncountable inflection
func AddUncountable(values ...string) {
uncountableInflections = append(uncountableInflections, values...)
compile()
}
// GetPlural retrieves the plural inflection values
func GetPlural() RegularSlice {
plurals := make(RegularSlice, len(pluralInflections))
copy(plurals, pluralInflections)
return plurals
}
// GetSingular retrieves the singular inflection values
func GetSingular() RegularSlice {
singulars := make(RegularSlice, len(singularInflections))
copy(singulars, singularInflections)
return singulars
}
// GetIrregular retrieves the irregular inflection values
func GetIrregular() IrregularSlice {
irregular := make(IrregularSlice, len(irregularInflections))
copy(irregular, irregularInflections)
return irregular
}
// GetUncountable retrieves the uncountable inflection values
func GetUncountable() []string {
uncountables := make([]string, len(uncountableInflections))
copy(uncountables, uncountableInflections)
return uncountables
}
// SetPlural sets the plural inflections slice
func SetPlural(inflections RegularSlice) {
pluralInflections = inflections
compile()
}
// SetSingular sets the singular inflections slice
func SetSingular(inflections RegularSlice) {
singularInflections = inflections
compile()
}
// SetIrregular sets the irregular inflections slice
func SetIrregular(inflections IrregularSlice) {
irregularInflections = inflections
compile()
}
// SetUncountable sets the uncountable inflections slice
func SetUncountable(inflections []string) {
uncountableInflections = inflections
compile()
}
// Plural converts a word to its plural form
func Plural(str string) string {
for _, inflection := range compiledPluralMaps {
if inflection.regexp.MatchString(str) {
return inflection.regexp.ReplaceAllString(str, inflection.replace)
}
}
return str
}
// Singular converts a word to its singular form
func Singular(str string) string {
for _, inflection := range compiledSingularMaps {
if inflection.regexp.MatchString(str) {
return inflection.regexp.ReplaceAllString(str, inflection.replace)
}
}
return str
}