Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,15 @@
require_relative 'aws-sdk-code-generator/rbs/error_list'
require_relative 'aws-sdk-code-generator/rbs/method_signature'
require_relative 'aws-sdk-code-generator/rbs/keyword_argument_builder'
require_relative 'aws-sdk-code-generator/rbs/input_type_alias_collector'
require_relative 'aws-sdk-code-generator/rbs/resource_action'
require_relative 'aws-sdk-code-generator/rbs/resource_association'
require_relative 'aws-sdk-code-generator/rbs/resource_batch_action'
require_relative 'aws-sdk-code-generator/rbs/resource_client_request'
require_relative 'aws-sdk-code-generator/rbs/waiter'
require_relative 'aws-sdk-code-generator/views/rbs/async_client_class'
require_relative 'aws-sdk-code-generator/views/rbs/client_class'
require_relative 'aws-sdk-code-generator/views/rbs/params'
require_relative 'aws-sdk-code-generator/views/rbs/errors_module'
require_relative 'aws-sdk-code-generator/views/rbs/resource_class'
require_relative 'aws-sdk-code-generator/views/rbs/root_resource_class'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,17 @@ def rbs_files(options = {})
Enumerator.new do |y|
prefix = options.fetch(:prefix, '')
codegenerated_plugins = codegen_plugins(prefix)
type_alias_collector = RBS::InputTypeAliasCollector.new(api: @service.api)
aliased_shapes = type_alias_collector.shapes_to_alias
if aliased_shapes.any?
y.yield("#{prefix}/params.rbs", Views::RBS::Params.new(
service_name: @service.name,
api: @service.api,
aliased_shapes: aliased_shapes
).render)
end
unless @service.h2_required_setting?
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes are done for regular Client but what about AsyncClient?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh oops, good catch. I didn't know Ruby published separate client type for bidirectional streaming lol. Added to async clients as well as pointed out.

client_class = client_class_rbs(codegenerated_plugins)
client_class = client_class_rbs(codegenerated_plugins, aliased_shapes)
y.yield("#{prefix}/client.rbs", client_class.render)
y.yield("#{prefix}/resource.rbs", Views::RBS::RootResourceClass.new(
service_name: @service.name,
Expand All @@ -124,7 +133,12 @@ def rbs_files(options = {})
paginators: @service.paginators
).render)
end
y.yield("#{prefix}/async_client.rbs", async_client_class_rbs(codegenerated_plugins).render) if @service.h2_setting?
if @service.h2_setting?
y.yield("#{prefix}/async_client.rbs", async_client_class_rbs(
codegenerated_plugins,
aliased_shapes
).render)
end
y.yield("#{prefix}/errors.rbs", Views::RBS::ErrorsModule.new(service: @service).render)
if @waiters
y.yield("#{prefix}/waiters.rbs", Views::RBS::WaitersModule.new(
Expand Down Expand Up @@ -197,7 +211,7 @@ def client_class(codegenerated_plugins)
).render
end

def client_class_rbs(codegenerated_plugins)
def client_class_rbs(codegenerated_plugins, aliased_shapes)
Views::RBS::ClientClass.new(
service_name: @service.name,
codegenerated_plugins: codegenerated_plugins,
Expand All @@ -209,7 +223,8 @@ def client_class_rbs(codegenerated_plugins)
protocol: @service.protocol,
add_plugins: @service.add_plugins,
remove_plugins: @service.remove_plugins,
protocol_settings: @service.protocol_settings
protocol_settings: @service.protocol_settings,
aliased_shapes: aliased_shapes
)
end

Expand All @@ -233,7 +248,7 @@ def async_client_class(codegenerated_plugins)
).render
end

def async_client_class_rbs(codegenerated_plugins)
def async_client_class_rbs(codegenerated_plugins, aliased_shapes)
Views::RBS::AsyncClientClass.new(
service_name: @service.name,
codegenerated_plugins: codegenerated_plugins,
Expand All @@ -245,7 +260,8 @@ def async_client_class_rbs(codegenerated_plugins)
add_plugins: @service.add_plugins,
remove_plugins: @service.remove_plugins,
protocol_settings: @service.protocol_settings,
async_client: true
async_client: true,
aliased_shapes: aliased_shapes
)
end

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# frozen_string_literal: true

module AwsSdkCodeGenerator
module RBS
# Collects structure shapes referenced more than once in input shape.
class InputTypeAliasCollector
def initialize(api:)
@api = api
@shape_usage_count = Hash.new(0)
@size_cache = {}
end
# Returns a topologically sorted array of shape names to render as type aliases.
# Leaf dependencies come first so aliases can reference each other.
# params.rbs uses aliases in this list to define RBS type aliases.
# client_class.rbs uses aliases in this list to deduplicate content.
def shapes_to_alias
count_shape_usage
aliased = @shape_usage_count.select do |shape_name, count|
shape = @api['shapes'][shape_name]
shape['type'] == 'structure' &&
count > 1 &&
rendered_rbs_line_count_heuristic(shape_name) > 5
end.keys.to_set
topological_sort(aliased)
end

private

def topological_sort(shape_names)
sorted = []
visited = Set.new
shape_names.each { |name| topo_visit(name, shape_names, visited, sorted) }
sorted
end

def topo_visit(name, shape_names, visited, sorted)
return if visited.include?(name)

visited << name
shape = @api['shapes'][name]
shape['members']&.each_value do |ref|
dep = ref['shape']
topo_visit(dep, shape_names, visited, sorted) if shape_names.include?(dep)
end
sorted << name
end

def count_shape_usage
@api['operations'].each_value do |op|
input_shape_name = op.dig('input', 'shape')
next unless input_shape_name

walk_shape(input_shape_name, Set.new)
end
end

def walk_shape(shape_name, ancestors)
return if ancestors.include?(shape_name)

ancestors += [shape_name]

shape = @api['shapes'][shape_name]
return unless shape && shape['type'] == 'structure'

shape['members']&.each_value do |member_ref|
member_shape_name = member_ref['shape']
member_shape = @api['shapes'][member_shape_name]
next unless member_shape

case member_shape['type']
when 'structure'
@shape_usage_count[member_shape_name] += 1
walk_shape(member_shape_name, ancestors)
when 'list'
walk_list(member_shape, ancestors)
when 'map'
walk_map(member_shape, ancestors)
end
end
end

def walk_list(list_shape, ancestors)
member_ref = list_shape['member']
member_shape = @api['shapes'][member_ref['shape']]
return unless member_shape

case member_shape['type']
when 'structure'
@shape_usage_count[member_ref['shape']] += 1
walk_shape(member_ref['shape'], ancestors)
when 'list'
walk_list(member_shape, ancestors)
when 'map'
walk_map(member_shape, ancestors)
end
end

def walk_map(map_shape, ancestors)
value_ref = map_shape['value']
value_shape = @api['shapes'][value_ref['shape']]
return unless value_shape

case value_shape['type']
when 'structure'
@shape_usage_count[value_ref['shape']] += 1
walk_shape(value_ref['shape'], ancestors)
when 'list'
walk_list(value_shape, ancestors)
when 'map'
walk_map(value_shape, ancestors)
end
end

def rendered_rbs_line_count_heuristic(shape_name, visited = Set.new)
return 0 if visited.include?(shape_name)

# Cache results to deduplicate calculation for nested structure types that get used multiple times in the model.
return @size_cache[shape_name] if @size_cache.key?(shape_name)

visited += [shape_name]
shape = @api['shapes'][shape_name]
return @size_cache[shape_name] = 1 unless shape['members']

@size_cache[shape_name] = shape['members'].sum do |_, ref|
member_line_count(ref, visited)
end
end

def member_line_count(ref, visited)
child = @api['shapes'][ref['shape']]
case child['type']
# If it's a structure, add 2 lines then recurse (one line each for '{' and '}' in rendered RBS)
when 'structure' then 2 + rendered_rbs_line_count_heuristic(ref['shape'], visited)
when 'list'
member_shape = @api['shapes'][child['member']['shape']]
# If it's a list with structure member, add 2 lines & recurse.
# (one line each for 'Hash[::String, {' and '}]' in rendered RBS)
# If it's a list with primitive member, add 1 and return
# (e.g., renders as 1-liner, like 'list: Array[::String]')
member_shape['type'] == 'structure' ? 2 +
rendered_rbs_line_count_heuristic(child['member']['shape'], visited) : 1
when 'map'
value_shape = @api['shapes'][child['value']['shape']]
# If it's a map with structure member, add 2 lines & recurse.
# (one line each for 'Hash[::String, {' and '}]' in rendered RBS)
# If it's a map with primitive member, add 1 and return
# (e.g., renders as 1-liner, like 'map: Hash[::String, ::String]')
value_shape['type'] == 'structure' ? 2 +
rendered_rbs_line_count_heuristic(child['value']['shape'], visited) : 1
else 1
end
end
end
end
end
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ def format(indent: '')
result.join(joint)
end

def format_as_alias(indent: '')
struct(@shape, indent, [])
end

private

def struct(struct_shape, i, visited)
members_str = struct_members(struct_shape, i, visited, keyword: false)
result = ["{"]
Expand Down Expand Up @@ -67,12 +73,16 @@ def struct_member(struct, member_name, member_ref, i, visited, keyword:)
end

def ref_value(ref, i, visited)
if visited.include?(ref['shape'])
return "untyped"
else
visited = visited + [ref['shape']]
return "untyped" if visited.include?(ref['shape'])

# If this shape should be aliased, emit the alias reference
if @options[:aliased_shapes]&.include?(ref['shape'])
alias_name = underscore(ref['shape'])
return "Params::#{alias_name}"
end

visited = visited + [ref['shape']]

s = shape(ref)
case s['type']
when 'structure'
Expand Down Expand Up @@ -115,7 +125,7 @@ def scalar_list(member_ref, i, visited)

def complex_list(member_ref, i, visited)
newline_indent = newline ? "\n#{i}" : ""
"Array[#{newline_indent}#{more_indent}#{ref_value(member_ref, i + more_indent, visited)},#{newline_indent}]"
"Array[#{newline_indent}#{more_indent}#{ref_value(member_ref, i + more_indent, visited)}#{newline_indent}]"
end

def complex?(ref)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def initialize(options)
@plugins = PluginList.new(options)
@codegenerated_plugins = options.fetch(:codegenerated_plugins)
@protocol_settings = options.fetch(:protocol_settings, {})
@aliased_shapes = options.fetch(:aliased_shapes, Set.new).to_set
end

# @return [String|nil]
Expand Down Expand Up @@ -68,7 +69,7 @@ def operations
api: @api,
shape: input_shape,
newline: true,
options: options
options: options.merge(aliased_shapes: @aliased_shapes)
)
arguments = builder.format(indent: indent)
include_required = input_shape['required']&.empty?&.!
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def initialize(options)
@codegenerated_plugins = options.fetch(:codegenerated_plugins)
@waiters = AwsSdkCodeGenerator::RBS::Waiter.build_list(api: @api, waiters:options.fetch(:waiters))
@protocol_settings = options.fetch(:protocol_settings, {})
@aliased_shapes = options.fetch(:aliased_shapes, Set.new).to_set
end

# @return [String|nil]
Expand Down Expand Up @@ -60,6 +61,7 @@ def operations
api: @api,
shape: input_shape,
newline: true,
options: { aliased_shapes: @aliased_shapes }
)
arguments = builder.format(indent: indent)
include_required = input_shape["required"]&.empty?&.!
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# frozen_string_literal: true

module AwsSdkCodeGenerator
module Views
module RBS
class Params < View
include Helper

def initialize(options)
@service_name = options.fetch(:service_name)
@api = options.fetch(:api)
@aliased_shapes = options.fetch(:aliased_shapes)
end

def generated_src_warning
GENERATED_SRC_WARNING
end

def service_name
@service_name
end

def type_aliases
aliased_set = @aliased_shapes.to_set
@aliased_shapes.map do |shape_name|
shape = @api['shapes'][shape_name]
builder = AwsSdkCodeGenerator::RBS::KeywordArgumentBuilder.new(
api: @api,
shape: shape,
newline: true,
options: { aliased_shapes: aliased_set - [shape_name] }
)
{
'name' => underscore(shape_name),
'definition' => builder.format_as_alias(indent: ' '),
}
end
end
end
end
end
end
Loading
Loading