Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions crates/rustc_codegen_spirv/src/linker/ipo.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
//! Tools for interprocedural optimizations (aka "IPO"s).

// FIXME(eddyb) perhaps make all IPOs sub-modules of this module?

use indexmap::IndexSet;
use rspirv::dr::Module;
use rspirv::spirv::Op;
use rustc_data_structures::fx::FxHashMap;

// FIXME(eddyb) use newtyped indices and `IndexVec`.
type FuncIdx = usize;

pub struct CallGraph {
pub entry_points: IndexSet<FuncIdx>,

/// `callees[i].contains(j)` implies `functions[i]` calls `functions[j]`.
callees: Vec<IndexSet<FuncIdx>>,
}

impl CallGraph {
pub fn collect(module: &Module) -> Self {
let func_id_to_idx: FxHashMap<_, _> = module
.functions
.iter()
.enumerate()
.map(|(i, func)| (func.def_id().unwrap(), i))
.collect();
let entry_points = module
.entry_points
.iter()
.map(|entry| {
assert_eq!(entry.class.opcode, Op::EntryPoint);
func_id_to_idx[&entry.operands[1].unwrap_id_ref()]
})
.collect();
let callees = module
.functions
.iter()
.map(|func| {
func.all_inst_iter()
.filter(|inst| inst.class.opcode == Op::FunctionCall)
.filter_map(|inst| {
// FIXME(eddyb) `func_id_to_idx` should always have an
// entry for a callee ID, but when ran early enough
// (before zombie removal), the callee ID might not
// point to an `OpFunction` (unsure what, `OpUndef`?).
func_id_to_idx
.get(&inst.operands[0].unwrap_id_ref())
.copied()
})
.collect()
})
.collect();
Self {
entry_points,
callees,
}
}

/// Order functions using a post-order traversal, i.e. callees before callers.
// FIXME(eddyb) replace this with `rustc_data_structures::graph::iterate`
// (or similar).
pub fn post_order(&self) -> Vec<FuncIdx> {
let num_funcs = self.callees.len();

// FIXME(eddyb) use a proper bitset.
let mut visited = vec![false; num_funcs];
let mut post_order = Vec::with_capacity(num_funcs);

// Visit the call graph with entry points as roots.
for &entry in &self.entry_points {
self.post_order_step(entry, &mut visited, &mut post_order);
}

// Also visit any functions that were not reached from entry points
// (they might be dead but they should be processed nonetheless).
for func in 0..num_funcs {
if !visited[func] {
self.post_order_step(func, &mut visited, &mut post_order);
}
}

post_order
}

fn post_order_step(&self, func: FuncIdx, visited: &mut [bool], post_order: &mut Vec<FuncIdx>) {
if visited[func] {
return;
}
visited[func] = true;

for &callee in &self.callees[func] {
self.post_order_step(callee, visited, post_order);
}

post_order.push(func);
}
}
10 changes: 10 additions & 0 deletions crates/rustc_codegen_spirv/src/linker/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ mod destructure_composites;
mod duplicates;
mod import_export_link;
mod inline;
mod ipo;
mod mem2reg;
mod param_weakening;
mod peephole_opts;
mod simple_passes;
mod specializer;
Expand Down Expand Up @@ -129,6 +131,14 @@ pub fn link(sess: &Session, mut inputs: Vec<Module>, opts: &Options) -> Result<L
import_export_link::run(sess, &mut output)?;
}

// HACK(eddyb) this has to run before the `remove_zombies` pass, so that any
// zombies that are passed as call arguments, but eventually unused, won't
// be (incorrectly) considered used.
{
let _timer = sess.timer("link_remove_unused_params");
output = param_weakening::remove_unused_params(output);
}

{
let _timer = sess.timer("link_remove_zombies");
zombies::remove_zombies(sess, &mut output);
Expand Down
114 changes: 114 additions & 0 deletions crates/rustc_codegen_spirv/src/linker/param_weakening.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
//! Interprocedural optimizations that "weaken" function parameters, i.e. they
//! replace parameter types with "simpler" ones, or outright remove parameters,
//! based on how those parameters are used in the function and/or what arguments
//! get passed from callers.
//!
use crate::linker::ipo::CallGraph;
use indexmap::IndexMap;
use rspirv::dr::{Builder, Module, Operand};
use rspirv::spirv::{Op, Word};
use rustc_data_structures::fx::FxHashMap;
use rustc_index::bit_set::BitSet;
use std::mem;

pub fn remove_unused_params(module: Module) -> Module {
let call_graph = CallGraph::collect(&module);

// Gather all of the unused parameters for each function, transitively.
// (i.e. parameters which are passed, as call arguments, to functions that
// won't use them, are also considered unused, through any number of calls)
let mut unused_params_per_func_id: IndexMap<Word, BitSet<usize>> = IndexMap::new();
for func_idx in call_graph.post_order() {
// Skip entry points, as they're the only "exported" functions, at least
// at link-time (likely only relevant to `Kernel`s, but not `Shader`s).
if call_graph.entry_points.contains(&func_idx) {
continue;
}

let func = &module.functions[func_idx];

let params_id_to_idx: FxHashMap<Word, usize> = func
.parameters
.iter()
.enumerate()
.map(|(i, p)| (p.result_id.unwrap(), i))
.collect();
let mut unused_params = BitSet::new_filled(func.parameters.len());
for inst in func.all_inst_iter() {
// If this is a call, we can ignore the arguments passed to the
// callee parameters we already determined to be unused, because
// those parameters (and matching arguments) will get removed later.
let (operands, ignore_operands) = if inst.class.opcode == Op::FunctionCall {
(
&inst.operands[1..],
unused_params_per_func_id.get(&inst.operands[0].unwrap_id_ref()),
)
} else {
(&inst.operands[..], None)
};

for (i, operand) in operands.iter().enumerate() {
if let Some(ignore_operands) = ignore_operands {
if ignore_operands.contains(i) {
continue;
}
}

if let Operand::IdRef(id) = operand {
if let Some(&param_idx) = params_id_to_idx.get(id) {
unused_params.remove(param_idx);
}
}
}
}

if !unused_params.is_empty() {
unused_params_per_func_id.insert(func.def_id().unwrap(), unused_params);
}
}

// Remove unused parameters and call arguments for unused parameters.
let mut builder = Builder::new_from_module(module);
for func_idx in 0..builder.module_ref().functions.len() {
let func = &mut builder.module_mut().functions[func_idx];
let unused_params = unused_params_per_func_id.get(&func.def_id().unwrap());
if let Some(unused_params) = unused_params {
func.parameters = mem::take(&mut func.parameters)
.into_iter()
.enumerate()
.filter(|&(i, _)| !unused_params.contains(i))
.map(|(_, p)| p)
.collect();
}

for inst in func.all_inst_iter_mut() {
if inst.class.opcode == Op::FunctionCall {
if let Some(unused_callee_params) =
unused_params_per_func_id.get(&inst.operands[0].unwrap_id_ref())
{
inst.operands = mem::take(&mut inst.operands)
.into_iter()
.enumerate()
.filter(|&(i, _)| i == 0 || !unused_callee_params.contains(i - 1))
.map(|(_, o)| o)
.collect();
}
}
}

// Regenerate the function type from remaining parameters, if necessary.
if unused_params.is_some() {
let return_type = func.def.as_mut().unwrap().result_type.unwrap();
let new_param_types: Vec<_> = func
.parameters
.iter()
.map(|inst| inst.result_type.unwrap())
.collect();
let new_func_type = builder.type_function(return_type, new_param_types);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hrmpf. O(n^2) linear search here for dedup isn't great, but I guess it's fine, usually not many functions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"usually not many" - famous last words ;)

let func = &mut builder.module_mut().functions[func_idx];
func.def.as_mut().unwrap().operands[1] = Operand::IdRef(new_func_type);
}
}

builder.module()
}
85 changes: 2 additions & 83 deletions crates/rustc_codegen_spirv/src/linker/specializer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@
//! expand_params: Option<Vec<usize>>,
//! ```

use crate::linker::ipo::CallGraph;
use crate::spirv_type_constraints::{self, InstSig, StorageClassPat, TyListPat, TyPat};
use indexmap::{IndexMap, IndexSet};
use indexmap::IndexMap;
use rspirv::dr::{Builder, Function, Instruction, Module, Operand};
use rspirv::spirv::{Op, StorageClass, Word};
use rustc_data_structures::captures::Captures;
Expand Down Expand Up @@ -186,88 +187,6 @@ pub fn specialize(module: Module, specialization: impl Specialization) -> Module
expander.expand_module()
}

// FIXME(eddyb) use newtyped indices and `IndexVec`.
type FuncIdx = usize;

struct CallGraph {
entry_points: IndexSet<FuncIdx>,

/// `callees[i].contains(j)` implies `functions[i]` calls `functions[j]`.
callees: Vec<IndexSet<FuncIdx>>,
}

impl CallGraph {
fn collect(module: &Module) -> Self {
let func_id_to_idx: FxHashMap<_, _> = module
.functions
.iter()
.enumerate()
.map(|(i, func)| (func.def_id().unwrap(), i))
.collect();
let entry_points = module
.entry_points
.iter()
.map(|entry| {
assert_eq!(entry.class.opcode, Op::EntryPoint);
func_id_to_idx[&entry.operands[1].unwrap_id_ref()]
})
.collect();
let callees = module
.functions
.iter()
.map(|func| {
func.all_inst_iter()
.filter(|inst| inst.class.opcode == Op::FunctionCall)
.map(|inst| func_id_to_idx[&inst.operands[0].unwrap_id_ref()])
.collect()
})
.collect();
Self {
entry_points,
callees,
}
}

/// Order functions using a post-order traversal, i.e. callees before callers.
// FIXME(eddyb) replace this with `rustc_data_structures::graph::iterate`
// (or similar).
fn post_order(&self) -> Vec<FuncIdx> {
let num_funcs = self.callees.len();

// FIXME(eddyb) use a proper bitset.
let mut visited = vec![false; num_funcs];
let mut post_order = Vec::with_capacity(num_funcs);

// Visit the call graph with entry points as roots.
for &entry in &self.entry_points {
self.post_order_step(entry, &mut visited, &mut post_order);
}

// Also visit any functions that were not reached from entry points
// (they might be dead but they should be processed nonetheless).
for func in 0..num_funcs {
if !visited[func] {
self.post_order_step(func, &mut visited, &mut post_order);
}
}

post_order
}

fn post_order_step(&self, func: FuncIdx, visited: &mut [bool], post_order: &mut Vec<FuncIdx>) {
if visited[func] {
return;
}
visited[func] = true;

for &callee in &self.callees[func] {
self.post_order_step(callee, visited, post_order);
}

post_order.push(func);
}
}

// HACK(eddyb) `Copy` version of `Operand` that only includes the cases that
// are relevant to the inference algorithm (and is also smaller).
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
Expand Down
Loading