diff --git a/README.md b/README.md index 381d5da..44578de 100644 --- a/README.md +++ b/README.md @@ -55,8 +55,8 @@ The installer: | Shell | `cd`, `pwd`, `echo`, `exit`, `export`, `set`, `unset`, `source`/`.`, `alias`, `unalias`, `type`, `command`, `which`, `env` | | Control | `true`, `false`, `test`/`[`, `break`, `continue`, `return`, `shift`, `read`, `printf` | | Variables | `declare`, `local` | -| Files | `ls`, `cat`, `cp`, `mv`, `rm`, `mkdir`, `touch`, `find`, `stat`, `basename`, `dirname` | -| Text | `grep`, `sed`, `awk`, `sort`, `uniq`, `wc`, `head`, `tail`, `cut`, `tr`, `tee`, `xargs` | +| Files | `ls`, `cat`, `cp`, `mv`, `rm`, `mkdir`, `touch`, `find`, `basename`, `dirname` | +| Text | `grep`, `sed`, `sort`, `uniq`, `wc`, `head`, `tail`, `cut`, `tr`, `tee`, `xargs` | | System | `date`, `sleep`, `clear`, `jobs` | ## Usage @@ -85,8 +85,8 @@ for n in 1 2 3 4 5 6; do done # Pipelines and redirection -cat /etc/hosts | grep localhost | wc -l -find . -name "*.go" | xargs grep -l "TODO" > todo_files.txt +printf "127.0.0.1 localhost\n127.0.0.2 example\n" | grep localhost | wc -l +find . -name "*.go" | xargs grep "TODO" > todo_files.txt echo "result=$(date)" >> log.txt ``` diff --git a/cmd/bash/main.go b/cmd/bash/main.go index 0e244a7..9812697 100644 --- a/cmd/bash/main.go +++ b/cmd/bash/main.go @@ -60,12 +60,7 @@ func historyFile() string { func interactive() error { sh := shell.New() - // Build completer - completer := readline.NewPrefixCompleter( - readline.PcItemDynamic(func(line string) []string { - return dynamicComplete(sh, line) - }), - ) + completer := &shellCompleter{sh: sh} rl, err := readline.NewEx(&readline.Config{ HistoryFile: historyFile(), @@ -183,6 +178,56 @@ func buildPrompt(sh *shell.Shell) string { return pwd + suffix } +// shellCompleter implements readline.AutoCompleter. +// readline's built-in PrefixCompleter matches candidates against the full +// line, so "C:\workspace" never matches "cd C:\w". This implementation +// extracts the last token and returns only the suffix to append. +type shellCompleter struct{ sh *shell.Shell } + +func (c *shellCompleter) Do(line []rune, pos int) (newLine [][]rune, offset int) { + lineStr := string(line[:pos]) + last := lastToken(lineStr) + for _, comp := range dynamicComplete(c.sh, lineStr) { + if !strings.HasPrefix(comp, last) { + continue + } + suffix := []rune(comp[len(last):]) + // Add a trailing space for non-directory completions. + if len(suffix) > 0 && suffix[len(suffix)-1] != '/' && suffix[len(suffix)-1] != '\\' { + suffix = append(suffix, ' ') + } + newLine = append(newLine, suffix) + } + return newLine, len([]rune(last)) +} + +// lastToken returns the last whitespace-delimited token from s, +// respecting single and double quotes. +func lastToken(s string) string { + if strings.HasSuffix(s, " ") || strings.HasSuffix(s, "\t") { + return "" + } + inSingle, inDouble := false, false + start := 0 + for i := 0; i < len(s); i++ { + switch s[i] { + case '\'': + if !inDouble { + inSingle = !inSingle + } + case '"': + if !inSingle { + inDouble = !inDouble + } + case ' ', '\t': + if !inSingle && !inDouble { + start = i + 1 + } + } + } + return s[start:] +} + // dynamicComplete provides tab completion for commands and paths. func dynamicComplete(sh *shell.Shell, line string) []string { line = strings.TrimLeft(line, " \t") @@ -241,7 +286,7 @@ func dynamicComplete(sh *shell.Shell, line string) []string { if strings.HasPrefix(name, base) { p := filepath.Join(dir, name) if e.IsDir() { - p += "/" + p += string(filepath.Separator) } completions = append(completions, p) } diff --git a/install.ps1 b/install.ps1 index 0996862..e5d8dfc 100644 --- a/install.ps1 +++ b/install.ps1 @@ -1,4 +1,4 @@ -#Requires -Version 5.1 +#Requires -Version 5.1 <# .SYNOPSIS Installs bash-for-windows and registers it as a Windows Terminal profile. diff --git a/internal/shell/builtins.go b/internal/shell/builtins.go index 68a0a8f..52793ec 100644 --- a/internal/shell/builtins.go +++ b/internal/shell/builtins.go @@ -6,6 +6,7 @@ import ( "io" "os" "path/filepath" + "regexp" "runtime" "sort" "strconv" @@ -18,59 +19,67 @@ var aliases = make(map[string]string) func (s *Shell) initBuiltins() { s.builtins = map[string]func([]string) error{ // Shell builtins - "cd": s.builtinCd, - "pwd": s.builtinPwd, - "echo": s.builtinEcho, - "exit": s.builtinExit, - "export": s.builtinExport, - "source": s.builtinSource, - ".": s.builtinSource, - "alias": s.builtinAlias, - "unalias": s.builtinUnalias, - "type": s.builtinType, - "test": s.builtinTest, - "[": s.builtinTest, - "read": s.builtinRead, - "printf": s.builtinPrintf, - "true": s.builtinTrue, - "false": s.builtinFalse, - "set": s.builtinSet, - "unset": s.builtinUnset, - "env": s.builtinEnv, - "which": s.builtinWhich, - "return": s.builtinReturn, - "break": s.builtinBreak, - "continue": s.builtinContinue, - "shift": s.builtinShift, - "declare": s.builtinDeclare, - "local": s.builtinDeclare, - "command": s.builtinCommand, - "jobs": s.builtinJobs, + "cd": s.builtinCd, + "pwd": s.builtinPwd, + "echo": s.builtinEcho, + "exit": s.builtinExit, + "export": s.builtinExport, + "source": s.builtinSource, + ".": s.builtinSource, + "alias": s.builtinAlias, + "unalias": s.builtinUnalias, + "type": s.builtinType, + "test": s.builtinTest, + "[": s.builtinTest, + "[[": s.builtinDoubleBracket, + "read": s.builtinRead, + "printf": s.builtinPrintf, + "true": s.builtinTrue, + "false": s.builtinFalse, + "set": s.builtinSet, + "unset": s.builtinUnset, + "env": s.builtinEnv, + "which": s.builtinWhich, + "return": s.builtinReturn, + "break": s.builtinBreak, + "continue": s.builtinContinue, + "shift": s.builtinShift, + "declare": s.builtinDeclare, + "local": s.builtinDeclare, + "command": s.builtinCommand, + "jobs": s.builtinJobs, + "disown": s.builtinDisown, + "mktemp": s.builtinMktemp, + "uname": s.builtinUname, + "whoami": s.builtinWhoami, + "hostname": s.builtinHostname, + "mapfile": s.builtinMapfile, + "readarray": s.builtinMapfile, // Coreutils - "ls": s.cmdLs, - "cat": s.cmdCat, - "grep": s.cmdGrep, - "head": s.cmdHead, - "tail": s.cmdTail, - "sort": s.cmdSort, - "wc": s.cmdWc, - "find": s.cmdFind, - "cp": s.cmdCp, - "mv": s.cmdMv, - "rm": s.cmdRm, - "mkdir": s.cmdMkdir, - "touch": s.cmdTouch, - "clear": s.cmdClear, - "cut": s.cmdCut, - "tr": s.cmdTr, - "uniq": s.cmdUniq, - "tee": s.cmdTee, - "date": s.cmdDate, - "sleep": s.cmdSleep, + "ls": s.cmdLs, + "cat": s.cmdCat, + "grep": s.cmdGrep, + "head": s.cmdHead, + "tail": s.cmdTail, + "sort": s.cmdSort, + "wc": s.cmdWc, + "find": s.cmdFind, + "cp": s.cmdCp, + "mv": s.cmdMv, + "rm": s.cmdRm, + "mkdir": s.cmdMkdir, + "touch": s.cmdTouch, + "clear": s.cmdClear, + "cut": s.cmdCut, + "tr": s.cmdTr, + "uniq": s.cmdUniq, + "tee": s.cmdTee, + "date": s.cmdDate, + "sleep": s.cmdSleep, "basename": s.cmdBasename, - "dirname": s.cmdDirname, - "sed": s.cmdSed, - "xargs": s.cmdXargs, + "dirname": s.cmdDirname, + "sed": s.cmdSed, + "xargs": s.cmdXargs, } } @@ -532,9 +541,58 @@ func (s *Shell) builtinSet(args []string) error { } return nil } - // Handle positional params: set -- a b c - if args[0] == "--" { - s.SetArgs(args[1:]) + i := 0 + for i < len(args) { + switch args[i] { + case "--": + s.SetArgs(args[i+1:]) + return nil + case "-e": + s.errexit = true + case "+e": + s.errexit = false + case "-u": + s.nounset = true + case "+u": + s.nounset = false + case "-x": + // xtrace — ignore + case "+x": + // xtrace off — ignore + case "-o": + if i+1 < len(args) { + switch args[i+1] { + case "pipefail": + s.pipefail = true + i++ + case "errexit": + s.errexit = true + i++ + case "nounset": + s.nounset = true + i++ + } + } + case "+o": + if i+1 < len(args) { + i++ // ignore value + } + default: + // Combined flags like -euo + if strings.HasPrefix(args[i], "-") { + for _, c := range args[i][1:] { + switch c { + case 'e': + s.errexit = true + case 'u': + s.nounset = true + case 'x': + // xtrace — ignore + } + } + } + } + i++ } return nil } @@ -619,14 +677,71 @@ func (s *Shell) builtinShift(args []string) error { } func (s *Shell) builtinDeclare(args []string) error { - for _, arg := range args { + // Parse flags + nameref := false + isArray := false + isExport := false + + nonFlagStart := 0 + for i, arg := range args { + if !strings.HasPrefix(arg, "-") { + nonFlagStart = i + break + } + nonFlagStart = i + 1 + for _, ch := range arg[1:] { + switch ch { + case 'n': + nameref = true + case 'a': + isArray = true + case 'i': + // integer — treat as scalar + case 'r': + // readonly — ignore + case 'x': + isExport = true + case 'g': + // global — ignore + } + } + } + + for _, arg := range args[nonFlagStart:] { if strings.HasPrefix(arg, "-") { continue } if idx := strings.Index(arg, "="); idx > 0 { name := arg[:idx] - if isValidIdentifier(name) { - s.vars[name] = arg[idx+1:] + if !isValidIdentifier(name) { + continue + } + value := s.expandWord(arg[idx+1:]) + if nameref { + s.namerefs[name] = value + } else if isArray { + if strings.HasPrefix(value, "(") && strings.HasSuffix(value, ")") { + inner := value[1 : len(value)-1] + elems := s.tokenize(inner) + s.setArray(name, elems) + } else { + s.setArray(name, []string{value}) + } + } else { + s.vars[name] = value + if isExport { + os.Setenv(name, value) + } + } + } else { + // No = — just declare + if !isValidIdentifier(arg) { + continue + } + if isArray { + if s.arrays[arg] == nil { + s.arrays[arg] = []string{} + } } } } @@ -634,6 +749,23 @@ func (s *Shell) builtinDeclare(args []string) error { } func (s *Shell) builtinCommand(args []string) error { + if len(args) > 0 && args[0] == "-v" { + found := true + for _, name := range args[1:] { + if _, ok := s.builtins[name]; ok { + fmt.Fprintln(s.Stdout, name) + } else if p := findExecutable(name); p != "" { + fmt.Fprintln(s.Stdout, p) + } else { + found = false + fmt.Fprintf(s.Stderr, "bash: command not found: %s\n", name) + } + } + if !found { + return exitCodeErr{1} + } + return nil + } if len(args) == 0 { return nil } @@ -646,6 +778,291 @@ func (s *Shell) builtinJobs(_ []string) error { return nil } +func (s *Shell) builtinDisown(_ []string) error { return nil } + +func (s *Shell) builtinMktemp(args []string) error { + template := "" + for _, a := range args { + if !strings.HasPrefix(a, "-") { + template = a + break + } + } + + // Use Windows temp dir regardless of the /tmp/ prefix in template + tmpDir := os.Getenv("TEMP") + if tmpDir == "" { + tmpDir = os.Getenv("TMP") + } + if tmpDir == "" { + tmpDir = os.TempDir() + } + + // Extract the base pattern (strip directory prefix) + base := filepath.Base(template) + if base == "" || base == "." { + base = "tmp.XXXXXX" + } + + // Replace trailing X's with * for os.CreateTemp + xCount := 0 + for i := len(base) - 1; i >= 0 && base[i] == 'X'; i-- { + xCount++ + } + goPattern := base + if xCount > 0 { + goPattern = base[:len(base)-xCount] + "*" + } + + f, err := os.CreateTemp(tmpDir, goPattern) + if err != nil { + return fmt.Errorf("mktemp: %v", err) + } + f.Close() + fmt.Fprintln(s.Stdout, f.Name()) + return nil +} + +func (s *Shell) builtinUname(args []string) error { + showSys := len(args) == 0 + showRelease := false + for _, a := range args { + switch a { + case "-s": + showSys = true + case "-r": + showRelease = true + case "-a": + showSys = true + showRelease = true + } + } + if showSys { + fmt.Fprintln(s.Stdout, "Windows_NT") + } + if showRelease { + kernelRel := os.Getenv("OS_VERSION") + if kernelRel == "" { + kernelRel = "10.0" + } + fmt.Fprintln(s.Stdout, kernelRel) + } + return nil +} + +func (s *Shell) builtinWhoami(_ []string) error { + user := os.Getenv("USERNAME") + if user == "" { + user = os.Getenv("USER") + } + if user == "" { + user = "unknown" + } + fmt.Fprintln(s.Stdout, user) + return nil +} + +func (s *Shell) builtinHostname(_ []string) error { + h, err := os.Hostname() + if err != nil { + h = os.Getenv("COMPUTERNAME") + } + if h == "" { + h = "unknown" + } + fmt.Fprintln(s.Stdout, h) + return nil +} + +func (s *Shell) builtinMapfile(args []string) error { + varName := "" + trimNewlines := false + for _, a := range args { + if a == "-t" { + trimNewlines = true + } else if !strings.HasPrefix(a, "-") { + varName = a + } + } + if varName == "" { + varName = "MAPFILE" + } + + var lines []string + scanner := bufio.NewScanner(s.Stdin) + for scanner.Scan() { + line := scanner.Text() + if !trimNewlines { + line += "\n" + } + lines = append(lines, line) + } + s.setArray(varName, lines) + return nil +} + +// builtinDoubleBracket implements [[ ... ]] +func (s *Shell) builtinDoubleBracket(args []string) error { + // Strip trailing ]] if present + if len(args) > 0 && args[len(args)-1] == "]]" { + args = args[:len(args)-1] + } + if !s.evalDB(args) { + return exitCodeErr{1} + } + return nil +} + +// evalDB evaluates a [[ ... ]] expression. +func (s *Shell) evalDB(args []string) bool { + if len(args) == 0 { + return false + } + + // 1. || — lowest precedence, scan right-to-left + for i := len(args) - 1; i >= 0; i-- { + if args[i] == "||" { + return s.evalDB(args[:i]) || s.evalDB(args[i+1:]) + } + } + + // 2. && + for i := len(args) - 1; i >= 0; i-- { + if args[i] == "&&" { + return s.evalDB(args[:i]) && s.evalDB(args[i+1:]) + } + } + + // 3. ! prefix + if args[0] == "!" { + return !s.evalDB(args[1:]) + } + + // 4. Unary flags + if len(args) == 2 { + val := args[1] + switch args[0] { + case "-f": + info, err := os.Stat(val) + return err == nil && info.Mode().IsRegular() + case "-e": + _, err := os.Stat(val) + return err == nil + case "-d": + info, err := os.Stat(val) + return err == nil && info.IsDir() + case "-s": + info, err := os.Stat(val) + return err == nil && info.Size() > 0 + case "-r": + f, err := os.Open(val) + if err != nil { + return false + } + f.Close() + return true + case "-w": + f, err := os.OpenFile(val, os.O_WRONLY, 0) + if err != nil { + return false + } + f.Close() + return true + case "-x": + info, err := os.Stat(val) + return err == nil && (info.Mode()&0111 != 0 || runtime.GOOS == "windows") + case "-n": + return val != "" + case "-z": + return val == "" + case "-L": + _, err := os.Lstat(val) + return err == nil + case "-v": + _, ok := s.vars[val] + if !ok { + _, ok = s.arrays[val] + } + return ok + } + } + + // 5. Binary operators + if len(args) == 3 { + lhs, op, rhs := args[0], args[1], args[2] + switch op { + case "==": + matched, err := filepath.Match(rhs, lhs) + if err != nil { + return lhs == rhs + } + return matched + case "!=": + matched, err := filepath.Match(rhs, lhs) + if err != nil { + return lhs != rhs + } + return !matched + case "=~": + matched, err := regexp.MatchString(rhs, lhs) + return err == nil && matched + case "<": + return lhs < rhs + case ">": + return lhs > rhs + case "-eq": + ln, le := strconv.Atoi(lhs) + rn, re := strconv.Atoi(rhs) + if le != nil || re != nil { + return false + } + return ln == rn + case "-ne": + ln, le := strconv.Atoi(lhs) + rn, re := strconv.Atoi(rhs) + if le != nil || re != nil { + return false + } + return ln != rn + case "-lt": + ln, le := strconv.Atoi(lhs) + rn, re := strconv.Atoi(rhs) + if le != nil || re != nil { + return false + } + return ln < rn + case "-le": + ln, le := strconv.Atoi(lhs) + rn, re := strconv.Atoi(rhs) + if le != nil || re != nil { + return false + } + return ln <= rn + case "-gt": + ln, le := strconv.Atoi(lhs) + rn, re := strconv.Atoi(rhs) + if le != nil || re != nil { + return false + } + return ln > rn + case "-ge": + ln, le := strconv.Atoi(lhs) + rn, re := strconv.Atoi(rhs) + if le != nil || re != nil { + return false + } + return ln >= rn + } + } + + // 6. Single arg truthy test + if len(args) == 1 { + return args[0] != "" + } + + return false +} + // ─── Coreutils ──────────────────────────────────────────────────────────────── func (s *Shell) cmdLs(args []string) error { diff --git a/internal/shell/control.go b/internal/shell/control.go index 77c0f48..7518173 100644 --- a/internal/shell/control.go +++ b/internal/shell/control.go @@ -2,6 +2,7 @@ package shell import ( "fmt" + "path/filepath" "strings" ) @@ -85,6 +86,125 @@ func (s *Shell) executeIf(block string) error { return nil } +// executeCase handles: case WORD in PAT1) BODY1 ;; PAT2) BODY2 ;; esac +// splitStatements emits ";;" as a separate token, so we can parse arms directly. +func (s *Shell) executeCase(block string) error { + stmts := splitStatements(block) + if len(stmts) == 0 { + return nil + } + + // First stmt: "case WORD in" + caseHeader := stmts[0] + caseRest := strings.TrimSpace(strings.TrimPrefix(strings.TrimSpace(caseHeader), "case")) + + var word string + var startIdx int + + // Find trailing " in" to separate word from "in" + inIdx := strings.LastIndex(caseRest, " in") + if inIdx >= 0 && strings.TrimSpace(caseRest[inIdx+3:]) == "" { + word = s.expandWord(strings.TrimSpace(caseRest[:inIdx])) + startIdx = 1 + } else { + word = s.expandWord(caseRest) + startIdx = 1 + if startIdx < len(stmts) && strings.TrimSpace(stmts[startIdx]) == "in" { + startIdx++ + } + } + + // Parse arms from remaining stmts + // stmts now include ";;" as explicit tokens + type arm struct { + patterns []string + body []string + } + var arms []arm + var curArm *arm + + for _, stmt := range stmts[startIdx:] { + stmt = strings.TrimSpace(stmt) + if stmt == "" { + continue + } + if stmt == "esac" { + break + } + if stmt == ";;" || stmt == ";&" || stmt == ";;&" { + curArm = nil + continue + } + + if curArm == nil { + // Expect a pattern: PAT) [body] + parenIdx := findCasePatternEnd(stmt) + if parenIdx < 0 { + continue + } + patStr := strings.TrimSpace(stmt[:parenIdx]) + bodyStr := strings.TrimSpace(stmt[parenIdx+1:]) + + rawPats := strings.Split(patStr, "|") + var pats []string + for _, p := range rawPats { + p = strings.TrimSpace(p) + if p != "" { + pats = append(pats, p) + } + } + arms = append(arms, arm{patterns: pats}) + curArm = &arms[len(arms)-1] + if bodyStr != "" { + curArm.body = append(curArm.body, bodyStr) + } + } else { + curArm.body = append(curArm.body, stmt) + } + } + + // Execute matching arm + for _, a := range arms { + for _, pat := range a.patterns { + expandedPat := s.expandWord(pat) + matched := false + if expandedPat == "*" { + matched = true + } else { + if m, err := filepath.Match(expandedPat, word); err == nil { + matched = m + } else { + matched = expandedPat == word + } + } + if matched { + body := strings.Join(a.body, "\n") + return s.Execute(body) + } + } + } + return nil +} + +// findCasePatternEnd finds the ) that ends the pattern in a case arm. +// Handles quoted strings. +func findCasePatternEnd(chunk string) int { + inSingle := false + inDouble := false + for i := 0; i < len(chunk); i++ { + c := chunk[i] + switch { + case c == '\'' && !inDouble: + inSingle = !inSingle + case c == '"' && !inSingle: + inDouble = !inDouble + case c == ')' && !inSingle && !inDouble: + return i + } + } + return -1 +} + // executeFor handles: for VAR in WORDS; do BODY; done // (also: for VAR; do BODY; done — iterates positional params) func (s *Shell) executeFor(block string) error { @@ -95,7 +215,8 @@ func (s *Shell) executeFor(block string) error { // Parse "for VAR in WORDS" header := stmts[0] - fields := strings.Fields(header) + // Use tokenize so array expansion works in "for x in ${arr[@]}" + fields := s.tokenize(header) if len(fields) < 2 { return fmt.Errorf("for: bad syntax") } @@ -110,10 +231,8 @@ func (s *Shell) executeFor(block string) error { } } if inIdx >= 0 { - for _, raw := range fields[inIdx+1:] { - expanded := s.expandWord(raw) - items = append(items, s.expandGlob(expanded)...) - } + // Items are already expanded by tokenize + items = fields[inIdx+1:] } else { // for var; do ... → iterate positional params items = s.args @@ -164,6 +283,169 @@ func (s *Shell) executeFor(block string) error { return nil } +// executeForC handles C-style: for ((init; cond; incr)); do BODY; done +func (s *Shell) executeForC(block string) error { + stmts := splitStatements(block) + if len(stmts) == 0 { + return nil + } + + // Extract the ((...)) header + header := stmts[0] + // Find "for" keyword and then "((" + trimmed := strings.TrimSpace(header) + // Strip "for" keyword + trimmed = strings.TrimSpace(trimmed[3:]) // skip "for" + // Expect "((" + if !strings.HasPrefix(trimmed, "((") { + return fmt.Errorf("for: bad C-style syntax") + } + trimmed = trimmed[2:] // skip "((" + // Find closing "))" + endIdx := strings.Index(trimmed, "))") + if endIdx < 0 { + return fmt.Errorf("for: missing '))'") + } + inner := trimmed[:endIdx] + + // Split on ; to get init, cond, incr + parts := strings.SplitN(inner, ";", 3) + init := "" + cond := "" + incr := "" + if len(parts) >= 1 { + init = strings.TrimSpace(parts[0]) + } + if len(parts) >= 2 { + cond = strings.TrimSpace(parts[1]) + } + if len(parts) >= 3 { + incr = strings.TrimSpace(parts[2]) + } + + // Execute init as arithmetic assignment + if init != "" { + s.execArithAssign(init) + } + + // Collect body between "do" and "done" + var bodyStmts []string + inBody := false + for _, stmt := range stmts[1:] { + w := firstWord(stmt) + if !inBody { + if w == "do" { + inBody = true + if rest := afterWord(stmt); rest != "" { + bodyStmts = append(bodyStmts, rest) + } + } + continue + } + if w == "done" { + break + } + bodyStmts = append(bodyStmts, stmt) + } + + body := strings.Join(bodyStmts, "\n") + + for { + // Evaluate condition + if cond != "" { + if s.evalArith(cond) == 0 { + break + } + } + + if err := s.Execute(body); err != nil { + if be, ok := err.(breakErr); ok { + if be.n <= 1 { + break + } + return breakErr{be.n - 1} + } + if ce, ok := err.(continueErr); ok { + if ce.n <= 1 { + // continue — execute incr then re-check cond + if incr != "" { + s.execArithAssign(incr) + } + continue + } + return continueErr{ce.n - 1} + } + if _, ok := err.(returnErr); ok { + return err + } + } + + // Execute increment + if incr != "" { + s.execArithAssign(incr) + } + } + return nil +} + +// execArithAssign handles arithmetic assignment expressions like i=0, i++, i+=1, ((i++)) +func (s *Shell) execArithAssign(expr string) { + expr = strings.TrimSpace(expr) + if expr == "" { + return + } + + // Handle i++ and i-- + if strings.HasSuffix(expr, "++") { + varName := strings.TrimSpace(expr[:len(expr)-2]) + if isValidIdentifier(varName) { + n := s.evalArith(varName) + s.vars[varName] = fmt.Sprintf("%d", n+1) + return + } + } + if strings.HasSuffix(expr, "--") { + varName := strings.TrimSpace(expr[:len(expr)-2]) + if isValidIdentifier(varName) { + n := s.evalArith(varName) + s.vars[varName] = fmt.Sprintf("%d", n-1) + return + } + } + + // Handle i+=N + if idx := strings.Index(expr, "+="); idx > 0 { + varName := strings.TrimSpace(expr[:idx]) + if isValidIdentifier(varName) { + delta := s.evalArith(expr[idx+2:]) + n := s.evalArith(varName) + s.vars[varName] = fmt.Sprintf("%d", n+delta) + return + } + } + + // Handle i-=N + if idx := strings.Index(expr, "-="); idx > 0 { + varName := strings.TrimSpace(expr[:idx]) + if isValidIdentifier(varName) { + delta := s.evalArith(expr[idx+2:]) + n := s.evalArith(varName) + s.vars[varName] = fmt.Sprintf("%d", n-delta) + return + } + } + + // Handle i=expr + if idx := strings.Index(expr, "="); idx > 0 { + varName := strings.TrimSpace(expr[:idx]) + if isValidIdentifier(varName) { + n := s.evalArith(expr[idx+1:]) + s.vars[varName] = fmt.Sprintf("%d", n) + return + } + } +} + // executeWhileUntil handles while/until loops. func (s *Shell) executeWhileUntil(block string, isUntil bool) error { stmts := splitStatements(block) diff --git a/internal/shell/exec.go b/internal/shell/exec.go index acdcb57..17eb639 100644 --- a/internal/shell/exec.go +++ b/internal/shell/exec.go @@ -70,7 +70,10 @@ func splitPipe(input string) []string { current.WriteByte(c) case c == '|' && !inSingle && !inDouble && parenDepth == 0: if i+1 < len(input) && input[i+1] == '|' { - current.WriteByte(c) // part of ||, pass through + // || operator — pass both chars through + current.WriteByte(c) + current.WriteByte(input[i+1]) + i++ } else { parts = append(parts, strings.TrimSpace(current.String())) current.Reset() @@ -87,6 +90,63 @@ func splitPipe(input string) []string { return parts } +// parseArrayAssign detects NAME=(...) or NAME+=(...) at start of input. +func (s *Shell) parseArrayAssign(input string) (name string, appendMode bool, elements []string, ok bool) { + input = strings.TrimSpace(input) + // Read identifier + i := 0 + for i < len(input) && isVarChar(input[i]) { + i++ + } + if i == 0 { + return + } + name = input[:i] + if !isValidIdentifier(name) { + return + } + // Optional + + if i < len(input) && input[i] == '+' { + appendMode = true + i++ + } + // Require = + if i >= len(input) || input[i] != '=' { + return + } + i++ + // Require ( + if i >= len(input) || input[i] != '(' { + return + } + i++ + // Find matching ) + depth := 1 + j := i + inSingle := false + inDouble := false + for j < len(input) && depth > 0 { + c := input[j] + switch { + case c == '\'' && !inDouble: + inSingle = !inSingle + case c == '"' && !inSingle: + inDouble = !inDouble + case c == '(' && !inSingle && !inDouble: + depth++ + case c == ')' && !inSingle && !inDouble: + depth-- + } + if depth > 0 { + j++ + } + } + content := input[i:j] + elements = s.tokenize(content) + ok = true + return +} + // executeCommand executes a single command (no pipes, no &&/||). func (s *Shell) executeCommand(input string) error { input = strings.TrimSpace(input) @@ -94,12 +154,22 @@ func (s *Shell) executeCommand(input string) error { return nil } + // Detect array assignment NAME=(...) or NAME+=(...) + if name, appendMode, elems, ok := s.parseArrayAssign(input); ok { + if appendMode { + s.appendArray(name, elems) + } else { + s.setArray(name, elems) + } + return nil + } + tokens := s.tokenize(input) if len(tokens) == 0 { return nil } - cmdArgs, redirects := extractRedirects(tokens) + cmdArgs, redirects := s.extractRedirects(tokens) if len(cmdArgs) == 0 { // Pure redirection, e.g. "> file" creates/truncates file return s.withRedirects(redirects, func() error { return nil }) @@ -132,7 +202,7 @@ func (s *Shell) executeCommand(input string) error { }) } -func extractRedirects(tokens []string) ([]string, []redirect) { +func (s *Shell) extractRedirects(tokens []string) ([]string, []redirect) { var args []string var redirects []redirect @@ -194,11 +264,24 @@ func extractRedirects(tokens []string) ([]string, []redirect) { case strings.HasPrefix(tok, ">") && len(tok) > 1: redirects = append(redirects, redirect{1, ">", tok[1:]}) i++ - // < (stdin) + // < (stdin) — also handle process substitution <(...) case tok == "<": if i+1 < len(tokens) { - redirects = append(redirects, redirect{0, "<", tokens[i+1]}) - i += 2 + next := tokens[i+1] + // Process substitution: < <(cmd) + if strings.HasPrefix(next, "<(") && strings.HasSuffix(next, ")") { + cmd := next[2 : len(next)-1] + tmpf, err := os.CreateTemp("", "procsub*") + if err == nil { + s.withIO(nil, tmpf, nil, func() error { return s.Execute(cmd) }) + tmpf.Close() + redirects = append(redirects, redirect{0, "<", tmpf.Name()}) + } + i += 2 + } else { + redirects = append(redirects, redirect{0, "<", next}) + i += 2 + } } else { i++ } @@ -296,7 +379,7 @@ func (s *Shell) executeCommandBg(input string) error { if len(tokens) == 0 { return nil } - cmdArgs, _ := extractRedirects(tokens) + cmdArgs, _ := s.extractRedirects(tokens) if len(cmdArgs) == 0 { return nil } diff --git a/internal/shell/expand.go b/internal/shell/expand.go index 9ba70b7..fa22537 100644 --- a/internal/shell/expand.go +++ b/internal/shell/expand.go @@ -60,7 +60,17 @@ func (s *Shell) expandWord(word string) string { i += 2 } } else { - result.WriteByte(next) + // Only consume the backslash when escaping a shell + // metacharacter. Before regular path characters (letters, + // digits, etc.) keep it literal so Windows paths like + // C:\workspace work unquoted. + const metachars = " \t\n$*?[\"'\\|&;()<>{}!#~`" + if strings.ContainsRune(metachars, rune(next)) { + result.WriteByte(next) + } else { + result.WriteByte('\\') + result.WriteByte(next) + } i += 2 } } else { @@ -72,6 +82,82 @@ func (s *Shell) expandWord(word string) string { result.WriteByte('$') break } + + // $'...' ANSI C string + if word[i] == '\'' { + i++ // skip opening ' + for i < len(word) && word[i] != '\'' { + if word[i] == '\\' && i+1 < len(word) { + i++ // skip backslash, now at escape char + switch word[i] { + case 'n': + result.WriteByte('\n') + case 't': + result.WriteByte('\t') + case 'r': + result.WriteByte('\r') + case '\\': + result.WriteByte('\\') + case '\'': + result.WriteByte('\'') + case '"': + result.WriteByte('"') + case 'a': + result.WriteByte('\a') + case 'b': + result.WriteByte('\b') + case 'f': + result.WriteByte('\f') + case 'v': + result.WriteByte('\v') + case 'e', 'E': + result.WriteByte(0x1b) + case '0', '1', '2', '3', '4', '5', '6', '7': + // Octal \NNN — up to 3 digits + oct := 0 + for k := 0; k < 3 && i < len(word) && word[i] >= '0' && word[i] <= '7'; k++ { + oct = oct*8 + int(word[i]-'0') + i++ + } + result.WriteByte(byte(oct)) + continue + case 'x': + // Hex \xNN — up to 2 digits + i++ // skip 'x' + hexv := 0 + for k := 0; k < 2 && i < len(word); k++ { + d := word[i] + if d >= '0' && d <= '9' { + hexv = hexv*16 + int(d-'0') + i++ + } else if d >= 'a' && d <= 'f' { + hexv = hexv*16 + int(d-'a'+10) + i++ + } else if d >= 'A' && d <= 'F' { + hexv = hexv*16 + int(d-'A'+10) + i++ + } else { + break + } + } + result.WriteByte(byte(hexv)) + continue + default: + result.WriteByte('\\') + result.WriteByte(word[i]) + } + i++ + } else { + result.WriteByte(word[i]) + i++ + } + } + if i < len(word) { + i++ // skip closing ' + } + continue + } + switch word[i] { case '(': if i+1 < len(word) && word[i+1] == '(' { @@ -146,7 +232,7 @@ func (s *Shell) expandWord(word string) string { result.WriteString(s.vars["#"]) i++ case '@': - result.WriteString(s.vars["@"]) + result.WriteString(strings.Join(s.args, "\x01")) i++ case '*': result.WriteString(s.vars["*"]) @@ -172,17 +258,52 @@ func (s *Shell) expandWord(word string) string { } func (s *Shell) getVar(name string) string { - if v, ok := s.vars[name]; ok { + // Resolve nameref chain + resolved := s.resolveNR(name) + if v, ok := s.vars[resolved]; ok { return v } - return os.Getenv(name) + return os.Getenv(resolved) } func (s *Shell) evalVarExpr(expr string) string { - // ${#VAR} — string length + // ${#arr[@]} or ${#arr[*]} → length of array if strings.HasPrefix(expr, "#") { - return strconv.Itoa(len(s.getVar(expr[1:]))) + rest := expr[1:] + if strings.HasSuffix(rest, "[@]") || strings.HasSuffix(rest, "[*]") { + arrName := rest[:len(rest)-3] + return strconv.Itoa(len(s.getArray(arrName))) + } + // ${#VAR} — string length + return strconv.Itoa(len(s.getVar(rest))) } + + // Array indexing: ${arr[@]}, ${arr[*]}, ${arr[N]} + if bracketIdx := strings.Index(expr, "["); bracketIdx >= 0 && strings.HasSuffix(expr, "]") { + // Make sure there's no operator before the bracket + prefix := expr[:bracketIdx] + hasOp := strings.ContainsAny(prefix, ":-:=:+%#") + if !hasOp { + arrName := prefix + idx := expr[bracketIdx+1 : len(expr)-1] + if idx == "@" { + arr := s.getArray(arrName) + return strings.Join(arr, "\x01") + } + if idx == "*" { + arr := s.getArray(arrName) + return strings.Join(arr, " ") + } + // Numeric index + n := s.evalArith(idx) + arr := s.getArray(arrName) + if n >= 0 && n < len(arr) { + return arr[n] + } + return "" + } + } + // ${VAR:-default} if idx := strings.Index(expr, ":-"); idx >= 0 { varName := expr[:idx] @@ -285,7 +406,48 @@ func evalArithExpr(expr string) int { if strings.HasPrefix(expr, "(") && strings.HasSuffix(expr, ")") { return evalArithExpr(expr[1 : len(expr)-1]) } - // Operators in precedence order (lowest first so we split on last occurrence) + + // Comparison operators (lowest precedence) — multi-char first + for _, op := range []string{"<=", ">=", "==", "!=", "<", ">"} { + if idx := findBinaryOpStr(expr, op); idx >= 0 { + left := evalArithExpr(expr[:idx]) + right := evalArithExpr(expr[idx+len(op):]) + switch op { + case "<": + if left < right { + return 1 + } + return 0 + case ">": + if left > right { + return 1 + } + return 0 + case "<=": + if left <= right { + return 1 + } + return 0 + case ">=": + if left >= right { + return 1 + } + return 0 + case "==": + if left == right { + return 1 + } + return 0 + case "!=": + if left != right { + return 1 + } + return 0 + } + } + } + + // Arithmetic operators in precedence order (lowest first so we split on last occurrence) for _, op := range []string{"+", "-", "*", "/", "%"} { if idx := findBinaryOp(expr, op); idx >= 0 { left := evalArithExpr(expr[:idx]) @@ -313,6 +475,39 @@ func evalArithExpr(expr string) int { return 0 } +// findBinaryOpStr finds the rightmost occurrence of a multi-character binary operator. +func findBinaryOpStr(expr, op string) int { + depth := 0 + // Search right-to-left + for i := len(expr) - len(op); i >= 0; i-- { + switch expr[i] { + case ')': + depth++ + case '(': + depth-- + } + if depth != 0 { + continue + } + if expr[i:i+len(op)] == op { + // Make sure it's not part of a longer operator + // e.g. don't match < in <= + if len(op) == 1 { + // For < and >, make sure next char is not = + if i+1 < len(expr) && (expr[i+1] == '=' || (op == "<" && expr[i+1] == '<') || (op == ">" && expr[i+1] == '>')) { + continue + } + } + // For single char ops, make sure previous char is not the same op (e.g. << or >>) + if len(op) == 1 && i > 0 && expr[i-1] == expr[i] { + continue + } + return i + } + } + return -1 +} + func findBinaryOp(expr, op string) int { depth := 0 for i := len(expr) - 1; i >= 0; i-- { @@ -353,7 +548,7 @@ func (s *Shell) tokenize(input string) []string { current := strings.Builder{} inSingle := false inDouble := false - parenDepth := 0 // nesting depth inside $(...) or $((...)) + parenDepth := 0 // nesting depth inside $(...) or $((...)) pendingDollar := false // true after $ when next char is ( wasQuoted := false @@ -420,7 +615,7 @@ doneTokenizing: if eqIdx := strings.Index(clean, "="); eqIdx > 0 { name := clean[:eqIdx] if isValidIdentifier(name) && !strings.Contains(clean[:eqIdx], "$") { - value := s.expandWord(clean[eqIdx+1:]) + value := strings.ReplaceAll(s.expandWord(clean[eqIdx+1:]), "\x01", " ") s.vars[name] = value os.Setenv(name, value) rawTokens = rawTokens[1:] @@ -438,6 +633,16 @@ doneTokenizing: tok = tok[2:] } expanded := s.expandWord(tok) + // Handle multi-word expansion from $@ and ${arr[@]} + if strings.Contains(expanded, "\x01") { + parts := strings.Split(expanded, "\x01") + for _, p := range parts { + if p != "" { + result = append(result, p) + } + } + continue + } if !quoted && strings.ContainsAny(expanded, "*?[") { result = append(result, s.expandGlob(expanded)...) } else { diff --git a/internal/shell/shell.go b/internal/shell/shell.go index 0055260..647e78f 100644 --- a/internal/shell/shell.go +++ b/internal/shell/shell.go @@ -24,6 +24,8 @@ func (e exitCodeErr) Error() string { return "" } type Shell struct { vars map[string]string + arrays map[string][]string + namerefs map[string]string builtins map[string]func([]string) error funcs map[string]string // function name → body lastExit int @@ -31,15 +33,20 @@ type Shell struct { Stdout io.Writer Stderr io.Writer args []string + errexit bool + nounset bool + pipefail bool } func New() *Shell { s := &Shell{ - vars: map[string]string{}, - funcs: map[string]string{}, - Stdin: os.Stdin, - Stdout: os.Stdout, - Stderr: os.Stderr, + vars: map[string]string{}, + arrays: map[string][]string{}, + namerefs: map[string]string{}, + funcs: map[string]string{}, + Stdin: os.Stdin, + Stdout: os.Stdout, + Stderr: os.Stderr, } s.initBuiltins() s.vars["SHELL"] = "bash-for-windows" @@ -58,6 +65,38 @@ func New() *Shell { return s } +// resolveNR resolves a variable name through the nameref chain (circular-ref safe). +func (s *Shell) resolveNR(name string) string { + seen := map[string]bool{} + for { + if seen[name] { + return name + } + seen[name] = true + target, ok := s.namerefs[name] + if !ok { + return name + } + name = target + } +} + +// getArray returns array (resolving namerefs). +func (s *Shell) getArray(name string) []string { + return s.arrays[s.resolveNR(name)] +} + +// setArray sets array (resolving namerefs). +func (s *Shell) setArray(name string, vals []string) { + s.arrays[s.resolveNR(name)] = vals +} + +// appendArray appends to array (resolving namerefs). +func (s *Shell) appendArray(name string, vals []string) { + n := s.resolveNR(name) + s.arrays[n] = append(s.arrays[n], vals...) +} + func (s *Shell) SetArgs(args []string) { s.args = args s.vars["#"] = fmt.Sprintf("%d", len(args)) @@ -81,7 +120,12 @@ func (s *Shell) SetVar(name, value string) { // Execute runs commands from the given input string. func (s *Shell) Execute(input string) error { + // Normalize CRLF to LF + input = strings.ReplaceAll(input, "\r\n", "\n") + input = strings.ReplaceAll(input, "\r", "\n") input = strings.ReplaceAll(input, "\\\n", " ") + // Pre-process heredocs + input = preprocessHeredocs(input) blocks := parseBlocks(input) for _, block := range blocks { if err := s.executeBlock(block); err != nil { @@ -119,7 +163,7 @@ func IsIncomplete(input string) bool { for _, stmt := range stmts { w := firstWord(stmt) switch w { - case "if", "for", "while", "until": + case "if", "for", "while", "until", "case": depth++ case "fi", "done", "esac": depth-- @@ -162,7 +206,7 @@ func parseBlocks(input string) []string { funcKwDepth = 0 for _, p := range splitStatements(stmt[braceIdx+1:]) { switch firstWord(p) { - case "if", "for", "while", "until": + case "if", "for", "while", "until", "case": funcKwDepth++ case "fi", "done", "esac": funcKwDepth-- @@ -182,7 +226,7 @@ func parseBlocks(input string) []string { } switch w { - case "if", "for", "while", "until": + case "if", "for", "while", "until", "case": kwDepth++ } kwDepth += embeddedKwDepth(stmt) @@ -211,7 +255,7 @@ func parseBlocks(input string) []string { continue } switch w { - case "if", "for", "while", "until": + case "if", "for", "while", "until", "case": funcKwDepth++ case "fi", "done", "esac": funcKwDepth-- @@ -246,23 +290,178 @@ func embeddedKwDepth(stmt string) int { return delta } -// splitStatements splits input on semicolons and newlines, respecting quotes. +// preprocessHeredocs converts heredoc syntax (<= 2 && ((marker[0] == '\'' && marker[len(marker)-1] == '\'') || + (marker[0] == '"' && marker[len(marker)-1] == '"')) { + marker = marker[1 : len(marker)-1] + } + _ = quotedMarker + // Collect body until marker + var bodyLines []string + for i < len(lines) { + bodyLine := lines[i] + check := bodyLine + if stripTabs { + check = strings.TrimLeft(check, "\t") + } + if strings.TrimRight(check, "\r") == marker { + i++ + break + } + bodyLines = append(bodyLines, bodyLine) + i++ + } + // Write to temp file + content := strings.Join(bodyLines, "\n") + if len(bodyLines) > 0 { + content += "\n" + } + f, err := os.CreateTemp("", "heredoc*") + if err == nil { + f.WriteString(content) + f.Close() + newLine += " < " + f.Name() + } + } + result = append(result, newLine) + } + return strings.Join(result, "\n") +} + +// parseHeredocMarkers finds < 0: + parenDepth-- + current.WriteByte(c) + case c == ';' && !inSingle && !inDouble && parenDepth == 0: + if i+1 < len(input) && input[i+1] == ';' { + // Double semicolon — flush current token, then emit ";;" as a token + if s := strings.TrimSpace(current.String()); s != "" { + result = append(result, s) + } + current.Reset() + result = append(result, ";;") + i++ // skip second ; + } else { + // Single semicolon — just a statement separator + if s := strings.TrimSpace(current.String()); s != "" { + result = append(result, s) + } + current.Reset() + } + case c == '\n' && !inSingle && !inDouble && parenDepth == 0: if s := strings.TrimSpace(current.String()); s != "" { result = append(result, s) } @@ -338,11 +537,18 @@ func (s *Shell) executeBlock(block string) error { case "if": return s.executeIf(block) case "for": + // Check for C-style for (( ... )) + trimmed := strings.TrimSpace(block) + if strings.HasPrefix(trimmed, "for ((") || strings.HasPrefix(trimmed, "for((") { + return s.executeForC(block) + } return s.executeFor(block) case "while": return s.executeWhileUntil(block, false) case "until": return s.executeWhileUntil(block, true) + case "case": + return s.executeCase(block) } if isFuncDefStart(block) { return s.defineFunction(block) @@ -418,6 +624,7 @@ func (s *Shell) executeAndOrList(line string) error { op := "" inSingle := false inDouble := false + dbDepth := 0 // double-bracket [[ depth for i := 0; i < len(line); i++ { c := line[i] @@ -428,12 +635,22 @@ func (s *Shell) executeAndOrList(line string) error { case c == '"' && !inSingle: inDouble = !inDouble current.WriteByte(c) - case c == '&' && !inSingle && !inDouble && i+1 < len(line) && line[i+1] == '&': + case c == '[' && !inSingle && !inDouble && i+1 < len(line) && line[i+1] == '[': + dbDepth++ + current.WriteByte(c) + current.WriteByte(line[i+1]) + i++ + case c == ']' && !inSingle && !inDouble && i+1 < len(line) && line[i+1] == ']' && dbDepth > 0: + dbDepth-- + current.WriteByte(c) + current.WriteByte(line[i+1]) + i++ + case c == '&' && !inSingle && !inDouble && i+1 < len(line) && line[i+1] == '&' && dbDepth == 0: tokens = append(tokens, tok{current.String(), op}) current.Reset() op = "&&" i++ - case c == '|' && !inSingle && !inDouble && i+1 < len(line) && line[i+1] == '|': + case c == '|' && !inSingle && !inDouble && i+1 < len(line) && line[i+1] == '|' && dbDepth == 0: tokens = append(tokens, tok{current.String(), op}) current.Reset() op = "||"