Skip to content

Commit 992cfa2

Browse files
Propagate composite type field types from parser to plugin proto
1 parent b2933a8 commit 992cfa2

File tree

8 files changed

+228
-155
lines changed

8 files changed

+228
-155
lines changed

Diff for: internal/cmd/shim.go

+16-2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"github.com/sqlc-dev/sqlc/internal/config/convert"
77
"github.com/sqlc-dev/sqlc/internal/info"
88
"github.com/sqlc-dev/sqlc/internal/plugin"
9+
"github.com/sqlc-dev/sqlc/internal/sql/ast"
910
"github.com/sqlc-dev/sqlc/internal/sql/catalog"
1011
)
1112

@@ -59,6 +60,18 @@ func pluginWASM(p config.Plugin) *plugin.Codegen_WASM {
5960
return nil
6061
}
6162

63+
func identifierSlice(types []*ast.TypeName) []*plugin.Identifier {
64+
ids := []*plugin.Identifier{}
65+
for _, typ := range types {
66+
ids = append(ids, &plugin.Identifier{
67+
Catalog: typ.Catalog,
68+
Schema: typ.Schema,
69+
Name: typ.Name,
70+
})
71+
}
72+
return ids
73+
}
74+
6275
func pluginCatalog(c *catalog.Catalog) *plugin.Catalog {
6376
var schemas []*plugin.Schema
6477
for _, s := range c.Schemas {
@@ -74,8 +87,9 @@ func pluginCatalog(c *catalog.Catalog) *plugin.Catalog {
7487
})
7588
case *catalog.CompositeType:
7689
cts = append(cts, &plugin.CompositeType{
77-
Name: typ.Name,
78-
Comment: typ.Comment,
90+
Name: typ.Name,
91+
Comment: typ.Comment,
92+
ColTypeNames: identifierSlice(typ.ColTypeNames),
7993
})
8094
}
8195
}

Diff for: internal/engine/postgresql/convert.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -889,7 +889,8 @@ func convertCompositeTypeStmt(n *pg.CompositeTypeStmt) *ast.CompositeTypeStmt {
889889
}
890890
rel := parseRelationFromRangeVar(n.Typevar)
891891
return &ast.CompositeTypeStmt{
892-
TypeName: rel.TypeName(),
892+
TypeName: rel.TypeName(),
893+
ColDefList: convertSlice(n.GetColdeflist()),
893894
}
894895
}
895896

Diff for: internal/engine/postgresql/parse.go

+20-3
Original file line numberDiff line numberDiff line change
@@ -392,9 +392,26 @@ func translate(node *nodes.Node) (ast.Node, error) {
392392
case *nodes.Node_CompositeTypeStmt:
393393
n := inner.CompositeTypeStmt
394394
rel := parseRelationFromRangeVar(n.Typevar)
395-
return &ast.CompositeTypeStmt{
396-
TypeName: rel.TypeName(),
397-
}, nil
395+
stmt := &ast.CompositeTypeStmt{
396+
TypeName: rel.TypeName(),
397+
ColDefList: &ast.List{},
398+
}
399+
for _, val := range n.GetColdeflist() {
400+
switch item := val.Node.(type) {
401+
case *nodes.Node_ColumnDef:
402+
rel, err := parseRelationFromNodes(item.ColumnDef.TypeName.Names)
403+
if err != nil {
404+
return nil, err
405+
}
406+
stmt.ColDefList.Items = append(stmt.ColDefList.Items, &ast.ColumnDef{
407+
Colname: item.ColumnDef.GetColname(),
408+
TypeName: rel.TypeName(),
409+
IsLocal: true,
410+
CollClause: convertCollateClause(item.ColumnDef.GetCollClause()),
411+
})
412+
}
413+
}
414+
return stmt, nil
398415

399416
case *nodes.Node_CreateStmt:
400417
n := inner.CreateStmt

Diff for: internal/plugin/codegen.pb.go

+142-129
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Diff for: internal/sql/ast/composite_type_stmt.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
package ast
22

33
type CompositeTypeStmt struct {
4-
TypeName *TypeName
4+
TypeName *TypeName
5+
ColDefList *List
56
}
67

78
func (n *CompositeTypeStmt) Pos() int {

Diff for: internal/sql/astutils/rewrite.go

+1
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,7 @@ func (a *application) apply(parent ast.Node, name string, iter *iterator, n ast.
495495

496496
case *ast.CompositeTypeStmt:
497497
a.apply(n, "TypeName", nil, n.TypeName)
498+
a.apply(n, "ColDefList", nil, n.ColDefList)
498499

499500
case *ast.Const:
500501
a.apply(n, "Xpr", nil, n.Xpr)

Diff for: internal/sql/catalog/types.go

+44-19
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,9 @@ func (e *Enum) isType() {
2828
}
2929

3030
type CompositeType struct {
31-
Name string
32-
Comment string
31+
Name string
32+
ColTypeNames []*ast.TypeName
33+
Comment string
3334
}
3435

3536
func (ct *CompositeType) isType() {
@@ -101,6 +102,16 @@ func stringSlice(list *ast.List) []string {
101102
return items
102103
}
103104

105+
func columnTypeNamesSlice(list *ast.List) []*ast.TypeName {
106+
items := []*ast.TypeName{}
107+
for _, item := range list.Items {
108+
if n, ok := item.(*ast.ColumnDef); ok {
109+
items = append(items, n.TypeName)
110+
}
111+
}
112+
return items
113+
}
114+
104115
func (c *Catalog) getType(rel *ast.TypeName) (Type, int, error) {
105116
ns := rel.Schema
106117
if ns == "" {
@@ -136,7 +147,8 @@ func (c *Catalog) createCompositeType(stmt *ast.CompositeTypeStmt) error {
136147
return sqlerr.TypeExists(tbl.Name)
137148
}
138149
schema.Types = append(schema.Types, &CompositeType{
139-
Name: stmt.TypeName.Name,
150+
Name: stmt.TypeName.Name,
151+
ColTypeNames: columnTypeNamesSlice(stmt.ColDefList),
140152
})
141153
return nil
142154
}
@@ -277,16 +289,11 @@ func (c *Catalog) alterTypeSetSchema(stmt *ast.AlterTypeSetSchemaStmt) error {
277289
oldSchema.Types = append(oldSchema.Types[:idx], oldSchema.Types[idx+1:]...)
278290
newSchema.Types = append(newSchema.Types, typ)
279291

280-
// Update all the table columns with the new type
281-
for _, schema := range c.Schemas {
282-
for _, table := range schema.Tables {
283-
for _, column := range table.Columns {
284-
if column.Type == oldType {
285-
column.Type.Schema = *stmt.NewSchema
286-
}
287-
}
292+
c.updateTypeNames(func(t *ast.TypeName) {
293+
if *t == oldType {
294+
t.Schema = *stmt.NewSchema
288295
}
289-
}
296+
})
290297
return nil
291298
}
292299

@@ -343,8 +350,9 @@ func (c *Catalog) renameType(stmt *ast.RenameTypeStmt) error {
343350

344351
case *CompositeType:
345352
schema.Types[idx] = &CompositeType{
346-
Name: newName,
347-
Comment: typ.Comment,
353+
Name: newName,
354+
ColTypeNames: typ.ColTypeNames,
355+
Comment: typ.Comment,
348356
}
349357

350358
case *Enum:
@@ -359,16 +367,33 @@ func (c *Catalog) renameType(stmt *ast.RenameTypeStmt) error {
359367

360368
}
361369

362-
// Update all the table columns with the new type
370+
c.updateTypeNames(func(t *ast.TypeName) {
371+
if *t == *stmt.Type {
372+
t.Name = newName
373+
}
374+
})
375+
376+
return nil
377+
}
378+
379+
func (c *Catalog) updateTypeNames(typeUpdater func(t *ast.TypeName)) error {
363380
for _, schema := range c.Schemas {
381+
// Update all the table columns with the new type
364382
for _, table := range schema.Tables {
365383
for _, column := range table.Columns {
366-
if column.Type == *stmt.Type {
367-
column.Type.Name = newName
368-
}
384+
typeUpdater(&column.Type)
385+
}
386+
}
387+
// Update all the composite fields with the new type
388+
for _, typ := range schema.Types {
389+
composite, ok := typ.(*CompositeType)
390+
if !ok {
391+
continue
392+
}
393+
for _, fieldType := range composite.ColTypeNames {
394+
typeUpdater(fieldType)
369395
}
370396
}
371397
}
372-
373398
return nil
374399
}

Diff for: protos/plugin/codegen.proto

+1
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ message Schema {
6161
message CompositeType {
6262
string name = 1;
6363
string comment = 2;
64+
repeated Identifier col_type_names = 3;
6465
}
6566

6667
message Enum {

0 commit comments

Comments
 (0)