Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,15 @@
require_relative 'aws-sdk-code-generator/rbs/error_list'
require_relative 'aws-sdk-code-generator/rbs/method_signature'
require_relative 'aws-sdk-code-generator/rbs/keyword_argument_builder'
require_relative 'aws-sdk-code-generator/rbs/input_type_alias_collector'
require_relative 'aws-sdk-code-generator/rbs/resource_action'
require_relative 'aws-sdk-code-generator/rbs/resource_association'
require_relative 'aws-sdk-code-generator/rbs/resource_batch_action'
require_relative 'aws-sdk-code-generator/rbs/resource_client_request'
require_relative 'aws-sdk-code-generator/rbs/waiter'
require_relative 'aws-sdk-code-generator/views/rbs/async_client_class'
require_relative 'aws-sdk-code-generator/views/rbs/client_class'
require_relative 'aws-sdk-code-generator/views/rbs/params'
require_relative 'aws-sdk-code-generator/views/rbs/errors_module'
require_relative 'aws-sdk-code-generator/views/rbs/resource_class'
require_relative 'aws-sdk-code-generator/views/rbs/root_resource_class'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,17 @@ def rbs_files(options = {})
prefix = options.fetch(:prefix, '')
codegenerated_plugins = codegen_plugins(prefix)
unless @service.h2_required_setting?
Comment thread
jterapin marked this conversation as resolved.
client_class = client_class_rbs(codegenerated_plugins)
collector = RBS::InputTypeAliasCollector.new(api: @service.api)
aliased_shapes = collector.shapes_to_alias
client_class = client_class_rbs(codegenerated_plugins, aliased_shapes)
y.yield("#{prefix}/client.rbs", client_class.render)
if aliased_shapes.any?
y.yield("#{prefix}/params.rbs", Views::RBS::Params.new(
service_name: @service.name,
api: @service.api,
aliased_shapes: aliased_shapes
).render)
end
y.yield("#{prefix}/resource.rbs", Views::RBS::RootResourceClass.new(
service_name: @service.name,
client_class: client_class,
Expand Down Expand Up @@ -197,7 +206,7 @@ def client_class(codegenerated_plugins)
).render
end

def client_class_rbs(codegenerated_plugins)
def client_class_rbs(codegenerated_plugins, aliased_shapes)
Views::RBS::ClientClass.new(
service_name: @service.name,
codegenerated_plugins: codegenerated_plugins,
Expand All @@ -209,7 +218,8 @@ def client_class_rbs(codegenerated_plugins)
protocol: @service.protocol,
add_plugins: @service.add_plugins,
remove_plugins: @service.remove_plugins,
protocol_settings: @service.protocol_settings
protocol_settings: @service.protocol_settings,
aliased_shapes: aliased_shapes
)
end

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# frozen_string_literal: true

module AwsSdkCodeGenerator
module RBS
# Collects structure shapes referenced more than once in input shape.
class InputTypeAliasCollector
def initialize(api:)
@api = api
@shape_usage_count = Hash.new(0)
@size_cache = {}
end
# Returns a topologically sorted array of shape names to render as type aliases.
# Leaf dependencies come first so aliases can reference each other.
# params.rbs uses aliases in this list to define RBS type aliases.
# client_class.rbs uses aliases in this list to deduplicate content.
def shapes_to_alias
count_shape_usage
aliased = @shape_usage_count.select do |shape_name, count|
shape = @api['shapes'][shape_name]
shape['type'] == 'structure' &&
count > 1 &&
rendered_rbs_line_count_heuristic(shape_name) > 5
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Could turn this 5 into a const, i.e. MIN_LINES_FOR_ALIAS or add a small # min lines for alias

i recognized that the 5 is mentioned in the pr description but we could self-document here.

end.keys.to_set
topological_sort(aliased)
end

private

def topological_sort(shape_names)
sorted = []
visited = Set.new
shape_names.each { |name| topo_visit(name, shape_names, visited, sorted) }
sorted
end

def topo_visit(name, shape_names, visited, sorted)
return if visited.include?(name)

visited << name
shape = @api['shapes'][name]
shape['members']&.each_value do |ref|
dep = ref['shape']
topo_visit(dep, shape_names, visited, sorted) if shape_names.include?(dep)
end
sorted << name
end

def count_shape_usage
@api['operations'].each_value do |op|
input_shape_name = op.dig('input', 'shape')
next unless input_shape_name

walk_shape(input_shape_name, Set.new)
end
end

def walk_shape(shape_name, ancestors)
return if ancestors.include?(shape_name)

ancestors += [shape_name]

shape = @api['shapes'][shape_name]
return unless shape && shape['type'] == 'structure'

shape['members']&.each_value do |member_ref|
member_shape_name = member_ref['shape']
member_shape = @api['shapes'][member_shape_name]
next unless member_shape

case member_shape['type']
when 'structure'
@shape_usage_count[member_shape_name] += 1
walk_shape(member_shape_name, ancestors)
when 'list'
walk_list(member_shape, ancestors)
when 'map'
walk_map(member_shape, ancestors)
end
end
end

def walk_list(list_shape, ancestors)
member_ref = list_shape['member']
member_shape = @api['shapes'][member_ref['shape']]
return unless member_shape

case member_shape['type']
when 'structure'
@shape_usage_count[member_ref['shape']] += 1
walk_shape(member_ref['shape'], ancestors)
when 'list'
walk_list(member_shape, ancestors)
when 'map'
walk_map(member_shape, ancestors)
end
end

def walk_map(map_shape, ancestors)
value_ref = map_shape['value']
value_shape = @api['shapes'][value_ref['shape']]
return unless value_shape

case value_shape['type']
when 'structure'
@shape_usage_count[value_ref['shape']] += 1
walk_shape(value_ref['shape'], ancestors)
when 'list'
walk_list(value_shape, ancestors)
when 'map'
walk_map(value_shape, ancestors)
end
end

def rendered_rbs_line_count_heuristic(shape_name, visited = Set.new)
return 0 if visited.include?(shape_name)

# Cache results to deduplicate calculation for nested structure types that get used multiple times in the model.
@size_cache ||= {}
Comment thread
jterapin marked this conversation as resolved.
Outdated
return @size_cache[shape_name] if @size_cache.key?(shape_name)

visited += [shape_name]
shape = @api['shapes'][shape_name]
return @size_cache[shape_name] = 1 unless shape['members']

@size_cache[shape_name] = shape['members'].sum do |_, ref|
member_line_count(ref, visited)
end
end

def member_line_count(ref, visited)
child = @api['shapes'][ref['shape']]
case child['type']
# If it's a structure, add 2 lines then recurse (one line each for '{' and '}' in rendered RBS)
when 'structure' then 2 + rendered_rbs_line_count_heuristic(ref['shape'], visited)
when 'list'
member_shape = @api['shapes'][child['member']['shape']]
# If it's a list with structure member, add 2 lines & recurse.
# (one line each for 'Hash[::String, {' and '}]' in rendered RBS)
# If it's a list with primitive member, add 1 and return
# (e.g., renders as 1-liner, like 'list: Array[::String]')
member_shape['type'] == 'structure' ? 2 +
rendered_rbs_line_count_heuristic(child['member']['shape'], visited) : 1
when 'map'
value_shape = @api['shapes'][child['value']['shape']]
# If it's a map with structure member, add 2 lines & recurse.
# (one line each for 'Hash[::String, {' and '}]' in rendered RBS)
# If it's a map with primitive member, add 1 and return
# (e.g., renders as 1-liner, like 'map: Hash[::String, ::String]')
value_shape['type'] == 'structure' ? 2 +
rendered_rbs_line_count_heuristic(child['value']['shape'], visited) : 1
else 1
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs to align with case/when blocks. For the files you touch, I'll recommend running rubocop (docs) to catch these. I.e.

rubocop <file-path.rb>

If the suggested changes makes sense, you can do: rubocop <file-path.rb> -a to automatically apply the fixes.

end
end
end
end
end
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@ class KeywordArgumentBuilder

attr_reader :newline

def initialize(api:, shape:, newline:, options: {})
def initialize(api:, shape:, newline:, options: {}, aliased_shapes: Set.new, alias_namespace: nil)
@api = api
@shape = shape
@newline = newline
@options = options
@aliased_shapes = aliased_shapes
@alias_namespace = alias_namespace
Comment thread
jterapin marked this conversation as resolved.
Outdated
end

def format(indent: '')
Expand Down Expand Up @@ -69,10 +71,16 @@ def struct_member(struct, member_name, member_ref, i, visited, keyword:)
def ref_value(ref, i, visited)
if visited.include?(ref['shape'])
return "untyped"
else
visited = visited + [ref['shape']]
end
Comment thread
jterapin marked this conversation as resolved.
Outdated

# If this shape should be aliased, emit the alias reference
if @aliased_shapes.include?(ref['shape'])
alias_name = Underscore.underscore(ref['shape'])
return @alias_namespace ? "#{@alias_namespace}::#{alias_name}" : alias_name
end

visited = visited + [ref['shape']]

s = shape(ref)
case s['type']
when 'structure'
Expand Down Expand Up @@ -115,7 +123,7 @@ def scalar_list(member_ref, i, visited)

def complex_list(member_ref, i, visited)
newline_indent = newline ? "\n#{i}" : ""
"Array[#{newline_indent}#{more_indent}#{ref_value(member_ref, i + more_indent, visited)},#{newline_indent}]"
"Array[#{newline_indent}#{more_indent}#{ref_value(member_ref, i + more_indent, visited)}#{newline_indent}]"
end

def complex?(ref)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def initialize(options)
@codegenerated_plugins = options.fetch(:codegenerated_plugins)
@waiters = AwsSdkCodeGenerator::RBS::Waiter.build_list(api: @api, waiters:options.fetch(:waiters))
@protocol_settings = options.fetch(:protocol_settings, {})
@aliased_shapes = options.fetch(:aliased_shapes, Set.new).to_set
end

# @return [String|nil]
Expand Down Expand Up @@ -60,6 +61,8 @@ def operations
api: @api,
shape: input_shape,
newline: true,
aliased_shapes: @aliased_shapes,
alias_namespace: 'Params',
Comment thread
jterapin marked this conversation as resolved.
Outdated
)
arguments = builder.format(indent: indent)
include_required = input_shape["required"]&.empty?&.!
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# frozen_string_literal: true

module AwsSdkCodeGenerator
module Views
module RBS
class Params < View
include Helper

def initialize(options)
@service_name = options.fetch(:service_name)
@api = options.fetch(:api)
@aliased_shapes = options.fetch(:aliased_shapes)
end

def generated_src_warning
GENERATED_SRC_WARNING
end

def service_name
@service_name
end

def type_aliases
aliased_set = @aliased_shapes.to_set
@aliased_shapes.map do |shape_name|
shape = @api['shapes'][shape_name]
builder = AwsSdkCodeGenerator::RBS::KeywordArgumentBuilder.new(
api: @api,
shape: shape,
newline: true,
aliased_shapes: aliased_set - [shape_name],
alias_namespace: 'Params',
)
{
'name' => Underscore.underscore(shape_name),
'definition' => builder.struct(shape, ' ', []),
Comment thread
jterapin marked this conversation as resolved.
Outdated
}
end
end
end
end
end
end
Loading
Loading