Remove vendored dependencies (#4110)

This commit is contained in:
its-josh4
2023-09-11 17:36:48 -07:00
committed by GitHub
parent b36aa745d8
commit 4a9fdc8b55
2675 changed files with 20 additions and 1259580 deletions

View File

@@ -33,7 +33,7 @@ jobs:
run: docker exec -t build /bin/bash -c "make generate-backend"
- name: Run golangci-lint
uses: golangci/golangci-lint-action@v2
uses: golangci/golangci-lint-action@v3
with:
# Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version
version: latest
@@ -42,19 +42,26 @@ jobs:
# working-directory: somedir
# Optional: golangci-lint command line arguments.
args: --modules-download-mode=vendor --timeout=5m
#
# Note: By default, the `.golangci.yml` file should be at the root of the repository.
# The location of the configuration file can be changed by using `--config=`
args: --timeout=5m
# Optional: show only new issues if it's a pull request. The default value is `false`.
# only-new-issues: true
# Optional: if set to true then the action will use pre-installed Go.
# skip-go-installation: true
# Optional: if set to true, then all caching functionality will be completely disabled,
# takes precedence over all other caching options.
# skip-cache: true
# Optional: if set to true then the action don't cache or restore ~/go/pkg.
skip-pkg-cache: true
# Optional: if set to true, then the action won't cache or restore ~/go/pkg.
# skip-pkg-cache: true
# Optional: if set to true then the action don't cache or restore ~/.cache/go-build.
skip-build-cache: true
# Optional: if set to true, then the action won't cache or restore ~/.cache/go-build.
# skip-build-cache: true
# Optional: The mode to install golangci-lint. It can be 'binary' or 'goinstall'.
# install-mode: "goinstall"
- name: Cleanup build container
run: docker rm -f -v build

3
.gitignore vendored
View File

@@ -2,6 +2,9 @@
# Go
####
# Vendored dependencies
vendor
# Binaries for programs and plugins
*.exe
*.exe~

View File

@@ -1,7 +1,6 @@
# options for analysis running
run:
timeout: 5m
modules-download-mode: vendor
linters:
disable-all: true

View File

@@ -1,4 +1,4 @@
//go:generate go run -mod=vendor github.com/99designs/gqlgen
//go:generate go run github.com/99designs/gqlgen
package main
import (

View File

@@ -21,7 +21,6 @@ RUN apk add --no-cache make alpine-sdk
WORKDIR /stash
COPY ./go* ./*.go Makefile gqlgen.yml .gqlgenc.yml /stash/
COPY ./scripts /stash/scripts/
COPY ./vendor /stash/vendor/
COPY ./pkg /stash/pkg/
COPY ./cmd /stash/cmd
COPY ./internal /stash/internal

View File

@@ -21,7 +21,6 @@ RUN apt update && apt install -y build-essential golang
WORKDIR /stash
COPY ./go* ./*.go Makefile gqlgen.yml .gqlgenc.yml /stash/
COPY ./scripts /stash/scripts/
COPY ./vendor /stash/vendor/
COPY ./pkg /stash/pkg/
COPY ./cmd /stash/cmd
COPY ./internal /stash/internal

View File

@@ -2,7 +2,7 @@
COMPILER_CONTAINER="stashapp/compiler:7"
BUILD_DATE=`go run -mod=vendor scripts/getDate.go`
BUILD_DATE=`go run scripts/getDate.go`
GITHASH=`git rev-parse --short HEAD`
STASH_VERSION=`git describe --tags --exclude latest_develop`

View File

@@ -1,3 +0,0 @@
/**/node_modules
/codegen/tests/gen
/vendor

View File

@@ -1,20 +0,0 @@
root = true
[*]
end_of_line = lf
charset = utf-8
trim_trailing_whitespace = true
insert_final_newline = true
indent_style = space
indent_size = 4
[*.{go,gotpl}]
indent_style = tab
[*.yml]
indent_size = 2
# These often end up with go code inside, so lets keep tabs
[*.{html,md}]
indent_size = 2
indent_style = tab

View File

@@ -1,3 +0,0 @@
/codegen/templates/data.go linguist-generated
/_examples/dataloader/*_gen.go linguist-generated
generated.go linguist-generated

View File

@@ -1,16 +0,0 @@
/vendor
/docs/public
/docs/.hugo_build.lock
/_examples/chat/node_modules
/integration/node_modules
/integration/schema-fetched.graphql
/_examples/chat/package-lock.json
/_examples/federation/package-lock.json
/_examples/federation/node_modules
/codegen/gen
/gen
/.vscode
.idea/
*.test
*.out

View File

@@ -1,39 +0,0 @@
run:
tests: true
skip-dirs:
- bin
linters-settings:
errcheck:
ignore: fmt:.*,[rR]ead|[wW]rite|[cC]lose,io:Copy
linters:
disable-all: true
enable:
- bodyclose
- deadcode
- depguard
- dupl
- errcheck
- gocritic
- gofmt
- goimports
- gosimple
- govet
- ineffassign
- misspell
- nakedret
- prealloc
- staticcheck
- structcheck
- typecheck
- unconvert
- unused
- varcheck
issues:
exclude-rules:
# Exclude some linters from running on tests files.
- path: _test\.go
linters:
- dupl

File diff suppressed because it is too large Load Diff

View File

@@ -1,27 +0,0 @@
# Contribution Guidelines
Want to contribute to gqlgen? Here are some guidelines for how we accept help.
## Getting in Touch
Our [discord](https://discord.gg/DYEq3EMs4U) server is the best place to ask questions or get advice on using gqlgen.
## Reporting Bugs and Issues
We use [GitHub Issues](https://github.com/99designs/gqlgen/issues) to track bugs, so please do a search before submitting to ensure your problem isn't already tracked.
### New Issues
Please provide the expected and observed behaviours in your issue. A minimal GraphQL schema or configuration file should be provided where appropriate.
## Proposing a Change
If you intend to implement a feature for gqlgen, or make a non-trivial change to the current implementation, we recommend [first filing an issue](https://github.com/99designs/gqlgen/issues/new) marked with the `proposal` tag, so that the engineering team can provide guidance and feedback on the direction of an implementation. This also help ensure that other people aren't also working on the same thing.
Bug fixes are welcome and should come with appropriate test coverage.
New features should be made against the `next` branch.
### License
By contributing to gqlgen, you agree that your contributions will be licensed under its MIT license.

View File

@@ -1,19 +0,0 @@
Copyright (c) 2020 gqlgen authors
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -1,150 +0,0 @@
![gqlgen](https://user-images.githubusercontent.com/980499/133180111-d064b38c-6eb9-444b-a60f-7005a6e68222.png)
# gqlgen [![Integration](https://github.com/99designs/gqlgen/actions/workflows/integration.yml/badge.svg)](https://github.com/99designs/gqlgen/actions) [![Coverage Status](https://coveralls.io/repos/github/99designs/gqlgen/badge.svg?branch=master)](https://coveralls.io/github/99designs/gqlgen?branch=master) [![Go Report Card](https://goreportcard.com/badge/github.com/99designs/gqlgen)](https://goreportcard.com/report/github.com/99designs/gqlgen) [![Go Reference](https://pkg.go.dev/badge/github.com/99designs/gqlgen.svg)](https://pkg.go.dev/github.com/99designs/gqlgen) [![Read the Docs](https://badgen.net/badge/docs/available/green)](http://gqlgen.com/)
## What is gqlgen?
[gqlgen](https://github.com/99designs/gqlgen) is a Go library for building GraphQL servers without any fuss.<br/>
- **gqlgen is based on a Schema first approach** — You get to Define your API using the GraphQL [Schema Definition Language](http://graphql.org/learn/schema/).
- **gqlgen prioritizes Type safety** — You should never see `map[string]interface{}` here.
- **gqlgen enables Codegen** — We generate the boring bits, so you can focus on building your app quickly.
Still not convinced enough to use **gqlgen**? Compare **gqlgen** with other Go graphql [implementations](https://gqlgen.com/feature-comparison/)
## Quick start
1. [Initialise a new go module](https://golang.org/doc/tutorial/create-module)
mkdir example
cd example
go mod init example
2. Add `github.com/99designs/gqlgen` to your [project's tools.go](https://github.com/golang/go/wiki/Modules#how-can-i-track-tool-dependencies-for-a-module)
printf '// +build tools\npackage tools\nimport _ "github.com/99designs/gqlgen"' | gofmt > tools.go
go mod tidy
3. Initialise gqlgen config and generate models
go run github.com/99designs/gqlgen init
4. Start the graphql server
go run server.go
More help to get started:
- [Getting started tutorial](https://gqlgen.com/getting-started/) - a comprehensive guide to help you get started
- [Real-world examples](https://github.com/99designs/gqlgen/tree/master/_examples) show how to create GraphQL applications
- [Reference docs](https://pkg.go.dev/github.com/99designs/gqlgen) for the APIs
## Reporting Issues
If you think you've found a bug, or something isn't behaving the way you think it should, please raise an [issue](https://github.com/99designs/gqlgen/issues) on GitHub.
## Contributing
We welcome contributions, Read our [Contribution Guidelines](https://github.com/99designs/gqlgen/blob/master/CONTRIBUTING.md) to learn more about contributing to **gqlgen**
## Frequently asked questions
### How do I prevent fetching child objects that might not be used?
When you have nested or recursive schema like this:
```graphql
type User {
id: ID!
name: String!
friends: [User!]!
}
```
You need to tell gqlgen that it should only fetch friends if the user requested it. There are two ways to do this;
- #### Using Custom Models
Write a custom model that omits the friends field:
```go
type User struct {
ID int
Name string
}
```
And reference the model in `gqlgen.yml`:
```yaml
# gqlgen.yml
models:
User:
model: github.com/you/pkg/model.User # go import path to the User struct above
```
- #### Using Explicit Resolvers
If you want to Keep using the generated model, mark the field as requiring a resolver explicitly in `gqlgen.yml` like this:
```yaml
# gqlgen.yml
models:
User:
fields:
friends:
resolver: true # force a resolver to be generated
```
After doing either of the above and running generate we will need to provide a resolver for friends:
```go
func (r *userResolver) Friends(ctx context.Context, obj *User) ([]*User, error) {
// select * from user where friendid = obj.ID
return friends, nil
}
```
You can also use inline config with directives to achieve the same result
```graphql
directive @goModel(model: String, models: [String!]) on OBJECT
| INPUT_OBJECT
| SCALAR
| ENUM
| INTERFACE
| UNION
directive @goField(forceResolver: Boolean, name: String) on INPUT_FIELD_DEFINITION
| FIELD_DEFINITION
type User @goModel(model: "github.com/you/pkg/model.User") {
id: ID! @goField(name: "todoId")
friends: [User!]! @goField(forceResolver: true)
}
```
### Can I change the type of the ID from type String to Type Int?
Yes! You can by remapping it in config as seen below:
```yaml
models:
ID: # The GraphQL type ID is backed by
model:
- github.com/99designs/gqlgen/graphql.IntID # a go integer
- github.com/99designs/gqlgen/graphql.ID # or a go string
```
This means gqlgen will be able to automatically bind to strings or ints for models you have written yourself, but the
first model in this list is used as the default type and it will always be used when:
- Generating models based on schema
- As arguments in resolvers
There isn't any way around this, gqlgen has no way to know what you want in a given context.
## Other Resources
- [Christopher Biscardi @ Gophercon UK 2018](https://youtu.be/FdURVezcdcw)
- [Introducing gqlgen: a GraphQL Server Generator for Go](https://99designs.com.au/blog/engineering/gqlgen-a-graphql-server-generator-for-go/)
- [Dive into GraphQL by Iván Corrales Solera](https://medium.com/@ivan.corrales.solera/dive-into-graphql-9bfedf22e1a)
- [Sample Project built on gqlgen with Postgres by Oleg Shalygin](https://github.com/oshalygin/gqlgen-pg-todo-example)

View File

@@ -1,14 +0,0 @@
# When gqlgen gets released, the following things need to happen
Assuming the next version is $NEW_VERSION=v0.16.0 or something like that.
1. Run the https://github.com/99designs/gqlgen/blob/master/bin/release:
```
./bin/release $NEW_VERSION
```
2. git-chglog -o CHANGELOG.md
3. git commit and push the CHANGELOG.md
4. Go to https://github.com/99designs/gqlgen/releases and draft new release, autogenerate the release notes, and Create a discussion for this release
5. Comment on the release discussion with any really important notes (breaking changes)
I used https://github.com/git-chglog/git-chglog to automate the changelog maintenance process for now. We could just as easily use go releaser to make the whole thing automated.

View File

@@ -1,40 +0,0 @@
How to write tests for gqlgen
===
Testing generated code is a little tricky, heres how its currently set up.
### Testing responses from a server
There is a server in `codegen/testserver` that is generated as part
of `go generate ./...`, and tests written against it.
There are also a bunch of tests in against the examples, feel free to take examples from there.
### Testing the errors generated by the binary
These tests are **really** slow, because they need to run the whole codegen step. Use them very sparingly. If you can, find a way to unit test it instead.
Take a look at `codegen/testserver/input_test.go` for an example.
### Testing introspection
Introspection is tested by diffing the output of `graphql get-schema` against an expected output.
Setting up the integration environment is a little tricky:
```bash
cd integration
go generate ./...
go run ./server/server.go
```
in another terminal
```bash
cd integration
npm install
./node_modules/.bin/graphql-codegen
```
will write the schema to `integration/schema-fetched.graphql`, compare that with `schema-expected.graphql`
CI will run this and fail the build if the two files dont match.

View File

@@ -1,125 +0,0 @@
package api
import (
"fmt"
"syscall"
"github.com/99designs/gqlgen/codegen"
"github.com/99designs/gqlgen/codegen/config"
"github.com/99designs/gqlgen/plugin"
"github.com/99designs/gqlgen/plugin/federation"
"github.com/99designs/gqlgen/plugin/modelgen"
"github.com/99designs/gqlgen/plugin/resolvergen"
)
func Generate(cfg *config.Config, option ...Option) error {
_ = syscall.Unlink(cfg.Exec.Filename)
if cfg.Model.IsDefined() {
_ = syscall.Unlink(cfg.Model.Filename)
}
plugins := []plugin.Plugin{}
if cfg.Model.IsDefined() {
plugins = append(plugins, modelgen.New())
}
plugins = append(plugins, resolvergen.New())
if cfg.Federation.IsDefined() {
plugins = append([]plugin.Plugin{federation.New()}, plugins...)
}
for _, o := range option {
o(cfg, &plugins)
}
for _, p := range plugins {
if inj, ok := p.(plugin.EarlySourceInjector); ok {
if s := inj.InjectSourceEarly(); s != nil {
cfg.Sources = append(cfg.Sources, s)
}
}
}
if err := cfg.LoadSchema(); err != nil {
return fmt.Errorf("failed to load schema: %w", err)
}
for _, p := range plugins {
if inj, ok := p.(plugin.LateSourceInjector); ok {
if s := inj.InjectSourceLate(cfg.Schema); s != nil {
cfg.Sources = append(cfg.Sources, s)
}
}
}
// LoadSchema again now we have everything
if err := cfg.LoadSchema(); err != nil {
return fmt.Errorf("failed to load schema: %w", err)
}
if err := cfg.Init(); err != nil {
return fmt.Errorf("generating core failed: %w", err)
}
for _, p := range plugins {
if mut, ok := p.(plugin.ConfigMutator); ok {
err := mut.MutateConfig(cfg)
if err != nil {
return fmt.Errorf("%s: %w", p.Name(), err)
}
}
}
// Merge again now that the generated models have been injected into the typemap
data, err := codegen.BuildData(cfg)
if err != nil {
return fmt.Errorf("merging type systems failed: %w", err)
}
if err = codegen.GenerateCode(data); err != nil {
return fmt.Errorf("generating core failed: %w", err)
}
if !cfg.SkipModTidy {
if err = cfg.Packages.ModTidy(); err != nil {
return fmt.Errorf("tidy failed: %w", err)
}
}
for _, p := range plugins {
if mut, ok := p.(plugin.CodeGenerator); ok {
err := mut.GenerateCode(data)
if err != nil {
return fmt.Errorf("%s: %w", p.Name(), err)
}
}
}
if err = codegen.GenerateCode(data); err != nil {
return fmt.Errorf("generating core failed: %w", err)
}
if !cfg.SkipValidation {
if err := validate(cfg); err != nil {
return fmt.Errorf("validation failed: %w", err)
}
}
return nil
}
func validate(cfg *config.Config) error {
roots := []string{cfg.Exec.ImportPath()}
if cfg.Model.IsDefined() {
roots = append(roots, cfg.Model.ImportPath())
}
if cfg.Resolver.IsDefined() {
roots = append(roots, cfg.Resolver.ImportPath())
}
cfg.Packages.LoadAll(roots...)
errs := cfg.Packages.Errors()
if len(errs) > 0 {
return errs
}
return nil
}

View File

@@ -1,47 +0,0 @@
package api
import (
"github.com/99designs/gqlgen/codegen/config"
"github.com/99designs/gqlgen/plugin"
)
type Option func(cfg *config.Config, plugins *[]plugin.Plugin)
func NoPlugins() Option {
return func(cfg *config.Config, plugins *[]plugin.Plugin) {
*plugins = nil
}
}
func AddPlugin(p plugin.Plugin) Option {
return func(cfg *config.Config, plugins *[]plugin.Plugin) {
*plugins = append(*plugins, p)
}
}
// PrependPlugin prepends plugin any existing plugins
func PrependPlugin(p plugin.Plugin) Option {
return func(cfg *config.Config, plugins *[]plugin.Plugin) {
*plugins = append([]plugin.Plugin{p}, *plugins...)
}
}
// ReplacePlugin replaces any existing plugin with a matching plugin name
func ReplacePlugin(p plugin.Plugin) Option {
return func(cfg *config.Config, plugins *[]plugin.Plugin) {
if plugins != nil {
found := false
ps := *plugins
for i, o := range ps {
if p.Name() == o.Name() {
ps[i] = p
found = true
}
}
if !found {
ps = append(ps, p)
}
*plugins = ps
}
}
}

View File

@@ -1,118 +0,0 @@
package codegen
import (
"fmt"
"go/types"
"strings"
"github.com/99designs/gqlgen/codegen/config"
"github.com/99designs/gqlgen/codegen/templates"
"github.com/vektah/gqlparser/v2/ast"
)
type ArgSet struct {
Args []*FieldArgument
FuncDecl string
}
type FieldArgument struct {
*ast.ArgumentDefinition
TypeReference *config.TypeReference
VarName string // The name of the var in go
Object *Object // A link back to the parent object
Default interface{} // The default value
Directives []*Directive
Value interface{} // value set in Data
}
// ImplDirectives get not Builtin and location ARGUMENT_DEFINITION directive
func (f *FieldArgument) ImplDirectives() []*Directive {
d := make([]*Directive, 0)
for i := range f.Directives {
if !f.Directives[i].Builtin && f.Directives[i].IsLocation(ast.LocationArgumentDefinition) {
d = append(d, f.Directives[i])
}
}
return d
}
func (f *FieldArgument) DirectiveObjName() string {
return "rawArgs"
}
func (f *FieldArgument) Stream() bool {
return f.Object != nil && f.Object.Stream
}
func (b *builder) buildArg(obj *Object, arg *ast.ArgumentDefinition) (*FieldArgument, error) {
tr, err := b.Binder.TypeReference(arg.Type, nil)
if err != nil {
return nil, err
}
argDirs, err := b.getDirectives(arg.Directives)
if err != nil {
return nil, err
}
newArg := FieldArgument{
ArgumentDefinition: arg,
TypeReference: tr,
Object: obj,
VarName: templates.ToGoPrivate(arg.Name),
Directives: argDirs,
}
if arg.DefaultValue != nil {
newArg.Default, err = arg.DefaultValue.Value(nil)
if err != nil {
return nil, fmt.Errorf("default value is not valid: %w", err)
}
}
return &newArg, nil
}
func (b *builder) bindArgs(field *Field, params *types.Tuple) ([]*FieldArgument, error) {
var newArgs []*FieldArgument
nextArg:
for j := 0; j < params.Len(); j++ {
param := params.At(j)
for _, oldArg := range field.Args {
if strings.EqualFold(oldArg.Name, param.Name()) {
tr, err := b.Binder.TypeReference(oldArg.Type, param.Type())
if err != nil {
return nil, err
}
oldArg.TypeReference = tr
newArgs = append(newArgs, oldArg)
continue nextArg
}
}
// no matching arg found, abort
return nil, fmt.Errorf("arg %s not in schema", param.Name())
}
return newArgs, nil
}
func (a *Data) Args() map[string][]*FieldArgument {
ret := map[string][]*FieldArgument{}
for _, o := range a.Objects {
for _, f := range o.Fields {
if len(f.Args) > 0 {
ret[f.ArgsFunc()] = f.Args
}
}
}
for _, d := range a.Directives() {
if len(d.Args) > 0 {
ret[d.ArgsFunc()] = d.Args
}
}
return ret
}

View File

@@ -1,36 +0,0 @@
{{ range $name, $args := .Args }}
func (ec *executionContext) {{ $name }}(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) {
var err error
args := map[string]interface{}{}
{{- range $i, $arg := . }}
var arg{{$i}} {{ $arg.TypeReference.GO | ref}}
if tmp, ok := rawArgs[{{$arg.Name|quote}}]; ok {
ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField({{$arg.Name|quote}}))
{{- if $arg.ImplDirectives }}
directive0 := func(ctx context.Context) (interface{}, error) { return ec.{{ $arg.TypeReference.UnmarshalFunc }}(ctx, tmp) }
{{ template "implDirectives" $arg }}
tmp, err = directive{{$arg.ImplDirectives|len}}(ctx)
if err != nil {
return nil, graphql.ErrorOnPath(ctx, err)
}
if data, ok := tmp.({{ $arg.TypeReference.GO | ref }}) ; ok {
arg{{$i}} = data
{{- if $arg.TypeReference.IsNilable }}
} else if tmp == nil {
arg{{$i}} = nil
{{- end }}
} else {
return nil, graphql.ErrorOnPath(ctx, fmt.Errorf(`unexpected type %T from directive, should be {{ $arg.TypeReference.GO }}`, tmp))
}
{{- else }}
arg{{$i}}, err = ec.{{ $arg.TypeReference.UnmarshalFunc }}(ctx, tmp)
if err != nil {
return nil, err
}
{{- end }}
}
args[{{$arg.Name|quote}}] = arg{{$i}}
{{- end }}
return args, nil
}
{{ end }}

View File

@@ -1,11 +0,0 @@
package codegen
func (o *Object) UniqueFields() map[string][]*Field {
m := map[string][]*Field{}
for _, f := range o.Fields {
m[f.GoFieldName] = append(m[f.GoFieldName], f)
}
return m
}

View File

@@ -1,494 +0,0 @@
package config
import (
"errors"
"fmt"
"go/token"
"go/types"
"golang.org/x/tools/go/packages"
"github.com/99designs/gqlgen/codegen/templates"
"github.com/99designs/gqlgen/internal/code"
"github.com/vektah/gqlparser/v2/ast"
)
var ErrTypeNotFound = errors.New("unable to find type")
// Binder connects graphql types to golang types using static analysis
type Binder struct {
pkgs *code.Packages
schema *ast.Schema
cfg *Config
References []*TypeReference
SawInvalid bool
objectCache map[string]map[string]types.Object
}
func (c *Config) NewBinder() *Binder {
return &Binder{
pkgs: c.Packages,
schema: c.Schema,
cfg: c,
}
}
func (b *Binder) TypePosition(typ types.Type) token.Position {
named, isNamed := typ.(*types.Named)
if !isNamed {
return token.Position{
Filename: "unknown",
}
}
return b.ObjectPosition(named.Obj())
}
func (b *Binder) ObjectPosition(typ types.Object) token.Position {
if typ == nil {
return token.Position{
Filename: "unknown",
}
}
pkg := b.pkgs.Load(typ.Pkg().Path())
return pkg.Fset.Position(typ.Pos())
}
func (b *Binder) FindTypeFromName(name string) (types.Type, error) {
pkgName, typeName := code.PkgAndType(name)
return b.FindType(pkgName, typeName)
}
func (b *Binder) FindType(pkgName string, typeName string) (types.Type, error) {
if pkgName == "" {
if typeName == "map[string]interface{}" {
return MapType, nil
}
if typeName == "interface{}" {
return InterfaceType, nil
}
}
obj, err := b.FindObject(pkgName, typeName)
if err != nil {
return nil, err
}
if fun, isFunc := obj.(*types.Func); isFunc {
return fun.Type().(*types.Signature).Params().At(0).Type(), nil
}
return obj.Type(), nil
}
var (
MapType = types.NewMap(types.Typ[types.String], types.NewInterfaceType(nil, nil).Complete())
InterfaceType = types.NewInterfaceType(nil, nil)
)
func (b *Binder) DefaultUserObject(name string) (types.Type, error) {
models := b.cfg.Models[name].Model
if len(models) == 0 {
return nil, fmt.Errorf(name + " not found in typemap")
}
if models[0] == "map[string]interface{}" {
return MapType, nil
}
if models[0] == "interface{}" {
return InterfaceType, nil
}
pkgName, typeName := code.PkgAndType(models[0])
if pkgName == "" {
return nil, fmt.Errorf("missing package name for %s", name)
}
obj, err := b.FindObject(pkgName, typeName)
if err != nil {
return nil, err
}
return obj.Type(), nil
}
func (b *Binder) FindObject(pkgName string, typeName string) (types.Object, error) {
if pkgName == "" {
return nil, fmt.Errorf("package cannot be nil")
}
pkg := b.pkgs.LoadWithTypes(pkgName)
if pkg == nil {
err := b.pkgs.Errors()
if err != nil {
return nil, fmt.Errorf("package could not be loaded: %s.%s: %w", pkgName, typeName, err)
}
return nil, fmt.Errorf("required package was not loaded: %s.%s", pkgName, typeName)
}
if b.objectCache == nil {
b.objectCache = make(map[string]map[string]types.Object, b.pkgs.Count())
}
defsIndex, ok := b.objectCache[pkgName]
if !ok {
defsIndex = indexDefs(pkg)
b.objectCache[pkgName] = defsIndex
}
// function based marshalers take precedence
if val, ok := defsIndex["Marshal"+typeName]; ok {
return val, nil
}
if val, ok := defsIndex[typeName]; ok {
return val, nil
}
return nil, fmt.Errorf("%w: %s.%s", ErrTypeNotFound, pkgName, typeName)
}
func indexDefs(pkg *packages.Package) map[string]types.Object {
res := make(map[string]types.Object)
scope := pkg.Types.Scope()
for astNode, def := range pkg.TypesInfo.Defs {
// only look at defs in the top scope
if def == nil {
continue
}
parent := def.Parent()
if parent == nil || parent != scope {
continue
}
if _, ok := res[astNode.Name]; !ok {
// The above check may not be really needed, it is only here to have a consistent behavior with
// previous implementation of FindObject() function which only honored the first inclusion of a def.
// If this is still needed, we can consider something like sync.Map.LoadOrStore() to avoid two lookups.
res[astNode.Name] = def
}
}
return res
}
func (b *Binder) PointerTo(ref *TypeReference) *TypeReference {
newRef := *ref
newRef.GO = types.NewPointer(ref.GO)
b.References = append(b.References, &newRef)
return &newRef
}
// TypeReference is used by args and field types. The Definition can refer to both input and output types.
type TypeReference struct {
Definition *ast.Definition
GQL *ast.Type
GO types.Type // Type of the field being bound. Could be a pointer or a value type of Target.
Target types.Type // The actual type that we know how to bind to. May require pointer juggling when traversing to fields.
CastType types.Type // Before calling marshalling functions cast from/to this base type
Marshaler *types.Func // When using external marshalling functions this will point to the Marshal function
Unmarshaler *types.Func // When using external marshalling functions this will point to the Unmarshal function
IsMarshaler bool // Does the type implement graphql.Marshaler and graphql.Unmarshaler
IsContext bool // Is the Marshaler/Unmarshaller the context version; applies to either the method or interface variety.
}
func (ref *TypeReference) Elem() *TypeReference {
if p, isPtr := ref.GO.(*types.Pointer); isPtr {
newRef := *ref
newRef.GO = p.Elem()
return &newRef
}
if ref.IsSlice() {
newRef := *ref
newRef.GO = ref.GO.(*types.Slice).Elem()
newRef.GQL = ref.GQL.Elem
return &newRef
}
return nil
}
func (t *TypeReference) IsPtr() bool {
_, isPtr := t.GO.(*types.Pointer)
return isPtr
}
// fix for https://github.com/golang/go/issues/31103 may make it possible to remove this (may still be useful)
//
func (t *TypeReference) IsPtrToPtr() bool {
if p, isPtr := t.GO.(*types.Pointer); isPtr {
_, isPtr := p.Elem().(*types.Pointer)
return isPtr
}
return false
}
func (t *TypeReference) IsNilable() bool {
return IsNilable(t.GO)
}
func (t *TypeReference) IsSlice() bool {
_, isSlice := t.GO.(*types.Slice)
return t.GQL.Elem != nil && isSlice
}
func (t *TypeReference) IsPtrToSlice() bool {
if t.IsPtr() {
_, isPointerToSlice := t.GO.(*types.Pointer).Elem().(*types.Slice)
return isPointerToSlice
}
return false
}
func (t *TypeReference) IsNamed() bool {
_, isSlice := t.GO.(*types.Named)
return isSlice
}
func (t *TypeReference) IsStruct() bool {
_, isStruct := t.GO.Underlying().(*types.Struct)
return isStruct
}
func (t *TypeReference) IsScalar() bool {
return t.Definition.Kind == ast.Scalar
}
func (t *TypeReference) UniquenessKey() string {
nullability := "O"
if t.GQL.NonNull {
nullability = "N"
}
elemNullability := ""
if t.GQL.Elem != nil && t.GQL.Elem.NonNull {
// Fix for #896
elemNullability = "ᚄ"
}
return nullability + t.Definition.Name + "2" + templates.TypeIdentifier(t.GO) + elemNullability
}
func (t *TypeReference) MarshalFunc() string {
if t.Definition == nil {
panic(errors.New("Definition missing for " + t.GQL.Name()))
}
if t.Definition.Kind == ast.InputObject {
return ""
}
return "marshal" + t.UniquenessKey()
}
func (t *TypeReference) UnmarshalFunc() string {
if t.Definition == nil {
panic(errors.New("Definition missing for " + t.GQL.Name()))
}
if !t.Definition.IsInputType() {
return ""
}
return "unmarshal" + t.UniquenessKey()
}
func (t *TypeReference) IsTargetNilable() bool {
return IsNilable(t.Target)
}
func (b *Binder) PushRef(ret *TypeReference) {
b.References = append(b.References, ret)
}
func isMap(t types.Type) bool {
if t == nil {
return true
}
_, ok := t.(*types.Map)
return ok
}
func isIntf(t types.Type) bool {
if t == nil {
return true
}
_, ok := t.(*types.Interface)
return ok
}
func (b *Binder) TypeReference(schemaType *ast.Type, bindTarget types.Type) (ret *TypeReference, err error) {
if !isValid(bindTarget) {
b.SawInvalid = true
return nil, fmt.Errorf("%s has an invalid type", schemaType.Name())
}
var pkgName, typeName string
def := b.schema.Types[schemaType.Name()]
defer func() {
if err == nil && ret != nil {
b.PushRef(ret)
}
}()
if len(b.cfg.Models[schemaType.Name()].Model) == 0 {
return nil, fmt.Errorf("%s was not found", schemaType.Name())
}
for _, model := range b.cfg.Models[schemaType.Name()].Model {
if model == "map[string]interface{}" {
if !isMap(bindTarget) {
continue
}
return &TypeReference{
Definition: def,
GQL: schemaType,
GO: MapType,
}, nil
}
if model == "interface{}" {
if !isIntf(bindTarget) {
continue
}
return &TypeReference{
Definition: def,
GQL: schemaType,
GO: InterfaceType,
}, nil
}
pkgName, typeName = code.PkgAndType(model)
if pkgName == "" {
return nil, fmt.Errorf("missing package name for %s", schemaType.Name())
}
ref := &TypeReference{
Definition: def,
GQL: schemaType,
}
obj, err := b.FindObject(pkgName, typeName)
if err != nil {
return nil, err
}
if fun, isFunc := obj.(*types.Func); isFunc {
ref.GO = fun.Type().(*types.Signature).Params().At(0).Type()
ref.IsContext = fun.Type().(*types.Signature).Results().At(0).Type().String() == "github.com/99designs/gqlgen/graphql.ContextMarshaler"
ref.Marshaler = fun
ref.Unmarshaler = types.NewFunc(0, fun.Pkg(), "Unmarshal"+typeName, nil)
} else if hasMethod(obj.Type(), "MarshalGQLContext") && hasMethod(obj.Type(), "UnmarshalGQLContext") {
ref.GO = obj.Type()
ref.IsContext = true
ref.IsMarshaler = true
} else if hasMethod(obj.Type(), "MarshalGQL") && hasMethod(obj.Type(), "UnmarshalGQL") {
ref.GO = obj.Type()
ref.IsMarshaler = true
} else if underlying := basicUnderlying(obj.Type()); def.IsLeafType() && underlying != nil && underlying.Kind() == types.String {
// TODO delete before v1. Backwards compatibility case for named types wrapping strings (see #595)
ref.GO = obj.Type()
ref.CastType = underlying
underlyingRef, err := b.TypeReference(&ast.Type{NamedType: "String"}, nil)
if err != nil {
return nil, err
}
ref.Marshaler = underlyingRef.Marshaler
ref.Unmarshaler = underlyingRef.Unmarshaler
} else {
ref.GO = obj.Type()
}
ref.Target = ref.GO
ref.GO = b.CopyModifiersFromAst(schemaType, ref.GO)
if bindTarget != nil {
if err = code.CompatibleTypes(ref.GO, bindTarget); err != nil {
continue
}
ref.GO = bindTarget
}
return ref, nil
}
return nil, fmt.Errorf("%s is incompatible with %s", schemaType.Name(), bindTarget.String())
}
func isValid(t types.Type) bool {
basic, isBasic := t.(*types.Basic)
if !isBasic {
return true
}
return basic.Kind() != types.Invalid
}
func (b *Binder) CopyModifiersFromAst(t *ast.Type, base types.Type) types.Type {
if t.Elem != nil {
child := b.CopyModifiersFromAst(t.Elem, base)
if _, isStruct := child.Underlying().(*types.Struct); isStruct && !b.cfg.OmitSliceElementPointers {
child = types.NewPointer(child)
}
return types.NewSlice(child)
}
var isInterface bool
if named, ok := base.(*types.Named); ok {
_, isInterface = named.Underlying().(*types.Interface)
}
if !isInterface && !IsNilable(base) && !t.NonNull {
return types.NewPointer(base)
}
return base
}
func IsNilable(t types.Type) bool {
if namedType, isNamed := t.(*types.Named); isNamed {
return IsNilable(namedType.Underlying())
}
_, isPtr := t.(*types.Pointer)
_, isMap := t.(*types.Map)
_, isInterface := t.(*types.Interface)
_, isSlice := t.(*types.Slice)
_, isChan := t.(*types.Chan)
return isPtr || isMap || isInterface || isSlice || isChan
}
func hasMethod(it types.Type, name string) bool {
if ptr, isPtr := it.(*types.Pointer); isPtr {
it = ptr.Elem()
}
namedType, ok := it.(*types.Named)
if !ok {
return false
}
for i := 0; i < namedType.NumMethods(); i++ {
if namedType.Method(i).Name() == name {
return true
}
}
return false
}
func basicUnderlying(it types.Type) *types.Basic {
if ptr, isPtr := it.(*types.Pointer); isPtr {
it = ptr.Elem()
}
namedType, ok := it.(*types.Named)
if !ok {
return nil
}
if basic, ok := namedType.Underlying().(*types.Basic); ok {
return basic
}
return nil
}

View File

@@ -1,648 +0,0 @@
package config
import (
"fmt"
"io/ioutil"
"os"
"path/filepath"
"regexp"
"sort"
"strings"
"github.com/99designs/gqlgen/internal/code"
"github.com/vektah/gqlparser/v2"
"github.com/vektah/gqlparser/v2/ast"
"gopkg.in/yaml.v2"
)
type Config struct {
SchemaFilename StringList `yaml:"schema,omitempty"`
Exec ExecConfig `yaml:"exec"`
Model PackageConfig `yaml:"model,omitempty"`
Federation PackageConfig `yaml:"federation,omitempty"`
Resolver ResolverConfig `yaml:"resolver,omitempty"`
AutoBind []string `yaml:"autobind"`
Models TypeMap `yaml:"models,omitempty"`
StructTag string `yaml:"struct_tag,omitempty"`
Directives map[string]DirectiveConfig `yaml:"directives,omitempty"`
OmitSliceElementPointers bool `yaml:"omit_slice_element_pointers,omitempty"`
SkipValidation bool `yaml:"skip_validation,omitempty"`
SkipModTidy bool `yaml:"skip_mod_tidy,omitempty"`
Sources []*ast.Source `yaml:"-"`
Packages *code.Packages `yaml:"-"`
Schema *ast.Schema `yaml:"-"`
// Deprecated: use Federation instead. Will be removed next release
Federated bool `yaml:"federated,omitempty"`
}
var cfgFilenames = []string{".gqlgen.yml", "gqlgen.yml", "gqlgen.yaml"}
// DefaultConfig creates a copy of the default config
func DefaultConfig() *Config {
return &Config{
SchemaFilename: StringList{"schema.graphql"},
Model: PackageConfig{Filename: "models_gen.go"},
Exec: ExecConfig{Filename: "generated.go"},
Directives: map[string]DirectiveConfig{},
Models: TypeMap{},
}
}
// LoadDefaultConfig loads the default config so that it is ready to be used
func LoadDefaultConfig() (*Config, error) {
config := DefaultConfig()
for _, filename := range config.SchemaFilename {
filename = filepath.ToSlash(filename)
var err error
var schemaRaw []byte
schemaRaw, err = ioutil.ReadFile(filename)
if err != nil {
return nil, fmt.Errorf("unable to open schema: %w", err)
}
config.Sources = append(config.Sources, &ast.Source{Name: filename, Input: string(schemaRaw)})
}
return config, nil
}
// LoadConfigFromDefaultLocations looks for a config file in the current directory, and all parent directories
// walking up the tree. The closest config file will be returned.
func LoadConfigFromDefaultLocations() (*Config, error) {
cfgFile, err := findCfg()
if err != nil {
return nil, err
}
err = os.Chdir(filepath.Dir(cfgFile))
if err != nil {
return nil, fmt.Errorf("unable to enter config dir: %w", err)
}
return LoadConfig(cfgFile)
}
var path2regex = strings.NewReplacer(
`.`, `\.`,
`*`, `.+`,
`\`, `[\\/]`,
`/`, `[\\/]`,
)
// LoadConfig reads the gqlgen.yml config file
func LoadConfig(filename string) (*Config, error) {
config := DefaultConfig()
b, err := ioutil.ReadFile(filename)
if err != nil {
return nil, fmt.Errorf("unable to read config: %w", err)
}
if err := yaml.UnmarshalStrict(b, config); err != nil {
return nil, fmt.Errorf("unable to parse config: %w", err)
}
if err := CompleteConfig(config); err != nil {
return nil, err
}
return config, nil
}
// CompleteConfig fills in the schema and other values to a config loaded from
// YAML.
func CompleteConfig(config *Config) error {
defaultDirectives := map[string]DirectiveConfig{
"skip": {SkipRuntime: true},
"include": {SkipRuntime: true},
"deprecated": {SkipRuntime: true},
"specifiedBy": {SkipRuntime: true},
}
for key, value := range defaultDirectives {
if _, defined := config.Directives[key]; !defined {
config.Directives[key] = value
}
}
preGlobbing := config.SchemaFilename
config.SchemaFilename = StringList{}
for _, f := range preGlobbing {
var matches []string
// for ** we want to override default globbing patterns and walk all
// subdirectories to match schema files.
if strings.Contains(f, "**") {
pathParts := strings.SplitN(f, "**", 2)
rest := strings.TrimPrefix(strings.TrimPrefix(pathParts[1], `\`), `/`)
// turn the rest of the glob into a regex, anchored only at the end because ** allows
// for any number of dirs in between and walk will let us match against the full path name
globRe := regexp.MustCompile(path2regex.Replace(rest) + `$`)
if err := filepath.Walk(pathParts[0], func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if globRe.MatchString(strings.TrimPrefix(path, pathParts[0])) {
matches = append(matches, path)
}
return nil
}); err != nil {
return fmt.Errorf("failed to walk schema at root %s: %w", pathParts[0], err)
}
} else {
var err error
matches, err = filepath.Glob(f)
if err != nil {
return fmt.Errorf("failed to glob schema filename %s: %w", f, err)
}
}
for _, m := range matches {
if config.SchemaFilename.Has(m) {
continue
}
config.SchemaFilename = append(config.SchemaFilename, m)
}
}
for _, filename := range config.SchemaFilename {
filename = filepath.ToSlash(filename)
var err error
var schemaRaw []byte
schemaRaw, err = ioutil.ReadFile(filename)
if err != nil {
return fmt.Errorf("unable to open schema: %w", err)
}
config.Sources = append(config.Sources, &ast.Source{Name: filename, Input: string(schemaRaw)})
}
return nil
}
func (c *Config) Init() error {
if c.Packages == nil {
c.Packages = &code.Packages{}
}
if c.Schema == nil {
if err := c.LoadSchema(); err != nil {
return err
}
}
err := c.injectTypesFromSchema()
if err != nil {
return err
}
err = c.autobind()
if err != nil {
return err
}
c.injectBuiltins()
// prefetch all packages in one big packages.Load call
c.Packages.LoadAll(c.packageList()...)
// check everything is valid on the way out
err = c.check()
if err != nil {
return err
}
return nil
}
func (c *Config) packageList() []string {
pkgs := []string{
"github.com/99designs/gqlgen/graphql",
"github.com/99designs/gqlgen/graphql/introspection",
}
pkgs = append(pkgs, c.Models.ReferencedPackages()...)
pkgs = append(pkgs, c.AutoBind...)
return pkgs
}
func (c *Config) ReloadAllPackages() {
c.Packages.ReloadAll(c.packageList()...)
}
func (c *Config) injectTypesFromSchema() error {
c.Directives["goModel"] = DirectiveConfig{
SkipRuntime: true,
}
c.Directives["goField"] = DirectiveConfig{
SkipRuntime: true,
}
c.Directives["goTag"] = DirectiveConfig{
SkipRuntime: true,
}
for _, schemaType := range c.Schema.Types {
if schemaType == c.Schema.Query || schemaType == c.Schema.Mutation || schemaType == c.Schema.Subscription {
continue
}
if bd := schemaType.Directives.ForName("goModel"); bd != nil {
if ma := bd.Arguments.ForName("model"); ma != nil {
if mv, err := ma.Value.Value(nil); err == nil {
c.Models.Add(schemaType.Name, mv.(string))
}
}
if ma := bd.Arguments.ForName("models"); ma != nil {
if mvs, err := ma.Value.Value(nil); err == nil {
for _, mv := range mvs.([]interface{}) {
c.Models.Add(schemaType.Name, mv.(string))
}
}
}
}
if schemaType.Kind == ast.Object || schemaType.Kind == ast.InputObject {
for _, field := range schemaType.Fields {
if fd := field.Directives.ForName("goField"); fd != nil {
forceResolver := c.Models[schemaType.Name].Fields[field.Name].Resolver
fieldName := c.Models[schemaType.Name].Fields[field.Name].FieldName
if ra := fd.Arguments.ForName("forceResolver"); ra != nil {
if fr, err := ra.Value.Value(nil); err == nil {
forceResolver = fr.(bool)
}
}
if na := fd.Arguments.ForName("name"); na != nil {
if fr, err := na.Value.Value(nil); err == nil {
fieldName = fr.(string)
}
}
if c.Models[schemaType.Name].Fields == nil {
c.Models[schemaType.Name] = TypeMapEntry{
Model: c.Models[schemaType.Name].Model,
Fields: map[string]TypeMapField{},
}
}
c.Models[schemaType.Name].Fields[field.Name] = TypeMapField{
FieldName: fieldName,
Resolver: forceResolver,
}
}
}
}
}
return nil
}
type TypeMapEntry struct {
Model StringList `yaml:"model"`
Fields map[string]TypeMapField `yaml:"fields,omitempty"`
}
type TypeMapField struct {
Resolver bool `yaml:"resolver"`
FieldName string `yaml:"fieldName"`
GeneratedMethod string `yaml:"-"`
}
type StringList []string
func (a *StringList) UnmarshalYAML(unmarshal func(interface{}) error) error {
var single string
err := unmarshal(&single)
if err == nil {
*a = []string{single}
return nil
}
var multi []string
err = unmarshal(&multi)
if err != nil {
return err
}
*a = multi
return nil
}
func (a StringList) Has(file string) bool {
for _, existing := range a {
if existing == file {
return true
}
}
return false
}
func (c *Config) check() error {
if c.Models == nil {
c.Models = TypeMap{}
}
type FilenamePackage struct {
Filename string
Package string
Declaree string
}
fileList := map[string][]FilenamePackage{}
if err := c.Models.Check(); err != nil {
return fmt.Errorf("config.models: %w", err)
}
if err := c.Exec.Check(); err != nil {
return fmt.Errorf("config.exec: %w", err)
}
fileList[c.Exec.ImportPath()] = append(fileList[c.Exec.ImportPath()], FilenamePackage{
Filename: c.Exec.Filename,
Package: c.Exec.Package,
Declaree: "exec",
})
if c.Model.IsDefined() {
if err := c.Model.Check(); err != nil {
return fmt.Errorf("config.model: %w", err)
}
fileList[c.Model.ImportPath()] = append(fileList[c.Model.ImportPath()], FilenamePackage{
Filename: c.Model.Filename,
Package: c.Model.Package,
Declaree: "model",
})
}
if c.Resolver.IsDefined() {
if err := c.Resolver.Check(); err != nil {
return fmt.Errorf("config.resolver: %w", err)
}
fileList[c.Resolver.ImportPath()] = append(fileList[c.Resolver.ImportPath()], FilenamePackage{
Filename: c.Resolver.Filename,
Package: c.Resolver.Package,
Declaree: "resolver",
})
}
if c.Federation.IsDefined() {
if err := c.Federation.Check(); err != nil {
return fmt.Errorf("config.federation: %w", err)
}
fileList[c.Federation.ImportPath()] = append(fileList[c.Federation.ImportPath()], FilenamePackage{
Filename: c.Federation.Filename,
Package: c.Federation.Package,
Declaree: "federation",
})
if c.Federation.ImportPath() != c.Exec.ImportPath() {
return fmt.Errorf("federation and exec must be in the same package")
}
}
if c.Federated {
return fmt.Errorf("federated has been removed, instead use\nfederation:\n filename: path/to/federated.go")
}
for importPath, pkg := range fileList {
for _, file1 := range pkg {
for _, file2 := range pkg {
if file1.Package != file2.Package {
return fmt.Errorf("%s and %s define the same import path (%s) with different package names (%s vs %s)",
file1.Declaree,
file2.Declaree,
importPath,
file1.Package,
file2.Package,
)
}
}
}
}
return nil
}
type TypeMap map[string]TypeMapEntry
func (tm TypeMap) Exists(typeName string) bool {
_, ok := tm[typeName]
return ok
}
func (tm TypeMap) UserDefined(typeName string) bool {
m, ok := tm[typeName]
return ok && len(m.Model) > 0
}
func (tm TypeMap) Check() error {
for typeName, entry := range tm {
for _, model := range entry.Model {
if strings.LastIndex(model, ".") < strings.LastIndex(model, "/") {
return fmt.Errorf("model %s: invalid type specifier \"%s\" - you need to specify a struct to map to", typeName, entry.Model)
}
}
}
return nil
}
func (tm TypeMap) ReferencedPackages() []string {
var pkgs []string
for _, typ := range tm {
for _, model := range typ.Model {
if model == "map[string]interface{}" || model == "interface{}" {
continue
}
pkg, _ := code.PkgAndType(model)
if pkg == "" || inStrSlice(pkgs, pkg) {
continue
}
pkgs = append(pkgs, code.QualifyPackagePath(pkg))
}
}
sort.Slice(pkgs, func(i, j int) bool {
return pkgs[i] > pkgs[j]
})
return pkgs
}
func (tm TypeMap) Add(name string, goType string) {
modelCfg := tm[name]
modelCfg.Model = append(modelCfg.Model, goType)
tm[name] = modelCfg
}
type DirectiveConfig struct {
SkipRuntime bool `yaml:"skip_runtime"`
}
func inStrSlice(haystack []string, needle string) bool {
for _, v := range haystack {
if needle == v {
return true
}
}
return false
}
// findCfg searches for the config file in this directory and all parents up the tree
// looking for the closest match
func findCfg() (string, error) {
dir, err := os.Getwd()
if err != nil {
return "", fmt.Errorf("unable to get working dir to findCfg: %w", err)
}
cfg := findCfgInDir(dir)
for cfg == "" && dir != filepath.Dir(dir) {
dir = filepath.Dir(dir)
cfg = findCfgInDir(dir)
}
if cfg == "" {
return "", os.ErrNotExist
}
return cfg, nil
}
func findCfgInDir(dir string) string {
for _, cfgName := range cfgFilenames {
path := filepath.Join(dir, cfgName)
if _, err := os.Stat(path); err == nil {
return path
}
}
return ""
}
func (c *Config) autobind() error {
if len(c.AutoBind) == 0 {
return nil
}
ps := c.Packages.LoadAll(c.AutoBind...)
for _, t := range c.Schema.Types {
if c.Models.UserDefined(t.Name) {
continue
}
for i, p := range ps {
if p == nil || p.Module == nil {
return fmt.Errorf("unable to load %s - make sure you're using an import path to a package that exists", c.AutoBind[i])
}
if t := p.Types.Scope().Lookup(t.Name); t != nil {
c.Models.Add(t.Name(), t.Pkg().Path()+"."+t.Name())
break
}
}
}
for i, t := range c.Models {
for j, m := range t.Model {
pkg, typename := code.PkgAndType(m)
// skip anything that looks like an import path
if strings.Contains(pkg, "/") {
continue
}
for _, p := range ps {
if p.Name != pkg {
continue
}
if t := p.Types.Scope().Lookup(typename); t != nil {
c.Models[i].Model[j] = t.Pkg().Path() + "." + t.Name()
break
}
}
}
}
return nil
}
func (c *Config) injectBuiltins() {
builtins := TypeMap{
"__Directive": {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.Directive"}},
"__DirectiveLocation": {Model: StringList{"github.com/99designs/gqlgen/graphql.String"}},
"__Type": {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.Type"}},
"__TypeKind": {Model: StringList{"github.com/99designs/gqlgen/graphql.String"}},
"__Field": {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.Field"}},
"__EnumValue": {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.EnumValue"}},
"__InputValue": {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.InputValue"}},
"__Schema": {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.Schema"}},
"Float": {Model: StringList{"github.com/99designs/gqlgen/graphql.FloatContext"}},
"String": {Model: StringList{"github.com/99designs/gqlgen/graphql.String"}},
"Boolean": {Model: StringList{"github.com/99designs/gqlgen/graphql.Boolean"}},
"Int": {Model: StringList{
"github.com/99designs/gqlgen/graphql.Int",
"github.com/99designs/gqlgen/graphql.Int32",
"github.com/99designs/gqlgen/graphql.Int64",
}},
"ID": {
Model: StringList{
"github.com/99designs/gqlgen/graphql.ID",
"github.com/99designs/gqlgen/graphql.IntID",
},
},
}
for typeName, entry := range builtins {
if !c.Models.Exists(typeName) {
c.Models[typeName] = entry
}
}
// These are additional types that are injected if defined in the schema as scalars.
extraBuiltins := TypeMap{
"Time": {Model: StringList{"github.com/99designs/gqlgen/graphql.Time"}},
"Map": {Model: StringList{"github.com/99designs/gqlgen/graphql.Map"}},
"Upload": {Model: StringList{"github.com/99designs/gqlgen/graphql.Upload"}},
"Any": {Model: StringList{"github.com/99designs/gqlgen/graphql.Any"}},
}
for typeName, entry := range extraBuiltins {
if t, ok := c.Schema.Types[typeName]; !c.Models.Exists(typeName) && ok && t.Kind == ast.Scalar {
c.Models[typeName] = entry
}
}
}
func (c *Config) LoadSchema() error {
if c.Packages != nil {
c.Packages = &code.Packages{}
}
if err := c.check(); err != nil {
return err
}
schema, err := gqlparser.LoadSchema(c.Sources...)
if err != nil {
return err
}
if schema.Query == nil {
schema.Query = &ast.Definition{
Kind: ast.Object,
Name: "Query",
}
schema.Types["Query"] = schema.Query
}
c.Schema = schema
return nil
}
func abs(path string) string {
absPath, err := filepath.Abs(path)
if err != nil {
panic(err)
}
return filepath.ToSlash(absPath)
}

View File

@@ -1,97 +0,0 @@
package config
import (
"fmt"
"go/types"
"path/filepath"
"strings"
"github.com/99designs/gqlgen/internal/code"
)
type ExecConfig struct {
Package string `yaml:"package,omitempty"`
Layout ExecLayout `yaml:"layout,omitempty"` // Default: single-file
// Only for single-file layout:
Filename string `yaml:"filename,omitempty"`
// Only for follow-schema layout:
FilenameTemplate string `yaml:"filename_template,omitempty"` // String template with {name} as placeholder for base name.
DirName string `yaml:"dir"`
}
type ExecLayout string
var (
// Write all generated code to a single file.
ExecLayoutSingleFile ExecLayout = "single-file"
// Write generated code to a directory, generating one Go source file for each GraphQL schema file.
ExecLayoutFollowSchema ExecLayout = "follow-schema"
)
func (r *ExecConfig) Check() error {
if r.Layout == "" {
r.Layout = ExecLayoutSingleFile
}
switch r.Layout {
case ExecLayoutSingleFile:
if r.Filename == "" {
return fmt.Errorf("filename must be specified when using single-file layout")
}
if !strings.HasSuffix(r.Filename, ".go") {
return fmt.Errorf("filename should be path to a go source file when using single-file layout")
}
r.Filename = abs(r.Filename)
case ExecLayoutFollowSchema:
if r.DirName == "" {
return fmt.Errorf("dir must be specified when using follow-schema layout")
}
r.DirName = abs(r.DirName)
default:
return fmt.Errorf("invalid layout %s", r.Layout)
}
if strings.ContainsAny(r.Package, "./\\") {
return fmt.Errorf("package should be the output package name only, do not include the output filename")
}
if r.Package == "" && r.Dir() != "" {
r.Package = code.NameForDir(r.Dir())
}
return nil
}
func (r *ExecConfig) ImportPath() string {
if r.Dir() == "" {
return ""
}
return code.ImportPathForDir(r.Dir())
}
func (r *ExecConfig) Dir() string {
switch r.Layout {
case ExecLayoutSingleFile:
if r.Filename == "" {
return ""
}
return filepath.Dir(r.Filename)
case ExecLayoutFollowSchema:
return abs(r.DirName)
default:
panic("invalid layout " + r.Layout)
}
}
func (r *ExecConfig) Pkg() *types.Package {
if r.Dir() == "" {
return nil
}
return types.NewPackage(r.ImportPath(), r.Package)
}
func (r *ExecConfig) IsDefined() bool {
return r.Filename != "" || r.DirName != ""
}

View File

@@ -1,62 +0,0 @@
package config
import (
"fmt"
"go/types"
"path/filepath"
"strings"
"github.com/99designs/gqlgen/internal/code"
)
type PackageConfig struct {
Filename string `yaml:"filename,omitempty"`
Package string `yaml:"package,omitempty"`
}
func (c *PackageConfig) ImportPath() string {
if !c.IsDefined() {
return ""
}
return code.ImportPathForDir(c.Dir())
}
func (c *PackageConfig) Dir() string {
if !c.IsDefined() {
return ""
}
return filepath.Dir(c.Filename)
}
func (c *PackageConfig) Pkg() *types.Package {
if !c.IsDefined() {
return nil
}
return types.NewPackage(c.ImportPath(), c.Package)
}
func (c *PackageConfig) IsDefined() bool {
return c.Filename != ""
}
func (c *PackageConfig) Check() error {
if strings.ContainsAny(c.Package, "./\\") {
return fmt.Errorf("package should be the output package name only, do not include the output filename")
}
if c.Filename == "" {
return fmt.Errorf("filename must be specified")
}
if !strings.HasSuffix(c.Filename, ".go") {
return fmt.Errorf("filename should be path to a go source file")
}
c.Filename = abs(c.Filename)
// If Package is not set, first attempt to load the package at the output dir. If that fails
// fallback to just the base dir name of the output filename.
if c.Package == "" {
c.Package = code.NameForDir(c.Dir())
}
return nil
}

View File

@@ -1,100 +0,0 @@
package config
import (
"fmt"
"go/types"
"path/filepath"
"strings"
"github.com/99designs/gqlgen/internal/code"
)
type ResolverConfig struct {
Filename string `yaml:"filename,omitempty"`
FilenameTemplate string `yaml:"filename_template,omitempty"`
Package string `yaml:"package,omitempty"`
Type string `yaml:"type,omitempty"`
Layout ResolverLayout `yaml:"layout,omitempty"`
DirName string `yaml:"dir"`
}
type ResolverLayout string
var (
LayoutSingleFile ResolverLayout = "single-file"
LayoutFollowSchema ResolverLayout = "follow-schema"
)
func (r *ResolverConfig) Check() error {
if r.Layout == "" {
r.Layout = LayoutSingleFile
}
if r.Type == "" {
r.Type = "Resolver"
}
switch r.Layout {
case LayoutSingleFile:
if r.Filename == "" {
return fmt.Errorf("filename must be specified with layout=%s", r.Layout)
}
if !strings.HasSuffix(r.Filename, ".go") {
return fmt.Errorf("filename should be path to a go source file with layout=%s", r.Layout)
}
r.Filename = abs(r.Filename)
case LayoutFollowSchema:
if r.DirName == "" {
return fmt.Errorf("dirname must be specified with layout=%s", r.Layout)
}
r.DirName = abs(r.DirName)
if r.Filename == "" {
r.Filename = filepath.Join(r.DirName, "resolver.go")
} else {
r.Filename = abs(r.Filename)
}
default:
return fmt.Errorf("invalid layout %s. must be %s or %s", r.Layout, LayoutSingleFile, LayoutFollowSchema)
}
if strings.ContainsAny(r.Package, "./\\") {
return fmt.Errorf("package should be the output package name only, do not include the output filename")
}
if r.Package == "" && r.Dir() != "" {
r.Package = code.NameForDir(r.Dir())
}
return nil
}
func (r *ResolverConfig) ImportPath() string {
if r.Dir() == "" {
return ""
}
return code.ImportPathForDir(r.Dir())
}
func (r *ResolverConfig) Dir() string {
switch r.Layout {
case LayoutSingleFile:
if r.Filename == "" {
return ""
}
return filepath.Dir(r.Filename)
case LayoutFollowSchema:
return r.DirName
default:
panic("invalid layout " + r.Layout)
}
}
func (r *ResolverConfig) Pkg() *types.Package {
if r.Dir() == "" {
return nil
}
return types.NewPackage(r.ImportPath(), r.Package)
}
func (r *ResolverConfig) IsDefined() bool {
return r.Filename != "" || r.DirName != ""
}

View File

@@ -1,185 +0,0 @@
package codegen
import (
"fmt"
"sort"
"github.com/vektah/gqlparser/v2/ast"
"github.com/99designs/gqlgen/codegen/config"
)
// Data is a unified model of the code to be generated. Plugins may modify this structure to do things like implement
// resolvers or directives automatically (eg grpc, validation)
type Data struct {
Config *config.Config
Schema *ast.Schema
// If a schema is broken up into multiple Data instance, each representing part of the schema,
// AllDirectives should contain the directives for the entire schema. Directives() can
// then be used to get the directives that were defined in this Data instance's sources.
// If a single Data instance is used for the entire schema, AllDirectives and Directives()
// will be identical.
// AllDirectives should rarely be used directly.
AllDirectives DirectiveList
Objects Objects
Inputs Objects
Interfaces map[string]*Interface
ReferencedTypes map[string]*config.TypeReference
ComplexityRoots map[string]*Object
QueryRoot *Object
MutationRoot *Object
SubscriptionRoot *Object
}
type builder struct {
Config *config.Config
Schema *ast.Schema
Binder *config.Binder
Directives map[string]*Directive
}
// Get only the directives which are defined in the config's sources.
func (d *Data) Directives() DirectiveList {
res := DirectiveList{}
for k, directive := range d.AllDirectives {
for _, s := range d.Config.Sources {
if directive.Position.Src.Name == s.Name {
res[k] = directive
break
}
}
}
return res
}
func BuildData(cfg *config.Config) (*Data, error) {
// We reload all packages to allow packages to be compared correctly.
cfg.ReloadAllPackages()
b := builder{
Config: cfg,
Schema: cfg.Schema,
}
b.Binder = b.Config.NewBinder()
var err error
b.Directives, err = b.buildDirectives()
if err != nil {
return nil, err
}
dataDirectives := make(map[string]*Directive)
for name, d := range b.Directives {
if !d.Builtin {
dataDirectives[name] = d
}
}
s := Data{
Config: cfg,
AllDirectives: dataDirectives,
Schema: b.Schema,
Interfaces: map[string]*Interface{},
}
for _, schemaType := range b.Schema.Types {
switch schemaType.Kind {
case ast.Object:
obj, err := b.buildObject(schemaType)
if err != nil {
return nil, fmt.Errorf("unable to build object definition: %w", err)
}
s.Objects = append(s.Objects, obj)
case ast.InputObject:
input, err := b.buildObject(schemaType)
if err != nil {
return nil, fmt.Errorf("unable to build input definition: %w", err)
}
s.Inputs = append(s.Inputs, input)
case ast.Union, ast.Interface:
s.Interfaces[schemaType.Name], err = b.buildInterface(schemaType)
if err != nil {
return nil, fmt.Errorf("unable to bind to interface: %w", err)
}
}
}
if s.Schema.Query != nil {
s.QueryRoot = s.Objects.ByName(s.Schema.Query.Name)
} else {
return nil, fmt.Errorf("query entry point missing")
}
if s.Schema.Mutation != nil {
s.MutationRoot = s.Objects.ByName(s.Schema.Mutation.Name)
}
if s.Schema.Subscription != nil {
s.SubscriptionRoot = s.Objects.ByName(s.Schema.Subscription.Name)
}
if err := b.injectIntrospectionRoots(&s); err != nil {
return nil, err
}
s.ReferencedTypes = b.buildTypes()
sort.Slice(s.Objects, func(i, j int) bool {
return s.Objects[i].Definition.Name < s.Objects[j].Definition.Name
})
sort.Slice(s.Inputs, func(i, j int) bool {
return s.Inputs[i].Definition.Name < s.Inputs[j].Definition.Name
})
if b.Binder.SawInvalid {
// if we have a syntax error, show it
err := cfg.Packages.Errors()
if len(err) > 0 {
return nil, err
}
// otherwise show a generic error message
return nil, fmt.Errorf("invalid types were encountered while traversing the go source code, this probably means the invalid code generated isnt correct. add try adding -v to debug")
}
return &s, nil
}
func (b *builder) injectIntrospectionRoots(s *Data) error {
obj := s.Objects.ByName(b.Schema.Query.Name)
if obj == nil {
return fmt.Errorf("root query type must be defined")
}
__type, err := b.buildField(obj, &ast.FieldDefinition{
Name: "__type",
Type: ast.NamedType("__Type", nil),
Arguments: []*ast.ArgumentDefinition{
{
Name: "name",
Type: ast.NonNullNamedType("String", nil),
},
},
})
if err != nil {
return err
}
__schema, err := b.buildField(obj, &ast.FieldDefinition{
Name: "__schema",
Type: ast.NamedType("__Schema", nil),
})
if err != nil {
return err
}
obj.Fields = append(obj.Fields, __type, __schema)
return nil
}

View File

@@ -1,174 +0,0 @@
package codegen
import (
"fmt"
"strconv"
"strings"
"github.com/99designs/gqlgen/codegen/templates"
"github.com/vektah/gqlparser/v2/ast"
)
type DirectiveList map[string]*Directive
// LocationDirectives filter directives by location
func (dl DirectiveList) LocationDirectives(location string) DirectiveList {
return locationDirectives(dl, ast.DirectiveLocation(location))
}
type Directive struct {
*ast.DirectiveDefinition
Name string
Args []*FieldArgument
Builtin bool
}
// IsLocation check location directive
func (d *Directive) IsLocation(location ...ast.DirectiveLocation) bool {
for _, l := range d.Locations {
for _, a := range location {
if l == a {
return true
}
}
}
return false
}
func locationDirectives(directives DirectiveList, location ...ast.DirectiveLocation) map[string]*Directive {
mDirectives := make(map[string]*Directive)
for name, d := range directives {
if d.IsLocation(location...) {
mDirectives[name] = d
}
}
return mDirectives
}
func (b *builder) buildDirectives() (map[string]*Directive, error) {
directives := make(map[string]*Directive, len(b.Schema.Directives))
for name, dir := range b.Schema.Directives {
if _, ok := directives[name]; ok {
return nil, fmt.Errorf("directive with name %s already exists", name)
}
var args []*FieldArgument
for _, arg := range dir.Arguments {
tr, err := b.Binder.TypeReference(arg.Type, nil)
if err != nil {
return nil, err
}
newArg := &FieldArgument{
ArgumentDefinition: arg,
TypeReference: tr,
VarName: templates.ToGoPrivate(arg.Name),
}
if arg.DefaultValue != nil {
var err error
newArg.Default, err = arg.DefaultValue.Value(nil)
if err != nil {
return nil, fmt.Errorf("default value for directive argument %s(%s) is not valid: %w", dir.Name, arg.Name, err)
}
}
args = append(args, newArg)
}
directives[name] = &Directive{
DirectiveDefinition: dir,
Name: name,
Args: args,
Builtin: b.Config.Directives[name].SkipRuntime,
}
}
return directives, nil
}
func (b *builder) getDirectives(list ast.DirectiveList) ([]*Directive, error) {
dirs := make([]*Directive, len(list))
for i, d := range list {
argValues := make(map[string]interface{}, len(d.Arguments))
for _, da := range d.Arguments {
val, err := da.Value.Value(nil)
if err != nil {
return nil, err
}
argValues[da.Name] = val
}
def, ok := b.Directives[d.Name]
if !ok {
return nil, fmt.Errorf("directive %s not found", d.Name)
}
var args []*FieldArgument
for _, a := range def.Args {
value := a.Default
if argValue, ok := argValues[a.Name]; ok {
value = argValue
}
args = append(args, &FieldArgument{
ArgumentDefinition: a.ArgumentDefinition,
Value: value,
VarName: a.VarName,
TypeReference: a.TypeReference,
})
}
dirs[i] = &Directive{
Name: d.Name,
Args: args,
DirectiveDefinition: list[i].Definition,
Builtin: b.Config.Directives[d.Name].SkipRuntime,
}
}
return dirs, nil
}
func (d *Directive) ArgsFunc() string {
if len(d.Args) == 0 {
return ""
}
return "dir_" + d.Name + "_args"
}
func (d *Directive) CallArgs() string {
args := []string{"ctx", "obj", "n"}
for _, arg := range d.Args {
args = append(args, "args["+strconv.Quote(arg.Name)+"].("+templates.CurrentImports.LookupType(arg.TypeReference.GO)+")")
}
return strings.Join(args, ", ")
}
func (d *Directive) ResolveArgs(obj string, next int) string {
args := []string{"ctx", obj, fmt.Sprintf("directive%d", next)}
for _, arg := range d.Args {
dArg := arg.VarName
if arg.Value == nil && arg.Default == nil {
dArg = "nil"
}
args = append(args, dArg)
}
return strings.Join(args, ", ")
}
func (d *Directive) Declaration() string {
res := ucFirst(d.Name) + " func(ctx context.Context, obj interface{}, next graphql.Resolver"
for _, arg := range d.Args {
res += fmt.Sprintf(", %s %s", templates.ToGoPrivate(arg.Name), templates.CurrentImports.LookupType(arg.TypeReference.GO))
}
res += ") (res interface{}, err error)"
return res
}

View File

@@ -1,149 +0,0 @@
{{ define "implDirectives" }}{{ $in := .DirectiveObjName }}
{{- range $i, $directive := .ImplDirectives -}}
directive{{add $i 1}} := func(ctx context.Context) (interface{}, error) {
{{- range $arg := $directive.Args }}
{{- if notNil "Value" $arg }}
{{ $arg.VarName }}, err := ec.{{ $arg.TypeReference.UnmarshalFunc }}(ctx, {{ $arg.Value | dump }})
if err != nil{
return nil, err
}
{{- else if notNil "Default" $arg }}
{{ $arg.VarName }}, err := ec.{{ $arg.TypeReference.UnmarshalFunc }}(ctx, {{ $arg.Default | dump }})
if err != nil{
return nil, err
}
{{- end }}
{{- end }}
if ec.directives.{{$directive.Name|ucFirst}} == nil {
return nil, errors.New("directive {{$directive.Name}} is not implemented")
}
return ec.directives.{{$directive.Name|ucFirst}}({{$directive.ResolveArgs $in $i }})
}
{{ end -}}
{{ end }}
{{define "queryDirectives"}}
for _, d := range obj.Directives {
switch d.Name {
{{- range $directive := . }}
case "{{$directive.Name}}":
{{- if $directive.Args }}
rawArgs := d.ArgumentMap(ec.Variables)
args, err := ec.{{ $directive.ArgsFunc }}(ctx,rawArgs)
if err != nil {
ec.Error(ctx, err)
return graphql.Null
}
{{- end }}
n := next
next = func(ctx context.Context) (interface{}, error) {
if ec.directives.{{$directive.Name|ucFirst}} == nil {
return nil, errors.New("directive {{$directive.Name}} is not implemented")
}
return ec.directives.{{$directive.Name|ucFirst}}({{$directive.CallArgs}})
}
{{- end }}
}
}
tmp, err := next(ctx)
if err != nil {
ec.Error(ctx, err)
return graphql.Null
}
if data, ok := tmp.(graphql.Marshaler); ok {
return data
}
ec.Errorf(ctx, `unexpected type %T from directive, should be graphql.Marshaler`, tmp)
return graphql.Null
{{end}}
{{ if .Directives.LocationDirectives "QUERY" }}
func (ec *executionContext) _queryMiddleware(ctx context.Context, obj *ast.OperationDefinition, next func(ctx context.Context) (interface{}, error)) graphql.Marshaler {
{{ template "queryDirectives" .Directives.LocationDirectives "QUERY" }}
}
{{ end }}
{{ if .Directives.LocationDirectives "MUTATION" }}
func (ec *executionContext) _mutationMiddleware(ctx context.Context, obj *ast.OperationDefinition, next func(ctx context.Context) (interface{}, error)) graphql.Marshaler {
{{ template "queryDirectives" .Directives.LocationDirectives "MUTATION" }}
}
{{ end }}
{{ if .Directives.LocationDirectives "SUBSCRIPTION" }}
func (ec *executionContext) _subscriptionMiddleware(ctx context.Context, obj *ast.OperationDefinition, next func(ctx context.Context) (interface{}, error)) func() graphql.Marshaler {
for _, d := range obj.Directives {
switch d.Name {
{{- range $directive := .Directives.LocationDirectives "SUBSCRIPTION" }}
case "{{$directive.Name}}":
{{- if $directive.Args }}
rawArgs := d.ArgumentMap(ec.Variables)
args, err := ec.{{ $directive.ArgsFunc }}(ctx,rawArgs)
if err != nil {
ec.Error(ctx, err)
return func() graphql.Marshaler {
return graphql.Null
}
}
{{- end }}
n := next
next = func(ctx context.Context) (interface{}, error) {
if ec.directives.{{$directive.Name|ucFirst}} == nil {
return nil, errors.New("directive {{$directive.Name}} is not implemented")
}
return ec.directives.{{$directive.Name|ucFirst}}({{$directive.CallArgs}})
}
{{- end }}
}
}
tmp, err := next(ctx)
if err != nil {
ec.Error(ctx, err)
return func() graphql.Marshaler {
return graphql.Null
}
}
if data, ok := tmp.(func() graphql.Marshaler); ok {
return data
}
ec.Errorf(ctx, `unexpected type %T from directive, should be graphql.Marshaler`, tmp)
return func() graphql.Marshaler {
return graphql.Null
}
}
{{ end }}
{{ if .Directives.LocationDirectives "FIELD" }}
func (ec *executionContext) _fieldMiddleware(ctx context.Context, obj interface{}, next graphql.Resolver) interface{} {
{{- if .Directives.LocationDirectives "FIELD" }}
fc := graphql.GetFieldContext(ctx)
for _, d := range fc.Field.Directives {
switch d.Name {
{{- range $directive := .Directives.LocationDirectives "FIELD" }}
case "{{$directive.Name}}":
{{- if $directive.Args }}
rawArgs := d.ArgumentMap(ec.Variables)
args, err := ec.{{ $directive.ArgsFunc }}(ctx,rawArgs)
if err != nil {
ec.Error(ctx, err)
return nil
}
{{- end }}
n := next
next = func(ctx context.Context) (interface{}, error) {
if ec.directives.{{$directive.Name|ucFirst}} == nil {
return nil, errors.New("directive {{$directive.Name}} is not implemented")
}
return ec.directives.{{$directive.Name|ucFirst}}({{$directive.CallArgs}})
}
{{- end }}
}
}
{{- end }}
res, err := ec.ResolverMiddleware(ctx, next)
if err != nil {
ec.Error(ctx, err)
return nil
}
return res
}
{{ end }}

View File

@@ -1,556 +0,0 @@
package codegen
import (
"errors"
"fmt"
"go/types"
"log"
"reflect"
"strconv"
"strings"
"github.com/99designs/gqlgen/codegen/config"
"github.com/99designs/gqlgen/codegen/templates"
"github.com/vektah/gqlparser/v2/ast"
)
type Field struct {
*ast.FieldDefinition
TypeReference *config.TypeReference
GoFieldType GoFieldType // The field type in go, if any
GoReceiverName string // The name of method & var receiver in go, if any
GoFieldName string // The name of the method or var in go, if any
IsResolver bool // Does this field need a resolver
Args []*FieldArgument // A list of arguments to be passed to this field
MethodHasContext bool // If this is bound to a go method, does the method also take a context
NoErr bool // If this is bound to a go method, does that method have an error as the second argument
VOkFunc bool // If this is bound to a go method, is it of shape (interface{}, bool)
Object *Object // A link back to the parent object
Default interface{} // The default value
Stream bool // does this field return a channel?
Directives []*Directive
}
func (b *builder) buildField(obj *Object, field *ast.FieldDefinition) (*Field, error) {
dirs, err := b.getDirectives(field.Directives)
if err != nil {
return nil, err
}
f := Field{
FieldDefinition: field,
Object: obj,
Directives: dirs,
GoFieldName: templates.ToGo(field.Name),
GoFieldType: GoFieldVariable,
GoReceiverName: "obj",
}
if field.DefaultValue != nil {
var err error
f.Default, err = field.DefaultValue.Value(nil)
if err != nil {
return nil, fmt.Errorf("default value %s is not valid: %w", field.Name, err)
}
}
for _, arg := range field.Arguments {
newArg, err := b.buildArg(obj, arg)
if err != nil {
return nil, err
}
f.Args = append(f.Args, newArg)
}
if err = b.bindField(obj, &f); err != nil {
f.IsResolver = true
if errors.Is(err, config.ErrTypeNotFound) {
return nil, err
}
log.Println(err.Error())
}
if f.IsResolver && !f.TypeReference.IsPtr() && f.TypeReference.IsStruct() {
f.TypeReference = b.Binder.PointerTo(f.TypeReference)
}
return &f, nil
}
func (b *builder) bindField(obj *Object, f *Field) (errret error) {
defer func() {
if f.TypeReference == nil {
tr, err := b.Binder.TypeReference(f.Type, nil)
if err != nil {
errret = err
}
f.TypeReference = tr
}
if f.TypeReference != nil {
dirs, err := b.getDirectives(f.TypeReference.Definition.Directives)
if err != nil {
errret = err
}
for _, dir := range obj.Directives {
if dir.IsLocation(ast.LocationInputObject) {
dirs = append(dirs, dir)
}
}
f.Directives = append(dirs, f.Directives...)
}
}()
f.Stream = obj.Stream
switch {
case f.Name == "__schema":
f.GoFieldType = GoFieldMethod
f.GoReceiverName = "ec"
f.GoFieldName = "introspectSchema"
return nil
case f.Name == "__type":
f.GoFieldType = GoFieldMethod
f.GoReceiverName = "ec"
f.GoFieldName = "introspectType"
return nil
case f.Name == "_entities":
f.GoFieldType = GoFieldMethod
f.GoReceiverName = "ec"
f.GoFieldName = "__resolve_entities"
f.MethodHasContext = true
f.NoErr = true
return nil
case f.Name == "_service":
f.GoFieldType = GoFieldMethod
f.GoReceiverName = "ec"
f.GoFieldName = "__resolve__service"
f.MethodHasContext = true
return nil
case obj.Root:
f.IsResolver = true
return nil
case b.Config.Models[obj.Name].Fields[f.Name].Resolver:
f.IsResolver = true
return nil
case obj.Type == config.MapType:
f.GoFieldType = GoFieldMap
return nil
case b.Config.Models[obj.Name].Fields[f.Name].FieldName != "":
f.GoFieldName = b.Config.Models[obj.Name].Fields[f.Name].FieldName
}
target, err := b.findBindTarget(obj.Type.(*types.Named), f.GoFieldName)
if err != nil {
return err
}
pos := b.Binder.ObjectPosition(target)
switch target := target.(type) {
case nil:
objPos := b.Binder.TypePosition(obj.Type)
return fmt.Errorf(
"%s:%d adding resolver method for %s.%s, nothing matched",
objPos.Filename,
objPos.Line,
obj.Name,
f.Name,
)
case *types.Func:
sig := target.Type().(*types.Signature)
if sig.Results().Len() == 1 {
f.NoErr = true
} else if s := sig.Results(); s.Len() == 2 && s.At(1).Type().String() == "bool" {
f.VOkFunc = true
} else if sig.Results().Len() != 2 {
return fmt.Errorf("method has wrong number of args")
}
params := sig.Params()
// If the first argument is the context, remove it from the comparison and set
// the MethodHasContext flag so that the context will be passed to this model's method
if params.Len() > 0 && params.At(0).Type().String() == "context.Context" {
f.MethodHasContext = true
vars := make([]*types.Var, params.Len()-1)
for i := 1; i < params.Len(); i++ {
vars[i-1] = params.At(i)
}
params = types.NewTuple(vars...)
}
// Try to match target function's arguments with GraphQL field arguments
newArgs, err := b.bindArgs(f, params)
if err != nil {
return fmt.Errorf("%s:%d: %w", pos.Filename, pos.Line, err)
}
// Try to match target function's return types with GraphQL field return type
result := sig.Results().At(0)
tr, err := b.Binder.TypeReference(f.Type, result.Type())
if err != nil {
return err
}
// success, args and return type match. Bind to method
f.GoFieldType = GoFieldMethod
f.GoReceiverName = "obj"
f.GoFieldName = target.Name()
f.Args = newArgs
f.TypeReference = tr
return nil
case *types.Var:
tr, err := b.Binder.TypeReference(f.Type, target.Type())
if err != nil {
return err
}
// success, bind to var
f.GoFieldType = GoFieldVariable
f.GoReceiverName = "obj"
f.GoFieldName = target.Name()
f.TypeReference = tr
return nil
default:
panic(fmt.Errorf("unknown bind target %T for %s", target, f.Name))
}
}
// findBindTarget attempts to match the name to a field or method on a Type
// with the following priorites:
// 1. Any Fields with a struct tag (see config.StructTag). Errors if more than one match is found
// 2. Any method or field with a matching name. Errors if more than one match is found
// 3. Same logic again for embedded fields
func (b *builder) findBindTarget(t types.Type, name string) (types.Object, error) {
// NOTE: a struct tag will override both methods and fields
// Bind to struct tag
found, err := b.findBindStructTagTarget(t, name)
if found != nil || err != nil {
return found, err
}
// Search for a method to bind to
foundMethod, err := b.findBindMethodTarget(t, name)
if err != nil {
return nil, err
}
// Search for a field to bind to
foundField, err := b.findBindFieldTarget(t, name)
if err != nil {
return nil, err
}
switch {
case foundField == nil && foundMethod != nil:
// Bind to method
return foundMethod, nil
case foundField != nil && foundMethod == nil:
// Bind to field
return foundField, nil
case foundField != nil && foundMethod != nil:
// Error
return nil, fmt.Errorf("found more than one way to bind for %s", name)
}
// Search embeds
return b.findBindEmbedsTarget(t, name)
}
func (b *builder) findBindStructTagTarget(in types.Type, name string) (types.Object, error) {
if b.Config.StructTag == "" {
return nil, nil
}
switch t := in.(type) {
case *types.Named:
return b.findBindStructTagTarget(t.Underlying(), name)
case *types.Struct:
var found types.Object
for i := 0; i < t.NumFields(); i++ {
field := t.Field(i)
if !field.Exported() || field.Embedded() {
continue
}
tags := reflect.StructTag(t.Tag(i))
if val, ok := tags.Lookup(b.Config.StructTag); ok && equalFieldName(val, name) {
if found != nil {
return nil, fmt.Errorf("tag %s is ambigious; multiple fields have the same tag value of %s", b.Config.StructTag, val)
}
found = field
}
}
return found, nil
}
return nil, nil
}
func (b *builder) findBindMethodTarget(in types.Type, name string) (types.Object, error) {
switch t := in.(type) {
case *types.Named:
if _, ok := t.Underlying().(*types.Interface); ok {
return b.findBindMethodTarget(t.Underlying(), name)
}
return b.findBindMethoderTarget(t.Method, t.NumMethods(), name)
case *types.Interface:
// FIX-ME: Should use ExplicitMethod here? What's the difference?
return b.findBindMethoderTarget(t.Method, t.NumMethods(), name)
}
return nil, nil
}
func (b *builder) findBindMethoderTarget(methodFunc func(i int) *types.Func, methodCount int, name string) (types.Object, error) {
var found types.Object
for i := 0; i < methodCount; i++ {
method := methodFunc(i)
if !method.Exported() || !strings.EqualFold(method.Name(), name) {
continue
}
if found != nil {
return nil, fmt.Errorf("found more than one matching method to bind for %s", name)
}
found = method
}
return found, nil
}
func (b *builder) findBindFieldTarget(in types.Type, name string) (types.Object, error) {
switch t := in.(type) {
case *types.Named:
return b.findBindFieldTarget(t.Underlying(), name)
case *types.Struct:
var found types.Object
for i := 0; i < t.NumFields(); i++ {
field := t.Field(i)
if !field.Exported() || !equalFieldName(field.Name(), name) {
continue
}
if found != nil {
return nil, fmt.Errorf("found more than one matching field to bind for %s", name)
}
found = field
}
return found, nil
}
return nil, nil
}
func (b *builder) findBindEmbedsTarget(in types.Type, name string) (types.Object, error) {
switch t := in.(type) {
case *types.Named:
return b.findBindEmbedsTarget(t.Underlying(), name)
case *types.Struct:
return b.findBindStructEmbedsTarget(t, name)
case *types.Interface:
return b.findBindInterfaceEmbedsTarget(t, name)
}
return nil, nil
}
func (b *builder) findBindStructEmbedsTarget(strukt *types.Struct, name string) (types.Object, error) {
var found types.Object
for i := 0; i < strukt.NumFields(); i++ {
field := strukt.Field(i)
if !field.Embedded() {
continue
}
fieldType := field.Type()
if ptr, ok := fieldType.(*types.Pointer); ok {
fieldType = ptr.Elem()
}
f, err := b.findBindTarget(fieldType, name)
if err != nil {
return nil, err
}
if f != nil && found != nil {
return nil, fmt.Errorf("found more than one way to bind for %s", name)
}
if f != nil {
found = f
}
}
return found, nil
}
func (b *builder) findBindInterfaceEmbedsTarget(iface *types.Interface, name string) (types.Object, error) {
var found types.Object
for i := 0; i < iface.NumEmbeddeds(); i++ {
embeddedType := iface.EmbeddedType(i)
f, err := b.findBindTarget(embeddedType, name)
if err != nil {
return nil, err
}
if f != nil && found != nil {
return nil, fmt.Errorf("found more than one way to bind for %s", name)
}
if f != nil {
found = f
}
}
return found, nil
}
func (f *Field) HasDirectives() bool {
return len(f.ImplDirectives()) > 0
}
func (f *Field) DirectiveObjName() string {
if f.Object.Root {
return "nil"
}
return f.GoReceiverName
}
func (f *Field) ImplDirectives() []*Directive {
var d []*Directive
loc := ast.LocationFieldDefinition
if f.Object.IsInputType() {
loc = ast.LocationInputFieldDefinition
}
for i := range f.Directives {
if !f.Directives[i].Builtin &&
(f.Directives[i].IsLocation(loc, ast.LocationObject) || f.Directives[i].IsLocation(loc, ast.LocationInputObject)) {
d = append(d, f.Directives[i])
}
}
return d
}
func (f *Field) IsReserved() bool {
return strings.HasPrefix(f.Name, "__")
}
func (f *Field) IsMethod() bool {
return f.GoFieldType == GoFieldMethod
}
func (f *Field) IsVariable() bool {
return f.GoFieldType == GoFieldVariable
}
func (f *Field) IsMap() bool {
return f.GoFieldType == GoFieldMap
}
func (f *Field) IsConcurrent() bool {
if f.Object.DisableConcurrency {
return false
}
return f.MethodHasContext || f.IsResolver
}
func (f *Field) GoNameUnexported() string {
return templates.ToGoPrivate(f.Name)
}
func (f *Field) ShortInvocation() string {
if f.Object.Kind == ast.InputObject {
return fmt.Sprintf("%s().%s(ctx, &it, data)", strings.Title(f.Object.Definition.Name), f.GoFieldName)
}
return fmt.Sprintf("%s().%s(%s)", strings.Title(f.Object.Definition.Name), f.GoFieldName, f.CallArgs())
}
func (f *Field) ArgsFunc() string {
if len(f.Args) == 0 {
return ""
}
return "field_" + f.Object.Definition.Name + "_" + f.Name + "_args"
}
func (f *Field) ResolverType() string {
if !f.IsResolver {
return ""
}
return fmt.Sprintf("%s().%s(%s)", f.Object.Definition.Name, f.GoFieldName, f.CallArgs())
}
func (f *Field) ShortResolverDeclaration() string {
if f.Object.Kind == ast.InputObject {
return fmt.Sprintf("(ctx context.Context, obj %s, data %s) error",
templates.CurrentImports.LookupType(f.Object.Reference()),
templates.CurrentImports.LookupType(f.TypeReference.GO),
)
}
res := "(ctx context.Context"
if !f.Object.Root {
res += fmt.Sprintf(", obj %s", templates.CurrentImports.LookupType(f.Object.Reference()))
}
for _, arg := range f.Args {
res += fmt.Sprintf(", %s %s", arg.VarName, templates.CurrentImports.LookupType(arg.TypeReference.GO))
}
result := templates.CurrentImports.LookupType(f.TypeReference.GO)
if f.Object.Stream {
result = "<-chan " + result
}
res += fmt.Sprintf(") (%s, error)", result)
return res
}
func (f *Field) ComplexitySignature() string {
res := "func(childComplexity int"
for _, arg := range f.Args {
res += fmt.Sprintf(", %s %s", arg.VarName, templates.CurrentImports.LookupType(arg.TypeReference.GO))
}
res += ") int"
return res
}
func (f *Field) ComplexityArgs() string {
args := make([]string, len(f.Args))
for i, arg := range f.Args {
args[i] = "args[" + strconv.Quote(arg.Name) + "].(" + templates.CurrentImports.LookupType(arg.TypeReference.GO) + ")"
}
return strings.Join(args, ", ")
}
func (f *Field) CallArgs() string {
args := make([]string, 0, len(f.Args)+2)
if f.IsResolver {
args = append(args, "rctx")
if !f.Object.Root {
args = append(args, "obj")
}
} else if f.MethodHasContext {
args = append(args, "ctx")
}
for _, arg := range f.Args {
args = append(args, "args["+strconv.Quote(arg.Name)+"].("+templates.CurrentImports.LookupType(arg.TypeReference.GO)+")")
}
return strings.Join(args, ", ")
}

View File

@@ -1,129 +0,0 @@
{{- range $object := .Objects }}{{- range $field := $object.Fields }}
func (ec *executionContext) _{{$object.Name}}_{{$field.Name}}(ctx context.Context, field graphql.CollectedField{{ if not $object.Root }}, obj {{$object.Reference | ref}}{{end}}) (ret {{ if $object.Stream }}func(){{ end }}graphql.Marshaler) {
{{- $null := "graphql.Null" }}
{{- if $object.Stream }}
{{- $null = "nil" }}
{{- end }}
defer func () {
if r := recover(); r != nil {
ec.Error(ctx, ec.Recover(ctx, r))
ret = {{ $null }}
}
}()
fc := &graphql.FieldContext{
Object: {{$object.Name|quote}},
Field: field,
Args: nil,
IsMethod: {{or $field.IsMethod $field.IsResolver}},
IsResolver: {{ $field.IsResolver }},
}
ctx = graphql.WithFieldContext(ctx, fc)
{{- if $field.Args }}
rawArgs := field.ArgumentMap(ec.Variables)
args, err := ec.{{ $field.ArgsFunc }}(ctx,rawArgs)
if err != nil {
ec.Error(ctx, err)
return {{ $null }}
}
fc.Args = args
{{- end }}
{{- if $.AllDirectives.LocationDirectives "FIELD" }}
resTmp := ec._fieldMiddleware(ctx, {{if $object.Root}}nil{{else}}obj{{end}}, func(rctx context.Context) (interface{}, error) {
{{ template "field" $field }}
})
{{ else }}
resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) {
{{ template "field" $field }}
})
if err != nil {
ec.Error(ctx, err)
return {{ $null }}
}
{{- end }}
if resTmp == nil {
{{- if $field.TypeReference.GQL.NonNull }}
if !graphql.HasFieldError(ctx, fc) {
ec.Errorf(ctx, "must not be null")
}
{{- end }}
return {{ $null }}
}
{{- if $object.Stream }}
return func() graphql.Marshaler {
res, ok := <-resTmp.(<-chan {{$field.TypeReference.GO | ref}})
if !ok {
return nil
}
return graphql.WriterFunc(func(w io.Writer) {
w.Write([]byte{'{'})
graphql.MarshalString(field.Alias).MarshalGQL(w)
w.Write([]byte{':'})
ec.{{ $field.TypeReference.MarshalFunc }}(ctx, field.Selections, res).MarshalGQL(w)
w.Write([]byte{'}'})
})
}
{{- else }}
res := resTmp.({{$field.TypeReference.GO | ref}})
fc.Result = res
return ec.{{ $field.TypeReference.MarshalFunc }}(ctx, field.Selections, res)
{{- end }}
}
{{- end }}{{- end}}
{{ define "field" }}
{{- if .HasDirectives -}}
directive0 := func(rctx context.Context) (interface{}, error) {
ctx = rctx // use context from middleware stack in children
{{ template "fieldDefinition" . }}
}
{{ template "implDirectives" . }}
tmp, err := directive{{.ImplDirectives|len}}(rctx)
if err != nil {
return nil, graphql.ErrorOnPath(ctx, err)
}
if tmp == nil {
return nil, nil
}
if data, ok := tmp.({{if .Stream}}<-chan {{end}}{{ .TypeReference.GO | ref }}) ; ok {
return data, nil
}
return nil, fmt.Errorf(`unexpected type %T from directive, should be {{if .Stream}}<-chan {{end}}{{ .TypeReference.GO }}`, tmp)
{{- else -}}
ctx = rctx // use context from middleware stack in children
{{ template "fieldDefinition" . }}
{{- end -}}
{{ end }}
{{ define "fieldDefinition" }}
{{- if .IsResolver -}}
return ec.resolvers.{{ .ShortInvocation }}
{{- else if .IsMap -}}
switch v := {{.GoReceiverName}}[{{.Name|quote}}].(type) {
case {{if .Stream}}<-chan {{end}}{{.TypeReference.GO | ref}}:
return v, nil
case {{if .Stream}}<-chan {{end}}{{.TypeReference.Elem.GO | ref}}:
return &v, nil
case nil:
return ({{.TypeReference.GO | ref}})(nil), nil
default:
return nil, fmt.Errorf("unexpected type %T for field %s", v, {{ .Name | quote}})
}
{{- else if .IsMethod -}}
{{- if .VOkFunc -}}
v, ok := {{.GoReceiverName}}.{{.GoFieldName}}({{ .CallArgs }})
if !ok {
return nil, nil
}
return v, nil
{{- else if .NoErr -}}
return {{.GoReceiverName}}.{{.GoFieldName}}({{ .CallArgs }}), nil
{{- else -}}
return {{.GoReceiverName}}.{{.GoFieldName}}({{ .CallArgs }})
{{- end -}}
{{- else if .IsVariable -}}
return {{.GoReceiverName}}.{{.GoFieldName}}, nil
{{- end }}
{{- end }}

View File

@@ -1,213 +0,0 @@
package codegen
import (
"errors"
"fmt"
"io/ioutil"
"path/filepath"
"runtime"
"strings"
"github.com/99designs/gqlgen/codegen/config"
"github.com/99designs/gqlgen/codegen/templates"
"github.com/vektah/gqlparser/v2/ast"
)
func GenerateCode(data *Data) error {
if !data.Config.Exec.IsDefined() {
return fmt.Errorf("missing exec config")
}
switch data.Config.Exec.Layout {
case config.ExecLayoutSingleFile:
return generateSingleFile(data)
case config.ExecLayoutFollowSchema:
return generatePerSchema(data)
}
return fmt.Errorf("unrecognized exec layout %s", data.Config.Exec.Layout)
}
func generateSingleFile(data *Data) error {
return templates.Render(templates.Options{
PackageName: data.Config.Exec.Package,
Filename: data.Config.Exec.Filename,
Data: data,
RegionTags: true,
GeneratedHeader: true,
Packages: data.Config.Packages,
})
}
func generatePerSchema(data *Data) error {
err := generateRootFile(data)
if err != nil {
return err
}
builds := map[string]*Data{}
err = addObjects(data, &builds)
if err != nil {
return err
}
err = addInputs(data, &builds)
if err != nil {
return err
}
err = addInterfaces(data, &builds)
if err != nil {
return err
}
err = addReferencedTypes(data, &builds)
if err != nil {
return err
}
for filename, build := range builds {
if filename == "" {
continue
}
dir := data.Config.Exec.DirName
path := filepath.Join(dir, filename)
err = templates.Render(templates.Options{
PackageName: data.Config.Exec.Package,
Filename: path,
Data: build,
RegionTags: true,
GeneratedHeader: true,
Packages: data.Config.Packages,
})
if err != nil {
return err
}
}
return nil
}
func filename(p *ast.Position, config *config.Config) string {
name := "common!"
if p != nil && p.Src != nil {
gqlname := filepath.Base(p.Src.Name)
ext := filepath.Ext(p.Src.Name)
name = strings.TrimSuffix(gqlname, ext)
}
filenameTempl := config.Exec.FilenameTemplate
if filenameTempl == "" {
filenameTempl = "{name}.generated.go"
}
return strings.ReplaceAll(filenameTempl, "{name}", name)
}
func addBuild(filename string, p *ast.Position, data *Data, builds *map[string]*Data) {
buildConfig := *data.Config
if p != nil {
buildConfig.Sources = []*ast.Source{p.Src}
}
(*builds)[filename] = &Data{
Config: &buildConfig,
QueryRoot: data.QueryRoot,
MutationRoot: data.MutationRoot,
SubscriptionRoot: data.SubscriptionRoot,
AllDirectives: data.AllDirectives,
}
}
// Root file contains top-level definitions that should not be duplicated across the generated
// files for each schema file.
func generateRootFile(data *Data) error {
dir := data.Config.Exec.DirName
path := filepath.Join(dir, "root_.generated.go")
_, thisFile, _, _ := runtime.Caller(0)
rootDir := filepath.Dir(thisFile)
templatePath := filepath.Join(rootDir, "root_.gotpl")
templateBytes, err := ioutil.ReadFile(templatePath)
if err != nil {
return err
}
template := string(templateBytes)
return templates.Render(templates.Options{
PackageName: data.Config.Exec.Package,
Template: template,
Filename: path,
Data: data,
RegionTags: false,
GeneratedHeader: true,
Packages: data.Config.Packages,
})
}
func addObjects(data *Data, builds *map[string]*Data) error {
for _, o := range data.Objects {
filename := filename(o.Position, data.Config)
if (*builds)[filename] == nil {
addBuild(filename, o.Position, data, builds)
}
(*builds)[filename].Objects = append((*builds)[filename].Objects, o)
}
return nil
}
func addInputs(data *Data, builds *map[string]*Data) error {
for _, in := range data.Inputs {
filename := filename(in.Position, data.Config)
if (*builds)[filename] == nil {
addBuild(filename, in.Position, data, builds)
}
(*builds)[filename].Inputs = append((*builds)[filename].Inputs, in)
}
return nil
}
func addInterfaces(data *Data, builds *map[string]*Data) error {
for k, inf := range data.Interfaces {
filename := filename(inf.Position, data.Config)
if (*builds)[filename] == nil {
addBuild(filename, inf.Position, data, builds)
}
build := (*builds)[filename]
if build.Interfaces == nil {
build.Interfaces = map[string]*Interface{}
}
if build.Interfaces[k] != nil {
return errors.New("conflicting interface keys")
}
build.Interfaces[k] = inf
}
return nil
}
func addReferencedTypes(data *Data, builds *map[string]*Data) error {
for k, rt := range data.ReferencedTypes {
filename := filename(rt.Definition.Position, data.Config)
if (*builds)[filename] == nil {
addBuild(filename, rt.Definition.Position, data, builds)
}
build := (*builds)[filename]
if build.ReferencedTypes == nil {
build.ReferencedTypes = map[string]*config.TypeReference{}
}
if build.ReferencedTypes[k] != nil {
return errors.New("conflicting referenced type keys")
}
build.ReferencedTypes[k] = rt
}
return nil
}

View File

@@ -1,235 +0,0 @@
{{ reserveImport "context" }}
{{ reserveImport "fmt" }}
{{ reserveImport "io" }}
{{ reserveImport "strconv" }}
{{ reserveImport "time" }}
{{ reserveImport "sync" }}
{{ reserveImport "sync/atomic" }}
{{ reserveImport "errors" }}
{{ reserveImport "bytes" }}
{{ reserveImport "github.com/vektah/gqlparser/v2" "gqlparser" }}
{{ reserveImport "github.com/vektah/gqlparser/v2/ast" }}
{{ reserveImport "github.com/99designs/gqlgen/graphql" }}
{{ reserveImport "github.com/99designs/gqlgen/graphql/introspection" }}
{{ if eq .Config.Exec.Layout "single-file" }}
// NewExecutableSchema creates an ExecutableSchema from the ResolverRoot interface.
func NewExecutableSchema(cfg Config) graphql.ExecutableSchema {
return &executableSchema{
resolvers: cfg.Resolvers,
directives: cfg.Directives,
complexity: cfg.Complexity,
}
}
type Config struct {
Resolvers ResolverRoot
Directives DirectiveRoot
Complexity ComplexityRoot
}
type ResolverRoot interface {
{{- range $object := .Objects -}}
{{ if $object.HasResolvers -}}
{{ucFirst $object.Name}}() {{ucFirst $object.Name}}Resolver
{{ end }}
{{- end }}
{{- range $object := .Inputs -}}
{{ if $object.HasResolvers -}}
{{ucFirst $object.Name}}() {{ucFirst $object.Name}}Resolver
{{ end }}
{{- end }}
}
type DirectiveRoot struct {
{{ range $directive := .Directives }}
{{- $directive.Declaration }}
{{ end }}
}
type ComplexityRoot struct {
{{ range $object := .Objects }}
{{ if not $object.IsReserved -}}
{{ ucFirst $object.Name }} struct {
{{ range $_, $fields := $object.UniqueFields }}
{{- $field := index $fields 0 -}}
{{ if not $field.IsReserved -}}
{{ $field.GoFieldName }} {{ $field.ComplexitySignature }}
{{ end }}
{{- end }}
}
{{- end }}
{{ end }}
}
{{ end }}
{{ range $object := .Objects -}}
{{ if $object.HasResolvers }}
type {{ucFirst $object.Name}}Resolver interface {
{{ range $field := $object.Fields -}}
{{- if $field.IsResolver }}
{{- $field.GoFieldName}}{{ $field.ShortResolverDeclaration }}
{{- end }}
{{ end }}
}
{{- end }}
{{- end }}
{{ range $object := .Inputs -}}
{{ if $object.HasResolvers }}
type {{$object.Name}}Resolver interface {
{{ range $field := $object.Fields -}}
{{- if $field.IsResolver }}
{{- $field.GoFieldName}}{{ $field.ShortResolverDeclaration }}
{{- end }}
{{ end }}
}
{{- end }}
{{- end }}
{{ if eq .Config.Exec.Layout "single-file" }}
type executableSchema struct {
resolvers ResolverRoot
directives DirectiveRoot
complexity ComplexityRoot
}
func (e *executableSchema) Schema() *ast.Schema {
return parsedSchema
}
func (e *executableSchema) Complexity(typeName, field string, childComplexity int, rawArgs map[string]interface{}) (int, bool) {
ec := executionContext{nil, e}
_ = ec
switch typeName + "." + field {
{{ range $object := .Objects }}
{{ if not $object.IsReserved }}
{{ range $_, $fields := $object.UniqueFields }}
{{- $len := len $fields }}
{{- range $i, $field := $fields }}
{{- $last := eq (add $i 1) $len }}
{{- if not $field.IsReserved }}
{{- if eq $i 0 }}case {{ end }}"{{$object.Name}}.{{$field.Name}}"{{ if not $last }},{{ else }}:
if e.complexity.{{ucFirst $object.Name}}.{{$field.GoFieldName}} == nil {
break
}
{{ if $field.Args }}
args, err := ec.{{ $field.ArgsFunc }}(context.TODO(),rawArgs)
if err != nil {
return 0, false
}
{{ end }}
return e.complexity.{{ucFirst $object.Name}}.{{$field.GoFieldName}}(childComplexity{{if $field.Args}}, {{$field.ComplexityArgs}} {{ end }}), true
{{ end }}
{{- end }}
{{- end }}
{{ end }}
{{ end }}
{{ end }}
}
return 0, false
}
func (e *executableSchema) Exec(ctx context.Context) graphql.ResponseHandler {
rc := graphql.GetOperationContext(ctx)
ec := executionContext{rc, e}
first := true
switch rc.Operation.Operation {
{{- if .QueryRoot }} case ast.Query:
return func(ctx context.Context) *graphql.Response {
if !first { return nil }
first = false
{{ if .Directives.LocationDirectives "QUERY" -}}
data := ec._queryMiddleware(ctx, rc.Operation, func(ctx context.Context) (interface{}, error){
return ec._{{.QueryRoot.Name}}(ctx, rc.Operation.SelectionSet), nil
})
{{- else -}}
data := ec._{{.QueryRoot.Name}}(ctx, rc.Operation.SelectionSet)
{{- end }}
var buf bytes.Buffer
data.MarshalGQL(&buf)
return &graphql.Response{
Data: buf.Bytes(),
}
}
{{ end }}
{{- if .MutationRoot }} case ast.Mutation:
return func(ctx context.Context) *graphql.Response {
if !first { return nil }
first = false
{{ if .Directives.LocationDirectives "MUTATION" -}}
data := ec._mutationMiddleware(ctx, rc.Operation, func(ctx context.Context) (interface{}, error){
return ec._{{.MutationRoot.Name}}(ctx, rc.Operation.SelectionSet), nil
})
{{- else -}}
data := ec._{{.MutationRoot.Name}}(ctx, rc.Operation.SelectionSet)
{{- end }}
var buf bytes.Buffer
data.MarshalGQL(&buf)
return &graphql.Response{
Data: buf.Bytes(),
}
}
{{ end }}
{{- if .SubscriptionRoot }} case ast.Subscription:
{{ if .Directives.LocationDirectives "SUBSCRIPTION" -}}
next := ec._subscriptionMiddleware(ctx, rc.Operation, func(ctx context.Context) (interface{}, error){
return ec._{{.SubscriptionRoot.Name}}(ctx, rc.Operation.SelectionSet),nil
})
{{- else -}}
next := ec._{{.SubscriptionRoot.Name}}(ctx, rc.Operation.SelectionSet)
{{- end }}
var buf bytes.Buffer
return func(ctx context.Context) *graphql.Response {
buf.Reset()
data := next()
if data == nil {
return nil
}
data.MarshalGQL(&buf)
return &graphql.Response{
Data: buf.Bytes(),
}
}
{{ end }}
default:
return graphql.OneShot(graphql.ErrorResponse(ctx, "unsupported GraphQL operation"))
}
}
type executionContext struct {
*graphql.OperationContext
*executableSchema
}
func (ec *executionContext) introspectSchema() (*introspection.Schema, error) {
if ec.DisableIntrospection {
return nil, errors.New("introspection disabled")
}
return introspection.WrapSchema(parsedSchema), nil
}
func (ec *executionContext) introspectType(name string) (*introspection.Type, error) {
if ec.DisableIntrospection {
return nil, errors.New("introspection disabled")
}
return introspection.WrapTypeFromDef(parsedSchema, parsedSchema.Types[name]), nil
}
var sources = []*ast.Source{
{{- range $source := .Config.Sources }}
{Name: {{$source.Name|quote}}, Input: {{$source.Input|rawQuote}}, BuiltIn: {{$source.BuiltIn}}},
{{- end }}
}
var parsedSchema = gqlparser.MustLoadSchema(sources...)
{{ end }}

View File

@@ -1,72 +0,0 @@
{{- range $input := .Inputs }}
{{- if not .HasUnmarshal }}
func (ec *executionContext) unmarshalInput{{ .Name }}(ctx context.Context, obj interface{}) ({{.Type | ref}}, error) {
var it {{.Type | ref}}
asMap := map[string]interface{}{}
for k, v := range obj.(map[string]interface{}) {
asMap[k] = v
}
{{ range $field := .Fields}}
{{- if notNil "Default" $field }}
if _, present := asMap[{{$field.Name|quote}}] ; !present {
asMap[{{$field.Name|quote}}] = {{ $field.Default | dump }}
}
{{- end}}
{{- end }}
for k, v := range asMap {
switch k {
{{- range $field := .Fields }}
case {{$field.Name|quote}}:
var err error
ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField({{$field.Name|quote}}))
{{- if $field.ImplDirectives }}
directive0 := func(ctx context.Context) (interface{}, error) { return ec.{{ $field.TypeReference.UnmarshalFunc }}(ctx, v) }
{{ template "implDirectives" $field }}
tmp, err := directive{{$field.ImplDirectives|len}}(ctx)
if err != nil {
return it, graphql.ErrorOnPath(ctx, err)
}
if data, ok := tmp.({{ $field.TypeReference.GO | ref }}) ; ok {
{{- if $field.IsResolver }}
if err = ec.resolvers.{{ $field.ShortInvocation }}; err != nil {
return it, err
}
{{- else }}
it.{{$field.GoFieldName}} = data
{{- end }}
{{- if $field.TypeReference.IsNilable }}
{{- if not $field.IsResolver }}
} else if tmp == nil {
it.{{$field.GoFieldName}} = nil
{{- end }}
{{- end }}
} else {
err := fmt.Errorf(`unexpected type %T from directive, should be {{ $field.TypeReference.GO }}`, tmp)
return it, graphql.ErrorOnPath(ctx, err)
}
{{- else }}
{{- if $field.IsResolver }}
data, err := ec.{{ $field.TypeReference.UnmarshalFunc }}(ctx, v)
if err != nil {
return it, err
}
if err = ec.resolvers.{{ $field.ShortInvocation }}; err != nil {
return it, err
}
{{- else }}
it.{{$field.GoFieldName}}, err = ec.{{ $field.TypeReference.UnmarshalFunc }}(ctx, v)
if err != nil {
return it, err
}
{{- end }}
{{- end }}
{{- end }}
}
}
return it, nil
}
{{- end }}
{{ end }}

View File

@@ -1,87 +0,0 @@
package codegen
import (
"fmt"
"go/types"
"github.com/vektah/gqlparser/v2/ast"
"github.com/99designs/gqlgen/codegen/config"
)
type Interface struct {
*ast.Definition
Type types.Type
Implementors []InterfaceImplementor
InTypemap bool
}
type InterfaceImplementor struct {
*ast.Definition
Type types.Type
TakeRef bool
}
func (b *builder) buildInterface(typ *ast.Definition) (*Interface, error) {
obj, err := b.Binder.DefaultUserObject(typ.Name)
if err != nil {
panic(err)
}
i := &Interface{
Definition: typ,
Type: obj,
InTypemap: b.Config.Models.UserDefined(typ.Name),
}
interfaceType, err := findGoInterface(i.Type)
if interfaceType == nil || err != nil {
return nil, fmt.Errorf("%s is not an interface", i.Type)
}
for _, implementor := range b.Schema.GetPossibleTypes(typ) {
obj, err := b.Binder.DefaultUserObject(implementor.Name)
if err != nil {
return nil, fmt.Errorf("%s has no backing go type", implementor.Name)
}
implementorType, err := findGoNamedType(obj)
if err != nil {
return nil, fmt.Errorf("can not find backing go type %s: %w", obj.String(), err)
} else if implementorType == nil {
return nil, fmt.Errorf("can not find backing go type %s", obj.String())
}
anyValid := false
// first check if the value receiver can be nil, eg can we type switch on case Thing:
if types.Implements(implementorType, interfaceType) {
i.Implementors = append(i.Implementors, InterfaceImplementor{
Definition: implementor,
Type: obj,
TakeRef: !types.IsInterface(obj),
})
anyValid = true
}
// then check if the pointer receiver can be nil, eg can we type switch on case *Thing:
if types.Implements(types.NewPointer(implementorType), interfaceType) {
i.Implementors = append(i.Implementors, InterfaceImplementor{
Definition: implementor,
Type: types.NewPointer(obj),
})
anyValid = true
}
if !anyValid {
return nil, fmt.Errorf("%s does not satisfy the interface %s", implementorType.String(), i.Type.String())
}
}
return i, nil
}
func (i *InterfaceImplementor) CanBeNil() bool {
return config.IsNilable(i.Type)
}

View File

@@ -1,21 +0,0 @@
{{- range $interface := .Interfaces }}
func (ec *executionContext) _{{$interface.Name}}(ctx context.Context, sel ast.SelectionSet, obj {{$interface.Type | ref}}) graphql.Marshaler {
switch obj := (obj).(type) {
case nil:
return graphql.Null
{{- range $implementor := $interface.Implementors }}
case {{$implementor.Type | ref}}:
{{- if $implementor.CanBeNil }}
if obj == nil {
return graphql.Null
}
{{- end }}
return ec._{{$implementor.Name}}(ctx, sel, {{ if $implementor.TakeRef }}&{{ end }}obj)
{{- end }}
default:
panic(fmt.Errorf("unexpected type %T", obj))
}
}
{{- end }}

View File

@@ -1,169 +0,0 @@
package codegen
import (
"fmt"
"go/types"
"strconv"
"strings"
"unicode"
"github.com/99designs/gqlgen/codegen/config"
"github.com/vektah/gqlparser/v2/ast"
)
type GoFieldType int
const (
GoFieldUndefined GoFieldType = iota
GoFieldMethod
GoFieldVariable
GoFieldMap
)
type Object struct {
*ast.Definition
Type types.Type
ResolverInterface types.Type
Root bool
Fields []*Field
Implements []*ast.Definition
DisableConcurrency bool
Stream bool
Directives []*Directive
}
func (b *builder) buildObject(typ *ast.Definition) (*Object, error) {
dirs, err := b.getDirectives(typ.Directives)
if err != nil {
return nil, fmt.Errorf("%s: %w", typ.Name, err)
}
obj := &Object{
Definition: typ,
Root: b.Schema.Query == typ || b.Schema.Mutation == typ || b.Schema.Subscription == typ,
DisableConcurrency: typ == b.Schema.Mutation,
Stream: typ == b.Schema.Subscription,
Directives: dirs,
ResolverInterface: types.NewNamed(
types.NewTypeName(0, b.Config.Exec.Pkg(), strings.Title(typ.Name)+"Resolver", nil),
nil,
nil,
),
}
if !obj.Root {
goObject, err := b.Binder.DefaultUserObject(typ.Name)
if err != nil {
return nil, err
}
obj.Type = goObject
}
for _, intf := range b.Schema.GetImplements(typ) {
obj.Implements = append(obj.Implements, b.Schema.Types[intf.Name])
}
for _, field := range typ.Fields {
if strings.HasPrefix(field.Name, "__") {
continue
}
var f *Field
f, err = b.buildField(obj, field)
if err != nil {
return nil, err
}
obj.Fields = append(obj.Fields, f)
}
return obj, nil
}
func (o *Object) Reference() types.Type {
if config.IsNilable(o.Type) {
return o.Type
}
return types.NewPointer(o.Type)
}
type Objects []*Object
func (o *Object) Implementors() string {
satisfiedBy := strconv.Quote(o.Name)
for _, s := range o.Implements {
satisfiedBy += ", " + strconv.Quote(s.Name)
}
return "[]string{" + satisfiedBy + "}"
}
func (o *Object) HasResolvers() bool {
for _, f := range o.Fields {
if f.IsResolver {
return true
}
}
return false
}
func (o *Object) HasUnmarshal() bool {
if o.Type == config.MapType {
return true
}
for i := 0; i < o.Type.(*types.Named).NumMethods(); i++ {
if o.Type.(*types.Named).Method(i).Name() == "UnmarshalGQL" {
return true
}
}
return false
}
func (o *Object) HasDirectives() bool {
if len(o.Directives) > 0 {
return true
}
for _, f := range o.Fields {
if f.HasDirectives() {
return true
}
}
return false
}
func (o *Object) IsConcurrent() bool {
for _, f := range o.Fields {
if f.IsConcurrent() {
return true
}
}
return false
}
func (o *Object) IsReserved() bool {
return strings.HasPrefix(o.Definition.Name, "__")
}
func (o *Object) Description() string {
return o.Definition.Description
}
func (os Objects) ByName(name string) *Object {
for i, o := range os {
if strings.EqualFold(o.Definition.Name, name) {
return os[i]
}
}
return nil
}
func ucFirst(s string) string {
if s == "" {
return ""
}
r := []rune(s)
r[0] = unicode.ToUpper(r[0])
return string(r)
}

View File

@@ -1,113 +0,0 @@
{{- range $object := .Objects }}
var {{ $object.Name|lcFirst}}Implementors = {{$object.Implementors}}
{{- if .Stream }}
func (ec *executionContext) _{{$object.Name}}(ctx context.Context, sel ast.SelectionSet) func() graphql.Marshaler {
fields := graphql.CollectFields(ec.OperationContext, sel, {{$object.Name|lcFirst}}Implementors)
ctx = graphql.WithFieldContext(ctx, &graphql.FieldContext{
Object: {{$object.Name|quote}},
})
if len(fields) != 1 {
ec.Errorf(ctx, "must subscribe to exactly one stream")
return nil
}
switch fields[0].Name {
{{- range $field := $object.Fields }}
case "{{$field.Name}}":
return ec._{{$object.Name}}_{{$field.Name}}(ctx, fields[0])
{{- end }}
default:
panic("unknown field " + strconv.Quote(fields[0].Name))
}
}
{{- else }}
func (ec *executionContext) _{{$object.Name}}(ctx context.Context, sel ast.SelectionSet{{ if not $object.Root }},obj {{$object.Reference | ref }}{{ end }}) graphql.Marshaler {
fields := graphql.CollectFields(ec.OperationContext, sel, {{$object.Name|lcFirst}}Implementors)
{{- if $object.Root }}
ctx = graphql.WithFieldContext(ctx, &graphql.FieldContext{
Object: {{$object.Name|quote}},
})
{{end}}
out := graphql.NewFieldSet(fields)
var invalids uint32
for i, field := range fields {
{{- if $object.Root }}
innerCtx := graphql.WithRootFieldContext(ctx, &graphql.RootFieldContext{
Object: field.Name,
Field: field,
})
{{end}}
switch field.Name {
case "__typename":
out.Values[i] = graphql.MarshalString({{$object.Name|quote}})
{{- range $field := $object.Fields }}
case "{{$field.Name}}":
{{- if $field.IsConcurrent }}
field := field
innerFunc := func(ctx context.Context) (res graphql.Marshaler) {
defer func() {
if r := recover(); r != nil {
ec.Error(ctx, ec.Recover(ctx, r))
}
}()
res = ec._{{$object.Name}}_{{$field.Name}}(ctx, field{{if not $object.Root}}, obj{{end}})
{{- if $field.TypeReference.GQL.NonNull }}
if res == graphql.Null {
{{- if $object.IsConcurrent }}
atomic.AddUint32(&invalids, 1)
{{- else }}
invalids++
{{- end }}
}
{{- end }}
return res
}
{{if $object.Root}}
rrm := func(ctx context.Context) graphql.Marshaler {
return ec.OperationContext.RootResolverMiddleware(ctx, innerFunc)
}
{{end}}
out.Concurrently(i, func() graphql.Marshaler {
{{- if $object.Root -}}
return rrm(innerCtx)
{{- else -}}
return innerFunc(ctx)
{{end}}
})
{{- else }}
innerFunc := func(ctx context.Context) (res graphql.Marshaler) {
return ec._{{$object.Name}}_{{$field.Name}}(ctx, field{{if not $object.Root}}, obj{{end}})
}
{{if $object.Root}}
out.Values[i] = ec.OperationContext.RootResolverMiddleware(innerCtx, innerFunc)
{{else}}
out.Values[i] = innerFunc(ctx)
{{end}}
{{- if $field.TypeReference.GQL.NonNull }}
if out.Values[i] == graphql.Null {
{{- if $object.IsConcurrent }}
atomic.AddUint32(&invalids, 1)
{{- else }}
invalids++
{{- end }}
}
{{- end }}
{{- end }}
{{- end }}
default:
panic("unknown field " + strconv.Quote(field.Name))
}
}
out.Dispatch()
if invalids > 0 { return graphql.Null }
return out
}
{{- end }}
{{- end }}

View File

@@ -1,201 +0,0 @@
{{ reserveImport "context" }}
{{ reserveImport "fmt" }}
{{ reserveImport "io" }}
{{ reserveImport "strconv" }}
{{ reserveImport "time" }}
{{ reserveImport "sync" }}
{{ reserveImport "sync/atomic" }}
{{ reserveImport "errors" }}
{{ reserveImport "bytes" }}
{{ reserveImport "github.com/vektah/gqlparser/v2" "gqlparser" }}
{{ reserveImport "github.com/vektah/gqlparser/v2/ast" }}
{{ reserveImport "github.com/99designs/gqlgen/graphql" }}
{{ reserveImport "github.com/99designs/gqlgen/graphql/introspection" }}
// NewExecutableSchema creates an ExecutableSchema from the ResolverRoot interface.
func NewExecutableSchema(cfg Config) graphql.ExecutableSchema {
return &executableSchema{
resolvers: cfg.Resolvers,
directives: cfg.Directives,
complexity: cfg.Complexity,
}
}
type Config struct {
Resolvers ResolverRoot
Directives DirectiveRoot
Complexity ComplexityRoot
}
type ResolverRoot interface {
{{- range $object := .Objects -}}
{{ if $object.HasResolvers -}}
{{ucFirst $object.Name}}() {{ucFirst $object.Name}}Resolver
{{ end }}
{{- end }}
}
type DirectiveRoot struct {
{{ range $directive := .Directives }}
{{- $directive.Declaration }}
{{ end }}
}
type ComplexityRoot struct {
{{ range $object := .Objects }}
{{ if not $object.IsReserved -}}
{{ ucFirst $object.Name }} struct {
{{ range $_, $fields := $object.UniqueFields }}
{{- $field := index $fields 0 -}}
{{ if not $field.IsReserved -}}
{{ $field.GoFieldName }} {{ $field.ComplexitySignature }}
{{ end }}
{{- end }}
}
{{- end }}
{{ end }}
}
type executableSchema struct {
resolvers ResolverRoot
directives DirectiveRoot
complexity ComplexityRoot
}
func (e *executableSchema) Schema() *ast.Schema {
return parsedSchema
}
func (e *executableSchema) Complexity(typeName, field string, childComplexity int, rawArgs map[string]interface{}) (int, bool) {
ec := executionContext{nil, e}
_ = ec
switch typeName + "." + field {
{{ range $object := .Objects }}
{{ if not $object.IsReserved }}
{{ range $_, $fields := $object.UniqueFields }}
{{- $len := len $fields }}
{{- range $i, $field := $fields }}
{{- $last := eq (add $i 1) $len }}
{{- if not $field.IsReserved }}
{{- if eq $i 0 }}case {{ end }}"{{$object.Name}}.{{$field.Name}}"{{ if not $last }},{{ else }}:
if e.complexity.{{ucFirst $object.Name }}.{{$field.GoFieldName}} == nil {
break
}
{{ if $field.Args }}
args, err := ec.{{ $field.ArgsFunc }}(context.TODO(),rawArgs)
if err != nil {
return 0, false
}
{{ end }}
return e.complexity.{{ucFirst $object.Name}}.{{$field.GoFieldName}}(childComplexity{{if $field.Args}}, {{$field.ComplexityArgs}} {{ end }}), true
{{ end }}
{{- end }}
{{- end }}
{{ end }}
{{ end }}
{{ end }}
}
return 0, false
}
func (e *executableSchema) Exec(ctx context.Context) graphql.ResponseHandler {
rc := graphql.GetOperationContext(ctx)
ec := executionContext{rc, e}
first := true
switch rc.Operation.Operation {
{{- if .QueryRoot }} case ast.Query:
return func(ctx context.Context) *graphql.Response {
if !first { return nil }
first = false
{{ if .Directives.LocationDirectives "QUERY" -}}
data := ec._queryMiddleware(ctx, rc.Operation, func(ctx context.Context) (interface{}, error){
return ec._{{.QueryRoot.Name}}(ctx, rc.Operation.SelectionSet), nil
})
{{- else -}}
data := ec._{{.QueryRoot.Name}}(ctx, rc.Operation.SelectionSet)
{{- end }}
var buf bytes.Buffer
data.MarshalGQL(&buf)
return &graphql.Response{
Data: buf.Bytes(),
}
}
{{ end }}
{{- if .MutationRoot }} case ast.Mutation:
return func(ctx context.Context) *graphql.Response {
if !first { return nil }
first = false
{{ if .Directives.LocationDirectives "MUTATION" -}}
data := ec._mutationMiddleware(ctx, rc.Operation, func(ctx context.Context) (interface{}, error){
return ec._{{.MutationRoot.Name}}(ctx, rc.Operation.SelectionSet), nil
})
{{- else -}}
data := ec._{{.MutationRoot.Name}}(ctx, rc.Operation.SelectionSet)
{{- end }}
var buf bytes.Buffer
data.MarshalGQL(&buf)
return &graphql.Response{
Data: buf.Bytes(),
}
}
{{ end }}
{{- if .SubscriptionRoot }} case ast.Subscription:
{{ if .Directives.LocationDirectives "SUBSCRIPTION" -}}
next := ec._subscriptionMiddleware(ctx, rc.Operation, func(ctx context.Context) (interface{}, error){
return ec._{{.SubscriptionRoot.Name}}(ctx, rc.Operation.SelectionSet),nil
})
{{- else -}}
next := ec._{{.SubscriptionRoot.Name}}(ctx, rc.Operation.SelectionSet)
{{- end }}
var buf bytes.Buffer
return func(ctx context.Context) *graphql.Response {
buf.Reset()
data := next()
if data == nil {
return nil
}
data.MarshalGQL(&buf)
return &graphql.Response{
Data: buf.Bytes(),
}
}
{{ end }}
default:
return graphql.OneShot(graphql.ErrorResponse(ctx, "unsupported GraphQL operation"))
}
}
type executionContext struct {
*graphql.OperationContext
*executableSchema
}
func (ec *executionContext) introspectSchema() (*introspection.Schema, error) {
if ec.DisableIntrospection {
return nil, errors.New("introspection disabled")
}
return introspection.WrapSchema(parsedSchema), nil
}
func (ec *executionContext) introspectType(name string) (*introspection.Type, error) {
if ec.DisableIntrospection {
return nil, errors.New("introspection disabled")
}
return introspection.WrapTypeFromDef(parsedSchema, parsedSchema.Types[name]), nil
}
var sources = []*ast.Source{
{{- range $source := .Config.Sources }}
{Name: {{$source.Name|quote}}, Input: {{$source.Input|rawQuote}}, BuiltIn: {{$source.BuiltIn}}},
{{- end }}
}
var parsedSchema = gqlparser.MustLoadSchema(sources...)

View File

@@ -1,139 +0,0 @@
package templates
import (
"fmt"
"go/types"
"strconv"
"strings"
"github.com/99designs/gqlgen/internal/code"
)
type Import struct {
Name string
Path string
Alias string
}
type Imports struct {
imports []*Import
destDir string
packages *code.Packages
}
func (i *Import) String() string {
if strings.HasSuffix(i.Path, i.Alias) {
return strconv.Quote(i.Path)
}
return i.Alias + " " + strconv.Quote(i.Path)
}
func (s *Imports) String() string {
res := ""
for i, imp := range s.imports {
if i != 0 {
res += "\n"
}
res += imp.String()
}
return res
}
func (s *Imports) Reserve(path string, aliases ...string) (string, error) {
if path == "" {
panic("empty ambient import")
}
// if we are referencing our own package we dont need an import
if code.ImportPathForDir(s.destDir) == path {
return "", nil
}
name := s.packages.NameForPackage(path)
var alias string
if len(aliases) != 1 {
alias = name
} else {
alias = aliases[0]
}
if existing := s.findByPath(path); existing != nil {
if existing.Alias == alias {
return "", nil
}
return "", fmt.Errorf("ambient import already exists")
}
if alias := s.findByAlias(alias); alias != nil {
return "", fmt.Errorf("ambient import collides on an alias")
}
s.imports = append(s.imports, &Import{
Name: name,
Path: path,
Alias: alias,
})
return "", nil
}
func (s *Imports) Lookup(path string) string {
if path == "" {
return ""
}
path = code.NormalizeVendor(path)
// if we are referencing our own package we dont need an import
if code.ImportPathForDir(s.destDir) == path {
return ""
}
if existing := s.findByPath(path); existing != nil {
return existing.Alias
}
imp := &Import{
Name: s.packages.NameForPackage(path),
Path: path,
}
s.imports = append(s.imports, imp)
alias := imp.Name
i := 1
for s.findByAlias(alias) != nil {
alias = imp.Name + strconv.Itoa(i)
i++
if i > 1000 {
panic(fmt.Errorf("too many collisions, last attempt was %s", alias))
}
}
imp.Alias = alias
return imp.Alias
}
func (s *Imports) LookupType(t types.Type) string {
return types.TypeString(t, func(i *types.Package) string {
return s.Lookup(i.Path())
})
}
func (s Imports) findByPath(importPath string) *Import {
for _, imp := range s.imports {
if imp.Path == importPath {
return imp
}
}
return nil
}
func (s Imports) findByAlias(alias string) *Import {
for _, imp := range s.imports {
if imp.Alias == alias {
return imp
}
}
return nil
}

View File

@@ -1,611 +0,0 @@
package templates
import (
"bytes"
"fmt"
"go/types"
"io/ioutil"
"os"
"path/filepath"
"reflect"
"runtime"
"sort"
"strconv"
"strings"
"text/template"
"unicode"
"github.com/99designs/gqlgen/internal/code"
"github.com/99designs/gqlgen/internal/imports"
)
// CurrentImports keeps track of all the import declarations that are needed during the execution of a plugin.
// this is done with a global because subtemplates currently get called in functions. Lets aim to remove this eventually.
var CurrentImports *Imports
// Options specify various parameters to rendering a template.
type Options struct {
// PackageName is a helper that specifies the package header declaration.
// In other words, when you write the template you don't need to specify `package X`
// at the top of the file. By providing PackageName in the Options, the Render
// function will do that for you.
PackageName string
// Template is a string of the entire template that
// will be parsed and rendered. If it's empty,
// the plugin processor will look for .gotpl files
// in the same directory of where you wrote the plugin.
Template string
// Filename is the name of the file that will be
// written to the system disk once the template is rendered.
Filename string
RegionTags bool
GeneratedHeader bool
// PackageDoc is documentation written above the package line
PackageDoc string
// FileNotice is notice written below the package line
FileNotice string
// Data will be passed to the template execution.
Data interface{}
Funcs template.FuncMap
// Packages cache, you can find me on config.Config
Packages *code.Packages
}
// Render renders a gql plugin template from the given Options. Render is an
// abstraction of the text/template package that makes it easier to write gqlgen
// plugins. If Options.Template is empty, the Render function will look for `.gotpl`
// files inside the directory where you wrote the plugin.
func Render(cfg Options) error {
if CurrentImports != nil {
panic(fmt.Errorf("recursive or concurrent call to RenderToFile detected"))
}
CurrentImports = &Imports{packages: cfg.Packages, destDir: filepath.Dir(cfg.Filename)}
// load path relative to calling source file
_, callerFile, _, _ := runtime.Caller(1)
rootDir := filepath.Dir(callerFile)
funcs := Funcs()
for n, f := range cfg.Funcs {
funcs[n] = f
}
t := template.New("").Funcs(funcs)
var roots []string
if cfg.Template != "" {
var err error
t, err = t.New("template.gotpl").Parse(cfg.Template)
if err != nil {
return fmt.Errorf("error with provided template: %w", err)
}
roots = append(roots, "template.gotpl")
} else {
// load all the templates in the directory
err := filepath.Walk(rootDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
name := filepath.ToSlash(strings.TrimPrefix(path, rootDir+string(os.PathSeparator)))
if !strings.HasSuffix(info.Name(), ".gotpl") {
return nil
}
// omit any templates with "_" at the end of their name, which are meant for specific contexts only
if strings.HasSuffix(info.Name(), "_.gotpl") {
return nil
}
b, err := ioutil.ReadFile(path)
if err != nil {
return err
}
t, err = t.New(name).Parse(string(b))
if err != nil {
return fmt.Errorf("%s: %w", cfg.Filename, err)
}
roots = append(roots, name)
return nil
})
if err != nil {
return fmt.Errorf("locating templates: %w", err)
}
}
// then execute all the important looking ones in order, adding them to the same file
sort.Slice(roots, func(i, j int) bool {
// important files go first
if strings.HasSuffix(roots[i], "!.gotpl") {
return true
}
if strings.HasSuffix(roots[j], "!.gotpl") {
return false
}
return roots[i] < roots[j]
})
var buf bytes.Buffer
for _, root := range roots {
if cfg.RegionTags {
buf.WriteString("\n// region " + center(70, "*", " "+root+" ") + "\n")
}
err := t.Lookup(root).Execute(&buf, cfg.Data)
if err != nil {
return fmt.Errorf("%s: %w", root, err)
}
if cfg.RegionTags {
buf.WriteString("\n// endregion " + center(70, "*", " "+root+" ") + "\n")
}
}
var result bytes.Buffer
if cfg.GeneratedHeader {
result.WriteString("// Code generated by github.com/99designs/gqlgen, DO NOT EDIT.\n\n")
}
if cfg.PackageDoc != "" {
result.WriteString(cfg.PackageDoc + "\n")
}
result.WriteString("package ")
result.WriteString(cfg.PackageName)
result.WriteString("\n\n")
if cfg.FileNotice != "" {
result.WriteString(cfg.FileNotice)
result.WriteString("\n\n")
}
result.WriteString("import (\n")
result.WriteString(CurrentImports.String())
result.WriteString(")\n")
_, err := buf.WriteTo(&result)
if err != nil {
return err
}
CurrentImports = nil
err = write(cfg.Filename, result.Bytes(), cfg.Packages)
if err != nil {
return err
}
cfg.Packages.Evict(code.ImportPathForDir(filepath.Dir(cfg.Filename)))
return nil
}
func center(width int, pad string, s string) string {
if len(s)+2 > width {
return s
}
lpad := (width - len(s)) / 2
rpad := width - (lpad + len(s))
return strings.Repeat(pad, lpad) + s + strings.Repeat(pad, rpad)
}
func Funcs() template.FuncMap {
return template.FuncMap{
"ucFirst": UcFirst,
"lcFirst": LcFirst,
"quote": strconv.Quote,
"rawQuote": rawQuote,
"dump": Dump,
"ref": ref,
"ts": TypeIdentifier,
"call": Call,
"prefixLines": prefixLines,
"notNil": notNil,
"reserveImport": CurrentImports.Reserve,
"lookupImport": CurrentImports.Lookup,
"go": ToGo,
"goPrivate": ToGoPrivate,
"add": func(a, b int) int {
return a + b
},
"render": func(filename string, tpldata interface{}) (*bytes.Buffer, error) {
return render(resolveName(filename, 0), tpldata)
},
}
}
func UcFirst(s string) string {
if s == "" {
return ""
}
r := []rune(s)
r[0] = unicode.ToUpper(r[0])
return string(r)
}
func LcFirst(s string) string {
if s == "" {
return ""
}
r := []rune(s)
r[0] = unicode.ToLower(r[0])
return string(r)
}
func isDelimiter(c rune) bool {
return c == '-' || c == '_' || unicode.IsSpace(c)
}
func ref(p types.Type) string {
return CurrentImports.LookupType(p)
}
var pkgReplacer = strings.NewReplacer(
"/", "ᚋ",
".", "ᚗ",
"-", "ᚑ",
"~", "א",
)
func TypeIdentifier(t types.Type) string {
res := ""
for {
switch it := t.(type) {
case *types.Pointer:
t.Underlying()
res += "ᚖ"
t = it.Elem()
case *types.Slice:
res += "ᚕ"
t = it.Elem()
case *types.Named:
res += pkgReplacer.Replace(it.Obj().Pkg().Path())
res += "ᚐ"
res += it.Obj().Name()
return res
case *types.Basic:
res += it.Name()
return res
case *types.Map:
res += "map"
return res
case *types.Interface:
res += "interface"
return res
default:
panic(fmt.Errorf("unexpected type %T", it))
}
}
}
func Call(p *types.Func) string {
pkg := CurrentImports.Lookup(p.Pkg().Path())
if pkg != "" {
pkg += "."
}
if p.Type() != nil {
// make sure the returned type is listed in our imports.
ref(p.Type().(*types.Signature).Results().At(0).Type())
}
return pkg + p.Name()
}
func ToGo(name string) string {
if name == "_" {
return "_"
}
runes := make([]rune, 0, len(name))
wordWalker(name, func(info *wordInfo) {
word := info.Word
if info.MatchCommonInitial {
word = strings.ToUpper(word)
} else if !info.HasCommonInitial {
if strings.ToUpper(word) == word || strings.ToLower(word) == word {
// FOO or foo → Foo
// FOo → FOo
word = UcFirst(strings.ToLower(word))
}
}
runes = append(runes, []rune(word)...)
})
return string(runes)
}
func ToGoPrivate(name string) string {
if name == "_" {
return "_"
}
runes := make([]rune, 0, len(name))
first := true
wordWalker(name, func(info *wordInfo) {
word := info.Word
switch {
case first:
if strings.ToUpper(word) == word || strings.ToLower(word) == word {
// ID → id, CAMEL → camel
word = strings.ToLower(info.Word)
} else {
// ITicket → iTicket
word = LcFirst(info.Word)
}
first = false
case info.MatchCommonInitial:
word = strings.ToUpper(word)
case !info.HasCommonInitial:
word = UcFirst(strings.ToLower(word))
}
runes = append(runes, []rune(word)...)
})
return sanitizeKeywords(string(runes))
}
type wordInfo struct {
Word string
MatchCommonInitial bool
HasCommonInitial bool
}
// This function is based on the following code.
// https://github.com/golang/lint/blob/06c8688daad7faa9da5a0c2f163a3d14aac986ca/lint.go#L679
func wordWalker(str string, f func(*wordInfo)) {
runes := []rune(strings.TrimFunc(str, isDelimiter))
w, i := 0, 0 // index of start of word, scan
hasCommonInitial := false
for i+1 <= len(runes) {
eow := false // whether we hit the end of a word
switch {
case i+1 == len(runes):
eow = true
case isDelimiter(runes[i+1]):
// underscore; shift the remainder forward over any run of underscores
eow = true
n := 1
for i+n+1 < len(runes) && isDelimiter(runes[i+n+1]) {
n++
}
// Leave at most one underscore if the underscore is between two digits
if i+n+1 < len(runes) && unicode.IsDigit(runes[i]) && unicode.IsDigit(runes[i+n+1]) {
n--
}
copy(runes[i+1:], runes[i+n+1:])
runes = runes[:len(runes)-n]
case unicode.IsLower(runes[i]) && !unicode.IsLower(runes[i+1]):
// lower->non-lower
eow = true
}
i++
// [w,i) is a word.
word := string(runes[w:i])
if !eow && commonInitialisms[word] && !unicode.IsLower(runes[i]) {
// through
// split IDFoo → ID, Foo
// but URLs → URLs
} else if !eow {
if commonInitialisms[word] {
hasCommonInitial = true
}
continue
}
matchCommonInitial := false
if commonInitialisms[strings.ToUpper(word)] {
hasCommonInitial = true
matchCommonInitial = true
}
f(&wordInfo{
Word: word,
MatchCommonInitial: matchCommonInitial,
HasCommonInitial: hasCommonInitial,
})
hasCommonInitial = false
w = i
}
}
var keywords = []string{
"break",
"default",
"func",
"interface",
"select",
"case",
"defer",
"go",
"map",
"struct",
"chan",
"else",
"goto",
"package",
"switch",
"const",
"fallthrough",
"if",
"range",
"type",
"continue",
"for",
"import",
"return",
"var",
"_",
}
// sanitizeKeywords prevents collisions with go keywords for arguments to resolver functions
func sanitizeKeywords(name string) string {
for _, k := range keywords {
if name == k {
return name + "Arg"
}
}
return name
}
// commonInitialisms is a set of common initialisms.
// Only add entries that are highly unlikely to be non-initialisms.
// For instance, "ID" is fine (Freudian code is rare), but "AND" is not.
var commonInitialisms = map[string]bool{
"ACL": true,
"API": true,
"ASCII": true,
"CPU": true,
"CSS": true,
"CSV": true,
"DNS": true,
"EOF": true,
"GUID": true,
"HTML": true,
"HTTP": true,
"HTTPS": true,
"ICMP": true,
"ID": true,
"IP": true,
"JSON": true,
"KVK": true,
"LHS": true,
"PDF": true,
"PGP": true,
"QPS": true,
"QR": true,
"RAM": true,
"RHS": true,
"RPC": true,
"SLA": true,
"SMTP": true,
"SQL": true,
"SSH": true,
"SVG": true,
"TCP": true,
"TLS": true,
"TTL": true,
"UDP": true,
"UI": true,
"UID": true,
"URI": true,
"URL": true,
"UTF8": true,
"UUID": true,
"VM": true,
"XML": true,
"XMPP": true,
"XSRF": true,
"XSS": true,
}
func rawQuote(s string) string {
return "`" + strings.ReplaceAll(s, "`", "`+\"`\"+`") + "`"
}
func notNil(field string, data interface{}) bool {
v := reflect.ValueOf(data)
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
if v.Kind() != reflect.Struct {
return false
}
val := v.FieldByName(field)
return val.IsValid() && !val.IsNil()
}
func Dump(val interface{}) string {
switch val := val.(type) {
case int:
return strconv.Itoa(val)
case int64:
return fmt.Sprintf("%d", val)
case float64:
return fmt.Sprintf("%f", val)
case string:
return strconv.Quote(val)
case bool:
return strconv.FormatBool(val)
case nil:
return "nil"
case []interface{}:
var parts []string
for _, part := range val {
parts = append(parts, Dump(part))
}
return "[]interface{}{" + strings.Join(parts, ",") + "}"
case map[string]interface{}:
buf := bytes.Buffer{}
buf.WriteString("map[string]interface{}{")
var keys []string
for key := range val {
keys = append(keys, key)
}
sort.Strings(keys)
for _, key := range keys {
data := val[key]
buf.WriteString(strconv.Quote(key))
buf.WriteString(":")
buf.WriteString(Dump(data))
buf.WriteString(",")
}
buf.WriteString("}")
return buf.String()
default:
panic(fmt.Errorf("unsupported type %T", val))
}
}
func prefixLines(prefix, s string) string {
return prefix + strings.ReplaceAll(s, "\n", "\n"+prefix)
}
func resolveName(name string, skip int) string {
if name[0] == '.' {
// load path relative to calling source file
_, callerFile, _, _ := runtime.Caller(skip + 1)
return filepath.Join(filepath.Dir(callerFile), name[1:])
}
// load path relative to this directory
_, callerFile, _, _ := runtime.Caller(0)
return filepath.Join(filepath.Dir(callerFile), name)
}
func render(filename string, tpldata interface{}) (*bytes.Buffer, error) {
t := template.New("").Funcs(Funcs())
b, err := ioutil.ReadFile(filename)
if err != nil {
return nil, err
}
t, err = t.New(filepath.Base(filename)).Parse(string(b))
if err != nil {
panic(err)
}
buf := &bytes.Buffer{}
return buf, t.Execute(buf, tpldata)
}
func write(filename string, b []byte, packages *code.Packages) error {
err := os.MkdirAll(filepath.Dir(filename), 0o755)
if err != nil {
return fmt.Errorf("failed to create directory: %w", err)
}
formatted, err := imports.Prune(filename, b, packages)
if err != nil {
fmt.Fprintf(os.Stderr, "gofmt failed on %s: %s\n", filepath.Base(filename), err.Error())
formatted = b
}
err = ioutil.WriteFile(filename, formatted, 0o644)
if err != nil {
return fmt.Errorf("failed to write %s: %w", filename, err)
}
return nil
}

View File

@@ -1,32 +0,0 @@
package codegen
import (
"fmt"
"github.com/99designs/gqlgen/codegen/config"
)
func (b *builder) buildTypes() map[string]*config.TypeReference {
ret := map[string]*config.TypeReference{}
for _, ref := range b.Binder.References {
processType(ret, ref)
}
return ret
}
func processType(ret map[string]*config.TypeReference, ref *config.TypeReference) {
key := ref.UniquenessKey()
if existing, found := ret[key]; found {
// Simplistic check of content which is obviously different.
existingGQL := fmt.Sprintf("%v", existing.GQL)
newGQL := fmt.Sprintf("%v", ref.GQL)
if existingGQL != newGQL {
panic(fmt.Sprintf("non-unique key \"%s\", trying to replace %s with %s", key, existingGQL, newGQL))
}
}
ret[key] = ref
if ref.IsSlice() || ref.IsPtrToSlice() || ref.IsPtrToPtr() {
processType(ret, ref.Elem())
}
}

View File

@@ -1,192 +0,0 @@
{{- range $type := .ReferencedTypes }}
{{ with $type.UnmarshalFunc }}
func (ec *executionContext) {{ . }}(ctx context.Context, v interface{}) ({{ $type.GO | ref }}, error) {
{{- if and $type.IsNilable (not $type.GQL.NonNull) (not $type.IsPtrToPtr) }}
if v == nil { return nil, nil }
{{- end }}
{{- if $type.IsPtrToSlice }}
res, err := ec.{{ $type.Elem.UnmarshalFunc }}(ctx, v)
return &res, graphql.ErrorOnPath(ctx, err)
{{- else if $type.IsSlice }}
var vSlice []interface{}
if v != nil {
vSlice = graphql.CoerceList(v)
}
var err error
res := make([]{{$type.GO.Elem | ref}}, len(vSlice))
for i := range vSlice {
ctx := graphql.WithPathContext(ctx, graphql.NewPathWithIndex(i))
res[i], err = ec.{{ $type.Elem.UnmarshalFunc }}(ctx, vSlice[i])
if err != nil {
return nil, err
}
}
return res, nil
{{- else if and $type.IsPtrToPtr (not $type.Unmarshaler) (not $type.IsMarshaler) }}
var pres {{ $type.Elem.GO | ref }}
if v != nil {
res, err := ec.{{ $type.Elem.UnmarshalFunc }}(ctx, v)
if err != nil {
return nil, graphql.ErrorOnPath(ctx, err)
}
pres = res
}
return &pres, nil
{{- else }}
{{- if $type.Unmarshaler }}
{{- if $type.CastType }}
{{- if $type.IsContext }}
tmp, err := {{ $type.Unmarshaler | call }}(ctx, v)
{{- else }}
tmp, err := {{ $type.Unmarshaler | call }}(v)
{{- end }}
{{- if and $type.IsNilable $type.Elem }}
res := {{ $type.Elem.GO | ref }}(tmp)
{{- else}}
res := {{ $type.GO | ref }}(tmp)
{{- end }}
{{- else}}
{{- if $type.IsContext }}
res, err := {{ $type.Unmarshaler | call }}(ctx, v)
{{- else }}
res, err := {{ $type.Unmarshaler | call }}(v)
{{- end }}
{{- end }}
{{- if and $type.IsTargetNilable (not $type.IsNilable) }}
return *res, graphql.ErrorOnPath(ctx, err)
{{- else if and (not $type.IsTargetNilable) $type.IsNilable }}
return &res, graphql.ErrorOnPath(ctx, err)
{{- else}}
return res, graphql.ErrorOnPath(ctx, err)
{{- end }}
{{- else if eq ($type.GO | ref) "map[string]interface{}" }}
return v.(map[string]interface{}), nil
{{- else if $type.IsMarshaler }}
{{- if and $type.IsNilable $type.Elem }}
var res = new({{ $type.Elem.GO | ref }})
{{- else}}
var res {{ $type.GO | ref }}
{{- end }}
{{- if $type.IsContext }}
err := res.UnmarshalGQLContext(ctx, v)
{{- else }}
err := res.UnmarshalGQL(v)
{{- end }}
return res, graphql.ErrorOnPath(ctx, err)
{{- else }}
res, err := ec.unmarshalInput{{ $type.GQL.Name }}(ctx, v)
{{- if $type.IsNilable }}
return &res, graphql.ErrorOnPath(ctx, err)
{{- else}}
return res, graphql.ErrorOnPath(ctx, err)
{{- end }}
{{- end }}
{{- end }}
}
{{- end }}
{{ with $type.MarshalFunc }}
func (ec *executionContext) {{ . }}(ctx context.Context, sel ast.SelectionSet, v {{ $type.GO | ref }}) graphql.Marshaler {
{{- if $type.IsPtrToSlice }}
return ec.{{ $type.Elem.MarshalFunc }}(ctx, sel, *v)
{{- else if $type.IsSlice }}
{{- if not $type.GQL.NonNull }}
if v == nil {
return graphql.Null
}
{{- end }}
ret := make(graphql.Array, len(v))
{{- if not $type.IsScalar }}
var wg sync.WaitGroup
isLen1 := len(v) == 1
if !isLen1 {
wg.Add(len(v))
}
{{- end }}
for i := range v {
{{- if not $type.IsScalar }}
i := i
fc := &graphql.FieldContext{
Index: &i,
Result: &v[i],
}
ctx := graphql.WithFieldContext(ctx, fc)
f := func(i int) {
defer func() {
if r := recover(); r != nil {
ec.Error(ctx, ec.Recover(ctx, r))
ret = nil
}
}()
if !isLen1 {
defer wg.Done()
}
ret[i] = ec.{{ $type.Elem.MarshalFunc }}(ctx, sel, v[i])
}
if isLen1 {
f(i)
} else {
go f(i)
}
{{ else }}
ret[i] = ec.{{ $type.Elem.MarshalFunc }}(ctx, sel, v[i])
{{- end }}
}
{{ if not $type.IsScalar }} wg.Wait() {{ end }}
{{ if $type.Elem.GQL.NonNull }}
for _, e := range ret {
if e == graphql.Null {
return graphql.Null
}
}
{{ end }}
return ret
{{- else if and $type.IsPtrToPtr (not $type.Unmarshaler) (not $type.IsMarshaler) }}
if v == nil {
return graphql.Null
}
return ec.{{ $type.Elem.MarshalFunc }}(ctx, sel, *v)
{{- else }}
{{- if $type.IsNilable }}
if v == nil {
{{- if $type.GQL.NonNull }}
if !graphql.HasFieldError(ctx, graphql.GetFieldContext(ctx)) {
ec.Errorf(ctx, "must not be null")
}
{{- end }}
return graphql.Null
}
{{- end }}
{{- if $type.IsMarshaler }}
{{- if $type.IsContext }}
return graphql.WrapContextMarshaler(ctx, v)
{{- else }}
return v
{{- end }}
{{- else if $type.Marshaler }}
{{- $v := "v" }}
{{- if and $type.IsTargetNilable (not $type.IsNilable) }}
{{- $v = "&v" }}
{{- else if and (not $type.IsTargetNilable) $type.IsNilable }}
{{- $v = "*v" }}
{{- end }}
res := {{ $type.Marshaler | call }}({{- if $type.CastType }}{{ $type.CastType | ref }}({{ $v }}){{else}}{{ $v }}{{- end }})
{{- if $type.GQL.NonNull }}
if res == graphql.Null {
if !graphql.HasFieldError(ctx, graphql.GetFieldContext(ctx)) {
ec.Errorf(ctx, "must not be null")
}
}
{{- end }}
{{- if $type.IsContext }}
return graphql.WrapContextMarshaler(ctx, res)
{{- else }}
return res
{{- end }}
{{- else }}
return ec._{{$type.Definition.Name}}(ctx, sel, {{ if not $type.IsNilable}}&{{end}} v)
{{- end }}
{{- end }}
}
{{- end }}
{{- end }}

View File

@@ -1,46 +0,0 @@
package codegen
import (
"fmt"
"go/types"
"strings"
)
func findGoNamedType(def types.Type) (*types.Named, error) {
if def == nil {
return nil, nil
}
namedType, ok := def.(*types.Named)
if !ok {
return nil, fmt.Errorf("expected %s to be a named type, instead found %T\n", def.String(), def)
}
return namedType, nil
}
func findGoInterface(def types.Type) (*types.Interface, error) {
if def == nil {
return nil, nil
}
namedType, err := findGoNamedType(def)
if err != nil {
return nil, err
}
if namedType == nil {
return nil, nil
}
underlying, ok := namedType.Underlying().(*types.Interface)
if !ok {
return nil, fmt.Errorf("expected %s to be a named interface, instead found %s", def.String(), namedType.String())
}
return underlying, nil
}
func equalFieldName(source, target string) bool {
source = strings.ReplaceAll(source, "_", "")
target = strings.ReplaceAll(target, "_", "")
return strings.EqualFold(source, target)
}

View File

@@ -1,109 +0,0 @@
package complexity
import (
"github.com/99designs/gqlgen/graphql"
"github.com/vektah/gqlparser/v2/ast"
)
func Calculate(es graphql.ExecutableSchema, op *ast.OperationDefinition, vars map[string]interface{}) int {
walker := complexityWalker{
es: es,
schema: es.Schema(),
vars: vars,
}
return walker.selectionSetComplexity(op.SelectionSet)
}
type complexityWalker struct {
es graphql.ExecutableSchema
schema *ast.Schema
vars map[string]interface{}
}
func (cw complexityWalker) selectionSetComplexity(selectionSet ast.SelectionSet) int {
var complexity int
for _, selection := range selectionSet {
switch s := selection.(type) {
case *ast.Field:
fieldDefinition := cw.schema.Types[s.Definition.Type.Name()]
if fieldDefinition.Name == "__Schema" {
continue
}
var childComplexity int
switch fieldDefinition.Kind {
case ast.Object, ast.Interface, ast.Union:
childComplexity = cw.selectionSetComplexity(s.SelectionSet)
}
args := s.ArgumentMap(cw.vars)
var fieldComplexity int
if s.ObjectDefinition.Kind == ast.Interface {
fieldComplexity = cw.interfaceFieldComplexity(s.ObjectDefinition, s.Name, childComplexity, args)
} else {
fieldComplexity = cw.fieldComplexity(s.ObjectDefinition.Name, s.Name, childComplexity, args)
}
complexity = safeAdd(complexity, fieldComplexity)
case *ast.FragmentSpread:
complexity = safeAdd(complexity, cw.selectionSetComplexity(s.Definition.SelectionSet))
case *ast.InlineFragment:
complexity = safeAdd(complexity, cw.selectionSetComplexity(s.SelectionSet))
}
}
return complexity
}
func (cw complexityWalker) interfaceFieldComplexity(def *ast.Definition, field string, childComplexity int, args map[string]interface{}) int {
// Interfaces don't have their own separate field costs, so they have to assume the worst case.
// We iterate over all implementors and choose the most expensive one.
maxComplexity := 0
implementors := cw.schema.GetPossibleTypes(def)
for _, t := range implementors {
fieldComplexity := cw.fieldComplexity(t.Name, field, childComplexity, args)
if fieldComplexity > maxComplexity {
maxComplexity = fieldComplexity
}
}
return maxComplexity
}
func (cw complexityWalker) fieldComplexity(object, field string, childComplexity int, args map[string]interface{}) int {
if customComplexity, ok := cw.es.Complexity(object, field, childComplexity, args); ok && customComplexity >= childComplexity {
return customComplexity
}
// default complexity calculation
return safeAdd(1, childComplexity)
}
const maxInt = int(^uint(0) >> 1)
// safeAdd is a saturating add of a and b that ignores negative operands.
// If a + b would overflow through normal Go addition,
// it returns the maximum integer value instead.
//
// Adding complexities with this function prevents attackers from intentionally
// overflowing the complexity calculation to allow overly-complex queries.
//
// It also helps mitigate the impact of custom complexities that accidentally
// return negative values.
func safeAdd(a, b int) int {
// Ignore negative operands.
if a < 0 {
if b < 0 {
return 1
}
return b
} else if b < 0 {
return a
}
c := a + b
if c < a {
// Set c to maximum integer instead of overflowing.
c = maxInt
}
return c
}

View File

@@ -1,19 +0,0 @@
package graphql
import (
"encoding/json"
"io"
)
func MarshalAny(v interface{}) Marshaler {
return WriterFunc(func(w io.Writer) {
err := json.NewEncoder(w).Encode(v)
if err != nil {
panic(err)
}
})
}
func UnmarshalAny(v interface{}) (interface{}, error) {
return v, nil
}

View File

@@ -1,27 +0,0 @@
package graphql
import (
"fmt"
"io"
"strings"
)
func MarshalBoolean(b bool) Marshaler {
if b {
return WriterFunc(func(w io.Writer) { w.Write(trueLit) })
}
return WriterFunc(func(w io.Writer) { w.Write(falseLit) })
}
func UnmarshalBoolean(v interface{}) (bool, error) {
switch v := v.(type) {
case string:
return strings.ToLower(v) == "true", nil
case int:
return v != 0, nil
case bool:
return v, nil
default:
return false, fmt.Errorf("%T is not a bool", v)
}
}

View File

@@ -1,29 +0,0 @@
package graphql
import "context"
// Cache is a shared store for APQ and query AST caching
type Cache interface {
// Get looks up a key's value from the cache.
Get(ctx context.Context, key string) (value interface{}, ok bool)
// Add adds a value to the cache.
Add(ctx context.Context, key string, value interface{})
}
// MapCache is the simplest implementation of a cache, because it can not evict it should only be used in tests
type MapCache map[string]interface{}
// Get looks up a key's value from the cache.
func (m MapCache) Get(ctx context.Context, key string) (value interface{}, ok bool) {
v, ok := m[key]
return v, ok
}
// Add adds a value to the cache.
func (m MapCache) Add(ctx context.Context, key string, value interface{}) { m[key] = value }
type NoCache struct{}
func (n NoCache) Get(ctx context.Context, key string) (value interface{}, ok bool) { return nil, false }
func (n NoCache) Add(ctx context.Context, key string, value interface{}) {}

View File

@@ -1,56 +0,0 @@
package graphql
import (
"encoding/json"
)
// CoerceList applies coercion from a single value to a list.
func CoerceList(v interface{}) []interface{} {
var vSlice []interface{}
if v != nil {
switch v := v.(type) {
case []interface{}:
// already a slice no coercion required
vSlice = v
case []string:
if len(v) > 0 {
vSlice = []interface{}{v[0]}
}
case []json.Number:
if len(v) > 0 {
vSlice = []interface{}{v[0]}
}
case []bool:
if len(v) > 0 {
vSlice = []interface{}{v[0]}
}
case []map[string]interface{}:
if len(v) > 0 {
vSlice = []interface{}{v[0]}
}
case []float64:
if len(v) > 0 {
vSlice = []interface{}{v[0]}
}
case []float32:
if len(v) > 0 {
vSlice = []interface{}{v[0]}
}
case []int:
if len(v) > 0 {
vSlice = []interface{}{v[0]}
}
case []int32:
if len(v) > 0 {
vSlice = []interface{}{v[0]}
}
case []int64:
if len(v) > 0 {
vSlice = []interface{}{v[0]}
}
default:
vSlice = []interface{}{v}
}
}
return vSlice
}

View File

@@ -1,94 +0,0 @@
package graphql
import (
"context"
"time"
"github.com/vektah/gqlparser/v2/ast"
)
type key string
const resolverCtx key = "resolver_context"
// Deprecated: Use FieldContext instead
type ResolverContext = FieldContext
type FieldContext struct {
Parent *FieldContext
// The name of the type this field belongs to
Object string
// These are the args after processing, they can be mutated in middleware to change what the resolver will get.
Args map[string]interface{}
// The raw field
Field CollectedField
// The index of array in path.
Index *int
// The result object of resolver
Result interface{}
// IsMethod indicates if the resolver is a method
IsMethod bool
// IsResolver indicates if the field has a user-specified resolver
IsResolver bool
}
type FieldStats struct {
// When field execution started
Started time.Time
// When argument marshaling finished
ArgumentsCompleted time.Time
// When the field completed running all middleware. Not available inside field middleware!
Completed time.Time
}
func (r *FieldContext) Path() ast.Path {
var path ast.Path
for it := r; it != nil; it = it.Parent {
if it.Index != nil {
path = append(path, ast.PathIndex(*it.Index))
} else if it.Field.Field != nil {
path = append(path, ast.PathName(it.Field.Alias))
}
}
// because we are walking up the chain, all the elements are backwards, do an inplace flip.
for i := len(path)/2 - 1; i >= 0; i-- {
opp := len(path) - 1 - i
path[i], path[opp] = path[opp], path[i]
}
return path
}
// Deprecated: Use GetFieldContext instead
func GetResolverContext(ctx context.Context) *ResolverContext {
return GetFieldContext(ctx)
}
func GetFieldContext(ctx context.Context) *FieldContext {
if val, ok := ctx.Value(resolverCtx).(*FieldContext); ok {
return val
}
return nil
}
func WithFieldContext(ctx context.Context, rc *FieldContext) context.Context {
rc.Parent = GetFieldContext(ctx)
return context.WithValue(ctx, resolverCtx, rc)
}
func equalPath(a ast.Path, b ast.Path) bool {
if len(a) != len(b) {
return false
}
for i := 0; i < len(a); i++ {
if a[i] != b[i] {
return false
}
}
return true
}

View File

@@ -1,115 +0,0 @@
package graphql
import (
"context"
"errors"
"github.com/vektah/gqlparser/v2/ast"
)
// Deprecated: Please update all references to OperationContext instead
type RequestContext = OperationContext
type OperationContext struct {
RawQuery string
Variables map[string]interface{}
OperationName string
Doc *ast.QueryDocument
Operation *ast.OperationDefinition
DisableIntrospection bool
RecoverFunc RecoverFunc
ResolverMiddleware FieldMiddleware
RootResolverMiddleware RootFieldMiddleware
Stats Stats
}
func (c *OperationContext) Validate(ctx context.Context) error {
if c.Doc == nil {
return errors.New("field 'Doc'is required")
}
if c.RawQuery == "" {
return errors.New("field 'RawQuery' is required")
}
if c.Variables == nil {
c.Variables = make(map[string]interface{})
}
if c.ResolverMiddleware == nil {
return errors.New("field 'ResolverMiddleware' is required")
}
if c.RootResolverMiddleware == nil {
return errors.New("field 'RootResolverMiddleware' is required")
}
if c.RecoverFunc == nil {
c.RecoverFunc = DefaultRecover
}
return nil
}
const operationCtx key = "operation_context"
// Deprecated: Please update all references to GetOperationContext instead
func GetRequestContext(ctx context.Context) *RequestContext {
return GetOperationContext(ctx)
}
func GetOperationContext(ctx context.Context) *OperationContext {
if val, ok := ctx.Value(operationCtx).(*OperationContext); ok && val != nil {
return val
}
panic("missing operation context")
}
func WithOperationContext(ctx context.Context, rc *OperationContext) context.Context {
return context.WithValue(ctx, operationCtx, rc)
}
// HasOperationContext checks if the given context is part of an ongoing operation
//
// Some errors can happen outside of an operation, eg json unmarshal errors.
func HasOperationContext(ctx context.Context) bool {
_, ok := ctx.Value(operationCtx).(*OperationContext)
return ok
}
// This is just a convenient wrapper method for CollectFields
func CollectFieldsCtx(ctx context.Context, satisfies []string) []CollectedField {
resctx := GetFieldContext(ctx)
return CollectFields(GetOperationContext(ctx), resctx.Field.Selections, satisfies)
}
// CollectAllFields returns a slice of all GraphQL field names that were selected for the current resolver context.
// The slice will contain the unique set of all field names requested regardless of fragment type conditions.
func CollectAllFields(ctx context.Context) []string {
resctx := GetFieldContext(ctx)
collected := CollectFields(GetOperationContext(ctx), resctx.Field.Selections, nil)
uniq := make([]string, 0, len(collected))
Next:
for _, f := range collected {
for _, name := range uniq {
if name == f.Name {
continue Next
}
}
uniq = append(uniq, f.Name)
}
return uniq
}
// Errorf sends an error string to the client, passing it through the formatter.
// Deprecated: use graphql.AddErrorf(ctx, err) instead
func (c *OperationContext) Errorf(ctx context.Context, format string, args ...interface{}) {
AddErrorf(ctx, format, args...)
}
// Error sends an error to the client, passing it through the formatter.
// Deprecated: use graphql.AddError(ctx, err) instead
func (c *OperationContext) Error(ctx context.Context, err error) {
AddError(ctx, err)
}
func (c *OperationContext) Recover(ctx context.Context, err interface{}) error {
return ErrorOnPath(ctx, c.RecoverFunc(ctx, err))
}

View File

@@ -1,77 +0,0 @@
package graphql
import (
"context"
"github.com/vektah/gqlparser/v2/ast"
)
const fieldInputCtx key = "path_context"
type PathContext struct {
ParentField *FieldContext
Parent *PathContext
Field *string
Index *int
}
func (fic *PathContext) Path() ast.Path {
var path ast.Path
for it := fic; it != nil; it = it.Parent {
if it.Index != nil {
path = append(path, ast.PathIndex(*it.Index))
} else if it.Field != nil {
path = append(path, ast.PathName(*it.Field))
}
}
// because we are walking up the chain, all the elements are backwards, do an inplace flip.
for i := len(path)/2 - 1; i >= 0; i-- {
opp := len(path) - 1 - i
path[i], path[opp] = path[opp], path[i]
}
if fic.ParentField != nil {
fieldPath := fic.ParentField.Path()
return append(fieldPath, path...)
}
return path
}
func NewPathWithField(field string) *PathContext {
return &PathContext{Field: &field}
}
func NewPathWithIndex(index int) *PathContext {
return &PathContext{Index: &index}
}
func WithPathContext(ctx context.Context, fic *PathContext) context.Context {
if fieldContext := GetFieldContext(ctx); fieldContext != nil {
fic.ParentField = fieldContext
}
if fieldInputContext := GetPathContext(ctx); fieldInputContext != nil {
fic.Parent = fieldInputContext
}
return context.WithValue(ctx, fieldInputCtx, fic)
}
func GetPathContext(ctx context.Context) *PathContext {
if val, ok := ctx.Value(fieldInputCtx).(*PathContext); ok {
return val
}
return nil
}
func GetPath(ctx context.Context) ast.Path {
if pc := GetPathContext(ctx); pc != nil {
return pc.Path()
}
if fc := GetFieldContext(ctx); fc != nil {
return fc.Path()
}
return nil
}

View File

@@ -1,153 +0,0 @@
package graphql
import (
"context"
"fmt"
"sync"
"github.com/vektah/gqlparser/v2/gqlerror"
)
type responseContext struct {
errorPresenter ErrorPresenterFunc
recover RecoverFunc
errors gqlerror.List
errorsMu sync.Mutex
extensions map[string]interface{}
extensionsMu sync.Mutex
}
const resultCtx key = "result_context"
func getResponseContext(ctx context.Context) *responseContext {
val, ok := ctx.Value(resultCtx).(*responseContext)
if !ok {
panic("missing response context")
}
return val
}
func WithResponseContext(ctx context.Context, presenterFunc ErrorPresenterFunc, recoverFunc RecoverFunc) context.Context {
return context.WithValue(ctx, resultCtx, &responseContext{
errorPresenter: presenterFunc,
recover: recoverFunc,
})
}
// AddErrorf writes a formatted error to the client, first passing it through the error presenter.
func AddErrorf(ctx context.Context, format string, args ...interface{}) {
AddError(ctx, fmt.Errorf(format, args...))
}
// AddError sends an error to the client, first passing it through the error presenter.
func AddError(ctx context.Context, err error) {
c := getResponseContext(ctx)
presentedError := c.errorPresenter(ctx, ErrorOnPath(ctx, err))
c.errorsMu.Lock()
defer c.errorsMu.Unlock()
c.errors = append(c.errors, presentedError)
}
func Recover(ctx context.Context, err interface{}) (userMessage error) {
c := getResponseContext(ctx)
return ErrorOnPath(ctx, c.recover(ctx, err))
}
// HasFieldError returns true if the given field has already errored
func HasFieldError(ctx context.Context, rctx *FieldContext) bool {
c := getResponseContext(ctx)
c.errorsMu.Lock()
defer c.errorsMu.Unlock()
if len(c.errors) == 0 {
return false
}
path := rctx.Path()
for _, err := range c.errors {
if equalPath(err.Path, path) {
return true
}
}
return false
}
// GetFieldErrors returns a list of errors that occurred in the given field
func GetFieldErrors(ctx context.Context, rctx *FieldContext) gqlerror.List {
c := getResponseContext(ctx)
c.errorsMu.Lock()
defer c.errorsMu.Unlock()
if len(c.errors) == 0 {
return nil
}
path := rctx.Path()
var errs gqlerror.List
for _, err := range c.errors {
if equalPath(err.Path, path) {
errs = append(errs, err)
}
}
return errs
}
func GetErrors(ctx context.Context) gqlerror.List {
resCtx := getResponseContext(ctx)
resCtx.errorsMu.Lock()
defer resCtx.errorsMu.Unlock()
if len(resCtx.errors) == 0 {
return nil
}
errs := resCtx.errors
cpy := make(gqlerror.List, len(errs))
for i := range errs {
errCpy := *errs[i]
cpy[i] = &errCpy
}
return cpy
}
// RegisterExtension allows you to add a new extension into the graphql response
func RegisterExtension(ctx context.Context, key string, value interface{}) {
c := getResponseContext(ctx)
c.extensionsMu.Lock()
defer c.extensionsMu.Unlock()
if c.extensions == nil {
c.extensions = make(map[string]interface{})
}
if _, ok := c.extensions[key]; ok {
panic(fmt.Errorf("extension already registered for key %s", key))
}
c.extensions[key] = value
}
// GetExtensions returns any extensions registered in the current result context
func GetExtensions(ctx context.Context) map[string]interface{} {
ext := getResponseContext(ctx).extensions
if ext == nil {
return map[string]interface{}{}
}
return ext
}
func GetExtension(ctx context.Context, name string) interface{} {
ext := getResponseContext(ctx).extensions
if ext == nil {
return nil
}
return ext[name]
}

View File

@@ -1,25 +0,0 @@
package graphql
import (
"context"
)
const rootResolverCtx key = "root_resolver_context"
type RootFieldContext struct {
// The name of the type this field belongs to
Object string
// The raw field
Field CollectedField
}
func GetRootFieldContext(ctx context.Context) *RootFieldContext {
if val, ok := ctx.Value(rootResolverCtx).(*RootFieldContext); ok {
return val
}
return nil
}
func WithRootFieldContext(ctx context.Context, rc *RootFieldContext) context.Context {
return context.WithValue(ctx, rootResolverCtx, rc)
}

View File

@@ -1,51 +0,0 @@
package errcode
import (
"github.com/vektah/gqlparser/v2/gqlerror"
)
const (
ValidationFailed = "GRAPHQL_VALIDATION_FAILED"
ParseFailed = "GRAPHQL_PARSE_FAILED"
)
type ErrorKind int
const (
// issues with graphql (validation, parsing). 422s in http, GQL_ERROR in websocket
KindProtocol ErrorKind = iota
// user errors, 200s in http, GQL_DATA in websocket
KindUser
)
var codeType = map[string]ErrorKind{
ValidationFailed: KindProtocol,
ParseFailed: KindProtocol,
}
// RegisterErrorType should be called by extensions that want to customize the http status codes for errors they return
func RegisterErrorType(code string, kind ErrorKind) {
codeType[code] = kind
}
// Set the error code on a given graphql error extension
func Set(err *gqlerror.Error, value string) {
if err.Extensions == nil {
err.Extensions = map[string]interface{}{}
}
err.Extensions["code"] = value
}
// get the kind of the first non User error, defaults to User if no errors have a custom extension
func GetErrorKind(errs gqlerror.List) ErrorKind {
for _, err := range errs {
if code, ok := err.Extensions["code"].(string); ok {
if kind, ok := codeType[code]; ok && kind != KindUser {
return kind
}
}
}
return KindUser
}

View File

@@ -1,32 +0,0 @@
package graphql
import (
"context"
"errors"
"github.com/vektah/gqlparser/v2/gqlerror"
)
type ErrorPresenterFunc func(ctx context.Context, err error) *gqlerror.Error
func DefaultErrorPresenter(ctx context.Context, err error) *gqlerror.Error {
var gqlErr *gqlerror.Error
if errors.As(err, &gqlErr) {
return gqlErr
}
return gqlerror.WrapPath(GetPath(ctx), err)
}
func ErrorOnPath(ctx context.Context, err error) error {
if err == nil {
return nil
}
var gqlErr *gqlerror.Error
if errors.As(err, &gqlErr) {
if gqlErr.Path == nil {
gqlErr.Path = GetPath(ctx)
}
return gqlErr
}
return gqlerror.WrapPath(GetPath(ctx), err)
}

View File

@@ -1,162 +0,0 @@
//go:generate go run github.com/matryer/moq -out executable_schema_mock.go . ExecutableSchema
package graphql
import (
"context"
"fmt"
"github.com/vektah/gqlparser/v2/ast"
)
type ExecutableSchema interface {
Schema() *ast.Schema
Complexity(typeName, fieldName string, childComplexity int, args map[string]interface{}) (int, bool)
Exec(ctx context.Context) ResponseHandler
}
// CollectFields returns the set of fields from an ast.SelectionSet where all collected fields satisfy at least one of the GraphQL types
// passed through satisfies. Providing an empty or nil slice for satisfies will return collect all fields regardless of fragment
// type conditions.
func CollectFields(reqCtx *OperationContext, selSet ast.SelectionSet, satisfies []string) []CollectedField {
return collectFields(reqCtx, selSet, satisfies, map[string]bool{})
}
func collectFields(reqCtx *OperationContext, selSet ast.SelectionSet, satisfies []string, visited map[string]bool) []CollectedField {
groupedFields := make([]CollectedField, 0, len(selSet))
for _, sel := range selSet {
switch sel := sel.(type) {
case *ast.Field:
if !shouldIncludeNode(sel.Directives, reqCtx.Variables) {
continue
}
f := getOrCreateAndAppendField(&groupedFields, sel.Name, sel.Alias, sel.ObjectDefinition, func() CollectedField {
return CollectedField{Field: sel}
})
f.Selections = append(f.Selections, sel.SelectionSet...)
case *ast.InlineFragment:
if !shouldIncludeNode(sel.Directives, reqCtx.Variables) {
continue
}
if len(satisfies) > 0 && !instanceOf(sel.TypeCondition, satisfies) {
continue
}
for _, childField := range collectFields(reqCtx, sel.SelectionSet, satisfies, visited) {
f := getOrCreateAndAppendField(&groupedFields, childField.Name, childField.Alias, childField.ObjectDefinition, func() CollectedField { return childField })
f.Selections = append(f.Selections, childField.Selections...)
}
case *ast.FragmentSpread:
if !shouldIncludeNode(sel.Directives, reqCtx.Variables) {
continue
}
fragmentName := sel.Name
if _, seen := visited[fragmentName]; seen {
continue
}
visited[fragmentName] = true
fragment := reqCtx.Doc.Fragments.ForName(fragmentName)
if fragment == nil {
// should never happen, validator has already run
panic(fmt.Errorf("missing fragment %s", fragmentName))
}
if len(satisfies) > 0 && !instanceOf(fragment.TypeCondition, satisfies) {
continue
}
for _, childField := range collectFields(reqCtx, fragment.SelectionSet, satisfies, visited) {
f := getOrCreateAndAppendField(&groupedFields, childField.Name, childField.Alias, childField.ObjectDefinition, func() CollectedField { return childField })
f.Selections = append(f.Selections, childField.Selections...)
}
default:
panic(fmt.Errorf("unsupported %T", sel))
}
}
return groupedFields
}
type CollectedField struct {
*ast.Field
Selections ast.SelectionSet
}
func instanceOf(val string, satisfies []string) bool {
for _, s := range satisfies {
if val == s {
return true
}
}
return false
}
func getOrCreateAndAppendField(c *[]CollectedField, name string, alias string, objectDefinition *ast.Definition, creator func() CollectedField) *CollectedField {
for i, cf := range *c {
if cf.Name == name && cf.Alias == alias {
if cf.ObjectDefinition == objectDefinition {
return &(*c)[i]
}
if cf.ObjectDefinition == nil || objectDefinition == nil {
continue
}
if cf.ObjectDefinition.Name == objectDefinition.Name {
return &(*c)[i]
}
for _, ifc := range objectDefinition.Interfaces {
if ifc == cf.ObjectDefinition.Name {
return &(*c)[i]
}
}
}
}
f := creator()
*c = append(*c, f)
return &(*c)[len(*c)-1]
}
func shouldIncludeNode(directives ast.DirectiveList, variables map[string]interface{}) bool {
if len(directives) == 0 {
return true
}
skip, include := false, true
if d := directives.ForName("skip"); d != nil {
skip = resolveIfArgument(d, variables)
}
if d := directives.ForName("include"); d != nil {
include = resolveIfArgument(d, variables)
}
return !skip && include
}
func resolveIfArgument(d *ast.Directive, variables map[string]interface{}) bool {
arg := d.Arguments.ForName("if")
if arg == nil {
panic(fmt.Sprintf("%s: argument 'if' not defined", d.Name))
}
value, err := arg.Value.Value(variables)
if err != nil {
panic(err)
}
ret, ok := value.(bool)
if !ok {
panic(fmt.Sprintf("%s: argument 'if' is not a boolean", d.Name))
}
return ret
}

View File

@@ -1,172 +0,0 @@
// Code generated by moq; DO NOT EDIT.
// github.com/matryer/moq
package graphql
import (
"context"
"github.com/vektah/gqlparser/v2/ast"
"sync"
)
// Ensure, that ExecutableSchemaMock does implement ExecutableSchema.
// If this is not the case, regenerate this file with moq.
var _ ExecutableSchema = &ExecutableSchemaMock{}
// ExecutableSchemaMock is a mock implementation of ExecutableSchema.
//
// func TestSomethingThatUsesExecutableSchema(t *testing.T) {
//
// // make and configure a mocked ExecutableSchema
// mockedExecutableSchema := &ExecutableSchemaMock{
// ComplexityFunc: func(typeName string, fieldName string, childComplexity int, args map[string]interface{}) (int, bool) {
// panic("mock out the Complexity method")
// },
// ExecFunc: func(ctx context.Context) ResponseHandler {
// panic("mock out the Exec method")
// },
// SchemaFunc: func() *ast.Schema {
// panic("mock out the Schema method")
// },
// }
//
// // use mockedExecutableSchema in code that requires ExecutableSchema
// // and then make assertions.
//
// }
type ExecutableSchemaMock struct {
// ComplexityFunc mocks the Complexity method.
ComplexityFunc func(typeName string, fieldName string, childComplexity int, args map[string]interface{}) (int, bool)
// ExecFunc mocks the Exec method.
ExecFunc func(ctx context.Context) ResponseHandler
// SchemaFunc mocks the Schema method.
SchemaFunc func() *ast.Schema
// calls tracks calls to the methods.
calls struct {
// Complexity holds details about calls to the Complexity method.
Complexity []struct {
// TypeName is the typeName argument value.
TypeName string
// FieldName is the fieldName argument value.
FieldName string
// ChildComplexity is the childComplexity argument value.
ChildComplexity int
// Args is the args argument value.
Args map[string]interface{}
}
// Exec holds details about calls to the Exec method.
Exec []struct {
// Ctx is the ctx argument value.
Ctx context.Context
}
// Schema holds details about calls to the Schema method.
Schema []struct {
}
}
lockComplexity sync.RWMutex
lockExec sync.RWMutex
lockSchema sync.RWMutex
}
// Complexity calls ComplexityFunc.
func (mock *ExecutableSchemaMock) Complexity(typeName string, fieldName string, childComplexity int, args map[string]interface{}) (int, bool) {
if mock.ComplexityFunc == nil {
panic("ExecutableSchemaMock.ComplexityFunc: method is nil but ExecutableSchema.Complexity was just called")
}
callInfo := struct {
TypeName string
FieldName string
ChildComplexity int
Args map[string]interface{}
}{
TypeName: typeName,
FieldName: fieldName,
ChildComplexity: childComplexity,
Args: args,
}
mock.lockComplexity.Lock()
mock.calls.Complexity = append(mock.calls.Complexity, callInfo)
mock.lockComplexity.Unlock()
return mock.ComplexityFunc(typeName, fieldName, childComplexity, args)
}
// ComplexityCalls gets all the calls that were made to Complexity.
// Check the length with:
// len(mockedExecutableSchema.ComplexityCalls())
func (mock *ExecutableSchemaMock) ComplexityCalls() []struct {
TypeName string
FieldName string
ChildComplexity int
Args map[string]interface{}
} {
var calls []struct {
TypeName string
FieldName string
ChildComplexity int
Args map[string]interface{}
}
mock.lockComplexity.RLock()
calls = mock.calls.Complexity
mock.lockComplexity.RUnlock()
return calls
}
// Exec calls ExecFunc.
func (mock *ExecutableSchemaMock) Exec(ctx context.Context) ResponseHandler {
if mock.ExecFunc == nil {
panic("ExecutableSchemaMock.ExecFunc: method is nil but ExecutableSchema.Exec was just called")
}
callInfo := struct {
Ctx context.Context
}{
Ctx: ctx,
}
mock.lockExec.Lock()
mock.calls.Exec = append(mock.calls.Exec, callInfo)
mock.lockExec.Unlock()
return mock.ExecFunc(ctx)
}
// ExecCalls gets all the calls that were made to Exec.
// Check the length with:
// len(mockedExecutableSchema.ExecCalls())
func (mock *ExecutableSchemaMock) ExecCalls() []struct {
Ctx context.Context
} {
var calls []struct {
Ctx context.Context
}
mock.lockExec.RLock()
calls = mock.calls.Exec
mock.lockExec.RUnlock()
return calls
}
// Schema calls SchemaFunc.
func (mock *ExecutableSchemaMock) Schema() *ast.Schema {
if mock.SchemaFunc == nil {
panic("ExecutableSchemaMock.SchemaFunc: method is nil but ExecutableSchema.Schema was just called")
}
callInfo := struct {
}{}
mock.lockSchema.Lock()
mock.calls.Schema = append(mock.calls.Schema, callInfo)
mock.lockSchema.Unlock()
return mock.SchemaFunc()
}
// SchemaCalls gets all the calls that were made to Schema.
// Check the length with:
// len(mockedExecutableSchema.SchemaCalls())
func (mock *ExecutableSchemaMock) SchemaCalls() []struct {
} {
var calls []struct {
}
mock.lockSchema.RLock()
calls = mock.calls.Schema
mock.lockSchema.RUnlock()
return calls
}

View File

@@ -1,192 +0,0 @@
package executor
import (
"context"
"github.com/99designs/gqlgen/graphql"
"github.com/99designs/gqlgen/graphql/errcode"
"github.com/vektah/gqlparser/v2/ast"
"github.com/vektah/gqlparser/v2/gqlerror"
"github.com/vektah/gqlparser/v2/parser"
"github.com/vektah/gqlparser/v2/validator"
)
// Executor executes graphql queries against a schema.
type Executor struct {
es graphql.ExecutableSchema
extensions []graphql.HandlerExtension
ext extensions
errorPresenter graphql.ErrorPresenterFunc
recoverFunc graphql.RecoverFunc
queryCache graphql.Cache
}
var _ graphql.GraphExecutor = &Executor{}
// New creates a new Executor with the given schema, and a default error and
// recovery callbacks, and no query cache or extensions.
func New(es graphql.ExecutableSchema) *Executor {
e := &Executor{
es: es,
errorPresenter: graphql.DefaultErrorPresenter,
recoverFunc: graphql.DefaultRecover,
queryCache: graphql.NoCache{},
ext: processExtensions(nil),
}
return e
}
func (e *Executor) CreateOperationContext(ctx context.Context, params *graphql.RawParams) (*graphql.OperationContext, gqlerror.List) {
rc := &graphql.OperationContext{
DisableIntrospection: true,
RecoverFunc: e.recoverFunc,
ResolverMiddleware: e.ext.fieldMiddleware,
RootResolverMiddleware: e.ext.rootFieldMiddleware,
Stats: graphql.Stats{
Read: params.ReadTime,
OperationStart: graphql.GetStartTime(ctx),
},
}
ctx = graphql.WithOperationContext(ctx, rc)
for _, p := range e.ext.operationParameterMutators {
if err := p.MutateOperationParameters(ctx, params); err != nil {
return rc, gqlerror.List{err}
}
}
rc.RawQuery = params.Query
rc.OperationName = params.OperationName
var listErr gqlerror.List
rc.Doc, listErr = e.parseQuery(ctx, &rc.Stats, params.Query)
if len(listErr) != 0 {
return rc, listErr
}
rc.Operation = rc.Doc.Operations.ForName(params.OperationName)
if rc.Operation == nil {
return rc, gqlerror.List{gqlerror.Errorf("operation %s not found", params.OperationName)}
}
var err *gqlerror.Error
rc.Variables, err = validator.VariableValues(e.es.Schema(), rc.Operation, params.Variables)
if err != nil {
errcode.Set(err, errcode.ValidationFailed)
return rc, gqlerror.List{err}
}
rc.Stats.Validation.End = graphql.Now()
for _, p := range e.ext.operationContextMutators {
if err := p.MutateOperationContext(ctx, rc); err != nil {
return rc, gqlerror.List{err}
}
}
return rc, nil
}
func (e *Executor) DispatchOperation(ctx context.Context, rc *graphql.OperationContext) (graphql.ResponseHandler, context.Context) {
ctx = graphql.WithOperationContext(ctx, rc)
var innerCtx context.Context
res := e.ext.operationMiddleware(ctx, func(ctx context.Context) graphql.ResponseHandler {
innerCtx = ctx
tmpResponseContext := graphql.WithResponseContext(ctx, e.errorPresenter, e.recoverFunc)
responses := e.es.Exec(tmpResponseContext)
if errs := graphql.GetErrors(tmpResponseContext); errs != nil {
return graphql.OneShot(&graphql.Response{Errors: errs})
}
return func(ctx context.Context) *graphql.Response {
ctx = graphql.WithResponseContext(ctx, e.errorPresenter, e.recoverFunc)
resp := e.ext.responseMiddleware(ctx, func(ctx context.Context) *graphql.Response {
resp := responses(ctx)
if resp == nil {
return nil
}
resp.Errors = append(resp.Errors, graphql.GetErrors(ctx)...)
resp.Extensions = graphql.GetExtensions(ctx)
return resp
})
if resp == nil {
return nil
}
return resp
}
})
return res, innerCtx
}
func (e *Executor) DispatchError(ctx context.Context, list gqlerror.List) *graphql.Response {
ctx = graphql.WithResponseContext(ctx, e.errorPresenter, e.recoverFunc)
for _, gErr := range list {
graphql.AddError(ctx, gErr)
}
resp := e.ext.responseMiddleware(ctx, func(ctx context.Context) *graphql.Response {
resp := &graphql.Response{
Errors: list,
}
resp.Extensions = graphql.GetExtensions(ctx)
return resp
})
return resp
}
func (e *Executor) PresentRecoveredError(ctx context.Context, err interface{}) *gqlerror.Error {
return e.errorPresenter(ctx, e.recoverFunc(ctx, err))
}
func (e *Executor) SetQueryCache(cache graphql.Cache) {
e.queryCache = cache
}
func (e *Executor) SetErrorPresenter(f graphql.ErrorPresenterFunc) {
e.errorPresenter = f
}
func (e *Executor) SetRecoverFunc(f graphql.RecoverFunc) {
e.recoverFunc = f
}
// parseQuery decodes the incoming query and validates it, pulling from cache if present.
//
// NOTE: This should NOT look at variables, they will change per request. It should only parse and validate
// the raw query string.
func (e *Executor) parseQuery(ctx context.Context, stats *graphql.Stats, query string) (*ast.QueryDocument, gqlerror.List) {
stats.Parsing.Start = graphql.Now()
if doc, ok := e.queryCache.Get(ctx, query); ok {
now := graphql.Now()
stats.Parsing.End = now
stats.Validation.Start = now
return doc.(*ast.QueryDocument), nil
}
doc, err := parser.ParseQuery(&ast.Source{Input: query})
if err != nil {
errcode.Set(err, errcode.ParseFailed)
return nil, gqlerror.List{err}
}
stats.Parsing.End = graphql.Now()
stats.Validation.Start = graphql.Now()
listErr := validator.Validate(e.es.Schema(), doc)
if len(listErr) != 0 {
for _, e := range listErr {
errcode.Set(e, errcode.ValidationFailed)
}
return nil, listErr
}
e.queryCache.Add(ctx, query, doc)
return doc, nil
}

View File

@@ -1,195 +0,0 @@
package executor
import (
"context"
"fmt"
"github.com/99designs/gqlgen/graphql"
)
// Use adds the given extension to this Executor.
func (e *Executor) Use(extension graphql.HandlerExtension) {
if err := extension.Validate(e.es); err != nil {
panic(err)
}
switch extension.(type) {
case graphql.OperationParameterMutator,
graphql.OperationContextMutator,
graphql.OperationInterceptor,
graphql.RootFieldInterceptor,
graphql.FieldInterceptor,
graphql.ResponseInterceptor:
e.extensions = append(e.extensions, extension)
e.ext = processExtensions(e.extensions)
default:
panic(fmt.Errorf("cannot Use %T as a gqlgen handler extension because it does not implement any extension hooks", extension))
}
}
// AroundFields is a convenience method for creating an extension that only implements field middleware
func (e *Executor) AroundFields(f graphql.FieldMiddleware) {
e.Use(aroundFieldFunc(f))
}
// AroundRootFields is a convenience method for creating an extension that only implements root field middleware
func (e *Executor) AroundRootFields(f graphql.RootFieldMiddleware) {
e.Use(aroundRootFieldFunc(f))
}
// AroundOperations is a convenience method for creating an extension that only implements operation middleware
func (e *Executor) AroundOperations(f graphql.OperationMiddleware) {
e.Use(aroundOpFunc(f))
}
// AroundResponses is a convenience method for creating an extension that only implements response middleware
func (e *Executor) AroundResponses(f graphql.ResponseMiddleware) {
e.Use(aroundRespFunc(f))
}
type extensions struct {
operationMiddleware graphql.OperationMiddleware
responseMiddleware graphql.ResponseMiddleware
rootFieldMiddleware graphql.RootFieldMiddleware
fieldMiddleware graphql.FieldMiddleware
operationParameterMutators []graphql.OperationParameterMutator
operationContextMutators []graphql.OperationContextMutator
}
func processExtensions(exts []graphql.HandlerExtension) extensions {
e := extensions{
operationMiddleware: func(ctx context.Context, next graphql.OperationHandler) graphql.ResponseHandler {
return next(ctx)
},
responseMiddleware: func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response {
return next(ctx)
},
rootFieldMiddleware: func(ctx context.Context, next graphql.RootResolver) graphql.Marshaler {
return next(ctx)
},
fieldMiddleware: func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
return next(ctx)
},
}
// this loop goes backwards so the first extension is the outer most middleware and runs first.
for i := len(exts) - 1; i >= 0; i-- {
p := exts[i]
if p, ok := p.(graphql.OperationInterceptor); ok {
previous := e.operationMiddleware
e.operationMiddleware = func(ctx context.Context, next graphql.OperationHandler) graphql.ResponseHandler {
return p.InterceptOperation(ctx, func(ctx context.Context) graphql.ResponseHandler {
return previous(ctx, next)
})
}
}
if p, ok := p.(graphql.ResponseInterceptor); ok {
previous := e.responseMiddleware
e.responseMiddleware = func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response {
return p.InterceptResponse(ctx, func(ctx context.Context) *graphql.Response {
return previous(ctx, next)
})
}
}
if p, ok := p.(graphql.RootFieldInterceptor); ok {
previous := e.rootFieldMiddleware
e.rootFieldMiddleware = func(ctx context.Context, next graphql.RootResolver) graphql.Marshaler {
return p.InterceptRootField(ctx, func(ctx context.Context) graphql.Marshaler {
return previous(ctx, next)
})
}
}
if p, ok := p.(graphql.FieldInterceptor); ok {
previous := e.fieldMiddleware
e.fieldMiddleware = func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
return p.InterceptField(ctx, func(ctx context.Context) (res interface{}, err error) {
return previous(ctx, next)
})
}
}
}
for _, p := range exts {
if p, ok := p.(graphql.OperationParameterMutator); ok {
e.operationParameterMutators = append(e.operationParameterMutators, p)
}
if p, ok := p.(graphql.OperationContextMutator); ok {
e.operationContextMutators = append(e.operationContextMutators, p)
}
}
return e
}
type aroundOpFunc func(ctx context.Context, next graphql.OperationHandler) graphql.ResponseHandler
func (r aroundOpFunc) ExtensionName() string {
return "InlineOperationFunc"
}
func (r aroundOpFunc) Validate(schema graphql.ExecutableSchema) error {
if r == nil {
return fmt.Errorf("OperationFunc can not be nil")
}
return nil
}
func (r aroundOpFunc) InterceptOperation(ctx context.Context, next graphql.OperationHandler) graphql.ResponseHandler {
return r(ctx, next)
}
type aroundRespFunc func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response
func (r aroundRespFunc) ExtensionName() string {
return "InlineResponseFunc"
}
func (r aroundRespFunc) Validate(schema graphql.ExecutableSchema) error {
if r == nil {
return fmt.Errorf("ResponseFunc can not be nil")
}
return nil
}
func (r aroundRespFunc) InterceptResponse(ctx context.Context, next graphql.ResponseHandler) *graphql.Response {
return r(ctx, next)
}
type aroundFieldFunc func(ctx context.Context, next graphql.Resolver) (res interface{}, err error)
func (f aroundFieldFunc) ExtensionName() string {
return "InlineFieldFunc"
}
func (f aroundFieldFunc) Validate(schema graphql.ExecutableSchema) error {
if f == nil {
return fmt.Errorf("FieldFunc can not be nil")
}
return nil
}
func (f aroundFieldFunc) InterceptField(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
return f(ctx, next)
}
type aroundRootFieldFunc func(ctx context.Context, next graphql.RootResolver) graphql.Marshaler
func (f aroundRootFieldFunc) ExtensionName() string {
return "InlineRootFieldFunc"
}
func (f aroundRootFieldFunc) Validate(schema graphql.ExecutableSchema) error {
if f == nil {
return fmt.Errorf("RootFieldFunc can not be nil")
}
return nil
}
func (f aroundRootFieldFunc) InterceptRootField(ctx context.Context, next graphql.RootResolver) graphql.Marshaler {
return f(ctx, next)
}

View File

@@ -1,63 +0,0 @@
package graphql
import (
"io"
"sync"
)
type FieldSet struct {
fields []CollectedField
Values []Marshaler
delayed []delayedResult
}
type delayedResult struct {
i int
f func() Marshaler
}
func NewFieldSet(fields []CollectedField) *FieldSet {
return &FieldSet{
fields: fields,
Values: make([]Marshaler, len(fields)),
}
}
func (m *FieldSet) Concurrently(i int, f func() Marshaler) {
m.delayed = append(m.delayed, delayedResult{i: i, f: f})
}
func (m *FieldSet) Dispatch() {
if len(m.delayed) == 1 {
// only one concurrent task, no need to spawn a goroutine or deal create waitgroups
d := m.delayed[0]
m.Values[d.i] = d.f()
} else if len(m.delayed) > 1 {
// more than one concurrent task, use the main goroutine to do one, only spawn goroutines for the others
var wg sync.WaitGroup
for _, d := range m.delayed[1:] {
wg.Add(1)
go func(d delayedResult) {
m.Values[d.i] = d.f()
wg.Done()
}(d)
}
m.Values[m.delayed[0].i] = m.delayed[0].f()
wg.Wait()
}
}
func (m *FieldSet) MarshalGQL(writer io.Writer) {
writer.Write(openBrace)
for i, field := range m.fields {
if i != 0 {
writer.Write(comma)
}
writeQuotedString(writer, field.Alias)
writer.Write(colon)
m.Values[i].MarshalGQL(writer)
}
writer.Write(closeBrace)
}

View File

@@ -1,47 +0,0 @@
package graphql
import (
"context"
"encoding/json"
"fmt"
"io"
"math"
"strconv"
)
func MarshalFloat(f float64) Marshaler {
return WriterFunc(func(w io.Writer) {
io.WriteString(w, fmt.Sprintf("%g", f))
})
}
func UnmarshalFloat(v interface{}) (float64, error) {
switch v := v.(type) {
case string:
return strconv.ParseFloat(v, 64)
case int:
return float64(v), nil
case int64:
return float64(v), nil
case float64:
return v, nil
case json.Number:
return strconv.ParseFloat(string(v), 64)
default:
return 0, fmt.Errorf("%T is not an float", v)
}
}
func MarshalFloatContext(f float64) ContextMarshaler {
return ContextWriterFunc(func(ctx context.Context, w io.Writer) error {
if math.IsInf(f, 0) || math.IsNaN(f) {
return fmt.Errorf("cannot marshal infinite no NaN float values")
}
io.WriteString(w, fmt.Sprintf("%g", f))
return nil
})
}
func UnmarshalFloatContext(ctx context.Context, v interface{}) (float64, error) {
return UnmarshalFloat(v)
}

View File

@@ -1,130 +0,0 @@
package graphql
import (
"context"
"net/http"
"strconv"
"strings"
"github.com/vektah/gqlparser/v2/gqlerror"
)
type (
OperationMiddleware func(ctx context.Context, next OperationHandler) ResponseHandler
OperationHandler func(ctx context.Context) ResponseHandler
ResponseHandler func(ctx context.Context) *Response
ResponseMiddleware func(ctx context.Context, next ResponseHandler) *Response
Resolver func(ctx context.Context) (res interface{}, err error)
FieldMiddleware func(ctx context.Context, next Resolver) (res interface{}, err error)
RootResolver func(ctx context.Context) Marshaler
RootFieldMiddleware func(ctx context.Context, next RootResolver) Marshaler
RawParams struct {
Query string `json:"query"`
OperationName string `json:"operationName"`
Variables map[string]interface{} `json:"variables"`
Extensions map[string]interface{} `json:"extensions"`
ReadTime TraceTiming `json:"-"`
}
GraphExecutor interface {
CreateOperationContext(ctx context.Context, params *RawParams) (*OperationContext, gqlerror.List)
DispatchOperation(ctx context.Context, rc *OperationContext) (ResponseHandler, context.Context)
DispatchError(ctx context.Context, list gqlerror.List) *Response
}
// HandlerExtension adds functionality to the http handler. See the list of possible hook points below
// Its important to understand the lifecycle of a graphql request and the terminology we use in gqlgen
// before working with these
//
// +--- REQUEST POST /graphql --------------------------------------------+
// | +- OPERATION query OpName { viewer { name } } -----------------------+ |
// | | RESPONSE { "data": { "viewer": { "name": "bob" } } } | |
// | +- OPERATION subscription OpName2 { chat { message } } --------------+ |
// | | RESPONSE { "data": { "chat": { "message": "hello" } } } | |
// | | RESPONSE { "data": { "chat": { "message": "byee" } } } | |
// | +--------------------------------------------------------------------+ |
// +------------------------------------------------------------------------+
HandlerExtension interface {
// ExtensionName should be a CamelCase string version of the extension which may be shown in stats and logging.
ExtensionName() string
// Validate is called when adding an extension to the server, it allows validation against the servers schema.
Validate(schema ExecutableSchema) error
}
// OperationParameterMutator is called before creating a request context. allows manipulating the raw query
// on the way in.
OperationParameterMutator interface {
MutateOperationParameters(ctx context.Context, request *RawParams) *gqlerror.Error
}
// OperationContextMutator is called after creating the request context, but before executing the root resolver.
OperationContextMutator interface {
MutateOperationContext(ctx context.Context, rc *OperationContext) *gqlerror.Error
}
// OperationInterceptor is called for each incoming query, for basic requests the writer will be invoked once,
// for subscriptions it will be invoked multiple times.
OperationInterceptor interface {
InterceptOperation(ctx context.Context, next OperationHandler) ResponseHandler
}
// ResponseInterceptor is called around each graphql operation response. This can be called many times for a single
// operation the case of subscriptions.
ResponseInterceptor interface {
InterceptResponse(ctx context.Context, next ResponseHandler) *Response
}
RootFieldInterceptor interface {
InterceptRootField(ctx context.Context, next RootResolver) Marshaler
}
// FieldInterceptor called around each field
FieldInterceptor interface {
InterceptField(ctx context.Context, next Resolver) (res interface{}, err error)
}
// Transport provides support for different wire level encodings of graphql requests, eg Form, Get, Post, Websocket
Transport interface {
Supports(r *http.Request) bool
Do(w http.ResponseWriter, r *http.Request, exec GraphExecutor)
}
)
type Status int
func (p *RawParams) AddUpload(upload Upload, key, path string) *gqlerror.Error {
if !strings.HasPrefix(path, "variables.") {
return gqlerror.Errorf("invalid operations paths for key %s", key)
}
var ptr interface{} = p.Variables
parts := strings.Split(path, ".")
// skip the first part (variables) because we started there
for i, p := range parts[1:] {
last := i == len(parts)-2
if ptr == nil {
return gqlerror.Errorf("path is missing \"variables.\" prefix, key: %s, path: %s", key, path)
}
if index, parseNbrErr := strconv.Atoi(p); parseNbrErr == nil {
if last {
ptr.([]interface{})[index] = upload
} else {
ptr = ptr.([]interface{})[index]
}
} else {
if last {
ptr.(map[string]interface{})[p] = upload
} else {
ptr = ptr.(map[string]interface{})[p]
}
}
}
return nil
}

View File

@@ -1,114 +0,0 @@
package extension
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"github.com/99designs/gqlgen/graphql/errcode"
"github.com/vektah/gqlparser/v2/gqlerror"
"github.com/99designs/gqlgen/graphql"
"github.com/mitchellh/mapstructure"
)
const (
errPersistedQueryNotFound = "PersistedQueryNotFound"
errPersistedQueryNotFoundCode = "PERSISTED_QUERY_NOT_FOUND"
)
// AutomaticPersistedQuery saves client upload by optimistically sending only the hashes of queries, if the server
// does not yet know what the query is for the hash it will respond telling the client to send the query along with the
// hash in the next request.
// see https://github.com/apollographql/apollo-link-persisted-queries
type AutomaticPersistedQuery struct {
Cache graphql.Cache
}
type ApqStats struct {
// The hash of the incoming query
Hash string
// SentQuery is true if the incoming request sent the full query
SentQuery bool
}
const apqExtension = "APQ"
var _ interface {
graphql.OperationParameterMutator
graphql.HandlerExtension
} = AutomaticPersistedQuery{}
func (a AutomaticPersistedQuery) ExtensionName() string {
return "AutomaticPersistedQuery"
}
func (a AutomaticPersistedQuery) Validate(schema graphql.ExecutableSchema) error {
if a.Cache == nil {
return fmt.Errorf("AutomaticPersistedQuery.Cache can not be nil")
}
return nil
}
func (a AutomaticPersistedQuery) MutateOperationParameters(ctx context.Context, rawParams *graphql.RawParams) *gqlerror.Error {
if rawParams.Extensions["persistedQuery"] == nil {
return nil
}
var extension struct {
Sha256 string `mapstructure:"sha256Hash"`
Version int64 `mapstructure:"version"`
}
if err := mapstructure.Decode(rawParams.Extensions["persistedQuery"], &extension); err != nil {
return gqlerror.Errorf("invalid APQ extension data")
}
if extension.Version != 1 {
return gqlerror.Errorf("unsupported APQ version")
}
fullQuery := false
if rawParams.Query == "" {
// client sent optimistic query hash without query string, get it from the cache
query, ok := a.Cache.Get(ctx, extension.Sha256)
if !ok {
err := gqlerror.Errorf(errPersistedQueryNotFound)
errcode.Set(err, errPersistedQueryNotFoundCode)
return err
}
rawParams.Query = query.(string)
} else {
// client sent optimistic query hash with query string, verify and store it
if computeQueryHash(rawParams.Query) != extension.Sha256 {
return gqlerror.Errorf("provided APQ hash does not match query")
}
a.Cache.Add(ctx, extension.Sha256, rawParams.Query)
fullQuery = true
}
graphql.GetOperationContext(ctx).Stats.SetExtension(apqExtension, &ApqStats{
Hash: extension.Sha256,
SentQuery: fullQuery,
})
return nil
}
func GetApqStats(ctx context.Context) *ApqStats {
rc := graphql.GetOperationContext(ctx)
if rc == nil {
return nil
}
s, _ := rc.Stats.GetExtension(apqExtension).(*ApqStats)
return s
}
func computeQueryHash(query string) string {
b := sha256.Sum256([]byte(query))
return hex.EncodeToString(b[:])
}

View File

@@ -1,88 +0,0 @@
package extension
import (
"context"
"fmt"
"github.com/99designs/gqlgen/complexity"
"github.com/99designs/gqlgen/graphql"
"github.com/99designs/gqlgen/graphql/errcode"
"github.com/vektah/gqlparser/v2/gqlerror"
)
const errComplexityLimit = "COMPLEXITY_LIMIT_EXCEEDED"
// ComplexityLimit allows you to define a limit on query complexity
//
// If a query is submitted that exceeds the limit, a 422 status code will be returned.
type ComplexityLimit struct {
Func func(ctx context.Context, rc *graphql.OperationContext) int
es graphql.ExecutableSchema
}
var _ interface {
graphql.OperationContextMutator
graphql.HandlerExtension
} = &ComplexityLimit{}
const complexityExtension = "ComplexityLimit"
type ComplexityStats struct {
// The calculated complexity for this request
Complexity int
// The complexity limit for this request returned by the extension func
ComplexityLimit int
}
// FixedComplexityLimit sets a complexity limit that does not change
func FixedComplexityLimit(limit int) *ComplexityLimit {
return &ComplexityLimit{
Func: func(ctx context.Context, rc *graphql.OperationContext) int {
return limit
},
}
}
func (c ComplexityLimit) ExtensionName() string {
return complexityExtension
}
func (c *ComplexityLimit) Validate(schema graphql.ExecutableSchema) error {
if c.Func == nil {
return fmt.Errorf("ComplexityLimit func can not be nil")
}
c.es = schema
return nil
}
func (c ComplexityLimit) MutateOperationContext(ctx context.Context, rc *graphql.OperationContext) *gqlerror.Error {
op := rc.Doc.Operations.ForName(rc.OperationName)
complexity := complexity.Calculate(c.es, op, rc.Variables)
limit := c.Func(ctx, rc)
rc.Stats.SetExtension(complexityExtension, &ComplexityStats{
Complexity: complexity,
ComplexityLimit: limit,
})
if complexity > limit {
err := gqlerror.Errorf("operation has complexity %d, which exceeds the limit of %d", complexity, limit)
errcode.Set(err, errComplexityLimit)
return err
}
return nil
}
func GetComplexityStats(ctx context.Context) *ComplexityStats {
rc := graphql.GetOperationContext(ctx)
if rc == nil {
return nil
}
s, _ := rc.Stats.GetExtension(complexityExtension).(*ComplexityStats)
return s
}

View File

@@ -1,29 +0,0 @@
package extension
import (
"context"
"github.com/99designs/gqlgen/graphql"
"github.com/vektah/gqlparser/v2/gqlerror"
)
// EnableIntrospection enables clients to reflect all of the types available on the graph.
type Introspection struct{}
var _ interface {
graphql.OperationContextMutator
graphql.HandlerExtension
} = Introspection{}
func (c Introspection) ExtensionName() string {
return "Introspection"
}
func (c Introspection) Validate(schema graphql.ExecutableSchema) error {
return nil
}
func (c Introspection) MutateOperationContext(ctx context.Context, rc *graphql.OperationContext) *gqlerror.Error {
rc.DisableIntrospection = false
return nil
}

View File

@@ -1,32 +0,0 @@
package lru
import (
"context"
"github.com/99designs/gqlgen/graphql"
lru "github.com/hashicorp/golang-lru"
)
type LRU struct {
lru *lru.Cache
}
var _ graphql.Cache = &LRU{}
func New(size int) *LRU {
cache, err := lru.New(size)
if err != nil {
// An error is only returned for non-positive cache size
// and we already checked for that.
panic("unexpected error creating cache: " + err.Error())
}
return &LRU{cache}
}
func (l LRU) Get(ctx context.Context, key string) (value interface{}, ok bool) {
return l.lru.Get(key)
}
func (l LRU) Add(ctx context.Context, key string, value interface{}) {
l.lru.Add(key, value)
}

View File

@@ -1,185 +0,0 @@
package handler
import (
"context"
"encoding/json"
"fmt"
"net/http"
"time"
"github.com/99designs/gqlgen/graphql"
"github.com/99designs/gqlgen/graphql/executor"
"github.com/99designs/gqlgen/graphql/handler/extension"
"github.com/99designs/gqlgen/graphql/handler/lru"
"github.com/99designs/gqlgen/graphql/handler/transport"
"github.com/vektah/gqlparser/v2/gqlerror"
)
type (
Server struct {
transports []graphql.Transport
exec *executor.Executor
}
)
func New(es graphql.ExecutableSchema) *Server {
return &Server{
exec: executor.New(es),
}
}
func NewDefaultServer(es graphql.ExecutableSchema) *Server {
srv := New(es)
srv.AddTransport(transport.Websocket{
KeepAlivePingInterval: 10 * time.Second,
})
srv.AddTransport(transport.Options{})
srv.AddTransport(transport.GET{})
srv.AddTransport(transport.POST{})
srv.AddTransport(transport.MultipartForm{})
srv.SetQueryCache(lru.New(1000))
srv.Use(extension.Introspection{})
srv.Use(extension.AutomaticPersistedQuery{
Cache: lru.New(100),
})
return srv
}
func (s *Server) AddTransport(transport graphql.Transport) {
s.transports = append(s.transports, transport)
}
func (s *Server) SetErrorPresenter(f graphql.ErrorPresenterFunc) {
s.exec.SetErrorPresenter(f)
}
func (s *Server) SetRecoverFunc(f graphql.RecoverFunc) {
s.exec.SetRecoverFunc(f)
}
func (s *Server) SetQueryCache(cache graphql.Cache) {
s.exec.SetQueryCache(cache)
}
func (s *Server) Use(extension graphql.HandlerExtension) {
s.exec.Use(extension)
}
// AroundFields is a convenience method for creating an extension that only implements field middleware
func (s *Server) AroundFields(f graphql.FieldMiddleware) {
s.exec.AroundFields(f)
}
// AroundRootFields is a convenience method for creating an extension that only implements field middleware
func (s *Server) AroundRootFields(f graphql.RootFieldMiddleware) {
s.exec.AroundRootFields(f)
}
// AroundOperations is a convenience method for creating an extension that only implements operation middleware
func (s *Server) AroundOperations(f graphql.OperationMiddleware) {
s.exec.AroundOperations(f)
}
// AroundResponses is a convenience method for creating an extension that only implements response middleware
func (s *Server) AroundResponses(f graphql.ResponseMiddleware) {
s.exec.AroundResponses(f)
}
func (s *Server) getTransport(r *http.Request) graphql.Transport {
for _, t := range s.transports {
if t.Supports(r) {
return t
}
}
return nil
}
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
defer func() {
if err := recover(); err != nil {
err := s.exec.PresentRecoveredError(r.Context(), err)
resp := &graphql.Response{Errors: []*gqlerror.Error{err}}
b, _ := json.Marshal(resp)
w.WriteHeader(http.StatusUnprocessableEntity)
w.Write(b)
}
}()
r = r.WithContext(graphql.StartOperationTrace(r.Context()))
transport := s.getTransport(r)
if transport == nil {
sendErrorf(w, http.StatusBadRequest, "transport not supported")
return
}
transport.Do(w, r, s.exec)
}
func sendError(w http.ResponseWriter, code int, errors ...*gqlerror.Error) {
w.WriteHeader(code)
b, err := json.Marshal(&graphql.Response{Errors: errors})
if err != nil {
panic(err)
}
w.Write(b)
}
func sendErrorf(w http.ResponseWriter, code int, format string, args ...interface{}) {
sendError(w, code, &gqlerror.Error{Message: fmt.Sprintf(format, args...)})
}
type OperationFunc func(ctx context.Context, next graphql.OperationHandler) graphql.ResponseHandler
func (r OperationFunc) ExtensionName() string {
return "InlineOperationFunc"
}
func (r OperationFunc) Validate(schema graphql.ExecutableSchema) error {
if r == nil {
return fmt.Errorf("OperationFunc can not be nil")
}
return nil
}
func (r OperationFunc) InterceptOperation(ctx context.Context, next graphql.OperationHandler) graphql.ResponseHandler {
return r(ctx, next)
}
type ResponseFunc func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response
func (r ResponseFunc) ExtensionName() string {
return "InlineResponseFunc"
}
func (r ResponseFunc) Validate(schema graphql.ExecutableSchema) error {
if r == nil {
return fmt.Errorf("ResponseFunc can not be nil")
}
return nil
}
func (r ResponseFunc) InterceptResponse(ctx context.Context, next graphql.ResponseHandler) *graphql.Response {
return r(ctx, next)
}
type FieldFunc func(ctx context.Context, next graphql.Resolver) (res interface{}, err error)
func (f FieldFunc) ExtensionName() string {
return "InlineFieldFunc"
}
func (f FieldFunc) Validate(schema graphql.ExecutableSchema) error {
if f == nil {
return fmt.Errorf("FieldFunc can not be nil")
}
return nil
}
func (f FieldFunc) InterceptField(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
return f(ctx, next)
}

View File

@@ -1,26 +0,0 @@
package transport
import (
"encoding/json"
"fmt"
"net/http"
"github.com/99designs/gqlgen/graphql"
"github.com/vektah/gqlparser/v2/gqlerror"
)
// SendError sends a best effort error to a raw response writer. It assumes the client can understand the standard
// json error response
func SendError(w http.ResponseWriter, code int, errors ...*gqlerror.Error) {
w.WriteHeader(code)
b, err := json.Marshal(&graphql.Response{Errors: errors})
if err != nil {
panic(err)
}
w.Write(b)
}
// SendErrorf wraps SendError to add formatted messages
func SendErrorf(w http.ResponseWriter, code int, format string, args ...interface{}) {
SendError(w, code, &gqlerror.Error{Message: fmt.Sprintf(format, args...)})
}

View File

@@ -1,208 +0,0 @@
package transport
import (
"encoding/json"
"io"
"io/ioutil"
"mime"
"net/http"
"os"
"strings"
"github.com/99designs/gqlgen/graphql"
)
// MultipartForm the Multipart request spec https://github.com/jaydenseric/graphql-multipart-request-spec
type MultipartForm struct {
// MaxUploadSize sets the maximum number of bytes used to parse a request body
// as multipart/form-data.
MaxUploadSize int64
// MaxMemory defines the maximum number of bytes used to parse a request body
// as multipart/form-data in memory, with the remainder stored on disk in
// temporary files.
MaxMemory int64
}
var _ graphql.Transport = MultipartForm{}
func (f MultipartForm) Supports(r *http.Request) bool {
if r.Header.Get("Upgrade") != "" {
return false
}
mediaType, _, err := mime.ParseMediaType(r.Header.Get("Content-Type"))
if err != nil {
return false
}
return r.Method == "POST" && mediaType == "multipart/form-data"
}
func (f MultipartForm) maxUploadSize() int64 {
if f.MaxUploadSize == 0 {
return 32 << 20
}
return f.MaxUploadSize
}
func (f MultipartForm) maxMemory() int64 {
if f.MaxMemory == 0 {
return 32 << 20
}
return f.MaxMemory
}
func (f MultipartForm) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) {
w.Header().Set("Content-Type", "application/json")
start := graphql.Now()
var err error
if r.ContentLength > f.maxUploadSize() {
writeJsonError(w, "failed to parse multipart form, request body too large")
return
}
r.Body = http.MaxBytesReader(w, r.Body, f.maxUploadSize())
if err = r.ParseMultipartForm(f.maxMemory()); err != nil {
w.WriteHeader(http.StatusUnprocessableEntity)
if strings.Contains(err.Error(), "request body too large") {
writeJsonError(w, "failed to parse multipart form, request body too large")
return
}
writeJsonError(w, "failed to parse multipart form")
return
}
defer r.Body.Close()
var params graphql.RawParams
if err = jsonDecode(strings.NewReader(r.Form.Get("operations")), &params); err != nil {
w.WriteHeader(http.StatusUnprocessableEntity)
writeJsonError(w, "operations form field could not be decoded")
return
}
uploadsMap := map[string][]string{}
if err = json.Unmarshal([]byte(r.Form.Get("map")), &uploadsMap); err != nil {
w.WriteHeader(http.StatusUnprocessableEntity)
writeJsonError(w, "map form field could not be decoded")
return
}
var upload graphql.Upload
for key, paths := range uploadsMap {
if len(paths) == 0 {
w.WriteHeader(http.StatusUnprocessableEntity)
writeJsonErrorf(w, "invalid empty operations paths list for key %s", key)
return
}
file, header, err := r.FormFile(key)
if err != nil {
w.WriteHeader(http.StatusUnprocessableEntity)
writeJsonErrorf(w, "failed to get key %s from form", key)
return
}
defer file.Close()
if len(paths) == 1 {
upload = graphql.Upload{
File: file,
Size: header.Size,
Filename: header.Filename,
ContentType: header.Header.Get("Content-Type"),
}
if err := params.AddUpload(upload, key, paths[0]); err != nil {
w.WriteHeader(http.StatusUnprocessableEntity)
writeJsonGraphqlError(w, err)
return
}
} else {
if r.ContentLength < f.maxMemory() {
fileBytes, err := ioutil.ReadAll(file)
if err != nil {
w.WriteHeader(http.StatusUnprocessableEntity)
writeJsonErrorf(w, "failed to read file for key %s", key)
return
}
for _, path := range paths {
upload = graphql.Upload{
File: &bytesReader{s: &fileBytes, i: 0, prevRune: -1},
Size: header.Size,
Filename: header.Filename,
ContentType: header.Header.Get("Content-Type"),
}
if err := params.AddUpload(upload, key, path); err != nil {
w.WriteHeader(http.StatusUnprocessableEntity)
writeJsonGraphqlError(w, err)
return
}
}
} else {
tmpFile, err := ioutil.TempFile(os.TempDir(), "gqlgen-")
if err != nil {
w.WriteHeader(http.StatusUnprocessableEntity)
writeJsonErrorf(w, "failed to create temp file for key %s", key)
return
}
tmpName := tmpFile.Name()
defer func() {
_ = os.Remove(tmpName)
}()
_, err = io.Copy(tmpFile, file)
if err != nil {
w.WriteHeader(http.StatusUnprocessableEntity)
if err := tmpFile.Close(); err != nil {
writeJsonErrorf(w, "failed to copy to temp file and close temp file for key %s", key)
return
}
writeJsonErrorf(w, "failed to copy to temp file for key %s", key)
return
}
if err := tmpFile.Close(); err != nil {
w.WriteHeader(http.StatusUnprocessableEntity)
writeJsonErrorf(w, "failed to close temp file for key %s", key)
return
}
for _, path := range paths {
pathTmpFile, err := os.Open(tmpName)
if err != nil {
w.WriteHeader(http.StatusUnprocessableEntity)
writeJsonErrorf(w, "failed to open temp file for key %s", key)
return
}
defer pathTmpFile.Close()
upload = graphql.Upload{
File: pathTmpFile,
Size: header.Size,
Filename: header.Filename,
ContentType: header.Header.Get("Content-Type"),
}
if err := params.AddUpload(upload, key, path); err != nil {
w.WriteHeader(http.StatusUnprocessableEntity)
writeJsonGraphqlError(w, err)
return
}
}
}
}
}
params.ReadTime = graphql.TraceTiming{
Start: start,
End: graphql.Now(),
}
rc, gerr := exec.CreateOperationContext(r.Context(), &params)
if gerr != nil {
resp := exec.DispatchError(graphql.WithOperationContext(r.Context(), rc), gerr)
w.WriteHeader(statusFor(gerr))
writeJson(w, resp)
return
}
responses, ctx := exec.DispatchOperation(r.Context(), rc)
writeJson(w, responses(ctx))
}

View File

@@ -1,87 +0,0 @@
package transport
import (
"encoding/json"
"io"
"net/http"
"strings"
"github.com/99designs/gqlgen/graphql"
"github.com/99designs/gqlgen/graphql/errcode"
"github.com/vektah/gqlparser/v2/ast"
"github.com/vektah/gqlparser/v2/gqlerror"
)
// GET implements the GET side of the default HTTP transport
// defined in https://github.com/APIs-guru/graphql-over-http#get
type GET struct{}
var _ graphql.Transport = GET{}
func (h GET) Supports(r *http.Request) bool {
if r.Header.Get("Upgrade") != "" {
return false
}
return r.Method == "GET"
}
func (h GET) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) {
w.Header().Set("Content-Type", "application/json")
raw := &graphql.RawParams{
Query: r.URL.Query().Get("query"),
OperationName: r.URL.Query().Get("operationName"),
}
raw.ReadTime.Start = graphql.Now()
if variables := r.URL.Query().Get("variables"); variables != "" {
if err := jsonDecode(strings.NewReader(variables), &raw.Variables); err != nil {
w.WriteHeader(http.StatusBadRequest)
writeJsonError(w, "variables could not be decoded")
return
}
}
if extensions := r.URL.Query().Get("extensions"); extensions != "" {
if err := jsonDecode(strings.NewReader(extensions), &raw.Extensions); err != nil {
w.WriteHeader(http.StatusBadRequest)
writeJsonError(w, "extensions could not be decoded")
return
}
}
raw.ReadTime.End = graphql.Now()
rc, err := exec.CreateOperationContext(r.Context(), raw)
if err != nil {
w.WriteHeader(statusFor(err))
resp := exec.DispatchError(graphql.WithOperationContext(r.Context(), rc), err)
writeJson(w, resp)
return
}
op := rc.Doc.Operations.ForName(rc.OperationName)
if op.Operation != ast.Query {
w.WriteHeader(http.StatusNotAcceptable)
writeJsonError(w, "GET requests only allow query operations")
return
}
responses, ctx := exec.DispatchOperation(r.Context(), rc)
writeJson(w, responses(ctx))
}
func jsonDecode(r io.Reader, val interface{}) error {
dec := json.NewDecoder(r)
dec.UseNumber()
return dec.Decode(val)
}
func statusFor(errs gqlerror.List) int {
switch errcode.GetErrorKind(errs) {
case errcode.KindProtocol:
return http.StatusUnprocessableEntity
default:
return http.StatusOK
}
}

View File

@@ -1,53 +0,0 @@
package transport
import (
"mime"
"net/http"
"github.com/99designs/gqlgen/graphql"
)
// POST implements the POST side of the default HTTP transport
// defined in https://github.com/APIs-guru/graphql-over-http#post
type POST struct{}
var _ graphql.Transport = POST{}
func (h POST) Supports(r *http.Request) bool {
if r.Header.Get("Upgrade") != "" {
return false
}
mediaType, _, err := mime.ParseMediaType(r.Header.Get("Content-Type"))
if err != nil {
return false
}
return r.Method == "POST" && mediaType == "application/json"
}
func (h POST) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) {
w.Header().Set("Content-Type", "application/json")
var params *graphql.RawParams
start := graphql.Now()
if err := jsonDecode(r.Body, &params); err != nil {
w.WriteHeader(http.StatusBadRequest)
writeJsonErrorf(w, "json body could not be decoded: "+err.Error())
return
}
params.ReadTime = graphql.TraceTiming{
Start: start,
End: graphql.Now(),
}
rc, err := exec.CreateOperationContext(r.Context(), params)
if err != nil {
w.WriteHeader(statusFor(err))
resp := exec.DispatchError(graphql.WithOperationContext(r.Context(), rc), err)
writeJson(w, resp)
return
}
responses, ctx := exec.DispatchOperation(r.Context(), rc)
writeJson(w, responses(ctx))
}

View File

@@ -1,26 +0,0 @@
package transport
import (
"net/http"
"github.com/99designs/gqlgen/graphql"
)
// Options responds to http OPTIONS and HEAD requests
type Options struct{}
var _ graphql.Transport = Options{}
func (o Options) Supports(r *http.Request) bool {
return r.Method == "HEAD" || r.Method == "OPTIONS"
}
func (o Options) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) {
switch r.Method {
case http.MethodOptions:
w.Header().Set("Allow", "OPTIONS, GET, POST")
w.WriteHeader(http.StatusOK)
case http.MethodHead:
w.WriteHeader(http.StatusMethodNotAllowed)
}
}

View File

@@ -1,25 +0,0 @@
package transport
import (
"errors"
"io"
)
type bytesReader struct {
s *[]byte
i int64 // current reading index
prevRune int // index of previous rune; or < 0
}
func (r *bytesReader) Read(b []byte) (n int, err error) {
if r.s == nil {
return 0, errors.New("byte slice pointer is nil")
}
if r.i >= int64(len(*r.s)) {
return 0, io.EOF
}
r.prevRune = -1
n = copy(b, (*r.s)[r.i:])
r.i += int64(n)
return
}

View File

@@ -1,30 +0,0 @@
package transport
import (
"encoding/json"
"fmt"
"io"
"github.com/99designs/gqlgen/graphql"
"github.com/vektah/gqlparser/v2/gqlerror"
)
func writeJson(w io.Writer, response *graphql.Response) {
b, err := json.Marshal(response)
if err != nil {
panic(err)
}
w.Write(b)
}
func writeJsonError(w io.Writer, msg string) {
writeJson(w, &graphql.Response{Errors: gqlerror.List{{Message: msg}}})
}
func writeJsonErrorf(w io.Writer, format string, args ...interface{}) {
writeJson(w, &graphql.Response{Errors: gqlerror.List{{Message: fmt.Sprintf(format, args...)}}})
}
func writeJsonGraphqlError(w io.Writer, err ...*gqlerror.Error) {
writeJson(w, &graphql.Response{Errors: err})
}

View File

@@ -1,415 +0,0 @@
package transport
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"log"
"net"
"net/http"
"sync"
"time"
"github.com/99designs/gqlgen/graphql"
"github.com/99designs/gqlgen/graphql/errcode"
"github.com/gorilla/websocket"
"github.com/vektah/gqlparser/v2/gqlerror"
)
type (
Websocket struct {
Upgrader websocket.Upgrader
InitFunc WebsocketInitFunc
InitTimeout time.Duration
ErrorFunc WebsocketErrorFunc
KeepAlivePingInterval time.Duration
PingPongInterval time.Duration
didInjectSubprotocols bool
}
wsConnection struct {
Websocket
ctx context.Context
conn *websocket.Conn
me messageExchanger
active map[string]context.CancelFunc
mu sync.Mutex
keepAliveTicker *time.Ticker
pingPongTicker *time.Ticker
exec graphql.GraphExecutor
initPayload InitPayload
}
WebsocketInitFunc func(ctx context.Context, initPayload InitPayload) (context.Context, error)
WebsocketErrorFunc func(ctx context.Context, err error)
)
var errReadTimeout = errors.New("read timeout")
var _ graphql.Transport = Websocket{}
func (t Websocket) Supports(r *http.Request) bool {
return r.Header.Get("Upgrade") != ""
}
func (t Websocket) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) {
t.injectGraphQLWSSubprotocols()
ws, err := t.Upgrader.Upgrade(w, r, http.Header{})
if err != nil {
log.Printf("unable to upgrade %T to websocket %s: ", w, err.Error())
SendErrorf(w, http.StatusBadRequest, "unable to upgrade")
return
}
var me messageExchanger
switch ws.Subprotocol() {
default:
msg := websocket.FormatCloseMessage(websocket.CloseProtocolError, fmt.Sprintf("unsupported negotiated subprotocol %s", ws.Subprotocol()))
ws.WriteMessage(websocket.CloseMessage, msg)
return
case graphqlwsSubprotocol, "":
// clients are required to send a subprotocol, to be backward compatible with the previous implementation we select
// "graphql-ws" by default
me = graphqlwsMessageExchanger{c: ws}
case graphqltransportwsSubprotocol:
me = graphqltransportwsMessageExchanger{c: ws}
}
conn := wsConnection{
active: map[string]context.CancelFunc{},
conn: ws,
ctx: r.Context(),
exec: exec,
me: me,
Websocket: t,
}
if !conn.init() {
return
}
conn.run()
}
func (c *wsConnection) handlePossibleError(err error) {
if c.ErrorFunc != nil && err != nil {
c.ErrorFunc(c.ctx, err)
}
}
func (c *wsConnection) nextMessageWithTimeout(timeout time.Duration) (message, error) {
messages, errs := make(chan message, 1), make(chan error, 1)
go func() {
if m, err := c.me.NextMessage(); err != nil {
errs <- err
} else {
messages <- m
}
}()
select {
case m := <-messages:
return m, nil
case err := <-errs:
return message{}, err
case <-time.After(timeout):
return message{}, errReadTimeout
}
}
func (c *wsConnection) init() bool {
var m message
var err error
if c.InitTimeout != 0 {
m, err = c.nextMessageWithTimeout(c.InitTimeout)
} else {
m, err = c.me.NextMessage()
}
if err != nil {
if err == errReadTimeout {
c.close(websocket.CloseProtocolError, "connection initialisation timeout")
return false
}
if err == errInvalidMsg {
c.sendConnectionError("invalid json")
}
c.close(websocket.CloseProtocolError, "decoding error")
return false
}
switch m.t {
case initMessageType:
if len(m.payload) > 0 {
c.initPayload = make(InitPayload)
err := json.Unmarshal(m.payload, &c.initPayload)
if err != nil {
return false
}
}
if c.InitFunc != nil {
ctx, err := c.InitFunc(c.ctx, c.initPayload)
if err != nil {
c.sendConnectionError(err.Error())
c.close(websocket.CloseNormalClosure, "terminated")
return false
}
c.ctx = ctx
}
c.write(&message{t: connectionAckMessageType})
c.write(&message{t: keepAliveMessageType})
case connectionCloseMessageType:
c.close(websocket.CloseNormalClosure, "terminated")
return false
default:
c.sendConnectionError("unexpected message %s", m.t)
c.close(websocket.CloseProtocolError, "unexpected message")
return false
}
return true
}
func (c *wsConnection) write(msg *message) {
c.mu.Lock()
c.handlePossibleError(c.me.Send(msg))
c.mu.Unlock()
}
func (c *wsConnection) run() {
// We create a cancellation that will shutdown the keep-alive when we leave
// this function.
ctx, cancel := context.WithCancel(c.ctx)
defer func() {
cancel()
c.close(websocket.CloseAbnormalClosure, "unexpected closure")
}()
// If we're running in graphql-ws mode, create a timer that will trigger a
// keep alive message every interval
if (c.conn.Subprotocol() == "" || c.conn.Subprotocol() == graphqlwsSubprotocol) && c.KeepAlivePingInterval != 0 {
c.mu.Lock()
c.keepAliveTicker = time.NewTicker(c.KeepAlivePingInterval)
c.mu.Unlock()
go c.keepAlive(ctx)
}
// If we're running in graphql-transport-ws mode, create a timer that will
// trigger a ping message every interval
if c.conn.Subprotocol() == graphqltransportwsSubprotocol && c.PingPongInterval != 0 {
c.mu.Lock()
c.pingPongTicker = time.NewTicker(c.PingPongInterval)
c.mu.Unlock()
// Note: when the connection is closed by this deadline, the client
// will receive an "invalid close code"
c.conn.SetReadDeadline(time.Now().UTC().Add(2 * c.PingPongInterval))
go c.ping(ctx)
}
// Close the connection when the context is cancelled.
// Will optionally send a "close reason" that is retrieved from the context.
go c.closeOnCancel(ctx)
for {
start := graphql.Now()
m, err := c.me.NextMessage()
if err != nil {
// If the connection got closed by us, don't report the error
if !errors.Is(err, net.ErrClosed) {
c.handlePossibleError(err)
}
return
}
switch m.t {
case startMessageType:
c.subscribe(start, &m)
case stopMessageType:
c.mu.Lock()
closer := c.active[m.id]
c.mu.Unlock()
if closer != nil {
closer()
}
case connectionCloseMessageType:
c.close(websocket.CloseNormalClosure, "terminated")
return
case pingMessageType:
c.write(&message{t: pongMessageType, payload: m.payload})
case pongMessageType:
c.conn.SetReadDeadline(time.Now().UTC().Add(2 * c.PingPongInterval))
default:
c.sendConnectionError("unexpected message %s", m.t)
c.close(websocket.CloseProtocolError, "unexpected message")
return
}
}
}
func (c *wsConnection) keepAlive(ctx context.Context) {
for {
select {
case <-ctx.Done():
c.keepAliveTicker.Stop()
return
case <-c.keepAliveTicker.C:
c.write(&message{t: keepAliveMessageType})
}
}
}
func (c *wsConnection) ping(ctx context.Context) {
for {
select {
case <-ctx.Done():
c.pingPongTicker.Stop()
return
case <-c.pingPongTicker.C:
c.write(&message{t: pingMessageType, payload: json.RawMessage{}})
}
}
}
func (c *wsConnection) closeOnCancel(ctx context.Context) {
<-ctx.Done()
if r := closeReasonForContext(ctx); r != "" {
c.sendConnectionError(r)
}
c.close(websocket.CloseNormalClosure, "terminated")
}
func (c *wsConnection) subscribe(start time.Time, msg *message) {
ctx := graphql.StartOperationTrace(c.ctx)
var params *graphql.RawParams
if err := jsonDecode(bytes.NewReader(msg.payload), &params); err != nil {
c.sendError(msg.id, &gqlerror.Error{Message: "invalid json"})
c.complete(msg.id)
return
}
params.ReadTime = graphql.TraceTiming{
Start: start,
End: graphql.Now(),
}
rc, err := c.exec.CreateOperationContext(ctx, params)
if err != nil {
resp := c.exec.DispatchError(graphql.WithOperationContext(ctx, rc), err)
switch errcode.GetErrorKind(err) {
case errcode.KindProtocol:
c.sendError(msg.id, resp.Errors...)
default:
c.sendResponse(msg.id, &graphql.Response{Errors: err})
}
c.complete(msg.id)
return
}
ctx = graphql.WithOperationContext(ctx, rc)
if c.initPayload != nil {
ctx = withInitPayload(ctx, c.initPayload)
}
ctx, cancel := context.WithCancel(ctx)
c.mu.Lock()
c.active[msg.id] = cancel
c.mu.Unlock()
go func() {
defer func() {
if r := recover(); r != nil {
err := rc.Recover(ctx, r)
var gqlerr *gqlerror.Error
if !errors.As(err, &gqlerr) {
gqlerr = &gqlerror.Error{}
if err != nil {
gqlerr.Message = err.Error()
}
}
c.sendError(msg.id, gqlerr)
}
c.complete(msg.id)
c.mu.Lock()
delete(c.active, msg.id)
c.mu.Unlock()
cancel()
}()
responses, ctx := c.exec.DispatchOperation(ctx, rc)
for {
response := responses(ctx)
if response == nil {
break
}
c.sendResponse(msg.id, response)
}
c.complete(msg.id)
c.mu.Lock()
delete(c.active, msg.id)
c.mu.Unlock()
cancel()
}()
}
func (c *wsConnection) sendResponse(id string, response *graphql.Response) {
b, err := json.Marshal(response)
if err != nil {
panic(err)
}
c.write(&message{
payload: b,
id: id,
t: dataMessageType,
})
}
func (c *wsConnection) complete(id string) {
c.write(&message{id: id, t: completeMessageType})
}
func (c *wsConnection) sendError(id string, errors ...*gqlerror.Error) {
errs := make([]error, len(errors))
for i, err := range errors {
errs[i] = err
}
b, err := json.Marshal(errs)
if err != nil {
panic(err)
}
c.write(&message{t: errorMessageType, id: id, payload: b})
}
func (c *wsConnection) sendConnectionError(format string, args ...interface{}) {
b, err := json.Marshal(&gqlerror.Error{Message: fmt.Sprintf(format, args...)})
if err != nil {
panic(err)
}
c.write(&message{t: connectionErrorMessageType, payload: b})
}
func (c *wsConnection) close(closeCode int, message string) {
c.mu.Lock()
_ = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(closeCode, message))
for _, closer := range c.active {
closer()
}
c.mu.Unlock()
_ = c.conn.Close()
}

View File

@@ -1,22 +0,0 @@
package transport
import (
"context"
)
// A private key for context that only this package can access. This is important
// to prevent collisions between different context uses
var closeReasonCtxKey = &wsCloseReasonContextKey{"close-reason"}
type wsCloseReasonContextKey struct {
name string
}
func AppendCloseReason(ctx context.Context, reason string) context.Context {
return context.WithValue(ctx, closeReasonCtxKey, reason)
}
func closeReasonForContext(ctx context.Context) string {
reason, _ := ctx.Value(closeReasonCtxKey).(string)
return reason
}

View File

@@ -1,149 +0,0 @@
package transport
import (
"encoding/json"
"fmt"
"github.com/gorilla/websocket"
)
// https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md
const (
graphqltransportwsSubprotocol = "graphql-transport-ws"
graphqltransportwsConnectionInitMsg = graphqltransportwsMessageType("connection_init")
graphqltransportwsConnectionAckMsg = graphqltransportwsMessageType("connection_ack")
graphqltransportwsSubscribeMsg = graphqltransportwsMessageType("subscribe")
graphqltransportwsNextMsg = graphqltransportwsMessageType("next")
graphqltransportwsErrorMsg = graphqltransportwsMessageType("error")
graphqltransportwsCompleteMsg = graphqltransportwsMessageType("complete")
graphqltransportwsPingMsg = graphqltransportwsMessageType("ping")
graphqltransportwsPongMsg = graphqltransportwsMessageType("pong")
)
var allGraphqltransportwsMessageTypes = []graphqltransportwsMessageType{
graphqltransportwsConnectionInitMsg,
graphqltransportwsConnectionAckMsg,
graphqltransportwsSubscribeMsg,
graphqltransportwsNextMsg,
graphqltransportwsErrorMsg,
graphqltransportwsCompleteMsg,
graphqltransportwsPingMsg,
graphqltransportwsPongMsg,
}
type (
graphqltransportwsMessageExchanger struct {
c *websocket.Conn
}
graphqltransportwsMessage struct {
Payload json.RawMessage `json:"payload,omitempty"`
ID string `json:"id,omitempty"`
Type graphqltransportwsMessageType `json:"type"`
noOp bool
}
graphqltransportwsMessageType string
)
func (me graphqltransportwsMessageExchanger) NextMessage() (message, error) {
_, r, err := me.c.NextReader()
if err != nil {
return message{}, handleNextReaderError(err)
}
var graphqltransportwsMessage graphqltransportwsMessage
if err := jsonDecode(r, &graphqltransportwsMessage); err != nil {
return message{}, errInvalidMsg
}
return graphqltransportwsMessage.toMessage()
}
func (me graphqltransportwsMessageExchanger) Send(m *message) error {
msg := &graphqltransportwsMessage{}
if err := msg.fromMessage(m); err != nil {
return err
}
if msg.noOp {
return nil
}
return me.c.WriteJSON(msg)
}
func (t *graphqltransportwsMessageType) UnmarshalText(text []byte) (err error) {
var found bool
for _, candidate := range allGraphqltransportwsMessageTypes {
if string(candidate) == string(text) {
*t = candidate
found = true
break
}
}
if !found {
err = fmt.Errorf("invalid message type %s", string(text))
}
return err
}
func (t graphqltransportwsMessageType) MarshalText() ([]byte, error) {
return []byte(string(t)), nil
}
func (m graphqltransportwsMessage) toMessage() (message, error) {
var t messageType
var err error
switch m.Type {
default:
err = fmt.Errorf("invalid client->server message type %s", m.Type)
case graphqltransportwsConnectionInitMsg:
t = initMessageType
case graphqltransportwsSubscribeMsg:
t = startMessageType
case graphqltransportwsCompleteMsg:
t = stopMessageType
case graphqltransportwsPingMsg:
t = pingMessageType
case graphqltransportwsPongMsg:
t = pongMessageType
}
return message{
payload: m.Payload,
id: m.ID,
t: t,
}, err
}
func (m *graphqltransportwsMessage) fromMessage(msg *message) (err error) {
m.ID = msg.id
m.Payload = msg.payload
switch msg.t {
default:
err = fmt.Errorf("invalid server->client message type %s", msg.t)
case connectionAckMessageType:
m.Type = graphqltransportwsConnectionAckMsg
case keepAliveMessageType:
m.noOp = true
case connectionErrorMessageType:
m.noOp = true
case dataMessageType:
m.Type = graphqltransportwsNextMsg
case completeMessageType:
m.Type = graphqltransportwsCompleteMsg
case errorMessageType:
m.Type = graphqltransportwsErrorMsg
case pingMessageType:
m.Type = graphqltransportwsPingMsg
case pongMessageType:
m.Type = graphqltransportwsPongMsg
}
return err
}

View File

@@ -1,171 +0,0 @@
package transport
import (
"encoding/json"
"fmt"
"github.com/gorilla/websocket"
)
// https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md
const (
graphqlwsSubprotocol = "graphql-ws"
graphqlwsConnectionInitMsg = graphqlwsMessageType("connection_init")
graphqlwsConnectionTerminateMsg = graphqlwsMessageType("connection_terminate")
graphqlwsStartMsg = graphqlwsMessageType("start")
graphqlwsStopMsg = graphqlwsMessageType("stop")
graphqlwsConnectionAckMsg = graphqlwsMessageType("connection_ack")
graphqlwsConnectionErrorMsg = graphqlwsMessageType("connection_error")
graphqlwsDataMsg = graphqlwsMessageType("data")
graphqlwsErrorMsg = graphqlwsMessageType("error")
graphqlwsCompleteMsg = graphqlwsMessageType("complete")
graphqlwsConnectionKeepAliveMsg = graphqlwsMessageType("ka")
)
var allGraphqlwsMessageTypes = []graphqlwsMessageType{
graphqlwsConnectionInitMsg,
graphqlwsConnectionTerminateMsg,
graphqlwsStartMsg,
graphqlwsStopMsg,
graphqlwsConnectionAckMsg,
graphqlwsConnectionErrorMsg,
graphqlwsDataMsg,
graphqlwsErrorMsg,
graphqlwsCompleteMsg,
graphqlwsConnectionKeepAliveMsg,
}
type (
graphqlwsMessageExchanger struct {
c *websocket.Conn
}
graphqlwsMessage struct {
Payload json.RawMessage `json:"payload,omitempty"`
ID string `json:"id,omitempty"`
Type graphqlwsMessageType `json:"type"`
noOp bool
}
graphqlwsMessageType string
)
func (me graphqlwsMessageExchanger) NextMessage() (message, error) {
_, r, err := me.c.NextReader()
if err != nil {
return message{}, handleNextReaderError(err)
}
var graphqlwsMessage graphqlwsMessage
if err := jsonDecode(r, &graphqlwsMessage); err != nil {
return message{}, errInvalidMsg
}
return graphqlwsMessage.toMessage()
}
func (me graphqlwsMessageExchanger) Send(m *message) error {
msg := &graphqlwsMessage{}
if err := msg.fromMessage(m); err != nil {
return err
}
if msg.noOp {
return nil
}
return me.c.WriteJSON(msg)
}
func (t *graphqlwsMessageType) UnmarshalText(text []byte) (err error) {
var found bool
for _, candidate := range allGraphqlwsMessageTypes {
if string(candidate) == string(text) {
*t = candidate
found = true
break
}
}
if !found {
err = fmt.Errorf("invalid message type %s", string(text))
}
return err
}
func (t graphqlwsMessageType) MarshalText() ([]byte, error) {
return []byte(string(t)), nil
}
func (m graphqlwsMessage) toMessage() (message, error) {
var t messageType
var err error
switch m.Type {
default:
err = fmt.Errorf("invalid client->server message type %s", m.Type)
case graphqlwsConnectionInitMsg:
t = initMessageType
case graphqlwsConnectionTerminateMsg:
t = connectionCloseMessageType
case graphqlwsStartMsg:
t = startMessageType
case graphqlwsStopMsg:
t = stopMessageType
case graphqlwsConnectionAckMsg:
t = connectionAckMessageType
case graphqlwsConnectionErrorMsg:
t = connectionErrorMessageType
case graphqlwsDataMsg:
t = dataMessageType
case graphqlwsErrorMsg:
t = errorMessageType
case graphqlwsCompleteMsg:
t = completeMessageType
case graphqlwsConnectionKeepAliveMsg:
t = keepAliveMessageType
}
return message{
payload: m.Payload,
id: m.ID,
t: t,
}, err
}
func (m *graphqlwsMessage) fromMessage(msg *message) (err error) {
m.ID = msg.id
m.Payload = msg.payload
switch msg.t {
default:
err = fmt.Errorf("invalid server->client message type %s", msg.t)
case initMessageType:
m.Type = graphqlwsConnectionInitMsg
case connectionAckMessageType:
m.Type = graphqlwsConnectionAckMsg
case keepAliveMessageType:
m.Type = graphqlwsConnectionKeepAliveMsg
case connectionErrorMessageType:
m.Type = graphqlwsConnectionErrorMsg
case connectionCloseMessageType:
m.Type = graphqlwsConnectionTerminateMsg
case startMessageType:
m.Type = graphqlwsStartMsg
case stopMessageType:
m.Type = graphqlwsStopMsg
case dataMessageType:
m.Type = graphqlwsDataMsg
case completeMessageType:
m.Type = graphqlwsCompleteMsg
case errorMessageType:
m.Type = graphqlwsErrorMsg
case pingMessageType:
m.noOp = true
case pongMessageType:
m.noOp = true
}
return err
}

View File

@@ -1,57 +0,0 @@
package transport
import "context"
type key string
const (
initpayload key = "ws_initpayload_context"
)
// InitPayload is a structure that is parsed from the websocket init message payload. TO use
// request headers for non-websocket, instead wrap the graphql handler in a middleware.
type InitPayload map[string]interface{}
// GetString safely gets a string value from the payload. It returns an empty string if the
// payload is nil or the value isn't set.
func (p InitPayload) GetString(key string) string {
if p == nil {
return ""
}
if value, ok := p[key]; ok {
res, _ := value.(string)
return res
}
return ""
}
// Authorization is a short hand for getting the Authorization header from the
// payload.
func (p InitPayload) Authorization() string {
if value := p.GetString("Authorization"); value != "" {
return value
}
if value := p.GetString("authorization"); value != "" {
return value
}
return ""
}
func withInitPayload(ctx context.Context, payload InitPayload) context.Context {
return context.WithValue(ctx, initpayload, payload)
}
// GetInitPayload gets a map of the data sent with the connection_init message, which is used by
// graphql clients as a stand-in for HTTP headers.
func GetInitPayload(ctx context.Context) InitPayload {
payload, ok := ctx.Value(initpayload).(InitPayload)
if !ok {
return nil
}
return payload
}

View File

@@ -1,116 +0,0 @@
package transport
import (
"encoding/json"
"errors"
"github.com/gorilla/websocket"
)
const (
initMessageType messageType = iota
connectionAckMessageType
keepAliveMessageType
connectionErrorMessageType
connectionCloseMessageType
startMessageType
stopMessageType
dataMessageType
completeMessageType
errorMessageType
pingMessageType
pongMessageType
)
var (
supportedSubprotocols = []string{
graphqlwsSubprotocol,
graphqltransportwsSubprotocol,
}
errWsConnClosed = errors.New("websocket connection closed")
errInvalidMsg = errors.New("invalid message received")
)
type (
messageType int
message struct {
payload json.RawMessage
id string
t messageType
}
messageExchanger interface {
NextMessage() (message, error)
Send(m *message) error
}
)
func (t messageType) String() string {
var text string
switch t {
default:
text = "unknown"
case initMessageType:
text = "init"
case connectionAckMessageType:
text = "connection ack"
case keepAliveMessageType:
text = "keep alive"
case connectionErrorMessageType:
text = "connection error"
case connectionCloseMessageType:
text = "connection close"
case startMessageType:
text = "start"
case stopMessageType:
text = "stop subscription"
case dataMessageType:
text = "data"
case completeMessageType:
text = "complete"
case errorMessageType:
text = "error"
case pingMessageType:
text = "ping"
case pongMessageType:
text = "pong"
}
return text
}
func contains(list []string, elem string) bool {
for _, e := range list {
if e == elem {
return true
}
}
return false
}
func (t *Websocket) injectGraphQLWSSubprotocols() {
// the list of subprotocols is specified by the consumer of the Websocket struct,
// in order to preserve backward compatibility, we inject the graphql specific subprotocols
// at runtime
if !t.didInjectSubprotocols {
defer func() {
t.didInjectSubprotocols = true
}()
for _, subprotocol := range supportedSubprotocols {
if !contains(t.Upgrader.Subprotocols, subprotocol) {
t.Upgrader.Subprotocols = append(t.Upgrader.Subprotocols, subprotocol)
}
}
}
}
func handleNextReaderError(err error) error {
// TODO: should we consider all closure scenarios here for the ws connection?
// for now we only list the error codes from the previous implementation
if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseNoStatusReceived) {
return errWsConnClosed
}
return err
}

View File

@@ -1,58 +0,0 @@
package graphql
import (
"encoding/json"
"fmt"
"io"
"strconv"
)
func MarshalID(s string) Marshaler {
return MarshalString(s)
}
func UnmarshalID(v interface{}) (string, error) {
switch v := v.(type) {
case string:
return v, nil
case json.Number:
return string(v), nil
case int:
return strconv.Itoa(v), nil
case int64:
return strconv.FormatInt(v, 10), nil
case float64:
return fmt.Sprintf("%f", v), nil
case bool:
if v {
return "true", nil
} else {
return "false", nil
}
case nil:
return "null", nil
default:
return "", fmt.Errorf("%T is not a string", v)
}
}
func MarshalIntID(i int) Marshaler {
return WriterFunc(func(w io.Writer) {
writeQuotedString(w, strconv.Itoa(i))
})
}
func UnmarshalIntID(v interface{}) (int, error) {
switch v := v.(type) {
case string:
return strconv.Atoi(v)
case int:
return v, nil
case int64:
return int(v), nil
case json.Number:
return strconv.Atoi(string(v))
default:
return 0, fmt.Errorf("%T is not an int", v)
}
}

View File

@@ -1,79 +0,0 @@
package graphql
import (
"encoding/json"
"fmt"
"io"
"strconv"
)
func MarshalInt(i int) Marshaler {
return WriterFunc(func(w io.Writer) {
io.WriteString(w, strconv.Itoa(i))
})
}
func UnmarshalInt(v interface{}) (int, error) {
switch v := v.(type) {
case string:
return strconv.Atoi(v)
case int:
return v, nil
case int64:
return int(v), nil
case json.Number:
return strconv.Atoi(string(v))
default:
return 0, fmt.Errorf("%T is not an int", v)
}
}
func MarshalInt64(i int64) Marshaler {
return WriterFunc(func(w io.Writer) {
io.WriteString(w, strconv.FormatInt(i, 10))
})
}
func UnmarshalInt64(v interface{}) (int64, error) {
switch v := v.(type) {
case string:
return strconv.ParseInt(v, 10, 64)
case int:
return int64(v), nil
case int64:
return v, nil
case json.Number:
return strconv.ParseInt(string(v), 10, 64)
default:
return 0, fmt.Errorf("%T is not an int", v)
}
}
func MarshalInt32(i int32) Marshaler {
return WriterFunc(func(w io.Writer) {
io.WriteString(w, strconv.FormatInt(int64(i), 10))
})
}
func UnmarshalInt32(v interface{}) (int32, error) {
switch v := v.(type) {
case string:
iv, err := strconv.ParseInt(v, 10, 32)
if err != nil {
return 0, err
}
return int32(iv), nil
case int:
return int32(v), nil
case int64:
return int32(v), nil
case json.Number:
iv, err := strconv.ParseInt(string(v), 10, 32)
if err != nil {
return 0, err
}
return int32(iv), nil
default:
return 0, fmt.Errorf("%T is not an int", v)
}
}

View File

@@ -1,101 +0,0 @@
// introspection implements the spec defined in https://github.com/facebook/graphql/blob/master/spec/Section%204%20--%20Introspection.md#schema-introspection
package introspection
import "github.com/vektah/gqlparser/v2/ast"
type (
Directive struct {
Name string
description string
Locations []string
Args []InputValue
IsRepeatable bool
}
EnumValue struct {
Name string
description string
deprecation *ast.Directive
}
Field struct {
Name string
description string
Type *Type
Args []InputValue
deprecation *ast.Directive
}
InputValue struct {
Name string
description string
DefaultValue *string
Type *Type
}
)
func WrapSchema(schema *ast.Schema) *Schema {
return &Schema{schema: schema}
}
func (f *EnumValue) Description() *string {
if f.description == "" {
return nil
}
return &f.description
}
func (f *EnumValue) IsDeprecated() bool {
return f.deprecation != nil
}
func (f *EnumValue) DeprecationReason() *string {
if f.deprecation == nil {
return nil
}
reason := f.deprecation.Arguments.ForName("reason")
if reason == nil {
return nil
}
return &reason.Value.Raw
}
func (f *Field) Description() *string {
if f.description == "" {
return nil
}
return &f.description
}
func (f *Field) IsDeprecated() bool {
return f.deprecation != nil
}
func (f *Field) DeprecationReason() *string {
if f.deprecation == nil {
return nil
}
reason := f.deprecation.Arguments.ForName("reason")
if reason == nil {
return nil
}
return &reason.Value.Raw
}
func (f *InputValue) Description() *string {
if f.description == "" {
return nil
}
return &f.description
}
func (f *Directive) Description() *string {
if f.description == "" {
return nil
}
return &f.description
}

View File

@@ -1,106 +0,0 @@
package introspection
// Query is the query generated by graphiql to determine type information
const Query = `
query IntrospectionQuery {
__schema {
description
queryType {
name
}
mutationType {
name
}
subscriptionType {
name
}
types {
...FullType
}
directives {
name
description
locations
args {
...InputValue
}
}
}
}
fragment FullType on __Type {
kind
name
description
specifiedByURL
fields(includeDeprecated: true) {
name
description
args {
...InputValue
}
type {
...TypeRef
}
isDeprecated
deprecationReason
}
inputFields {
...InputValue
}
interfaces {
...TypeRef
}
enumValues(includeDeprecated: true) {
name
description
isDeprecated
deprecationReason
}
possibleTypes {
...TypeRef
}
}
fragment InputValue on __InputValue {
name
description
type {
...TypeRef
}
defaultValue
}
fragment TypeRef on __Type {
kind
name
ofType {
kind
name
ofType {
kind
name
ofType {
kind
name
ofType {
kind
name
ofType {
kind
name
ofType {
kind
name
ofType {
kind
name
}
}
}
}
}
}
}
}
`

View File

@@ -1,93 +0,0 @@
package introspection
import (
"sort"
"strings"
"github.com/vektah/gqlparser/v2/ast"
)
type Schema struct {
schema *ast.Schema
}
func (s *Schema) Description() *string {
if s.schema.Description == "" {
return nil
}
return &s.schema.Description
}
func (s *Schema) Types() []Type {
typeIndex := map[string]Type{}
typeNames := make([]string, 0, len(s.schema.Types))
for _, typ := range s.schema.Types {
if strings.HasPrefix(typ.Name, "__") {
continue
}
typeNames = append(typeNames, typ.Name)
typeIndex[typ.Name] = *WrapTypeFromDef(s.schema, typ)
}
sort.Strings(typeNames)
types := make([]Type, len(typeNames))
for i, t := range typeNames {
types[i] = typeIndex[t]
}
return types
}
func (s *Schema) QueryType() *Type {
return WrapTypeFromDef(s.schema, s.schema.Query)
}
func (s *Schema) MutationType() *Type {
return WrapTypeFromDef(s.schema, s.schema.Mutation)
}
func (s *Schema) SubscriptionType() *Type {
return WrapTypeFromDef(s.schema, s.schema.Subscription)
}
func (s *Schema) Directives() []Directive {
dIndex := map[string]Directive{}
dNames := make([]string, 0, len(s.schema.Directives))
for _, d := range s.schema.Directives {
dNames = append(dNames, d.Name)
dIndex[d.Name] = s.directiveFromDef(d)
}
sort.Strings(dNames)
res := make([]Directive, len(dNames))
for i, d := range dNames {
res[i] = dIndex[d]
}
return res
}
func (s *Schema) directiveFromDef(d *ast.DirectiveDefinition) Directive {
locs := make([]string, len(d.Locations))
for i, loc := range d.Locations {
locs[i] = string(loc)
}
args := make([]InputValue, len(d.Arguments))
for i, arg := range d.Arguments {
args[i] = InputValue{
Name: arg.Name,
description: arg.Description,
DefaultValue: defaultValue(arg.DefaultValue),
Type: WrapTypeFromType(s.schema, arg.Type),
}
}
return Directive{
Name: d.Name,
description: d.Description,
Locations: locs,
Args: args,
IsRepeatable: d.IsRepeatable,
}
}

View File

@@ -1,191 +0,0 @@
package introspection
import (
"strings"
"github.com/vektah/gqlparser/v2/ast"
)
type Type struct {
schema *ast.Schema
def *ast.Definition
typ *ast.Type
}
func WrapTypeFromDef(s *ast.Schema, def *ast.Definition) *Type {
if def == nil {
return nil
}
return &Type{schema: s, def: def}
}
func WrapTypeFromType(s *ast.Schema, typ *ast.Type) *Type {
if typ == nil {
return nil
}
if !typ.NonNull && typ.NamedType != "" {
return &Type{schema: s, def: s.Types[typ.NamedType]}
}
return &Type{schema: s, typ: typ}
}
func (t *Type) Kind() string {
if t.typ != nil {
if t.typ.NonNull {
return "NON_NULL"
}
if t.typ.Elem != nil {
return "LIST"
}
} else {
return string(t.def.Kind)
}
panic("UNKNOWN")
}
func (t *Type) Name() *string {
if t.def == nil {
return nil
}
return &t.def.Name
}
func (t *Type) Description() *string {
if t.def == nil || t.def.Description == "" {
return nil
}
return &t.def.Description
}
func (t *Type) Fields(includeDeprecated bool) []Field {
if t.def == nil || (t.def.Kind != ast.Object && t.def.Kind != ast.Interface) {
return []Field{}
}
fields := []Field{}
for _, f := range t.def.Fields {
if strings.HasPrefix(f.Name, "__") {
continue
}
if !includeDeprecated && f.Directives.ForName("deprecated") != nil {
continue
}
var args []InputValue
for _, arg := range f.Arguments {
args = append(args, InputValue{
Type: WrapTypeFromType(t.schema, arg.Type),
Name: arg.Name,
description: arg.Description,
DefaultValue: defaultValue(arg.DefaultValue),
})
}
fields = append(fields, Field{
Name: f.Name,
description: f.Description,
Args: args,
Type: WrapTypeFromType(t.schema, f.Type),
deprecation: f.Directives.ForName("deprecated"),
})
}
return fields
}
func (t *Type) InputFields() []InputValue {
if t.def == nil || t.def.Kind != ast.InputObject {
return []InputValue{}
}
res := []InputValue{}
for _, f := range t.def.Fields {
res = append(res, InputValue{
Name: f.Name,
description: f.Description,
Type: WrapTypeFromType(t.schema, f.Type),
DefaultValue: defaultValue(f.DefaultValue),
})
}
return res
}
func defaultValue(value *ast.Value) *string {
if value == nil {
return nil
}
val := value.String()
return &val
}
func (t *Type) Interfaces() []Type {
if t.def == nil || t.def.Kind != ast.Object {
return []Type{}
}
res := []Type{}
for _, intf := range t.def.Interfaces {
res = append(res, *WrapTypeFromDef(t.schema, t.schema.Types[intf]))
}
return res
}
func (t *Type) PossibleTypes() []Type {
if t.def == nil || (t.def.Kind != ast.Interface && t.def.Kind != ast.Union) {
return []Type{}
}
res := []Type{}
for _, pt := range t.schema.GetPossibleTypes(t.def) {
res = append(res, *WrapTypeFromDef(t.schema, pt))
}
return res
}
func (t *Type) EnumValues(includeDeprecated bool) []EnumValue {
if t.def == nil || t.def.Kind != ast.Enum {
return []EnumValue{}
}
res := []EnumValue{}
for _, val := range t.def.EnumValues {
if !includeDeprecated && val.Directives.ForName("deprecated") != nil {
continue
}
res = append(res, EnumValue{
Name: val.Name,
description: val.Description,
deprecation: val.Directives.ForName("deprecated"),
})
}
return res
}
func (t *Type) OfType() *Type {
if t.typ == nil {
return nil
}
if t.typ.NonNull {
// fake non null nodes
cpy := *t.typ
cpy.NonNull = false
return WrapTypeFromType(t.schema, &cpy)
}
return WrapTypeFromType(t.schema, t.typ.Elem)
}
func (t *Type) SpecifiedByURL() *string {
directive := t.def.Directives.ForName("specifiedBy")
if t.def.Kind != ast.Scalar || directive == nil {
return nil
}
// def: directive @specifiedBy(url: String!) on SCALAR
// the argument "url" is required.
url := directive.Arguments.ForName("url")
return &url.Value.Raw
}

View File

@@ -1,93 +0,0 @@
package graphql
import (
"context"
"io"
)
var (
nullLit = []byte(`null`)
trueLit = []byte(`true`)
falseLit = []byte(`false`)
openBrace = []byte(`{`)
closeBrace = []byte(`}`)
openBracket = []byte(`[`)
closeBracket = []byte(`]`)
colon = []byte(`:`)
comma = []byte(`,`)
)
var (
Null = &lit{nullLit}
True = &lit{trueLit}
False = &lit{falseLit}
)
type Marshaler interface {
MarshalGQL(w io.Writer)
}
type Unmarshaler interface {
UnmarshalGQL(v interface{}) error
}
type ContextMarshaler interface {
MarshalGQLContext(ctx context.Context, w io.Writer) error
}
type ContextUnmarshaler interface {
UnmarshalGQLContext(ctx context.Context, v interface{}) error
}
type contextMarshalerAdapter struct {
Context context.Context
ContextMarshaler
}
func WrapContextMarshaler(ctx context.Context, m ContextMarshaler) Marshaler {
return contextMarshalerAdapter{Context: ctx, ContextMarshaler: m}
}
func (a contextMarshalerAdapter) MarshalGQL(w io.Writer) {
err := a.MarshalGQLContext(a.Context, w)
if err != nil {
AddError(a.Context, err)
Null.MarshalGQL(w)
}
}
type WriterFunc func(writer io.Writer)
func (f WriterFunc) MarshalGQL(w io.Writer) {
f(w)
}
type ContextWriterFunc func(ctx context.Context, writer io.Writer) error
func (f ContextWriterFunc) MarshalGQLContext(ctx context.Context, w io.Writer) error {
return f(ctx, w)
}
type Array []Marshaler
func (a Array) MarshalGQL(writer io.Writer) {
writer.Write(openBracket)
for i, val := range a {
if i != 0 {
writer.Write(comma)
}
val.MarshalGQL(writer)
}
writer.Write(closeBracket)
}
type lit struct{ b []byte }
func (l lit) MarshalGQL(w io.Writer) {
w.Write(l.b)
}
func (l lit) MarshalGQLContext(ctx context.Context, w io.Writer) error {
w.Write(l.b)
return nil
}

View File

@@ -1,24 +0,0 @@
package graphql
import (
"encoding/json"
"fmt"
"io"
)
func MarshalMap(val map[string]interface{}) Marshaler {
return WriterFunc(func(w io.Writer) {
err := json.NewEncoder(w).Encode(val)
if err != nil {
panic(err)
}
})
}
func UnmarshalMap(v interface{}) (map[string]interface{}, error) {
if m, ok := v.(map[string]interface{}); ok {
return m, nil
}
return nil, fmt.Errorf("%T is not a map", v)
}

View File

@@ -1,16 +0,0 @@
package graphql
import "context"
func OneShot(resp *Response) ResponseHandler {
var oneshot bool
return func(context context.Context) *Response {
if oneshot {
return nil
}
oneshot = true
return resp
}
}

View File

@@ -1,73 +0,0 @@
package playground
import (
"html/template"
"net/http"
)
var page = template.Must(template.New("graphiql").Parse(`<!DOCTYPE html>
<html>
<head>
<title>{{.title}}</title>
<link
rel="stylesheet"
href="https://cdn.jsdelivr.net/npm/graphiql@{{.version}}/graphiql.min.css"
integrity="{{.cssSRI}}"
crossorigin="anonymous"
/>
</head>
<body style="margin: 0;">
<div id="graphiql" style="height: 100vh;"></div>
<script
src="https://cdn.jsdelivr.net/npm/react@17.0.2/umd/react.production.min.js"
integrity="{{.reactSRI}}"
crossorigin="anonymous"
></script>
<script
src="https://cdn.jsdelivr.net/npm/react-dom@17.0.2/umd/react-dom.production.min.js"
integrity="{{.reactDOMSRI}}"
crossorigin="anonymous"
></script>
<script
src="https://cdn.jsdelivr.net/npm/graphiql@{{.version}}/graphiql.min.js"
integrity="{{.jsSRI}}"
crossorigin="anonymous"
></script>
<script>
const url = location.protocol + '//' + location.host + '{{.endpoint}}';
const wsProto = location.protocol == 'https:' ? 'wss:' : 'ws:';
const subscriptionUrl = wsProto + '//' + location.host + '{{.endpoint}}';
const fetcher = GraphiQL.createFetcher({ url, subscriptionUrl });
ReactDOM.render(
React.createElement(GraphiQL, {
fetcher: fetcher,
headerEditorEnabled: true,
shouldPersistHeaders: true
}),
document.getElementById('graphiql'),
);
</script>
</body>
</html>
`))
func Handler(title string, endpoint string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Content-Type", "text/html")
err := page.Execute(w, map[string]string{
"title": title,
"endpoint": endpoint,
"version": "1.5.16",
"cssSRI": "sha256-HADQowUuFum02+Ckkv5Yu5ygRoLllHZqg0TFZXY7NHI=",
"jsSRI": "sha256-uHp12yvpXC4PC9+6JmITxKuLYwjlW9crq9ywPE5Rxco=",
"reactSRI": "sha256-Ipu/TQ50iCCVZBUsZyNJfxrDk0E2yhaEIz0vqI+kFG8=",
"reactDOMSRI": "sha256-nbMykgB6tsOFJ7OdVmPpdqMFVk4ZsqWocT6issAPUF0=",
})
if err != nil {
panic(err)
}
}
}

View File

@@ -1,20 +0,0 @@
package graphql
import (
"context"
"fmt"
"os"
"runtime/debug"
"github.com/vektah/gqlparser/v2/gqlerror"
)
type RecoverFunc func(ctx context.Context, err interface{}) (userMessage error)
func DefaultRecover(ctx context.Context, err interface{}) error {
fmt.Fprintln(os.Stderr, err)
fmt.Fprintln(os.Stderr)
debug.PrintStack()
return gqlerror.Errorf("internal system error")
}

View File

@@ -1,24 +0,0 @@
package graphql
import (
"context"
"encoding/json"
"fmt"
"github.com/vektah/gqlparser/v2/gqlerror"
)
// Errors are intentionally serialized first based on the advice in
// https://github.com/facebook/graphql/commit/7b40390d48680b15cb93e02d46ac5eb249689876#diff-757cea6edf0288677a9eea4cfc801d87R107
// and https://github.com/facebook/graphql/pull/384
type Response struct {
Errors gqlerror.List `json:"errors,omitempty"`
Data json.RawMessage `json:"data"`
Extensions map[string]interface{} `json:"extensions,omitempty"`
}
func ErrorResponse(ctx context.Context, messagef string, args ...interface{}) *Response {
return &Response{
Errors: gqlerror.List{{Message: fmt.Sprintf(messagef, args...)}},
}
}

View File

@@ -1,7 +0,0 @@
package graphql
type Query struct{}
type Mutation struct{}
type Subscription struct{}

View File

@@ -1,60 +0,0 @@
package graphql
import (
"context"
"fmt"
"time"
)
type Stats struct {
OperationStart time.Time
Read TraceTiming
Parsing TraceTiming
Validation TraceTiming
// Stats collected by handler extensions. Dont use directly, the extension should provide a type safe way to
// access this.
extension map[string]interface{}
}
type TraceTiming struct {
Start time.Time
End time.Time
}
var ctxTraceStart key = "trace_start"
// StartOperationTrace captures the current time and stores it in context. This will eventually be added to request
// context but we want to grab it as soon as possible. For transports that can only handle a single graphql query
// per http requests you dont need to call this at all, the server will do it for you. For transports that handle
// multiple (eg batching, subscriptions) this should be called before decoding each request.
func StartOperationTrace(ctx context.Context) context.Context {
return context.WithValue(ctx, ctxTraceStart, Now())
}
// GetStartTime should only be called by the handler package, it will be set into request context
// as Stats.Start
func GetStartTime(ctx context.Context) time.Time {
t, ok := ctx.Value(ctxTraceStart).(time.Time)
if !ok {
panic(fmt.Sprintf("missing start time: %T", ctx.Value(ctxTraceStart)))
}
return t
}
func (c *Stats) SetExtension(name string, data interface{}) {
if c.extension == nil {
c.extension = map[string]interface{}{}
}
c.extension[name] = data
}
func (c *Stats) GetExtension(name string) interface{} {
if c.extension == nil {
return nil
}
return c.extension[name]
}
// Now is time.Now, except in tests. Then it can be whatever you want it to be.
var Now = time.Now

View File

@@ -1,70 +0,0 @@
package graphql
import (
"fmt"
"io"
"strconv"
)
const encodeHex = "0123456789ABCDEF"
func MarshalString(s string) Marshaler {
return WriterFunc(func(w io.Writer) {
writeQuotedString(w, s)
})
}
func writeQuotedString(w io.Writer, s string) {
start := 0
io.WriteString(w, `"`)
for i, c := range s {
if c < 0x20 || c == '\\' || c == '"' {
io.WriteString(w, s[start:i])
switch c {
case '\t':
io.WriteString(w, `\t`)
case '\r':
io.WriteString(w, `\r`)
case '\n':
io.WriteString(w, `\n`)
case '\\':
io.WriteString(w, `\\`)
case '"':
io.WriteString(w, `\"`)
default:
io.WriteString(w, `\u00`)
w.Write([]byte{encodeHex[c>>4], encodeHex[c&0xf]})
}
start = i + 1
}
}
io.WriteString(w, s[start:])
io.WriteString(w, `"`)
}
func UnmarshalString(v interface{}) (string, error) {
switch v := v.(type) {
case string:
return v, nil
case int:
return strconv.Itoa(v), nil
case int64:
return strconv.FormatInt(v, 10), nil
case float64:
return fmt.Sprintf("%f", v), nil
case bool:
if v {
return "true", nil
} else {
return "false", nil
}
case nil:
return "null", nil
default:
return "", fmt.Errorf("%T is not a string", v)
}
}

View File

@@ -1,25 +0,0 @@
package graphql
import (
"errors"
"io"
"strconv"
"time"
)
func MarshalTime(t time.Time) Marshaler {
if t.IsZero() {
return Null
}
return WriterFunc(func(w io.Writer) {
io.WriteString(w, strconv.Quote(t.Format(time.RFC3339Nano)))
})
}
func UnmarshalTime(v interface{}) (time.Time, error) {
if tmpStr, ok := v.(string); ok {
return time.Parse(time.RFC3339Nano, tmpStr)
}
return time.Time{}, errors.New("time should be RFC3339Nano formatted string")
}

View File

@@ -1,81 +0,0 @@
package graphql
import (
"encoding/json"
"fmt"
"io"
"strconv"
)
func MarshalUint(i uint) Marshaler {
return WriterFunc(func(w io.Writer) {
_, _ = io.WriteString(w, strconv.FormatUint(uint64(i), 10))
})
}
func UnmarshalUint(v interface{}) (uint, error) {
switch v := v.(type) {
case string:
u64, err := strconv.ParseUint(v, 10, 64)
return uint(u64), err
case int:
return uint(v), nil
case int64:
return uint(v), nil
case json.Number:
u64, err := strconv.ParseUint(string(v), 10, 64)
return uint(u64), err
default:
return 0, fmt.Errorf("%T is not an uint", v)
}
}
func MarshalUint64(i uint64) Marshaler {
return WriterFunc(func(w io.Writer) {
_, _ = io.WriteString(w, strconv.FormatUint(i, 10))
})
}
func UnmarshalUint64(v interface{}) (uint64, error) {
switch v := v.(type) {
case string:
return strconv.ParseUint(v, 10, 64)
case int:
return uint64(v), nil
case int64:
return uint64(v), nil
case json.Number:
return strconv.ParseUint(string(v), 10, 64)
default:
return 0, fmt.Errorf("%T is not an uint", v)
}
}
func MarshalUint32(i uint32) Marshaler {
return WriterFunc(func(w io.Writer) {
_, _ = io.WriteString(w, strconv.FormatUint(uint64(i), 10))
})
}
func UnmarshalUint32(v interface{}) (uint32, error) {
switch v := v.(type) {
case string:
iv, err := strconv.ParseInt(v, 10, 32)
if err != nil {
return 0, err
}
return uint32(iv), nil
case int:
return uint32(v), nil
case int64:
return uint32(v), nil
case json.Number:
iv, err := strconv.ParseUint(string(v), 10, 32)
if err != nil {
return 0, err
}
return uint32(iv), nil
default:
return 0, fmt.Errorf("%T is not an uint", v)
}
}

Some files were not shown because too many files have changed in this diff Show More