Skip to content

Commit

Permalink
[red-knot] Condense literals display by types (#13185)
Browse files Browse the repository at this point in the history
Co-authored-by: Micha Reiser <[email protected]>
  • Loading branch information
Slyces and MichaReiser committed Sep 3, 2024
1 parent 599103c commit 46e687e
Show file tree
Hide file tree
Showing 4 changed files with 218 additions and 57 deletions.
1 change: 1 addition & 0 deletions crates/red_knot_python_semantic/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ pub(crate) mod site_packages;
pub mod types;

type FxOrderSet<V> = ordermap::set::OrderSet<V, BuildHasherDefault<FxHasher>>;
type FxOrderMap<K, V> = ordermap::map::OrderMap<K, V, BuildHasherDefault<FxHasher>>;
13 changes: 13 additions & 0 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,19 @@ impl<'db> Type<'db> {
matches!(self, Type::Never)
}

/// Returns `true` if this type should be displayed as a literal value.
pub const fn is_literal(&self) -> bool {
matches!(
self,
Type::IntLiteral(_)
| Type::BooleanLiteral(_)
| Type::StringLiteral(_)
| Type::BytesLiteral(_)
| Type::Class(_)
| Type::Function(_)
)
}

pub fn may_be_unbound(&self, db: &'db dyn Db) -> bool {
match self {
Type::Unbound => true,
Expand Down
249 changes: 198 additions & 51 deletions crates/red_knot_python_semantic/src/types/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,16 @@ use ruff_python_ast::str::Quote;
use ruff_python_literal::escape::AsciiEscape;

use crate::types::{IntersectionType, Type, UnionType};
use crate::Db;
use crate::{Db, FxOrderMap};

impl<'db> Type<'db> {
pub fn display(&'db self, db: &'db dyn Db) -> DisplayType<'db> {
DisplayType { ty: self, db }
}

fn representation(&'db self, db: &'db dyn Db) -> DisplayRepresentation<'db> {
DisplayRepresentation { db, ty: self }
}
}

#[derive(Copy, Clone)]
Expand All @@ -21,6 +25,31 @@ pub struct DisplayType<'db> {
}

impl Display for DisplayType<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let representation = self.ty.representation(self.db);
if self.ty.is_literal() {
write!(f, "Literal[{representation}]",)
} else {
representation.fmt(f)
}
}
}

impl std::fmt::Debug for DisplayType<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
std::fmt::Display::fmt(self, f)
}
}

/// Writes the string representation of a type, which is the value displayed either as
/// `Literal[<repr>]` or `Literal[<repr1>, <repr2>]` for literal types or as `<repr>` for
/// non literals
struct DisplayRepresentation<'db> {
ty: &'db Type<'db>,
db: &'db dyn Db,
}

impl std::fmt::Display for DisplayRepresentation<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self.ty {
Type::Any => f.write_str("Any"),
Expand All @@ -32,39 +61,27 @@ impl Display for DisplayType<'_> {
write!(f, "<module '{:?}'>", file.path(self.db))
}
// TODO functions and classes should display using a fully qualified name
Type::Class(class) => write!(f, "Literal[{}]", class.name(self.db)),
Type::Class(class) => f.write_str(class.name(self.db)),
Type::Instance(class) => f.write_str(class.name(self.db)),
Type::Function(function) => write!(f, "Literal[{}]", function.name(self.db)),
Type::Function(function) => f.write_str(function.name(self.db)),
Type::Union(union) => union.display(self.db).fmt(f),
Type::Intersection(intersection) => intersection.display(self.db).fmt(f),
Type::IntLiteral(n) => write!(f, "Literal[{n}]"),
Type::BooleanLiteral(boolean) => {
write!(f, "Literal[{}]", if *boolean { "True" } else { "False" })
Type::IntLiteral(n) => write!(f, "{n}"),
Type::BooleanLiteral(boolean) => f.write_str(if *boolean { "True" } else { "False" }),
Type::StringLiteral(string) => {
write!(f, r#""{}""#, string.value(self.db).replace('"', r#"\""#))
}
Type::StringLiteral(string) => write!(
f,
r#"Literal["{}"]"#,
string.value(self.db).replace('"', r#"\""#)
),
Type::LiteralString => write!(f, "LiteralString"),
Type::LiteralString => f.write_str("LiteralString"),
Type::BytesLiteral(bytes) => {
let escape =
AsciiEscape::with_preferred_quote(bytes.value(self.db).as_ref(), Quote::Double);

f.write_str("Literal[")?;
escape.bytes_repr().write(f)?;
f.write_str("]")
escape.bytes_repr().write(f)
}
}
}
}

impl std::fmt::Debug for DisplayType<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
std::fmt::Display::fmt(self, f)
}
}

impl<'db> UnionType<'db> {
fn display(&'db self, db: &'db dyn Db) -> DisplayUnionType<'db> {
DisplayUnionType { db, ty: self }
Expand All @@ -78,46 +95,62 @@ struct DisplayUnionType<'db> {

impl Display for DisplayUnionType<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let union = self.ty;
let elements = self.ty.elements(self.db);

let (int_literals, other_types): (Vec<Type>, Vec<Type>) = union
.elements(self.db)
.iter()
.copied()
.partition(|ty| matches!(ty, Type::IntLiteral(_)));
// Group literal types by kind.
let mut grouped_literals = FxOrderMap::default();

for element in elements {
if let Ok(literal_kind) = LiteralTypeKind::try_from(*element) {
grouped_literals
.entry(literal_kind)
.or_insert_with(Vec::new)
.push(*element);
}
}

let mut first = true;
if !int_literals.is_empty() {
f.write_str("Literal[")?;
let mut nums: Vec<_> = int_literals
.into_iter()
.filter_map(|ty| {
if let Type::IntLiteral(n) = ty {
Some(n)
} else {
None
}
})
.collect();
nums.sort_unstable();
for num in nums {

// Print all types, but write all literals together (while preserving their position).
for ty in elements {
if let Ok(literal_kind) = LiteralTypeKind::try_from(*ty) {
let Some(mut literals) = grouped_literals.remove(&literal_kind) else {
continue;
};

if !first {
f.write_str(", ")?;
f.write_str(" | ")?;
};

f.write_str("Literal[")?;

if literal_kind == LiteralTypeKind::IntLiteral {
literals.sort_unstable_by_key(|ty| match ty {
Type::IntLiteral(n) => *n,
_ => panic!("Expected only int literals when kind is IntLiteral"),
});
}
write!(f, "{num}")?;
first = false;

for (i, literal_ty) in literals.iter().enumerate() {
if i > 0 {
f.write_str(", ")?;
}
literal_ty.representation(self.db).fmt(f)?;
}
f.write_str("]")?;
} else {
if !first {
f.write_str(" | ")?;
};

ty.display(self.db).fmt(f)?;
}
f.write_str("]")?;
}

for ty in other_types {
if !first {
f.write_str(" | ")?;
};
first = false;
write!(f, "{}", ty.display(self.db))?;
}

debug_assert!(grouped_literals.is_empty());

Ok(())
}
}
Expand All @@ -128,6 +161,30 @@ impl std::fmt::Debug for DisplayUnionType<'_> {
}
}

#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
enum LiteralTypeKind {
Class,
Function,
IntLiteral,
StringLiteral,
BytesLiteral,
}

impl TryFrom<Type<'_>> for LiteralTypeKind {
type Error = ();

fn try_from(value: Type<'_>) -> Result<Self, Self::Error> {
match value {
Type::Class(_) => Ok(Self::Class),
Type::Function(_) => Ok(Self::Function),
Type::IntLiteral(_) => Ok(Self::IntLiteral),
Type::StringLiteral(_) => Ok(Self::StringLiteral),
Type::BytesLiteral(_) => Ok(Self::BytesLiteral),
_ => Err(()),
}
}
}

impl<'db> IntersectionType<'db> {
fn display(&'db self, db: &'db dyn Db) -> DisplayIntersectionType<'db> {
DisplayIntersectionType { db, ty: self }
Expand Down Expand Up @@ -167,3 +224,93 @@ impl std::fmt::Debug for DisplayIntersectionType<'_> {
std::fmt::Display::fmt(self, f)
}
}

#[cfg(test)]
mod tests {
use ruff_db::files::system_path_to_file;
use ruff_db::system::{DbWithTestSystem, SystemPathBuf};

use crate::db::tests::TestDb;
use crate::types::{
global_symbol_ty_by_name, BytesLiteralType, StringLiteralType, Type, UnionBuilder,
};
use crate::{Program, ProgramSettings, PythonVersion, SearchPathSettings};

fn setup_db() -> TestDb {
let db = TestDb::new();

let src_root = SystemPathBuf::from("/src");
db.memory_file_system()
.create_directory_all(&src_root)
.unwrap();

Program::from_settings(
&db,
&ProgramSettings {
target_version: PythonVersion::default(),
search_paths: SearchPathSettings::new(src_root),
},
)
.expect("Valid search path settings");

db
}

#[test]
fn test_condense_literal_display_by_type() -> anyhow::Result<()> {
let mut db = setup_db();

db.write_dedented(
"src/main.py",
"
def foo(x: int) -> int:
return x + 1
def bar(s: str) -> str:
return s
class A: ...
class B: ...
",
)?;
let mod_file = system_path_to_file(&db, "src/main.py").expect("Expected file to exist.");

let vec: Vec<Type<'_>> = vec![
Type::Unknown,
Type::IntLiteral(-1),
global_symbol_ty_by_name(&db, mod_file, "A"),
Type::StringLiteral(StringLiteralType::new(&db, Box::from("A"))),
Type::BytesLiteral(BytesLiteralType::new(&db, Box::from([0]))),
Type::BytesLiteral(BytesLiteralType::new(&db, Box::from([7]))),
Type::IntLiteral(0),
Type::IntLiteral(1),
Type::StringLiteral(StringLiteralType::new(&db, Box::from("B"))),
global_symbol_ty_by_name(&db, mod_file, "foo"),
global_symbol_ty_by_name(&db, mod_file, "bar"),
global_symbol_ty_by_name(&db, mod_file, "B"),
Type::BooleanLiteral(true),
Type::None,
];
let builder = vec.iter().fold(UnionBuilder::new(&db), |builder, literal| {
builder.add(*literal)
});
let Type::Union(union) = builder.build() else {
panic!("expected a union");
};
let display = format!("{}", union.display(&db));
assert_eq!(
display,
concat!(
"Unknown | ",
"Literal[-1, 0, 1] | ",
"Literal[A, B] | ",
"Literal[\"A\", \"B\"] | ",
"Literal[b\"\\x00\", b\"\\x07\"] | ",
"Literal[foo, bar] | ",
"Literal[True] | ",
"None"
)
);
Ok(())
}
}
12 changes: 6 additions & 6 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3093,7 +3093,7 @@ mod tests {
",
)?;

assert_public_ty(&db, "src/a.py", "x", "Literal[3] | Unbound");
assert_public_ty(&db, "src/a.py", "x", "Unbound | Literal[3]");
Ok(())
}

Expand All @@ -3119,8 +3119,8 @@ mod tests {
)?;

assert_public_ty(&db, "src/a.py", "x", "Literal[3, 4, 5]");
assert_public_ty(&db, "src/a.py", "r", "Literal[2] | Unbound");
assert_public_ty(&db, "src/a.py", "s", "Literal[5] | Unbound");
assert_public_ty(&db, "src/a.py", "r", "Unbound | Literal[2]");
assert_public_ty(&db, "src/a.py", "s", "Unbound | Literal[5]");
Ok(())
}

Expand Down Expand Up @@ -3356,7 +3356,7 @@ mod tests {

assert_eq!(
y_ty.display(&db).to_string(),
"Literal[1] | Literal[copyright]"
"Literal[copyright] | Literal[1]"
);

Ok(())
Expand Down Expand Up @@ -3389,7 +3389,7 @@ mod tests {
let y_ty = symbol_ty_by_name(&db, class_scope, "y");
let x_ty = symbol_ty_by_name(&db, class_scope, "x");

assert_eq!(x_ty.display(&db).to_string(), "Literal[2] | Unbound");
assert_eq!(x_ty.display(&db).to_string(), "Unbound | Literal[2]");
assert_eq!(y_ty.display(&db).to_string(), "Literal[1]");

Ok(())
Expand Down Expand Up @@ -3522,7 +3522,7 @@ mod tests {
",
)?;

assert_public_ty(&db, "/src/a.py", "x", "Literal[1] | None");
assert_public_ty(&db, "/src/a.py", "x", "None | Literal[1]");
assert_public_ty(&db, "/src/a.py", "y", "Literal[0, 1]");

Ok(())
Expand Down

0 comments on commit 46e687e

Please sign in to comment.