Files
bash-for-windows/internal/shell/shell.go
Cametendo 8c6a2ab4c2 Add control flow, I/O redirection, functions, coreutils, history/completion
- I/O redirection: >, >>, <, 2>, 2>&1, &>
- Job control: background & operator
- if/elif/else/fi, for/do/done, while/until loops
- Shell functions with local/declare, positional param save/restore
- Exit code propagation via exitCodeErr sentinel
- Arithmetic expansion $((expr)) with bare variable names
- Command substitution $(cmd) with pipeline support
- Glob expansion, tilde expansion, ${VAR:-default} and other forms
- Tab completion and command history via chzyer/readline
- Inline comment stripping (# outside quotes)
- Builtins: test/[, read, printf, tr, sed, cut, tail, tee, xargs,
  basename, dirname, date, sleep, uniq, sort, wc, head, grep, find,
  true, false, break, continue, return, shift, set, unset, export,
  declare/local, source, alias, jobs, command, which, env
- Bug fixes: tokenizer parenDepth double-count for $((,
  splitPipe not paren-aware (broke pipelines in $()),
  local/declare TrimLeft stripping valid var name chars,
  parseBlocks missing nested keywords after do/then/else

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-26 12:50:06 +02:00

510 lines
11 KiB
Go

package shell
import (
"fmt"
"io"
"os"
"strings"
)
// Sentinel errors for control flow
type breakErr struct{ n int }
type continueErr struct{ n int }
type returnErr struct{ code int }
func (e breakErr) Error() string { return fmt.Sprintf("break %d", e.n) }
func (e continueErr) Error() string { return fmt.Sprintf("continue %d", e.n) }
func (e returnErr) Error() string { return fmt.Sprintf("return %d", e.code) }
// exitCodeErr carries a non-zero exit code that sets $? without a message.
// Used by functions, test, false, etc.
type exitCodeErr struct{ code int }
func (e exitCodeErr) Error() string { return "" }
type Shell struct {
vars map[string]string
builtins map[string]func([]string) error
funcs map[string]string // function name → body
lastExit int
Stdin io.Reader
Stdout io.Writer
Stderr io.Writer
args []string
}
func New() *Shell {
s := &Shell{
vars: map[string]string{},
funcs: map[string]string{},
Stdin: os.Stdin,
Stdout: os.Stdout,
Stderr: os.Stderr,
}
s.initBuiltins()
s.vars["SHELL"] = "bash-for-windows"
s.vars["BASH_VERSION"] = "5.2.15(1)-release"
s.vars["?"] = "0"
s.vars["#"] = "0"
s.vars["@"] = ""
s.vars["*"] = ""
s.vars["!"] = ""
if pwd, err := os.Getwd(); err == nil {
s.vars["PWD"] = pwd
}
if home, err := os.UserHomeDir(); err == nil {
s.vars["HOME"] = home
}
return s
}
func (s *Shell) SetArgs(args []string) {
s.args = args
s.vars["#"] = fmt.Sprintf("%d", len(args))
s.vars["@"] = strings.Join(args, " ")
s.vars["*"] = strings.Join(args, " ")
for i, a := range args {
s.vars[fmt.Sprintf("%d", i+1)] = a
}
}
func (s *Shell) GetVar(name string) string {
if v, ok := s.vars[name]; ok {
return v
}
return os.Getenv(name)
}
func (s *Shell) SetVar(name, value string) {
s.vars[name] = value
}
// Execute runs commands from the given input string.
func (s *Shell) Execute(input string) error {
input = strings.ReplaceAll(input, "\\\n", " ")
blocks := parseBlocks(input)
for _, block := range blocks {
if err := s.executeBlock(block); err != nil {
switch err.(type) {
case breakErr, continueErr, returnErr:
return err
}
s.setExitCode(err)
}
}
return nil
}
// IsIncomplete returns true if the input is an incomplete multi-line construct.
func IsIncomplete(input string) bool {
stmts := splitStatements(input)
depth := 0
inSingle := false
inDouble := false
for _, ch := range input {
switch ch {
case '\'':
if !inDouble {
inSingle = !inSingle
}
case '"':
if !inSingle {
inDouble = !inDouble
}
}
}
if inSingle || inDouble {
return true
}
for _, stmt := range stmts {
w := firstWord(stmt)
switch w {
case "if", "for", "while", "until":
depth++
case "fi", "done", "esac":
depth--
case "{":
depth++
case "}":
if depth > 0 {
depth--
}
}
if isFuncDefStart(stmt) && strings.HasSuffix(strings.TrimSpace(stmt), "{") {
depth++
}
}
return depth > 0
}
// parseBlocks groups statements into logical execution units.
// Multi-line if/for/while/function blocks are gathered into single entries.
func parseBlocks(input string) []string {
stmts := splitStatements(input)
var blocks []string
var current []string
kwDepth := 0 // if/for/while/until → fi/done nesting
inFunc := false
funcKwDepth := 0 // keyword nesting inside a function body
for _, stmt := range stmts {
stmt = strings.TrimSpace(stmt)
if stmt == "" || strings.HasPrefix(stmt, "#") {
continue
}
w := firstWord(stmt)
if !inFunc {
// Detect function definition opening with `{`
if isFuncDefStart(stmt) && strings.Contains(stmt, "{") {
braceIdx := strings.Index(stmt, "{")
// Count keywords after the { on this same line
funcKwDepth = 0
for _, p := range splitStatements(stmt[braceIdx+1:]) {
switch firstWord(p) {
case "if", "for", "while", "until":
funcKwDepth++
case "fi", "done", "esac":
funcKwDepth--
}
}
// If the line also ends with } it's a self-contained function
if strings.HasSuffix(stmt, "}") {
current = append(current, stmt)
blocks = append(blocks, strings.Join(current, "\n"))
current = nil
funcKwDepth = 0
continue
}
inFunc = true
current = append(current, stmt)
continue
}
switch w {
case "if", "for", "while", "until":
kwDepth++
}
kwDepth += embeddedKwDepth(stmt)
current = append(current, stmt)
switch w {
case "fi", "done", "esac":
kwDepth--
case "}":
if kwDepth > 0 {
kwDepth--
}
}
if kwDepth <= 0 && len(current) > 0 {
kwDepth = 0
blocks = append(blocks, strings.Join(current, "\n"))
current = nil
}
} else {
// Inside function body — watch for } at funcKwDepth==0
if w == "}" && funcKwDepth <= 0 {
current = append(current, stmt)
blocks = append(blocks, strings.Join(current, "\n"))
current = nil
inFunc = false
funcKwDepth = 0
continue
}
switch w {
case "if", "for", "while", "until":
funcKwDepth++
case "fi", "done", "esac":
funcKwDepth--
}
funcKwDepth += embeddedKwDepth(stmt)
current = append(current, stmt)
}
}
if len(current) > 0 {
blocks = append(blocks, strings.Join(current, "\n"))
}
return blocks
}
// embeddedKwDepth returns the net depth change from keywords that appear
// after do/then/else/elif within a single statement (excluding the first word,
// which is handled separately by the caller).
func embeddedKwDepth(stmt string) int {
words := strings.Fields(stmt)
delta := 0
for j := 1; j < len(words); j++ {
switch words[j-1] {
case "do", "then", "else", "elif":
switch words[j] {
case "if", "for", "while", "until":
delta++
case "fi", "done", "esac":
delta--
}
}
}
return delta
}
// splitStatements splits input on semicolons and newlines, respecting quotes.
func splitStatements(input string) []string {
var result []string
current := strings.Builder{}
inSingle := false
inDouble := false
for i := 0; i < len(input); i++ {
c := input[i]
switch {
case c == '\'' && !inDouble:
inSingle = !inSingle
current.WriteByte(c)
case c == '"' && !inSingle:
inDouble = !inDouble
current.WriteByte(c)
case (c == ';' || c == '\n') && !inSingle && !inDouble:
if s := strings.TrimSpace(current.String()); s != "" {
result = append(result, s)
}
current.Reset()
default:
current.WriteByte(c)
}
}
if s := strings.TrimSpace(current.String()); s != "" {
result = append(result, s)
}
return result
}
func firstWord(s string) string {
fields := strings.Fields(s)
if len(fields) == 0 {
return ""
}
return fields[0]
}
func afterWord(s string) string {
for i, ch := range s {
if ch == ' ' || ch == '\t' {
return strings.TrimSpace(s[i:])
}
}
return ""
}
func isFuncDefStart(stmt string) bool {
if strings.HasPrefix(stmt, "function ") {
return true
}
for i, ch := range stmt {
if ch == ' ' || ch == '\t' {
break
}
if ch == '(' {
name := strings.TrimSpace(stmt[:i])
return isValidIdentifier(name)
}
}
return false
}
func isValidIdentifier(s string) bool {
if len(s) == 0 {
return false
}
for i, c := range s {
if i == 0 {
if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || c == '_') {
return false
}
} else {
if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_') {
return false
}
}
}
return true
}
func (s *Shell) executeBlock(block string) error {
block = strings.TrimSpace(block)
if block == "" || strings.HasPrefix(block, "#") {
return nil
}
w := firstWord(block)
switch w {
case "if":
return s.executeIf(block)
case "for":
return s.executeFor(block)
case "while":
return s.executeWhileUntil(block, false)
case "until":
return s.executeWhileUntil(block, true)
}
if isFuncDefStart(block) {
return s.defineFunction(block)
}
for _, line := range strings.Split(block, "\n") {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "#") {
continue
}
if err := s.executeLine(line); err != nil {
return err
}
}
return nil
}
func (s *Shell) executeLine(line string) error {
return s.executeChain(line)
}
func (s *Shell) executeChain(line string) error {
for _, seg := range splitBySemicolon(line) {
seg = strings.TrimSpace(seg)
if seg == "" {
continue
}
if err := s.executeAndOrList(seg); err != nil {
switch err.(type) {
case breakErr, continueErr, returnErr:
return err
}
s.setExitCode(err)
}
}
return nil
}
func splitBySemicolon(line string) []string {
var parts []string
current := strings.Builder{}
inSingle := false
inDouble := false
for i := 0; i < len(line); i++ {
c := line[i]
switch {
case c == '\'' && !inDouble:
inSingle = !inSingle
current.WriteByte(c)
case c == '"' && !inSingle:
inDouble = !inDouble
current.WriteByte(c)
case c == ';' && !inSingle && !inDouble:
parts = append(parts, current.String())
current.Reset()
default:
current.WriteByte(c)
}
}
if current.Len() > 0 {
parts = append(parts, current.String())
}
return parts
}
func (s *Shell) executeAndOrList(line string) error {
type tok struct {
text string
op string
}
var tokens []tok
current := strings.Builder{}
op := ""
inSingle := false
inDouble := false
for i := 0; i < len(line); i++ {
c := line[i]
switch {
case c == '\'' && !inDouble:
inSingle = !inSingle
current.WriteByte(c)
case c == '"' && !inSingle:
inDouble = !inDouble
current.WriteByte(c)
case c == '&' && !inSingle && !inDouble && i+1 < len(line) && line[i+1] == '&':
tokens = append(tokens, tok{current.String(), op})
current.Reset()
op = "&&"
i++
case c == '|' && !inSingle && !inDouble && i+1 < len(line) && line[i+1] == '|':
tokens = append(tokens, tok{current.String(), op})
current.Reset()
op = "||"
i++
default:
current.WriteByte(c)
}
}
if current.Len() > 0 {
tokens = append(tokens, tok{current.String(), op})
}
var lastErr error
for i, t := range tokens {
cmd := strings.TrimSpace(t.text)
if cmd == "" {
continue
}
run := i == 0
if i > 0 {
run = (t.op == "&&" && lastErr == nil) || (t.op == "||" && lastErr != nil)
}
if run {
err := s.executePipeline(cmd)
lastErr = err
s.setExitCode(err)
}
}
return lastErr
}
func (s *Shell) setExitCode(err error) {
if err == nil {
s.vars["?"] = "0"
s.lastExit = 0
} else if ec, ok := err.(exitCodeErr); ok {
s.vars["?"] = fmt.Sprintf("%d", ec.code)
s.lastExit = ec.code
} else {
s.vars["?"] = "1"
s.lastExit = 1
}
}
// BuiltinNames returns a sorted list of all registered builtin names (for tab completion).
func (s *Shell) BuiltinNames() []string {
names := make([]string, 0, len(s.builtins)+len(s.funcs))
for k := range s.builtins {
names = append(names, k)
}
for k := range s.funcs {
names = append(names, k)
}
return names
}
// withIO temporarily swaps stdin/stdout/stderr, runs fn, then restores.
// Pass nil to leave the corresponding stream unchanged.
func (s *Shell) withIO(stdin io.Reader, stdout io.Writer, stderr io.Writer, fn func() error) error {
oldIn, oldOut, oldErr := s.Stdin, s.Stdout, s.Stderr
if stdin != nil {
s.Stdin = stdin
}
if stdout != nil {
s.Stdout = stdout
}
if stderr != nil {
s.Stderr = stderr
}
err := fn()
s.Stdin, s.Stdout, s.Stderr = oldIn, oldOut, oldErr
return err
}