From a266d0e9f6d379665b2dfa3d4de4f628c2dec057 Mon Sep 17 00:00:00 2001 From: runzexia Date: Thu, 25 Apr 2019 19:37:37 +0800 Subject: [PATCH] add pipeline apis --- Gopkg.lock | 27 +- pkg/apis/devops/v1alpha2/register.go | 38 + pkg/apiserver/devops/project_pipeline.go | 161 ++ pkg/models/devops/project_pipeline.go | 882 ++++++++++ pkg/models/devops/project_pipeline_handler.go | 258 +++ pkg/models/devops/project_pipeline_test.go | 510 ++++++ vendor/github.com/beevik/etree/CONTRIBUTORS | 10 + vendor/github.com/beevik/etree/LICENSE | 24 + vendor/github.com/beevik/etree/etree.go | 1453 +++++++++++++++++ vendor/github.com/beevik/etree/helpers.go | 276 ++++ vendor/github.com/beevik/etree/path.go | 582 +++++++ vendor/github.com/jinzhu/gorm/License | 21 - vendor/github.com/jinzhu/gorm/association.go | 377 ----- vendor/github.com/jinzhu/gorm/callback.go | 242 --- .../github.com/jinzhu/gorm/callback_create.go | 164 -- .../github.com/jinzhu/gorm/callback_delete.go | 63 - .../github.com/jinzhu/gorm/callback_query.go | 104 -- .../jinzhu/gorm/callback_query_preload.go | 404 ----- .../jinzhu/gorm/callback_row_query.go | 30 - .../github.com/jinzhu/gorm/callback_save.go | 170 -- .../github.com/jinzhu/gorm/callback_update.go | 121 -- vendor/github.com/jinzhu/gorm/dialect.go | 138 -- .../github.com/jinzhu/gorm/dialect_common.go | 176 -- .../github.com/jinzhu/gorm/dialect_mysql.go | 191 --- .../jinzhu/gorm/dialect_postgres.go | 143 -- .../github.com/jinzhu/gorm/dialect_sqlite3.go | 107 -- vendor/github.com/jinzhu/gorm/errors.go | 72 - vendor/github.com/jinzhu/gorm/field.go | 66 - vendor/github.com/jinzhu/gorm/interface.go | 20 - .../jinzhu/gorm/join_table_handler.go | 211 --- vendor/github.com/jinzhu/gorm/logger.go | 119 -- vendor/github.com/jinzhu/gorm/main.go | 792 --------- vendor/github.com/jinzhu/gorm/model.go | 14 - vendor/github.com/jinzhu/gorm/model_struct.go | 640 -------- vendor/github.com/jinzhu/gorm/naming.go | 124 -- vendor/github.com/jinzhu/gorm/scope.go | 1397 ---------------- vendor/github.com/jinzhu/gorm/search.go | 153 -- vendor/github.com/jinzhu/gorm/utils.go | 226 --- vendor/github.com/jinzhu/inflection/LICENSE | 21 - .../jinzhu/inflection/inflections.go | 273 ---- 40 files changed, 4203 insertions(+), 6597 deletions(-) create mode 100644 pkg/apiserver/devops/project_pipeline.go create mode 100644 pkg/models/devops/project_pipeline.go create mode 100644 pkg/models/devops/project_pipeline_handler.go create mode 100644 pkg/models/devops/project_pipeline_test.go create mode 100644 vendor/github.com/beevik/etree/CONTRIBUTORS create mode 100644 vendor/github.com/beevik/etree/LICENSE create mode 100644 vendor/github.com/beevik/etree/etree.go create mode 100644 vendor/github.com/beevik/etree/helpers.go create mode 100644 vendor/github.com/beevik/etree/path.go delete mode 100644 vendor/github.com/jinzhu/gorm/License delete mode 100644 vendor/github.com/jinzhu/gorm/association.go delete mode 100644 vendor/github.com/jinzhu/gorm/callback.go delete mode 100644 vendor/github.com/jinzhu/gorm/callback_create.go delete mode 100644 vendor/github.com/jinzhu/gorm/callback_delete.go delete mode 100644 vendor/github.com/jinzhu/gorm/callback_query.go delete mode 100644 vendor/github.com/jinzhu/gorm/callback_query_preload.go delete mode 100644 vendor/github.com/jinzhu/gorm/callback_row_query.go delete mode 100644 vendor/github.com/jinzhu/gorm/callback_save.go delete mode 100644 vendor/github.com/jinzhu/gorm/callback_update.go delete mode 100644 vendor/github.com/jinzhu/gorm/dialect.go delete mode 100644 vendor/github.com/jinzhu/gorm/dialect_common.go delete mode 100644 vendor/github.com/jinzhu/gorm/dialect_mysql.go delete mode 100644 vendor/github.com/jinzhu/gorm/dialect_postgres.go delete mode 100644 vendor/github.com/jinzhu/gorm/dialect_sqlite3.go delete mode 100644 vendor/github.com/jinzhu/gorm/errors.go delete mode 100644 vendor/github.com/jinzhu/gorm/field.go delete mode 100644 vendor/github.com/jinzhu/gorm/interface.go delete mode 100644 vendor/github.com/jinzhu/gorm/join_table_handler.go delete mode 100644 vendor/github.com/jinzhu/gorm/logger.go delete mode 100644 vendor/github.com/jinzhu/gorm/main.go delete mode 100644 vendor/github.com/jinzhu/gorm/model.go delete mode 100644 vendor/github.com/jinzhu/gorm/model_struct.go delete mode 100644 vendor/github.com/jinzhu/gorm/naming.go delete mode 100644 vendor/github.com/jinzhu/gorm/scope.go delete mode 100644 vendor/github.com/jinzhu/gorm/search.go delete mode 100644 vendor/github.com/jinzhu/gorm/utils.go delete mode 100644 vendor/github.com/jinzhu/inflection/LICENSE delete mode 100644 vendor/github.com/jinzhu/inflection/inflections.go diff --git a/Gopkg.lock b/Gopkg.lock index 1bd32310d..8e34e7f20 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -68,6 +68,14 @@ revision = "ccb8e960c48f04d6935e72476ae4a51028f9e22f" version = "v9" +[[projects]] + digest = "1:15e3271f463f2f40d98bf426aabb86941fc66b10272ccfdfebe548683e37acb1" + name = "github.com/beevik/etree" + packages = ["."] + pruneopts = "NUT" + revision = "8aee6516be3b1163bb6450c35c50e4969e3a3aa8" + version = "v1.1.0" + [[projects]] branch = "master" digest = "1:707ebe952a8b3d00b343c01536c79c73771d100f63ec6babeaed5c79e2b8a8dd" @@ -562,22 +570,6 @@ pruneopts = "NUT" revision = "3eca13d6893afd7ecabe15f4445f5d2872a1b012" -[[projects]] - digest = "1:2ddfc1382a659966038282873c9e33e7694fa503130d445e97c4fdc3b8c5db66" - name = "github.com/jinzhu/gorm" - packages = ["."] - pruneopts = "NUT" - revision = "472c70caa40267cb89fd8facb07fe6454b578626" - version = "v1.9.2" - -[[projects]] - branch = "master" - digest = "1:802f75230c29108e787d40679f9bf5da1a5673eaf5c10eb89afd993e18972909" - name = "github.com/jinzhu/inflection" - packages = ["."] - pruneopts = "NUT" - revision = "04140366298a54a039076d798123ffa108fff46c" - [[projects]] digest = "1:da62aa6632d04e080b8a8b85a59ed9ed1550842a0099a55f3ae3a20d02a3745a" name = "github.com/joho/godotenv" @@ -2011,6 +2003,7 @@ analyzer-version = 1 input-imports = [ "github.com/asaskevich/govalidator", + "github.com/beevik/etree", "github.com/dgrijalva/jwt-go", "github.com/docker/docker/api/types", "github.com/docker/docker/client", @@ -2026,7 +2019,6 @@ "github.com/golang/example/stringutil", "github.com/golang/glog", "github.com/google/uuid", - "github.com/jinzhu/gorm", "github.com/json-iterator/go", "github.com/kiali/kiali/config", "github.com/kiali/kiali/handlers", @@ -2059,7 +2051,6 @@ "gopkg.in/src-d/go-git.v4/storage/memory", "gopkg.in/yaml.v2", "k8s.io/api/apps/v1", - "k8s.io/api/apps/v1beta2", "k8s.io/api/batch/v1", "k8s.io/api/batch/v1beta1", "k8s.io/api/core/v1", diff --git a/pkg/apis/devops/v1alpha2/register.go b/pkg/apis/devops/v1alpha2/register.go index 1d92fc13e..71932250b 100644 --- a/pkg/apis/devops/v1alpha2/register.go +++ b/pkg/apis/devops/v1alpha2/register.go @@ -25,6 +25,7 @@ import ( devopsapi "kubesphere.io/kubesphere/pkg/apiserver/devops" "kubesphere.io/kubesphere/pkg/apiserver/runtime" "kubesphere.io/kubesphere/pkg/models/devops" + "net/http" ) const GroupName = "devops.kubesphere.io" @@ -79,6 +80,43 @@ func addWebService(c *restful.Container) error { Doc("add devops project members"). Metadata(restfulspec.KeyOpenAPITags, tags). Writes(&devops.DevOpsProjectMembership{})) + webservice.Route(webservice.POST("/devops/{devops}/pipelines"). + To(devopsapi.CreateDevOpsProjectPipelineHandler). + Doc("add devops project pipeline"). + Metadata(restfulspec.KeyOpenAPITags, tags). + Reads(devops.ProjectPipeline{})) + webservice.Route(webservice.PUT("/devops/{devops}/pipelines/{pipelines}"). + To(devopsapi.UpdateDevOpsProjectPipelineHandler). + Doc("add devops project pipeline"). + Metadata(restfulspec.KeyOpenAPITags, tags). + Reads(devops.ProjectPipeline{})) + webservice.Route(webservice.GET("/devops/{devops}/pipelines/{pipelines}/config"). + To(devopsapi.GetDevOpsProjectPipelineHandler). + Doc("get devops project pipeline config"). + Metadata(restfulspec.KeyOpenAPITags, tags). + Returns(http.StatusOK, "ok", devops.ProjectPipeline{}). + Writes(devops.ProjectPipeline{})) + webservice.Route(webservice.GET("/devops/{devops}/pipelines/{pipelines}/sonarStatus"). + To(devopsapi.GetPipelineSonarStatusHandler). + Doc("get devops project pipeline sonarStatus"). + Metadata(restfulspec.KeyOpenAPITags, tags). + Returns(http.StatusOK, "ok", []devops.SonarStatus{}). + Writes([]devops.SonarStatus{})) + webservice.Route(webservice.GET("/devops/{devops}/pipelines/{pipelines}/branches/{branches}/sonarStatus"). + To(devopsapi.GetMultiBranchesPipelineSonarStatusHandler). + Doc("get devops project pipeline sonarStatus"). + Metadata(restfulspec.KeyOpenAPITags, tags). + Returns(http.StatusOK, "ok", []devops.SonarStatus{}). + Writes([]devops.SonarStatus{})) + webservice.Route(webservice.DELETE("/devops/{devops}/pipelines/{pipelines}"). + To(devopsapi.DeleteDevOpsProjectPipelineHandler). + Doc("delete devops project pipeline"). + Metadata(restfulspec.KeyOpenAPITags, tags)) + webservice.Route(webservice.PUT("/devops/{devops}/pipelines"). + To(devopsapi.CreateDevOpsProjectPipelineHandler). + Doc("add devops project pipeline"). + Metadata(restfulspec.KeyOpenAPITags, tags). + Reads(devops.ProjectPipeline{})) webservice.Route(webservice.PATCH("/devops/{devops}/members/{members}"). To(devopsapi.UpdateDevOpsProjectMemberHandler). diff --git a/pkg/apiserver/devops/project_pipeline.go b/pkg/apiserver/devops/project_pipeline.go new file mode 100644 index 000000000..3e0460151 --- /dev/null +++ b/pkg/apiserver/devops/project_pipeline.go @@ -0,0 +1,161 @@ +package devops + +import ( + "github.com/emicklei/go-restful" + "github.com/golang/glog" + "kubesphere.io/kubesphere/pkg/constants" + "kubesphere.io/kubesphere/pkg/errors" + "kubesphere.io/kubesphere/pkg/models/devops" + "net/http" +) + +func CreateDevOpsProjectPipelineHandler(request *restful.Request, resp *restful.Response) { + + projectId := request.PathParameter("devops") + username := request.HeaderParameter(constants.UserNameHeader) + var pipeline *devops.ProjectPipeline + err := request.ReadEntity(&pipeline) + if err != nil { + glog.Errorf("%+v", err) + errors.ParseSvcErr(restful.NewError(http.StatusBadRequest, err.Error()), resp) + return + } + err = devops.CheckProjectUserInRole(username, projectId, []string{devops.ProjectOwner, devops.ProjectMaintainer}) + if err != nil { + glog.Errorf("%+v", err) + errors.ParseSvcErr(restful.NewError(http.StatusForbidden, err.Error()), resp) + return + } + pipelineName, err := devops.CreateProjectPipeline(projectId, pipeline) + + if err != nil { + glog.Errorf("%+v", err) + errors.ParseSvcErr(err, resp) + return + } + + resp.WriteAsJson(struct { + Name string `json:"name"` + }{Name: pipelineName}) + return +} + +func DeleteDevOpsProjectPipelineHandler(request *restful.Request, resp *restful.Response) { + projectId := request.PathParameter("devops") + username := request.HeaderParameter(constants.UserNameHeader) + pipelineId := request.PathParameter("pipelines") + + err := devops.CheckProjectUserInRole(username, projectId, []string{devops.ProjectOwner, devops.ProjectMaintainer}) + if err != nil { + glog.Errorf("%+v", err) + errors.ParseSvcErr(restful.NewError(http.StatusForbidden, err.Error()), resp) + return + } + pipelineName, err := devops.DeleteProjectPipeline(projectId, pipelineId) + + if err != nil { + glog.Errorf("%+v", err) + errors.ParseSvcErr(err, resp) + return + } + + resp.WriteAsJson(struct { + Name string `json:"name"` + }{Name: pipelineName}) + return +} + +func UpdateDevOpsProjectPipelineHandler(request *restful.Request, resp *restful.Response) { + + projectId := request.PathParameter("devops") + username := request.HeaderParameter(constants.UserNameHeader) + pipelineId := request.PathParameter("pipelines") + var pipeline *devops.ProjectPipeline + err := request.ReadEntity(&pipeline) + if err != nil { + glog.Errorf("%+v", err) + errors.ParseSvcErr(restful.NewError(http.StatusBadRequest, err.Error()), resp) + return + } + err = devops.CheckProjectUserInRole(username, projectId, []string{devops.ProjectOwner, devops.ProjectMaintainer}) + if err != nil { + glog.Errorf("%+v", err) + errors.ParseSvcErr(restful.NewError(http.StatusForbidden, err.Error()), resp) + return + } + pipelineName, err := devops.UpdateProjectPipeline(projectId, pipelineId, pipeline) + + if err != nil { + glog.Errorf("%+v", err) + errors.ParseSvcErr(err, resp) + return + } + + resp.WriteAsJson(struct { + Name string `json:"name"` + }{Name: pipelineName}) + return +} + +func GetDevOpsProjectPipelineHandler(request *restful.Request, resp *restful.Response) { + + projectId := request.PathParameter("devops") + username := request.HeaderParameter(constants.UserNameHeader) + pipelineId := request.PathParameter("pipelines") + + err := devops.CheckProjectUserInRole(username, projectId, []string{devops.ProjectOwner, devops.ProjectMaintainer}) + if err != nil { + glog.Errorf("%+v", err) + errors.ParseSvcErr(restful.NewError(http.StatusForbidden, err.Error()), resp) + return + } + pipeline, err := devops.GetProjectPipeline(projectId, pipelineId) + + if err != nil { + glog.Errorf("%+v", err) + errors.ParseSvcErr(err, resp) + return + } + + resp.WriteAsJson(pipeline) + return +} + +func GetPipelineSonarStatusHandler(request *restful.Request, resp *restful.Response) { + projectId := request.PathParameter("devops") + username := request.HeaderParameter(constants.UserNameHeader) + pipelineId := request.PathParameter("pipelines") + err := devops.CheckProjectUserInRole(username, projectId, devops.AllRoleSlice) + if err != nil { + glog.Errorf("%+v", err) + errors.ParseSvcErr(restful.NewError(http.StatusForbidden, err.Error()), resp) + return + } + sonarStatus, err := devops.GetPipelineSonar(projectId, pipelineId) + if err != nil { + glog.Errorf("%+v", err) + errors.ParseSvcErr(err, resp) + return + } + resp.WriteAsJson(sonarStatus) +} + +func GetMultiBranchesPipelineSonarStatusHandler(request *restful.Request, resp *restful.Response) { + projectId := request.PathParameter("devops") + username := request.HeaderParameter(constants.UserNameHeader) + pipelineId := request.PathParameter("pipelines") + branchId := request.PathParameter("branches") + err := devops.CheckProjectUserInRole(username, projectId, devops.AllRoleSlice) + if err != nil { + glog.Errorf("%+v", err) + errors.ParseSvcErr(restful.NewError(http.StatusForbidden, err.Error()), resp) + return + } + sonarStatus, err := devops.GetMultiBranchPipelineSonar(projectId, pipelineId, branchId) + if err != nil { + glog.Errorf("%+v", err) + errors.ParseSvcErr(err, resp) + return + } + resp.WriteAsJson(sonarStatus) +} diff --git a/pkg/models/devops/project_pipeline.go b/pkg/models/devops/project_pipeline.go new file mode 100644 index 000000000..b775f023a --- /dev/null +++ b/pkg/models/devops/project_pipeline.go @@ -0,0 +1,882 @@ +package devops + +import ( + "fmt" + "github.com/beevik/etree" + "github.com/golang/glog" + "github.com/kubesphere/sonargo/sonar" + "kubesphere.io/kubesphere/pkg/gojenkins" + "kubesphere.io/kubesphere/pkg/simple/client/sonarqube" + "strconv" + "strings" + "time" +) + +const ( + NoScmPipelineType = "pipeline" + MultiBranchPipelineType = "multi-branch-pipeline" +) + +var ParameterTypeMap = map[string]string{ + "hudson.model.StringParameterDefinition": "string", + "hudson.model.ChoiceParameterDefinition": "choice", + "hudson.model.TextParameterDefinition": "text", + "hudson.model.BooleanParameterDefinition": "boolean", + "hudson.model.FileParameterDefinition": "file", + "hudson.model.PasswordParameterDefinition": "password", +} + +const ( + SonarAnalysisActionClass = "hudson.plugins.sonar.action.SonarAnalysisAction" + SonarMetricKeys = "alert_status,quality_gate_details,bugs,new_bugs,reliability_rating,new_reliability_rating,vulnerabilities,new_vulnerabilities,security_rating,new_security_rating,code_smells,new_code_smells,sqale_rating,new_maintainability_rating,sqale_index,new_technical_debt,coverage,new_coverage,new_lines_to_cover,tests,duplicated_lines_density,new_duplicated_lines_density,duplicated_blocks,ncloc,ncloc_language_distribution,projects,new_lines" + SonarAdditionalFields = "metrics,periods" +) + +type SonarStatus struct { + Measures *sonargo.MeasuresComponentObject `json:"measures,omitempty"` + Issues *sonargo.IssuesSearchObject `json:"issues,omitempty"` + JenkinsAction *gojenkins.GeneralObj `json:"jenkinsAction,omitempty"` + Task *sonargo.CeTaskObject `json:"task,omitempty"` +} + +type ProjectPipeline struct { + Type string `json:"type"` + Pipeline *NoScmPipeline `json:"pipeline"` + MultiBranchPipeline *MultiBranchPipeline `json:"multi_branch_pipeline"` +} + +type NoScmPipeline struct { + Name string `json:"name"` + Description string `json:"description"` + Discarder *DiscarderProperty `json:"discarder"` + Parameters []*Parameter `json:"parameters"` + DisableConcurrent bool `json:"disable_concurrent" mapstructure:"disable_concurrent"` + TimerTrigger *TimerTrigger `json:"timer_trigger" mapstructure:"timer_trigger"` + RemoteTrigger *RemoteTrigger `json:"remote_trigger" mapstructure:"remote_trigger"` + Jenkinsfile string `json:"jenkinsfile"` +} + +type MultiBranchPipeline struct { + Name string `json:"name"` + Description string `json:"description"` + Discarder *DiscarderProperty `json:"discarder"` + TimerTrigger *TimerTrigger `json:"timer_trigger" mapstructure:"timer_trigger"` + SourceType string `json:"source_type"` + GitSource *GitSource `json:"git_source"` + GitHubSource *GithubSource `json:"github_source"` + SvnSource *SvnSource `json:"svn_source"` + SingleSvnSource *SingleSvnSource `json:"single_svn_source"` + ScriptPath string `json:"script_path" mapstructure:"script_path"` +} + +type GitSource struct { + Url string `json:"url,omitempty" mapstructure:"url"` + CredentialId string `json:"credential_id,omitempty" mapstructure:"credential_id"` + DiscoverBranches bool `json:"discover_branches,omitempty" mapstructure:"discover_branches"` + CloneOption *GitCloneOption `json:"git_clone_option,omitempty" mapstructure:"git_clone_option"` + RegexFilter string `json:"regex_filter,omitempty" mapstructure:"regex_filter"` +} + +type GithubSource struct { + Owner string `json:"owner,omitempty" mapstructure:"owner"` + Repo string `json:"repo,omitempty" mapstructure:"repo"` + CredentialId string `json:"credential_id,omitempty" mapstructure:"credential_id"` + ApiUri string `json:"api_uri,omitempty" mapstructure:"api_uri"` + DiscoverBranches int `json:"discover_branches,omitempty" mapstructure:"discover_branches"` + DiscoverPRFromOrigin int `json:"discover_pr_from_origin,omitempty" mapstructure:"discover_pr_from_origin"` + DiscoverPRFromForks *GithubDiscoverPRFromForks `json:"discover_pr_from_forks,omitempty" mapstructure:"discover_pr_from_forks"` + CloneOption *GitCloneOption `json:"git_clone_option,omitempty" mapstructure:"git_clone_option"` + RegexFilter string `json:"regex_filter,omitempty" mapstructure:"regex_filter"` +} + +type GitCloneOption struct { + Shallow bool `json:"shallow" mapstructure:"shallow"` + Timeout int `json:"timeout,omitempty" mapstructure:"timeout"` + Depth int `json:"depth,omitempty" mapstructure:"depth"` +} + +type SvnSource struct { + Remote string `json:"remote,omitempty"` + CredentialId string `json:"credential_id,omitempty" mapstructure:"credential_id"` + Includes string `json:"includes,omitempty"` + Excludes string `json:"excludes,omitempty"` +} +type SingleSvnSource struct { + Remote string `json:"remote,omitempty"` + CredentialId string `json:"credential_id,omitempty" mapstructure:"credential_id"` +} + +type ScmInfo struct { + Type string `json:"type"` + Repo string `json:"repo"` + ApiUri string `json:"api_uri,omitempty"` + Path string `json:"path"` +} + +type GithubDiscoverPRFromForks struct { + Strategy int `json:"strategy" mapstructure:"strategy"` + Trust int `json:"trust" mapstructure:"trust"` +} + +type DiscarderProperty struct { + DaysToKeep string `json:"days_to_keep" mapstructure:"days_to_keep"` + NumToKeep string `json:"num_to_keep" mapstructure:"num_to_keep"` +} + +type Parameter struct { + Name string `json:"name"` + DefaultValue string `json:"default_value,omitempty" mapstructure:"default_value"` + Type string `json:"type"` + Description string `json:"description"` +} + +type TimerTrigger struct { + // user in no scm job + Cron string `json:"cron,omitempty"` + + // use in multi-branch job + Interval string `json:"interval,omitempty"` +} + +type RemoteTrigger struct { + Token string `json:"token"` +} + +func replaceXmlVersion(config, oldVersion, targetVersion string) string { + lines := strings.Split(string(config), "\n") + lines[0] = strings.Replace(lines[0], oldVersion, targetVersion, -1) + output := strings.Join(lines, "\n") + return output +} + +func createPipelineConfigXml(pipeline *NoScmPipeline) (string, error) { + doc := etree.NewDocument() + xmlString := ` + + + + + + + + + + + +` + doc.ReadFromString(xmlString) + flow := doc.SelectElement("flow-definition") + flow.CreateElement("description").SetText(pipeline.Description) + properties := flow.CreateElement("properties") + + if pipeline.DisableConcurrent { + properties.CreateElement("org.jenkinsci.plugins.workflow.job.properties.DisableConcurrentBuildsJobProperty") + } + + if pipeline.Discarder != nil { + discarder := properties.CreateElement("jenkins.model.BuildDiscarderProperty") + strategy := discarder.CreateElement("strategy") + strategy.CreateAttr("class", "hudson.tasks.LogRotator") + strategy.CreateElement("daysToKeep").SetText(pipeline.Discarder.DaysToKeep) + strategy.CreateElement("numToKeep").SetText(pipeline.Discarder.NumToKeep) + strategy.CreateElement("artifactDaysToKeep").SetText("-1") + strategy.CreateElement("artifactNumToKeep").SetText("-1") + } + if pipeline.Parameters != nil { + parameterDefinitions := properties.CreateElement("hudson.model.ParametersDefinitionProperty"). + CreateElement("parameterDefinitions") + for _, parameter := range pipeline.Parameters { + for className, typeName := range ParameterTypeMap { + if typeName == parameter.Type { + paramDefine := parameterDefinitions.CreateElement(className) + paramDefine.CreateElement("name").SetText(parameter.Name) + paramDefine.CreateElement("description").SetText(parameter.Description) + switch parameter.Type { + case "choice": + choices := paramDefine.CreateElement("choices") + choices.CreateAttr("class", "java.util.Arrays$ArrayList") + a := choices.CreateElement("a") + a.CreateAttr("class", "string-array") + choiceValues := strings.Split(parameter.DefaultValue, "\n") + for _, choiceValue := range choiceValues { + a.CreateElement("string").SetText(choiceValue) + } + case "file": + break + default: + paramDefine.CreateElement("defaultValue").SetText(parameter.DefaultValue) + } + } + } + } + } + + if pipeline.TimerTrigger != nil { + triggers := properties. + CreateElement("org.jenkinsci.plugins.workflow.job.properties.PipelineTriggersJobProperty"). + CreateElement("triggers") + triggers.CreateElement("hudson.triggers.TimerTrigger").CreateElement("spec").SetText(pipeline.TimerTrigger.Cron) + } + + pipelineDefine := flow.CreateElement("definition") + pipelineDefine.CreateAttr("class", "org.jenkinsci.plugins.workflow.cps.CpsFlowDefinition") + pipelineDefine.CreateAttr("plugin", "workflow-cps") + pipelineDefine.CreateElement("script").SetText(pipeline.Jenkinsfile) + + pipelineDefine.CreateElement("sandbox").SetText("true") + + flow.CreateElement("triggers") + + if pipeline.RemoteTrigger != nil { + flow.CreateElement("authToken").SetText(pipeline.RemoteTrigger.Token) + } + flow.CreateElement("disabled").SetText("false") + + doc.Indent(2) + stringXml, err := doc.WriteToString() + if err != nil { + return "", err + } + return replaceXmlVersion(stringXml, "1.0", "1.1"), err +} + +func parsePipelineConfigXml(config string) (*NoScmPipeline, error) { + pipeline := &NoScmPipeline{} + config = replaceXmlVersion(config, "1.1", "1.0") + doc := etree.NewDocument() + err := doc.ReadFromString(config) + if err != nil { + return nil, err + } + flow := doc.SelectElement("flow-definition") + if flow == nil { + return nil, fmt.Errorf("can not find pipeline definition") + } + pipeline.Description = flow.SelectElement("description").Text() + + properties := flow.SelectElement("properties") + if properties. + SelectElement( + "org.jenkinsci.plugins.workflow.job.properties.DisableConcurrentBuildsJobProperty") != nil { + pipeline.DisableConcurrent = true + } + if properties.SelectElement("jenkins.model.BuildDiscarderProperty") != nil { + strategy := properties. + SelectElement("jenkins.model.BuildDiscarderProperty"). + SelectElement("strategy") + pipeline.Discarder = &DiscarderProperty{ + DaysToKeep: strategy.SelectElement("daysToKeep").Text(), + NumToKeep: strategy.SelectElement("numToKeep").Text(), + } + } + if parametersProperty := properties.SelectElement("hudson.model.ParametersDefinitionProperty"); parametersProperty != nil { + params := parametersProperty.SelectElement("parameterDefinitions").ChildElements() + for _, param := range params { + switch param.Tag { + case "hudson.model.StringParameterDefinition": + pipeline.Parameters = append(pipeline.Parameters, &Parameter{ + Name: param.SelectElement("name").Text(), + Description: param.SelectElement("description").Text(), + DefaultValue: param.SelectElement("defaultValue").Text(), + Type: ParameterTypeMap["hudson.model.StringParameterDefinition"], + }) + case "hudson.model.BooleanParameterDefinition": + pipeline.Parameters = append(pipeline.Parameters, &Parameter{ + Name: param.SelectElement("name").Text(), + Description: param.SelectElement("description").Text(), + DefaultValue: param.SelectElement("defaultValue").Text(), + Type: ParameterTypeMap["hudson.model.BooleanParameterDefinition"], + }) + case "hudson.model.TextParameterDefinition": + pipeline.Parameters = append(pipeline.Parameters, &Parameter{ + Name: param.SelectElement("name").Text(), + Description: param.SelectElement("description").Text(), + DefaultValue: param.SelectElement("defaultValue").Text(), + Type: ParameterTypeMap["hudson.model.TextParameterDefinition"], + }) + case "hudson.model.FileParameterDefinition": + pipeline.Parameters = append(pipeline.Parameters, &Parameter{ + Name: param.SelectElement("name").Text(), + Description: param.SelectElement("description").Text(), + Type: ParameterTypeMap["hudson.model.FileParameterDefinition"], + }) + case "hudson.model.PasswordParameterDefinition": + pipeline.Parameters = append(pipeline.Parameters, &Parameter{ + Name: param.SelectElement("name").Text(), + Description: param.SelectElement("description").Text(), + DefaultValue: param.SelectElement("name").Text(), + Type: ParameterTypeMap["hudson.model.PasswordParameterDefinition"], + }) + case "hudson.model.ChoiceParameterDefinition": + choiceParameter := &Parameter{ + Name: param.SelectElement("name").Text(), + Description: param.SelectElement("description").Text(), + Type: ParameterTypeMap["hudson.model.ChoiceParameterDefinition"], + } + choices := param.SelectElement("choices").SelectElement("a").SelectElements("string") + for _, choice := range choices { + choiceParameter.DefaultValue += fmt.Sprintf("%s\n", choice.Text()) + } + choiceParameter.DefaultValue = strings.TrimSpace(choiceParameter.DefaultValue) + pipeline.Parameters = append(pipeline.Parameters, choiceParameter) + default: + pipeline.Parameters = append(pipeline.Parameters, &Parameter{ + Name: param.SelectElement("name").Text(), + Description: param.SelectElement("description").Text(), + DefaultValue: "unknown", + Type: param.Tag, + }) + } + } + } + + if triggerProperty := properties. + SelectElement( + "org.jenkinsci.plugins.workflow.job.properties.PipelineTriggersJobProperty"); triggerProperty != nil { + triggers := triggerProperty.SelectElement("triggers") + if timerTrigger := triggers.SelectElement("hudson.triggers.TimerTrigger"); timerTrigger != nil { + pipeline.TimerTrigger = &TimerTrigger{ + Cron: timerTrigger.SelectElement("spec").Text(), + } + } + } + if authToken := flow.SelectElement("authToken"); authToken != nil { + pipeline.RemoteTrigger = &RemoteTrigger{ + Token: authToken.Text(), + } + } + if definition := flow.SelectElement("definition"); definition != nil { + if script := definition.SelectElement("script"); script != nil { + pipeline.Jenkinsfile = script.Text() + } + } + return pipeline, nil +} + +func createMultiBranchPipelineConfigXml(projectName string, pipeline *MultiBranchPipeline) (string, error) { + doc := etree.NewDocument() + xmlString := ` + + + + + + + + + + + + + + + false + + + + + +` + err := doc.ReadFromString(xmlString) + if err != nil { + return "", err + } + + project := doc.SelectElement("org.jenkinsci.plugins.workflow.multibranch.WorkflowMultiBranchProject") + project.CreateElement("description").SetText(pipeline.Description) + + if pipeline.Discarder != nil { + discarder := project.CreateElement("orphanedItemStrategy") + discarder.CreateAttr("class", "com.cloudbees.hudson.plugins.folder.computed.DefaultOrphanedItemStrategy") + discarder.CreateAttr("plugin", "cloudbees-folder") + discarder.CreateElement("pruneDeadBranches").SetText("true") + discarder.CreateElement("daysToKeep").SetText(pipeline.Discarder.DaysToKeep) + discarder.CreateElement("numToKeep").SetText(pipeline.Discarder.NumToKeep) + } + + triggers := project.CreateElement("triggers") + if pipeline.TimerTrigger != nil { + timeTrigger := triggers.CreateElement( + "com.cloudbees.hudson.plugins.folder.computed.PeriodicFolderTrigger") + timeTrigger.CreateAttr("plugin", "cloudbees-folder") + millis, err := strconv.ParseInt(pipeline.TimerTrigger.Interval, 10, 64) + if err != nil { + return "", err + } + timeTrigger.CreateElement("spec").SetText(toCrontab(millis)) + timeTrigger.CreateElement("interval").SetText(pipeline.TimerTrigger.Interval) + + triggers.CreateElement("disabled").SetText("false") + } + + sources := project.CreateElement("sources") + sources.CreateAttr("class", "jenkins.branch.MultiBranchProject$BranchSourceList") + sources.CreateAttr("plugin", "branch-api") + sourcesOwner := sources.CreateElement("owner") + sourcesOwner.CreateAttr("class", "org.jenkinsci.plugins.workflow.multibranch.WorkflowMultiBranchProject") + sourcesOwner.CreateAttr("reference", "../..") + + branchSource := sources.CreateElement("data").CreateElement("jenkins.branch.BranchSource") + branchSourceStrategy := branchSource.CreateElement("strategy") + branchSourceStrategy.CreateAttr("class", "jenkins.branch.NamedExceptionsBranchPropertyStrategy") + branchSourceStrategy.CreateElement("defaultProperties").CreateAttr("class", "empty-list") + branchSourceStrategy.CreateElement("namedExceptions").CreateAttr("class", "empty-list") + + switch pipeline.SourceType { + case "git": + gitDefine := pipeline.GitSource + + gitSource := branchSource.CreateElement("source") + gitSource.CreateAttr("class", "jenkins.plugins.git.GitSCMSource") + gitSource.CreateAttr("plugin", "git") + gitSource.CreateElement("id").SetText(projectName + pipeline.Name) + gitSource.CreateElement("remote").SetText(gitDefine.Url) + if gitDefine.CredentialId != "" { + gitSource.CreateElement("credentialsId").SetText(gitDefine.CredentialId) + } + traits := gitSource.CreateElement("traits") + if gitDefine.DiscoverBranches { + traits.CreateElement("jenkins.plugins.git.traits.BranchDiscoveryTrait") + } + if gitDefine.CloneOption != nil { + cloneExtension := traits.CreateElement("jenkins.plugins.git.traits.CloneOptionTrait").CreateElement("extension") + cloneExtension.CreateAttr("class", "hudson.plugins.git.extensions.impl.CloneOption") + cloneExtension.CreateElement("shallow").SetText(strconv.FormatBool(gitDefine.CloneOption.Shallow)) + cloneExtension.CreateElement("noTags").SetText(strconv.FormatBool(false)) + cloneExtension.CreateElement("reference") + if gitDefine.CloneOption.Timeout >= 0 { + cloneExtension.CreateElement("timeout").SetText(strconv.Itoa(gitDefine.CloneOption.Timeout)) + } else { + cloneExtension.CreateElement("timeout").SetText(strconv.Itoa(10)) + } + + if gitDefine.CloneOption.Depth >= 0 { + cloneExtension.CreateElement("depth").SetText(strconv.Itoa(gitDefine.CloneOption.Depth)) + } else { + cloneExtension.CreateElement("depth").SetText(strconv.Itoa(1)) + } + } + + if gitDefine.RegexFilter != "" { + regexTraits := traits.CreateElement("jenkins.scm.impl.trait.RegexSCMHeadFilterTrait") + regexTraits.CreateAttr("plugin", "scm-api@2.4.0") + regexTraits.CreateElement("regex").SetText(gitDefine.RegexFilter) + } + + case "github": + githubDefine := pipeline.GitHubSource + + githubSource := branchSource.CreateElement("source") + githubSource.CreateAttr("class", "org.jenkinsci.plugins.github_branch_source.GitHubSCMSource") + githubSource.CreateAttr("plugin", "github-branch-source") + githubSource.CreateElement("id").SetText(projectName + pipeline.Name) + githubSource.CreateElement("credentialsId").SetText(githubDefine.CredentialId) + githubSource.CreateElement("repoOwner").SetText(githubDefine.Owner) + githubSource.CreateElement("repository").SetText(githubDefine.Repo) + if githubDefine.ApiUri != "" { + githubSource.CreateElement("apiUri").SetText(githubDefine.ApiUri) + } + traits := githubSource.CreateElement("traits") + if githubDefine.DiscoverBranches != 0 { + traits.CreateElement("org.jenkinsci.plugins.github__branch__source.BranchDiscoveryTrait"). + CreateElement("strategyId").SetText(strconv.Itoa(githubDefine.DiscoverBranches)) + } + if githubDefine.DiscoverPRFromOrigin != 0 { + traits.CreateElement("org.jenkinsci.plugins.github__branch__source.OriginPullRequestDiscoveryTrait"). + CreateElement("strategyId").SetText(strconv.Itoa(githubDefine.DiscoverPRFromOrigin)) + } + if githubDefine.DiscoverPRFromForks != nil { + forkTrait := traits.CreateElement("org.jenkinsci.plugins.github__branch__source.ForkPullRequestDiscoveryTrait") + forkTrait.CreateElement("strategyId").SetText(strconv.Itoa(githubDefine.DiscoverPRFromForks.Strategy)) + trustClass := "org.jenkinsci.plugins.github_branch_source.ForkPullRequestDiscoveryTrait$" + switch githubDefine.DiscoverPRFromForks.Trust { + case 1: + trustClass += "TrustContributors" + case 2: + trustClass += "TrustEveryone" + case 3: + trustClass += "TrustPermission" + case 4: + trustClass += "TrustNobody" + default: + return "", fmt.Errorf("unsupport trust choice") + } + forkTrait.CreateElement("trust").CreateAttr("class", trustClass) + } + if githubDefine.CloneOption != nil { + cloneExtension := traits.CreateElement("jenkins.plugins.git.traits.CloneOptionTrait").CreateElement("extension") + cloneExtension.CreateAttr("class", "hudson.plugins.git.extensions.impl.CloneOption") + cloneExtension.CreateElement("shallow").SetText(strconv.FormatBool(githubDefine.CloneOption.Shallow)) + cloneExtension.CreateElement("noTags").SetText(strconv.FormatBool(false)) + cloneExtension.CreateElement("reference") + if githubDefine.CloneOption.Timeout >= 0 { + cloneExtension.CreateElement("timeout").SetText(strconv.Itoa(githubDefine.CloneOption.Timeout)) + } else { + cloneExtension.CreateElement("timeout").SetText(strconv.Itoa(10)) + } + + if githubDefine.CloneOption.Depth >= 0 { + cloneExtension.CreateElement("depth").SetText(strconv.Itoa(githubDefine.CloneOption.Depth)) + } else { + cloneExtension.CreateElement("depth").SetText(strconv.Itoa(1)) + } + } + if githubDefine.RegexFilter != "" { + regexTraits := traits.CreateElement("jenkins.scm.impl.trait.RegexSCMHeadFilterTrait") + regexTraits.CreateAttr("plugin", "scm-api@2.4.0") + regexTraits.CreateElement("regex").SetText(githubDefine.RegexFilter) + } + + case "svn": + svnDefine := pipeline.SvnSource + svnSource := branchSource.CreateElement("source") + svnSource.CreateAttr("class", "jenkins.scm.impl.subversion.SubversionSCMSource") + svnSource.CreateAttr("plugin", "subversion") + svnSource.CreateElement("id").SetText(projectName + pipeline.Name) + if svnDefine.CredentialId != "" { + svnSource.CreateElement("credentialsId").SetText(svnDefine.CredentialId) + } + if svnDefine.Remote != "" { + svnSource.CreateElement("remoteBase").SetText(svnDefine.Remote) + } + if svnDefine.Includes != "" { + svnSource.CreateElement("includes").SetText(svnDefine.Includes) + } + if svnDefine.Excludes != "" { + svnSource.CreateElement("excludes").SetText(svnDefine.Excludes) + } + + case "single_svn": + singleSvnDefine := pipeline.SingleSvnSource + if err != nil { + return "", err + } + svnSource := branchSource.CreateElement("source") + svnSource.CreateAttr("class", "jenkins.scm.impl.SingleSCMSource") + svnSource.CreateAttr("plugin", "scm-api") + + svnSource.CreateElement("id").SetText(projectName + pipeline.Name) + svnSource.CreateElement("name").SetText("master") + + scm := svnSource.CreateElement("scm") + scm.CreateAttr("class", "hudson.scm.SubversionSCM") + scm.CreateAttr("plugin", "subversion") + + location := scm.CreateElement("locations").CreateElement("hudson.scm.SubversionSCM_-ModuleLocation") + if singleSvnDefine.Remote != "" { + location.CreateElement("remote").SetText(singleSvnDefine.Remote) + } + if singleSvnDefine.CredentialId != "" { + location.CreateElement("credentialsId").SetText(singleSvnDefine.CredentialId) + } + location.CreateElement("local").SetText(".") + location.CreateElement("depthOption").SetText("infinity") + location.CreateElement("ignoreExternalsOption").SetText("true") + location.CreateElement("cancelProcessOnExternalsFail").SetText("true") + + svnSource.CreateElement("excludedRegions") + svnSource.CreateElement("includedRegions") + svnSource.CreateElement("excludedUsers") + svnSource.CreateElement("excludedRevprop") + svnSource.CreateElement("excludedCommitMessages") + svnSource.CreateElement("workspaceUpdater").CreateAttr("class", "hudson.scm.subversion.UpdateUpdater") + svnSource.CreateElement("ignoreDirPropChanges").SetText("false") + svnSource.CreateElement("filterChangelog").SetText("false") + svnSource.CreateElement("quietOperation").SetText("true") + + default: + return "", fmt.Errorf("unsupport source type") + } + factory := project.CreateElement("factory") + factory.CreateAttr("class", "org.jenkinsci.plugins.workflow.multibranch.WorkflowBranchProjectFactory") + + factoryOwner := factory.CreateElement("owner") + factoryOwner.CreateAttr("class", "org.jenkinsci.plugins.workflow.multibranch.WorkflowMultiBranchProject") + factoryOwner.CreateAttr("reference", "../..") + factory.CreateElement("scriptPath").SetText(pipeline.ScriptPath) + + doc.Indent(2) + stringXml, err := doc.WriteToString() + return replaceXmlVersion(stringXml, "1.0", "1.1"), err +} + +func parseMultiBranchPipelineConfigXml(config string) (*MultiBranchPipeline, error) { + pipeline := &MultiBranchPipeline{} + config = replaceXmlVersion(config, "1.1", "1.0") + doc := etree.NewDocument() + err := doc.ReadFromString(config) + if err != nil { + return nil, err + } + project := doc.SelectElement("org.jenkinsci.plugins.workflow.multibranch.WorkflowMultiBranchProject") + if project == nil { + return nil, fmt.Errorf("can not parse mutibranch pipeline config") + } + pipeline.Description = project.SelectElement("description").Text() + + if discarder := project.SelectElement("orphanedItemStrategy"); discarder != nil { + pipeline.Discarder = &DiscarderProperty{ + DaysToKeep: discarder.SelectElement("daysToKeep").Text(), + NumToKeep: discarder.SelectElement("numToKeep").Text(), + } + } + if triggers := project.SelectElement("triggers"); triggers != nil { + if timerTrigger := triggers.SelectElement( + "com.cloudbees.hudson.plugins.folder.computed.PeriodicFolderTrigger"); timerTrigger != nil { + pipeline.TimerTrigger = &TimerTrigger{ + Interval: timerTrigger.SelectElement("interval").Text(), + } + } + } + + if sources := project.SelectElement("sources"); sources != nil { + if sourcesData := sources.SelectElement("data"); sourcesData != nil { + if branchSource := sourcesData.SelectElement("jenkins.branch.BranchSource"); branchSource != nil { + source := branchSource.SelectElement("source") + switch source.SelectAttr("class").Value { + case "org.jenkinsci.plugins.github_branch_source.GitHubSCMSource": + githubSource := &GithubSource{} + if credential := source.SelectElement("credentialsId"); credential != nil { + githubSource.CredentialId = credential.Text() + } + if repoOwner := source.SelectElement("repoOwner"); repoOwner != nil { + githubSource.Owner = repoOwner.Text() + } + if repository := source.SelectElement("repository"); repository != nil { + githubSource.Repo = repository.Text() + } + if apiUri := source.SelectElement("apiUri"); apiUri != nil { + githubSource.ApiUri = apiUri.Text() + } + traits := source.SelectElement("traits") + if branchDiscoverTrait := traits.SelectElement( + "org.jenkinsci.plugins.github__branch__source.BranchDiscoveryTrait"); branchDiscoverTrait != nil { + strategyId, err := strconv.Atoi(branchDiscoverTrait.SelectElement("strategyId").Text()) + if err != nil { + return nil, err + } + githubSource.DiscoverBranches = strategyId + } + if originPRDiscoverTrait := traits.SelectElement( + "org.jenkinsci.plugins.github__branch__source.OriginPullRequestDiscoveryTrait"); originPRDiscoverTrait != nil { + strategyId, err := strconv.Atoi(originPRDiscoverTrait.SelectElement("strategyId").Text()) + if err != nil { + return nil, err + } + githubSource.DiscoverPRFromOrigin = strategyId + } + if forkPRDiscoverTrait := traits.SelectElement( + "org.jenkinsci.plugins.github__branch__source.ForkPullRequestDiscoveryTrait"); forkPRDiscoverTrait != nil { + strategyId, err := strconv.Atoi(forkPRDiscoverTrait.SelectElement("strategyId").Text()) + if err != nil { + return nil, err + } + trustClass := forkPRDiscoverTrait.SelectElement("trust").SelectAttr("class").Value + trust := strings.Split(trustClass, "$") + switch trust[1] { + case "TrustContributors": + githubSource.DiscoverPRFromForks = &GithubDiscoverPRFromForks{ + Strategy: strategyId, + Trust: 1, + } + case "TrustEveryone": + githubSource.DiscoverPRFromForks = &GithubDiscoverPRFromForks{ + Strategy: strategyId, + Trust: 2, + } + case "TrustPermission": + githubSource.DiscoverPRFromForks = &GithubDiscoverPRFromForks{ + Strategy: strategyId, + Trust: 3, + } + case "TrustNobody": + githubSource.DiscoverPRFromForks = &GithubDiscoverPRFromForks{ + Strategy: strategyId, + Trust: 4, + } + } + if cloneTrait := traits.SelectElement( + "jenkins.plugins.git.traits.CloneOptionTrait"); cloneTrait != nil { + if cloneExtension := cloneTrait.SelectElement( + "extension"); cloneExtension != nil { + githubSource.CloneOption = &GitCloneOption{} + if value, err := strconv.ParseBool(cloneExtension.SelectElement("shallow").Text()); err == nil { + githubSource.CloneOption.Shallow = value + } + if value, err := strconv.ParseInt(cloneExtension.SelectElement("timeout").Text(), 10, 32); err == nil { + githubSource.CloneOption.Timeout = int(value) + } + if value, err := strconv.ParseInt(cloneExtension.SelectElement("depth").Text(), 10, 32); err == nil { + githubSource.CloneOption.Depth = int(value) + } + } + } + + if regexTrait := traits.SelectElement( + "jenkins.scm.impl.trait.RegexSCMHeadFilterTrait"); regexTrait != nil { + if regex := regexTrait.SelectElement("regex"); regex != nil { + githubSource.RegexFilter = regex.Text() + } + } + } + + pipeline.GitHubSource = githubSource + pipeline.SourceType = "github" + case "jenkins.plugins.git.GitSCMSource": + gitSource := &GitSource{} + if credential := source.SelectElement("credentialsId"); credential != nil { + gitSource.CredentialId = credential.Text() + } + if remote := source.SelectElement("remote"); remote != nil { + gitSource.Url = remote.Text() + } + + traits := source.SelectElement("traits") + if branchDiscoverTrait := traits.SelectElement( + "jenkins.plugins.git.traits.BranchDiscoveryTrait"); branchDiscoverTrait != nil { + gitSource.DiscoverBranches = true + } + if cloneTrait := traits.SelectElement( + "jenkins.plugins.git.traits.CloneOptionTrait"); cloneTrait != nil { + if cloneExtension := cloneTrait.SelectElement( + "extension"); cloneExtension != nil { + gitSource.CloneOption = &GitCloneOption{} + if value, err := strconv.ParseBool(cloneExtension.SelectElement("shallow").Text()); err == nil { + gitSource.CloneOption.Shallow = value + } + if value, err := strconv.ParseInt(cloneExtension.SelectElement("timeout").Text(), 10, 32); err == nil { + gitSource.CloneOption.Timeout = int(value) + } + if value, err := strconv.ParseInt(cloneExtension.SelectElement("depth").Text(), 10, 32); err == nil { + gitSource.CloneOption.Depth = int(value) + } + } + } + if regexTrait := traits.SelectElement( + "jenkins.scm.impl.trait.RegexSCMHeadFilterTrait"); regexTrait != nil { + if regex := regexTrait.SelectElement("regex"); regex != nil { + gitSource.RegexFilter = regex.Text() + } + } + + pipeline.SourceType = "git" + pipeline.GitSource = gitSource + case "jenkins.scm.impl.SingleSCMSource": + singleSvnSource := &SingleSvnSource{} + + if scm := source.SelectElement("scm"); scm != nil { + if locations := scm.SelectElement("locations"); locations != nil { + if moduleLocations := locations.SelectElement("hudson.scm.SubversionSCM_-ModuleLocation"); moduleLocations != nil { + if remote := moduleLocations.SelectElement("remote"); remote != nil { + singleSvnSource.Remote = remote.Text() + } + if credentialId := moduleLocations.SelectElement("credentialsId"); credentialId != nil { + singleSvnSource.CredentialId = credentialId.Text() + } + } + } + } + pipeline.SourceType = "single_svn" + + pipeline.SingleSvnSource = singleSvnSource + + case "jenkins.scm.impl.subversion.SubversionSCMSource": + svnSource := &SvnSource{} + + if remote := source.SelectElement("remoteBase"); remote != nil { + svnSource.Remote = remote.Text() + } + + if credentialsId := source.SelectElement("credentialsId"); credentialsId != nil { + svnSource.CredentialId = credentialsId.Text() + } + + if includes := source.SelectElement("includes"); includes != nil { + svnSource.Includes = includes.Text() + } + + if excludes := source.SelectElement("excludes"); excludes != nil { + svnSource.Excludes = excludes.Text() + } + + pipeline.SourceType = "svn" + + pipeline.SvnSource = svnSource + } + } + } + } + + pipeline.ScriptPath = project.SelectElement("factory").SelectElement("scriptPath").Text() + return pipeline, nil +} + +func toCrontab(millis int64) string { + if millis*time.Millisecond.Nanoseconds() <= 5*time.Minute.Nanoseconds() { + return "* * * * *" + } + if millis*time.Millisecond.Nanoseconds() <= 30*time.Minute.Nanoseconds() { + return "H/5 * * * *" + } + if millis*time.Millisecond.Nanoseconds() <= 1*time.Hour.Nanoseconds() { + return "H/15 * * * *" + } + if millis*time.Millisecond.Nanoseconds() <= 8*time.Hour.Nanoseconds() { + return "H/30 * * * *" + } + if millis*time.Millisecond.Nanoseconds() <= 24*time.Hour.Nanoseconds() { + return "H H/4 * * *" + } + if millis*time.Millisecond.Nanoseconds() <= 48*time.Hour.Nanoseconds() { + return "H H/12 * * *" + } + return "H H * * *" + +} + +func getBuildSonarResults(build *gojenkins.Build) ([]*SonarStatus, error) { + sonarClient := sonarqube.Client() + actions := build.GetActions() + sonarStatuses := make([]*SonarStatus, 0) + for _, action := range actions { + if action.ClassName == SonarAnalysisActionClass { + sonarStatus := &SonarStatus{} + taskOptions := &sonargo.CeTaskOption{ + Id: action.SonarTaskId, + } + ceTask, _, err := sonarClient.Ce.Task(taskOptions) + if err != nil { + glog.Errorf("get sonar task error [%+v]", err) + continue + } + sonarStatus.Task = ceTask + measuresComponentOption := &sonargo.MeasuresComponentOption{ + Component: ceTask.Task.ComponentKey, + AdditionalFields: SonarAdditionalFields, + MetricKeys: SonarMetricKeys, + } + measures, _, err := sonarClient.Measures.Component(measuresComponentOption) + if err != nil { + glog.Errorf("get sonar task error [%+v]", err) + continue + } + sonarStatus.Measures = measures + + issuesSearchOption := &sonargo.IssuesSearchOption{ + AdditionalFields: "_all", + ComponentKeys: ceTask.Task.ComponentKey, + Resolved: "false", + Ps: "10", + S: "FILE_LINE", + Facets: "severities,types", + } + issuesSearch, _, err := sonarClient.Issues.Search(issuesSearchOption) + sonarStatus.Issues = issuesSearch + jenkinsAction := action + sonarStatus.JenkinsAction = &jenkinsAction + + sonarStatuses = append(sonarStatuses, sonarStatus) + } + } + return sonarStatuses, nil +} diff --git a/pkg/models/devops/project_pipeline_handler.go b/pkg/models/devops/project_pipeline_handler.go new file mode 100644 index 000000000..cf767f56b --- /dev/null +++ b/pkg/models/devops/project_pipeline_handler.go @@ -0,0 +1,258 @@ +package devops + +import ( + "fmt" + "github.com/emicklei/go-restful" + "github.com/golang/glog" + "kubesphere.io/devops/pkg/utils/stringutils" + "kubesphere.io/kubesphere/pkg/gojenkins/utils" + "kubesphere.io/kubesphere/pkg/simple/client/admin_jenkins" + "net/http" +) + +func CreateProjectPipeline(projectId string, pipeline *ProjectPipeline) (string, error) { + jenkinsClient := admin_jenkins.Client() + switch pipeline.Type { + case NoScmPipelineType: + + config, err := createPipelineConfigXml(pipeline.Pipeline) + if err != nil { + glog.Errorf("%+v", err) + return "", restful.NewError(http.StatusInternalServerError, err.Error()) + } + + job, err := jenkinsClient.GetJob(pipeline.Pipeline.Name, projectId) + if job != nil { + err := fmt.Errorf("job name [%s] has been used", job.GetName()) + glog.Warning(err.Error()) + return "", restful.NewError(http.StatusConflict, err.Error()) + } + + if err != nil && utils.GetJenkinsStatusCode(err) != http.StatusNotFound { + glog.Errorf("%+v", err) + return "", restful.NewError(utils.GetJenkinsStatusCode(err), err.Error()) + } + + _, err = jenkinsClient.CreateJobInFolder(config, pipeline.Pipeline.Name, projectId) + if err != nil { + glog.Errorf("%+v", err) + return "", restful.NewError(utils.GetJenkinsStatusCode(err), err.Error()) + } + + return pipeline.Pipeline.Name, nil + case MultiBranchPipelineType: + config, err := createMultiBranchPipelineConfigXml(projectId, pipeline.MultiBranchPipeline) + if err != nil { + glog.Errorf("%+v", err) + + return "", restful.NewError(http.StatusInternalServerError, err.Error()) + } + + job, err := jenkinsClient.GetJob(pipeline.MultiBranchPipeline.Name, projectId) + if job != nil { + err := fmt.Errorf("job name [%s] has been used", job.GetName()) + glog.Warning(err.Error()) + return "", restful.NewError(http.StatusConflict, err.Error()) + } + + if err != nil && utils.GetJenkinsStatusCode(err) != http.StatusNotFound { + glog.Errorf("%+v", err) + return "", restful.NewError(utils.GetJenkinsStatusCode(err), err.Error()) + } + + _, err = jenkinsClient.CreateJobInFolder(config, pipeline.MultiBranchPipeline.Name, projectId) + if err != nil { + glog.Errorf("%+v", err) + return "", restful.NewError(utils.GetJenkinsStatusCode(err), err.Error()) + } + + return pipeline.MultiBranchPipeline.Name, nil + + default: + err := fmt.Errorf("error unsupport job type") + glog.Errorf("%+v", err) + return "", restful.NewError(http.StatusBadRequest, err.Error()) + } +} + +func DeleteProjectPipeline(projectId string, pipelineId string) (string, error) { + jenkinsClient := admin_jenkins.Client() + _, err := jenkinsClient.DeleteJob(pipelineId, projectId) + if err != nil { + glog.Errorf("%+v", err) + return "", restful.NewError(utils.GetJenkinsStatusCode(err), err.Error()) + } + return pipelineId, nil +} + +func UpdateProjectPipeline(projectId, pipelineId string, pipeline *ProjectPipeline) (string, error) { + jenkinsClient := admin_jenkins.Client() + switch pipeline.Type { + case NoScmPipelineType: + + config, err := createPipelineConfigXml(pipeline.Pipeline) + if err != nil { + glog.Errorf("%+v", err) + return "", restful.NewError(http.StatusInternalServerError, err.Error()) + } + + job, err := jenkinsClient.GetJob(pipelineId, projectId) + + if err != nil { + glog.Errorf("%+v", err) + return "", restful.NewError(utils.GetJenkinsStatusCode(err), err.Error()) + } + + err = job.UpdateConfig(config) + if err != nil { + glog.Errorf("%+v", err) + return "", restful.NewError(utils.GetJenkinsStatusCode(err), err.Error()) + } + + return pipeline.Pipeline.Name, nil + case MultiBranchPipelineType: + + config, err := createMultiBranchPipelineConfigXml(projectId, pipeline.MultiBranchPipeline) + if err != nil { + glog.Errorf("%+v", err) + + return "", restful.NewError(http.StatusInternalServerError, err.Error()) + } + + job, err := jenkinsClient.GetJob(pipelineId, projectId) + + if err != nil { + glog.Errorf("%+v", err) + return "", restful.NewError(utils.GetJenkinsStatusCode(err), err.Error()) + } + + err = job.UpdateConfig(config) + if err != nil { + glog.Errorf("%+v", err) + return "", restful.NewError(utils.GetJenkinsStatusCode(err), err.Error()) + } + + return pipeline.MultiBranchPipeline.Name, nil + + default: + err := fmt.Errorf("error unsupport job type") + glog.Errorf("%+v", err) + return "", restful.NewError(http.StatusBadRequest, err.Error()) + } +} + +func GetProjectPipeline(projectId, pipelineId string) (*ProjectPipeline, error) { + jenkinsClient := admin_jenkins.Client() + + job, err := jenkinsClient.GetJob(pipelineId, projectId) + if err != nil { + glog.Errorf("%+v", err) + return nil, restful.NewError(utils.GetJenkinsStatusCode(err), err.Error()) + } + switch job.Raw.Class { + case "org.jenkinsci.plugins.workflow.job.WorkflowJob": + config, err := job.GetConfig() + if err != nil { + glog.Errorf("%+v", err) + return nil, restful.NewError(utils.GetJenkinsStatusCode(err), err.Error()) + } + pipeline, err := parsePipelineConfigXml(config) + if err != nil { + glog.Errorf("%+v", err) + return nil, restful.NewError(utils.GetJenkinsStatusCode(err), err.Error()) + } + pipeline.Name = pipelineId + return &ProjectPipeline{ + Type: NoScmPipelineType, + Pipeline: pipeline, + }, nil + + case "org.jenkinsci.plugins.workflow.multibranch.WorkflowMultiBranchProject": + config, err := job.GetConfig() + if err != nil { + glog.Errorf("%+v", err) + return nil, restful.NewError(utils.GetJenkinsStatusCode(err), err.Error()) + } + pipeline, err := parseMultiBranchPipelineConfigXml(config) + if err != nil { + glog.Errorf("%+v", err) + return nil, restful.NewError(utils.GetJenkinsStatusCode(err), err.Error()) + } + pipeline.Name = pipelineId + return &ProjectPipeline{ + Type: MultiBranchPipelineType, + MultiBranchPipeline: pipeline, + }, nil + default: + err := fmt.Errorf("error unsupport job type") + glog.Errorf("%+v", err) + return nil, restful.NewError(http.StatusBadRequest, err.Error()) + + } +} + +func GetPipelineSonar(projectId, pipelineId string) ([]*SonarStatus, error) { + jenkinsClient := admin_jenkins.Client() + job, err := jenkinsClient.GetJob(pipelineId, projectId) + if err != nil { + glog.Errorf("%+v", err) + return nil, restful.NewError(utils.GetJenkinsStatusCode(err), err.Error()) + } + build, err := job.GetLastBuild() + if err != nil { + glog.Errorf("%+v", err) + return nil, restful.NewError(utils.GetJenkinsStatusCode(err), err.Error()) + } + + sonarStatus, err := getBuildSonarResults(build) + if err != nil { + glog.Errorf("%+v", err) + return nil, restful.NewError(http.StatusBadRequest, err.Error()) + } + if len(sonarStatus) == 0 { + build, err := job.GetLastCompletedBuild() + if err != nil && stringutils.GetJenkinsStatusCode(err) != http.StatusNotFound { + glog.Errorf("%+v", err) + return nil, restful.NewError(utils.GetJenkinsStatusCode(err), err.Error()) + } + sonarStatus, err = getBuildSonarResults(build) + if err != nil { + glog.Errorf("%+v", err) + return nil, restful.NewError(http.StatusBadRequest, err.Error()) + } + } + return sonarStatus, nil +} + +func GetMultiBranchPipelineSonar(projectId, pipelineId, branchId string) ([]*SonarStatus, error) { + jenkinsClient := admin_jenkins.Client() + job, err := jenkinsClient.GetJob(branchId, projectId, pipelineId) + if err != nil { + glog.Errorf("%+v", err) + return nil, restful.NewError(utils.GetJenkinsStatusCode(err), err.Error()) + } + build, err := job.GetLastBuild() + if err != nil { + glog.Errorf("%+v", err) + return nil, restful.NewError(utils.GetJenkinsStatusCode(err), err.Error()) + } + + sonarStatus, err := getBuildSonarResults(build) + if err != nil { + glog.Errorf("%+v", err) + return nil, restful.NewError(http.StatusBadRequest, err.Error()) + } + if len(sonarStatus) == 0 { + build, err := job.GetLastCompletedBuild() + if err != nil && stringutils.GetJenkinsStatusCode(err) != http.StatusNotFound { + glog.Errorf("%+v", err) + return nil, restful.NewError(utils.GetJenkinsStatusCode(err), err.Error()) + } + sonarStatus, err = getBuildSonarResults(build) + if err != nil { + glog.Errorf("%+v", err) + return nil, restful.NewError(http.StatusBadRequest, err.Error()) + } + } + return sonarStatus, nil +} diff --git a/pkg/models/devops/project_pipeline_test.go b/pkg/models/devops/project_pipeline_test.go new file mode 100644 index 000000000..516d2d7a1 --- /dev/null +++ b/pkg/models/devops/project_pipeline_test.go @@ -0,0 +1,510 @@ +package devops + +import ( + "reflect" + "testing" +) + +func Test_NoScmPipelineConfig(t *testing.T) { + inputs := []*NoScmPipeline{ + &NoScmPipeline{ + Name: "", + Description: "for test", + Jenkinsfile: "node{echo 'hello'}", + }, + &NoScmPipeline{ + Name: "", + Description: "", + Jenkinsfile: "node{echo 'hello'}", + }, + &NoScmPipeline{ + Name: "", + Description: "", + Jenkinsfile: "node{echo 'hello'}", + DisableConcurrent: true, + }, + } + for _, input := range inputs { + outputString, err := createPipelineConfigXml(input) + if err != nil { + t.Fatalf("should not get error %+v", err) + } + output, err := parsePipelineConfigXml(outputString) + + if err != nil { + t.Fatalf("should not get error %+v", err) + } + if !reflect.DeepEqual(input, output) { + t.Fatalf("input [%+v] output [%+v] should equal ", input, output) + } + } +} + +func Test_NoScmPipelineConfig_Discarder(t *testing.T) { + inputs := []*NoScmPipeline{ + &NoScmPipeline{ + Name: "", + Description: "for test", + Jenkinsfile: "node{echo 'hello'}", + Discarder: &DiscarderProperty{ + "3", "5", + }, + }, + &NoScmPipeline{ + Name: "", + Description: "for test", + Jenkinsfile: "node{echo 'hello'}", + Discarder: &DiscarderProperty{ + "3", "", + }, + }, + &NoScmPipeline{ + Name: "", + Description: "for test", + Jenkinsfile: "node{echo 'hello'}", + Discarder: &DiscarderProperty{ + "", "21321", + }, + }, + &NoScmPipeline{ + Name: "", + Description: "for test", + Jenkinsfile: "node{echo 'hello'}", + Discarder: &DiscarderProperty{ + "", "", + }, + }, + } + for _, input := range inputs { + outputString, err := createPipelineConfigXml(input) + if err != nil { + t.Fatalf("should not get error %+v", err) + } + output, err := parsePipelineConfigXml(outputString) + + if err != nil { + t.Fatalf("should not get error %+v", err) + } + if !reflect.DeepEqual(input, output) { + t.Fatalf("input [%+v] output [%+v] should equal ", input, output) + } + } +} + +func Test_NoScmPipelineConfig_Param(t *testing.T) { + inputs := []*NoScmPipeline{ + &NoScmPipeline{ + Name: "", + Description: "for test", + Jenkinsfile: "node{echo 'hello'}", + Parameters: []*Parameter{ + &Parameter{ + Name: "d", + DefaultValue: "a\nb", + Type: "choice", + Description: "fortest", + }, + }, + }, + &NoScmPipeline{ + Name: "", + Description: "for test", + Jenkinsfile: "node{echo 'hello'}", + Parameters: []*Parameter{ + &Parameter{ + Name: "a", + DefaultValue: "abc", + Type: "string", + Description: "fortest", + }, + &Parameter{ + Name: "b", + DefaultValue: "false", + Type: "boolean", + Description: "fortest", + }, + &Parameter{ + Name: "c", + DefaultValue: "password \n aaa", + Type: "text", + Description: "fortest", + }, + &Parameter{ + Name: "d", + DefaultValue: "a\nb", + Type: "choice", + Description: "fortest", + }, + }, + }, + } + for _, input := range inputs { + outputString, err := createPipelineConfigXml(input) + if err != nil { + t.Fatalf("should not get error %+v", err) + } + output, err := parsePipelineConfigXml(outputString) + + if err != nil { + t.Fatalf("should not get error %+v", err) + } + if !reflect.DeepEqual(input, output) { + t.Fatalf("input [%+v] output [%+v] should equal ", input, output) + } + } +} + +func Test_NoScmPipelineConfig_Trigger(t *testing.T) { + inputs := []*NoScmPipeline{ + &NoScmPipeline{ + Name: "", + Description: "for test", + Jenkinsfile: "node{echo 'hello'}", + TimerTrigger: &TimerTrigger{ + Cron: "1 1 1 * * *", + }, + }, + + &NoScmPipeline{ + Name: "", + Description: "for test", + Jenkinsfile: "node{echo 'hello'}", + RemoteTrigger: &RemoteTrigger{ + Token: "abc", + }, + }, + &NoScmPipeline{ + Name: "", + Description: "for test", + Jenkinsfile: "node{echo 'hello'}", + TimerTrigger: &TimerTrigger{ + Cron: "1 1 1 * * *", + }, + RemoteTrigger: &RemoteTrigger{ + Token: "abc", + }, + }, + } + + for _, input := range inputs { + outputString, err := createPipelineConfigXml(input) + if err != nil { + t.Fatalf("should not get error %+v", err) + } + output, err := parsePipelineConfigXml(outputString) + + if err != nil { + t.Fatalf("should not get error %+v", err) + } + if !reflect.DeepEqual(input, output) { + t.Fatalf("input [%+v] output [%+v] should equal ", input, output) + } + } +} + +func Test_MultiBranchPipelineConfig(t *testing.T) { + + inputs := []*MultiBranchPipeline{ + &MultiBranchPipeline{ + Name: "", + Description: "for test", + ScriptPath: "Jenkinsfile", + SourceType: "git", + GitSource: &GitSource{}, + }, + &MultiBranchPipeline{ + Name: "", + Description: "for test", + ScriptPath: "Jenkinsfile", + SourceType: "github", + GitHubSource: &GithubSource{}, + }, + &MultiBranchPipeline{ + Name: "", + Description: "for test", + ScriptPath: "Jenkinsfile", + SourceType: "single_svn", + SingleSvnSource: &SingleSvnSource{}, + }, + &MultiBranchPipeline{ + Name: "", + Description: "for test", + ScriptPath: "Jenkinsfile", + SourceType: "svn", + SvnSource: &SvnSource{}, + }, + } + for _, input := range inputs { + outputString, err := createMultiBranchPipelineConfigXml("", input) + if err != nil { + t.Fatalf("should not get error %+v", err) + } + output, err := parseMultiBranchPipelineConfigXml(outputString) + + if err != nil { + t.Fatalf("should not get error %+v", err) + } + if !reflect.DeepEqual(input, output) { + t.Fatalf("input [%+v] output [%+v] should equal ", input, output) + } + } +} + +func Test_MultiBranchPipelineConfig_Discarder(t *testing.T) { + + inputs := []*MultiBranchPipeline{ + &MultiBranchPipeline{ + Name: "", + Description: "for test", + ScriptPath: "Jenkinsfile", + SourceType: "git", + Discarder: &DiscarderProperty{ + DaysToKeep: "1", + NumToKeep: "2", + }, + GitSource: &GitSource{}, + }, + } + for _, input := range inputs { + outputString, err := createMultiBranchPipelineConfigXml("", input) + if err != nil { + t.Fatalf("should not get error %+v", err) + } + output, err := parseMultiBranchPipelineConfigXml(outputString) + + if err != nil { + t.Fatalf("should not get error %+v", err) + } + if !reflect.DeepEqual(input, output) { + t.Fatalf("input [%+v] output [%+v] should equal ", input, output) + } + } +} + +func Test_MultiBranchPipelineConfig_TimerTrigger(t *testing.T) { + inputs := []*MultiBranchPipeline{ + &MultiBranchPipeline{ + Name: "", + Description: "for test", + ScriptPath: "Jenkinsfile", + SourceType: "git", + TimerTrigger: &TimerTrigger{ + Interval: "12345566", + }, + GitSource: &GitSource{}, + }, + } + for _, input := range inputs { + outputString, err := createMultiBranchPipelineConfigXml("", input) + if err != nil { + t.Fatalf("should not get error %+v", err) + } + output, err := parseMultiBranchPipelineConfigXml(outputString) + + if err != nil { + t.Fatalf("should not get error %+v", err) + } + if !reflect.DeepEqual(input, output) { + t.Fatalf("input [%+v] output [%+v] should equal ", input, output) + } + } +} + +func Test_MultiBranchPipelineConfig_Source(t *testing.T) { + + inputs := []*MultiBranchPipeline{ + &MultiBranchPipeline{ + Name: "", + Description: "for test", + ScriptPath: "Jenkinsfile", + SourceType: "git", + TimerTrigger: &TimerTrigger{ + Interval: "12345566", + }, + GitSource: &GitSource{ + Url: "https://github.com/kubesphere/devops", + CredentialId: "git", + DiscoverBranches: true, + }, + }, + &MultiBranchPipeline{ + Name: "", + Description: "for test", + ScriptPath: "Jenkinsfile", + SourceType: "github", + TimerTrigger: &TimerTrigger{ + Interval: "12345566", + }, + GitHubSource: &GithubSource{ + Owner: "kubesphere", + Repo: "devops", + CredentialId: "github", + ApiUri: "https://api.github.com", + DiscoverBranches: 1, + DiscoverPRFromOrigin: 2, + DiscoverPRFromForks: &GithubDiscoverPRFromForks{ + Strategy: 1, + Trust: 1, + }, + }, + }, + + &MultiBranchPipeline{ + Name: "", + Description: "for test", + ScriptPath: "Jenkinsfile", + SourceType: "svn", + TimerTrigger: &TimerTrigger{ + Interval: "12345566", + }, + SvnSource: &SvnSource{ + Remote: "https://api.svn.com/bcd", + CredentialId: "svn", + Excludes: "truck", + Includes: "tag/*", + }, + }, + &MultiBranchPipeline{ + Name: "", + Description: "for test", + ScriptPath: "Jenkinsfile", + SourceType: "single_svn", + TimerTrigger: &TimerTrigger{ + Interval: "12345566", + }, + SingleSvnSource: &SingleSvnSource{ + Remote: "https://api.svn.com/bcd", + CredentialId: "svn", + }, + }, + } + + for _, input := range inputs { + outputString, err := createMultiBranchPipelineConfigXml("", input) + if err != nil { + t.Fatalf("should not get error %+v", err) + } + output, err := parseMultiBranchPipelineConfigXml(outputString) + + if err != nil { + t.Fatalf("should not get error %+v", err) + } + if !reflect.DeepEqual(input, output) { + t.Fatalf("input [%+v] output [%+v] should equal ", input, output) + } + } +} + +func Test_MultiBranchPipelineCloneConfig(t *testing.T) { + + inputs := []*MultiBranchPipeline{ + &MultiBranchPipeline{ + Name: "", + Description: "for test", + ScriptPath: "Jenkinsfile", + SourceType: "git", + GitSource: &GitSource{ + Url: "https://github.com/kubesphere/devops", + CredentialId: "git", + DiscoverBranches: true, + CloneOption: &GitCloneOption{ + Shallow: false, + Depth: 3, + Timeout: 20, + }, + }, + }, + &MultiBranchPipeline{ + Name: "", + Description: "for test", + ScriptPath: "Jenkinsfile", + SourceType: "github", + GitHubSource: &GithubSource{ + Owner: "kubesphere", + Repo: "devops", + CredentialId: "github", + ApiUri: "https://api.github.com", + DiscoverBranches: 1, + DiscoverPRFromOrigin: 2, + DiscoverPRFromForks: &GithubDiscoverPRFromForks{ + Strategy: 1, + Trust: 1, + }, + CloneOption: &GitCloneOption{ + Shallow: false, + Depth: 3, + Timeout: 20, + }, + }, + }, + } + + for _, input := range inputs { + outputString, err := createMultiBranchPipelineConfigXml("", input) + if err != nil { + t.Fatalf("should not get error %+v", err) + } + output, err := parseMultiBranchPipelineConfigXml(outputString) + + if err != nil { + t.Fatalf("should not get error %+v", err) + } + if !reflect.DeepEqual(input, output) { + t.Fatalf("input [%+v] output [%+v] should equal ", input, output) + } + } + +} + +func Test_MultiBranchPipelineRegexFilter(t *testing.T) { + + inputs := []*MultiBranchPipeline{ + &MultiBranchPipeline{ + Name: "", + Description: "for test", + ScriptPath: "Jenkinsfile", + SourceType: "git", + GitSource: &GitSource{ + Url: "https://github.com/kubesphere/devops", + CredentialId: "git", + DiscoverBranches: true, + RegexFilter: ".*", + }, + }, + &MultiBranchPipeline{ + Name: "", + Description: "for test", + ScriptPath: "Jenkinsfile", + SourceType: "github", + GitHubSource: &GithubSource{ + Owner: "kubesphere", + Repo: "devops", + CredentialId: "github", + ApiUri: "https://api.github.com", + DiscoverBranches: 1, + DiscoverPRFromOrigin: 2, + DiscoverPRFromForks: &GithubDiscoverPRFromForks{ + Strategy: 1, + Trust: 1, + }, + RegexFilter: ".*", + }, + }, + } + + for _, input := range inputs { + outputString, err := createMultiBranchPipelineConfigXml("", input) + if err != nil { + t.Fatalf("should not get error %+v", err) + } + output, err := parseMultiBranchPipelineConfigXml(outputString) + + if err != nil { + t.Fatalf("should not get error %+v", err) + } + if !reflect.DeepEqual(input, output) { + t.Fatalf("input [%+v] output [%+v] should equal ", input, output) + } + } + +} diff --git a/vendor/github.com/beevik/etree/CONTRIBUTORS b/vendor/github.com/beevik/etree/CONTRIBUTORS new file mode 100644 index 000000000..03211a85e --- /dev/null +++ b/vendor/github.com/beevik/etree/CONTRIBUTORS @@ -0,0 +1,10 @@ +Brett Vickers (beevik) +Felix Geisendörfer (felixge) +Kamil Kisiel (kisielk) +Graham King (grahamking) +Matt Smith (ma314smith) +Michal Jemala (michaljemala) +Nicolas Piganeau (npiganeau) +Chris Brown (ccbrown) +Earncef Sequeira (earncef) +Gabriel de Labachelerie (wuzuf) diff --git a/vendor/github.com/beevik/etree/LICENSE b/vendor/github.com/beevik/etree/LICENSE new file mode 100644 index 000000000..26f1f7751 --- /dev/null +++ b/vendor/github.com/beevik/etree/LICENSE @@ -0,0 +1,24 @@ +Copyright 2015-2019 Brett Vickers. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY COPYRIGHT HOLDER ``AS IS'' AND ANY +EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL COPYRIGHT HOLDER OR +CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/beevik/etree/etree.go b/vendor/github.com/beevik/etree/etree.go new file mode 100644 index 000000000..9e24f9012 --- /dev/null +++ b/vendor/github.com/beevik/etree/etree.go @@ -0,0 +1,1453 @@ +// Copyright 2015-2019 Brett Vickers. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package etree provides XML services through an Element Tree +// abstraction. +package etree + +import ( + "bufio" + "bytes" + "encoding/xml" + "errors" + "io" + "os" + "sort" + "strings" +) + +const ( + // NoIndent is used with Indent to disable all indenting. + NoIndent = -1 +) + +// ErrXML is returned when XML parsing fails due to incorrect formatting. +var ErrXML = errors.New("etree: invalid XML format") + +// ReadSettings allow for changing the default behavior of the ReadFrom* +// methods. +type ReadSettings struct { + // CharsetReader to be passed to standard xml.Decoder. Default: nil. + CharsetReader func(charset string, input io.Reader) (io.Reader, error) + + // Permissive allows input containing common mistakes such as missing tags + // or attribute values. Default: false. + Permissive bool + + // Entity to be passed to standard xml.Decoder. Default: nil. + Entity map[string]string +} + +// newReadSettings creates a default ReadSettings record. +func newReadSettings() ReadSettings { + return ReadSettings{ + CharsetReader: func(label string, input io.Reader) (io.Reader, error) { + return input, nil + }, + Permissive: false, + } +} + +// WriteSettings allow for changing the serialization behavior of the WriteTo* +// methods. +type WriteSettings struct { + // CanonicalEndTags forces the production of XML end tags, even for + // elements that have no child elements. Default: false. + CanonicalEndTags bool + + // CanonicalText forces the production of XML character references for + // text data characters &, <, and >. If false, XML character references + // are also produced for " and '. Default: false. + CanonicalText bool + + // CanonicalAttrVal forces the production of XML character references for + // attribute value characters &, < and ". If false, XML character + // references are also produced for > and '. Default: false. + CanonicalAttrVal bool + + // When outputting indented XML, use a carriage return and linefeed + // ("\r\n") as a new-line delimiter instead of just a linefeed ("\n"). + // This is useful on Windows-based systems. + UseCRLF bool +} + +// newWriteSettings creates a default WriteSettings record. +func newWriteSettings() WriteSettings { + return WriteSettings{ + CanonicalEndTags: false, + CanonicalText: false, + CanonicalAttrVal: false, + UseCRLF: false, + } +} + +// A Token is an empty interface that represents an Element, CharData, +// Comment, Directive, or ProcInst. +type Token interface { + Parent() *Element + Index() int + dup(parent *Element) Token + setParent(parent *Element) + setIndex(index int) + writeTo(w *bufio.Writer, s *WriteSettings) +} + +// A Document is a container holding a complete XML hierarchy. Its embedded +// element contains zero or more children, one of which is usually the root +// element. The embedded element may include other children such as +// processing instructions or BOM CharData tokens. +type Document struct { + Element + ReadSettings ReadSettings + WriteSettings WriteSettings +} + +// An Element represents an XML element, its attributes, and its child tokens. +type Element struct { + Space, Tag string // namespace prefix and tag + Attr []Attr // key-value attribute pairs + Child []Token // child tokens (elements, comments, etc.) + parent *Element // parent element + index int // token index in parent's children +} + +// An Attr represents a key-value attribute of an XML element. +type Attr struct { + Space, Key string // The attribute's namespace prefix and key + Value string // The attribute value string + element *Element // element containing the attribute +} + +// charDataFlags are used with CharData tokens to store additional settings. +type charDataFlags uint8 + +const ( + // The CharData was created by an indent function as whitespace. + whitespaceFlag charDataFlags = 1 << iota + + // The CharData contains a CDATA section. + cdataFlag +) + +// CharData can be used to represent character data or a CDATA section within +// an XML document. +type CharData struct { + Data string + parent *Element + index int + flags charDataFlags +} + +// A Comment represents an XML comment. +type Comment struct { + Data string + parent *Element + index int +} + +// A Directive represents an XML directive. +type Directive struct { + Data string + parent *Element + index int +} + +// A ProcInst represents an XML processing instruction. +type ProcInst struct { + Target string + Inst string + parent *Element + index int +} + +// NewDocument creates an XML document without a root element. +func NewDocument() *Document { + return &Document{ + Element{Child: make([]Token, 0)}, + newReadSettings(), + newWriteSettings(), + } +} + +// Copy returns a recursive, deep copy of the document. +func (d *Document) Copy() *Document { + return &Document{*(d.dup(nil).(*Element)), d.ReadSettings, d.WriteSettings} +} + +// Root returns the root element of the document, or nil if there is no root +// element. +func (d *Document) Root() *Element { + for _, t := range d.Child { + if c, ok := t.(*Element); ok { + return c + } + } + return nil +} + +// SetRoot replaces the document's root element with e. If the document +// already has a root when this function is called, then the document's +// original root is unbound first. If the element e is bound to another +// document (or to another element within a document), then it is unbound +// first. +func (d *Document) SetRoot(e *Element) { + if e.parent != nil { + e.parent.RemoveChild(e) + } + + p := &d.Element + e.setParent(p) + + // If there is already a root element, replace it. + for i, t := range p.Child { + if _, ok := t.(*Element); ok { + t.setParent(nil) + t.setIndex(-1) + p.Child[i] = e + e.setIndex(i) + return + } + } + + // No existing root element, so add it. + p.addChild(e) +} + +// ReadFrom reads XML from the reader r into the document d. It returns the +// number of bytes read and any error encountered. +func (d *Document) ReadFrom(r io.Reader) (n int64, err error) { + return d.Element.readFrom(r, d.ReadSettings) +} + +// ReadFromFile reads XML from the string s into the document d. +func (d *Document) ReadFromFile(filename string) error { + f, err := os.Open(filename) + if err != nil { + return err + } + defer f.Close() + _, err = d.ReadFrom(f) + return err +} + +// ReadFromBytes reads XML from the byte slice b into the document d. +func (d *Document) ReadFromBytes(b []byte) error { + _, err := d.ReadFrom(bytes.NewReader(b)) + return err +} + +// ReadFromString reads XML from the string s into the document d. +func (d *Document) ReadFromString(s string) error { + _, err := d.ReadFrom(strings.NewReader(s)) + return err +} + +// WriteTo serializes an XML document into the writer w. It +// returns the number of bytes written and any error encountered. +func (d *Document) WriteTo(w io.Writer) (n int64, err error) { + cw := newCountWriter(w) + b := bufio.NewWriter(cw) + for _, c := range d.Child { + c.writeTo(b, &d.WriteSettings) + } + err, n = b.Flush(), cw.bytes + return +} + +// WriteToFile serializes an XML document into the file named +// filename. +func (d *Document) WriteToFile(filename string) error { + f, err := os.Create(filename) + if err != nil { + return err + } + defer f.Close() + _, err = d.WriteTo(f) + return err +} + +// WriteToBytes serializes the XML document into a slice of +// bytes. +func (d *Document) WriteToBytes() (b []byte, err error) { + var buf bytes.Buffer + if _, err = d.WriteTo(&buf); err != nil { + return + } + return buf.Bytes(), nil +} + +// WriteToString serializes the XML document into a string. +func (d *Document) WriteToString() (s string, err error) { + var b []byte + if b, err = d.WriteToBytes(); err != nil { + return + } + return string(b), nil +} + +type indentFunc func(depth int) string + +// Indent modifies the document's element tree by inserting character data +// tokens containing newlines and indentation. The amount of indentation per +// depth level is given as spaces. Pass etree.NoIndent for spaces if you want +// no indentation at all. +func (d *Document) Indent(spaces int) { + var indent indentFunc + switch { + case spaces < 0: + indent = func(depth int) string { return "" } + case d.WriteSettings.UseCRLF == true: + indent = func(depth int) string { return indentCRLF(depth*spaces, indentSpaces) } + default: + indent = func(depth int) string { return indentLF(depth*spaces, indentSpaces) } + } + d.Element.indent(0, indent) +} + +// IndentTabs modifies the document's element tree by inserting CharData +// tokens containing newlines and tabs for indentation. One tab is used per +// indentation level. +func (d *Document) IndentTabs() { + var indent indentFunc + switch d.WriteSettings.UseCRLF { + case true: + indent = func(depth int) string { return indentCRLF(depth, indentTabs) } + default: + indent = func(depth int) string { return indentLF(depth, indentTabs) } + } + d.Element.indent(0, indent) +} + +// NewElement creates an unparented element with the specified tag. The tag +// may be prefixed by a namespace prefix and a colon. +func NewElement(tag string) *Element { + space, stag := spaceDecompose(tag) + return newElement(space, stag, nil) +} + +// newElement is a helper function that creates an element and binds it to +// a parent element if possible. +func newElement(space, tag string, parent *Element) *Element { + e := &Element{ + Space: space, + Tag: tag, + Attr: make([]Attr, 0), + Child: make([]Token, 0), + parent: parent, + index: -1, + } + if parent != nil { + parent.addChild(e) + } + return e +} + +// Copy creates a recursive, deep copy of the element and all its attributes +// and children. The returned element has no parent but can be parented to a +// another element using AddElement, or to a document using SetRoot. +func (e *Element) Copy() *Element { + return e.dup(nil).(*Element) +} + +// FullTag returns the element e's complete tag, including namespace prefix if +// present. +func (e *Element) FullTag() string { + if e.Space == "" { + return e.Tag + } + return e.Space + ":" + e.Tag +} + +// NamespaceURI returns the XML namespace URI associated with the element. If +// the element is part of the XML default namespace, NamespaceURI returns the +// empty string. +func (e *Element) NamespaceURI() string { + if e.Space == "" { + return e.findDefaultNamespaceURI() + } + return e.findLocalNamespaceURI(e.Space) +} + +// findLocalNamespaceURI finds the namespace URI corresponding to the +// requested prefix. +func (e *Element) findLocalNamespaceURI(prefix string) string { + for _, a := range e.Attr { + if a.Space == "xmlns" && a.Key == prefix { + return a.Value + } + } + + if e.parent == nil { + return "" + } + + return e.parent.findLocalNamespaceURI(prefix) +} + +// findDefaultNamespaceURI finds the default namespace URI of the element. +func (e *Element) findDefaultNamespaceURI() string { + for _, a := range e.Attr { + if a.Space == "" && a.Key == "xmlns" { + return a.Value + } + } + + if e.parent == nil { + return "" + } + + return e.parent.findDefaultNamespaceURI() +} + +// hasText returns true if the element has character data immediately +// folllowing the element's opening tag. +func (e *Element) hasText() bool { + if len(e.Child) == 0 { + return false + } + _, ok := e.Child[0].(*CharData) + return ok +} + +// namespacePrefix returns the namespace prefix associated with the element. +func (e *Element) namespacePrefix() string { + return e.Space +} + +// name returns the tag associated with the element. +func (e *Element) name() string { + return e.Tag +} + +// Text returns all character data immediately following the element's opening +// tag. +func (e *Element) Text() string { + if len(e.Child) == 0 { + return "" + } + + text := "" + for _, ch := range e.Child { + if cd, ok := ch.(*CharData); ok { + if text == "" { + text = cd.Data + } else { + text = text + cd.Data + } + } else { + break + } + } + return text +} + +// SetText replaces all character data immediately following an element's +// opening tag with the requested string. +func (e *Element) SetText(text string) { + e.replaceText(0, text, 0) +} + +// SetCData replaces all character data immediately following an element's +// opening tag with a CDATA section. +func (e *Element) SetCData(text string) { + e.replaceText(0, text, cdataFlag) +} + +// Tail returns all character data immediately following the element's end +// tag. +func (e *Element) Tail() string { + if e.Parent() == nil { + return "" + } + + p := e.Parent() + i := e.Index() + + text := "" + for _, ch := range p.Child[i+1:] { + if cd, ok := ch.(*CharData); ok { + if text == "" { + text = cd.Data + } else { + text = text + cd.Data + } + } else { + break + } + } + return text +} + +// SetTail replaces all character data immediately following the element's end +// tag with the requested string. +func (e *Element) SetTail(text string) { + if e.Parent() == nil { + return + } + + p := e.Parent() + p.replaceText(e.Index()+1, text, 0) +} + +// replaceText is a helper function that replaces a series of chardata tokens +// starting at index i with the requested text. +func (e *Element) replaceText(i int, text string, flags charDataFlags) { + end := e.findTermCharDataIndex(i) + + switch { + case end == i: + if text != "" { + // insert a new chardata token at index i + cd := newCharData(text, flags, nil) + e.InsertChildAt(i, cd) + } + + case end == i+1: + if text == "" { + // remove the chardata token at index i + e.RemoveChildAt(i) + } else { + // replace the first and only character token at index i + cd := e.Child[i].(*CharData) + cd.Data, cd.flags = text, flags + } + + default: + if text == "" { + // remove all chardata tokens starting from index i + copy(e.Child[i:], e.Child[end:]) + removed := end - i + e.Child = e.Child[:len(e.Child)-removed] + for j := i; j < len(e.Child); j++ { + e.Child[j].setIndex(j) + } + } else { + // replace the first chardata token at index i and remove all + // subsequent chardata tokens + cd := e.Child[i].(*CharData) + cd.Data, cd.flags = text, flags + copy(e.Child[i+1:], e.Child[end:]) + removed := end - (i + 1) + e.Child = e.Child[:len(e.Child)-removed] + for j := i + 1; j < len(e.Child); j++ { + e.Child[j].setIndex(j) + } + } + } +} + +// findTermCharDataIndex finds the index of the first child token that isn't +// a CharData token. It starts from the requested start index. +func (e *Element) findTermCharDataIndex(start int) int { + for i := start; i < len(e.Child); i++ { + if _, ok := e.Child[i].(*CharData); !ok { + return i + } + } + return len(e.Child) +} + +// CreateElement creates an element with the specified tag and adds it as the +// last child element of the element e. The tag may be prefixed by a namespace +// prefix and a colon. +func (e *Element) CreateElement(tag string) *Element { + space, stag := spaceDecompose(tag) + return newElement(space, stag, e) +} + +// AddChild adds the token t as the last child of element e. If token t was +// already the child of another element, it is first removed from its current +// parent element. +func (e *Element) AddChild(t Token) { + if t.Parent() != nil { + t.Parent().RemoveChild(t) + } + + t.setParent(e) + e.addChild(t) +} + +// InsertChild inserts the token t before e's existing child token ex. If ex +// is nil or ex is not a child of e, then t is added to the end of e's child +// token list. If token t was already the child of another element, it is +// first removed from its current parent element. +// +// Deprecated: InsertChild is deprecated. Use InsertChildAt instead. +func (e *Element) InsertChild(ex Token, t Token) { + if ex == nil || ex.Parent() != e { + e.AddChild(t) + return + } + + if t.Parent() != nil { + t.Parent().RemoveChild(t) + } + + t.setParent(e) + + i := ex.Index() + e.Child = append(e.Child, nil) + copy(e.Child[i+1:], e.Child[i:]) + e.Child[i] = t + + for j := i; j < len(e.Child); j++ { + e.Child[j].setIndex(j) + } +} + +// InsertChildAt inserts the token t into the element e's list of child tokens +// just before the requested index. If the index is greater than or equal to +// the length of the list of child tokens, the token t is added to the end of +// the list. +func (e *Element) InsertChildAt(index int, t Token) { + if index >= len(e.Child) { + e.AddChild(t) + return + } + + if t.Parent() != nil { + if t.Parent() == e && t.Index() > index { + index-- + } + t.Parent().RemoveChild(t) + } + + t.setParent(e) + + e.Child = append(e.Child, nil) + copy(e.Child[index+1:], e.Child[index:]) + e.Child[index] = t + + for j := index; j < len(e.Child); j++ { + e.Child[j].setIndex(j) + } +} + +// RemoveChild attempts to remove the token t from element e's list of +// children. If the token t is a child of e, then it is returned. Otherwise, +// nil is returned. +func (e *Element) RemoveChild(t Token) Token { + if t.Parent() != e { + return nil + } + return e.RemoveChildAt(t.Index()) +} + +// RemoveChildAt removes the index-th child token from the element e. The +// removed child token is returned. If the index is out of bounds, no child is +// removed and nil is returned. +func (e *Element) RemoveChildAt(index int) Token { + if index >= len(e.Child) { + return nil + } + + t := e.Child[index] + for j := index + 1; j < len(e.Child); j++ { + e.Child[j].setIndex(j - 1) + } + e.Child = append(e.Child[:index], e.Child[index+1:]...) + t.setIndex(-1) + t.setParent(nil) + return t +} + +// ReadFrom reads XML from the reader r and stores the result as a new child +// of element e. +func (e *Element) readFrom(ri io.Reader, settings ReadSettings) (n int64, err error) { + r := newCountReader(ri) + dec := xml.NewDecoder(r) + dec.CharsetReader = settings.CharsetReader + dec.Strict = !settings.Permissive + dec.Entity = settings.Entity + var stack stack + stack.push(e) + for { + t, err := dec.RawToken() + switch { + case err == io.EOF: + return r.bytes, nil + case err != nil: + return r.bytes, err + case stack.empty(): + return r.bytes, ErrXML + } + + top := stack.peek().(*Element) + + switch t := t.(type) { + case xml.StartElement: + e := newElement(t.Name.Space, t.Name.Local, top) + for _, a := range t.Attr { + e.createAttr(a.Name.Space, a.Name.Local, a.Value, e) + } + stack.push(e) + case xml.EndElement: + stack.pop() + case xml.CharData: + data := string(t) + var flags charDataFlags + if isWhitespace(data) { + flags = whitespaceFlag + } + newCharData(data, flags, top) + case xml.Comment: + newComment(string(t), top) + case xml.Directive: + newDirective(string(t), top) + case xml.ProcInst: + newProcInst(t.Target, string(t.Inst), top) + } + } +} + +// SelectAttr finds an element attribute matching the requested key and +// returns it if found. Returns nil if no matching attribute is found. The key +// may be prefixed by a namespace prefix and a colon. +func (e *Element) SelectAttr(key string) *Attr { + space, skey := spaceDecompose(key) + for i, a := range e.Attr { + if spaceMatch(space, a.Space) && skey == a.Key { + return &e.Attr[i] + } + } + return nil +} + +// SelectAttrValue finds an element attribute matching the requested key and +// returns its value if found. The key may be prefixed by a namespace prefix +// and a colon. If the key is not found, the dflt value is returned instead. +func (e *Element) SelectAttrValue(key, dflt string) string { + space, skey := spaceDecompose(key) + for _, a := range e.Attr { + if spaceMatch(space, a.Space) && skey == a.Key { + return a.Value + } + } + return dflt +} + +// ChildElements returns all elements that are children of element e. +func (e *Element) ChildElements() []*Element { + var elements []*Element + for _, t := range e.Child { + if c, ok := t.(*Element); ok { + elements = append(elements, c) + } + } + return elements +} + +// SelectElement returns the first child element with the given tag. The tag +// may be prefixed by a namespace prefix and a colon. Returns nil if no +// element with a matching tag was found. +func (e *Element) SelectElement(tag string) *Element { + space, stag := spaceDecompose(tag) + for _, t := range e.Child { + if c, ok := t.(*Element); ok && spaceMatch(space, c.Space) && stag == c.Tag { + return c + } + } + return nil +} + +// SelectElements returns a slice of all child elements with the given tag. +// The tag may be prefixed by a namespace prefix and a colon. +func (e *Element) SelectElements(tag string) []*Element { + space, stag := spaceDecompose(tag) + var elements []*Element + for _, t := range e.Child { + if c, ok := t.(*Element); ok && spaceMatch(space, c.Space) && stag == c.Tag { + elements = append(elements, c) + } + } + return elements +} + +// FindElement returns the first element matched by the XPath-like path +// string. Returns nil if no element is found using the path. Panics if an +// invalid path string is supplied. +func (e *Element) FindElement(path string) *Element { + return e.FindElementPath(MustCompilePath(path)) +} + +// FindElementPath returns the first element matched by the XPath-like path +// string. Returns nil if no element is found using the path. +func (e *Element) FindElementPath(path Path) *Element { + p := newPather() + elements := p.traverse(e, path) + switch { + case len(elements) > 0: + return elements[0] + default: + return nil + } +} + +// FindElements returns a slice of elements matched by the XPath-like path +// string. Panics if an invalid path string is supplied. +func (e *Element) FindElements(path string) []*Element { + return e.FindElementsPath(MustCompilePath(path)) +} + +// FindElementsPath returns a slice of elements matched by the Path object. +func (e *Element) FindElementsPath(path Path) []*Element { + p := newPather() + return p.traverse(e, path) +} + +// GetPath returns the absolute path of the element. +func (e *Element) GetPath() string { + path := []string{} + for seg := e; seg != nil; seg = seg.Parent() { + if seg.Tag != "" { + path = append(path, seg.Tag) + } + } + + // Reverse the path. + for i, j := 0, len(path)-1; i < j; i, j = i+1, j-1 { + path[i], path[j] = path[j], path[i] + } + + return "/" + strings.Join(path, "/") +} + +// GetRelativePath returns the path of the element relative to the source +// element. If the two elements are not part of the same element tree, then +// GetRelativePath returns the empty string. +func (e *Element) GetRelativePath(source *Element) string { + var path []*Element + + if source == nil { + return "" + } + + // Build a reverse path from the element toward the root. Stop if the + // source element is encountered. + var seg *Element + for seg = e; seg != nil && seg != source; seg = seg.Parent() { + path = append(path, seg) + } + + // If we found the source element, reverse the path and compose the + // string. + if seg == source { + if len(path) == 0 { + return "." + } + parts := []string{} + for i := len(path) - 1; i >= 0; i-- { + parts = append(parts, path[i].Tag) + } + return "./" + strings.Join(parts, "/") + } + + // The source wasn't encountered, so climb from the source element toward + // the root of the tree until an element in the reversed path is + // encountered. + + findPathIndex := func(e *Element, path []*Element) int { + for i, ee := range path { + if e == ee { + return i + } + } + return -1 + } + + climb := 0 + for seg = source; seg != nil; seg = seg.Parent() { + i := findPathIndex(seg, path) + if i >= 0 { + path = path[:i] // truncate at found segment + break + } + climb++ + } + + // No element in the reversed path was encountered, so the two elements + // must not be part of the same tree. + if seg == nil { + return "" + } + + // Reverse the (possibly truncated) path and prepend ".." segments to + // climb. + parts := []string{} + for i := 0; i < climb; i++ { + parts = append(parts, "..") + } + for i := len(path) - 1; i >= 0; i-- { + parts = append(parts, path[i].Tag) + } + return strings.Join(parts, "/") +} + +// indent recursively inserts proper indentation between an +// XML element's child tokens. +func (e *Element) indent(depth int, indent indentFunc) { + e.stripIndent() + n := len(e.Child) + if n == 0 { + return + } + + oldChild := e.Child + e.Child = make([]Token, 0, n*2+1) + isCharData, firstNonCharData := false, true + for _, c := range oldChild { + // Insert NL+indent before child if it's not character data. + // Exceptions: when it's the first non-character-data child, or when + // the child is at root depth. + _, isCharData = c.(*CharData) + if !isCharData { + if !firstNonCharData || depth > 0 { + s := indent(depth) + if s != "" { + newCharData(s, whitespaceFlag, e) + } + } + firstNonCharData = false + } + + e.addChild(c) + + // Recursively process child elements. + if ce, ok := c.(*Element); ok { + ce.indent(depth+1, indent) + } + } + + // Insert NL+indent before the last child. + if !isCharData { + if !firstNonCharData || depth > 0 { + s := indent(depth - 1) + if s != "" { + newCharData(s, whitespaceFlag, e) + } + } + } +} + +// stripIndent removes any previously inserted indentation. +func (e *Element) stripIndent() { + // Count the number of non-indent child tokens + n := len(e.Child) + for _, c := range e.Child { + if cd, ok := c.(*CharData); ok && cd.IsWhitespace() { + n-- + } + } + if n == len(e.Child) { + return + } + + // Strip out indent CharData + newChild := make([]Token, n) + j := 0 + for _, c := range e.Child { + if cd, ok := c.(*CharData); ok && cd.IsWhitespace() { + continue + } + newChild[j] = c + newChild[j].setIndex(j) + j++ + } + e.Child = newChild +} + +// dup duplicates the element. +func (e *Element) dup(parent *Element) Token { + ne := &Element{ + Space: e.Space, + Tag: e.Tag, + Attr: make([]Attr, len(e.Attr)), + Child: make([]Token, len(e.Child)), + parent: parent, + index: e.index, + } + for i, t := range e.Child { + ne.Child[i] = t.dup(ne) + } + for i, a := range e.Attr { + ne.Attr[i] = a + } + return ne +} + +// Parent returns the element token's parent element, or nil if it has no +// parent. +func (e *Element) Parent() *Element { + return e.parent +} + +// Index returns the index of this element within its parent element's +// list of child tokens. If this element has no parent element, the index +// is -1. +func (e *Element) Index() int { + return e.index +} + +// setParent replaces the element token's parent. +func (e *Element) setParent(parent *Element) { + e.parent = parent +} + +// setIndex sets the element token's index within its parent's Child slice. +func (e *Element) setIndex(index int) { + e.index = index +} + +// writeTo serializes the element to the writer w. +func (e *Element) writeTo(w *bufio.Writer, s *WriteSettings) { + w.WriteByte('<') + w.WriteString(e.FullTag()) + for _, a := range e.Attr { + w.WriteByte(' ') + a.writeTo(w, s) + } + if len(e.Child) > 0 { + w.WriteString(">") + for _, c := range e.Child { + c.writeTo(w, s) + } + w.Write([]byte{'<', '/'}) + w.WriteString(e.FullTag()) + w.WriteByte('>') + } else { + if s.CanonicalEndTags { + w.Write([]byte{'>', '<', '/'}) + w.WriteString(e.FullTag()) + w.WriteByte('>') + } else { + w.Write([]byte{'/', '>'}) + } + } +} + +// addChild adds a child token to the element e. +func (e *Element) addChild(t Token) { + t.setIndex(len(e.Child)) + e.Child = append(e.Child, t) +} + +// CreateAttr creates an attribute and adds it to element e. The key may be +// prefixed by a namespace prefix and a colon. If an attribute with the key +// already exists, its value is replaced. +func (e *Element) CreateAttr(key, value string) *Attr { + space, skey := spaceDecompose(key) + return e.createAttr(space, skey, value, e) +} + +// createAttr is a helper function that creates attributes. +func (e *Element) createAttr(space, key, value string, parent *Element) *Attr { + for i, a := range e.Attr { + if space == a.Space && key == a.Key { + e.Attr[i].Value = value + return &e.Attr[i] + } + } + a := Attr{ + Space: space, + Key: key, + Value: value, + element: parent, + } + e.Attr = append(e.Attr, a) + return &e.Attr[len(e.Attr)-1] +} + +// RemoveAttr removes and returns a copy of the first attribute of the element +// whose key matches the given key. The key may be prefixed by a namespace +// prefix and a colon. If a matching attribute does not exist, nil is +// returned. +func (e *Element) RemoveAttr(key string) *Attr { + space, skey := spaceDecompose(key) + for i, a := range e.Attr { + if space == a.Space && skey == a.Key { + e.Attr = append(e.Attr[0:i], e.Attr[i+1:]...) + return &Attr{ + Space: a.Space, + Key: a.Key, + Value: a.Value, + element: nil, + } + } + } + return nil +} + +// SortAttrs sorts the element's attributes lexicographically by key. +func (e *Element) SortAttrs() { + sort.Sort(byAttr(e.Attr)) +} + +type byAttr []Attr + +func (a byAttr) Len() int { + return len(a) +} + +func (a byAttr) Swap(i, j int) { + a[i], a[j] = a[j], a[i] +} + +func (a byAttr) Less(i, j int) bool { + sp := strings.Compare(a[i].Space, a[j].Space) + if sp == 0 { + return strings.Compare(a[i].Key, a[j].Key) < 0 + } + return sp < 0 +} + +// FullKey returns the attribute a's complete key, including namespace prefix +// if present. +func (a *Attr) FullKey() string { + if a.Space == "" { + return a.Key + } + return a.Space + ":" + a.Key +} + +// Element returns the element containing the attribute. +func (a *Attr) Element() *Element { + return a.element +} + +// NamespaceURI returns the XML namespace URI associated with the attribute. +// If the element is part of the XML default namespace, NamespaceURI returns +// the empty string. +func (a *Attr) NamespaceURI() string { + return a.element.NamespaceURI() +} + +// writeTo serializes the attribute to the writer. +func (a *Attr) writeTo(w *bufio.Writer, s *WriteSettings) { + w.WriteString(a.FullKey()) + w.WriteString(`="`) + var m escapeMode + if s.CanonicalAttrVal { + m = escapeCanonicalAttr + } else { + m = escapeNormal + } + escapeString(w, a.Value, m) + w.WriteByte('"') +} + +// NewText creates a parentless CharData token containing character data. +func NewText(text string) *CharData { + return newCharData(text, 0, nil) +} + +// NewCData creates a parentless XML character CDATA section. +func NewCData(data string) *CharData { + return newCharData(data, cdataFlag, nil) +} + +// NewCharData creates a parentless CharData token containing character data. +// +// Deprecated: NewCharData is deprecated. Instead, use NewText, which does the +// same thing. +func NewCharData(data string) *CharData { + return newCharData(data, 0, nil) +} + +// newCharData creates a character data token and binds it to a parent +// element. If parent is nil, the CharData token remains unbound. +func newCharData(data string, flags charDataFlags, parent *Element) *CharData { + c := &CharData{ + Data: data, + parent: parent, + index: -1, + flags: flags, + } + if parent != nil { + parent.addChild(c) + } + return c +} + +// CreateText creates a CharData token containing character data and adds it +// as a child of element e. +func (e *Element) CreateText(text string) *CharData { + return newCharData(text, 0, e) +} + +// CreateCData creates a CharData token containing a CDATA section and adds it +// as a child of element e. +func (e *Element) CreateCData(data string) *CharData { + return newCharData(data, cdataFlag, e) +} + +// CreateCharData creates a CharData token containing character data and adds +// it as a child of element e. +// +// Deprecated: CreateCharData is deprecated. Instead, use CreateText, which +// does the same thing. +func (e *Element) CreateCharData(data string) *CharData { + return newCharData(data, 0, e) +} + +// dup duplicates the character data. +func (c *CharData) dup(parent *Element) Token { + return &CharData{ + Data: c.Data, + flags: c.flags, + parent: parent, + index: c.index, + } +} + +// IsCData returns true if the character data token is to be encoded as a +// CDATA section. +func (c *CharData) IsCData() bool { + return (c.flags & cdataFlag) != 0 +} + +// IsWhitespace returns true if the character data token was created by one of +// the document Indent methods to contain only whitespace. +func (c *CharData) IsWhitespace() bool { + return (c.flags & whitespaceFlag) != 0 +} + +// Parent returns the character data token's parent element, or nil if it has +// no parent. +func (c *CharData) Parent() *Element { + return c.parent +} + +// Index returns the index of this CharData token within its parent element's +// list of child tokens. If this CharData token has no parent element, the +// index is -1. +func (c *CharData) Index() int { + return c.index +} + +// setParent replaces the character data token's parent. +func (c *CharData) setParent(parent *Element) { + c.parent = parent +} + +// setIndex sets the CharData token's index within its parent element's Child +// slice. +func (c *CharData) setIndex(index int) { + c.index = index +} + +// writeTo serializes character data to the writer. +func (c *CharData) writeTo(w *bufio.Writer, s *WriteSettings) { + if c.IsCData() { + w.WriteString(``) + } else { + var m escapeMode + if s.CanonicalText { + m = escapeCanonicalText + } else { + m = escapeNormal + } + escapeString(w, c.Data, m) + } +} + +// NewComment creates a parentless XML comment. +func NewComment(comment string) *Comment { + return newComment(comment, nil) +} + +// NewComment creates an XML comment and binds it to a parent element. If +// parent is nil, the Comment remains unbound. +func newComment(comment string, parent *Element) *Comment { + c := &Comment{ + Data: comment, + parent: parent, + index: -1, + } + if parent != nil { + parent.addChild(c) + } + return c +} + +// CreateComment creates an XML comment and adds it as a child of element e. +func (e *Element) CreateComment(comment string) *Comment { + return newComment(comment, e) +} + +// dup duplicates the comment. +func (c *Comment) dup(parent *Element) Token { + return &Comment{ + Data: c.Data, + parent: parent, + index: c.index, + } +} + +// Parent returns comment token's parent element, or nil if it has no parent. +func (c *Comment) Parent() *Element { + return c.parent +} + +// Index returns the index of this Comment token within its parent element's +// list of child tokens. If this Comment token has no parent element, the +// index is -1. +func (c *Comment) Index() int { + return c.index +} + +// setParent replaces the comment token's parent. +func (c *Comment) setParent(parent *Element) { + c.parent = parent +} + +// setIndex sets the Comment token's index within its parent element's Child +// slice. +func (c *Comment) setIndex(index int) { + c.index = index +} + +// writeTo serialies the comment to the writer. +func (c *Comment) writeTo(w *bufio.Writer, s *WriteSettings) { + w.WriteString("") +} + +// NewDirective creates a parentless XML directive. +func NewDirective(data string) *Directive { + return newDirective(data, nil) +} + +// newDirective creates an XML directive and binds it to a parent element. If +// parent is nil, the Directive remains unbound. +func newDirective(data string, parent *Element) *Directive { + d := &Directive{ + Data: data, + parent: parent, + index: -1, + } + if parent != nil { + parent.addChild(d) + } + return d +} + +// CreateDirective creates an XML directive and adds it as the last child of +// element e. +func (e *Element) CreateDirective(data string) *Directive { + return newDirective(data, e) +} + +// dup duplicates the directive. +func (d *Directive) dup(parent *Element) Token { + return &Directive{ + Data: d.Data, + parent: parent, + index: d.index, + } +} + +// Parent returns directive token's parent element, or nil if it has no +// parent. +func (d *Directive) Parent() *Element { + return d.parent +} + +// Index returns the index of this Directive token within its parent element's +// list of child tokens. If this Directive token has no parent element, the +// index is -1. +func (d *Directive) Index() int { + return d.index +} + +// setParent replaces the directive token's parent. +func (d *Directive) setParent(parent *Element) { + d.parent = parent +} + +// setIndex sets the Directive token's index within its parent element's Child +// slice. +func (d *Directive) setIndex(index int) { + d.index = index +} + +// writeTo serializes the XML directive to the writer. +func (d *Directive) writeTo(w *bufio.Writer, s *WriteSettings) { + w.WriteString("") +} + +// NewProcInst creates a parentless XML processing instruction. +func NewProcInst(target, inst string) *ProcInst { + return newProcInst(target, inst, nil) +} + +// newProcInst creates an XML processing instruction and binds it to a parent +// element. If parent is nil, the ProcInst remains unbound. +func newProcInst(target, inst string, parent *Element) *ProcInst { + p := &ProcInst{ + Target: target, + Inst: inst, + parent: parent, + index: -1, + } + if parent != nil { + parent.addChild(p) + } + return p +} + +// CreateProcInst creates a processing instruction and adds it as a child of +// element e. +func (e *Element) CreateProcInst(target, inst string) *ProcInst { + return newProcInst(target, inst, e) +} + +// dup duplicates the procinst. +func (p *ProcInst) dup(parent *Element) Token { + return &ProcInst{ + Target: p.Target, + Inst: p.Inst, + parent: parent, + index: p.index, + } +} + +// Parent returns processing instruction token's parent element, or nil if it +// has no parent. +func (p *ProcInst) Parent() *Element { + return p.parent +} + +// Index returns the index of this ProcInst token within its parent element's +// list of child tokens. If this ProcInst token has no parent element, the +// index is -1. +func (p *ProcInst) Index() int { + return p.index +} + +// setParent replaces the processing instruction token's parent. +func (p *ProcInst) setParent(parent *Element) { + p.parent = parent +} + +// setIndex sets the processing instruction token's index within its parent +// element's Child slice. +func (p *ProcInst) setIndex(index int) { + p.index = index +} + +// writeTo serializes the processing instruction to the writer. +func (p *ProcInst) writeTo(w *bufio.Writer, s *WriteSettings) { + w.WriteString("") +} diff --git a/vendor/github.com/beevik/etree/helpers.go b/vendor/github.com/beevik/etree/helpers.go new file mode 100644 index 000000000..825e14e91 --- /dev/null +++ b/vendor/github.com/beevik/etree/helpers.go @@ -0,0 +1,276 @@ +// Copyright 2015-2019 Brett Vickers. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package etree + +import ( + "bufio" + "io" + "strings" + "unicode/utf8" +) + +// A simple stack +type stack struct { + data []interface{} +} + +func (s *stack) empty() bool { + return len(s.data) == 0 +} + +func (s *stack) push(value interface{}) { + s.data = append(s.data, value) +} + +func (s *stack) pop() interface{} { + value := s.data[len(s.data)-1] + s.data[len(s.data)-1] = nil + s.data = s.data[:len(s.data)-1] + return value +} + +func (s *stack) peek() interface{} { + return s.data[len(s.data)-1] +} + +// A fifo is a simple first-in-first-out queue. +type fifo struct { + data []interface{} + head, tail int +} + +func (f *fifo) add(value interface{}) { + if f.len()+1 >= len(f.data) { + f.grow() + } + f.data[f.tail] = value + if f.tail++; f.tail == len(f.data) { + f.tail = 0 + } +} + +func (f *fifo) remove() interface{} { + value := f.data[f.head] + f.data[f.head] = nil + if f.head++; f.head == len(f.data) { + f.head = 0 + } + return value +} + +func (f *fifo) len() int { + if f.tail >= f.head { + return f.tail - f.head + } + return len(f.data) - f.head + f.tail +} + +func (f *fifo) grow() { + c := len(f.data) * 2 + if c == 0 { + c = 4 + } + buf, count := make([]interface{}, c), f.len() + if f.tail >= f.head { + copy(buf[0:count], f.data[f.head:f.tail]) + } else { + hindex := len(f.data) - f.head + copy(buf[0:hindex], f.data[f.head:]) + copy(buf[hindex:count], f.data[:f.tail]) + } + f.data, f.head, f.tail = buf, 0, count +} + +// countReader implements a proxy reader that counts the number of +// bytes read from its encapsulated reader. +type countReader struct { + r io.Reader + bytes int64 +} + +func newCountReader(r io.Reader) *countReader { + return &countReader{r: r} +} + +func (cr *countReader) Read(p []byte) (n int, err error) { + b, err := cr.r.Read(p) + cr.bytes += int64(b) + return b, err +} + +// countWriter implements a proxy writer that counts the number of +// bytes written by its encapsulated writer. +type countWriter struct { + w io.Writer + bytes int64 +} + +func newCountWriter(w io.Writer) *countWriter { + return &countWriter{w: w} +} + +func (cw *countWriter) Write(p []byte) (n int, err error) { + b, err := cw.w.Write(p) + cw.bytes += int64(b) + return b, err +} + +// isWhitespace returns true if the byte slice contains only +// whitespace characters. +func isWhitespace(s string) bool { + for i := 0; i < len(s); i++ { + if c := s[i]; c != ' ' && c != '\t' && c != '\n' && c != '\r' { + return false + } + } + return true +} + +// spaceMatch returns true if namespace a is the empty string +// or if namespace a equals namespace b. +func spaceMatch(a, b string) bool { + switch { + case a == "": + return true + default: + return a == b + } +} + +// spaceDecompose breaks a namespace:tag identifier at the ':' +// and returns the two parts. +func spaceDecompose(str string) (space, key string) { + colon := strings.IndexByte(str, ':') + if colon == -1 { + return "", str + } + return str[:colon], str[colon+1:] +} + +// Strings used by indentCRLF and indentLF +const ( + indentSpaces = "\r\n " + indentTabs = "\r\n\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t" +) + +// indentCRLF returns a CRLF newline followed by n copies of the first +// non-CRLF character in the source string. +func indentCRLF(n int, source string) string { + switch { + case n < 0: + return source[:2] + case n < len(source)-1: + return source[:n+2] + default: + return source + strings.Repeat(source[2:3], n-len(source)+2) + } +} + +// indentLF returns a LF newline followed by n copies of the first non-LF +// character in the source string. +func indentLF(n int, source string) string { + switch { + case n < 0: + return source[1:2] + case n < len(source)-1: + return source[1 : n+2] + default: + return source[1:] + strings.Repeat(source[2:3], n-len(source)+2) + } +} + +// nextIndex returns the index of the next occurrence of sep in s, +// starting from offset. It returns -1 if the sep string is not found. +func nextIndex(s, sep string, offset int) int { + switch i := strings.Index(s[offset:], sep); i { + case -1: + return -1 + default: + return offset + i + } +} + +// isInteger returns true if the string s contains an integer. +func isInteger(s string) bool { + for i := 0; i < len(s); i++ { + if (s[i] < '0' || s[i] > '9') && !(i == 0 && s[i] == '-') { + return false + } + } + return true +} + +type escapeMode byte + +const ( + escapeNormal escapeMode = iota + escapeCanonicalText + escapeCanonicalAttr +) + +// escapeString writes an escaped version of a string to the writer. +func escapeString(w *bufio.Writer, s string, m escapeMode) { + var esc []byte + last := 0 + for i := 0; i < len(s); { + r, width := utf8.DecodeRuneInString(s[i:]) + i += width + switch r { + case '&': + esc = []byte("&") + case '<': + esc = []byte("<") + case '>': + if m == escapeCanonicalAttr { + continue + } + esc = []byte(">") + case '\'': + if m != escapeNormal { + continue + } + esc = []byte("'") + case '"': + if m == escapeCanonicalText { + continue + } + esc = []byte(""") + case '\t': + if m != escapeCanonicalAttr { + continue + } + esc = []byte(" ") + case '\n': + if m != escapeCanonicalAttr { + continue + } + esc = []byte(" ") + case '\r': + if m == escapeNormal { + continue + } + esc = []byte(" ") + default: + if !isInCharacterRange(r) || (r == 0xFFFD && width == 1) { + esc = []byte("\uFFFD") + break + } + continue + } + w.WriteString(s[last : i-width]) + w.Write(esc) + last = i + } + w.WriteString(s[last:]) +} + +func isInCharacterRange(r rune) bool { + return r == 0x09 || + r == 0x0A || + r == 0x0D || + r >= 0x20 && r <= 0xD7FF || + r >= 0xE000 && r <= 0xFFFD || + r >= 0x10000 && r <= 0x10FFFF +} diff --git a/vendor/github.com/beevik/etree/path.go b/vendor/github.com/beevik/etree/path.go new file mode 100644 index 000000000..82db0ac55 --- /dev/null +++ b/vendor/github.com/beevik/etree/path.go @@ -0,0 +1,582 @@ +// Copyright 2015-2019 Brett Vickers. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package etree + +import ( + "strconv" + "strings" +) + +/* +A Path is a string that represents a search path through an etree starting +from the document root or an arbitrary element. Paths are used with the +Element object's Find* methods to locate and return desired elements. + +A Path consists of a series of slash-separated "selectors", each of which may +be modified by one or more bracket-enclosed "filters". Selectors are used to +traverse the etree from element to element, while filters are used to narrow +the list of candidate elements at each node. + +Although etree Path strings are similar to XPath strings +(https://www.w3.org/TR/1999/REC-xpath-19991116/), they have a more limited set +of selectors and filtering options. + +The following selectors are supported by etree Path strings: + + . Select the current element. + .. Select the parent of the current element. + * Select all child elements of the current element. + / Select the root element when used at the start of a path. + // Select all descendants of the current element. + tag Select all child elements with a name matching the tag. + +The following basic filters are supported by etree Path strings: + + [@attrib] Keep elements with an attribute named attrib. + [@attrib='val'] Keep elements with an attribute named attrib and value matching val. + [tag] Keep elements with a child element named tag. + [tag='val'] Keep elements with a child element named tag and text matching val. + [n] Keep the n-th element, where n is a numeric index starting from 1. + +The following function filters are also supported: + + [text()] Keep elements with non-empty text. + [text()='val'] Keep elements whose text matches val. + [local-name()='val'] Keep elements whose un-prefixed tag matches val. + [name()='val'] Keep elements whose full tag exactly matches val. + [namespace-prefix()='val'] Keep elements whose namespace prefix matches val. + [namespace-uri()='val'] Keep elements whose namespace URI matches val. + +Here are some examples of Path strings: + +- Select the bookstore child element of the root element: + /bookstore + +- Beginning from the root element, select the title elements of all +descendant book elements having a 'category' attribute of 'WEB': + //book[@category='WEB']/title + +- Beginning from the current element, select the first descendant +book element with a title child element containing the text 'Great +Expectations': + .//book[title='Great Expectations'][1] + +- Beginning from the current element, select all child elements of +book elements with an attribute 'language' set to 'english': + ./book/*[@language='english'] + +- Beginning from the current element, select all child elements of +book elements containing the text 'special': + ./book/*[text()='special'] + +- Beginning from the current element, select all descendant book +elements whose title child element has a 'language' attribute of 'french': + .//book/title[@language='french']/.. + +- Beginning from the current element, select all book elements +belonging to the http://www.w3.org/TR/html4/ namespace: + .//book[namespace-uri()='http://www.w3.org/TR/html4/'] + +*/ +type Path struct { + segments []segment +} + +// ErrPath is returned by path functions when an invalid etree path is provided. +type ErrPath string + +// Error returns the string describing a path error. +func (err ErrPath) Error() string { + return "etree: " + string(err) +} + +// CompilePath creates an optimized version of an XPath-like string that +// can be used to query elements in an element tree. +func CompilePath(path string) (Path, error) { + var comp compiler + segments := comp.parsePath(path) + if comp.err != ErrPath("") { + return Path{nil}, comp.err + } + return Path{segments}, nil +} + +// MustCompilePath creates an optimized version of an XPath-like string that +// can be used to query elements in an element tree. Panics if an error +// occurs. Use this function to create Paths when you know the path is +// valid (i.e., if it's hard-coded). +func MustCompilePath(path string) Path { + p, err := CompilePath(path) + if err != nil { + panic(err) + } + return p +} + +// A segment is a portion of a path between "/" characters. +// It contains one selector and zero or more [filters]. +type segment struct { + sel selector + filters []filter +} + +func (seg *segment) apply(e *Element, p *pather) { + seg.sel.apply(e, p) + for _, f := range seg.filters { + f.apply(p) + } +} + +// A selector selects XML elements for consideration by the +// path traversal. +type selector interface { + apply(e *Element, p *pather) +} + +// A filter pares down a list of candidate XML elements based +// on a path filter in [brackets]. +type filter interface { + apply(p *pather) +} + +// A pather is helper object that traverses an element tree using +// a Path object. It collects and deduplicates all elements matching +// the path query. +type pather struct { + queue fifo + results []*Element + inResults map[*Element]bool + candidates []*Element + scratch []*Element // used by filters +} + +// A node represents an element and the remaining path segments that +// should be applied against it by the pather. +type node struct { + e *Element + segments []segment +} + +func newPather() *pather { + return &pather{ + results: make([]*Element, 0), + inResults: make(map[*Element]bool), + candidates: make([]*Element, 0), + scratch: make([]*Element, 0), + } +} + +// traverse follows the path from the element e, collecting +// and then returning all elements that match the path's selectors +// and filters. +func (p *pather) traverse(e *Element, path Path) []*Element { + for p.queue.add(node{e, path.segments}); p.queue.len() > 0; { + p.eval(p.queue.remove().(node)) + } + return p.results +} + +// eval evalutes the current path node by applying the remaining +// path's selector rules against the node's element. +func (p *pather) eval(n node) { + p.candidates = p.candidates[0:0] + seg, remain := n.segments[0], n.segments[1:] + seg.apply(n.e, p) + + if len(remain) == 0 { + for _, c := range p.candidates { + if in := p.inResults[c]; !in { + p.inResults[c] = true + p.results = append(p.results, c) + } + } + } else { + for _, c := range p.candidates { + p.queue.add(node{c, remain}) + } + } +} + +// A compiler generates a compiled path from a path string. +type compiler struct { + err ErrPath +} + +// parsePath parses an XPath-like string describing a path +// through an element tree and returns a slice of segment +// descriptors. +func (c *compiler) parsePath(path string) []segment { + // If path ends with //, fix it + if strings.HasSuffix(path, "//") { + path = path + "*" + } + + var segments []segment + + // Check for an absolute path + if strings.HasPrefix(path, "/") { + segments = append(segments, segment{new(selectRoot), []filter{}}) + path = path[1:] + } + + // Split path into segments + for _, s := range splitPath(path) { + segments = append(segments, c.parseSegment(s)) + if c.err != ErrPath("") { + break + } + } + return segments +} + +func splitPath(path string) []string { + pieces := make([]string, 0) + start := 0 + inquote := false + for i := 0; i+1 <= len(path); i++ { + if path[i] == '\'' { + inquote = !inquote + } else if path[i] == '/' && !inquote { + pieces = append(pieces, path[start:i]) + start = i + 1 + } + } + return append(pieces, path[start:]) +} + +// parseSegment parses a path segment between / characters. +func (c *compiler) parseSegment(path string) segment { + pieces := strings.Split(path, "[") + seg := segment{ + sel: c.parseSelector(pieces[0]), + filters: []filter{}, + } + for i := 1; i < len(pieces); i++ { + fpath := pieces[i] + if fpath[len(fpath)-1] != ']' { + c.err = ErrPath("path has invalid filter [brackets].") + break + } + seg.filters = append(seg.filters, c.parseFilter(fpath[:len(fpath)-1])) + } + return seg +} + +// parseSelector parses a selector at the start of a path segment. +func (c *compiler) parseSelector(path string) selector { + switch path { + case ".": + return new(selectSelf) + case "..": + return new(selectParent) + case "*": + return new(selectChildren) + case "": + return new(selectDescendants) + default: + return newSelectChildrenByTag(path) + } +} + +var fnTable = map[string]struct { + hasFn func(e *Element) bool + getValFn func(e *Element) string +}{ + "local-name": {nil, (*Element).name}, + "name": {nil, (*Element).FullTag}, + "namespace-prefix": {nil, (*Element).namespacePrefix}, + "namespace-uri": {nil, (*Element).NamespaceURI}, + "text": {(*Element).hasText, (*Element).Text}, +} + +// parseFilter parses a path filter contained within [brackets]. +func (c *compiler) parseFilter(path string) filter { + if len(path) == 0 { + c.err = ErrPath("path contains an empty filter expression.") + return nil + } + + // Filter contains [@attr='val'], [fn()='val'], or [tag='val']? + eqindex := strings.Index(path, "='") + if eqindex >= 0 { + rindex := nextIndex(path, "'", eqindex+2) + if rindex != len(path)-1 { + c.err = ErrPath("path has mismatched filter quotes.") + return nil + } + + key := path[:eqindex] + value := path[eqindex+2 : rindex] + + switch { + case key[0] == '@': + return newFilterAttrVal(key[1:], value) + case strings.HasSuffix(key, "()"): + fn := key[:len(key)-2] + if t, ok := fnTable[fn]; ok && t.getValFn != nil { + return newFilterFuncVal(t.getValFn, value) + } + c.err = ErrPath("path has unknown function " + fn) + return nil + default: + return newFilterChildText(key, value) + } + } + + // Filter contains [@attr], [N], [tag] or [fn()] + switch { + case path[0] == '@': + return newFilterAttr(path[1:]) + case strings.HasSuffix(path, "()"): + fn := path[:len(path)-2] + if t, ok := fnTable[fn]; ok && t.hasFn != nil { + return newFilterFunc(t.hasFn) + } + c.err = ErrPath("path has unknown function " + fn) + return nil + case isInteger(path): + pos, _ := strconv.Atoi(path) + switch { + case pos > 0: + return newFilterPos(pos - 1) + default: + return newFilterPos(pos) + } + default: + return newFilterChild(path) + } +} + +// selectSelf selects the current element into the candidate list. +type selectSelf struct{} + +func (s *selectSelf) apply(e *Element, p *pather) { + p.candidates = append(p.candidates, e) +} + +// selectRoot selects the element's root node. +type selectRoot struct{} + +func (s *selectRoot) apply(e *Element, p *pather) { + root := e + for root.parent != nil { + root = root.parent + } + p.candidates = append(p.candidates, root) +} + +// selectParent selects the element's parent into the candidate list. +type selectParent struct{} + +func (s *selectParent) apply(e *Element, p *pather) { + if e.parent != nil { + p.candidates = append(p.candidates, e.parent) + } +} + +// selectChildren selects the element's child elements into the +// candidate list. +type selectChildren struct{} + +func (s *selectChildren) apply(e *Element, p *pather) { + for _, c := range e.Child { + if c, ok := c.(*Element); ok { + p.candidates = append(p.candidates, c) + } + } +} + +// selectDescendants selects all descendant child elements +// of the element into the candidate list. +type selectDescendants struct{} + +func (s *selectDescendants) apply(e *Element, p *pather) { + var queue fifo + for queue.add(e); queue.len() > 0; { + e := queue.remove().(*Element) + p.candidates = append(p.candidates, e) + for _, c := range e.Child { + if c, ok := c.(*Element); ok { + queue.add(c) + } + } + } +} + +// selectChildrenByTag selects into the candidate list all child +// elements of the element having the specified tag. +type selectChildrenByTag struct { + space, tag string +} + +func newSelectChildrenByTag(path string) *selectChildrenByTag { + s, l := spaceDecompose(path) + return &selectChildrenByTag{s, l} +} + +func (s *selectChildrenByTag) apply(e *Element, p *pather) { + for _, c := range e.Child { + if c, ok := c.(*Element); ok && spaceMatch(s.space, c.Space) && s.tag == c.Tag { + p.candidates = append(p.candidates, c) + } + } +} + +// filterPos filters the candidate list, keeping only the +// candidate at the specified index. +type filterPos struct { + index int +} + +func newFilterPos(pos int) *filterPos { + return &filterPos{pos} +} + +func (f *filterPos) apply(p *pather) { + if f.index >= 0 { + if f.index < len(p.candidates) { + p.scratch = append(p.scratch, p.candidates[f.index]) + } + } else { + if -f.index <= len(p.candidates) { + p.scratch = append(p.scratch, p.candidates[len(p.candidates)+f.index]) + } + } + p.candidates, p.scratch = p.scratch, p.candidates[0:0] +} + +// filterAttr filters the candidate list for elements having +// the specified attribute. +type filterAttr struct { + space, key string +} + +func newFilterAttr(str string) *filterAttr { + s, l := spaceDecompose(str) + return &filterAttr{s, l} +} + +func (f *filterAttr) apply(p *pather) { + for _, c := range p.candidates { + for _, a := range c.Attr { + if spaceMatch(f.space, a.Space) && f.key == a.Key { + p.scratch = append(p.scratch, c) + break + } + } + } + p.candidates, p.scratch = p.scratch, p.candidates[0:0] +} + +// filterAttrVal filters the candidate list for elements having +// the specified attribute with the specified value. +type filterAttrVal struct { + space, key, val string +} + +func newFilterAttrVal(str, value string) *filterAttrVal { + s, l := spaceDecompose(str) + return &filterAttrVal{s, l, value} +} + +func (f *filterAttrVal) apply(p *pather) { + for _, c := range p.candidates { + for _, a := range c.Attr { + if spaceMatch(f.space, a.Space) && f.key == a.Key && f.val == a.Value { + p.scratch = append(p.scratch, c) + break + } + } + } + p.candidates, p.scratch = p.scratch, p.candidates[0:0] +} + +// filterFunc filters the candidate list for elements satisfying a custom +// boolean function. +type filterFunc struct { + fn func(e *Element) bool +} + +func newFilterFunc(fn func(e *Element) bool) *filterFunc { + return &filterFunc{fn} +} + +func (f *filterFunc) apply(p *pather) { + for _, c := range p.candidates { + if f.fn(c) { + p.scratch = append(p.scratch, c) + } + } + p.candidates, p.scratch = p.scratch, p.candidates[0:0] +} + +// filterFuncVal filters the candidate list for elements containing a value +// matching the result of a custom function. +type filterFuncVal struct { + fn func(e *Element) string + val string +} + +func newFilterFuncVal(fn func(e *Element) string, value string) *filterFuncVal { + return &filterFuncVal{fn, value} +} + +func (f *filterFuncVal) apply(p *pather) { + for _, c := range p.candidates { + if f.fn(c) == f.val { + p.scratch = append(p.scratch, c) + } + } + p.candidates, p.scratch = p.scratch, p.candidates[0:0] +} + +// filterChild filters the candidate list for elements having +// a child element with the specified tag. +type filterChild struct { + space, tag string +} + +func newFilterChild(str string) *filterChild { + s, l := spaceDecompose(str) + return &filterChild{s, l} +} + +func (f *filterChild) apply(p *pather) { + for _, c := range p.candidates { + for _, cc := range c.Child { + if cc, ok := cc.(*Element); ok && + spaceMatch(f.space, cc.Space) && + f.tag == cc.Tag { + p.scratch = append(p.scratch, c) + } + } + } + p.candidates, p.scratch = p.scratch, p.candidates[0:0] +} + +// filterChildText filters the candidate list for elements having +// a child element with the specified tag and text. +type filterChildText struct { + space, tag, text string +} + +func newFilterChildText(str, text string) *filterChildText { + s, l := spaceDecompose(str) + return &filterChildText{s, l, text} +} + +func (f *filterChildText) apply(p *pather) { + for _, c := range p.candidates { + for _, cc := range c.Child { + if cc, ok := cc.(*Element); ok && + spaceMatch(f.space, cc.Space) && + f.tag == cc.Tag && + f.text == cc.Text() { + p.scratch = append(p.scratch, c) + } + } + } + p.candidates, p.scratch = p.scratch, p.candidates[0:0] +} diff --git a/vendor/github.com/jinzhu/gorm/License b/vendor/github.com/jinzhu/gorm/License deleted file mode 100644 index 037e1653e..000000000 --- a/vendor/github.com/jinzhu/gorm/License +++ /dev/null @@ -1,21 +0,0 @@ -The MIT License (MIT) - -Copyright (c) 2013-NOW Jinzhu - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -THE SOFTWARE. diff --git a/vendor/github.com/jinzhu/gorm/association.go b/vendor/github.com/jinzhu/gorm/association.go deleted file mode 100644 index a73344fe6..000000000 --- a/vendor/github.com/jinzhu/gorm/association.go +++ /dev/null @@ -1,377 +0,0 @@ -package gorm - -import ( - "errors" - "fmt" - "reflect" -) - -// Association Mode contains some helper methods to handle relationship things easily. -type Association struct { - Error error - scope *Scope - column string - field *Field -} - -// Find find out all related associations -func (association *Association) Find(value interface{}) *Association { - association.scope.related(value, association.column) - return association.setErr(association.scope.db.Error) -} - -// Append append new associations for many2many, has_many, replace current association for has_one, belongs_to -func (association *Association) Append(values ...interface{}) *Association { - if association.Error != nil { - return association - } - - if relationship := association.field.Relationship; relationship.Kind == "has_one" { - return association.Replace(values...) - } - return association.saveAssociations(values...) -} - -// Replace replace current associations with new one -func (association *Association) Replace(values ...interface{}) *Association { - if association.Error != nil { - return association - } - - var ( - relationship = association.field.Relationship - scope = association.scope - field = association.field.Field - newDB = scope.NewDB() - ) - - // Append new values - association.field.Set(reflect.Zero(association.field.Field.Type())) - association.saveAssociations(values...) - - // Belongs To - if relationship.Kind == "belongs_to" { - // Set foreign key to be null when clearing value (length equals 0) - if len(values) == 0 { - // Set foreign key to be nil - var foreignKeyMap = map[string]interface{}{} - for _, foreignKey := range relationship.ForeignDBNames { - foreignKeyMap[foreignKey] = nil - } - association.setErr(newDB.Model(scope.Value).UpdateColumn(foreignKeyMap).Error) - } - } else { - // Polymorphic Relations - if relationship.PolymorphicDBName != "" { - newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), relationship.PolymorphicValue) - } - - // Delete Relations except new created - if len(values) > 0 { - var associationForeignFieldNames, associationForeignDBNames []string - if relationship.Kind == "many_to_many" { - // if many to many relations, get association fields name from association foreign keys - associationScope := scope.New(reflect.New(field.Type()).Interface()) - for idx, dbName := range relationship.AssociationForeignFieldNames { - if field, ok := associationScope.FieldByName(dbName); ok { - associationForeignFieldNames = append(associationForeignFieldNames, field.Name) - associationForeignDBNames = append(associationForeignDBNames, relationship.AssociationForeignDBNames[idx]) - } - } - } else { - // If has one/many relations, use primary keys - for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() { - associationForeignFieldNames = append(associationForeignFieldNames, field.Name) - associationForeignDBNames = append(associationForeignDBNames, field.DBName) - } - } - - newPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, field.Interface()) - - if len(newPrimaryKeys) > 0 { - sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, associationForeignDBNames), toQueryMarks(newPrimaryKeys)) - newDB = newDB.Where(sql, toQueryValues(newPrimaryKeys)...) - } - } - - if relationship.Kind == "many_to_many" { - // if many to many relations, delete related relations from join table - var sourceForeignFieldNames []string - - for _, dbName := range relationship.ForeignFieldNames { - if field, ok := scope.FieldByName(dbName); ok { - sourceForeignFieldNames = append(sourceForeignFieldNames, field.Name) - } - } - - if sourcePrimaryKeys := scope.getColumnAsArray(sourceForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 { - newDB = newDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(sourcePrimaryKeys)), toQueryValues(sourcePrimaryKeys)...) - - association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB)) - } - } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { - // has_one or has_many relations, set foreign key to be nil (TODO or delete them?) - var foreignKeyMap = map[string]interface{}{} - for idx, foreignKey := range relationship.ForeignDBNames { - foreignKeyMap[foreignKey] = nil - if field, ok := scope.FieldByName(relationship.AssociationForeignFieldNames[idx]); ok { - newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) - } - } - - fieldValue := reflect.New(association.field.Field.Type()).Interface() - association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error) - } - } - return association -} - -// Delete remove relationship between source & passed arguments, but won't delete those arguments -func (association *Association) Delete(values ...interface{}) *Association { - if association.Error != nil { - return association - } - - var ( - relationship = association.field.Relationship - scope = association.scope - field = association.field.Field - newDB = scope.NewDB() - ) - - if len(values) == 0 { - return association - } - - var deletingResourcePrimaryFieldNames, deletingResourcePrimaryDBNames []string - for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() { - deletingResourcePrimaryFieldNames = append(deletingResourcePrimaryFieldNames, field.Name) - deletingResourcePrimaryDBNames = append(deletingResourcePrimaryDBNames, field.DBName) - } - - deletingPrimaryKeys := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, values...) - - if relationship.Kind == "many_to_many" { - // source value's foreign keys - for idx, foreignKey := range relationship.ForeignDBNames { - if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok { - newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) - } - } - - // get association's foreign fields name - var associationScope = scope.New(reflect.New(field.Type()).Interface()) - var associationForeignFieldNames []string - for _, associationDBName := range relationship.AssociationForeignFieldNames { - if field, ok := associationScope.FieldByName(associationDBName); ok { - associationForeignFieldNames = append(associationForeignFieldNames, field.Name) - } - } - - // association value's foreign keys - deletingPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, values...) - sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(deletingPrimaryKeys)) - newDB = newDB.Where(sql, toQueryValues(deletingPrimaryKeys)...) - - association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB)) - } else { - var foreignKeyMap = map[string]interface{}{} - for _, foreignKey := range relationship.ForeignDBNames { - foreignKeyMap[foreignKey] = nil - } - - if relationship.Kind == "belongs_to" { - // find with deleting relation's foreign keys - primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, values...) - newDB = newDB.Where( - fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), - toQueryValues(primaryKeys)..., - ) - - // set foreign key to be null if there are some records affected - modelValue := reflect.New(scope.GetModelStruct().ModelType).Interface() - if results := newDB.Model(modelValue).UpdateColumn(foreignKeyMap); results.Error == nil { - if results.RowsAffected > 0 { - scope.updatedAttrsWithValues(foreignKeyMap) - } - } else { - association.setErr(results.Error) - } - } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { - // find all relations - primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value) - newDB = newDB.Where( - fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), - toQueryValues(primaryKeys)..., - ) - - // only include those deleting relations - newDB = newDB.Where( - fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, deletingResourcePrimaryDBNames), toQueryMarks(deletingPrimaryKeys)), - toQueryValues(deletingPrimaryKeys)..., - ) - - // set matched relation's foreign key to be null - fieldValue := reflect.New(association.field.Field.Type()).Interface() - association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error) - } - } - - // Remove deleted records from source's field - if association.Error == nil { - if field.Kind() == reflect.Slice { - leftValues := reflect.Zero(field.Type()) - - for i := 0; i < field.Len(); i++ { - reflectValue := field.Index(i) - primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, reflectValue.Interface())[0] - var isDeleted = false - for _, pk := range deletingPrimaryKeys { - if equalAsString(primaryKey, pk) { - isDeleted = true - break - } - } - if !isDeleted { - leftValues = reflect.Append(leftValues, reflectValue) - } - } - - association.field.Set(leftValues) - } else if field.Kind() == reflect.Struct { - primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, field.Interface())[0] - for _, pk := range deletingPrimaryKeys { - if equalAsString(primaryKey, pk) { - association.field.Set(reflect.Zero(field.Type())) - break - } - } - } - } - - return association -} - -// Clear remove relationship between source & current associations, won't delete those associations -func (association *Association) Clear() *Association { - return association.Replace() -} - -// Count return the count of current associations -func (association *Association) Count() int { - var ( - count = 0 - relationship = association.field.Relationship - scope = association.scope - fieldValue = association.field.Field.Interface() - query = scope.DB() - ) - - switch relationship.Kind { - case "many_to_many": - query = relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, query, scope.Value) - case "has_many", "has_one": - primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value) - query = query.Where( - fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), - toQueryValues(primaryKeys)..., - ) - case "belongs_to": - primaryKeys := scope.getColumnAsArray(relationship.ForeignFieldNames, scope.Value) - query = query.Where( - fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys)), - toQueryValues(primaryKeys)..., - ) - } - - if relationship.PolymorphicType != "" { - query = query.Where( - fmt.Sprintf("%v.%v = ?", scope.New(fieldValue).QuotedTableName(), scope.Quote(relationship.PolymorphicDBName)), - relationship.PolymorphicValue, - ) - } - - if err := query.Model(fieldValue).Count(&count).Error; err != nil { - association.Error = err - } - return count -} - -// saveAssociations save passed values as associations -func (association *Association) saveAssociations(values ...interface{}) *Association { - var ( - scope = association.scope - field = association.field - relationship = field.Relationship - ) - - saveAssociation := func(reflectValue reflect.Value) { - // value has to been pointer - if reflectValue.Kind() != reflect.Ptr { - reflectPtr := reflect.New(reflectValue.Type()) - reflectPtr.Elem().Set(reflectValue) - reflectValue = reflectPtr - } - - // value has to been saved for many2many - if relationship.Kind == "many_to_many" { - if scope.New(reflectValue.Interface()).PrimaryKeyZero() { - association.setErr(scope.NewDB().Save(reflectValue.Interface()).Error) - } - } - - // Assign Fields - var fieldType = field.Field.Type() - var setFieldBackToValue, setSliceFieldBackToValue bool - if reflectValue.Type().AssignableTo(fieldType) { - field.Set(reflectValue) - } else if reflectValue.Type().Elem().AssignableTo(fieldType) { - // if field's type is struct, then need to set value back to argument after save - setFieldBackToValue = true - field.Set(reflectValue.Elem()) - } else if fieldType.Kind() == reflect.Slice { - if reflectValue.Type().AssignableTo(fieldType.Elem()) { - field.Set(reflect.Append(field.Field, reflectValue)) - } else if reflectValue.Type().Elem().AssignableTo(fieldType.Elem()) { - // if field's type is slice of struct, then need to set value back to argument after save - setSliceFieldBackToValue = true - field.Set(reflect.Append(field.Field, reflectValue.Elem())) - } - } - - if relationship.Kind == "many_to_many" { - association.setErr(relationship.JoinTableHandler.Add(relationship.JoinTableHandler, scope.NewDB(), scope.Value, reflectValue.Interface())) - } else { - association.setErr(scope.NewDB().Select(field.Name).Save(scope.Value).Error) - - if setFieldBackToValue { - reflectValue.Elem().Set(field.Field) - } else if setSliceFieldBackToValue { - reflectValue.Elem().Set(field.Field.Index(field.Field.Len() - 1)) - } - } - } - - for _, value := range values { - reflectValue := reflect.ValueOf(value) - indirectReflectValue := reflect.Indirect(reflectValue) - if indirectReflectValue.Kind() == reflect.Struct { - saveAssociation(reflectValue) - } else if indirectReflectValue.Kind() == reflect.Slice { - for i := 0; i < indirectReflectValue.Len(); i++ { - saveAssociation(indirectReflectValue.Index(i)) - } - } else { - association.setErr(errors.New("invalid value type")) - } - } - return association -} - -// setErr set error when the error is not nil. And return Association. -func (association *Association) setErr(err error) *Association { - if err != nil { - association.Error = err - } - return association -} diff --git a/vendor/github.com/jinzhu/gorm/callback.go b/vendor/github.com/jinzhu/gorm/callback.go deleted file mode 100644 index a4382147b..000000000 --- a/vendor/github.com/jinzhu/gorm/callback.go +++ /dev/null @@ -1,242 +0,0 @@ -package gorm - -import "log" - -// DefaultCallback default callbacks defined by gorm -var DefaultCallback = &Callback{} - -// Callback is a struct that contains all CRUD callbacks -// Field `creates` contains callbacks will be call when creating object -// Field `updates` contains callbacks will be call when updating object -// Field `deletes` contains callbacks will be call when deleting object -// Field `queries` contains callbacks will be call when querying object with query methods like Find, First, Related, Association... -// Field `rowQueries` contains callbacks will be call when querying object with Row, Rows... -// Field `processors` contains all callback processors, will be used to generate above callbacks in order -type Callback struct { - creates []*func(scope *Scope) - updates []*func(scope *Scope) - deletes []*func(scope *Scope) - queries []*func(scope *Scope) - rowQueries []*func(scope *Scope) - processors []*CallbackProcessor -} - -// CallbackProcessor contains callback informations -type CallbackProcessor struct { - name string // current callback's name - before string // register current callback before a callback - after string // register current callback after a callback - replace bool // replace callbacks with same name - remove bool // delete callbacks with same name - kind string // callback type: create, update, delete, query, row_query - processor *func(scope *Scope) // callback handler - parent *Callback -} - -func (c *Callback) clone() *Callback { - return &Callback{ - creates: c.creates, - updates: c.updates, - deletes: c.deletes, - queries: c.queries, - rowQueries: c.rowQueries, - processors: c.processors, - } -} - -// Create could be used to register callbacks for creating object -// db.Callback().Create().After("gorm:create").Register("plugin:run_after_create", func(*Scope) { -// // business logic -// ... -// -// // set error if some thing wrong happened, will rollback the creating -// scope.Err(errors.New("error")) -// }) -func (c *Callback) Create() *CallbackProcessor { - return &CallbackProcessor{kind: "create", parent: c} -} - -// Update could be used to register callbacks for updating object, refer `Create` for usage -func (c *Callback) Update() *CallbackProcessor { - return &CallbackProcessor{kind: "update", parent: c} -} - -// Delete could be used to register callbacks for deleting object, refer `Create` for usage -func (c *Callback) Delete() *CallbackProcessor { - return &CallbackProcessor{kind: "delete", parent: c} -} - -// Query could be used to register callbacks for querying objects with query methods like `Find`, `First`, `Related`, `Association`... -// Refer `Create` for usage -func (c *Callback) Query() *CallbackProcessor { - return &CallbackProcessor{kind: "query", parent: c} -} - -// RowQuery could be used to register callbacks for querying objects with `Row`, `Rows`, refer `Create` for usage -func (c *Callback) RowQuery() *CallbackProcessor { - return &CallbackProcessor{kind: "row_query", parent: c} -} - -// After insert a new callback after callback `callbackName`, refer `Callbacks.Create` -func (cp *CallbackProcessor) After(callbackName string) *CallbackProcessor { - cp.after = callbackName - return cp -} - -// Before insert a new callback before callback `callbackName`, refer `Callbacks.Create` -func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor { - cp.before = callbackName - return cp -} - -// Register a new callback, refer `Callbacks.Create` -func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) { - if cp.kind == "row_query" { - if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" { - log.Printf("Registing RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...\n", callbackName) - cp.before = "gorm:row_query" - } - } - - cp.name = callbackName - cp.processor = &callback - cp.parent.processors = append(cp.parent.processors, cp) - cp.parent.reorder() -} - -// Remove a registered callback -// db.Callback().Create().Remove("gorm:update_time_stamp_when_create") -func (cp *CallbackProcessor) Remove(callbackName string) { - log.Printf("[info] removing callback `%v` from %v\n", callbackName, fileWithLineNum()) - cp.name = callbackName - cp.remove = true - cp.parent.processors = append(cp.parent.processors, cp) - cp.parent.reorder() -} - -// Replace a registered callback with new callback -// db.Callback().Create().Replace("gorm:update_time_stamp_when_create", func(*Scope) { -// scope.SetColumn("Created", now) -// scope.SetColumn("Updated", now) -// }) -func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) { - log.Printf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum()) - cp.name = callbackName - cp.processor = &callback - cp.replace = true - cp.parent.processors = append(cp.parent.processors, cp) - cp.parent.reorder() -} - -// Get registered callback -// db.Callback().Create().Get("gorm:create") -func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) { - for _, p := range cp.parent.processors { - if p.name == callbackName && p.kind == cp.kind && !cp.remove { - return *p.processor - } - } - return nil -} - -// getRIndex get right index from string slice -func getRIndex(strs []string, str string) int { - for i := len(strs) - 1; i >= 0; i-- { - if strs[i] == str { - return i - } - } - return -1 -} - -// sortProcessors sort callback processors based on its before, after, remove, replace -func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) { - var ( - allNames, sortedNames []string - sortCallbackProcessor func(c *CallbackProcessor) - ) - - for _, cp := range cps { - // show warning message the callback name already exists - if index := getRIndex(allNames, cp.name); index > -1 && !cp.replace && !cp.remove { - log.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum()) - } - allNames = append(allNames, cp.name) - } - - sortCallbackProcessor = func(c *CallbackProcessor) { - if getRIndex(sortedNames, c.name) == -1 { // if not sorted - if c.before != "" { // if defined before callback - if index := getRIndex(sortedNames, c.before); index != -1 { - // if before callback already sorted, append current callback just after it - sortedNames = append(sortedNames[:index], append([]string{c.name}, sortedNames[index:]...)...) - } else if index := getRIndex(allNames, c.before); index != -1 { - // if before callback exists but haven't sorted, append current callback to last - sortedNames = append(sortedNames, c.name) - sortCallbackProcessor(cps[index]) - } - } - - if c.after != "" { // if defined after callback - if index := getRIndex(sortedNames, c.after); index != -1 { - // if after callback already sorted, append current callback just before it - sortedNames = append(sortedNames[:index+1], append([]string{c.name}, sortedNames[index+1:]...)...) - } else if index := getRIndex(allNames, c.after); index != -1 { - // if after callback exists but haven't sorted - cp := cps[index] - // set after callback's before callback to current callback - if cp.before == "" { - cp.before = c.name - } - sortCallbackProcessor(cp) - } - } - - // if current callback haven't been sorted, append it to last - if getRIndex(sortedNames, c.name) == -1 { - sortedNames = append(sortedNames, c.name) - } - } - } - - for _, cp := range cps { - sortCallbackProcessor(cp) - } - - var sortedFuncs []*func(scope *Scope) - for _, name := range sortedNames { - if index := getRIndex(allNames, name); !cps[index].remove { - sortedFuncs = append(sortedFuncs, cps[index].processor) - } - } - - return sortedFuncs -} - -// reorder all registered processors, and reset CRUD callbacks -func (c *Callback) reorder() { - var creates, updates, deletes, queries, rowQueries []*CallbackProcessor - - for _, processor := range c.processors { - if processor.name != "" { - switch processor.kind { - case "create": - creates = append(creates, processor) - case "update": - updates = append(updates, processor) - case "delete": - deletes = append(deletes, processor) - case "query": - queries = append(queries, processor) - case "row_query": - rowQueries = append(rowQueries, processor) - } - } - } - - c.creates = sortProcessors(creates) - c.updates = sortProcessors(updates) - c.deletes = sortProcessors(deletes) - c.queries = sortProcessors(queries) - c.rowQueries = sortProcessors(rowQueries) -} diff --git a/vendor/github.com/jinzhu/gorm/callback_create.go b/vendor/github.com/jinzhu/gorm/callback_create.go deleted file mode 100644 index 2ab05d3b0..000000000 --- a/vendor/github.com/jinzhu/gorm/callback_create.go +++ /dev/null @@ -1,164 +0,0 @@ -package gorm - -import ( - "fmt" - "strings" -) - -// Define callbacks for creating -func init() { - DefaultCallback.Create().Register("gorm:begin_transaction", beginTransactionCallback) - DefaultCallback.Create().Register("gorm:before_create", beforeCreateCallback) - DefaultCallback.Create().Register("gorm:save_before_associations", saveBeforeAssociationsCallback) - DefaultCallback.Create().Register("gorm:update_time_stamp", updateTimeStampForCreateCallback) - DefaultCallback.Create().Register("gorm:create", createCallback) - DefaultCallback.Create().Register("gorm:force_reload_after_create", forceReloadAfterCreateCallback) - DefaultCallback.Create().Register("gorm:save_after_associations", saveAfterAssociationsCallback) - DefaultCallback.Create().Register("gorm:after_create", afterCreateCallback) - DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback) -} - -// beforeCreateCallback will invoke `BeforeSave`, `BeforeCreate` method before creating -func beforeCreateCallback(scope *Scope) { - if !scope.HasError() { - scope.CallMethod("BeforeSave") - } - if !scope.HasError() { - scope.CallMethod("BeforeCreate") - } -} - -// updateTimeStampForCreateCallback will set `CreatedAt`, `UpdatedAt` when creating -func updateTimeStampForCreateCallback(scope *Scope) { - if !scope.HasError() { - now := NowFunc() - - if createdAtField, ok := scope.FieldByName("CreatedAt"); ok { - if createdAtField.IsBlank { - createdAtField.Set(now) - } - } - - if updatedAtField, ok := scope.FieldByName("UpdatedAt"); ok { - if updatedAtField.IsBlank { - updatedAtField.Set(now) - } - } - } -} - -// createCallback the callback used to insert data into database -func createCallback(scope *Scope) { - if !scope.HasError() { - defer scope.trace(NowFunc()) - - var ( - columns, placeholders []string - blankColumnsWithDefaultValue []string - ) - - for _, field := range scope.Fields() { - if scope.changeableField(field) { - if field.IsNormal && !field.IsIgnored { - if field.IsBlank && field.HasDefaultValue { - blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, scope.Quote(field.DBName)) - scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue) - } else if !field.IsPrimaryKey || !field.IsBlank { - columns = append(columns, scope.Quote(field.DBName)) - placeholders = append(placeholders, scope.AddToVars(field.Field.Interface())) - } - } else if field.Relationship != nil && field.Relationship.Kind == "belongs_to" { - for _, foreignKey := range field.Relationship.ForeignDBNames { - if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) { - columns = append(columns, scope.Quote(foreignField.DBName)) - placeholders = append(placeholders, scope.AddToVars(foreignField.Field.Interface())) - } - } - } - } - } - - var ( - returningColumn = "*" - quotedTableName = scope.QuotedTableName() - primaryField = scope.PrimaryField() - extraOption string - ) - - if str, ok := scope.Get("gorm:insert_option"); ok { - extraOption = fmt.Sprint(str) - } - - if primaryField != nil { - returningColumn = scope.Quote(primaryField.DBName) - } - - lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn) - - if len(columns) == 0 { - scope.Raw(fmt.Sprintf( - "INSERT INTO %v %v%v%v", - quotedTableName, - scope.Dialect().DefaultValueStr(), - addExtraSpaceIfExist(extraOption), - addExtraSpaceIfExist(lastInsertIDReturningSuffix), - )) - } else { - scope.Raw(fmt.Sprintf( - "INSERT INTO %v (%v) VALUES (%v)%v%v", - scope.QuotedTableName(), - strings.Join(columns, ","), - strings.Join(placeholders, ","), - addExtraSpaceIfExist(extraOption), - addExtraSpaceIfExist(lastInsertIDReturningSuffix), - )) - } - - // execute create sql - if lastInsertIDReturningSuffix == "" || primaryField == nil { - if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { - // set rows affected count - scope.db.RowsAffected, _ = result.RowsAffected() - - // set primary value to primary field - if primaryField != nil && primaryField.IsBlank { - if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil { - scope.Err(primaryField.Set(primaryValue)) - } - } - } - } else { - if primaryField.Field.CanAddr() { - if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil { - primaryField.IsBlank = false - scope.db.RowsAffected = 1 - } - } else { - scope.Err(ErrUnaddressable) - } - } - } -} - -// forceReloadAfterCreateCallback will reload columns that having default value, and set it back to current object -func forceReloadAfterCreateCallback(scope *Scope) { - if blankColumnsWithDefaultValue, ok := scope.InstanceGet("gorm:blank_columns_with_default_value"); ok { - db := scope.DB().New().Table(scope.TableName()).Select(blankColumnsWithDefaultValue.([]string)) - for _, field := range scope.Fields() { - if field.IsPrimaryKey && !field.IsBlank { - db = db.Where(fmt.Sprintf("%v = ?", field.DBName), field.Field.Interface()) - } - } - db.Scan(scope.Value) - } -} - -// afterCreateCallback will invoke `AfterCreate`, `AfterSave` method after creating -func afterCreateCallback(scope *Scope) { - if !scope.HasError() { - scope.CallMethod("AfterCreate") - } - if !scope.HasError() { - scope.CallMethod("AfterSave") - } -} diff --git a/vendor/github.com/jinzhu/gorm/callback_delete.go b/vendor/github.com/jinzhu/gorm/callback_delete.go deleted file mode 100644 index 73d908806..000000000 --- a/vendor/github.com/jinzhu/gorm/callback_delete.go +++ /dev/null @@ -1,63 +0,0 @@ -package gorm - -import ( - "errors" - "fmt" -) - -// Define callbacks for deleting -func init() { - DefaultCallback.Delete().Register("gorm:begin_transaction", beginTransactionCallback) - DefaultCallback.Delete().Register("gorm:before_delete", beforeDeleteCallback) - DefaultCallback.Delete().Register("gorm:delete", deleteCallback) - DefaultCallback.Delete().Register("gorm:after_delete", afterDeleteCallback) - DefaultCallback.Delete().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback) -} - -// beforeDeleteCallback will invoke `BeforeDelete` method before deleting -func beforeDeleteCallback(scope *Scope) { - if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() { - scope.Err(errors.New("Missing WHERE clause while deleting")) - return - } - if !scope.HasError() { - scope.CallMethod("BeforeDelete") - } -} - -// deleteCallback used to delete data from database or set deleted_at to current time (when using with soft delete) -func deleteCallback(scope *Scope) { - if !scope.HasError() { - var extraOption string - if str, ok := scope.Get("gorm:delete_option"); ok { - extraOption = fmt.Sprint(str) - } - - deletedAtField, hasDeletedAtField := scope.FieldByName("DeletedAt") - - if !scope.Search.Unscoped && hasDeletedAtField { - scope.Raw(fmt.Sprintf( - "UPDATE %v SET %v=%v%v%v", - scope.QuotedTableName(), - scope.Quote(deletedAtField.DBName), - scope.AddToVars(NowFunc()), - addExtraSpaceIfExist(scope.CombinedConditionSql()), - addExtraSpaceIfExist(extraOption), - )).Exec() - } else { - scope.Raw(fmt.Sprintf( - "DELETE FROM %v%v%v", - scope.QuotedTableName(), - addExtraSpaceIfExist(scope.CombinedConditionSql()), - addExtraSpaceIfExist(extraOption), - )).Exec() - } - } -} - -// afterDeleteCallback will invoke `AfterDelete` method after deleting -func afterDeleteCallback(scope *Scope) { - if !scope.HasError() { - scope.CallMethod("AfterDelete") - } -} diff --git a/vendor/github.com/jinzhu/gorm/callback_query.go b/vendor/github.com/jinzhu/gorm/callback_query.go deleted file mode 100644 index 593e5d304..000000000 --- a/vendor/github.com/jinzhu/gorm/callback_query.go +++ /dev/null @@ -1,104 +0,0 @@ -package gorm - -import ( - "errors" - "fmt" - "reflect" -) - -// Define callbacks for querying -func init() { - DefaultCallback.Query().Register("gorm:query", queryCallback) - DefaultCallback.Query().Register("gorm:preload", preloadCallback) - DefaultCallback.Query().Register("gorm:after_query", afterQueryCallback) -} - -// queryCallback used to query data from database -func queryCallback(scope *Scope) { - if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip { - return - } - - //we are only preloading relations, dont touch base model - if _, skip := scope.InstanceGet("gorm:only_preload"); skip { - return - } - - defer scope.trace(NowFunc()) - - var ( - isSlice, isPtr bool - resultType reflect.Type - results = scope.IndirectValue() - ) - - if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok { - if primaryField := scope.PrimaryField(); primaryField != nil { - scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), scope.Quote(primaryField.DBName), orderBy)) - } - } - - if value, ok := scope.Get("gorm:query_destination"); ok { - results = indirect(reflect.ValueOf(value)) - } - - if kind := results.Kind(); kind == reflect.Slice { - isSlice = true - resultType = results.Type().Elem() - results.Set(reflect.MakeSlice(results.Type(), 0, 0)) - - if resultType.Kind() == reflect.Ptr { - isPtr = true - resultType = resultType.Elem() - } - } else if kind != reflect.Struct { - scope.Err(errors.New("unsupported destination, should be slice or struct")) - return - } - - scope.prepareQuerySQL() - - if !scope.HasError() { - scope.db.RowsAffected = 0 - if str, ok := scope.Get("gorm:query_option"); ok { - scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str)) - } - - if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { - defer rows.Close() - - columns, _ := rows.Columns() - for rows.Next() { - scope.db.RowsAffected++ - - elem := results - if isSlice { - elem = reflect.New(resultType).Elem() - } - - scope.scan(rows, columns, scope.New(elem.Addr().Interface()).Fields()) - - if isSlice { - if isPtr { - results.Set(reflect.Append(results, elem.Addr())) - } else { - results.Set(reflect.Append(results, elem)) - } - } - } - - if err := rows.Err(); err != nil { - scope.Err(err) - } else if scope.db.RowsAffected == 0 && !isSlice { - scope.Err(ErrRecordNotFound) - } - } - } -} - -// afterQueryCallback will invoke `AfterFind` method after querying -func afterQueryCallback(scope *Scope) { - if !scope.HasError() { - scope.CallMethod("AfterFind") - } -} diff --git a/vendor/github.com/jinzhu/gorm/callback_query_preload.go b/vendor/github.com/jinzhu/gorm/callback_query_preload.go deleted file mode 100644 index d7c8a133e..000000000 --- a/vendor/github.com/jinzhu/gorm/callback_query_preload.go +++ /dev/null @@ -1,404 +0,0 @@ -package gorm - -import ( - "errors" - "fmt" - "reflect" - "strconv" - "strings" -) - -// preloadCallback used to preload associations -func preloadCallback(scope *Scope) { - if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip { - return - } - - if ap, ok := scope.Get("gorm:auto_preload"); ok { - // If gorm:auto_preload IS NOT a bool then auto preload. - // Else if it IS a bool, use the value - if apb, ok := ap.(bool); !ok { - autoPreload(scope) - } else if apb { - autoPreload(scope) - } - } - - if scope.Search.preload == nil || scope.HasError() { - return - } - - var ( - preloadedMap = map[string]bool{} - fields = scope.Fields() - ) - - for _, preload := range scope.Search.preload { - var ( - preloadFields = strings.Split(preload.schema, ".") - currentScope = scope - currentFields = fields - ) - - for idx, preloadField := range preloadFields { - var currentPreloadConditions []interface{} - - if currentScope == nil { - continue - } - - // if not preloaded - if preloadKey := strings.Join(preloadFields[:idx+1], "."); !preloadedMap[preloadKey] { - - // assign search conditions to last preload - if idx == len(preloadFields)-1 { - currentPreloadConditions = preload.conditions - } - - for _, field := range currentFields { - if field.Name != preloadField || field.Relationship == nil { - continue - } - - switch field.Relationship.Kind { - case "has_one": - currentScope.handleHasOnePreload(field, currentPreloadConditions) - case "has_many": - currentScope.handleHasManyPreload(field, currentPreloadConditions) - case "belongs_to": - currentScope.handleBelongsToPreload(field, currentPreloadConditions) - case "many_to_many": - currentScope.handleManyToManyPreload(field, currentPreloadConditions) - default: - scope.Err(errors.New("unsupported relation")) - } - - preloadedMap[preloadKey] = true - break - } - - if !preloadedMap[preloadKey] { - scope.Err(fmt.Errorf("can't preload field %s for %s", preloadField, currentScope.GetModelStruct().ModelType)) - return - } - } - - // preload next level - if idx < len(preloadFields)-1 { - currentScope = currentScope.getColumnAsScope(preloadField) - if currentScope != nil { - currentFields = currentScope.Fields() - } - } - } - } -} - -func autoPreload(scope *Scope) { - for _, field := range scope.Fields() { - if field.Relationship == nil { - continue - } - - if val, ok := field.TagSettingsGet("PRELOAD"); ok { - if preload, err := strconv.ParseBool(val); err != nil { - scope.Err(errors.New("invalid preload option")) - return - } else if !preload { - continue - } - } - - scope.Search.Preload(field.Name) - } -} - -func (scope *Scope) generatePreloadDBWithConditions(conditions []interface{}) (*DB, []interface{}) { - var ( - preloadDB = scope.NewDB() - preloadConditions []interface{} - ) - - for _, condition := range conditions { - if scopes, ok := condition.(func(*DB) *DB); ok { - preloadDB = scopes(preloadDB) - } else { - preloadConditions = append(preloadConditions, condition) - } - } - - return preloadDB, preloadConditions -} - -// handleHasOnePreload used to preload has one associations -func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) { - relation := field.Relationship - - // get relations's primary keys - primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value) - if len(primaryKeys) == 0 { - return - } - - // preload conditions - preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) - - // find relations - query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)) - values := toQueryValues(primaryKeys) - if relation.PolymorphicType != "" { - query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName)) - values = append(values, relation.PolymorphicValue) - } - - results := makeSlice(field.Struct.Type) - scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error) - - // assign find results - var ( - resultsValue = indirect(reflect.ValueOf(results)) - indirectScopeValue = scope.IndirectValue() - ) - - if indirectScopeValue.Kind() == reflect.Slice { - foreignValuesToResults := make(map[string]reflect.Value) - for i := 0; i < resultsValue.Len(); i++ { - result := resultsValue.Index(i) - foreignValues := toString(getValueFromFields(result, relation.ForeignFieldNames)) - foreignValuesToResults[foreignValues] = result - } - for j := 0; j < indirectScopeValue.Len(); j++ { - indirectValue := indirect(indirectScopeValue.Index(j)) - valueString := toString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames)) - if result, found := foreignValuesToResults[valueString]; found { - indirectValue.FieldByName(field.Name).Set(result) - } - } - } else { - for i := 0; i < resultsValue.Len(); i++ { - result := resultsValue.Index(i) - scope.Err(field.Set(result)) - } - } -} - -// handleHasManyPreload used to preload has many associations -func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) { - relation := field.Relationship - - // get relations's primary keys - primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value) - if len(primaryKeys) == 0 { - return - } - - // preload conditions - preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) - - // find relations - query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)) - values := toQueryValues(primaryKeys) - if relation.PolymorphicType != "" { - query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName)) - values = append(values, relation.PolymorphicValue) - } - - results := makeSlice(field.Struct.Type) - scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error) - - // assign find results - var ( - resultsValue = indirect(reflect.ValueOf(results)) - indirectScopeValue = scope.IndirectValue() - ) - - if indirectScopeValue.Kind() == reflect.Slice { - preloadMap := make(map[string][]reflect.Value) - for i := 0; i < resultsValue.Len(); i++ { - result := resultsValue.Index(i) - foreignValues := getValueFromFields(result, relation.ForeignFieldNames) - preloadMap[toString(foreignValues)] = append(preloadMap[toString(foreignValues)], result) - } - - for j := 0; j < indirectScopeValue.Len(); j++ { - object := indirect(indirectScopeValue.Index(j)) - objectRealValue := getValueFromFields(object, relation.AssociationForeignFieldNames) - f := object.FieldByName(field.Name) - if results, ok := preloadMap[toString(objectRealValue)]; ok { - f.Set(reflect.Append(f, results...)) - } else { - f.Set(reflect.MakeSlice(f.Type(), 0, 0)) - } - } - } else { - scope.Err(field.Set(resultsValue)) - } -} - -// handleBelongsToPreload used to preload belongs to associations -func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) { - relation := field.Relationship - - // preload conditions - preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) - - // get relations's primary keys - primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames, scope.Value) - if len(primaryKeys) == 0 { - return - } - - // find relations - results := makeSlice(field.Struct.Type) - scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error) - - // assign find results - var ( - resultsValue = indirect(reflect.ValueOf(results)) - indirectScopeValue = scope.IndirectValue() - ) - - foreignFieldToObjects := make(map[string][]*reflect.Value) - if indirectScopeValue.Kind() == reflect.Slice { - for j := 0; j < indirectScopeValue.Len(); j++ { - object := indirect(indirectScopeValue.Index(j)) - valueString := toString(getValueFromFields(object, relation.ForeignFieldNames)) - foreignFieldToObjects[valueString] = append(foreignFieldToObjects[valueString], &object) - } - } - - for i := 0; i < resultsValue.Len(); i++ { - result := resultsValue.Index(i) - if indirectScopeValue.Kind() == reflect.Slice { - valueString := toString(getValueFromFields(result, relation.AssociationForeignFieldNames)) - if objects, found := foreignFieldToObjects[valueString]; found { - for _, object := range objects { - object.FieldByName(field.Name).Set(result) - } - } - } else { - scope.Err(field.Set(result)) - } - } -} - -// handleManyToManyPreload used to preload many to many associations -func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) { - var ( - relation = field.Relationship - joinTableHandler = relation.JoinTableHandler - fieldType = field.Struct.Type.Elem() - foreignKeyValue interface{} - foreignKeyType = reflect.ValueOf(&foreignKeyValue).Type() - linkHash = map[string][]reflect.Value{} - isPtr bool - ) - - if fieldType.Kind() == reflect.Ptr { - isPtr = true - fieldType = fieldType.Elem() - } - - var sourceKeys = []string{} - for _, key := range joinTableHandler.SourceForeignKeys() { - sourceKeys = append(sourceKeys, key.DBName) - } - - // preload conditions - preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) - - // generate query with join table - newScope := scope.New(reflect.New(fieldType).Interface()) - preloadDB = preloadDB.Table(newScope.TableName()).Model(newScope.Value) - - if len(preloadDB.search.selects) == 0 { - preloadDB = preloadDB.Select("*") - } - - preloadDB = joinTableHandler.JoinWith(joinTableHandler, preloadDB, scope.Value) - - // preload inline conditions - if len(preloadConditions) > 0 { - preloadDB = preloadDB.Where(preloadConditions[0], preloadConditions[1:]...) - } - - rows, err := preloadDB.Rows() - - if scope.Err(err) != nil { - return - } - defer rows.Close() - - columns, _ := rows.Columns() - for rows.Next() { - var ( - elem = reflect.New(fieldType).Elem() - fields = scope.New(elem.Addr().Interface()).Fields() - ) - - // register foreign keys in join tables - var joinTableFields []*Field - for _, sourceKey := range sourceKeys { - joinTableFields = append(joinTableFields, &Field{StructField: &StructField{DBName: sourceKey, IsNormal: true}, Field: reflect.New(foreignKeyType).Elem()}) - } - - scope.scan(rows, columns, append(fields, joinTableFields...)) - - scope.New(elem.Addr().Interface()). - InstanceSet("gorm:skip_query_callback", true). - callCallbacks(scope.db.parent.callbacks.queries) - - var foreignKeys = make([]interface{}, len(sourceKeys)) - // generate hashed forkey keys in join table - for idx, joinTableField := range joinTableFields { - if !joinTableField.Field.IsNil() { - foreignKeys[idx] = joinTableField.Field.Elem().Interface() - } - } - hashedSourceKeys := toString(foreignKeys) - - if isPtr { - linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem.Addr()) - } else { - linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem) - } - } - - if err := rows.Err(); err != nil { - scope.Err(err) - } - - // assign find results - var ( - indirectScopeValue = scope.IndirectValue() - fieldsSourceMap = map[string][]reflect.Value{} - foreignFieldNames = []string{} - ) - - for _, dbName := range relation.ForeignFieldNames { - if field, ok := scope.FieldByName(dbName); ok { - foreignFieldNames = append(foreignFieldNames, field.Name) - } - } - - if indirectScopeValue.Kind() == reflect.Slice { - for j := 0; j < indirectScopeValue.Len(); j++ { - object := indirect(indirectScopeValue.Index(j)) - key := toString(getValueFromFields(object, foreignFieldNames)) - fieldsSourceMap[key] = append(fieldsSourceMap[key], object.FieldByName(field.Name)) - } - } else if indirectScopeValue.IsValid() { - key := toString(getValueFromFields(indirectScopeValue, foreignFieldNames)) - fieldsSourceMap[key] = append(fieldsSourceMap[key], indirectScopeValue.FieldByName(field.Name)) - } - for source, link := range linkHash { - for i, field := range fieldsSourceMap[source] { - //If not 0 this means Value is a pointer and we already added preloaded models to it - if fieldsSourceMap[source][i].Len() != 0 { - continue - } - field.Set(reflect.Append(fieldsSourceMap[source][i], link...)) - } - - } -} diff --git a/vendor/github.com/jinzhu/gorm/callback_row_query.go b/vendor/github.com/jinzhu/gorm/callback_row_query.go deleted file mode 100644 index c2ff4a083..000000000 --- a/vendor/github.com/jinzhu/gorm/callback_row_query.go +++ /dev/null @@ -1,30 +0,0 @@ -package gorm - -import "database/sql" - -// Define callbacks for row query -func init() { - DefaultCallback.RowQuery().Register("gorm:row_query", rowQueryCallback) -} - -type RowQueryResult struct { - Row *sql.Row -} - -type RowsQueryResult struct { - Rows *sql.Rows - Error error -} - -// queryCallback used to query data from database -func rowQueryCallback(scope *Scope) { - if result, ok := scope.InstanceGet("row_query_result"); ok { - scope.prepareQuerySQL() - - if rowResult, ok := result.(*RowQueryResult); ok { - rowResult.Row = scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...) - } else if rowsResult, ok := result.(*RowsQueryResult); ok { - rowsResult.Rows, rowsResult.Error = scope.SQLDB().Query(scope.SQL, scope.SQLVars...) - } - } -} diff --git a/vendor/github.com/jinzhu/gorm/callback_save.go b/vendor/github.com/jinzhu/gorm/callback_save.go deleted file mode 100644 index 3b4e05895..000000000 --- a/vendor/github.com/jinzhu/gorm/callback_save.go +++ /dev/null @@ -1,170 +0,0 @@ -package gorm - -import ( - "reflect" - "strings" -) - -func beginTransactionCallback(scope *Scope) { - scope.Begin() -} - -func commitOrRollbackTransactionCallback(scope *Scope) { - scope.CommitOrRollback() -} - -func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCreate bool, saveReference bool, r *Relationship) { - checkTruth := func(value interface{}) bool { - if v, ok := value.(bool); ok && !v { - return false - } - - if v, ok := value.(string); ok { - v = strings.ToLower(v) - return v == "true" - } - - return true - } - - if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { - if r = field.Relationship; r != nil { - autoUpdate, autoCreate, saveReference = true, true, true - - if value, ok := scope.Get("gorm:save_associations"); ok { - autoUpdate = checkTruth(value) - autoCreate = autoUpdate - saveReference = autoUpdate - } else if value, ok := field.TagSettingsGet("SAVE_ASSOCIATIONS"); ok { - autoUpdate = checkTruth(value) - autoCreate = autoUpdate - saveReference = autoUpdate - } - - if value, ok := scope.Get("gorm:association_autoupdate"); ok { - autoUpdate = checkTruth(value) - } else if value, ok := field.TagSettingsGet("ASSOCIATION_AUTOUPDATE"); ok { - autoUpdate = checkTruth(value) - } - - if value, ok := scope.Get("gorm:association_autocreate"); ok { - autoCreate = checkTruth(value) - } else if value, ok := field.TagSettingsGet("ASSOCIATION_AUTOCREATE"); ok { - autoCreate = checkTruth(value) - } - - if value, ok := scope.Get("gorm:association_save_reference"); ok { - saveReference = checkTruth(value) - } else if value, ok := field.TagSettingsGet("ASSOCIATION_SAVE_REFERENCE"); ok { - saveReference = checkTruth(value) - } - } - } - - return -} - -func saveBeforeAssociationsCallback(scope *Scope) { - for _, field := range scope.Fields() { - autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field) - - if relationship != nil && relationship.Kind == "belongs_to" { - fieldValue := field.Field.Addr().Interface() - newScope := scope.New(fieldValue) - - if newScope.PrimaryKeyZero() { - if autoCreate { - scope.Err(scope.NewDB().Save(fieldValue).Error) - } - } else if autoUpdate { - scope.Err(scope.NewDB().Save(fieldValue).Error) - } - - if saveReference { - if len(relationship.ForeignFieldNames) != 0 { - // set value's foreign key - for idx, fieldName := range relationship.ForeignFieldNames { - associationForeignName := relationship.AssociationForeignDBNames[idx] - if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok { - scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface())) - } - } - } - } - } - } -} - -func saveAfterAssociationsCallback(scope *Scope) { - for _, field := range scope.Fields() { - autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field) - - if relationship != nil && (relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") { - value := field.Field - - switch value.Kind() { - case reflect.Slice: - for i := 0; i < value.Len(); i++ { - newDB := scope.NewDB() - elem := value.Index(i).Addr().Interface() - newScope := newDB.NewScope(elem) - - if saveReference { - if relationship.JoinTableHandler == nil && len(relationship.ForeignFieldNames) != 0 { - for idx, fieldName := range relationship.ForeignFieldNames { - associationForeignName := relationship.AssociationForeignDBNames[idx] - if f, ok := scope.FieldByName(associationForeignName); ok { - scope.Err(newScope.SetColumn(fieldName, f.Field.Interface())) - } - } - } - - if relationship.PolymorphicType != "" { - scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue)) - } - } - - if newScope.PrimaryKeyZero() { - if autoCreate { - scope.Err(newDB.Save(elem).Error) - } - } else if autoUpdate { - scope.Err(newDB.Save(elem).Error) - } - - if !scope.New(newScope.Value).PrimaryKeyZero() && saveReference { - if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil { - scope.Err(joinTableHandler.Add(joinTableHandler, newDB, scope.Value, newScope.Value)) - } - } - } - default: - elem := value.Addr().Interface() - newScope := scope.New(elem) - - if saveReference { - if len(relationship.ForeignFieldNames) != 0 { - for idx, fieldName := range relationship.ForeignFieldNames { - associationForeignName := relationship.AssociationForeignDBNames[idx] - if f, ok := scope.FieldByName(associationForeignName); ok { - scope.Err(newScope.SetColumn(fieldName, f.Field.Interface())) - } - } - } - - if relationship.PolymorphicType != "" { - scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue)) - } - } - - if newScope.PrimaryKeyZero() { - if autoCreate { - scope.Err(scope.NewDB().Save(elem).Error) - } - } else if autoUpdate { - scope.Err(scope.NewDB().Save(elem).Error) - } - } - } - } -} diff --git a/vendor/github.com/jinzhu/gorm/callback_update.go b/vendor/github.com/jinzhu/gorm/callback_update.go deleted file mode 100644 index f6ba0ffd9..000000000 --- a/vendor/github.com/jinzhu/gorm/callback_update.go +++ /dev/null @@ -1,121 +0,0 @@ -package gorm - -import ( - "errors" - "fmt" - "sort" - "strings" -) - -// Define callbacks for updating -func init() { - DefaultCallback.Update().Register("gorm:assign_updating_attributes", assignUpdatingAttributesCallback) - DefaultCallback.Update().Register("gorm:begin_transaction", beginTransactionCallback) - DefaultCallback.Update().Register("gorm:before_update", beforeUpdateCallback) - DefaultCallback.Update().Register("gorm:save_before_associations", saveBeforeAssociationsCallback) - DefaultCallback.Update().Register("gorm:update_time_stamp", updateTimeStampForUpdateCallback) - DefaultCallback.Update().Register("gorm:update", updateCallback) - DefaultCallback.Update().Register("gorm:save_after_associations", saveAfterAssociationsCallback) - DefaultCallback.Update().Register("gorm:after_update", afterUpdateCallback) - DefaultCallback.Update().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback) -} - -// assignUpdatingAttributesCallback assign updating attributes to model -func assignUpdatingAttributesCallback(scope *Scope) { - if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok { - if updateMaps, hasUpdate := scope.updatedAttrsWithValues(attrs); hasUpdate { - scope.InstanceSet("gorm:update_attrs", updateMaps) - } else { - scope.SkipLeft() - } - } -} - -// beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating -func beforeUpdateCallback(scope *Scope) { - if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() { - scope.Err(errors.New("Missing WHERE clause while updating")) - return - } - if _, ok := scope.Get("gorm:update_column"); !ok { - if !scope.HasError() { - scope.CallMethod("BeforeSave") - } - if !scope.HasError() { - scope.CallMethod("BeforeUpdate") - } - } -} - -// updateTimeStampForUpdateCallback will set `UpdatedAt` when updating -func updateTimeStampForUpdateCallback(scope *Scope) { - if _, ok := scope.Get("gorm:update_column"); !ok { - scope.SetColumn("UpdatedAt", NowFunc()) - } -} - -// updateCallback the callback used to update data to database -func updateCallback(scope *Scope) { - if !scope.HasError() { - var sqls []string - - if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok { - // Sort the column names so that the generated SQL is the same every time. - updateMap := updateAttrs.(map[string]interface{}) - var columns []string - for c := range updateMap { - columns = append(columns, c) - } - sort.Strings(columns) - - for _, column := range columns { - value := updateMap[column] - sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(column), scope.AddToVars(value))) - } - } else { - for _, field := range scope.Fields() { - if scope.changeableField(field) { - if !field.IsPrimaryKey && field.IsNormal { - if !field.IsForeignKey || !field.IsBlank || !field.HasDefaultValue { - sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) - } - } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { - for _, foreignKey := range relationship.ForeignDBNames { - if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) { - sqls = append(sqls, - fmt.Sprintf("%v = %v", scope.Quote(foreignField.DBName), scope.AddToVars(foreignField.Field.Interface()))) - } - } - } - } - } - } - - var extraOption string - if str, ok := scope.Get("gorm:update_option"); ok { - extraOption = fmt.Sprint(str) - } - - if len(sqls) > 0 { - scope.Raw(fmt.Sprintf( - "UPDATE %v SET %v%v%v", - scope.QuotedTableName(), - strings.Join(sqls, ", "), - addExtraSpaceIfExist(scope.CombinedConditionSql()), - addExtraSpaceIfExist(extraOption), - )).Exec() - } - } -} - -// afterUpdateCallback will invoke `AfterUpdate`, `AfterSave` method after updating -func afterUpdateCallback(scope *Scope) { - if _, ok := scope.Get("gorm:update_column"); !ok { - if !scope.HasError() { - scope.CallMethod("AfterUpdate") - } - if !scope.HasError() { - scope.CallMethod("AfterSave") - } - } -} diff --git a/vendor/github.com/jinzhu/gorm/dialect.go b/vendor/github.com/jinzhu/gorm/dialect.go deleted file mode 100644 index 27b308af3..000000000 --- a/vendor/github.com/jinzhu/gorm/dialect.go +++ /dev/null @@ -1,138 +0,0 @@ -package gorm - -import ( - "database/sql" - "fmt" - "reflect" - "strconv" - "strings" -) - -// Dialect interface contains behaviors that differ across SQL database -type Dialect interface { - // GetName get dialect's name - GetName() string - - // SetDB set db for dialect - SetDB(db SQLCommon) - - // BindVar return the placeholder for actual values in SQL statements, in many dbs it is "?", Postgres using $1 - BindVar(i int) string - // Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name - Quote(key string) string - // DataTypeOf return data's sql type - DataTypeOf(field *StructField) string - - // HasIndex check has index or not - HasIndex(tableName string, indexName string) bool - // HasForeignKey check has foreign key or not - HasForeignKey(tableName string, foreignKeyName string) bool - // RemoveIndex remove index - RemoveIndex(tableName string, indexName string) error - // HasTable check has table or not - HasTable(tableName string) bool - // HasColumn check has column or not - HasColumn(tableName string, columnName string) bool - // ModifyColumn modify column's type - ModifyColumn(tableName string, columnName string, typ string) error - - // LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case - LimitAndOffsetSQL(limit, offset interface{}) string - // SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL` - SelectFromDummyTable() string - // LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING` - LastInsertIDReturningSuffix(tableName, columnName string) string - // DefaultValueStr - DefaultValueStr() string - - // BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference - BuildKeyName(kind, tableName string, fields ...string) string - - // CurrentDatabase return current database name - CurrentDatabase() string -} - -var dialectsMap = map[string]Dialect{} - -func newDialect(name string, db SQLCommon) Dialect { - if value, ok := dialectsMap[name]; ok { - dialect := reflect.New(reflect.TypeOf(value).Elem()).Interface().(Dialect) - dialect.SetDB(db) - return dialect - } - - fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", name) - commontDialect := &commonDialect{} - commontDialect.SetDB(db) - return commontDialect -} - -// RegisterDialect register new dialect -func RegisterDialect(name string, dialect Dialect) { - dialectsMap[name] = dialect -} - -// GetDialect gets the dialect for the specified dialect name -func GetDialect(name string) (dialect Dialect, ok bool) { - dialect, ok = dialectsMap[name] - return -} - -// ParseFieldStructForDialect get field's sql data type -var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fieldValue reflect.Value, sqlType string, size int, additionalType string) { - // Get redirected field type - var ( - reflectType = field.Struct.Type - dataType, _ = field.TagSettingsGet("TYPE") - ) - - for reflectType.Kind() == reflect.Ptr { - reflectType = reflectType.Elem() - } - - // Get redirected field value - fieldValue = reflect.Indirect(reflect.New(reflectType)) - - if gormDataType, ok := fieldValue.Interface().(interface { - GormDataType(Dialect) string - }); ok { - dataType = gormDataType.GormDataType(dialect) - } - - // Get scanner's real value - if dataType == "" { - var getScannerValue func(reflect.Value) - getScannerValue = func(value reflect.Value) { - fieldValue = value - if _, isScanner := reflect.New(fieldValue.Type()).Interface().(sql.Scanner); isScanner && fieldValue.Kind() == reflect.Struct { - getScannerValue(fieldValue.Field(0)) - } - } - getScannerValue(fieldValue) - } - - // Default Size - if num, ok := field.TagSettingsGet("SIZE"); ok { - size, _ = strconv.Atoi(num) - } else { - size = 255 - } - - // Default type from tag setting - notNull, _ := field.TagSettingsGet("NOT NULL") - unique, _ := field.TagSettingsGet("UNIQUE") - additionalType = notNull + " " + unique - if value, ok := field.TagSettingsGet("DEFAULT"); ok { - additionalType = additionalType + " DEFAULT " + value - } - - return fieldValue, dataType, size, strings.TrimSpace(additionalType) -} - -func currentDatabaseAndTable(dialect Dialect, tableName string) (string, string) { - if strings.Contains(tableName, ".") { - splitStrings := strings.SplitN(tableName, ".", 2) - return splitStrings[0], splitStrings[1] - } - return dialect.CurrentDatabase(), tableName -} diff --git a/vendor/github.com/jinzhu/gorm/dialect_common.go b/vendor/github.com/jinzhu/gorm/dialect_common.go deleted file mode 100644 index a479be79b..000000000 --- a/vendor/github.com/jinzhu/gorm/dialect_common.go +++ /dev/null @@ -1,176 +0,0 @@ -package gorm - -import ( - "fmt" - "reflect" - "regexp" - "strconv" - "strings" - "time" -) - -// DefaultForeignKeyNamer contains the default foreign key name generator method -type DefaultForeignKeyNamer struct { -} - -type commonDialect struct { - db SQLCommon - DefaultForeignKeyNamer -} - -func init() { - RegisterDialect("common", &commonDialect{}) -} - -func (commonDialect) GetName() string { - return "common" -} - -func (s *commonDialect) SetDB(db SQLCommon) { - s.db = db -} - -func (commonDialect) BindVar(i int) string { - return "$$$" // ? -} - -func (commonDialect) Quote(key string) string { - return fmt.Sprintf(`"%s"`, key) -} - -func (s *commonDialect) fieldCanAutoIncrement(field *StructField) bool { - if value, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok { - return strings.ToLower(value) != "false" - } - return field.IsPrimaryKey -} - -func (s *commonDialect) DataTypeOf(field *StructField) string { - var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s) - - if sqlType == "" { - switch dataValue.Kind() { - case reflect.Bool: - sqlType = "BOOLEAN" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if s.fieldCanAutoIncrement(field) { - sqlType = "INTEGER AUTO_INCREMENT" - } else { - sqlType = "INTEGER" - } - case reflect.Int64, reflect.Uint64: - if s.fieldCanAutoIncrement(field) { - sqlType = "BIGINT AUTO_INCREMENT" - } else { - sqlType = "BIGINT" - } - case reflect.Float32, reflect.Float64: - sqlType = "FLOAT" - case reflect.String: - if size > 0 && size < 65532 { - sqlType = fmt.Sprintf("VARCHAR(%d)", size) - } else { - sqlType = "VARCHAR(65532)" - } - case reflect.Struct: - if _, ok := dataValue.Interface().(time.Time); ok { - sqlType = "TIMESTAMP" - } - default: - if _, ok := dataValue.Interface().([]byte); ok { - if size > 0 && size < 65532 { - sqlType = fmt.Sprintf("BINARY(%d)", size) - } else { - sqlType = "BINARY(65532)" - } - } - } - } - - if sqlType == "" { - panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", dataValue.Type().Name(), dataValue.Kind().String())) - } - - if strings.TrimSpace(additionalType) == "" { - return sqlType - } - return fmt.Sprintf("%v %v", sqlType, additionalType) -} - -func (s commonDialect) HasIndex(tableName string, indexName string) bool { - var count int - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, tableName, indexName).Scan(&count) - return count > 0 -} - -func (s commonDialect) RemoveIndex(tableName string, indexName string) error { - _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v", indexName)) - return err -} - -func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bool { - return false -} - -func (s commonDialect) HasTable(tableName string) bool { - var count int - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", currentDatabase, tableName).Scan(&count) - return count > 0 -} - -func (s commonDialect) HasColumn(tableName string, columnName string) bool { - var count int - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count) - return count > 0 -} - -func (s commonDialect) ModifyColumn(tableName string, columnName string, typ string) error { - _, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", tableName, columnName, typ)) - return err -} - -func (s commonDialect) CurrentDatabase() (name string) { - s.db.QueryRow("SELECT DATABASE()").Scan(&name) - return -} - -func (commonDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string) { - if limit != nil { - if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 { - sql += fmt.Sprintf(" LIMIT %d", parsedLimit) - } - } - if offset != nil { - if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 { - sql += fmt.Sprintf(" OFFSET %d", parsedOffset) - } - } - return -} - -func (commonDialect) SelectFromDummyTable() string { - return "" -} - -func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string { - return "" -} - -func (commonDialect) DefaultValueStr() string { - return "DEFAULT VALUES" -} - -// BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference -func (DefaultForeignKeyNamer) BuildKeyName(kind, tableName string, fields ...string) string { - keyName := fmt.Sprintf("%s_%s_%s", kind, tableName, strings.Join(fields, "_")) - keyName = regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(keyName, "_") - return keyName -} - -// IsByteArrayOrSlice returns true of the reflected value is an array or slice -func IsByteArrayOrSlice(value reflect.Value) bool { - return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0)) -} diff --git a/vendor/github.com/jinzhu/gorm/dialect_mysql.go b/vendor/github.com/jinzhu/gorm/dialect_mysql.go deleted file mode 100644 index 5d63e5cd2..000000000 --- a/vendor/github.com/jinzhu/gorm/dialect_mysql.go +++ /dev/null @@ -1,191 +0,0 @@ -package gorm - -import ( - "crypto/sha1" - "fmt" - "reflect" - "regexp" - "strconv" - "strings" - "time" - "unicode/utf8" -) - -type mysql struct { - commonDialect -} - -func init() { - RegisterDialect("mysql", &mysql{}) -} - -func (mysql) GetName() string { - return "mysql" -} - -func (mysql) Quote(key string) string { - return fmt.Sprintf("`%s`", key) -} - -// Get Data Type for MySQL Dialect -func (s *mysql) DataTypeOf(field *StructField) string { - var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s) - - // MySQL allows only one auto increment column per table, and it must - // be a KEY column. - if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok { - if _, ok = field.TagSettingsGet("INDEX"); !ok && !field.IsPrimaryKey { - field.TagSettingsDelete("AUTO_INCREMENT") - } - } - - if sqlType == "" { - switch dataValue.Kind() { - case reflect.Bool: - sqlType = "boolean" - case reflect.Int8: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "tinyint AUTO_INCREMENT" - } else { - sqlType = "tinyint" - } - case reflect.Int, reflect.Int16, reflect.Int32: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "int AUTO_INCREMENT" - } else { - sqlType = "int" - } - case reflect.Uint8: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "tinyint unsigned AUTO_INCREMENT" - } else { - sqlType = "tinyint unsigned" - } - case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "int unsigned AUTO_INCREMENT" - } else { - sqlType = "int unsigned" - } - case reflect.Int64: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "bigint AUTO_INCREMENT" - } else { - sqlType = "bigint" - } - case reflect.Uint64: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "bigint unsigned AUTO_INCREMENT" - } else { - sqlType = "bigint unsigned" - } - case reflect.Float32, reflect.Float64: - sqlType = "double" - case reflect.String: - if size > 0 && size < 65532 { - sqlType = fmt.Sprintf("varchar(%d)", size) - } else { - sqlType = "longtext" - } - case reflect.Struct: - if _, ok := dataValue.Interface().(time.Time); ok { - precision := "" - if p, ok := field.TagSettingsGet("PRECISION"); ok { - precision = fmt.Sprintf("(%s)", p) - } - - if _, ok := field.TagSettingsGet("NOT NULL"); ok { - sqlType = fmt.Sprintf("timestamp%v", precision) - } else { - sqlType = fmt.Sprintf("timestamp%v NULL", precision) - } - } - default: - if IsByteArrayOrSlice(dataValue) { - if size > 0 && size < 65532 { - sqlType = fmt.Sprintf("varbinary(%d)", size) - } else { - sqlType = "longblob" - } - } - } - } - - if sqlType == "" { - panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", dataValue.Type().Name(), dataValue.Kind().String())) - } - - if strings.TrimSpace(additionalType) == "" { - return sqlType - } - return fmt.Sprintf("%v %v", sqlType, additionalType) -} - -func (s mysql) RemoveIndex(tableName string, indexName string) error { - _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName))) - return err -} - -func (s mysql) ModifyColumn(tableName string, columnName string, typ string) error { - _, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v MODIFY COLUMN %v %v", tableName, columnName, typ)) - return err -} - -func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) { - if limit != nil { - if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 { - sql += fmt.Sprintf(" LIMIT %d", parsedLimit) - - if offset != nil { - if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 { - sql += fmt.Sprintf(" OFFSET %d", parsedOffset) - } - } - } - } - return -} - -func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool { - var count int - currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", currentDatabase, tableName, foreignKeyName).Scan(&count) - return count > 0 -} - -func (s mysql) CurrentDatabase() (name string) { - s.db.QueryRow("SELECT DATABASE()").Scan(&name) - return -} - -func (mysql) SelectFromDummyTable() string { - return "FROM DUAL" -} - -func (s mysql) BuildKeyName(kind, tableName string, fields ...string) string { - keyName := s.commonDialect.BuildKeyName(kind, tableName, fields...) - if utf8.RuneCountInString(keyName) <= 64 { - return keyName - } - h := sha1.New() - h.Write([]byte(keyName)) - bs := h.Sum(nil) - - // sha1 is 40 characters, keep first 24 characters of destination - destRunes := []rune(regexp.MustCompile("[^a-zA-Z0-9]+").ReplaceAllString(fields[0], "_")) - if len(destRunes) > 24 { - destRunes = destRunes[:24] - } - - return fmt.Sprintf("%s%x", string(destRunes), bs) -} - -func (mysql) DefaultValueStr() string { - return "VALUES()" -} diff --git a/vendor/github.com/jinzhu/gorm/dialect_postgres.go b/vendor/github.com/jinzhu/gorm/dialect_postgres.go deleted file mode 100644 index 53d31388e..000000000 --- a/vendor/github.com/jinzhu/gorm/dialect_postgres.go +++ /dev/null @@ -1,143 +0,0 @@ -package gorm - -import ( - "encoding/json" - "fmt" - "reflect" - "strings" - "time" -) - -type postgres struct { - commonDialect -} - -func init() { - RegisterDialect("postgres", &postgres{}) - RegisterDialect("cloudsqlpostgres", &postgres{}) -} - -func (postgres) GetName() string { - return "postgres" -} - -func (postgres) BindVar(i int) string { - return fmt.Sprintf("$%v", i) -} - -func (s *postgres) DataTypeOf(field *StructField) string { - var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s) - - if sqlType == "" { - switch dataValue.Kind() { - case reflect.Bool: - sqlType = "boolean" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uintptr: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "serial" - } else { - sqlType = "integer" - } - case reflect.Int64, reflect.Uint32, reflect.Uint64: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "bigserial" - } else { - sqlType = "bigint" - } - case reflect.Float32, reflect.Float64: - sqlType = "numeric" - case reflect.String: - if _, ok := field.TagSettingsGet("SIZE"); !ok { - size = 0 // if SIZE haven't been set, use `text` as the default type, as there are no performance different - } - - if size > 0 && size < 65532 { - sqlType = fmt.Sprintf("varchar(%d)", size) - } else { - sqlType = "text" - } - case reflect.Struct: - if _, ok := dataValue.Interface().(time.Time); ok { - sqlType = "timestamp with time zone" - } - case reflect.Map: - if dataValue.Type().Name() == "Hstore" { - sqlType = "hstore" - } - default: - if IsByteArrayOrSlice(dataValue) { - sqlType = "bytea" - - if isUUID(dataValue) { - sqlType = "uuid" - } - - if isJSON(dataValue) { - sqlType = "jsonb" - } - } - } - } - - if sqlType == "" { - panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", dataValue.Type().Name(), dataValue.Kind().String())) - } - - if strings.TrimSpace(additionalType) == "" { - return sqlType - } - return fmt.Sprintf("%v %v", sqlType, additionalType) -} - -func (s postgres) HasIndex(tableName string, indexName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2 AND schemaname = CURRENT_SCHEMA()", tableName, indexName).Scan(&count) - return count > 0 -} - -func (s postgres) HasForeignKey(tableName string, foreignKeyName string) bool { - var count int - s.db.QueryRow("SELECT count(con.conname) FROM pg_constraint con WHERE $1::regclass::oid = con.conrelid AND con.conname = $2 AND con.contype='f'", tableName, foreignKeyName).Scan(&count) - return count > 0 -} - -func (s postgres) HasTable(tableName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE' AND table_schema = CURRENT_SCHEMA()", tableName).Scan(&count) - return count > 0 -} - -func (s postgres) HasColumn(tableName string, columnName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2 AND table_schema = CURRENT_SCHEMA()", tableName, columnName).Scan(&count) - return count > 0 -} - -func (s postgres) CurrentDatabase() (name string) { - s.db.QueryRow("SELECT CURRENT_DATABASE()").Scan(&name) - return -} - -func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string { - return fmt.Sprintf("RETURNING %v.%v", tableName, key) -} - -func (postgres) SupportLastInsertID() bool { - return false -} - -func isUUID(value reflect.Value) bool { - if value.Kind() != reflect.Array || value.Type().Len() != 16 { - return false - } - typename := value.Type().Name() - lower := strings.ToLower(typename) - return "uuid" == lower || "guid" == lower -} - -func isJSON(value reflect.Value) bool { - _, ok := value.Interface().(json.RawMessage) - return ok -} diff --git a/vendor/github.com/jinzhu/gorm/dialect_sqlite3.go b/vendor/github.com/jinzhu/gorm/dialect_sqlite3.go deleted file mode 100644 index 5f96c363a..000000000 --- a/vendor/github.com/jinzhu/gorm/dialect_sqlite3.go +++ /dev/null @@ -1,107 +0,0 @@ -package gorm - -import ( - "fmt" - "reflect" - "strings" - "time" -) - -type sqlite3 struct { - commonDialect -} - -func init() { - RegisterDialect("sqlite3", &sqlite3{}) -} - -func (sqlite3) GetName() string { - return "sqlite3" -} - -// Get Data Type for Sqlite Dialect -func (s *sqlite3) DataTypeOf(field *StructField) string { - var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s) - - if sqlType == "" { - switch dataValue.Kind() { - case reflect.Bool: - sqlType = "bool" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "integer primary key autoincrement" - } else { - sqlType = "integer" - } - case reflect.Int64, reflect.Uint64: - if s.fieldCanAutoIncrement(field) { - field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") - sqlType = "integer primary key autoincrement" - } else { - sqlType = "bigint" - } - case reflect.Float32, reflect.Float64: - sqlType = "real" - case reflect.String: - if size > 0 && size < 65532 { - sqlType = fmt.Sprintf("varchar(%d)", size) - } else { - sqlType = "text" - } - case reflect.Struct: - if _, ok := dataValue.Interface().(time.Time); ok { - sqlType = "datetime" - } - default: - if IsByteArrayOrSlice(dataValue) { - sqlType = "blob" - } - } - } - - if sqlType == "" { - panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", dataValue.Type().Name(), dataValue.Kind().String())) - } - - if strings.TrimSpace(additionalType) == "" { - return sqlType - } - return fmt.Sprintf("%v %v", sqlType, additionalType) -} - -func (s sqlite3) HasIndex(tableName string, indexName string) bool { - var count int - s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Scan(&count) - return count > 0 -} - -func (s sqlite3) HasTable(tableName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Scan(&count) - return count > 0 -} - -func (s sqlite3) HasColumn(tableName string, columnName string) bool { - var count int - s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%\"%v\" %%' OR sql LIKE '%%%v %%');\n", columnName, columnName), tableName).Scan(&count) - return count > 0 -} - -func (s sqlite3) CurrentDatabase() (name string) { - var ( - ifaces = make([]interface{}, 3) - pointers = make([]*string, 3) - i int - ) - for i = 0; i < 3; i++ { - ifaces[i] = &pointers[i] - } - if err := s.db.QueryRow("PRAGMA database_list").Scan(ifaces...); err != nil { - return - } - if pointers[1] != nil { - name = *pointers[1] - } - return -} diff --git a/vendor/github.com/jinzhu/gorm/errors.go b/vendor/github.com/jinzhu/gorm/errors.go deleted file mode 100644 index 27c9a92d0..000000000 --- a/vendor/github.com/jinzhu/gorm/errors.go +++ /dev/null @@ -1,72 +0,0 @@ -package gorm - -import ( - "errors" - "strings" -) - -var ( - // ErrRecordNotFound record not found error, happens when only haven't find any matched data when looking up with a struct, finding a slice won't return this error - ErrRecordNotFound = errors.New("record not found") - // ErrInvalidSQL invalid SQL error, happens when you passed invalid SQL - ErrInvalidSQL = errors.New("invalid SQL") - // ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback` - ErrInvalidTransaction = errors.New("no valid transaction") - // ErrCantStartTransaction can't start transaction when you are trying to start one with `Begin` - ErrCantStartTransaction = errors.New("can't start transaction") - // ErrUnaddressable unaddressable value - ErrUnaddressable = errors.New("using unaddressable value") -) - -// Errors contains all happened errors -type Errors []error - -// IsRecordNotFoundError returns current error has record not found error or not -func IsRecordNotFoundError(err error) bool { - if errs, ok := err.(Errors); ok { - for _, err := range errs { - if err == ErrRecordNotFound { - return true - } - } - } - return err == ErrRecordNotFound -} - -// GetErrors gets all happened errors -func (errs Errors) GetErrors() []error { - return errs -} - -// Add adds an error -func (errs Errors) Add(newErrors ...error) Errors { - for _, err := range newErrors { - if err == nil { - continue - } - - if errors, ok := err.(Errors); ok { - errs = errs.Add(errors...) - } else { - ok = true - for _, e := range errs { - if err == e { - ok = false - } - } - if ok { - errs = append(errs, err) - } - } - } - return errs -} - -// Error format happened errors -func (errs Errors) Error() string { - var errors = []string{} - for _, e := range errs { - errors = append(errors, e.Error()) - } - return strings.Join(errors, "; ") -} diff --git a/vendor/github.com/jinzhu/gorm/field.go b/vendor/github.com/jinzhu/gorm/field.go deleted file mode 100644 index acd06e20d..000000000 --- a/vendor/github.com/jinzhu/gorm/field.go +++ /dev/null @@ -1,66 +0,0 @@ -package gorm - -import ( - "database/sql" - "database/sql/driver" - "errors" - "fmt" - "reflect" -) - -// Field model field definition -type Field struct { - *StructField - IsBlank bool - Field reflect.Value -} - -// Set set a value to the field -func (field *Field) Set(value interface{}) (err error) { - if !field.Field.IsValid() { - return errors.New("field value not valid") - } - - if !field.Field.CanAddr() { - return ErrUnaddressable - } - - reflectValue, ok := value.(reflect.Value) - if !ok { - reflectValue = reflect.ValueOf(value) - } - - fieldValue := field.Field - if reflectValue.IsValid() { - if reflectValue.Type().ConvertibleTo(fieldValue.Type()) { - fieldValue.Set(reflectValue.Convert(fieldValue.Type())) - } else { - if fieldValue.Kind() == reflect.Ptr { - if fieldValue.IsNil() { - fieldValue.Set(reflect.New(field.Struct.Type.Elem())) - } - fieldValue = fieldValue.Elem() - } - - if reflectValue.Type().ConvertibleTo(fieldValue.Type()) { - fieldValue.Set(reflectValue.Convert(fieldValue.Type())) - } else if scanner, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { - v := reflectValue.Interface() - if valuer, ok := v.(driver.Valuer); ok { - if v, err = valuer.Value(); err == nil { - err = scanner.Scan(v) - } - } else { - err = scanner.Scan(v) - } - } else { - err = fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), fieldValue.Type()) - } - } - } else { - field.Field.Set(reflect.Zero(field.Field.Type())) - } - - field.IsBlank = isBlank(field.Field) - return err -} diff --git a/vendor/github.com/jinzhu/gorm/interface.go b/vendor/github.com/jinzhu/gorm/interface.go deleted file mode 100644 index 55128f7fc..000000000 --- a/vendor/github.com/jinzhu/gorm/interface.go +++ /dev/null @@ -1,20 +0,0 @@ -package gorm - -import "database/sql" - -// SQLCommon is the minimal database connection functionality gorm requires. Implemented by *sql.DB. -type SQLCommon interface { - Exec(query string, args ...interface{}) (sql.Result, error) - Prepare(query string) (*sql.Stmt, error) - Query(query string, args ...interface{}) (*sql.Rows, error) - QueryRow(query string, args ...interface{}) *sql.Row -} - -type sqlDb interface { - Begin() (*sql.Tx, error) -} - -type sqlTx interface { - Commit() error - Rollback() error -} diff --git a/vendor/github.com/jinzhu/gorm/join_table_handler.go b/vendor/github.com/jinzhu/gorm/join_table_handler.go deleted file mode 100644 index a036d46d2..000000000 --- a/vendor/github.com/jinzhu/gorm/join_table_handler.go +++ /dev/null @@ -1,211 +0,0 @@ -package gorm - -import ( - "errors" - "fmt" - "reflect" - "strings" -) - -// JoinTableHandlerInterface is an interface for how to handle many2many relations -type JoinTableHandlerInterface interface { - // initialize join table handler - Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) - // Table return join table's table name - Table(db *DB) string - // Add create relationship in join table for source and destination - Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error - // Delete delete relationship in join table for sources - Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error - // JoinWith query with `Join` conditions - JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB - // SourceForeignKeys return source foreign keys - SourceForeignKeys() []JoinTableForeignKey - // DestinationForeignKeys return destination foreign keys - DestinationForeignKeys() []JoinTableForeignKey -} - -// JoinTableForeignKey join table foreign key struct -type JoinTableForeignKey struct { - DBName string - AssociationDBName string -} - -// JoinTableSource is a struct that contains model type and foreign keys -type JoinTableSource struct { - ModelType reflect.Type - ForeignKeys []JoinTableForeignKey -} - -// JoinTableHandler default join table handler -type JoinTableHandler struct { - TableName string `sql:"-"` - Source JoinTableSource `sql:"-"` - Destination JoinTableSource `sql:"-"` -} - -// SourceForeignKeys return source foreign keys -func (s *JoinTableHandler) SourceForeignKeys() []JoinTableForeignKey { - return s.Source.ForeignKeys -} - -// DestinationForeignKeys return destination foreign keys -func (s *JoinTableHandler) DestinationForeignKeys() []JoinTableForeignKey { - return s.Destination.ForeignKeys -} - -// Setup initialize a default join table handler -func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) { - s.TableName = tableName - - s.Source = JoinTableSource{ModelType: source} - s.Source.ForeignKeys = []JoinTableForeignKey{} - for idx, dbName := range relationship.ForeignFieldNames { - s.Source.ForeignKeys = append(s.Source.ForeignKeys, JoinTableForeignKey{ - DBName: relationship.ForeignDBNames[idx], - AssociationDBName: dbName, - }) - } - - s.Destination = JoinTableSource{ModelType: destination} - s.Destination.ForeignKeys = []JoinTableForeignKey{} - for idx, dbName := range relationship.AssociationForeignFieldNames { - s.Destination.ForeignKeys = append(s.Destination.ForeignKeys, JoinTableForeignKey{ - DBName: relationship.AssociationForeignDBNames[idx], - AssociationDBName: dbName, - }) - } -} - -// Table return join table's table name -func (s JoinTableHandler) Table(db *DB) string { - return DefaultTableNameHandler(db, s.TableName) -} - -func (s JoinTableHandler) updateConditionMap(conditionMap map[string]interface{}, db *DB, joinTableSources []JoinTableSource, sources ...interface{}) { - for _, source := range sources { - scope := db.NewScope(source) - modelType := scope.GetModelStruct().ModelType - - for _, joinTableSource := range joinTableSources { - if joinTableSource.ModelType == modelType { - for _, foreignKey := range joinTableSource.ForeignKeys { - if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok { - conditionMap[foreignKey.DBName] = field.Field.Interface() - } - } - break - } - } - } -} - -// Add create relationship in join table for source and destination -func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error { - var ( - scope = db.NewScope("") - conditionMap = map[string]interface{}{} - ) - - // Update condition map for source - s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source}, source) - - // Update condition map for destination - s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Destination}, destination) - - var assignColumns, binVars, conditions []string - var values []interface{} - for key, value := range conditionMap { - assignColumns = append(assignColumns, scope.Quote(key)) - binVars = append(binVars, `?`) - conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key))) - values = append(values, value) - } - - for _, value := range values { - values = append(values, value) - } - - quotedTable := scope.Quote(handler.Table(db)) - sql := fmt.Sprintf( - "INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v)", - quotedTable, - strings.Join(assignColumns, ","), - strings.Join(binVars, ","), - scope.Dialect().SelectFromDummyTable(), - quotedTable, - strings.Join(conditions, " AND "), - ) - - return db.Exec(sql, values...).Error -} - -// Delete delete relationship in join table for sources -func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error { - var ( - scope = db.NewScope(nil) - conditions []string - values []interface{} - conditionMap = map[string]interface{}{} - ) - - s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source, s.Destination}, sources...) - - for key, value := range conditionMap { - conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key))) - values = append(values, value) - } - - return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error -} - -// JoinWith query with `Join` conditions -func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB { - var ( - scope = db.NewScope(source) - tableName = handler.Table(db) - quotedTableName = scope.Quote(tableName) - joinConditions []string - values []interface{} - ) - - if s.Source.ModelType == scope.GetModelStruct().ModelType { - destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName() - for _, foreignKey := range s.Destination.ForeignKeys { - joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTableName, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName))) - } - - var foreignDBNames []string - var foreignFieldNames []string - - for _, foreignKey := range s.Source.ForeignKeys { - foreignDBNames = append(foreignDBNames, foreignKey.DBName) - if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok { - foreignFieldNames = append(foreignFieldNames, field.Name) - } - } - - foreignFieldValues := scope.getColumnAsArray(foreignFieldNames, scope.Value) - - var condString string - if len(foreignFieldValues) > 0 { - var quotedForeignDBNames []string - for _, dbName := range foreignDBNames { - quotedForeignDBNames = append(quotedForeignDBNames, tableName+"."+dbName) - } - - condString = fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, quotedForeignDBNames), toQueryMarks(foreignFieldValues)) - - keys := scope.getColumnAsArray(foreignFieldNames, scope.Value) - values = append(values, toQueryValues(keys)) - } else { - condString = fmt.Sprintf("1 <> 1") - } - - return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTableName, strings.Join(joinConditions, " AND "))). - Where(condString, toQueryValues(foreignFieldValues)...) - } - - db.Error = errors.New("wrong source type for join table handler") - return db -} diff --git a/vendor/github.com/jinzhu/gorm/logger.go b/vendor/github.com/jinzhu/gorm/logger.go deleted file mode 100644 index 4324a2e40..000000000 --- a/vendor/github.com/jinzhu/gorm/logger.go +++ /dev/null @@ -1,119 +0,0 @@ -package gorm - -import ( - "database/sql/driver" - "fmt" - "log" - "os" - "reflect" - "regexp" - "strconv" - "time" - "unicode" -) - -var ( - defaultLogger = Logger{log.New(os.Stdout, "\r\n", 0)} - sqlRegexp = regexp.MustCompile(`\?`) - numericPlaceHolderRegexp = regexp.MustCompile(`\$\d+`) -) - -func isPrintable(s string) bool { - for _, r := range s { - if !unicode.IsPrint(r) { - return false - } - } - return true -} - -var LogFormatter = func(values ...interface{}) (messages []interface{}) { - if len(values) > 1 { - var ( - sql string - formattedValues []string - level = values[0] - currentTime = "\n\033[33m[" + NowFunc().Format("2006-01-02 15:04:05") + "]\033[0m" - source = fmt.Sprintf("\033[35m(%v)\033[0m", values[1]) - ) - - messages = []interface{}{source, currentTime} - - if level == "sql" { - // duration - messages = append(messages, fmt.Sprintf(" \033[36;1m[%.2fms]\033[0m ", float64(values[2].(time.Duration).Nanoseconds()/1e4)/100.0)) - // sql - - for _, value := range values[4].([]interface{}) { - indirectValue := reflect.Indirect(reflect.ValueOf(value)) - if indirectValue.IsValid() { - value = indirectValue.Interface() - if t, ok := value.(time.Time); ok { - formattedValues = append(formattedValues, fmt.Sprintf("'%v'", t.Format("2006-01-02 15:04:05"))) - } else if b, ok := value.([]byte); ok { - if str := string(b); isPrintable(str) { - formattedValues = append(formattedValues, fmt.Sprintf("'%v'", str)) - } else { - formattedValues = append(formattedValues, "''") - } - } else if r, ok := value.(driver.Valuer); ok { - if value, err := r.Value(); err == nil && value != nil { - formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value)) - } else { - formattedValues = append(formattedValues, "NULL") - } - } else { - formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value)) - } - } else { - formattedValues = append(formattedValues, "NULL") - } - } - - // differentiate between $n placeholders or else treat like ? - if numericPlaceHolderRegexp.MatchString(values[3].(string)) { - sql = values[3].(string) - for index, value := range formattedValues { - placeholder := fmt.Sprintf(`\$%d([^\d]|$)`, index+1) - sql = regexp.MustCompile(placeholder).ReplaceAllString(sql, value+"$1") - } - } else { - formattedValuesLength := len(formattedValues) - for index, value := range sqlRegexp.Split(values[3].(string), -1) { - sql += value - if index < formattedValuesLength { - sql += formattedValues[index] - } - } - } - - messages = append(messages, sql) - messages = append(messages, fmt.Sprintf(" \n\033[36;31m[%v]\033[0m ", strconv.FormatInt(values[5].(int64), 10)+" rows affected or returned ")) - } else { - messages = append(messages, "\033[31;1m") - messages = append(messages, values[2:]...) - messages = append(messages, "\033[0m") - } - } - - return -} - -type logger interface { - Print(v ...interface{}) -} - -// LogWriter log writer interface -type LogWriter interface { - Println(v ...interface{}) -} - -// Logger default logger -type Logger struct { - LogWriter -} - -// Print format & print log -func (logger Logger) Print(values ...interface{}) { - logger.Println(LogFormatter(values...)...) -} diff --git a/vendor/github.com/jinzhu/gorm/main.go b/vendor/github.com/jinzhu/gorm/main.go deleted file mode 100644 index 17c75ed30..000000000 --- a/vendor/github.com/jinzhu/gorm/main.go +++ /dev/null @@ -1,792 +0,0 @@ -package gorm - -import ( - "database/sql" - "errors" - "fmt" - "reflect" - "strings" - "sync" - "time" -) - -// DB contains information for current db connection -type DB struct { - Value interface{} - Error error - RowsAffected int64 - - // single db - db SQLCommon - blockGlobalUpdate bool - logMode int - logger logger - search *search - values sync.Map - - // global db - parent *DB - callbacks *Callback - dialect Dialect - singularTable bool -} - -// Open initialize a new db connection, need to import driver first, e.g: -// -// import _ "github.com/go-sql-driver/mysql" -// func main() { -// db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True&loc=Local") -// } -// GORM has wrapped some drivers, for easier to remember driver's import path, so you could import the mysql driver with -// import _ "github.com/jinzhu/gorm/dialects/mysql" -// // import _ "github.com/jinzhu/gorm/dialects/postgres" -// // import _ "github.com/jinzhu/gorm/dialects/sqlite" -// // import _ "github.com/jinzhu/gorm/dialects/mssql" -func Open(dialect string, args ...interface{}) (db *DB, err error) { - if len(args) == 0 { - err = errors.New("invalid database source") - return nil, err - } - var source string - var dbSQL SQLCommon - var ownDbSQL bool - - switch value := args[0].(type) { - case string: - var driver = dialect - if len(args) == 1 { - source = value - } else if len(args) >= 2 { - driver = value - source = args[1].(string) - } - dbSQL, err = sql.Open(driver, source) - ownDbSQL = true - case SQLCommon: - dbSQL = value - ownDbSQL = false - default: - return nil, fmt.Errorf("invalid database source: %v is not a valid type", value) - } - - db = &DB{ - db: dbSQL, - logger: defaultLogger, - callbacks: DefaultCallback, - dialect: newDialect(dialect, dbSQL), - } - db.parent = db - if err != nil { - return - } - // Send a ping to make sure the database connection is alive. - if d, ok := dbSQL.(*sql.DB); ok { - if err = d.Ping(); err != nil && ownDbSQL { - d.Close() - } - } - return -} - -// New clone a new db connection without search conditions -func (s *DB) New() *DB { - clone := s.clone() - clone.search = nil - clone.Value = nil - return clone -} - -type closer interface { - Close() error -} - -// Close close current db connection. If database connection is not an io.Closer, returns an error. -func (s *DB) Close() error { - if db, ok := s.parent.db.(closer); ok { - return db.Close() - } - return errors.New("can't close current db") -} - -// DB get `*sql.DB` from current connection -// If the underlying database connection is not a *sql.DB, returns nil -func (s *DB) DB() *sql.DB { - db, _ := s.db.(*sql.DB) - return db -} - -// CommonDB return the underlying `*sql.DB` or `*sql.Tx` instance, mainly intended to allow coexistence with legacy non-GORM code. -func (s *DB) CommonDB() SQLCommon { - return s.db -} - -// Dialect get dialect -func (s *DB) Dialect() Dialect { - return s.dialect -} - -// Callback return `Callbacks` container, you could add/change/delete callbacks with it -// db.Callback().Create().Register("update_created_at", updateCreated) -// Refer https://jinzhu.github.io/gorm/development.html#callbacks -func (s *DB) Callback() *Callback { - s.parent.callbacks = s.parent.callbacks.clone() - return s.parent.callbacks -} - -// SetLogger replace default logger -func (s *DB) SetLogger(log logger) { - s.logger = log -} - -// LogMode set log mode, `true` for detailed logs, `false` for no log, default, will only print error logs -func (s *DB) LogMode(enable bool) *DB { - if enable { - s.logMode = 2 - } else { - s.logMode = 1 - } - return s -} - -// BlockGlobalUpdate if true, generates an error on update/delete without where clause. -// This is to prevent eventual error with empty objects updates/deletions -func (s *DB) BlockGlobalUpdate(enable bool) *DB { - s.blockGlobalUpdate = enable - return s -} - -// HasBlockGlobalUpdate return state of block -func (s *DB) HasBlockGlobalUpdate() bool { - return s.blockGlobalUpdate -} - -// SingularTable use singular table by default -func (s *DB) SingularTable(enable bool) { - modelStructsMap = sync.Map{} - s.parent.singularTable = enable -} - -// NewScope create a scope for current operation -func (s *DB) NewScope(value interface{}) *Scope { - dbClone := s.clone() - dbClone.Value = value - return &Scope{db: dbClone, Search: dbClone.search.clone(), Value: value} -} - -// QueryExpr returns the query as expr object -func (s *DB) QueryExpr() *expr { - scope := s.NewScope(s.Value) - scope.InstanceSet("skip_bindvar", true) - scope.prepareQuerySQL() - - return Expr(scope.SQL, scope.SQLVars...) -} - -// SubQuery returns the query as sub query -func (s *DB) SubQuery() *expr { - scope := s.NewScope(s.Value) - scope.InstanceSet("skip_bindvar", true) - scope.prepareQuerySQL() - - return Expr(fmt.Sprintf("(%v)", scope.SQL), scope.SQLVars...) -} - -// Where return a new relation, filter records with given conditions, accepts `map`, `struct` or `string` as conditions, refer http://jinzhu.github.io/gorm/crud.html#query -func (s *DB) Where(query interface{}, args ...interface{}) *DB { - return s.clone().search.Where(query, args...).db -} - -// Or filter records that match before conditions or this one, similar to `Where` -func (s *DB) Or(query interface{}, args ...interface{}) *DB { - return s.clone().search.Or(query, args...).db -} - -// Not filter records that don't match current conditions, similar to `Where` -func (s *DB) Not(query interface{}, args ...interface{}) *DB { - return s.clone().search.Not(query, args...).db -} - -// Limit specify the number of records to be retrieved -func (s *DB) Limit(limit interface{}) *DB { - return s.clone().search.Limit(limit).db -} - -// Offset specify the number of records to skip before starting to return the records -func (s *DB) Offset(offset interface{}) *DB { - return s.clone().search.Offset(offset).db -} - -// Order specify order when retrieve records from database, set reorder to `true` to overwrite defined conditions -// db.Order("name DESC") -// db.Order("name DESC", true) // reorder -// db.Order(gorm.Expr("name = ? DESC", "first")) // sql expression -func (s *DB) Order(value interface{}, reorder ...bool) *DB { - return s.clone().search.Order(value, reorder...).db -} - -// Select specify fields that you want to retrieve from database when querying, by default, will select all fields; -// When creating/updating, specify fields that you want to save to database -func (s *DB) Select(query interface{}, args ...interface{}) *DB { - return s.clone().search.Select(query, args...).db -} - -// Omit specify fields that you want to ignore when saving to database for creating, updating -func (s *DB) Omit(columns ...string) *DB { - return s.clone().search.Omit(columns...).db -} - -// Group specify the group method on the find -func (s *DB) Group(query string) *DB { - return s.clone().search.Group(query).db -} - -// Having specify HAVING conditions for GROUP BY -func (s *DB) Having(query interface{}, values ...interface{}) *DB { - return s.clone().search.Having(query, values...).db -} - -// Joins specify Joins conditions -// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) -func (s *DB) Joins(query string, args ...interface{}) *DB { - return s.clone().search.Joins(query, args...).db -} - -// Scopes pass current database connection to arguments `func(*DB) *DB`, which could be used to add conditions dynamically -// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB { -// return db.Where("amount > ?", 1000) -// } -// -// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB { -// return func (db *gorm.DB) *gorm.DB { -// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status) -// } -// } -// -// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders) -// Refer https://jinzhu.github.io/gorm/crud.html#scopes -func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB { - for _, f := range funcs { - s = f(s) - } - return s -} - -// Unscoped return all record including deleted record, refer Soft Delete https://jinzhu.github.io/gorm/crud.html#soft-delete -func (s *DB) Unscoped() *DB { - return s.clone().search.unscoped().db -} - -// Attrs initialize struct with argument if record not found with `FirstOrInit` https://jinzhu.github.io/gorm/crud.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/crud.html#firstorcreate -func (s *DB) Attrs(attrs ...interface{}) *DB { - return s.clone().search.Attrs(attrs...).db -} - -// Assign assign result with argument regardless it is found or not with `FirstOrInit` https://jinzhu.github.io/gorm/crud.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/crud.html#firstorcreate -func (s *DB) Assign(attrs ...interface{}) *DB { - return s.clone().search.Assign(attrs...).db -} - -// First find first record that match given conditions, order by primary key -func (s *DB) First(out interface{}, where ...interface{}) *DB { - newScope := s.NewScope(out) - newScope.Search.Limit(1) - return newScope.Set("gorm:order_by_primary_key", "ASC"). - inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db -} - -// Take return a record that match given conditions, the order will depend on the database implementation -func (s *DB) Take(out interface{}, where ...interface{}) *DB { - newScope := s.NewScope(out) - newScope.Search.Limit(1) - return newScope.inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db -} - -// Last find last record that match given conditions, order by primary key -func (s *DB) Last(out interface{}, where ...interface{}) *DB { - newScope := s.NewScope(out) - newScope.Search.Limit(1) - return newScope.Set("gorm:order_by_primary_key", "DESC"). - inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db -} - -// Find find records that match given conditions -func (s *DB) Find(out interface{}, where ...interface{}) *DB { - return s.NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db -} - -//Preloads preloads relations, don`t touch out -func (s *DB) Preloads(out interface{}) *DB { - return s.NewScope(out).InstanceSet("gorm:only_preload", 1).callCallbacks(s.parent.callbacks.queries).db -} - -// Scan scan value to a struct -func (s *DB) Scan(dest interface{}) *DB { - return s.NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db -} - -// Row return `*sql.Row` with given conditions -func (s *DB) Row() *sql.Row { - return s.NewScope(s.Value).row() -} - -// Rows return `*sql.Rows` with given conditions -func (s *DB) Rows() (*sql.Rows, error) { - return s.NewScope(s.Value).rows() -} - -// ScanRows scan `*sql.Rows` to give struct -func (s *DB) ScanRows(rows *sql.Rows, result interface{}) error { - var ( - scope = s.NewScope(result) - clone = scope.db - columns, err = rows.Columns() - ) - - if clone.AddError(err) == nil { - scope.scan(rows, columns, scope.Fields()) - } - - return clone.Error -} - -// Pluck used to query single column from a model as a map -// var ages []int64 -// db.Find(&users).Pluck("age", &ages) -func (s *DB) Pluck(column string, value interface{}) *DB { - return s.NewScope(s.Value).pluck(column, value).db -} - -// Count get how many records for a model -func (s *DB) Count(value interface{}) *DB { - return s.NewScope(s.Value).count(value).db -} - -// Related get related associations -func (s *DB) Related(value interface{}, foreignKeys ...string) *DB { - return s.NewScope(s.Value).related(value, foreignKeys...).db -} - -// FirstOrInit find first matched record or initialize a new one with given conditions (only works with struct, map conditions) -// https://jinzhu.github.io/gorm/crud.html#firstorinit -func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB { - c := s.clone() - if result := c.First(out, where...); result.Error != nil { - if !result.RecordNotFound() { - return result - } - c.NewScope(out).inlineCondition(where...).initialize() - } else { - c.NewScope(out).updatedAttrsWithValues(c.search.assignAttrs) - } - return c -} - -// FirstOrCreate find first matched record or create a new one with given conditions (only works with struct, map conditions) -// https://jinzhu.github.io/gorm/crud.html#firstorcreate -func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB { - c := s.clone() - if result := s.First(out, where...); result.Error != nil { - if !result.RecordNotFound() { - return result - } - return c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(c.parent.callbacks.creates).db - } else if len(c.search.assignAttrs) > 0 { - return c.NewScope(out).InstanceSet("gorm:update_interface", c.search.assignAttrs).callCallbacks(c.parent.callbacks.updates).db - } - return c -} - -// Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update -func (s *DB) Update(attrs ...interface{}) *DB { - return s.Updates(toSearchableMap(attrs...), true) -} - -// Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update -func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB { - return s.NewScope(s.Value). - Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0). - InstanceSet("gorm:update_interface", values). - callCallbacks(s.parent.callbacks.updates).db -} - -// UpdateColumn update attributes without callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update -func (s *DB) UpdateColumn(attrs ...interface{}) *DB { - return s.UpdateColumns(toSearchableMap(attrs...)) -} - -// UpdateColumns update attributes without callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update -func (s *DB) UpdateColumns(values interface{}) *DB { - return s.NewScope(s.Value). - Set("gorm:update_column", true). - Set("gorm:save_associations", false). - InstanceSet("gorm:update_interface", values). - callCallbacks(s.parent.callbacks.updates).db -} - -// Save update value in database, if the value doesn't have primary key, will insert it -func (s *DB) Save(value interface{}) *DB { - scope := s.NewScope(value) - if !scope.PrimaryKeyZero() { - newDB := scope.callCallbacks(s.parent.callbacks.updates).db - if newDB.Error == nil && newDB.RowsAffected == 0 { - return s.New().FirstOrCreate(value) - } - return newDB - } - return scope.callCallbacks(s.parent.callbacks.creates).db -} - -// Create insert the value into database -func (s *DB) Create(value interface{}) *DB { - scope := s.NewScope(value) - return scope.callCallbacks(s.parent.callbacks.creates).db -} - -// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition -func (s *DB) Delete(value interface{}, where ...interface{}) *DB { - return s.NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callbacks.deletes).db -} - -// Raw use raw sql as conditions, won't run it unless invoked by other methods -// db.Raw("SELECT name, age FROM users WHERE name = ?", 3).Scan(&result) -func (s *DB) Raw(sql string, values ...interface{}) *DB { - return s.clone().search.Raw(true).Where(sql, values...).db -} - -// Exec execute raw sql -func (s *DB) Exec(sql string, values ...interface{}) *DB { - scope := s.NewScope(nil) - generatedSQL := scope.buildCondition(map[string]interface{}{"query": sql, "args": values}, true) - generatedSQL = strings.TrimSuffix(strings.TrimPrefix(generatedSQL, "("), ")") - scope.Raw(generatedSQL) - return scope.Exec().db -} - -// Model specify the model you would like to run db operations -// // update all users's name to `hello` -// db.Model(&User{}).Update("name", "hello") -// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello` -// db.Model(&user).Update("name", "hello") -func (s *DB) Model(value interface{}) *DB { - c := s.clone() - c.Value = value - return c -} - -// Table specify the table you would like to run db operations -func (s *DB) Table(name string) *DB { - clone := s.clone() - clone.search.Table(name) - clone.Value = nil - return clone -} - -// Debug start debug mode -func (s *DB) Debug() *DB { - return s.clone().LogMode(true) -} - -// Begin begin a transaction -func (s *DB) Begin() *DB { - c := s.clone() - if db, ok := c.db.(sqlDb); ok && db != nil { - tx, err := db.Begin() - c.db = interface{}(tx).(SQLCommon) - - c.dialect.SetDB(c.db) - c.AddError(err) - } else { - c.AddError(ErrCantStartTransaction) - } - return c -} - -// Commit commit a transaction -func (s *DB) Commit() *DB { - var emptySQLTx *sql.Tx - if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx { - s.AddError(db.Commit()) - } else { - s.AddError(ErrInvalidTransaction) - } - return s -} - -// Rollback rollback a transaction -func (s *DB) Rollback() *DB { - var emptySQLTx *sql.Tx - if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx { - s.AddError(db.Rollback()) - } else { - s.AddError(ErrInvalidTransaction) - } - return s -} - -// NewRecord check if value's primary key is blank -func (s *DB) NewRecord(value interface{}) bool { - return s.NewScope(value).PrimaryKeyZero() -} - -// RecordNotFound check if returning ErrRecordNotFound error -func (s *DB) RecordNotFound() bool { - for _, err := range s.GetErrors() { - if err == ErrRecordNotFound { - return true - } - } - return false -} - -// CreateTable create table for models -func (s *DB) CreateTable(models ...interface{}) *DB { - db := s.Unscoped() - for _, model := range models { - db = db.NewScope(model).createTable().db - } - return db -} - -// DropTable drop table for models -func (s *DB) DropTable(values ...interface{}) *DB { - db := s.clone() - for _, value := range values { - if tableName, ok := value.(string); ok { - db = db.Table(tableName) - } - - db = db.NewScope(value).dropTable().db - } - return db -} - -// DropTableIfExists drop table if it is exist -func (s *DB) DropTableIfExists(values ...interface{}) *DB { - db := s.clone() - for _, value := range values { - if s.HasTable(value) { - db.AddError(s.DropTable(value).Error) - } - } - return db -} - -// HasTable check has table or not -func (s *DB) HasTable(value interface{}) bool { - var ( - scope = s.NewScope(value) - tableName string - ) - - if name, ok := value.(string); ok { - tableName = name - } else { - tableName = scope.TableName() - } - - has := scope.Dialect().HasTable(tableName) - s.AddError(scope.db.Error) - return has -} - -// AutoMigrate run auto migration for given models, will only add missing fields, won't delete/change current data -func (s *DB) AutoMigrate(values ...interface{}) *DB { - db := s.Unscoped() - for _, value := range values { - db = db.NewScope(value).autoMigrate().db - } - return db -} - -// ModifyColumn modify column to type -func (s *DB) ModifyColumn(column string, typ string) *DB { - scope := s.NewScope(s.Value) - scope.modifyColumn(column, typ) - return scope.db -} - -// DropColumn drop a column -func (s *DB) DropColumn(column string) *DB { - scope := s.NewScope(s.Value) - scope.dropColumn(column) - return scope.db -} - -// AddIndex add index for columns with given name -func (s *DB) AddIndex(indexName string, columns ...string) *DB { - scope := s.Unscoped().NewScope(s.Value) - scope.addIndex(false, indexName, columns...) - return scope.db -} - -// AddUniqueIndex add unique index for columns with given name -func (s *DB) AddUniqueIndex(indexName string, columns ...string) *DB { - scope := s.Unscoped().NewScope(s.Value) - scope.addIndex(true, indexName, columns...) - return scope.db -} - -// RemoveIndex remove index with name -func (s *DB) RemoveIndex(indexName string) *DB { - scope := s.NewScope(s.Value) - scope.removeIndex(indexName) - return scope.db -} - -// AddForeignKey Add foreign key to the given scope, e.g: -// db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT") -func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB { - scope := s.NewScope(s.Value) - scope.addForeignKey(field, dest, onDelete, onUpdate) - return scope.db -} - -// RemoveForeignKey Remove foreign key from the given scope, e.g: -// db.Model(&User{}).RemoveForeignKey("city_id", "cities(id)") -func (s *DB) RemoveForeignKey(field string, dest string) *DB { - scope := s.clone().NewScope(s.Value) - scope.removeForeignKey(field, dest) - return scope.db -} - -// Association start `Association Mode` to handler relations things easir in that mode, refer: https://jinzhu.github.io/gorm/associations.html#association-mode -func (s *DB) Association(column string) *Association { - var err error - var scope = s.Set("gorm:association:source", s.Value).NewScope(s.Value) - - if primaryField := scope.PrimaryField(); primaryField.IsBlank { - err = errors.New("primary key can't be nil") - } else { - if field, ok := scope.FieldByName(column); ok { - if field.Relationship == nil || len(field.Relationship.ForeignFieldNames) == 0 { - err = fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type()) - } else { - return &Association{scope: scope, column: column, field: field} - } - } else { - err = fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column) - } - } - - return &Association{Error: err} -} - -// Preload preload associations with given conditions -// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users) -func (s *DB) Preload(column string, conditions ...interface{}) *DB { - return s.clone().search.Preload(column, conditions...).db -} - -// Set set setting by name, which could be used in callbacks, will clone a new db, and update its setting -func (s *DB) Set(name string, value interface{}) *DB { - return s.clone().InstantSet(name, value) -} - -// InstantSet instant set setting, will affect current db -func (s *DB) InstantSet(name string, value interface{}) *DB { - s.values.Store(name, value) - return s -} - -// Get get setting by name -func (s *DB) Get(name string) (value interface{}, ok bool) { - value, ok = s.values.Load(name) - return -} - -// SetJoinTableHandler set a model's join table handler for a relation -func (s *DB) SetJoinTableHandler(source interface{}, column string, handler JoinTableHandlerInterface) { - scope := s.NewScope(source) - for _, field := range scope.GetModelStruct().StructFields { - if field.Name == column || field.DBName == column { - if many2many, _ := field.TagSettingsGet("MANY2MANY"); many2many != "" { - source := (&Scope{Value: source}).GetModelStruct().ModelType - destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType - handler.Setup(field.Relationship, many2many, source, destination) - field.Relationship.JoinTableHandler = handler - if table := handler.Table(s); scope.Dialect().HasTable(table) { - s.Table(table).AutoMigrate(handler) - } - } - } - } -} - -// AddError add error to the db -func (s *DB) AddError(err error) error { - if err != nil { - if err != ErrRecordNotFound { - if s.logMode == 0 { - go s.print(fileWithLineNum(), err) - } else { - s.log(err) - } - - errors := Errors(s.GetErrors()) - errors = errors.Add(err) - if len(errors) > 1 { - err = errors - } - } - - s.Error = err - } - return err -} - -// GetErrors get happened errors from the db -func (s *DB) GetErrors() []error { - if errs, ok := s.Error.(Errors); ok { - return errs - } else if s.Error != nil { - return []error{s.Error} - } - return []error{} -} - -//////////////////////////////////////////////////////////////////////////////// -// Private Methods For DB -//////////////////////////////////////////////////////////////////////////////// - -func (s *DB) clone() *DB { - db := &DB{ - db: s.db, - parent: s.parent, - logger: s.logger, - logMode: s.logMode, - Value: s.Value, - Error: s.Error, - blockGlobalUpdate: s.blockGlobalUpdate, - dialect: newDialect(s.dialect.GetName(), s.db), - } - - s.values.Range(func(k, v interface{}) bool { - db.values.Store(k, v) - return true - }) - - if s.search == nil { - db.search = &search{limit: -1, offset: -1} - } else { - db.search = s.search.clone() - } - - db.search.db = db - return db -} - -func (s *DB) print(v ...interface{}) { - s.logger.Print(v...) -} - -func (s *DB) log(v ...interface{}) { - if s != nil && s.logMode == 2 { - s.print(append([]interface{}{"log", fileWithLineNum()}, v...)...) - } -} - -func (s *DB) slog(sql string, t time.Time, vars ...interface{}) { - if s.logMode == 2 { - s.print("sql", fileWithLineNum(), NowFunc().Sub(t), sql, vars, s.RowsAffected) - } -} diff --git a/vendor/github.com/jinzhu/gorm/model.go b/vendor/github.com/jinzhu/gorm/model.go deleted file mode 100644 index f37ff7eaa..000000000 --- a/vendor/github.com/jinzhu/gorm/model.go +++ /dev/null @@ -1,14 +0,0 @@ -package gorm - -import "time" - -// Model base model definition, including fields `ID`, `CreatedAt`, `UpdatedAt`, `DeletedAt`, which could be embedded in your models -// type User struct { -// gorm.Model -// } -type Model struct { - ID uint `gorm:"primary_key"` - CreatedAt time.Time - UpdatedAt time.Time - DeletedAt *time.Time `sql:"index"` -} diff --git a/vendor/github.com/jinzhu/gorm/model_struct.go b/vendor/github.com/jinzhu/gorm/model_struct.go deleted file mode 100644 index 8c27e2097..000000000 --- a/vendor/github.com/jinzhu/gorm/model_struct.go +++ /dev/null @@ -1,640 +0,0 @@ -package gorm - -import ( - "database/sql" - "errors" - "go/ast" - "reflect" - "strings" - "sync" - "time" - - "github.com/jinzhu/inflection" -) - -// DefaultTableNameHandler default table name handler -var DefaultTableNameHandler = func(db *DB, defaultTableName string) string { - return defaultTableName -} - -var modelStructsMap sync.Map - -// ModelStruct model definition -type ModelStruct struct { - PrimaryFields []*StructField - StructFields []*StructField - ModelType reflect.Type - - defaultTableName string - l sync.Mutex -} - -// TableName returns model's table name -func (s *ModelStruct) TableName(db *DB) string { - s.l.Lock() - defer s.l.Unlock() - - if s.defaultTableName == "" && db != nil && s.ModelType != nil { - // Set default table name - if tabler, ok := reflect.New(s.ModelType).Interface().(tabler); ok { - s.defaultTableName = tabler.TableName() - } else { - tableName := ToTableName(s.ModelType.Name()) - if db == nil || !db.parent.singularTable { - tableName = inflection.Plural(tableName) - } - s.defaultTableName = tableName - } - } - - return DefaultTableNameHandler(db, s.defaultTableName) -} - -// StructField model field's struct definition -type StructField struct { - DBName string - Name string - Names []string - IsPrimaryKey bool - IsNormal bool - IsIgnored bool - IsScanner bool - HasDefaultValue bool - Tag reflect.StructTag - TagSettings map[string]string - Struct reflect.StructField - IsForeignKey bool - Relationship *Relationship - - tagSettingsLock sync.RWMutex -} - -// TagSettingsSet Sets a tag in the tag settings map -func (s *StructField) TagSettingsSet(key, val string) { - s.tagSettingsLock.Lock() - defer s.tagSettingsLock.Unlock() - s.TagSettings[key] = val -} - -// TagSettingsGet returns a tag from the tag settings -func (s *StructField) TagSettingsGet(key string) (string, bool) { - s.tagSettingsLock.RLock() - defer s.tagSettingsLock.RUnlock() - val, ok := s.TagSettings[key] - return val, ok -} - -// TagSettingsDelete deletes a tag -func (s *StructField) TagSettingsDelete(key string) { - s.tagSettingsLock.Lock() - defer s.tagSettingsLock.Unlock() - delete(s.TagSettings, key) -} - -func (structField *StructField) clone() *StructField { - clone := &StructField{ - DBName: structField.DBName, - Name: structField.Name, - Names: structField.Names, - IsPrimaryKey: structField.IsPrimaryKey, - IsNormal: structField.IsNormal, - IsIgnored: structField.IsIgnored, - IsScanner: structField.IsScanner, - HasDefaultValue: structField.HasDefaultValue, - Tag: structField.Tag, - TagSettings: map[string]string{}, - Struct: structField.Struct, - IsForeignKey: structField.IsForeignKey, - } - - if structField.Relationship != nil { - relationship := *structField.Relationship - clone.Relationship = &relationship - } - - // copy the struct field tagSettings, they should be read-locked while they are copied - structField.tagSettingsLock.Lock() - defer structField.tagSettingsLock.Unlock() - for key, value := range structField.TagSettings { - clone.TagSettings[key] = value - } - - return clone -} - -// Relationship described the relationship between models -type Relationship struct { - Kind string - PolymorphicType string - PolymorphicDBName string - PolymorphicValue string - ForeignFieldNames []string - ForeignDBNames []string - AssociationForeignFieldNames []string - AssociationForeignDBNames []string - JoinTableHandler JoinTableHandlerInterface -} - -func getForeignField(column string, fields []*StructField) *StructField { - for _, field := range fields { - if field.Name == column || field.DBName == column || field.DBName == ToColumnName(column) { - return field - } - } - return nil -} - -// GetModelStruct get value's model struct, relationships based on struct and tag definition -func (scope *Scope) GetModelStruct() *ModelStruct { - var modelStruct ModelStruct - // Scope value can't be nil - if scope.Value == nil { - return &modelStruct - } - - reflectType := reflect.ValueOf(scope.Value).Type() - for reflectType.Kind() == reflect.Slice || reflectType.Kind() == reflect.Ptr { - reflectType = reflectType.Elem() - } - - // Scope value need to be a struct - if reflectType.Kind() != reflect.Struct { - return &modelStruct - } - - // Get Cached model struct - if value, ok := modelStructsMap.Load(reflectType); ok && value != nil { - return value.(*ModelStruct) - } - - modelStruct.ModelType = reflectType - - // Get all fields - for i := 0; i < reflectType.NumField(); i++ { - if fieldStruct := reflectType.Field(i); ast.IsExported(fieldStruct.Name) { - field := &StructField{ - Struct: fieldStruct, - Name: fieldStruct.Name, - Names: []string{fieldStruct.Name}, - Tag: fieldStruct.Tag, - TagSettings: parseTagSetting(fieldStruct.Tag), - } - - // is ignored field - if _, ok := field.TagSettingsGet("-"); ok { - field.IsIgnored = true - } else { - if _, ok := field.TagSettingsGet("PRIMARY_KEY"); ok { - field.IsPrimaryKey = true - modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) - } - - if _, ok := field.TagSettingsGet("DEFAULT"); ok { - field.HasDefaultValue = true - } - - if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok && !field.IsPrimaryKey { - field.HasDefaultValue = true - } - - indirectType := fieldStruct.Type - for indirectType.Kind() == reflect.Ptr { - indirectType = indirectType.Elem() - } - - fieldValue := reflect.New(indirectType).Interface() - if _, isScanner := fieldValue.(sql.Scanner); isScanner { - // is scanner - field.IsScanner, field.IsNormal = true, true - if indirectType.Kind() == reflect.Struct { - for i := 0; i < indirectType.NumField(); i++ { - for key, value := range parseTagSetting(indirectType.Field(i).Tag) { - if _, ok := field.TagSettingsGet(key); !ok { - field.TagSettingsSet(key, value) - } - } - } - } - } else if _, isTime := fieldValue.(*time.Time); isTime { - // is time - field.IsNormal = true - } else if _, ok := field.TagSettingsGet("EMBEDDED"); ok || fieldStruct.Anonymous { - // is embedded struct - for _, subField := range scope.New(fieldValue).GetModelStruct().StructFields { - subField = subField.clone() - subField.Names = append([]string{fieldStruct.Name}, subField.Names...) - if prefix, ok := field.TagSettingsGet("EMBEDDED_PREFIX"); ok { - subField.DBName = prefix + subField.DBName - } - - if subField.IsPrimaryKey { - if _, ok := subField.TagSettingsGet("PRIMARY_KEY"); ok { - modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, subField) - } else { - subField.IsPrimaryKey = false - } - } - - if subField.Relationship != nil && subField.Relationship.JoinTableHandler != nil { - if joinTableHandler, ok := subField.Relationship.JoinTableHandler.(*JoinTableHandler); ok { - newJoinTableHandler := &JoinTableHandler{} - newJoinTableHandler.Setup(subField.Relationship, joinTableHandler.TableName, reflectType, joinTableHandler.Destination.ModelType) - subField.Relationship.JoinTableHandler = newJoinTableHandler - } - } - - modelStruct.StructFields = append(modelStruct.StructFields, subField) - } - continue - } else { - // build relationships - switch indirectType.Kind() { - case reflect.Slice: - defer func(field *StructField) { - var ( - relationship = &Relationship{} - toScope = scope.New(reflect.New(field.Struct.Type).Interface()) - foreignKeys []string - associationForeignKeys []string - elemType = field.Struct.Type - ) - - if foreignKey, _ := field.TagSettingsGet("FOREIGNKEY"); foreignKey != "" { - foreignKeys = strings.Split(foreignKey, ",") - } - - if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_FOREIGNKEY"); foreignKey != "" { - associationForeignKeys = strings.Split(foreignKey, ",") - } else if foreignKey, _ := field.TagSettingsGet("ASSOCIATIONFOREIGNKEY"); foreignKey != "" { - associationForeignKeys = strings.Split(foreignKey, ",") - } - - for elemType.Kind() == reflect.Slice || elemType.Kind() == reflect.Ptr { - elemType = elemType.Elem() - } - - if elemType.Kind() == reflect.Struct { - if many2many, _ := field.TagSettingsGet("MANY2MANY"); many2many != "" { - relationship.Kind = "many_to_many" - - { // Foreign Keys for Source - joinTableDBNames := []string{} - - if foreignKey, _ := field.TagSettingsGet("JOINTABLE_FOREIGNKEY"); foreignKey != "" { - joinTableDBNames = strings.Split(foreignKey, ",") - } - - // if no foreign keys defined with tag - if len(foreignKeys) == 0 { - for _, field := range modelStruct.PrimaryFields { - foreignKeys = append(foreignKeys, field.DBName) - } - } - - for idx, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil { - // source foreign keys (db names) - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.DBName) - - // setup join table foreign keys for source - if len(joinTableDBNames) > idx { - // if defined join table's foreign key - relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBNames[idx]) - } else { - defaultJointableForeignKey := ToColumnName(reflectType.Name()) + "_" + foreignField.DBName - relationship.ForeignDBNames = append(relationship.ForeignDBNames, defaultJointableForeignKey) - } - } - } - } - - { // Foreign Keys for Association (Destination) - associationJoinTableDBNames := []string{} - - if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_JOINTABLE_FOREIGNKEY"); foreignKey != "" { - associationJoinTableDBNames = strings.Split(foreignKey, ",") - } - - // if no association foreign keys defined with tag - if len(associationForeignKeys) == 0 { - for _, field := range toScope.PrimaryFields() { - associationForeignKeys = append(associationForeignKeys, field.DBName) - } - } - - for idx, name := range associationForeignKeys { - if field, ok := toScope.FieldByName(name); ok { - // association foreign keys (db names) - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName) - - // setup join table foreign keys for association - if len(associationJoinTableDBNames) > idx { - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationJoinTableDBNames[idx]) - } else { - // join table foreign keys for association - joinTableDBName := ToColumnName(elemType.Name()) + "_" + field.DBName - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName) - } - } - } - } - - joinTableHandler := JoinTableHandler{} - joinTableHandler.Setup(relationship, ToTableName(many2many), reflectType, elemType) - relationship.JoinTableHandler = &joinTableHandler - field.Relationship = relationship - } else { - // User has many comments, associationType is User, comment use UserID as foreign key - var associationType = reflectType.Name() - var toFields = toScope.GetStructFields() - relationship.Kind = "has_many" - - if polymorphic, _ := field.TagSettingsGet("POLYMORPHIC"); polymorphic != "" { - // Dog has many toys, tag polymorphic is Owner, then associationType is Owner - // Toy use OwnerID, OwnerType ('dogs') as foreign key - if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil { - associationType = polymorphic - relationship.PolymorphicType = polymorphicType.Name - relationship.PolymorphicDBName = polymorphicType.DBName - // if Dog has multiple set of toys set name of the set (instead of default 'dogs') - if value, ok := field.TagSettingsGet("POLYMORPHIC_VALUE"); ok { - relationship.PolymorphicValue = value - } else { - relationship.PolymorphicValue = scope.TableName() - } - polymorphicType.IsForeignKey = true - } - } - - // if no foreign keys defined with tag - if len(foreignKeys) == 0 { - // if no association foreign keys defined with tag - if len(associationForeignKeys) == 0 { - for _, field := range modelStruct.PrimaryFields { - foreignKeys = append(foreignKeys, associationType+field.Name) - associationForeignKeys = append(associationForeignKeys, field.Name) - } - } else { - // generate foreign keys from defined association foreign keys - for _, scopeFieldName := range associationForeignKeys { - if foreignField := getForeignField(scopeFieldName, modelStruct.StructFields); foreignField != nil { - foreignKeys = append(foreignKeys, associationType+foreignField.Name) - associationForeignKeys = append(associationForeignKeys, foreignField.Name) - } - } - } - } else { - // generate association foreign keys from foreign keys - if len(associationForeignKeys) == 0 { - for _, foreignKey := range foreignKeys { - if strings.HasPrefix(foreignKey, associationType) { - associationForeignKey := strings.TrimPrefix(foreignKey, associationType) - if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil { - associationForeignKeys = append(associationForeignKeys, associationForeignKey) - } - } - } - if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { - associationForeignKeys = []string{scope.PrimaryKey()} - } - } else if len(foreignKeys) != len(associationForeignKeys) { - scope.Err(errors.New("invalid foreign keys, should have same length")) - return - } - } - - for idx, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, toFields); foreignField != nil { - if associationField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); associationField != nil { - // source foreign keys - foreignField.IsForeignKey = true - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name) - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName) - - // association foreign keys - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) - relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) - } - } - } - - if len(relationship.ForeignFieldNames) != 0 { - field.Relationship = relationship - } - } - } else { - field.IsNormal = true - } - }(field) - case reflect.Struct: - defer func(field *StructField) { - var ( - // user has one profile, associationType is User, profile use UserID as foreign key - // user belongs to profile, associationType is Profile, user use ProfileID as foreign key - associationType = reflectType.Name() - relationship = &Relationship{} - toScope = scope.New(reflect.New(field.Struct.Type).Interface()) - toFields = toScope.GetStructFields() - tagForeignKeys []string - tagAssociationForeignKeys []string - ) - - if foreignKey, _ := field.TagSettingsGet("FOREIGNKEY"); foreignKey != "" { - tagForeignKeys = strings.Split(foreignKey, ",") - } - - if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_FOREIGNKEY"); foreignKey != "" { - tagAssociationForeignKeys = strings.Split(foreignKey, ",") - } else if foreignKey, _ := field.TagSettingsGet("ASSOCIATIONFOREIGNKEY"); foreignKey != "" { - tagAssociationForeignKeys = strings.Split(foreignKey, ",") - } - - if polymorphic, _ := field.TagSettingsGet("POLYMORPHIC"); polymorphic != "" { - // Cat has one toy, tag polymorphic is Owner, then associationType is Owner - // Toy use OwnerID, OwnerType ('cats') as foreign key - if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil { - associationType = polymorphic - relationship.PolymorphicType = polymorphicType.Name - relationship.PolymorphicDBName = polymorphicType.DBName - // if Cat has several different types of toys set name for each (instead of default 'cats') - if value, ok := field.TagSettingsGet("POLYMORPHIC_VALUE"); ok { - relationship.PolymorphicValue = value - } else { - relationship.PolymorphicValue = scope.TableName() - } - polymorphicType.IsForeignKey = true - } - } - - // Has One - { - var foreignKeys = tagForeignKeys - var associationForeignKeys = tagAssociationForeignKeys - // if no foreign keys defined with tag - if len(foreignKeys) == 0 { - // if no association foreign keys defined with tag - if len(associationForeignKeys) == 0 { - for _, primaryField := range modelStruct.PrimaryFields { - foreignKeys = append(foreignKeys, associationType+primaryField.Name) - associationForeignKeys = append(associationForeignKeys, primaryField.Name) - } - } else { - // generate foreign keys form association foreign keys - for _, associationForeignKey := range tagAssociationForeignKeys { - if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil { - foreignKeys = append(foreignKeys, associationType+foreignField.Name) - associationForeignKeys = append(associationForeignKeys, foreignField.Name) - } - } - } - } else { - // generate association foreign keys from foreign keys - if len(associationForeignKeys) == 0 { - for _, foreignKey := range foreignKeys { - if strings.HasPrefix(foreignKey, associationType) { - associationForeignKey := strings.TrimPrefix(foreignKey, associationType) - if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil { - associationForeignKeys = append(associationForeignKeys, associationForeignKey) - } - } - } - if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { - associationForeignKeys = []string{scope.PrimaryKey()} - } - } else if len(foreignKeys) != len(associationForeignKeys) { - scope.Err(errors.New("invalid foreign keys, should have same length")) - return - } - } - - for idx, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, toFields); foreignField != nil { - if scopeField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); scopeField != nil { - foreignField.IsForeignKey = true - // source foreign keys - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scopeField.Name) - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scopeField.DBName) - - // association foreign keys - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) - relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) - } - } - } - } - - if len(relationship.ForeignFieldNames) != 0 { - relationship.Kind = "has_one" - field.Relationship = relationship - } else { - var foreignKeys = tagForeignKeys - var associationForeignKeys = tagAssociationForeignKeys - - if len(foreignKeys) == 0 { - // generate foreign keys & association foreign keys - if len(associationForeignKeys) == 0 { - for _, primaryField := range toScope.PrimaryFields() { - foreignKeys = append(foreignKeys, field.Name+primaryField.Name) - associationForeignKeys = append(associationForeignKeys, primaryField.Name) - } - } else { - // generate foreign keys with association foreign keys - for _, associationForeignKey := range associationForeignKeys { - if foreignField := getForeignField(associationForeignKey, toFields); foreignField != nil { - foreignKeys = append(foreignKeys, field.Name+foreignField.Name) - associationForeignKeys = append(associationForeignKeys, foreignField.Name) - } - } - } - } else { - // generate foreign keys & association foreign keys - if len(associationForeignKeys) == 0 { - for _, foreignKey := range foreignKeys { - if strings.HasPrefix(foreignKey, field.Name) { - associationForeignKey := strings.TrimPrefix(foreignKey, field.Name) - if foreignField := getForeignField(associationForeignKey, toFields); foreignField != nil { - associationForeignKeys = append(associationForeignKeys, associationForeignKey) - } - } - } - if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { - associationForeignKeys = []string{toScope.PrimaryKey()} - } - } else if len(foreignKeys) != len(associationForeignKeys) { - scope.Err(errors.New("invalid foreign keys, should have same length")) - return - } - } - - for idx, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil { - if associationField := getForeignField(associationForeignKeys[idx], toFields); associationField != nil { - foreignField.IsForeignKey = true - - // association foreign keys - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name) - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName) - - // source foreign keys - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) - relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) - } - } - } - - if len(relationship.ForeignFieldNames) != 0 { - relationship.Kind = "belongs_to" - field.Relationship = relationship - } - } - }(field) - default: - field.IsNormal = true - } - } - } - - // Even it is ignored, also possible to decode db value into the field - if value, ok := field.TagSettingsGet("COLUMN"); ok { - field.DBName = value - } else { - field.DBName = ToColumnName(fieldStruct.Name) - } - - modelStruct.StructFields = append(modelStruct.StructFields, field) - } - } - - if len(modelStruct.PrimaryFields) == 0 { - if field := getForeignField("id", modelStruct.StructFields); field != nil { - field.IsPrimaryKey = true - modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) - } - } - - modelStructsMap.Store(reflectType, &modelStruct) - - return &modelStruct -} - -// GetStructFields get model's field structs -func (scope *Scope) GetStructFields() (fields []*StructField) { - return scope.GetModelStruct().StructFields -} - -func parseTagSetting(tags reflect.StructTag) map[string]string { - setting := map[string]string{} - for _, str := range []string{tags.Get("sql"), tags.Get("gorm")} { - tags := strings.Split(str, ";") - for _, value := range tags { - v := strings.Split(value, ":") - k := strings.TrimSpace(strings.ToUpper(v[0])) - if len(v) >= 2 { - setting[k] = strings.Join(v[1:], ":") - } else { - setting[k] = k - } - } - } - return setting -} diff --git a/vendor/github.com/jinzhu/gorm/naming.go b/vendor/github.com/jinzhu/gorm/naming.go deleted file mode 100644 index 6b0a4fddb..000000000 --- a/vendor/github.com/jinzhu/gorm/naming.go +++ /dev/null @@ -1,124 +0,0 @@ -package gorm - -import ( - "bytes" - "strings" -) - -// Namer is a function type which is given a string and return a string -type Namer func(string) string - -// NamingStrategy represents naming strategies -type NamingStrategy struct { - DB Namer - Table Namer - Column Namer -} - -// TheNamingStrategy is being initialized with defaultNamingStrategy -var TheNamingStrategy = &NamingStrategy{ - DB: defaultNamer, - Table: defaultNamer, - Column: defaultNamer, -} - -// AddNamingStrategy sets the naming strategy -func AddNamingStrategy(ns *NamingStrategy) { - if ns.DB == nil { - ns.DB = defaultNamer - } - if ns.Table == nil { - ns.Table = defaultNamer - } - if ns.Column == nil { - ns.Column = defaultNamer - } - TheNamingStrategy = ns -} - -// DBName alters the given name by DB -func (ns *NamingStrategy) DBName(name string) string { - return ns.DB(name) -} - -// TableName alters the given name by Table -func (ns *NamingStrategy) TableName(name string) string { - return ns.Table(name) -} - -// ColumnName alters the given name by Column -func (ns *NamingStrategy) ColumnName(name string) string { - return ns.Column(name) -} - -// ToDBName convert string to db name -func ToDBName(name string) string { - return TheNamingStrategy.DBName(name) -} - -// ToTableName convert string to table name -func ToTableName(name string) string { - return TheNamingStrategy.TableName(name) -} - -// ToColumnName convert string to db name -func ToColumnName(name string) string { - return TheNamingStrategy.ColumnName(name) -} - -var smap = newSafeMap() - -func defaultNamer(name string) string { - const ( - lower = false - upper = true - ) - - if v := smap.Get(name); v != "" { - return v - } - - if name == "" { - return "" - } - - var ( - value = commonInitialismsReplacer.Replace(name) - buf = bytes.NewBufferString("") - lastCase, currCase, nextCase, nextNumber bool - ) - - for i, v := range value[:len(value)-1] { - nextCase = bool(value[i+1] >= 'A' && value[i+1] <= 'Z') - nextNumber = bool(value[i+1] >= '0' && value[i+1] <= '9') - - if i > 0 { - if currCase == upper { - if lastCase == upper && (nextCase == upper || nextNumber == upper) { - buf.WriteRune(v) - } else { - if value[i-1] != '_' && value[i+1] != '_' { - buf.WriteRune('_') - } - buf.WriteRune(v) - } - } else { - buf.WriteRune(v) - if i == len(value)-2 && (nextCase == upper && nextNumber == lower) { - buf.WriteRune('_') - } - } - } else { - currCase = upper - buf.WriteRune(v) - } - lastCase = currCase - currCase = nextCase - } - - buf.WriteByte(value[len(value)-1]) - - s := strings.ToLower(buf.String()) - smap.Set(name, s) - return s -} diff --git a/vendor/github.com/jinzhu/gorm/scope.go b/vendor/github.com/jinzhu/gorm/scope.go deleted file mode 100644 index 806ccb7d8..000000000 --- a/vendor/github.com/jinzhu/gorm/scope.go +++ /dev/null @@ -1,1397 +0,0 @@ -package gorm - -import ( - "bytes" - "database/sql" - "database/sql/driver" - "errors" - "fmt" - "reflect" - "regexp" - "strings" - "time" -) - -// Scope contain current operation's information when you perform any operation on the database -type Scope struct { - Search *search - Value interface{} - SQL string - SQLVars []interface{} - db *DB - instanceID string - primaryKeyField *Field - skipLeft bool - fields *[]*Field - selectAttrs *[]string -} - -// IndirectValue return scope's reflect value's indirect value -func (scope *Scope) IndirectValue() reflect.Value { - return indirect(reflect.ValueOf(scope.Value)) -} - -// New create a new Scope without search information -func (scope *Scope) New(value interface{}) *Scope { - return &Scope{db: scope.NewDB(), Search: &search{}, Value: value} -} - -//////////////////////////////////////////////////////////////////////////////// -// Scope DB -//////////////////////////////////////////////////////////////////////////////// - -// DB return scope's DB connection -func (scope *Scope) DB() *DB { - return scope.db -} - -// NewDB create a new DB without search information -func (scope *Scope) NewDB() *DB { - if scope.db != nil { - db := scope.db.clone() - db.search = nil - db.Value = nil - return db - } - return nil -} - -// SQLDB return *sql.DB -func (scope *Scope) SQLDB() SQLCommon { - return scope.db.db -} - -// Dialect get dialect -func (scope *Scope) Dialect() Dialect { - return scope.db.dialect -} - -// Quote used to quote string to escape them for database -func (scope *Scope) Quote(str string) string { - if strings.Contains(str, ".") { - newStrs := []string{} - for _, str := range strings.Split(str, ".") { - newStrs = append(newStrs, scope.Dialect().Quote(str)) - } - return strings.Join(newStrs, ".") - } - - return scope.Dialect().Quote(str) -} - -// Err add error to Scope -func (scope *Scope) Err(err error) error { - if err != nil { - scope.db.AddError(err) - } - return err -} - -// HasError check if there are any error -func (scope *Scope) HasError() bool { - return scope.db.Error != nil -} - -// Log print log message -func (scope *Scope) Log(v ...interface{}) { - scope.db.log(v...) -} - -// SkipLeft skip remaining callbacks -func (scope *Scope) SkipLeft() { - scope.skipLeft = true -} - -// Fields get value's fields -func (scope *Scope) Fields() []*Field { - if scope.fields == nil { - var ( - fields []*Field - indirectScopeValue = scope.IndirectValue() - isStruct = indirectScopeValue.Kind() == reflect.Struct - ) - - for _, structField := range scope.GetModelStruct().StructFields { - if isStruct { - fieldValue := indirectScopeValue - for _, name := range structField.Names { - if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { - fieldValue.Set(reflect.New(fieldValue.Type().Elem())) - } - fieldValue = reflect.Indirect(fieldValue).FieldByName(name) - } - fields = append(fields, &Field{StructField: structField, Field: fieldValue, IsBlank: isBlank(fieldValue)}) - } else { - fields = append(fields, &Field{StructField: structField, IsBlank: true}) - } - } - scope.fields = &fields - } - - return *scope.fields -} - -// FieldByName find `gorm.Field` with field name or db name -func (scope *Scope) FieldByName(name string) (field *Field, ok bool) { - var ( - dbName = ToColumnName(name) - mostMatchedField *Field - ) - - for _, field := range scope.Fields() { - if field.Name == name || field.DBName == name { - return field, true - } - if field.DBName == dbName { - mostMatchedField = field - } - } - return mostMatchedField, mostMatchedField != nil -} - -// PrimaryFields return scope's primary fields -func (scope *Scope) PrimaryFields() (fields []*Field) { - for _, field := range scope.Fields() { - if field.IsPrimaryKey { - fields = append(fields, field) - } - } - return fields -} - -// PrimaryField return scope's main primary field, if defined more that one primary fields, will return the one having column name `id` or the first one -func (scope *Scope) PrimaryField() *Field { - if primaryFields := scope.GetModelStruct().PrimaryFields; len(primaryFields) > 0 { - if len(primaryFields) > 1 { - if field, ok := scope.FieldByName("id"); ok { - return field - } - } - return scope.PrimaryFields()[0] - } - return nil -} - -// PrimaryKey get main primary field's db name -func (scope *Scope) PrimaryKey() string { - if field := scope.PrimaryField(); field != nil { - return field.DBName - } - return "" -} - -// PrimaryKeyZero check main primary field's value is blank or not -func (scope *Scope) PrimaryKeyZero() bool { - field := scope.PrimaryField() - return field == nil || field.IsBlank -} - -// PrimaryKeyValue get the primary key's value -func (scope *Scope) PrimaryKeyValue() interface{} { - if field := scope.PrimaryField(); field != nil && field.Field.IsValid() { - return field.Field.Interface() - } - return 0 -} - -// HasColumn to check if has column -func (scope *Scope) HasColumn(column string) bool { - for _, field := range scope.GetStructFields() { - if field.IsNormal && (field.Name == column || field.DBName == column) { - return true - } - } - return false -} - -// SetColumn to set the column's value, column could be field or field's name/dbname -func (scope *Scope) SetColumn(column interface{}, value interface{}) error { - var updateAttrs = map[string]interface{}{} - if attrs, ok := scope.InstanceGet("gorm:update_attrs"); ok { - updateAttrs = attrs.(map[string]interface{}) - defer scope.InstanceSet("gorm:update_attrs", updateAttrs) - } - - if field, ok := column.(*Field); ok { - updateAttrs[field.DBName] = value - return field.Set(value) - } else if name, ok := column.(string); ok { - var ( - dbName = ToDBName(name) - mostMatchedField *Field - ) - for _, field := range scope.Fields() { - if field.DBName == value { - updateAttrs[field.DBName] = value - return field.Set(value) - } - if (field.DBName == dbName) || (field.Name == name && mostMatchedField == nil) { - mostMatchedField = field - } - } - - if mostMatchedField != nil { - updateAttrs[mostMatchedField.DBName] = value - return mostMatchedField.Set(value) - } - } - return errors.New("could not convert column to field") -} - -// CallMethod call scope value's method, if it is a slice, will call its element's method one by one -func (scope *Scope) CallMethod(methodName string) { - if scope.Value == nil { - return - } - - if indirectScopeValue := scope.IndirectValue(); indirectScopeValue.Kind() == reflect.Slice { - for i := 0; i < indirectScopeValue.Len(); i++ { - scope.callMethod(methodName, indirectScopeValue.Index(i)) - } - } else { - scope.callMethod(methodName, indirectScopeValue) - } -} - -// AddToVars add value as sql's vars, used to prevent SQL injection -func (scope *Scope) AddToVars(value interface{}) string { - _, skipBindVar := scope.InstanceGet("skip_bindvar") - - if expr, ok := value.(*expr); ok { - exp := expr.expr - for _, arg := range expr.args { - if skipBindVar { - scope.AddToVars(arg) - } else { - exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1) - } - } - return exp - } - - scope.SQLVars = append(scope.SQLVars, value) - - if skipBindVar { - return "?" - } - return scope.Dialect().BindVar(len(scope.SQLVars)) -} - -// SelectAttrs return selected attributes -func (scope *Scope) SelectAttrs() []string { - if scope.selectAttrs == nil { - attrs := []string{} - for _, value := range scope.Search.selects { - if str, ok := value.(string); ok { - attrs = append(attrs, str) - } else if strs, ok := value.([]string); ok { - attrs = append(attrs, strs...) - } else if strs, ok := value.([]interface{}); ok { - for _, str := range strs { - attrs = append(attrs, fmt.Sprintf("%v", str)) - } - } - } - scope.selectAttrs = &attrs - } - return *scope.selectAttrs -} - -// OmitAttrs return omitted attributes -func (scope *Scope) OmitAttrs() []string { - return scope.Search.omits -} - -type tabler interface { - TableName() string -} - -type dbTabler interface { - TableName(*DB) string -} - -// TableName return table name -func (scope *Scope) TableName() string { - if scope.Search != nil && len(scope.Search.tableName) > 0 { - return scope.Search.tableName - } - - if tabler, ok := scope.Value.(tabler); ok { - return tabler.TableName() - } - - if tabler, ok := scope.Value.(dbTabler); ok { - return tabler.TableName(scope.db) - } - - return scope.GetModelStruct().TableName(scope.db.Model(scope.Value)) -} - -// QuotedTableName return quoted table name -func (scope *Scope) QuotedTableName() (name string) { - if scope.Search != nil && len(scope.Search.tableName) > 0 { - if strings.Contains(scope.Search.tableName, " ") { - return scope.Search.tableName - } - return scope.Quote(scope.Search.tableName) - } - - return scope.Quote(scope.TableName()) -} - -// CombinedConditionSql return combined condition sql -func (scope *Scope) CombinedConditionSql() string { - joinSQL := scope.joinsSQL() - whereSQL := scope.whereSQL() - if scope.Search.raw { - whereSQL = strings.TrimSuffix(strings.TrimPrefix(whereSQL, "WHERE ("), ")") - } - return joinSQL + whereSQL + scope.groupSQL() + - scope.havingSQL() + scope.orderSQL() + scope.limitAndOffsetSQL() -} - -// Raw set raw sql -func (scope *Scope) Raw(sql string) *Scope { - scope.SQL = strings.Replace(sql, "$$$", "?", -1) - return scope -} - -// Exec perform generated SQL -func (scope *Scope) Exec() *Scope { - defer scope.trace(NowFunc()) - - if !scope.HasError() { - if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { - if count, err := result.RowsAffected(); scope.Err(err) == nil { - scope.db.RowsAffected = count - } - } - } - return scope -} - -// Set set value by name -func (scope *Scope) Set(name string, value interface{}) *Scope { - scope.db.InstantSet(name, value) - return scope -} - -// Get get setting by name -func (scope *Scope) Get(name string) (interface{}, bool) { - return scope.db.Get(name) -} - -// InstanceID get InstanceID for scope -func (scope *Scope) InstanceID() string { - if scope.instanceID == "" { - scope.instanceID = fmt.Sprintf("%v%v", &scope, &scope.db) - } - return scope.instanceID -} - -// InstanceSet set instance setting for current operation, but not for operations in callbacks, like saving associations callback -func (scope *Scope) InstanceSet(name string, value interface{}) *Scope { - return scope.Set(name+scope.InstanceID(), value) -} - -// InstanceGet get instance setting from current operation -func (scope *Scope) InstanceGet(name string) (interface{}, bool) { - return scope.Get(name + scope.InstanceID()) -} - -// Begin start a transaction -func (scope *Scope) Begin() *Scope { - if db, ok := scope.SQLDB().(sqlDb); ok { - if tx, err := db.Begin(); err == nil { - scope.db.db = interface{}(tx).(SQLCommon) - scope.InstanceSet("gorm:started_transaction", true) - } - } - return scope -} - -// CommitOrRollback commit current transaction if no error happened, otherwise will rollback it -func (scope *Scope) CommitOrRollback() *Scope { - if _, ok := scope.InstanceGet("gorm:started_transaction"); ok { - if db, ok := scope.db.db.(sqlTx); ok { - if scope.HasError() { - db.Rollback() - } else { - scope.Err(db.Commit()) - } - scope.db.db = scope.db.parent.db - } - } - return scope -} - -//////////////////////////////////////////////////////////////////////////////// -// Private Methods For *gorm.Scope -//////////////////////////////////////////////////////////////////////////////// - -func (scope *Scope) callMethod(methodName string, reflectValue reflect.Value) { - // Only get address from non-pointer - if reflectValue.CanAddr() && reflectValue.Kind() != reflect.Ptr { - reflectValue = reflectValue.Addr() - } - - if methodValue := reflectValue.MethodByName(methodName); methodValue.IsValid() { - switch method := methodValue.Interface().(type) { - case func(): - method() - case func(*Scope): - method(scope) - case func(*DB): - newDB := scope.NewDB() - method(newDB) - scope.Err(newDB.Error) - case func() error: - scope.Err(method()) - case func(*Scope) error: - scope.Err(method(scope)) - case func(*DB) error: - newDB := scope.NewDB() - scope.Err(method(newDB)) - scope.Err(newDB.Error) - default: - scope.Err(fmt.Errorf("unsupported function %v", methodName)) - } - } -} - -var ( - columnRegexp = regexp.MustCompile("^[a-zA-Z\\d]+(\\.[a-zA-Z\\d]+)*$") // only match string like `name`, `users.name` - isNumberRegexp = regexp.MustCompile("^\\s*\\d+\\s*$") // match if string is number - comparisonRegexp = regexp.MustCompile("(?i) (=|<>|(>|<)(=?)|LIKE|IS|IN) ") - countingQueryRegexp = regexp.MustCompile("(?i)^count(.+)$") -) - -func (scope *Scope) quoteIfPossible(str string) string { - if columnRegexp.MatchString(str) { - return scope.Quote(str) - } - return str -} - -func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) { - var ( - ignored interface{} - values = make([]interface{}, len(columns)) - selectFields []*Field - selectedColumnsMap = map[string]int{} - resetFields = map[int]*Field{} - ) - - for index, column := range columns { - values[index] = &ignored - - selectFields = fields - offset := 0 - if idx, ok := selectedColumnsMap[column]; ok { - offset = idx + 1 - selectFields = selectFields[offset:] - } - - for fieldIndex, field := range selectFields { - if field.DBName == column { - if field.Field.Kind() == reflect.Ptr { - values[index] = field.Field.Addr().Interface() - } else { - reflectValue := reflect.New(reflect.PtrTo(field.Struct.Type)) - reflectValue.Elem().Set(field.Field.Addr()) - values[index] = reflectValue.Interface() - resetFields[index] = field - } - - selectedColumnsMap[column] = offset + fieldIndex - - if field.IsNormal { - break - } - } - } - } - - scope.Err(rows.Scan(values...)) - - for index, field := range resetFields { - if v := reflect.ValueOf(values[index]).Elem().Elem(); v.IsValid() { - field.Field.Set(v) - } - } -} - -func (scope *Scope) primaryCondition(value interface{}) string { - return fmt.Sprintf("(%v.%v = %v)", scope.QuotedTableName(), scope.Quote(scope.PrimaryKey()), value) -} - -func (scope *Scope) buildCondition(clause map[string]interface{}, include bool) (str string) { - var ( - quotedTableName = scope.QuotedTableName() - quotedPrimaryKey = scope.Quote(scope.PrimaryKey()) - equalSQL = "=" - inSQL = "IN" - ) - - // If building not conditions - if !include { - equalSQL = "<>" - inSQL = "NOT IN" - } - - switch value := clause["query"].(type) { - case sql.NullInt64: - return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, value.Int64) - case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: - return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, value) - case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}: - if !include && reflect.ValueOf(value).Len() == 0 { - return - } - str = fmt.Sprintf("(%v.%v %s (?))", quotedTableName, quotedPrimaryKey, inSQL) - clause["args"] = []interface{}{value} - case string: - if isNumberRegexp.MatchString(value) { - return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, scope.AddToVars(value)) - } - - if value != "" { - if !include { - if comparisonRegexp.MatchString(value) { - str = fmt.Sprintf("NOT (%v)", value) - } else { - str = fmt.Sprintf("(%v.%v NOT IN (?))", quotedTableName, scope.Quote(value)) - } - } else { - str = fmt.Sprintf("(%v)", value) - } - } - case map[string]interface{}: - var sqls []string - for key, value := range value { - if value != nil { - sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", quotedTableName, scope.Quote(key), equalSQL, scope.AddToVars(value))) - } else { - if !include { - sqls = append(sqls, fmt.Sprintf("(%v.%v IS NOT NULL)", quotedTableName, scope.Quote(key))) - } else { - sqls = append(sqls, fmt.Sprintf("(%v.%v IS NULL)", quotedTableName, scope.Quote(key))) - } - } - } - return strings.Join(sqls, " AND ") - case interface{}: - var sqls []string - newScope := scope.New(value) - - if len(newScope.Fields()) == 0 { - scope.Err(fmt.Errorf("invalid query condition: %v", value)) - return - } - scopeQuotedTableName := newScope.QuotedTableName() - for _, field := range newScope.Fields() { - if !field.IsIgnored && !field.IsBlank { - sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", scopeQuotedTableName, scope.Quote(field.DBName), equalSQL, scope.AddToVars(field.Field.Interface()))) - } - } - return strings.Join(sqls, " AND ") - default: - scope.Err(fmt.Errorf("invalid query condition: %v", value)) - return - } - - replacements := []string{} - args := clause["args"].([]interface{}) - for _, arg := range args { - var err error - switch reflect.ValueOf(arg).Kind() { - case reflect.Slice: // For where("id in (?)", []int64{1,2}) - if scanner, ok := interface{}(arg).(driver.Valuer); ok { - arg, err = scanner.Value() - replacements = append(replacements, scope.AddToVars(arg)) - } else if b, ok := arg.([]byte); ok { - replacements = append(replacements, scope.AddToVars(b)) - } else if as, ok := arg.([][]interface{}); ok { - var tempMarks []string - for _, a := range as { - var arrayMarks []string - for _, v := range a { - arrayMarks = append(arrayMarks, scope.AddToVars(v)) - } - - if len(arrayMarks) > 0 { - tempMarks = append(tempMarks, fmt.Sprintf("(%v)", strings.Join(arrayMarks, ","))) - } - } - - if len(tempMarks) > 0 { - replacements = append(replacements, strings.Join(tempMarks, ",")) - } - } else if values := reflect.ValueOf(arg); values.Len() > 0 { - var tempMarks []string - for i := 0; i < values.Len(); i++ { - tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) - } - replacements = append(replacements, strings.Join(tempMarks, ",")) - } else { - replacements = append(replacements, scope.AddToVars(Expr("NULL"))) - } - default: - if valuer, ok := interface{}(arg).(driver.Valuer); ok { - arg, err = valuer.Value() - } - - replacements = append(replacements, scope.AddToVars(arg)) - } - - if err != nil { - scope.Err(err) - } - } - - buff := bytes.NewBuffer([]byte{}) - i := 0 - for _, s := range str { - if s == '?' && len(replacements) > i { - buff.WriteString(replacements[i]) - i++ - } else { - buff.WriteRune(s) - } - } - - str = buff.String() - - return -} - -func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) { - switch value := clause["query"].(type) { - case string: - str = value - case []string: - str = strings.Join(value, ", ") - } - - args := clause["args"].([]interface{}) - replacements := []string{} - for _, arg := range args { - switch reflect.ValueOf(arg).Kind() { - case reflect.Slice: - values := reflect.ValueOf(arg) - var tempMarks []string - for i := 0; i < values.Len(); i++ { - tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) - } - replacements = append(replacements, strings.Join(tempMarks, ",")) - default: - if valuer, ok := interface{}(arg).(driver.Valuer); ok { - arg, _ = valuer.Value() - } - replacements = append(replacements, scope.AddToVars(arg)) - } - } - - buff := bytes.NewBuffer([]byte{}) - i := 0 - for pos, char := range str { - if str[pos] == '?' { - buff.WriteString(replacements[i]) - i++ - } else { - buff.WriteRune(char) - } - } - - str = buff.String() - - return -} - -func (scope *Scope) whereSQL() (sql string) { - var ( - quotedTableName = scope.QuotedTableName() - deletedAtField, hasDeletedAtField = scope.FieldByName("DeletedAt") - primaryConditions, andConditions, orConditions []string - ) - - if !scope.Search.Unscoped && hasDeletedAtField { - sql := fmt.Sprintf("%v.%v IS NULL", quotedTableName, scope.Quote(deletedAtField.DBName)) - primaryConditions = append(primaryConditions, sql) - } - - if !scope.PrimaryKeyZero() { - for _, field := range scope.PrimaryFields() { - sql := fmt.Sprintf("%v.%v = %v", quotedTableName, scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())) - primaryConditions = append(primaryConditions, sql) - } - } - - for _, clause := range scope.Search.whereConditions { - if sql := scope.buildCondition(clause, true); sql != "" { - andConditions = append(andConditions, sql) - } - } - - for _, clause := range scope.Search.orConditions { - if sql := scope.buildCondition(clause, true); sql != "" { - orConditions = append(orConditions, sql) - } - } - - for _, clause := range scope.Search.notConditions { - if sql := scope.buildCondition(clause, false); sql != "" { - andConditions = append(andConditions, sql) - } - } - - orSQL := strings.Join(orConditions, " OR ") - combinedSQL := strings.Join(andConditions, " AND ") - if len(combinedSQL) > 0 { - if len(orSQL) > 0 { - combinedSQL = combinedSQL + " OR " + orSQL - } - } else { - combinedSQL = orSQL - } - - if len(primaryConditions) > 0 { - sql = "WHERE " + strings.Join(primaryConditions, " AND ") - if len(combinedSQL) > 0 { - sql = sql + " AND (" + combinedSQL + ")" - } - } else if len(combinedSQL) > 0 { - sql = "WHERE " + combinedSQL - } - return -} - -func (scope *Scope) selectSQL() string { - if len(scope.Search.selects) == 0 { - if len(scope.Search.joinConditions) > 0 { - return fmt.Sprintf("%v.*", scope.QuotedTableName()) - } - return "*" - } - return scope.buildSelectQuery(scope.Search.selects) -} - -func (scope *Scope) orderSQL() string { - if len(scope.Search.orders) == 0 || scope.Search.ignoreOrderQuery { - return "" - } - - var orders []string - for _, order := range scope.Search.orders { - if str, ok := order.(string); ok { - orders = append(orders, scope.quoteIfPossible(str)) - } else if expr, ok := order.(*expr); ok { - exp := expr.expr - for _, arg := range expr.args { - exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1) - } - orders = append(orders, exp) - } - } - return " ORDER BY " + strings.Join(orders, ",") -} - -func (scope *Scope) limitAndOffsetSQL() string { - return scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset) -} - -func (scope *Scope) groupSQL() string { - if len(scope.Search.group) == 0 { - return "" - } - return " GROUP BY " + scope.Search.group -} - -func (scope *Scope) havingSQL() string { - if len(scope.Search.havingConditions) == 0 { - return "" - } - - var andConditions []string - for _, clause := range scope.Search.havingConditions { - if sql := scope.buildCondition(clause, true); sql != "" { - andConditions = append(andConditions, sql) - } - } - - combinedSQL := strings.Join(andConditions, " AND ") - if len(combinedSQL) == 0 { - return "" - } - - return " HAVING " + combinedSQL -} - -func (scope *Scope) joinsSQL() string { - var joinConditions []string - for _, clause := range scope.Search.joinConditions { - if sql := scope.buildCondition(clause, true); sql != "" { - joinConditions = append(joinConditions, strings.TrimSuffix(strings.TrimPrefix(sql, "("), ")")) - } - } - - return strings.Join(joinConditions, " ") + " " -} - -func (scope *Scope) prepareQuerySQL() { - if scope.Search.raw { - scope.Raw(scope.CombinedConditionSql()) - } else { - scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSQL(), scope.QuotedTableName(), scope.CombinedConditionSql())) - } - return -} - -func (scope *Scope) inlineCondition(values ...interface{}) *Scope { - if len(values) > 0 { - scope.Search.Where(values[0], values[1:]...) - } - return scope -} - -func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope { - defer func() { - if err := recover(); err != nil { - if db, ok := scope.db.db.(sqlTx); ok { - db.Rollback() - } - panic(err) - } - }() - for _, f := range funcs { - (*f)(scope) - if scope.skipLeft { - break - } - } - return scope -} - -func convertInterfaceToMap(values interface{}, withIgnoredField bool) map[string]interface{} { - var attrs = map[string]interface{}{} - - switch value := values.(type) { - case map[string]interface{}: - return value - case []interface{}: - for _, v := range value { - for key, value := range convertInterfaceToMap(v, withIgnoredField) { - attrs[key] = value - } - } - case interface{}: - reflectValue := reflect.ValueOf(values) - - switch reflectValue.Kind() { - case reflect.Map: - for _, key := range reflectValue.MapKeys() { - attrs[ToColumnName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface() - } - default: - for _, field := range (&Scope{Value: values}).Fields() { - if !field.IsBlank && (withIgnoredField || !field.IsIgnored) { - attrs[field.DBName] = field.Field.Interface() - } - } - } - } - return attrs -} - -func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[string]interface{}, hasUpdate bool) { - if scope.IndirectValue().Kind() != reflect.Struct { - return convertInterfaceToMap(value, false), true - } - - results = map[string]interface{}{} - - for key, value := range convertInterfaceToMap(value, true) { - if field, ok := scope.FieldByName(key); ok && scope.changeableField(field) { - if _, ok := value.(*expr); ok { - hasUpdate = true - results[field.DBName] = value - } else { - err := field.Set(value) - if field.IsNormal && !field.IsIgnored { - hasUpdate = true - if err == ErrUnaddressable { - results[field.DBName] = value - } else { - results[field.DBName] = field.Field.Interface() - } - } - } - } - } - return -} - -func (scope *Scope) row() *sql.Row { - defer scope.trace(NowFunc()) - - result := &RowQueryResult{} - scope.InstanceSet("row_query_result", result) - scope.callCallbacks(scope.db.parent.callbacks.rowQueries) - - return result.Row -} - -func (scope *Scope) rows() (*sql.Rows, error) { - defer scope.trace(NowFunc()) - - result := &RowsQueryResult{} - scope.InstanceSet("row_query_result", result) - scope.callCallbacks(scope.db.parent.callbacks.rowQueries) - - return result.Rows, result.Error -} - -func (scope *Scope) initialize() *Scope { - for _, clause := range scope.Search.whereConditions { - scope.updatedAttrsWithValues(clause["query"]) - } - scope.updatedAttrsWithValues(scope.Search.initAttrs) - scope.updatedAttrsWithValues(scope.Search.assignAttrs) - return scope -} - -func (scope *Scope) isQueryForColumn(query interface{}, column string) bool { - queryStr := strings.ToLower(fmt.Sprint(query)) - if queryStr == column { - return true - } - - if strings.HasSuffix(queryStr, "as "+column) { - return true - } - - if strings.HasSuffix(queryStr, "as "+scope.Quote(column)) { - return true - } - - return false -} - -func (scope *Scope) pluck(column string, value interface{}) *Scope { - dest := reflect.Indirect(reflect.ValueOf(value)) - if dest.Kind() != reflect.Slice { - scope.Err(fmt.Errorf("results should be a slice, not %s", dest.Kind())) - return scope - } - - if query, ok := scope.Search.selects["query"]; !ok || !scope.isQueryForColumn(query, column) { - scope.Search.Select(column) - } - - rows, err := scope.rows() - if scope.Err(err) == nil { - defer rows.Close() - for rows.Next() { - elem := reflect.New(dest.Type().Elem()).Interface() - scope.Err(rows.Scan(elem)) - dest.Set(reflect.Append(dest, reflect.ValueOf(elem).Elem())) - } - - if err := rows.Err(); err != nil { - scope.Err(err) - } - } - return scope -} - -func (scope *Scope) count(value interface{}) *Scope { - if query, ok := scope.Search.selects["query"]; !ok || !countingQueryRegexp.MatchString(fmt.Sprint(query)) { - if len(scope.Search.group) != 0 { - scope.Search.Select("count(*) FROM ( SELECT count(*) as name ") - scope.Search.group += " ) AS count_table" - } else { - scope.Search.Select("count(*)") - } - } - scope.Search.ignoreOrderQuery = true - scope.Err(scope.row().Scan(value)) - return scope -} - -func (scope *Scope) typeName() string { - typ := scope.IndirectValue().Type() - - for typ.Kind() == reflect.Slice || typ.Kind() == reflect.Ptr { - typ = typ.Elem() - } - - return typ.Name() -} - -// trace print sql log -func (scope *Scope) trace(t time.Time) { - if len(scope.SQL) > 0 { - scope.db.slog(scope.SQL, t, scope.SQLVars...) - } -} - -func (scope *Scope) changeableField(field *Field) bool { - if selectAttrs := scope.SelectAttrs(); len(selectAttrs) > 0 { - for _, attr := range selectAttrs { - if field.Name == attr || field.DBName == attr { - return true - } - } - return false - } - - for _, attr := range scope.OmitAttrs() { - if field.Name == attr || field.DBName == attr { - return false - } - } - - return true -} - -func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { - toScope := scope.db.NewScope(value) - tx := scope.db.Set("gorm:association:source", scope.Value) - - for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") { - fromField, _ := scope.FieldByName(foreignKey) - toField, _ := toScope.FieldByName(foreignKey) - - if fromField != nil { - if relationship := fromField.Relationship; relationship != nil { - if relationship.Kind == "many_to_many" { - joinTableHandler := relationship.JoinTableHandler - scope.Err(joinTableHandler.JoinWith(joinTableHandler, tx, scope.Value).Find(value).Error) - } else if relationship.Kind == "belongs_to" { - for idx, foreignKey := range relationship.ForeignDBNames { - if field, ok := scope.FieldByName(foreignKey); ok { - tx = tx.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.AssociationForeignDBNames[idx])), field.Field.Interface()) - } - } - scope.Err(tx.Find(value).Error) - } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { - for idx, foreignKey := range relationship.ForeignDBNames { - if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok { - tx = tx.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) - } - } - - if relationship.PolymorphicType != "" { - tx = tx.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), relationship.PolymorphicValue) - } - scope.Err(tx.Find(value).Error) - } - } else { - sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey())) - scope.Err(tx.Where(sql, fromField.Field.Interface()).Find(value).Error) - } - return scope - } else if toField != nil { - sql := fmt.Sprintf("%v = ?", scope.Quote(toField.DBName)) - scope.Err(tx.Where(sql, scope.PrimaryKeyValue()).Find(value).Error) - return scope - } - } - - scope.Err(fmt.Errorf("invalid association %v", foreignKeys)) - return scope -} - -// getTableOptions return the table options string or an empty string if the table options does not exist -func (scope *Scope) getTableOptions() string { - tableOptions, ok := scope.Get("gorm:table_options") - if !ok { - return "" - } - return " " + tableOptions.(string) -} - -func (scope *Scope) createJoinTable(field *StructField) { - if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil { - joinTableHandler := relationship.JoinTableHandler - joinTable := joinTableHandler.Table(scope.db) - if !scope.Dialect().HasTable(joinTable) { - toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()} - - var sqlTypes, primaryKeys []string - for idx, fieldName := range relationship.ForeignFieldNames { - if field, ok := scope.FieldByName(fieldName); ok { - foreignKeyStruct := field.clone() - foreignKeyStruct.IsPrimaryKey = false - foreignKeyStruct.TagSettingsSet("IS_JOINTABLE_FOREIGNKEY", "true") - foreignKeyStruct.TagSettingsDelete("AUTO_INCREMENT") - sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct)) - primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx])) - } - } - - for idx, fieldName := range relationship.AssociationForeignFieldNames { - if field, ok := toScope.FieldByName(fieldName); ok { - foreignKeyStruct := field.clone() - foreignKeyStruct.IsPrimaryKey = false - foreignKeyStruct.TagSettingsSet("IS_JOINTABLE_FOREIGNKEY", "true") - foreignKeyStruct.TagSettingsDelete("AUTO_INCREMENT") - sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct)) - primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx])) - } - } - - scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v, PRIMARY KEY (%v))%s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), strings.Join(primaryKeys, ","), scope.getTableOptions())).Error) - } - scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler) - } -} - -func (scope *Scope) createTable() *Scope { - var tags []string - var primaryKeys []string - var primaryKeyInColumnType = false - for _, field := range scope.GetModelStruct().StructFields { - if field.IsNormal { - sqlTag := scope.Dialect().DataTypeOf(field) - - // Check if the primary key constraint was specified as - // part of the column type. If so, we can only support - // one column as the primary key. - if strings.Contains(strings.ToLower(sqlTag), "primary key") { - primaryKeyInColumnType = true - } - - tags = append(tags, scope.Quote(field.DBName)+" "+sqlTag) - } - - if field.IsPrimaryKey { - primaryKeys = append(primaryKeys, scope.Quote(field.DBName)) - } - scope.createJoinTable(field) - } - - var primaryKeyStr string - if len(primaryKeys) > 0 && !primaryKeyInColumnType { - primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ",")) - } - - scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v)%s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec() - - scope.autoIndex() - return scope -} - -func (scope *Scope) dropTable() *Scope { - scope.Raw(fmt.Sprintf("DROP TABLE %v%s", scope.QuotedTableName(), scope.getTableOptions())).Exec() - return scope -} - -func (scope *Scope) modifyColumn(column string, typ string) { - scope.db.AddError(scope.Dialect().ModifyColumn(scope.QuotedTableName(), scope.Quote(column), typ)) -} - -func (scope *Scope) dropColumn(column string) { - scope.Raw(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", scope.QuotedTableName(), scope.Quote(column))).Exec() -} - -func (scope *Scope) addIndex(unique bool, indexName string, column ...string) { - if scope.Dialect().HasIndex(scope.TableName(), indexName) { - return - } - - var columns []string - for _, name := range column { - columns = append(columns, scope.quoteIfPossible(name)) - } - - sqlCreate := "CREATE INDEX" - if unique { - sqlCreate = "CREATE UNIQUE INDEX" - } - - scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSQL())).Exec() -} - -func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) { - // Compatible with old generated key - keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign") - - if scope.Dialect().HasForeignKey(scope.TableName(), keyName) { - return - } - var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;` - scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName), scope.quoteIfPossible(field), dest, onDelete, onUpdate)).Exec() -} - -func (scope *Scope) removeForeignKey(field string, dest string) { - keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign") - if !scope.Dialect().HasForeignKey(scope.TableName(), keyName) { - return - } - var mysql mysql - var query string - if scope.Dialect().GetName() == mysql.GetName() { - query = `ALTER TABLE %s DROP FOREIGN KEY %s;` - } else { - query = `ALTER TABLE %s DROP CONSTRAINT %s;` - } - - scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName))).Exec() -} - -func (scope *Scope) removeIndex(indexName string) { - scope.Dialect().RemoveIndex(scope.TableName(), indexName) -} - -func (scope *Scope) autoMigrate() *Scope { - tableName := scope.TableName() - quotedTableName := scope.QuotedTableName() - - if !scope.Dialect().HasTable(tableName) { - scope.createTable() - } else { - for _, field := range scope.GetModelStruct().StructFields { - if !scope.Dialect().HasColumn(tableName, field.DBName) { - if field.IsNormal { - sqlTag := scope.Dialect().DataTypeOf(field) - scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec() - } - } - scope.createJoinTable(field) - } - scope.autoIndex() - } - return scope -} - -func (scope *Scope) autoIndex() *Scope { - var indexes = map[string][]string{} - var uniqueIndexes = map[string][]string{} - - for _, field := range scope.GetStructFields() { - if name, ok := field.TagSettingsGet("INDEX"); ok { - names := strings.Split(name, ",") - - for _, name := range names { - if name == "INDEX" || name == "" { - name = scope.Dialect().BuildKeyName("idx", scope.TableName(), field.DBName) - } - indexes[name] = append(indexes[name], field.DBName) - } - } - - if name, ok := field.TagSettingsGet("UNIQUE_INDEX"); ok { - names := strings.Split(name, ",") - - for _, name := range names { - if name == "UNIQUE_INDEX" || name == "" { - name = scope.Dialect().BuildKeyName("uix", scope.TableName(), field.DBName) - } - uniqueIndexes[name] = append(uniqueIndexes[name], field.DBName) - } - } - } - - for name, columns := range indexes { - if db := scope.NewDB().Table(scope.TableName()).Model(scope.Value).AddIndex(name, columns...); db.Error != nil { - scope.db.AddError(db.Error) - } - } - - for name, columns := range uniqueIndexes { - if db := scope.NewDB().Table(scope.TableName()).Model(scope.Value).AddUniqueIndex(name, columns...); db.Error != nil { - scope.db.AddError(db.Error) - } - } - - return scope -} - -func (scope *Scope) getColumnAsArray(columns []string, values ...interface{}) (results [][]interface{}) { - for _, value := range values { - indirectValue := indirect(reflect.ValueOf(value)) - - switch indirectValue.Kind() { - case reflect.Slice: - for i := 0; i < indirectValue.Len(); i++ { - var result []interface{} - var object = indirect(indirectValue.Index(i)) - var hasValue = false - for _, column := range columns { - field := object.FieldByName(column) - if hasValue || !isBlank(field) { - hasValue = true - } - result = append(result, field.Interface()) - } - - if hasValue { - results = append(results, result) - } - } - case reflect.Struct: - var result []interface{} - var hasValue = false - for _, column := range columns { - field := indirectValue.FieldByName(column) - if hasValue || !isBlank(field) { - hasValue = true - } - result = append(result, field.Interface()) - } - - if hasValue { - results = append(results, result) - } - } - } - - return -} - -func (scope *Scope) getColumnAsScope(column string) *Scope { - indirectScopeValue := scope.IndirectValue() - - switch indirectScopeValue.Kind() { - case reflect.Slice: - if fieldStruct, ok := scope.GetModelStruct().ModelType.FieldByName(column); ok { - fieldType := fieldStruct.Type - if fieldType.Kind() == reflect.Slice || fieldType.Kind() == reflect.Ptr { - fieldType = fieldType.Elem() - } - - resultsMap := map[interface{}]bool{} - results := reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType))).Elem() - - for i := 0; i < indirectScopeValue.Len(); i++ { - result := indirect(indirect(indirectScopeValue.Index(i)).FieldByName(column)) - - if result.Kind() == reflect.Slice { - for j := 0; j < result.Len(); j++ { - if elem := result.Index(j); elem.CanAddr() && resultsMap[elem.Addr()] != true { - resultsMap[elem.Addr()] = true - results = reflect.Append(results, elem.Addr()) - } - } - } else if result.CanAddr() && resultsMap[result.Addr()] != true { - resultsMap[result.Addr()] = true - results = reflect.Append(results, result.Addr()) - } - } - return scope.New(results.Interface()) - } - case reflect.Struct: - if field := indirectScopeValue.FieldByName(column); field.CanAddr() { - return scope.New(field.Addr().Interface()) - } - } - return nil -} - -func (scope *Scope) hasConditions() bool { - return !scope.PrimaryKeyZero() || - len(scope.Search.whereConditions) > 0 || - len(scope.Search.orConditions) > 0 || - len(scope.Search.notConditions) > 0 -} diff --git a/vendor/github.com/jinzhu/gorm/search.go b/vendor/github.com/jinzhu/gorm/search.go deleted file mode 100644 index 901385956..000000000 --- a/vendor/github.com/jinzhu/gorm/search.go +++ /dev/null @@ -1,153 +0,0 @@ -package gorm - -import ( - "fmt" -) - -type search struct { - db *DB - whereConditions []map[string]interface{} - orConditions []map[string]interface{} - notConditions []map[string]interface{} - havingConditions []map[string]interface{} - joinConditions []map[string]interface{} - initAttrs []interface{} - assignAttrs []interface{} - selects map[string]interface{} - omits []string - orders []interface{} - preload []searchPreload - offset interface{} - limit interface{} - group string - tableName string - raw bool - Unscoped bool - ignoreOrderQuery bool -} - -type searchPreload struct { - schema string - conditions []interface{} -} - -func (s *search) clone() *search { - clone := *s - return &clone -} - -func (s *search) Where(query interface{}, values ...interface{}) *search { - s.whereConditions = append(s.whereConditions, map[string]interface{}{"query": query, "args": values}) - return s -} - -func (s *search) Not(query interface{}, values ...interface{}) *search { - s.notConditions = append(s.notConditions, map[string]interface{}{"query": query, "args": values}) - return s -} - -func (s *search) Or(query interface{}, values ...interface{}) *search { - s.orConditions = append(s.orConditions, map[string]interface{}{"query": query, "args": values}) - return s -} - -func (s *search) Attrs(attrs ...interface{}) *search { - s.initAttrs = append(s.initAttrs, toSearchableMap(attrs...)) - return s -} - -func (s *search) Assign(attrs ...interface{}) *search { - s.assignAttrs = append(s.assignAttrs, toSearchableMap(attrs...)) - return s -} - -func (s *search) Order(value interface{}, reorder ...bool) *search { - if len(reorder) > 0 && reorder[0] { - s.orders = []interface{}{} - } - - if value != nil && value != "" { - s.orders = append(s.orders, value) - } - return s -} - -func (s *search) Select(query interface{}, args ...interface{}) *search { - s.selects = map[string]interface{}{"query": query, "args": args} - return s -} - -func (s *search) Omit(columns ...string) *search { - s.omits = columns - return s -} - -func (s *search) Limit(limit interface{}) *search { - s.limit = limit - return s -} - -func (s *search) Offset(offset interface{}) *search { - s.offset = offset - return s -} - -func (s *search) Group(query string) *search { - s.group = s.getInterfaceAsSQL(query) - return s -} - -func (s *search) Having(query interface{}, values ...interface{}) *search { - if val, ok := query.(*expr); ok { - s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": val.expr, "args": val.args}) - } else { - s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": query, "args": values}) - } - return s -} - -func (s *search) Joins(query string, values ...interface{}) *search { - s.joinConditions = append(s.joinConditions, map[string]interface{}{"query": query, "args": values}) - return s -} - -func (s *search) Preload(schema string, values ...interface{}) *search { - var preloads []searchPreload - for _, preload := range s.preload { - if preload.schema != schema { - preloads = append(preloads, preload) - } - } - preloads = append(preloads, searchPreload{schema, values}) - s.preload = preloads - return s -} - -func (s *search) Raw(b bool) *search { - s.raw = b - return s -} - -func (s *search) unscoped() *search { - s.Unscoped = true - return s -} - -func (s *search) Table(name string) *search { - s.tableName = name - return s -} - -func (s *search) getInterfaceAsSQL(value interface{}) (str string) { - switch value.(type) { - case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: - str = fmt.Sprintf("%v", value) - default: - s.db.AddError(ErrInvalidSQL) - } - - if str == "-1" { - return "" - } - return -} diff --git a/vendor/github.com/jinzhu/gorm/utils.go b/vendor/github.com/jinzhu/gorm/utils.go deleted file mode 100644 index e58e57a56..000000000 --- a/vendor/github.com/jinzhu/gorm/utils.go +++ /dev/null @@ -1,226 +0,0 @@ -package gorm - -import ( - "database/sql/driver" - "fmt" - "reflect" - "regexp" - "runtime" - "strings" - "sync" - "time" -) - -// NowFunc returns current time, this function is exported in order to be able -// to give the flexibility to the developer to customize it according to their -// needs, e.g: -// gorm.NowFunc = func() time.Time { -// return time.Now().UTC() -// } -var NowFunc = func() time.Time { - return time.Now() -} - -// Copied from golint -var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} -var commonInitialismsReplacer *strings.Replacer - -var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*.go`) -var goTestRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*test.go`) - -func init() { - var commonInitialismsForReplacer []string - for _, initialism := range commonInitialisms { - commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism))) - } - commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...) -} - -type safeMap struct { - m map[string]string - l *sync.RWMutex -} - -func (s *safeMap) Set(key string, value string) { - s.l.Lock() - defer s.l.Unlock() - s.m[key] = value -} - -func (s *safeMap) Get(key string) string { - s.l.RLock() - defer s.l.RUnlock() - return s.m[key] -} - -func newSafeMap() *safeMap { - return &safeMap{l: new(sync.RWMutex), m: make(map[string]string)} -} - -// SQL expression -type expr struct { - expr string - args []interface{} -} - -// Expr generate raw SQL expression, for example: -// DB.Model(&product).Update("price", gorm.Expr("price * ? + ?", 2, 100)) -func Expr(expression string, args ...interface{}) *expr { - return &expr{expr: expression, args: args} -} - -func indirect(reflectValue reflect.Value) reflect.Value { - for reflectValue.Kind() == reflect.Ptr { - reflectValue = reflectValue.Elem() - } - return reflectValue -} - -func toQueryMarks(primaryValues [][]interface{}) string { - var results []string - - for _, primaryValue := range primaryValues { - var marks []string - for range primaryValue { - marks = append(marks, "?") - } - - if len(marks) > 1 { - results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ","))) - } else { - results = append(results, strings.Join(marks, "")) - } - } - return strings.Join(results, ",") -} - -func toQueryCondition(scope *Scope, columns []string) string { - var newColumns []string - for _, column := range columns { - newColumns = append(newColumns, scope.Quote(column)) - } - - if len(columns) > 1 { - return fmt.Sprintf("(%v)", strings.Join(newColumns, ",")) - } - return strings.Join(newColumns, ",") -} - -func toQueryValues(values [][]interface{}) (results []interface{}) { - for _, value := range values { - for _, v := range value { - results = append(results, v) - } - } - return -} - -func fileWithLineNum() string { - for i := 2; i < 15; i++ { - _, file, line, ok := runtime.Caller(i) - if ok && (!goSrcRegexp.MatchString(file) || goTestRegexp.MatchString(file)) { - return fmt.Sprintf("%v:%v", file, line) - } - } - return "" -} - -func isBlank(value reflect.Value) bool { - switch value.Kind() { - case reflect.String: - return value.Len() == 0 - case reflect.Bool: - return !value.Bool() - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return value.Int() == 0 - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - return value.Uint() == 0 - case reflect.Float32, reflect.Float64: - return value.Float() == 0 - case reflect.Interface, reflect.Ptr: - return value.IsNil() - } - - return reflect.DeepEqual(value.Interface(), reflect.Zero(value.Type()).Interface()) -} - -func toSearchableMap(attrs ...interface{}) (result interface{}) { - if len(attrs) > 1 { - if str, ok := attrs[0].(string); ok { - result = map[string]interface{}{str: attrs[1]} - } - } else if len(attrs) == 1 { - if attr, ok := attrs[0].(map[string]interface{}); ok { - result = attr - } - - if attr, ok := attrs[0].(interface{}); ok { - result = attr - } - } - return -} - -func equalAsString(a interface{}, b interface{}) bool { - return toString(a) == toString(b) -} - -func toString(str interface{}) string { - if values, ok := str.([]interface{}); ok { - var results []string - for _, value := range values { - results = append(results, toString(value)) - } - return strings.Join(results, "_") - } else if bytes, ok := str.([]byte); ok { - return string(bytes) - } else if reflectValue := reflect.Indirect(reflect.ValueOf(str)); reflectValue.IsValid() { - return fmt.Sprintf("%v", reflectValue.Interface()) - } - return "" -} - -func makeSlice(elemType reflect.Type) interface{} { - if elemType.Kind() == reflect.Slice { - elemType = elemType.Elem() - } - sliceType := reflect.SliceOf(elemType) - slice := reflect.New(sliceType) - slice.Elem().Set(reflect.MakeSlice(sliceType, 0, 0)) - return slice.Interface() -} - -func strInSlice(a string, list []string) bool { - for _, b := range list { - if b == a { - return true - } - } - return false -} - -// getValueFromFields return given fields's value -func getValueFromFields(value reflect.Value, fieldNames []string) (results []interface{}) { - // If value is a nil pointer, Indirect returns a zero Value! - // Therefor we need to check for a zero value, - // as FieldByName could panic - if indirectValue := reflect.Indirect(value); indirectValue.IsValid() { - for _, fieldName := range fieldNames { - if fieldValue := reflect.Indirect(indirectValue.FieldByName(fieldName)); fieldValue.IsValid() { - result := fieldValue.Interface() - if r, ok := result.(driver.Valuer); ok { - result, _ = r.Value() - } - results = append(results, result) - } - } - } - return -} - -func addExtraSpaceIfExist(str string) string { - if str != "" { - return " " + str - } - return "" -} diff --git a/vendor/github.com/jinzhu/inflection/LICENSE b/vendor/github.com/jinzhu/inflection/LICENSE deleted file mode 100644 index a1ca9a0ff..000000000 --- a/vendor/github.com/jinzhu/inflection/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -The MIT License (MIT) - -Copyright (c) 2015 - Jinzhu - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/vendor/github.com/jinzhu/inflection/inflections.go b/vendor/github.com/jinzhu/inflection/inflections.go deleted file mode 100644 index 606263bb7..000000000 --- a/vendor/github.com/jinzhu/inflection/inflections.go +++ /dev/null @@ -1,273 +0,0 @@ -/* -Package inflection pluralizes and singularizes English nouns. - - inflection.Plural("person") => "people" - inflection.Plural("Person") => "People" - inflection.Plural("PERSON") => "PEOPLE" - - inflection.Singular("people") => "person" - inflection.Singular("People") => "Person" - inflection.Singular("PEOPLE") => "PERSON" - - inflection.Plural("FancyPerson") => "FancydPeople" - inflection.Singular("FancyPeople") => "FancydPerson" - -Standard rules are from Rails's ActiveSupport (https://github.com/rails/rails/blob/master/activesupport/lib/active_support/inflections.rb) - -If you want to register more rules, follow: - - inflection.AddUncountable("fish") - inflection.AddIrregular("person", "people") - inflection.AddPlural("(bu)s$", "${1}ses") # "bus" => "buses" / "BUS" => "BUSES" / "Bus" => "Buses" - inflection.AddSingular("(bus)(es)?$", "${1}") # "buses" => "bus" / "Buses" => "Bus" / "BUSES" => "BUS" -*/ -package inflection - -import ( - "regexp" - "strings" -) - -type inflection struct { - regexp *regexp.Regexp - replace string -} - -// Regular is a regexp find replace inflection -type Regular struct { - find string - replace string -} - -// Irregular is a hard replace inflection, -// containing both singular and plural forms -type Irregular struct { - singular string - plural string -} - -// RegularSlice is a slice of Regular inflections -type RegularSlice []Regular - -// IrregularSlice is a slice of Irregular inflections -type IrregularSlice []Irregular - -var pluralInflections = RegularSlice{ - {"([a-z])$", "${1}s"}, - {"s$", "s"}, - {"^(ax|test)is$", "${1}es"}, - {"(octop|vir)us$", "${1}i"}, - {"(octop|vir)i$", "${1}i"}, - {"(alias|status)$", "${1}es"}, - {"(bu)s$", "${1}ses"}, - {"(buffal|tomat)o$", "${1}oes"}, - {"([ti])um$", "${1}a"}, - {"([ti])a$", "${1}a"}, - {"sis$", "ses"}, - {"(?:([^f])fe|([lr])f)$", "${1}${2}ves"}, - {"(hive)$", "${1}s"}, - {"([^aeiouy]|qu)y$", "${1}ies"}, - {"(x|ch|ss|sh)$", "${1}es"}, - {"(matr|vert|ind)(?:ix|ex)$", "${1}ices"}, - {"^(m|l)ouse$", "${1}ice"}, - {"^(m|l)ice$", "${1}ice"}, - {"^(ox)$", "${1}en"}, - {"^(oxen)$", "${1}"}, - {"(quiz)$", "${1}zes"}, -} - -var singularInflections = RegularSlice{ - {"s$", ""}, - {"(ss)$", "${1}"}, - {"(n)ews$", "${1}ews"}, - {"([ti])a$", "${1}um"}, - {"((a)naly|(b)a|(d)iagno|(p)arenthe|(p)rogno|(s)ynop|(t)he)(sis|ses)$", "${1}sis"}, - {"(^analy)(sis|ses)$", "${1}sis"}, - {"([^f])ves$", "${1}fe"}, - {"(hive)s$", "${1}"}, - {"(tive)s$", "${1}"}, - {"([lr])ves$", "${1}f"}, - {"([^aeiouy]|qu)ies$", "${1}y"}, - {"(s)eries$", "${1}eries"}, - {"(m)ovies$", "${1}ovie"}, - {"(c)ookies$", "${1}ookie"}, - {"(x|ch|ss|sh)es$", "${1}"}, - {"^(m|l)ice$", "${1}ouse"}, - {"(bus)(es)?$", "${1}"}, - {"(o)es$", "${1}"}, - {"(shoe)s$", "${1}"}, - {"(cris|test)(is|es)$", "${1}is"}, - {"^(a)x[ie]s$", "${1}xis"}, - {"(octop|vir)(us|i)$", "${1}us"}, - {"(alias|status)(es)?$", "${1}"}, - {"^(ox)en", "${1}"}, - {"(vert|ind)ices$", "${1}ex"}, - {"(matr)ices$", "${1}ix"}, - {"(quiz)zes$", "${1}"}, - {"(database)s$", "${1}"}, -} - -var irregularInflections = IrregularSlice{ - {"person", "people"}, - {"man", "men"}, - {"child", "children"}, - {"sex", "sexes"}, - {"move", "moves"}, - {"mombie", "mombies"}, -} - -var uncountableInflections = []string{"equipment", "information", "rice", "money", "species", "series", "fish", "sheep", "jeans", "police"} - -var compiledPluralMaps []inflection -var compiledSingularMaps []inflection - -func compile() { - compiledPluralMaps = []inflection{} - compiledSingularMaps = []inflection{} - for _, uncountable := range uncountableInflections { - inf := inflection{ - regexp: regexp.MustCompile("^(?i)(" + uncountable + ")$"), - replace: "${1}", - } - compiledPluralMaps = append(compiledPluralMaps, inf) - compiledSingularMaps = append(compiledSingularMaps, inf) - } - - for _, value := range irregularInflections { - infs := []inflection{ - inflection{regexp: regexp.MustCompile(strings.ToUpper(value.singular) + "$"), replace: strings.ToUpper(value.plural)}, - inflection{regexp: regexp.MustCompile(strings.Title(value.singular) + "$"), replace: strings.Title(value.plural)}, - inflection{regexp: regexp.MustCompile(value.singular + "$"), replace: value.plural}, - } - compiledPluralMaps = append(compiledPluralMaps, infs...) - } - - for _, value := range irregularInflections { - infs := []inflection{ - inflection{regexp: regexp.MustCompile(strings.ToUpper(value.plural) + "$"), replace: strings.ToUpper(value.singular)}, - inflection{regexp: regexp.MustCompile(strings.Title(value.plural) + "$"), replace: strings.Title(value.singular)}, - inflection{regexp: regexp.MustCompile(value.plural + "$"), replace: value.singular}, - } - compiledSingularMaps = append(compiledSingularMaps, infs...) - } - - for i := len(pluralInflections) - 1; i >= 0; i-- { - value := pluralInflections[i] - infs := []inflection{ - inflection{regexp: regexp.MustCompile(strings.ToUpper(value.find)), replace: strings.ToUpper(value.replace)}, - inflection{regexp: regexp.MustCompile(value.find), replace: value.replace}, - inflection{regexp: regexp.MustCompile("(?i)" + value.find), replace: value.replace}, - } - compiledPluralMaps = append(compiledPluralMaps, infs...) - } - - for i := len(singularInflections) - 1; i >= 0; i-- { - value := singularInflections[i] - infs := []inflection{ - inflection{regexp: regexp.MustCompile(strings.ToUpper(value.find)), replace: strings.ToUpper(value.replace)}, - inflection{regexp: regexp.MustCompile(value.find), replace: value.replace}, - inflection{regexp: regexp.MustCompile("(?i)" + value.find), replace: value.replace}, - } - compiledSingularMaps = append(compiledSingularMaps, infs...) - } -} - -func init() { - compile() -} - -// AddPlural adds a plural inflection -func AddPlural(find, replace string) { - pluralInflections = append(pluralInflections, Regular{find, replace}) - compile() -} - -// AddSingular adds a singular inflection -func AddSingular(find, replace string) { - singularInflections = append(singularInflections, Regular{find, replace}) - compile() -} - -// AddIrregular adds an irregular inflection -func AddIrregular(singular, plural string) { - irregularInflections = append(irregularInflections, Irregular{singular, plural}) - compile() -} - -// AddUncountable adds an uncountable inflection -func AddUncountable(values ...string) { - uncountableInflections = append(uncountableInflections, values...) - compile() -} - -// GetPlural retrieves the plural inflection values -func GetPlural() RegularSlice { - plurals := make(RegularSlice, len(pluralInflections)) - copy(plurals, pluralInflections) - return plurals -} - -// GetSingular retrieves the singular inflection values -func GetSingular() RegularSlice { - singulars := make(RegularSlice, len(singularInflections)) - copy(singulars, singularInflections) - return singulars -} - -// GetIrregular retrieves the irregular inflection values -func GetIrregular() IrregularSlice { - irregular := make(IrregularSlice, len(irregularInflections)) - copy(irregular, irregularInflections) - return irregular -} - -// GetUncountable retrieves the uncountable inflection values -func GetUncountable() []string { - uncountables := make([]string, len(uncountableInflections)) - copy(uncountables, uncountableInflections) - return uncountables -} - -// SetPlural sets the plural inflections slice -func SetPlural(inflections RegularSlice) { - pluralInflections = inflections - compile() -} - -// SetSingular sets the singular inflections slice -func SetSingular(inflections RegularSlice) { - singularInflections = inflections - compile() -} - -// SetIrregular sets the irregular inflections slice -func SetIrregular(inflections IrregularSlice) { - irregularInflections = inflections - compile() -} - -// SetUncountable sets the uncountable inflections slice -func SetUncountable(inflections []string) { - uncountableInflections = inflections - compile() -} - -// Plural converts a word to its plural form -func Plural(str string) string { - for _, inflection := range compiledPluralMaps { - if inflection.regexp.MatchString(str) { - return inflection.regexp.ReplaceAllString(str, inflection.replace) - } - } - return str -} - -// Singular converts a word to its singular form -func Singular(str string) string { - for _, inflection := range compiledSingularMaps { - if inflection.regexp.MatchString(str) { - return inflection.regexp.ReplaceAllString(str, inflection.replace) - } - } - return str -}